#! /usr/bin/env python
## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import sys
import os
import pyfixbuf
import shared_utils
import argparse
import time
import copy
import contextlib
import re

SCRIPT_PATH = os.path.dirname(os.path.realpath(__file__))

RECORD_TIME = False


@contextlib.contextmanager
def open_file_or_stdout(filename=None):
    if filename and filename != '-':
        fh = open(filename, "w")
    else:
        fh = sys.stdout
    try:
        yield fh
    finally:
        if fh is not sys.stdout:
            fh.close()


class Args(object):
    pass


def natural_sort_key(s, _nsre=re.compile('([0-9]+)')):
    return [int(text) if text.isdigit() else text.lower()
            for text in _nsre.split(s)]


class Template(object):
    indent1 = '    '
    indent2 = '        '
    indentc = "|--"

    def __init__(self, tid, children, fields):
        self.tid = tid
        self.children = children
        self.fields = fields

    def print_template(self, f):
        name = TEMPLATES.options.get(self.tid)
        if name is not None:
            name_string = name + "(" + format(self.tid, 'x') + ")"
        else:
            name_string = format(self.tid, 'x')

        print(self.indent1 + name_string, file=f)

        for field in self.fields:
            print(self.indent2 + field, file=f)

    def print_children(self, f, depth):
        for child in self.children:
            child_tmpl = TEMPLATES.templates.get(child)
            name = TEMPLATES.options.get(child_tmpl.tid)
            if name is not None:
                child_string = name + "(" + format(child_tmpl.tid, 'x') + ")"
            else:
                child_string = format(child_tmpl.tid, 'x')

            print((self.indent1 * depth) + self.indentc + child_string, file=f)

            child_tmpl.print_children(f, depth + 1)


class Tree(object):
    def __init__(self):
        self.roots = {}
        self.templates = {}
        self.field_paths = set([])
        self.yaf_flag = False
        self.screen_flag = False
        self.fieldspec_flag = False
        self.options = {}

    def print_tree(self, filename):

        with open_file_or_stdout(filename) as f:
            print("Template hierarchy:", file=f)
            for tid, tmpl in self.roots.items():
                name = TEMPLATES.options.get(tid)
                if name is not None:
                    name_string = name + "(" + format(tid, 'x') + ")"
                else:
                    name_string = format(tid, 'x')
                print(name_string, file=f)
                tmpl.print_children(f, 1)

    def print_tmpl_fields(self, filename):

        top_level_fields = set([])
        with open_file_or_stdout(filename) as f:
            print("Fields in templates:", file=f)
            print("Top-level Fields:", file=f)
            for tid, tmpl in self.roots.items():
                for field in tmpl.fields:
                    if field not in top_level_fields:
                        print(tmpl.indent1 + field, file=f)
                        top_level_fields.add(field)

            print("DPI Templates and Fields:", file=f)
            for tid, tmpl in self.templates.items():
                tmpl.print_template(f)

    def print_field_paths(self, filename):

        field_paths = list(self.field_paths)
        field_paths.sort()
        with open_file_or_stdout(filename) as f:
            print("Field paths:", file=f)
            for path in field_paths:
                print(path, file=f)


TEMPLATES = Tree()   # Global variable to hold tree of templates


def is_yaf_template(template):
    return ((template & 0xF000) >> 12) == 0xb


def format_template(template):

    global TEMPLATES

    if TEMPLATES.yaf_flag and is_yaf_template(template):
        return "" if TEMPLATES.fieldspec_flag else "top"
    elif TEMPLATES.fieldspec_flag:
        return TEMPLATES.options.get(template, "%#06x" % template)
    else:
        return "%#06x" % (template)


def format_basic_list(field, count):

    global TEMPLATES

    if TEMPLATES.fieldspec_flag:
        return "basicList[%s]" % (field)
    else:
        return "basicList (%d) - %s" % (count, field)


def format_field(field, ancestors=None,
                 bl_field=None, bl_count=None):

    global TEMPLATES

    if not ancestors:
        ancestors = []

    if TEMPLATES.fieldspec_flag:
        sep = "/"
    else:
        sep = " - "

    if TEMPLATES.fieldspec_flag:
        s = field + ": ipfix:"
    else:
        s = field + ": "

    return s + sep.join(
        [
            format_template(a)
            for a in ancestors
            if not (TEMPLATES.fieldspec_flag and is_yaf_template(a))
        ] + [
            field
            if bl_count is None
            else format_basic_list(bl_field, bl_count)
        ]
    )


def process_subtmpl(data, passed_ancestors):
    global TEMPLATES
    ancestors = copy.deepcopy(passed_ancestors)

    tmpl = TEMPLATES.templates.get(data.template.template_id)
    if tmpl is None:
        tmpl = Template(data.template.template_id, set(), set())

    data_tmpl = data.template
    bl_count = 0
    ancestors.append(data_tmpl.template_id)

    for field in data.iterfields():
        value = field.value
        curr_field = field.ie.name
        datatype = field.ie.type
        if datatype == pyfixbuf.DataType.BASIC_LIST:
            fieldname = "basicList(" + str(bl_count) + "): " + value.element.name
            tmpl.fields.add(fieldname)
            # bl_path = ""
            bl_path = format_field(
                value.element.name, ancestors=ancestors,
                bl_field=value.element.name, bl_count=bl_count
            )
            # for ancestor in ancestors:
            #     bl_path += format(ancestor, 'x') + " - "
            # bl_path = format_ancestors(ancestors)
            # bl_path += "basicList(" + str(bl_count) + ") - " + value.element.name
            # if TEMPLATES.yaf_flag and bl_path[0] == 'b':
            #     bl_path = "top" + bl_path[4:]
            # TEMPLATES.bl_paths.add(bl_path)
            TEMPLATES.field_paths.add(bl_path)
            bl_count += 1

        elif datatype == pyfixbuf.DataType.SUB_TMPL_LIST:
            tmpl.children.add(value.template_id)
            for record in value:
                process_subtmpl(record, list(ancestors))

        elif datatype == pyfixbuf.DataType.SUB_TMPL_MULTI_LIST:
            for entry in value:
                for record in entry:
                    tmpl.children = tmpl.children.add(record.template.template_id)
                    process_subtmpl(record, list(ancestors))

        else:
            tmpl.fields.add(curr_field)
            # field_path = format_field(curr_field, ancestors=ancestors)
            field_path = format_field(curr_field, ancestors=ancestors)
            TEMPLATES.field_paths.add(field_path)

    TEMPLATES.templates[data_tmpl.template_id] = tmpl


def process_toptmpl(data):
    global TEMPLATES

    tmpl = TEMPLATES.roots.get(data.template.template_id)
    if tmpl is None:
        tmpl = Template(data.template.template_id, set(), set())

    data_tmpl = data.template
    bl_count = 0
    ancestors = [data_tmpl.template_id]

    for field in data.iterfields():
        value = field.value
        curr_field = field.ie.name
        datatype = field.ie.type
        if datatype == pyfixbuf.DataType.BASIC_LIST:
            fieldname = "basicList(" + str(bl_count) + "): " + value.element.name
            tmpl.fields.add(fieldname)
            bl_path = format_field(
                value.element.name, ancestors=ancestors,
                bl_field=value.element.name, bl_count=bl_count
            )
            # TEMPLATES.bl_paths.add(bl_path)
            TEMPLATES.field_paths.add(bl_path)
            bl_count += 1

        elif datatype == pyfixbuf.DataType.SUB_TMPL_LIST:
            tmpl.children.add(value.template_id)
            for record in value:
                process_subtmpl(record, ancestors)

        elif datatype == pyfixbuf.DataType.SUB_TMPL_MULTI_LIST:
            for entry in value:
                for record in entry:
                    tmpl.children.add(record.template.template_id)
                    process_subtmpl(record, ancestors)

        else:
            tmpl.fields.add(curr_field)
            field_path = format_field(curr_field, ancestors=ancestors)
            TEMPLATES.field_paths.add(field_path)

    TEMPLATES.roots[data_tmpl.template_id] = tmpl


def see_data(in_files):
    pass


if __name__ == '__main__':
    shared_utils.print_warning_message()

    in_files = []
    args = Args()
    tmpl_hierarchy_filename = "tmpl_hierarchy.txt"
    tmpl_fields_filename = "fields_in_tmpls.txt"
    field_paths_filename = "field_paths.txt"
    bl_paths_filename = "bl_paths.txt"
    all_flag = True

    # Create command-line argument parser and add arguments
    parser = argparse.ArgumentParser(prog=sys.argv[0],
                                     description='descripton: Visualize data in an IPFIX file')
    parser.add_argument('--yaf', action='store_true',
                        help='print "top" instead of top-level template IDs')
    parser.add_argument('--screen', action='store_true',
                        help='print outputs to screen instead of file(s)')
    parser.add_argument('--fpaths-fieldspec', action='store_true',
                        help='show field paths as field specs')

    input_group = parser.add_argument_group('required input arguments')
    input_group.add_argument('--input', metavar='INPUT/PATH', nargs=1,
                             required=True,
                             help='input file or directory (no *)')

    indy_group = parser.add_argument_group(
        'optional output arguments', '(prints all outputs by default)')
    indy_group.add_argument(
        '--tree', action='store_true',
        help='Print template hierarchy to "%s"' % tmpl_hierarchy_filename)
    indy_group.add_argument(
        '--fpaths', action='store_true',
        help='Print field paths to "%s"' % field_paths_filename)
    indy_group.add_argument(
        '--fields', action='store_true',
        help='Print fields in templates to "%s"' % tmpl_fields_filename)

    # Add all command-line arguments to instance of class Args
    parser.parse_args(namespace=args)

    # handle input directory vs single input file
    if os.path.isdir(args.input[0]):
        for filename in os.listdir(args.input[0]):
            in_files.append(os.path.join(args.input[0], filename))
    else:
        in_files.append(args.input[0])

    if args.yaf:
        TEMPLATES.yaf_flag = True

    if args.screen:
        TEMPLATES.screen_flag = True

    if args.fpaths_fieldspec:
        TEMPLATES.fieldspec_flag = True

    if args.tree or args.fpaths or args.fields:
        all_flag = False

    # Set up collection capabilities
    infomodel = pyfixbuf.InfoModel()
    infomodel.read_from_xml_file(
        os.path.join(
            SCRIPT_PATH,
            "..",
            "src",
            "pyfixbuf",
            "cert",
            "cert_ipfix.xml"
        )
    )
    templateId = infomodel.get_element("templateId")
    informationElementName = infomodel.get_element("informationElementName")

    if RECORD_TIME:
        start_time = time.clock()

    # Iterate through all records
    print("Processing data...")
    for data in shared_utils.process_files(in_files, infomodel):
        tmpl = data.template
        if templateId in tmpl:
            TEMPLATES.options[data["templateId"]] = data["templateName"]
        elif informationElementName not in tmpl:
            process_toptmpl(data)

    # Print "Visualizations" to files
    if all_flag or args.tree:
        TEMPLATES.print_tree("-" if args.screen else tmpl_hierarchy_filename)
    if all_flag or args.fields:
        TEMPLATES.print_tmpl_fields("-" if args.screen else tmpl_fields_filename)
    if all_flag or args.fpaths:
        TEMPLATES.print_field_paths("-" if args.screen else field_paths_filename)

    if RECORD_TIME:
        end_time = time.clock()
        elapsed = end_time - start_time
        print("Time elapsed: %f seconds" % elapsed)
