"""
Shell functions for RAVE analyses.
"""

__version__ = "$Rev: 7149 $"


import sys, re, subprocess, time
from rave.plugins.decorators import *
from rave.plugins.dataset import Dataset
from rave.plugins.times import *
from datetime import datetime

nafscii_header_map = { 'source': 'sensor',
                       'sip':    'sip',
                       'dip':    'dip',
                       'sp':     'sport',
                       'dp':     'dport',
                       'proto':  'proto',
                       'shosts': 'shosts',
                       'dhosts': 'dhosts',
                       'flo':    'flows',
                       'pkt':    'packets',
                       'oct':    'bytes' }


def parse_nafscii(file):
    headings = file.readline().split()
    column_names = ['time'] + [nafscii_header_map.get(h,h)
                               for h in headings[2:]]
    data = Dataset(column_names)
    for line in file:
        row = line.split()
        data.add_row([row[0] + ' ' + row[1]] + row[2:])
    return data

re_pipe_and_whitespace = re.compile(r'\s*\|\s*')

def parse_generic(file, column_names, delim=re_pipe_and_whitespace):
    data = Dataset(column_names)
    for line in file:
        row = delim.split(line.strip())
        data.add_row(row)
    return data
    
def parse_generic_silk(header_map, file):
    def safe_get(h):
        if h in header_map:
            return header_map[h]
        else:
            return h
    headline = file.readline()
    headings = re_pipe_and_whitespace.split(headline.strip())[:-1]
    column_names = [safe_get(h) for h in headings]
    data = Dataset(column_names)
    # Determine time and num format columns, for translation while processing
    time_columns = []
    float_columns = []
    int_columns = []
    for i in xrange(len(column_names)):
        if column_names[i] in ['time', 'stime', 'etime', 'mintime', 'maxtime']:
            time_columns.append(i)
        if column_names[i] in ['sport', 'dport', 'proto', 'records']:
            int_columns.append(i)
        # Added pct_of_total and cumul_pct to this, even thought it's rwstats
        # and not just "generic silk." I don't know if this is Highest Good
        # or not, but it works.... --pg
        if column_names[i] in ['shosts', 'dhosts', 'flows', 'packets',
                               'bytes', 'dur', 'ppf', 'bpf', 'pct_of_total',
                               'cumul_pct']:
            float_columns.append(i)
    # Process data
    for line in file:
        row = re_pipe_and_whitespace.split(line.strip())[:-1]
        for i in time_columns:
            row[i] = datetime_obj(row[i])
        for i in int_columns:
            row[i] = int(row[i])
        for i in float_columns:
            row[i] = float(row[i])
        data.add_row(row)
    return data

rwtotal_header_map = { 'sIP_First8':  'sip',
                       'sIP_First16': 'sip',
                       'sIP_First24': 'sip',
                       'sIP_Last8':   'sip',
                       'sIP_Last16':  'sip',
                       'sIP_Last24':  'sip',
                       'dIP_First8':  'dip',
                       'dIP_First16': 'dip',
                       'dIP_First24': 'dip',
                       'dIP_Last8':   'dip',
                       'dIP_Last16':  'dip',
                       'dIP_Last24':  'dip',
                       'sPort':       'sport',
                       'dPort':       'dport',
                       'protocol':    'proto',
                       'elapsed':     'dur',
                       'icmpTypeCod': 'icmp_tc',  
                       'packets':     'ppf',
                       'bytes':       'bpf',
                       'Records':     'flows',
                       'Bytes':       'bytes',
                       'Packets':     'packets' }


def parse_rwtotal(file):
    return parse_generic_silk(rwtotal_header_map, file)

rwuniq_header_map = {  'sensor':    'sensor',
                       'sIP':       'sip',
                       'dIP':       'dip',
                       'sPort':     'sport',
                       'dPort':     'dport',
                       'pro':       'proto',
                       'dur':       'duration',  
                       'packets':   'ppf',
                       'bytes':     'bpf',
                       'flags':     'flags',
                       'sTime':     'stime',
                       'eTime':     'etime',
                       'in':        'intf_in',
                       'out':       'intf_out',
                       'nhIP':      'next_hop',
                       'cla':       'class',
                       'type':      'type',
                       'Records':   'flows',
                       'Packets':   'packets',
                       'Bytes':     'bytes',
                       'min_sTime': 'mintime',
                       'max_eTime': 'maxtime',
                       'sval':      'sval',
                       'dval':      'dval' }

def parse_rwuniq(file):
    return parse_generic_silk(rwuniq_header_map, file)

rwcount_header_map = { 'Date':    'time',
                       'Records': 'flows',
                       'Packets': 'packets',
                       'Bytes':   'bytes' }

def parse_rwcount(file):
    return parse_generic_silk(rwcount_header_map, file)

rwcut_header_map = { 'sIP':        'sip',
                     'dIP':        'dip',
                     'sPort':      'sport',
                     'dPort':      'dport',
                     'pro':        'proto',
                     'packets':    'packets',
                     'bytes':      'bytes',
                     'flags':      'flags',
                     'sTime':      'stime',
                     'dur':        'dur',
                     'eTime':      'etime',
                     'sen':        'sensor',
                     'in':         'intf_in',
                     'out':        'intf_out',
                     'nhIP':       'next_hop',
                     'cla':        'class',
                     'type':       'type',
                     'sTime+msec': 'stime_msec',
                     'eTime+msec': 'etime_msec',
                     'dur+msec':   'dur_msec',
                     'iTy':        'icmp_type',
                     'iCo':        'icmp_code' }

def parse_rwcut(file):
    return parse_generic_silk(rwcut_header_map, file)

rwaddrcount_header_map = { 'sIP':        'sip',
                           'sIP_Uniq':   'shosts',
                           'dIP':        'sip',
                           'dIP_Uniq':   'dhosts',
                           'Bytes':      'bytes',
                           'Packets':    'packets',
                           'Records':    'records',
                           'Start_Time': 'mintime',
                           'End_Time':   'maxtime' }

def parse_rwaddrcount(file):
    return parse_generic_silk(rwaddrcount_header_map, file)

rwstats_header_map = { 'sIP':        'sip',
                       'dIP':        'dip',
                       'sPort':      'sport',
                       'dPort':      'dport',
                       'protocol':   'protocol',
                       'icmpType':   'icmp_type',
                       'icmpCode':   'icmp_code',
                       'Records':    'records',
                       '%_of_total': 'pct_of_total',
                       'cumul_%':    'cumul_pct' }

def parse_rwstats(file):
    # Strip off top two lines -- they provide useful info, but
    # not terribly so, and there's no good way to fit them in right
    # now. --pg
    file.readline()
    file.readline()
    return parse_generic_silk(rwstats_header_map, file)

def cmd(cmd, kwds={}):
    if not isinstance(cmd, (str, list)):
        raise TypeError('cmd must be string or list of strings')
    if isinstance(cmd, str):
        # Command is a string, split on whitespace
        cmd = cmd.split()
    # Map keyword substitution over list
    args = map((lambda x: x % kwds), cmd)
    args = [x for x in args if x <> '']
    return args

def pipe(*cmds, **kwargs):
    # pipe(cmd1, cmd2, cmd3)
    # hook each command up in a pipeline, stdout of 1 -> stdin of 2, etc.
    # return stdout of last command
    close_in = False
    close_err = False
    if 'stdin' in kwargs:
        in_file = kwargs['stdin']
    else:
        in_file = None
    if 'stderr' in kwargs:
        err_file = kwargs['stderr']
    else:
        err_file = None
    if isinstance(in_file, str):
        in_file = file(in_file, 'r')
        close_in = True
    if isinstance(err_file, str):
        err_file = file(err_file, 'a')
        close_err = True
    try:
        if len(cmds) == 1 and isinstance(cmds[0], (tuple, list)):
            cmds = cmds[0]
        in_fd = in_file
        err_fd = err_file
        out_fd = subprocess.PIPE
        for cmd in cmds:
            prev_popen = subprocess.Popen(cmd, stdin=in_fd, stderr=err_fd,
                                          stdout=out_fd)
            in_fd = prev_popen.stdout
    finally:
        if close_in:
            in_file.close()
        if close_err:
            err_file.close()
    return in_fd

def exec_pipe(out_file, *cmds, **kwargs):
    # Just like pipe, except that stdout of the final process
    # is written to a file, and this function waits for the process
    # to complete.  If out_file is None, no output redirection occurs.
    # Returns the exit status of the last command.
    close_file = False
    close_err = False
    close_in = False
    if 'stdin' in kwargs:
        in_file = kwargs['stdin']
    else:
        in_file = None
    if 'stderr' in kwargs:
        err_file = kwargs['stderr']
    else:
        err_file = None
    if isinstance(in_file, str):
        in_file = file(in_file, 'r')
        close_in = True
    if isinstance(err_file, str):
        err_file = file(err_file, 'a')
        close_err = True
    if isinstance(out_file, str):
        out_file = file(out_file, 'w')
        close_file = True
    try:
        if len(cmds) == 1 and isinstance(cmds[0], (tuple, list)):
            cmds = cmds[0]
        in_fd = in_file
        err_fd = err_file
        out_fd = subprocess.PIPE
        for i in xrange(len(cmds)):
            if i == len(cmds)-1:
                out_fd = out_file
            prev_popen = subprocess.Popen(cmds[i], stdin=in_fd, stdout=out_fd,
                                          stderr=err_fd)
            in_fd = prev_popen.stdout
        result = prev_popen.wait()
    finally:
        if close_file:
            out_file.close()
        if close_in:
            in_file.close()
        if close_err:
            err_file.close()
    return result

# example:
# 
# @rdeco.op_file(mime_type='image/png',
#                typemap={'sdate': silk_datetime,
#                         'edate': silk_datetime,
#                         'by': int,
#                         'width': int,
#                         'height': int})
# def ts_flows(out_file, sdate, edate, by, width=800, height=600):
#     c1 = cmd(['rwfilter', '--type=in,out,null', '--start-date=%(sdate)s',
#               '--end-date=%(edate)s', '--proto=0-255', '--pass=stdout'],
#              {'sdate': sdate, 'edate': edate})
#     c2 = cmd(['rwcount', '--bin-size=%(by)d', '--load-scheme=1'],
#              {'by': by})
#     d = parse_rwcount(pipe(c1, c2))
#     render_timeseries(out_file, d['time'], d['flows'], w=width, h=height)
#
# or, equivalently:
#
# @rdeco.op_file(mime_type='image/png',
#                typemap={'sdate': silk_datetime,
#                         'edate': silk_datetime,
#                         'by': int,
#                         'width': int,
#                         'height': int})
# def ts_flows(out_file, sdate, edate, by, width=800, height=600):
#     c1 = cmd("""rwfilter --type=in,out,null --start-date=%(sdate)s
#                 --end-date=%(edate)s --proto=0-255 --pass=stdout""",
#              {'sdate': sdate, 'edate': edate})
#     c2 = cmd("""rwcount --bin-size=%(by)d --load-scheme=1""",
#              {'by': by})
#     d = parse_rwcount(pipe(c1, c2))
#     render_timeseries(out_file, d['time'], d['flows'], w=width, h=height)

parsers = { 'nafscii': parse_nafscii,
            'rwtotal': parse_rwtotal,
            'rwuniq': parse_rwuniq,
            'rwcount': parse_rwcount }

def get_next_bin(bin_size):
    if bin_size == 60:
        cache_size = 3600
        next_bin_size = 60
    elif bin_size == 300:
        cache_size = 3600 * 3
        next_bin_size = 60
    elif bin_size == 600:
        cache_size = 3600 * 6
        next_bin_size = 300
    elif bin_size == 3600:
        cache_size = 86400
        next_bin_size = 600
    elif bin_size > 3600 and (bin_size % 3600) == 0:
        cache_size = bin_size / 3600 * 86400
        next_bin_size = 3600
    else:
        raise ValueError("can only deal with bin_size in 60,300,600,3600")
    return (cache_size, next_bin_size)

# Data expires in 60 minutes from now if the etime parameter comes
# within the past 3 hours.  Otherwise, it never expires.
etime_var_strategy = crossover_strategy(duration_strategy(60 * 60),
                                        forever,
                                        age=3600*3, age_arg='etime')

@op
@typemap(str, silk_datetime, list_of(str), dict)
@use_strategy(etime_var_strategy)
@version(20070301)
def get_bin_hour(parser, etime, cmds, kw={}):
    hour = silk_datetime(etime)[0:13]    
    if hour > silk_datetime(datetime.utcnow())[0:13]:
        return Dataset()
    kw = dict(kw)
    kw['hour'] = hour
    cmdmap = map((lambda c: cmd(c, kw)), cmds)
    return parsers[parser](pipe(cmdmap))

# I think this is an artiface of an aborted change.
# TODO: verify and remove
#re_whitespace = re.compile(r'\s+')
#def normalize(x):
#    return  " ".join(re.split(re_whitespace, x))

@op
@typemap(str, str, one_of(None, str), list_of(str), list_of(str),
         silk_datetime, silk_datetime, dict, int)
@use_strategy(etime_var_strategy)
@version(20070301)
def get_bin_range(parser, time_col, label_col, value_cols, cmds,
                  stime, etime, kw, bin_size):
    (cache_size, next_bin_size) = get_next_bin(bin_size)
    stime_secs = bin_datetime(cache_size, stime)
    etime_secs = bin_datetime(cache_size, etime) + cache_size
    t = stime_secs
    t0 = int(time.time())
    result = Dataset()

    while t < etime_secs and t < t0:
        if bin_size == 60:
            d = get_bin_hour(parser, t, cmds, kw)
        else:
            d = get_bin_range(parser, time_col, label_col, value_cols,
                              cmds, t, t+cache_size, kw, next_bin_size)
            if label_col:
                labels = d[label_col]
                times = d[time_col]
                values = tuple(d[x] for x in value_cols)
                d = Dataset((label_col, time_col) + tuple(value_cols),
                            _data=bin_labeled_by(bin_size, labels, times,
                                                 *values))
            else:
                times = d[time_col]
                values = tuple(d[x] for x in value_cols)
                d = Dataset((time_col,) + tuple(value_cols),
                            _data=bin_by(bin_size, times, *values))
        result += d
        t = t + cache_size
        
    return result

def get_binned(parser, time_col, label_col, value_cols, cmds,
               stime, etime, kw={}, bin_size=60):
    stime = datetime_obj(stime)
    etime = datetime_obj(etime)
    data = get_bin_range(parser, time_col, label_col, value_cols, cmds,
                         stime, etime, kw=kw, bin_size=bin_size)        
    return Dataset(r for r in data if r[time_col] >= stime
                                   if r[time_col] <= etime)
    

__all__ = [ 'parse_nafscii', 'parse_rwtotal', 'parse_rwuniq',
            'parse_rwcount', 'parse_generic', 'cmd', 'pipe', 'exec_pipe',
            'get_binned' ]
