## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import sys
import pyfixbuf
import pyfixbuf.cert
import shared_utils
from string import punctuation


def check_fields_to_match(fields_to_match, bad_IPs, data):
    data_dict = data.as_dict()
    for field in fields_to_match:
        if str(data_dict.get(field)) in bad_IPs:
            return True
    return False


def count_records(in_files, bad_IPs, fields_to_match):
    # Initialize found count
    found_count = 0

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    # Iterate through all records
    for data in shared_utils.process_files(in_files, infomodel):
        # Check if bad IP is in the desired field
        if check_fields_to_match(fields_to_match, bad_IPs, data):
            found_count += 1

    # Print final count
    print("Found %d records with bad IP(s) in desired field(s)." % found_count)
    return found_count


def filter_to_screen(in_files, bad_IPs, fields_to_match, out_fields, delim):
    # Print headers
    field_string = ""
    for field in out_fields:
        field_string += delim + field
    print(field_string.strip(punctuation))

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    # Iterate through all passed input files
    for data in shared_utils.process_files(in_files, infomodel):
        # Check if a bad IP is in the desired field
        if check_fields_to_match(fields_to_match, bad_IPs, data):
            # If bad IP found in desired field, print output fields to console
            shared_utils.print_fields(data, out_fields, delim, infomodel)


def filter_to_ipfix(in_files, bad_IPs, fields_to_match, out_file):
    # Set up export capabilities
    infomodel = pyfixbuf.InfoModel()  # not needed
    pyfixbuf.cert.add_elements_to_model(infomodel)
    exporter = pyfixbuf.Exporter()
    exporter.init_file(out_file)
    exp_session = pyfixbuf.Session(infomodel)
    exp_buf = pyfixbuf.Buffer(auto=True)
    exp_buf.init_export(exp_session, exporter)

    # Initialize a filtered record counter
    found_count = 0

    for data in shared_utils.process_files(in_files, infomodel, exp_session):
        if check_fields_to_match(fields_to_match, bad_IPs, data):
            # If bad IP found in desired field, output record to IPFIX file
            found_count += 1
            shared_utils.print_to_ipfix(data, exp_buf)

    print("Wrote %d records to %s." % (found_count, out_file), file=sys.stderr)
