## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import sys
import pyfixbuf
import pyfixbuf.cert
import shared_utils
from string import punctuation


def check_dnsQName(dnsQN, bad_dnsQNs):
    dnsQN = dnsQN[::-1].strip(punctuation)
    for bad_dnsQN in bad_dnsQNs:
        bad_dnsQN = bad_dnsQN[::-1]
        if dnsQN.find(bad_dnsQN) == 0:
            if len(dnsQN) == len(bad_dnsQN):
                return True
            elif dnsQN[len(bad_dnsQN)] == '.':
                return True
    return False


def check_stl(stl, bad_dnsQNames):
    # check template ID of STL first
    if stl.template_id != 0xcf00:
        return False
    for record in stl:
        if "dnsName" in record:
            dnsName = record["dnsName"]
            if check_dnsQName(dnsName, bad_dnsQNames):
                return True
    return False


def check_record(data, bad_dnsQNames):
    data_dict = data.as_dict()
    try:
        stml = data_dict["subTemplateMultiList"]
    except KeyError:
        return False

    for entry in stml:
        if "subTemplateList" in entry:
            for record in entry:
                for i in range(len(record)):
                    if type(record[i]) == pyfixbuf.STL:
                        if check_stl(record[i], bad_dnsQNames):
                            return True
    return False


def count_records(in_files, bad_dnsQNames):
    # Initialize found count
    found_count = 0

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    for data in shared_utils.process_files(in_files, infomodel):
        if check_record(data, bad_dnsQNames):
            found_count += 1

    # Print final count
    print("Found %d records with bad DNS query name(s)." % found_count)
    return found_count


def filter_to_screen(in_files, bad_dnsQNames, out_fields, delim):
    # Print headers
    field_string = ""
    for field in out_fields:
        field_string += delim + field
    print(field_string.strip(punctuation))

    # Set up InfoModel
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    for data in shared_utils.process_files(in_files, infomodel):
        if check_record(data, bad_dnsQNames):
            shared_utils.print_fields(data, out_fields, delim, infomodel)


def filter_to_ipfix(in_files, bad_dnsQNames, out_file):
    # Set up export capabilities
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)
    exporter = pyfixbuf.Exporter()
    exporter.init_file(out_file)
    exp_session = pyfixbuf.Session(infomodel)
    exp_buf = pyfixbuf.Buffer(auto=True)
    exp_buf.init_export(exp_session, exporter)

    # Initialize a filtered record counter
    found_count = 0

    for data in shared_utils.process_files(in_files, infomodel, exp_session):
        if check_record(data, bad_dnsQNames):
            found_count += 1
            shared_utils.print_to_ipfix(data, exp_buf)

    print("Wrote %d records to %s." % (found_count, out_file), file=sys.stderr)
