/*
 *  Copyright 2012-2025 Carnegie Mellon University
 *  See license information in LICENSE.txt.
 */
/*
 *  mediator_filter.c
 *
 *  Functions to limit which IPFIX records are processed based on values
 *  within the record.
 *
 *  ------------------------------------------------------------------------
 *  Authors: Emily Sarneso
 *  ------------------------------------------------------------------------
 *  @DISTRIBUTION_STATEMENT_BEGIN@
 *  super_mediator-1.13
 *
 *  Copyright 2025 Carnegie Mellon University.
 *
 *  NO WARRANTY. THIS CARNEGIE MELLON UNIVERSITY AND SOFTWARE ENGINEERING
 *  INSTITUTE MATERIAL IS FURNISHED ON AN "AS-IS" BASIS. CARNEGIE MELLON
 *  UNIVERSITY MAKES NO WARRANTIES OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
 *  AS TO ANY MATTER INCLUDING, BUT NOT LIMITED TO, WARRANTY OF FITNESS FOR
 *  PURPOSE OR MERCHANTABILITY, EXCLUSIVITY, OR RESULTS OBTAINED FROM USE OF
 *  THE MATERIAL. CARNEGIE MELLON UNIVERSITY DOES NOT MAKE ANY WARRANTY OF
 *  ANY KIND WITH RESPECT TO FREEDOM FROM PATENT, TRADEMARK, OR COPYRIGHT
 *  INFRINGEMENT.
 *
 *  Licensed under a GNU GPL 2.0-style license, please see LICENSE.txt or
 *  contact permission@sei.cmu.edu for full terms.
 *
 *  [DISTRIBUTION STATEMENT A] This material has been approved for public
 *  release and unlimited distribution.  Please see Copyright notice for
 *  non-US Government use and distribution.
 *
 *  This Software includes and/or makes use of Third-Party Software each
 *  subject to its own license.
 *
 *  DM25-1447
 *  @DISTRIBUTION_STATEMENT_END@
 *  ------------------------------------------------------------------------
 */

#include "mediator_filter.h"
#include "mediator_util.h"

#if ENABLE_SKIPSET
#include SKIPSET_HEADER_NAME
#if HAVE_SILK_SKIPADDR_H
#include <silk/skipaddr.h>
#endif
#endif /* if ENABLE_SKIPSET */



/**
 * mdComparison
 *
 * compare val_one to val_two with the given oper
 *
 * @param val_one - a value to be compared
 * @param val_two - the other value to be compared
 * @param oper - an operation such as '=', or '<'
 * @return TRUE - if the comparison is TRUE.
 *
 */
static gboolean
mdComparison(
    uint64_t        val_one,
    uint64_t        val_two,
    fieldOperator   oper)
{
    switch (oper) {
      case EQUAL:
        return (val_one == val_two);
      case NOT_EQUAL:
        return (val_one != val_two);
      case GREATER_THAN:
        return (val_one > val_two);
      case GREATER_THAN_OR_EQUAL:
        return (val_one >= val_two);
      case LESS_THAN:
        return (val_one < val_two);
      case LESS_THAN_OR_EQUAL:
        return (val_one <= val_two);
      default:
        return FALSE;
    }
    return FALSE;               /* NOTREACHED */
}

/**
 * Compare two IPv6 addresses with the given operator
 *
 * @param val_one - a value to be compared
 * @param val_two - the other value to be compared
 * @param oper - an operation such as '=', or '<'
 * @return TRUE - if the comparison is TRUE.
 *
 */
static gboolean
mdCompareIPv6(
    const uint8_t   val_one[16],
    const uint8_t   val_two[16],
    fieldOperator   oper)
{
    const int rc = memcmp(val_one, val_two, 16);
    switch (oper) {
      case EQUAL:
        return (rc == 0);
      case NOT_EQUAL:
        return (rc != 0);
      case GREATER_THAN:
        return (rc > 0);
      case GREATER_THAN_OR_EQUAL:
        return (rc >= 0);
      case LESS_THAN:
        return (rc < 0);
      case LESS_THAN_OR_EQUAL:
        return (rc <= 0);
      default:
        return FALSE;
    }
    return FALSE;               /* NOTREACHED */
}

#if ENABLE_SKIPSET
/**
 * mdCompareIPSet
 *
 * @param md_flow - the full flow record
 * @param filter - the filter node
 *
 * @return TRUE if the filter passed
 *
 */
static gboolean
mdCompareIPSet(
    const mdFullFlow_t  *md_flow,
    const md_filter_t   *filter)
{
    const md_main_template_t *flow;
    skipaddr_t                addr;

    flow = md_flow->rec;

    if (!md_flow->tmpl_attr.is_ipv6) {
        switch (filter->field) {
          case ANY_IP:
            skipaddrSetV4(&addr, &(flow->sourceIPv4Address));
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            skipaddrSetV4(&addr, &(flow->destinationIPv4Address));
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          case DIP_ANY:
          case DIP_V4:
            skipaddrSetV4(&addr, &(flow->destinationIPv4Address));
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          case SIP_ANY:
          case SIP_V4:
            skipaddrSetV4(&addr, &(flow->sourceIPv4Address));
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          default:
            return FALSE;
        }
    }
#if HAVE_SKIPADDRSETV6
    else {
        switch (filter->field) {
          case ANY_IP:
          case ANY_IP6:
            skipaddrSetV6(&addr, flow->sourceIPv6Address);
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            skipaddrSetV6(&addr, flow->destinationIPv6Address);
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          case DIP_ANY:
          case DIP_V6:
            skipaddrSetV6(&addr, flow->destinationIPv6Address);
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          case SIP_ANY:
          case SIP_V6:
            skipaddrSetV6(&addr, flow->sourceIPv6Address);
            if (skIPSetCheckAddress(filter->ipset->ipset, &addr)) {
                return TRUE;
            }
            return FALSE;
          default:
            return FALSE;
        }
    }
#endif  /* HAVE_SKIPADDRSETV6 */

    return FALSE;
}
#endif  /* ENABLE_SKIPSET */


/**
 * mdCollectionFilter
 *
 *    Loop through the filters and compare the filter expressions to the given
 *    flow record.
 *
 *    If `and_filter` is FALSE, one expression being TRUE causes the function
 *    to return TRUE.  The function returns FALSE when all expressions are
 *    FALSE.
 *
 *    If `and_filter` if TRUE, one expression being FALSE causes the function
 *    to return FALSE, The function returns TRUE when all expressions are
 *    TRUE.
 *
 *    @param filter - a list of filters
 *    @param md_flow - the full flow record
 *    @param and_filter - whether the filtering expressions should be ANDed
 *    @param obdomain - the observation domain
 *    @param collector_id - the collector ID
 *    @return TRUE if one of the filters passed
 */
static gboolean
mdCollectionFilter(
    const md_filter_t   *filter,
    const mdFullFlow_t  *md_flow,
    uint32_t             obdomain,
    gboolean             and_filter,
    uint16_t             collector_id)
{
    const md_main_template_t *flow;
    const md_filter_t        *cfil;
    gboolean                  rc;

    flow = ((NULL == md_flow) ? NULL : md_flow->rec);

    for (cfil = filter; cfil != NULL; cfil = cfil->next) {
        rc = and_filter;
        if (!flow) {
            /* stats/dedup/dnsrr record - only 2 fields apply */
            switch (cfil->field) {
              case COLLECTOR:
                rc = mdComparison(collector_id, cfil->val.u64, cfil->oper);
                break;
              case OBDOMAIN:
                rc = mdComparison(obdomain, cfil->val.u64, cfil->oper);
                break;
              default:
                rc = FALSE;
                break;
            }
            goto maybe_return;
        }

#if ENABLE_SKIPSET
        if (cfil->ipset) {
            rc = mdCompareIPSet(md_flow, cfil);
            if (NOT_IN_LIST == cfil->oper) {
                rc = !rc;
            } else if (IN_LIST != cfil->oper) {
                g_error("Programmer error: Invalid operator %d for IPset",
                        cfil->oper);
            }
            goto maybe_return;
        }
#endif  /* ENABLE_SKIPSET */

        switch (cfil->field) {
          case ANY_IP6:
            rc = (mdCompareIPv6(flow->sourceIPv6Address,
                                cfil->val.ip6, cfil->oper) ||
                  mdCompareIPv6(flow->destinationIPv6Address,
                                cfil->val.ip6, cfil->oper));
            break;
          case ANY_IP:
            rc = (mdComparison(flow->sourceIPv4Address,
                               cfil->val.ip4, cfil->oper) ||
                  mdComparison(flow->destinationIPv4Address,
                               cfil->val.ip4, cfil->oper));
            break;
          case ANY_PORT:
            rc = (mdComparison(flow->sourceTransportPort,
                               cfil->val.u64, cfil->oper) ||
                  mdComparison(flow->destinationTransportPort,
                               cfil->val.u64, cfil->oper));
            break;
          case APPLICATION:
            rc = mdComparison(flow->silkAppLabel, cfil->val.u64, cfil->oper);
            break;
          case BYTES:
            rc = mdComparison(((md_flow->tmpl_attr.is_delta)
                               ? flow->octetDeltaCount
                               : flow->octetTotalCount),
                              cfil->val.u64, cfil->oper);
            break;
          case COLLECTOR:
            rc = mdComparison(collector_id, cfil->val.u64, cfil->oper);
            break;
          case DIP_V4:
            rc = mdComparison(flow->destinationIPv4Address,
                              cfil->val.ip4, cfil->oper);
            break;
          case DIP_V6:
            rc = mdCompareIPv6(flow->destinationIPv6Address,
                               cfil->val.ip6, cfil->oper);
            break;
          case DPORT:
            rc = mdComparison(flow->destinationTransportPort,
                              cfil->val.u64, cfil->oper);
            break;
          case EGRESS:
            rc = mdComparison(flow->egressInterface,
                              cfil->val.u64, cfil->oper);
            break;
          case IFLAGS:
            rc = mdComparison(flow->initialTCPFlags,
                              cfil->val.u64, cfil->oper);
            break;
          case INGRESS:
            rc = mdComparison(flow->ingressInterface,
                              cfil->val.u64, cfil->oper);
            break;
          case IPVERSION:
            if (md_flow->tmpl_attr.is_ipv6) {
                rc = (cfil->val.u64 == 6 && cfil->oper == EQUAL);
            } else {
                rc = (cfil->val.u64 == 4 && cfil->oper == EQUAL);
            }
            break;
          case OBDOMAIN:
            rc = mdComparison(obdomain, cfil->val.u64, cfil->oper);
            break;
          case PKTS:
            rc = mdComparison(((md_flow->tmpl_attr.is_delta)
                               ? flow->packetDeltaCount
                               : flow->packetTotalCount),
                              cfil->val.u64, cfil->oper);
            break;
          case PROTOCOL:
            rc = mdComparison(flow->protocolIdentifier,
                              cfil->val.u64, cfil->oper);
            break;
          case RBYTES:
            rc = mdComparison(((md_flow->tmpl_attr.is_delta)
                               ? flow->reverseOctetDeltaCount
                               : flow->reverseOctetTotalCount),
                              cfil->val.u64, cfil->oper);
            break;
          case RIFLAGS:
            rc = mdComparison(flow->reverseInitialTCPFlags,
                              cfil->val.u64, cfil->oper);
            break;
          case RPKTS:
            rc = mdComparison(((md_flow->tmpl_attr.is_delta)
                               ? flow->reversePacketDeltaCount
                               : flow->reversePacketTotalCount),
                              cfil->val.u64, cfil->oper);
            break;
          case RUFLAGS:
            rc = mdComparison(flow->reverseUnionTCPFlags,
                              cfil->val.u64, cfil->oper);
            break;
          case SIP_V4:
            rc = mdComparison(flow->sourceIPv4Address,
                              cfil->val.ip4, cfil->oper);
            break;
          case SIP_V6:
            rc = mdCompareIPv6(flow->sourceIPv6Address,
                               cfil->val.ip6, cfil->oper);
            break;
          case SPORT:
            rc = mdComparison(flow->sourceTransportPort,
                              cfil->val.u64, cfil->oper);
            break;
          case UFLAGS:
            rc = mdComparison(flow->unionTCPFlags,
                              cfil->val.u64, cfil->oper);
            break;
          case VLAN:
            rc = mdComparison(flow->vlanId, cfil->val.u64, cfil->oper);
            break;
          default:
            g_warning("Unsupported Field %s in Filter .. Ignoring",
                      mdFieldName(cfil->field));
            continue;
        }

      maybe_return:
        /* Return early if (!and_filter and TRUE) or (and_filter and FALSE) */
        if (!and_filter) {
            if (rc) {
                return TRUE;
            }
        } else if (!rc) {
            return FALSE;
        }
    }

    /* if we've made it here we're either an OR filter with all false or an
     * AND filter with all true, so return value of and_filter */
    return and_filter;
}


/**
 * mdFilter
 *
 * @param filter - a list of filters
 * @param md_flow - the full flow record
 * @param obdomain - the observation domain
 * @param and_filter - true if filters should be ANDed
 * @param collector_id - the collector ID
 * @return TRUE if one of the filters passed
 *
 */
gboolean
mdFilter(
    const md_filter_t   *filter,
    const mdFullFlow_t  *md_flow,
    uint32_t             obdomain,
    gboolean             and_filter,
    uint16_t             collector_id)
{
    if (filter == NULL) {
        return TRUE;
    }
    return mdCollectionFilter(filter, md_flow, obdomain,
                              and_filter, collector_id);
}

/**
 * md_new_filter_node
 *
 */
md_filter_t *
md_new_filter_node(
    void)
{
    md_filter_t *mf = g_slice_new0(md_filter_t);

    return mf;
}

void
md_free_filter_node(
    md_filter_t  *mf)
{
    if (mf) {
#if ENABLE_SKIPSET
        if (mf->ipset) {
            mdUtilIPSetClose(mf->ipset);
        }
#endif  /* ENABLE_SKIPSET */
        g_slice_free(md_filter_t, mf);
    }
}


const char *
mdFieldName(
    mdAcceptFilterField_t  field)
{
    static char buf[32];

    switch (field) {
      case SIP_ANY: return "SIP";
      case DIP_ANY: return "DIP";
      case SIP_V4: return "SIP_V4";
      case DIP_V4: return "DIP_V4";
      case SPORT: return "SPORT";
      case DPORT: return "DPORT";
      case PROTOCOL: return "PROTOCOL";
      case APPLICATION: return "APPLICATION";
      case SIP_V6: return "SIP_V6";
      case DIP_V6: return "DIP_V6";
      case ANY_IP6: return "ANY_IP6";
      case ANY_IP: return "ANY_IP";
      case ANY_PORT: return "ANY_PORT";
      case OBDOMAIN: return "OBDOMAIN";
      case IPVERSION: return "VERSION";
      case VLAN: return "VLAN";
      case FLOWKEYHASH: return "FLOWKEYHASH";
      case DURATION_PLACEHOLDER: return "DURATION";
      case STIME_PLACEHOLDER: return "STIME";
      case ETIME_PLACEHOLDER: return "ETIME";
      case STIME_EPOCH_MS: return "STIMEMS";
      case ETIME_EPOCH_MS: return "ETIMEMS";
      case SIP_INT: return "SIP_INT";
      case DIP_INT: return "DIP_INT";
      case RTT_PLACEHOLDER: return "RTT";
      case PKTS: return "PACKETS";
      case RPKTS: return "RPACKETS";
      case BYTES: return "BYTES";
      case RBYTES: return "RBYTES";
      case IFLAGS: return "IFLAGS";
      case RIFLAGS: return "RIFLAGS";
      case UFLAGS: return "UFLAGS";
      case RUFLAGS: return "RUFLAGS";
      case ATTRIBUTES: return "ATTRIBUTES";
      case RATTRIBUTES: return "RATTRIBUTES";
      case MAC: return "MAC";
      case DSTMAC: return "DSTMAC";
      case TCPSEQ: return "TCPSEQ";
      case RTCPSEQ: return "RTCPSEQ";
      case ENTROPY: return "ENTROPY";
      case RENTROPY: return "RENTROPY";
      case ENDREASON: return "ENDREASON";
      case OSNAME: return "OSNAME";
      case OSVERSION: return "OSVERSION";
      case ROSNAME: return "ROSNAME";
      case ROSVERSION: return "ROSVERSION";
      case FINGERPRINT: return "FINGERPRINT";
      case RFINGERPRINT: return "RFINGERPRINT";
      case DHCPFP: return "DHCPFP";
      case DHCPVC: return "DHCPVC";
      case RDHCPFP: return "RDHCPFP";
      case RDHCPVC: return "RDHCPVC";
      case INGRESS: return "INGRESS";
      case EGRESS: return "EGRESS";
      case DATABYTES: return "DATABYTES";
      case RDATABYTES: return "RDATABYTES";
      case ITIME: return "ITIME";
      case RITIME: return "RITIME";
      case STDITIME: return "STDITIME";
      case RSTDITIME: return "RSTDITIME";
      case TCPURG: return "TCPURG";
      case RTCPURG: return "RTCPURG";
      case SMALLPKTS: return "SMALLPKTS";
      case RSMALLPKTS: return "RSMALLPKTS";
      case LARGEPKTS: return "LARGEPKTS";
      case RLARGEPKTS: return "RLARGEPKTS";
      case NONEMPTYPKTS: return "NONEMPTYPKTS";
      case RNONEMPTYPKTS: return "RNONEMPTYPKTS";
      case MAXSIZE: return "MAXSIZE";
      case RMAXSIZE: return "RMAXSIZE";
      case STDPAYLEN: return "STDPAYLEN";
      case RSTDPAYLEN: return "RSTDPAYLEN";
      case FIRSTEIGHT: return "FIRSTEIGHT";
      case DPI: return "DPI";
      case VLANINT: return "VLANINT";
      case TOS: return "TOS";
      case RTOS: return "RTOS";
      case MPLS1: return "MPLS1";
      case MPLS2: return "MPLS2";
      case MPLS3: return "MPLS3";
      case COLLECTOR: return "COLLECTOR";
      case FIRSTNONEMPTY: return "FIRSTNONEMPTY";
      case RFIRSTNONEMPTY: return "RFIRSTNONEMPTY";
      case MPTCPSEQ: return "MPTCPSEQ";
      case MPTCPTOKEN: return "MPTCPTOKEN";
      case MPTCPMSS: return "MPTCPMSS";
      case MPTCPID: return "MPTCPID";
      case MPTCPFLAGS: return "MPTCPFLAGS";
      case PAYLOAD: return "PAYLOAD";
      case RPAYLOAD: return "RPAYLOAD";
      case DHCPOPTIONS: return "DHCPOPTIONS";
      case RDHCPOPTIONS: return "RDHCPOPTIONS";
      case NDPI_MASTER: return "NDPI_MASTER";
      case NDPI_SUB: return "NDPI_SUB";
      case STIME_EPOCH: return "STIME_EPOCH";
      case ETIME_EPOCH: return "ETIME_EPOCH";
      default:
        snprintf(buf, sizeof(buf), "FIELD_%d", (int)field);
        return buf;
    }
}


#ifdef HAVE_SPREAD
/**
 * mdSpreadExporterFilter
 *
 * loop through spread filters and add the groups
 * that should receive this flow.  If a group does not
 * have a filter, it automatically gets all flows.
 *
 */
int
mdSpreadExporterFilter(
    const md_spread_filter_t  *sf,
    const mdFullFlow_t        *md_flow,
    char                     **groups)
{
    const md_spread_filter_t *cfil;
    int                       num_groups = 0;

    for (cfil = sf; cfil != NULL; cfil = cfil->next) {
        if (cfil->filterList == NULL) {
            /* there is no filter for this group so it gets everything */
            groups[num_groups] = cfil->group;
            num_groups++;
        } else {
            if (mdCollectionFilter(cfil->filterList, md_flow, 0, 0, 0)) {
                if (num_groups < 10) {
                    groups[num_groups] = cfil->group;
                    num_groups++;
                }
            }
        }
    }

    return num_groups;
}

/**
 * md_new_spread_node
 *
 *
 */
md_spread_filter_t *
md_new_spread_node(
    void)
{
    md_spread_filter_t *ms = g_slice_new0(md_spread_filter_t);
    return ms;
}

void
md_free_spread_node(
    md_spread_filter_t  *ms)
{
    g_slice_free(md_spread_filter_t, ms);
}

#endif  /* HAVE_SPREAD */
