"""
Generates data for host characterization:
    * # bytes of traffic where host was acting as a server for web traffic.
    * # bytes of traffic where host was acting as a server for mail traffic.
    * # bytes of traffic where host was acting as a server for other traffic.
        * Protocol, port/ICMP message type (if relevant) of largest subtype
          in that traffic
    * # bytes of traffic where host was acting as a client for web traffic.
    * # bytes of traffic where host was acting as a client for mail traffic.
    * # bytes of traffic where host was acting as a client for other traffic.
        * Protocol, port/ICMP message type (if relevant) of largest subtype
          in that traffic
    * # bytes of all other traffic
        * Protocol, port/ICMP message type (if relevant) of largest subtype
          in that traffic
"""
from rave.plugins.decorators import *
from rave.plugins.shell import datetime_obj, cmd, pipe
from rave.log import log_factory


from sets import Set
from datetime import timedelta
import logging
import subprocess

get_log = log_factory()
#logging.basicConfig(level=logging.DEBUG)

# Abbreviations:
# hsp : host server ports
# hcp : host client ports
# hwmp: host web/mail ports

# Port families
pf = {}
pf['mail']     = Set((25, 143, 109, 110, 220, 993, 995))
pf['other']    = Set((
      20 , 21  , 53  , 67 , 68 , 88
    , 123, 161 , 162 , 213, 389, 396
    , 500, 543 , 636 , 749, 750, 760 
    , 761, 2105, 3268
))
pf['routing']  = Set((179, 520, 521))
pf['simple']   = Set((7, 9, 11, 13, 17, 19))
pf['targeted'] = Set((135, 137, 138, 139, 445, 1433, 1434, 3127, 4899))
pf['terminal'] = Set((22, 23, 512, 513, 514, 3389, 5950, 5951, 6000))
pf['web']      = Set((80, 443, 1080, 3128, 8080))
pf['p2p']      = Set((
      411, 412, 1214, 4662, 5737, 6346, 6347
    , 6699, 6881, 6882, 6883, 6884, 6885, 6886
))
pf_allfams = Set(
    reduce(lambda x,y: x.union(y), pf.values())
)



class TimeRange(object):
    """
    A set of datetime objects representing a bounded time range.
    """
    def __init__(self, start, end):
        """
        Parameters:
        start
            start time as a silk datetime specifier
        end 
            end time as a silk datetime specifier
        """
        self.start = datetime_obj(start)
        self.end = datetime_obj(end)
    def mk_recent(self, pct):
        "Make a new TimeRange that's the most recent pct percent of this one."
        td = (self.end - self.start) // pct
        return TimeRange(self.end - td, self.end)
    def __str__(self):
        return "<TimeRange: %s - %s>" % (self.start, self.end)
    def __eq__(self, other):
        return self.start == other.start and self.end == other.end
    def __lt__(self,other):
        return self.end < other.start
    def __gt__(self,other):
        return self.start > other.end
    def __contains__(self, key):
        return self.start < key.start and self.end > key.end
    def __len__(self):
        return self.end - self.start

def timerange(src):
    if isinstance(src, TimeRange):
        return src
    elif len(src) == 2:
        return TimeRange(src[0], src[1])
    else:
        raise TypeError("Illegal value for TimeRange object: %s" % src)


# Time-related functions that can't quite make it onto the TimeRange class

def d2s(fmt, d):
    "Convert datetime d to string according to format fmt."
    return fmt % {
        'Y': d.year, 'M': d.month, 'D': d.day
        , 'h': d.hour, 'm': d.minute, 's': d.second
    }

def interval(r):
    timefmt = "%(Y).4d/%(M).2d/%(D)2d:%(h).2d:%(m).2d:%(s).2d"
    return "%s-%s" % (
        d2s(timefmt, r.start), d2s(timefmt, r.end)
    )



class OtherList(object):
    "Data structure for 'other' category in plot."
    def __init__(self):
        self.num_bytes = 0
        self.peak = (0, 0, 0)
    def update(self, proto, port, num_bytes):
        """
        Update running tally of total and peak traffic.

        For ICMP traffic, use message type as port.
        For all other non-TCP/UDP traffic, use 0 as port.
        """
        self.num_bytes += num_bytes
        if self.peak[2] < num_bytes:
            self.peak = (proto, port, num_bytes)
    def __str__(self):
        return "<OtherList: %d bytes (Peak: proto=%d, port=%d, %d bytes>" % (
            self.num_bytes, self.peak[0], self.peak[1], self.peak[2]
        )

class Orientation(object):
    "Traffic sharing a particular orientation (client or server)."
    def __init__(self):
        self.families = { 'web': 0, 'mail': 0}
    #   Other (non-web or -mail) that still has this orientation
        self.other = OtherList()
    def update_family(self, family, num_bytes):
        self.families[family] += num_bytes
        return self.families[family]
    def update_other(self, proto, port, num_bytes):
        return self.other.update(proto, port, num_bytes)
    def dump(self):
        for k,v in self.families.items():
            print "\t%s: %s" % (k, str(v))
        print "other: %s" % str(self.other)
    def totals(self):
        return (
              self.families['web']
            , self.families['mail']
            , self.other.num_bytes
        )


class HostCharacterization(object):
    "A series of data representing the host characterization."
    def __init__(self):
        self.server = Orientation()
        self.client = Orientation()
    #   Traffic that doesn't fit in either the server or client Orientation
        self.other  = OtherList()

    def update(self, orientation, proto, port, num_bytes):
        """
        Update host characterization. Assumes that all TCP/UDP traffic 
        it sees comes from or goes to at least one well-known port.
        (In other words, that it is "oriented" traffic.)
        """
        if proto not in (6,17):
            return self.other.update(proto, port, num_bytes)
        for f in 'web', 'mail':
            if port in pf[f]:
                return orientation.update_family(f, num_bytes)
    #   Not in web or mail port families
        return orientation.update_other(proto, port, num_bytes)

    def update_client(self, proto, port, num_bytes):
        return self.update(self.client, proto, port, num_bytes)

    def update_server(self, proto, port, num_bytes):
        return self.update(self.server, proto, port, num_bytes)

    def update_other(self, proto, port, num_bytes):
        return self.other.update(proto, port, num_bytes)

    def as_tuple(self):
        cw, cm, co = self.client.totals()
        sw, sm, so = self.server.totals()
        return (cm, cw, sw, sm, so, self.other.num_bytes, co)

    def dump(self):
        print "SERVER:"
        self.server.dump()
        print "CLIENT:"
        self.client.dump()
        print "OTHER:"
        print "\t%s" % str(self.other)


# I don't know if silk data is de facto encoded as an octet stream (i.e.,
# network byte-order and whatnot). 
@op_file
@typemap(hostip=str, r=timerange, types=str)
@mime_type('application/octet-stream')
@use_strategy(duration_strategy(5*60))
def silk_host_data(out_file, hostip, r, types='in,out'):
    """
    Create a file containing all flows going to or from hostip in the
    specified time range and types of traffic.
    Parameters:
    out_file
        Where to write binary silk data file
    hostip
        IP of host whose traffic to pull
    r
        TimeRange object denoting beginning and end of time under
        consideration
    types
        string of silk types to be pulled from. Passed directly to the
        rwfilter '--types' switch
    """
    datefmt = "%(Y).4d/%(M).2d/%(D).2d:%(h).2d"
    def interpolate(template, **kwargs):
        x = template % kwargs
        return x
    rwf = interpolate(
        """rwfilter
        --type %(type)s
        --start-date=%(start)s
        --end-date=%(end)s
        --stime=%(interval)s
        --any-address=%(ip)s
        --pass=%(file)s
        """
        , type=types
        , start=d2s(datefmt, r.start)
        , end=d2s(datefmt, r.end)
        , interval=interval(r)
        , ip=hostip
        , file=out_file
    )
    get_log().debug("RWF: %s" % rwf)
#   Shouldn't use rave.shell.pipe() here because we need to wait for the
#   command to finish.
    rc = subprocess.call(rwf.split())
    return out_file

def run_pipeline(silk_cmds, typemap=None):
    """
    Run and parse delimited output.
    Parameters
        typemap - sequence of type conversion functions. there must be one
        for each delimited item.
    Returns
        generator of tuples containing output, mapped by typemap
    """
    for c in silk_cmds:
        get_log().debug("COMMAND: %s" % c)
    out = pipe(*(cmd(c) for c in silk_cmds))
    for line in out.readlines():
        items = line[:-2].split('|')
        yield tuple(x(y) for x,y in zip(typemap, items))
    get_log().debug("DONE WITH PIPELINE")


def filter_normal(fname, addrtype, hostip, wellknown, r):
    """
    Filter silk data in fname for traffic coming from or going to
    a set of well-known ports, and either coming to or from a given IP.
    Parameters:
    fname
        Name of file containing prefiltered silk data
    addrtype
        'saddress' or 'daddress' -- Whether hostip is the source of
        destination of traffic
    hostip
        IP to filter for
    wellknown
        'sport' or 'dport' -- whether the well-known ports are source
        or destination ports
    pset
        set of well-known ports  to filter for
    """
    if wellknown == 'sport':
        ephemeral = 'dport'
    else:
        ephemeral = 'sport'
    pipeline = (
          """rwfilter
          --protocol=6,17
          --%s=%s 
          --%s=0-1023
          --%s=1024-65535 
          --stime=%s
          --pass=stdout %s
          """ % (
                addrtype, hostip
              , wellknown, ephemeral
              , interval(r), fname
          ) 
        , """rwuniq --no-titles --no-columns
          --bytes --fields=proto,%s
          """ % wellknown
    )
    return run_pipeline(pipeline, typemap=(int, int, long))


def filter_upper(fname, addrtype, hostip, porttype, r):
    """
    Filter ephemeral->ephemeral port traffic.
    Parameters:
    fname
        Name of file containing prefiltered silk data
    addrtype
        'saddress' or 'daddress' -- Whether hostip is the source of
        destination of traffic
    hostip
        IP to filter for
    porttype
        'sport' or 'dport' -- whether to aggregate by source or
        destination port
    """
    pipeline = (
          """rwfilter
          --protocol=6,17
          --%s=%s
          --sport=1024-65535
          --dport=1024-65535
          --stime=%s
          --pass=stdout %s
          """ % (addrtype, hostip, interval(r), fname)
        , """rwuniq --no-titles --no-columns
          --bytes --fields=proto,%s
          """ % porttype
    )
    return run_pipeline(pipeline, typemap=(int, int, long))

def filter_lower(fname, addrtype, hostip, porttype, r):
    "Filter lower->lower TCP/UDP traffic."
    pipeline = (
          """rwfilter
          --protocol=6,17
          --%s=%s
          --sport=0-1023
          --dport=0-1023
          --stime=%s
          --pass=stdout %s
          """ % (addrtype, hostip, interval(r), fname)
        , """rwuniq --no-titles --no-columns
          --bytes --fields=proto,%s
          """ % porttype
    )
    return run_pipeline(pipeline, typemap=(int, int, long))

def filter_icmp(fname, r):
    "Filter ICMP traffic"
    pipeline = (
          """rwfilter
          --protocol=1
          --stime=%s
          --pass=stdout %s
          """ % (interval(r), fname)
        , "rwtotal --no-title --no-columns --skip-zeroes --icmp-code"
    )
    recs = run_pipeline(pipeline, typemap=(str,int,long,int))
#   Can't make rwtotal _not_ display flows and packets, so ignore them.
    for icmp, ignore, bytes, ignore in recs:
        icmp_code, icmp_type = icmp.split()
        yield int(icmp_code), int(icmp_type), bytes


def filter_other(fname, r):
    "Filter non-TCP/UDP/ICMP traffic."
    pipeline = (
          """rwfilter
          --protocol=2-5,7-16,18-255
          --stime=%s
          --pass=stdout %s
          """ % (interval(r), fname)
        , """rwtotal --no-titles --no-columns
          --proto
          """
    )
    recs = run_pipeline(pipeline, typemap=(int,int,long,int))
#   Can't make rwtotal _not_ display flows and packets, so ignore them.
    for proto, ignore, bytes, ignore in recs:
        yield proto, bytes

def get_host_char(datafile, hostip, r):
    def svr_updater(hc, proto, port, bytes):
        hc.update_server(proto, port, bytes)
    def clt_updater(hc, proto, port, bytes):
        hc.update_client(proto, port, bytes)
    def flatten_generators(*gens):
        for g in gens:
            for i in g:
                yield i

    filter_info = {
    #   Inbound traffic has the hostip as its destination address. When the
    #   hostip is acting as a server, the destination port will be well-known
    #   and the source port will be ephemeral. When the hostip is acting as a
    #   client, the source port will be well-known and the destination port
    #   will be ephemeral.
          'ib': {   'addr'   : 'daddress'
                    , 'tx': {
                          'server' : ('dport', 'sport', svr_updater)
                        , 'client' : ('sport', 'dport', clt_updater) 
                    }
                }
    #   Outbound traffic has the hostip as its source address. When the hostip
    #   is acting as a server, the source port will be well-known and the
    #   destination port will be ephemeral. When the hostip is acting as a
    #   client, the destination port will be well-known and the source port
    #   will be ephemeral.
        , 'ob': {   'addr'   : 'saddress'
                    , 'tx': {
                          'server' : ('sport', 'dport', svr_updater)
                        , 'client' : ('dport', 'sport', clt_updater)
                    }
                }
    }
#   TCP/UCP:
#   Do lower ports
    hc = HostCharacterization()
    total_bytes = 0
    for dname, dir in filter_info.items():
        get_log().debug("Collecting datafile in %s direction", dname)
        addr = dir['addr']
        for tx, (wellknown, ephemeral, updater) in dir['tx'].items():
            recs = filter_normal(
                datafile, addr, hostip, wellknown, r 
            )
            for proto, port, bytes in recs:
                updater(hc, proto, port, bytes)
                total_bytes += bytes

#   Do upper-port-to-upper-port and lower-port-to-lower-port traffic.
#   Organize by the port on hostip.
    def tally_misc(gen):
        allbytes = 0
        for proto, port, bytes in gen:
            hc.update_other(proto, port, bytes)
            allbytes += bytes
        return allbytes
    total_bytes += tally_misc(
        filter_upper(datafile, 'saddress', hostip, 'sport', r))
    total_bytes += tally_misc(
        filter_upper(datafile, 'daddress', hostip, 'dport', r))
    total_bytes += tally_misc(
        filter_lower(datafile, 'saddress', hostip, 'sport', r))
    total_bytes += tally_misc(
        filter_lower(datafile, 'daddress', hostip, 'dport', r))

#   Do ICMP traffic -- put ICMP message type in port (discard
#   code for now)
    icmp_bytes = 0
    for msg_type, code, bytes in filter_icmp(datafile, r):
        hc.update_other(1, msg_type, bytes)
        icmp_bytes += bytes
    get_log().debug("Total ICMP bytes: %d", icmp_bytes)
    total_bytes += icmp_bytes

#   Do everything else -- port == 0
    other_bytes = 0
    for proto, bytes in filter_other(datafile, r):
        hc.update_other(proto, 0, bytes)
        other_bytes += bytes
    get_log().debug("Total other bytes: %d", other_bytes)
    total_bytes += other_bytes

    return hc, total_bytes

agg_cache_strat = crossover_strategy(
    duration_strategy(10*60), duration_strategy(2*60*60*24), age=3*60*60)

# Cache for 10 minutes if it's less than 3 hours old, then for 2 days
@op
@typemap(str, datetime_obj, datetime_obj, str)
@use_strategy(agg_cache_strat)
def aggregate_by_hwmp(hostip, stime, etime, types):
    r = TimeRange(stime, etime)
    r_recent = r.mk_recent(20)
    get_log().debug("Total time range: %s" % r)
    get_log().debug("Recent time range: %s" % r_recent)
    datafile = silk_host_data(hostip, r, types) 
    get_log().debug("Getting total range data")
    all, total = get_host_char(datafile, hostip, r)
    get_log().debug("=====================================")
    get_log().debug("Getting recent data")
    recent, ignore = get_host_char(datafile, hostip, r_recent)
    return all, recent, total



















def main(hostip, start, end):
    all, recent, total  = aggregate_by_hwmp(hostip, start, end, "in,out")
    print "------------------ ALL DATA ----------------"
    all.dump()
    print "------------------ RECENT DATA ----------------"
    recent.dump()
    print "TOTAL BYTES: %d" % total
    print "-------CODE------"
    print """
all    = %s
recent = %s
total  = %s

all_client_other_peak = %s
all_server_other_peak = %s
other_peak =  %s
    """ % (
          str(all.as_tuple())
        , str(recent.as_tuple())
        , total
        , str(all.client.other.peak)
        , str(all.server.other.peak)
        , str(all.other.peak)
    )

def runmain():
    #main('140.221.158.251', "2005/11/14:00:00:00", "2005/11/14:23:59:59")
    main('140.0.0.0/8', "2005/11/14:00:00:00", "2005/11/14:23:59:59")
    #main('140.0.0.0/8', "2005/11/14:22:47:59", "2005/11/14:23:59:59")
    
if __name__ == '__main__':
    runmain()
