#! /usr/bin/env python
## ------------------------------------------------------------------------
## Copyright (C) 2018-2022 Carnegie Mellon University. All Rights Reserved.
## ------------------------------------------------------------------------
## See license information in LICENSE-OPENSOURCE.txt
## ------------------------------------------------------------------------

from __future__ import print_function
import argparse
import sys
import os
import IPFinder_utils
import shared_utils

try:
    import silk
    have_ipsets = True
except ImportError:
    have_ipsets = False


class Args(object):
    ''' Class used to process command-line arguments '''
    pass


if __name__ == '__main__':
    shared_utils.print_warning_message()

    in_files = []
    bad_IPs = []
    fields_to_match = []
    args = Args()

    # Create command-line argument parser and add arguments
    parser = argparse.ArgumentParser(prog='IPFinder.py',
                                     description="""
Find IPFIX records that contain IP addresses. """)

    parser.add_argument('--input', nargs=1, required=True,
                        help='input file or directory (no *)')

    ip_group = parser.add_mutually_exclusive_group(required=True)
    if have_ipsets:
        ip_group.add_argument('--ipset', nargs=1,
                              help='IPSet of IP addresses to filter records by')
    ip_group.add_argument('--ip', nargs=1,
                          help='IP address to filter records by')

    ipfield_group = parser.add_mutually_exclusive_group(required=True)
    ipfield_group.add_argument('--dip', action='store_true',
                               help='Filter by destination IP')
    ipfield_group.add_argument('--sip', action='store_true',
                               help='Filter by source IP')
    ipfield_group.add_argument('--anyIP', action='store_true',
                               help='Filter by source or destination IP')

    output_group = parser.add_mutually_exclusive_group(required=True)
    output_group.add_argument('--outfields', nargs='+',
                              help="Fields to print from filtered records")
    output_group.add_argument('--outfile', nargs=1,
                              help="IPFIX file to output filtered records to")
    output_group.add_argument('--count', action='store_true',
                              help="Print total count of found records")

    parser.add_argument('--delim', nargs=1, default='|',
                        help="field delimiter for printed fields")

    # Add all command-line arguments to instance of class Args
    parser.parse_args(namespace=args)

    # handle input directory vs single input file
    if os.path.isdir(args.input[0]):
        for filename in os.listdir(args.input[0]):
            in_files.append(os.path.join(args.input[0], filename))
    else:
        in_files.append(args.input[0])

    # handle single IP vs IPSet
    if args.ip:
        bad_IPs.append(args.ip[0])
    else:
        ipset = silk.IPSet.load(args.ipset[0])
        ipset.convert(4)    # convert to an IPv4 set
        for ip_addr in ipset:
            bad_IPs.append(str(ip_addr))

    # handle --sip, --dip, --anyIP
    if args.dip:
        fields_to_match.append('destinationIPv4Address')
    elif args.sip:
        fields_to_match.append('sourceIPv4Address')
    else:
        fields_to_match.append('destinationIPv4Address')
        fields_to_match.append('sourceIPv4Address')

    # handle outputting field values to console or records to ipfix file
    if args.outfields:
        # output selected fields to console
        IPFinder_utils.filter_to_screen(in_files, bad_IPs, fields_to_match, args.outfields, args.delim[0])
    elif args.count:
        # output total count of filtered records
        print('Counting filtered records...')
        count = IPFinder_utils.count_records(in_files, bad_IPs, fields_to_match)
    else:
        # output filtered records to IPFIX file
        print('Filtering to %s...' % args.outfile[0], file=sys.stderr)
        IPFinder_utils.filter_to_ipfix(in_files, bad_IPs, fields_to_match, args.outfile[0])
