#! /usr/bin/env python
## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import pyfixbuf
import pyfixbuf.cert
import os
import argparse
import shared_utils


def list_iterator(lst, tmpl_id=0):
    '''
    This iterator traverses the passed subTemplateList or subTemplateMultiList
    and yields all records with template ID tmpl_id. If tmpl_id is not passed,
    it yields all records in the list (including records in nested lists).
    '''
    if type(lst) == pyfixbuf.STML:
        for entry in lst:
            for record in entry:
                if type(record) != pyfixbuf.Record:
                    raise TypeError('iterator returned non-record type')
                if tmpl_id != 0:
                    if tmpl_id == record.template.template_id:
                        yield record
                else:
                    yield record
                for infoelem in record:
                    if (type(infoelem) == pyfixbuf.STL) or (type(infoelem) == pyfixbuf.STML):
                        for sub_record in list_iterator(infoelem):
                            if type(sub_record) != pyfixbuf.Record:
                                raise TypeError('iterator returned non-record type')
                            if tmpl_id != 0:
                                if tmpl_id == sub_record.template.template_id:
                                    yield sub_record
                            else:
                                yield sub_record
    elif type(lst) == pyfixbuf.STL:
        for record in lst:
            if type(record) != pyfixbuf.Record:
                raise TypeError('iterator returned non-record type')
            if tmpl_id != 0:
                if tmpl_id == record.template.template_id:
                    yield record
            else:
                yield record
            for i in range(len(record)):
                if (type(record[i]) == pyfixbuf.STL) or (type(record[i]) == pyfixbuf.STML):
                    for sub_record in list_iterator(record[i]):
                        if type(sub_record) != pyfixbuf.Record:
                            raise TypeError('iterator returned non-record type')
                        if tmpl_id != 0:
                            if tmpl_id == sub_record.template.template_id:
                                yield sub_record
                        else:
                            yield sub_record


if __name__ == '__main__':
    class Args(object):
        pass

    # parse args
    in_files = []
    args = Args()

    # Create command-line argument parser and add arguments
    parser = argparse.ArgumentParser(prog='list_iterator.py',
                                     description="""
Prints the template IDs used in the STMLs and STLs for each record
in the input file(s).""")

    parser.add_argument('--input', '--in', nargs=1, required=True,
                        help='input file or directory (no *)')

    # Add all command-line arguments to instance of class Args
    parser.parse_args(namespace=args)

    # handle input directory vs single input file
    if os.path.isdir(args.input[0]):
        for filename in os.listdir(args.input[0]):
            in_files.append(args.input[0] + '/' + filename)
    else:
        in_files.append(args.input[0])

    # Set up collection capabilities
    infomodel = pyfixbuf.InfoModel()
    pyfixbuf.cert.add_elements_to_model(infomodel)

    top_level_record_count = 0

    list_record_count = 0

    for data in shared_utils.process_files(in_files, infomodel):
        top_level_record_count += 1
        print("Record %d:" % top_level_record_count)
        tmpl = data.template
        for ie_spec in tmpl:
            infoelem = infomodel.get_element(ie_spec.name)
            if infoelem.semantic == pyfixbuf.Semantic.LIST:
                for record in list_iterator(data.__getitem__(infoelem.name)):
                    print("List Record Template ID: %x" % record.template.template_id)
                    list_record_count += 1

    print("Total Top-Level Records: %d" % top_level_record_count)
    print("Total List Records: %d" % list_record_count)
