## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import sys
import pyfixbuf
import pyfixbuf.cert
import shared_utils
from string import punctuation

# FIELD_var
SPORT = 'sourceTransportPort'
DPORT = 'destinationTransportPort'
APPLABEL = 'silkAppLabel'


def check_ports_applabel(data):
    data_dict = data.as_dict()
    try:
        applabel_val = data_dict[APPLABEL]
        sport_val = data_dict[SPORT]
        dport_val = data_dict[DPORT]
    except KeyError:
        return False

    # Check if service port is different from applabel
    if (sport_val >= 1024 and dport_val >= 1024) or applabel_val == 0:
        return False
    elif sport_val != applabel_val and dport_val != applabel_val:
        return True
    return False


def count_records(in_files):
    # Initialize found count
    found_count = 0

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    for data in shared_utils.process_files(in_files, infomodel):
        if check_ports_applabel(data):
            found_count += 1

    # Print final count
    print("Found %d records with service port/app label mismatch." % found_count)
    return found_count


def output_to_screen(in_files, out_fields, delim):
    # Print headers
    field_string = ""
    for field in out_fields:
        field_string += delim + field
    print(field_string.strip(punctuation))

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    # Iterate through all records
    for data in shared_utils.process_files(in_files, infomodel):
        if check_ports_applabel(data):
            shared_utils.print_fields(data, out_fields, delim, infomodel)


def output_to_ipfix(in_files, out_file):
    # Set up export capabilities
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)
    exporter = pyfixbuf.Exporter()
    exporter.init_file(out_file)
    exp_session = pyfixbuf.Session(infomodel)
    exp_buf = pyfixbuf.Buffer(auto=True)
    exp_buf.init_export(exp_session, exporter)

    # Initialize a filtered record counter
    found_count = 0

    for data in shared_utils.process_files(in_files, infomodel, exp_session):
        if check_ports_applabel(data):
            found_count += 1
            shared_utils.print_to_ipfix(data, exp_buf)

    print("Wrote %d records to %s." % (found_count, out_file), file=sys.stderr)
