#!/usr/bin/env python

#######################################################################
# Copyright (C) 2008 by Carnegie Mellon University.
#
# @OPENSOURCE_HEADER_START@
#
# Use of the SILK system and related source code is subject to the terms
# of the following licenses:
#
# GNU Public License (GPL) Rights pursuant to Version 2, June 1991
# Government Purpose License Rights (GPLR) pursuant to DFARS 252.225-7013
#
# NO WARRANTY
#
# ANY INFORMATION, MATERIALS, SERVICES, INTELLECTUAL PROPERTY OR OTHER
# PROPERTY OR RIGHTS GRANTED OR PROVIDED BY CARNEGIE MELLON UNIVERSITY
# PURSUANT TO THIS LICENSE (HEREINAFTER THE "DELIVERABLES") ARE ON AN
# "AS-IS" BASIS. CARNEGIE MELLON UNIVERSITY MAKES NO WARRANTIES OF ANY
# KIND, EITHER EXPRESS OR IMPLIED AS TO ANY MATTER INCLUDING, BUT NOT
# LIMITED TO, WARRANTY OF FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABILITY, INFORMATIONAL CONTENT, NONINFRINGEMENT, OR ERROR-FREE
# OPERATION. CARNEGIE MELLON UNIVERSITY SHALL NOT BE LIABLE FOR INDIRECT,
# SPECIAL OR CONSEQUENTIAL DAMAGES, SUCH AS LOSS OF PROFITS OR INABILITY
# TO USE SAID INTELLECTUAL PROPERTY, UNDER THIS LICENSE, REGARDLESS OF
# WHETHER SUCH PARTY WAS AWARE OF THE POSSIBILITY OF SUCH DAMAGES.
# LICENSEE AGREES THAT IT WILL NOT MAKE ANY WARRANTY ON BEHALF OF
# CARNEGIE MELLON UNIVERSITY, EXPRESS OR IMPLIED, TO ANY PERSON
# CONCERNING THE APPLICATION OF OR THE RESULTS TO BE OBTAINED WITH THE
# DELIVERABLES UNDER THIS LICENSE.
#
# Licensee hereby agrees to defend, indemnify, and hold harmless Carnegie
# Mellon University, its trustees, officers, employees, and agents from
# all claims or demands made against them (and any related losses,
# expenses, or attorney's fees) arising out of, or relating to Licensee's
# and/or its sub licensees' negligent use or willful misuse of or
# negligent conduct or willful misconduct regarding the Software,
# facilities, or other rights or assistance granted by Carnegie Mellon
# University under this License, including, but not limited to, any
# claims of product liability, personal injury, death, damage to
# property, or violation of any laws or regulations.
#
# Carnegie Mellon University Software Engineering Institute authored
# documents are sponsored by the U.S. Department of Defense under
# Contract F19628-00-C-0003. Carnegie Mellon University retains
# copyrights in all material produced under this contract. The U.S.
# Government retains a non-exclusive, royalty-free license to publish or
# reproduce these documents, or allow others to do so, for U.S.
# Government purposes only pursuant to the copyright license under the
# contract clause at 252.227.7013.
#
# @OPENSOURCE_HEADER_END@
#
#######################################################################
# $SiLK: rwidsquery 10790 2008-03-06 00:00:30Z tonyc $
#######################################################################
# rwidsquery
#
# Invoke rwfilter to pull flows matching Snort signatures/alerts
#######################################################################

# netsa-python isn't released yet
# from netsa.utils import get_protocol_number
from optparse import OptionParser, OptionValueError

import re
import datetime
import subprocess
import tempfile
import calendar
import subprocess
import sys


# BEGIN netsa-python code ----------------------------------------------------
def build_protocol_map():
    try:
        f = None
        proto_map = {}
        try:
            f = file('/etc/protocols', 'r')
            for l in f:
                l = l.strip()
                if l <> '' and l[0] <> '#':
                    try:
                        tokens = l.split()
                        proto_map[int(tokens[1])] = str.lower(tokens[0])
                    except:
                        pass
            return proto_map
        finally:
            if f:
                f.close()
    except:
        return { 1: 'icmp', 6: 'tcp', 17: 'udp' }


protocol_map = build_protocol_map()
protocol_map_reverse = dict([(v,k) for (k,v) in protocol_map.iteritems()])


def get_protocol_name(number):
    return protocol_map[number]

def get_protocol_number(name):
    return protocol_map_reverse[str.lower(name)]
# END netsa-python code ------------------------------------------------------

default_year = datetime.datetime.now().year
default_tolerance = 3600

def parse_options():

    filterargs = []
    file_types = [ 'full', 'fast', 'rule' ];
    op = OptionParser(usage="""%prog [options] [infile]""")
    op.set_defaults(
        output_file="stdout",
        year=default_year,
        tolerance=default_tolerance,
        type='auto')
    op.add_option("-o", "--output-file", dest="output_file",
        help="write flow records to this file (default: stdout)")
    op.add_option("-t", "--intype", dest="type",
        help="input file type (one of 'fast', 'full', or 'rule')")
    op.add_option("-s", "--start-date", dest="start",
        help="start date for flow selection")
    op.add_option("-e", "--end-date", dest="end",
        help="end date for flow selection")
    op.add_option("-y", "--year", dest="year",
        help="year to be used for alert timestamps")
    op.add_option("--tolerance", dest="tolerance",
        help="time tolerance in seconds between alert and flow timestamps")
    op.add_option("-c", "--config-file", dest="config_file",
        help="Snort configuration file location")
    op.add_option("-m", "--mask", dest="mask",
        help="list of rwfilter predicates to mask")
    op.add_option("--dry-run", dest="dry_run", action="store_true",
        help="display rwfilter command without running it")
    op.add_option("-v", "--verbose", dest="verbose", action="store_true",
        help="print rwfilter command before it's invoked")
    args = sys.argv[1:]
    
    try:
        args, filterargs = args[:args.index('--')], args[args.index('--')+1:]
    except ValueError:
        pass
    
    options, args = op.parse_args(args)

    if options.type not in file_types:
        op.error("file type must be one of %s" %(file_types))
    if options.type == 'rule' and not (options.start and options.end):
        op.error("start and end times for snort rule queries must be specified")
    if options.type in ['full', 'fast'] and (options.start or options.end):
        op.error("start and end arguments not supported for alert file input")
    if options.dry_run:
        options.verbose = True
    if len(args) > 1:
        op.error("too many command-line arguments provided")
    return options, args, filterargs

re_full = re.compile(r"""
    (?P<mon>\d+)/(?P<day>\d+)-(?P<hour>\d+):(?P<min>\d+):(?P<sec>\d+)
    (?:
    .*?
    (?P<sip>[\d.]+):(?P<sport>\d+)\s*->\s*(?P<dip>[\d.]+):(?P<dport>\d+)
    \s*
    (?P<proto>[\S]+)
    )?
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_fast = re.compile(r"""
    (?P<mon>\w+)\s*(?P<day>\d+)\s*(?P<hour>\d+):(?P<min>\d+):(?P<sec>\d+)
    (?:
    .*?
    {(?P<proto>[\S]+)}
    \s*
    (?P<sip>[\d.]+):(?P<sport>\d+)\s*->\s*(?P<dip>[\d.]+):(?P<dport>\d+)
    )?
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_rule = re.compile(r"""
    (?P<action>\w+)
    \s+
    (?P<proto>\w+)
    \s+
    (?P<sip>\S+)
    \s+
    (?P<sport>\S+)
    \s+
    [<>\-]+
    \s+
    (?P<dip>\S+)
    \s+
    (?P<dport>\S+)
    \s+
    \(
        (?P<options>.+)
    \)
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

re_rule_options = re.compile(r"""
    (?P<option>[\S^:]+?)
    \s*
    :
    \s*
    (?P<value>[^:]+)
    \s*
    ;
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

supported_rule_opts = {
    'ip_proto': 'protocol',
    'itype': 'icmp-type',
    'icode': 'icmp-code',
    'flags': 'tcp-flags'
}

def expand_ip_spec(prefix, val, tempfiles):
    if val == "any":
#        return [], []
        return []
    args = []
    if val.find('[') >= 0:
        addrs = filter(lambda x: x.find('!'), val.strip('[]').split(','))
        notaddrs = map(lambda x: x.strip('!'),
            filter(lambda x: not x.find('!'), val.strip('[]').split(',')))

        if len(addrs) > 0:
            f = tempfile.NamedTemporaryFile(suffix='.set')
            args.append('--%sipset=%s' %(prefix, f.name))
            tempfiles.append(f)
            proc = subprocess.Popen('rwsetbuild stdin stdout',
                                       shell=True,
                                       stdin=subprocess.PIPE,
                                       stdout=subprocess.PIPE,
                                       stderr=subprocess.PIPE,
                                   )
            out, err = proc.communicate('\n'.join(addrs))
            f.write(out)
            f.file.flush()

# rwfilter doesn't currently support both --sipset and --not-sipset
# so we can't do this.  Snort address specs with ! in them will be
# ignored until further notice.
#
#        if len(notaddrs) > 0:        
#            f = tempfile.NamedTemporaryFile(suffix='.set')
#            args.append('--not-%sipset=%s' %(prefix, f.name))
#            tempfiles.append(f)
#            proc = subprocess.Popen('rwsetbuild stdin stdout',
#                                       shell=True,
#                                       stdin=subprocess.PIPE,
#                                       stdout=subprocess.PIPE,
#                                       stderr=subprocess.PIPE,
#                                   )
#            out, err = proc.communicate('\n'.join(notaddrs))
#            f.write(out)
#            f.file.flush()
    else:
        args.append("--%saddress=%s" %(prefix, val))

    return args#, tempfiles

def expand_port_spec(prefix, val):
    args = []

    if val == "any":
        return args

    portlist = ""
    m = re.match('(?P<neg>!)?(?P<p1>\d+)?(?P<range>:)?(?P<p2>\d+)?', val)
    if m:
        if not m.group('neg'):
            if m.group('p1') and m.group('p2'):          # 123:456
                portlist = "%d-%d" %(int(m.group('p1')), int(m.group('p2')))
            elif m.group('p1') and m.group('range'):     # 123:
                portlist = "%d-65535" %(int(m.group('p1')))
            elif m.group('range') and m.group('p2'):     # :456
                portlist = "0-%d" %(int(m.group('p2')))
            else:
                portlist = "%d" %(int(m.group('p1')))    # 123
        else:
            if m.group('p1') and m.group('p2'):          # !123:456
                portlist = "0-%d,%d-65535" %(int(m.group('p1'))-1,
                    int(m.group('p2'))+1)
            elif m.group('p1') and m.group('range'):     # !123:
                portlist = "0-%d" %(int(m.group('p1'))-1)
            elif m.group('range') and m.group('p2'):     # !:456
                portlist = "%d-65535" %(int(m.group('p2'))+1)
            else:                                        # !123
               portlist = "0-%d,%d-65535" %(int(m.group('p1'))-1,
                int(m.group('p1'))+1) 

        args.append('--%sport=%s' %(prefix, portlist))

    return args


def process_rule(file, vars, tempfiles):
    """Convert a Snort rule into corresponding rwfilter arguments.
    
    file -- the input file to be processed
    vars -- dict of snort variables from snort.conf
    tempfiles -- list of temporary files created (caller cleans them up)
    """
    
    opts = []
    rule = ' '.join(file.readlines())
    if vars:
        for k in vars.keys():
            rule = rule.replace("$%s" %(k), vars[k])

    matches = re_rule.search(rule)
    if matches:
        for prefix in [ 's', 'd']:
            opts += expand_ip_spec(prefix,
                matches.group(prefix + 'ip'), tempfiles)
            opts += expand_port_spec(prefix, matches.group(prefix + 'port'))

        oi = re_rule_options.finditer(matches.group('options'))
        for opt in oi:
            if supported_rule_opts.has_key(opt.group('option')):
                optname = supported_rule_opts[opt.group('option')]
                optval = opt.group('value')
                if optname.find('icmp') >= 0:
                    # Snort and rwfilter denote ranges differently
                    if optval.startswith('<'):
                        optval = optval.replace('<', '')
                        optval = "0-%s" %( int(optval)-1 )
                    elif optval.startswith('>'):
                        optval = optval.replace('>', '')
                        optval = "%s-255" %( int(optval)+1 )
                elif optname.find('flags') >= 0:
                    # Tweak flags syntax.
                    if optval.find(',') > 0:
                        optval = optval[:optval.index(',')]
                    optval = optval.replace('+', '')
                    optval = optval.replace('*', 'C') # 1 == CWR
                    optval = optval.replace('1', 'C') # 1 == CWR
                    optval = optval.replace('2', 'E') # 2 == ECE
                opts.append('--%s=%s' %(optname, optval))

    return opts

def process_alert(file, type, year, tolerance):
    """Convert a Snort alert into corresponding rwfilter arguments.
    
    file -- the input file to be processed
    type -- type of alerts in the file (either 'fast' or 'full')
    year -- the year to be assumed for dates, since Snort timestamps lack one
    
    """
    
    opts = []
    matches = None
    
    record = ' '.join(file.readlines())
    if type == 'full':
        matches = re_full.search(record)
        month = matches.group('mon')
    elif type == 'fast':
        matches = re_fast.search(record)
        month = list(calendar.month_abbr).index(matches.group('mon'))
    else:
        return None
        
    if matches:
        dt = datetime.datetime(
            int(year),
            int(month),
            int(matches.group('day')),
            int(matches.group('hour')),
            int(matches.group('min')),
            int(matches.group('sec')))

        stime_min = dt - datetime.timedelta(seconds=tolerance)
        stime_max = dt + datetime.timedelta(seconds=tolerance)
        
        start_date = datetime.datetime(stime_min.year, stime_min.month,
            stime_min.day, stime_min.hour)
        end_date = datetime.datetime(stime_max.year, stime_max.month,
            stime_max.day, stime_max.hour)

        opts.append('--start-date=%s' %(start_date.strftime("%Y/%m/%d:%H")))
        opts.append('--end-date=%s' %(end_date.strftime("%Y/%m/%d:%H")))
        
        opts.append('--stime=%s-%s' %(stime_min.strftime("%Y/%m/%d:%H:%M:%S"),
            stime_max.strftime("%Y/%m/%d:%H:%M:%S")))
        
        if matches.group('sip'):
            opts.append("--saddress=%s" %(matches.group('sip')))

        if matches.group('sport') != "any":
            opts.append("--sport=%s" %(matches.group('sport')))

        if matches.group('dip'):
            opts.append("--daddress=%s" %(matches.group('dip')))

        if matches.group('dport') != "any":
            opts.append("--dport=%s" %(matches.group('dport')))

        if matches.group('proto'):
            opts.append("--protocol=%d" %(
                get_protocol_number(matches.group('proto'))))
            
        return opts

def get_snort_vars(file):

    lines = [ item.replace('var', '').strip()
        for item in file.readlines()
        if item.strip().startswith('var') ]
    vars=dict([ i.split(None, 1) for i in lines ])
    return vars

def main():

    cmdline = [ "rwfilter" ]
    tempfiles = []
    options, args, filterargs = parse_options()

    if len(args) == 1:
        try:
            f = open(args[0],'rb')
        except IOError:
            sys.stderr.write("file %s not found\n" %(args[0]))
            return
    else:
        f = sys.stdin

    if options.type in [ 'full', 'fast' ]:
        cmdline += process_alert(f, options.type, options.year,
            int(options.tolerance))

    elif options.type == 'rule':
        vars = None
        if options.config_file:
            try:
                conf_file = open(options.config_file, 'rb')
            except IOError:
                sys.stderr.write("couldn't load snort conf file %s\n"
                    %(options.config_file))
                return
            vars = get_snort_vars(conf_file)

        cmdline.append('--start-date=%s' %(options.start))
        cmdline.append('--end-date=%s' %(options.end))

        cmdline.append('--stime=%s-%s' %(options.start, options.end))
        cmdline += (process_rule(f, vars, tempfiles))


    # Mask out rwfilter options the user doesn't want to filter on
    if options.mask:
        cmdline = filter(lambda x:    
            not x.rsplit('=')[0].replace('--','')
            in options.mask.split(','),
            cmdline)

    # Add in extra rwfilter args from the command line
    cmdline += filterargs

    cmdline.append('--pass=%s' %(options.output_file))

    if options.verbose:
        print >> sys.stderr, ' '.join(cmdline)
    if not options.dry_run:
        subprocess.call(cmdline)
    
if __name__ == '__main__':
    main()
