Source code for ensign.csv2tensor

#!/usr/bin/env python
# ENSIGN rights
""" Converts tabular data into a sparse tensor.
"""

import copy
from datetime import datetime, date, timedelta
from datetime import time as date_time
from ipaddress import ip_address
import itertools as it
import linecache
from glob import glob
import os
from pathlib import Path
import sys
import time

import dask.dataframe as dd
from dask.distributed import Client, Variable
import numpy as np
import pandas as pd

from ensign import sptensor as spt
from ensign.constants import *
import ensign.ensign_io.ensign_logging as ensign_logging

log = ensign_logging.get_logger()
num_failures_global = None
#------------------------------------------------------------------------------
# Constants
#------------------------------------------------------------------------------
MAX_INT = sys.maxsize
MAX_FLOAT = sys.float_info.max
FLOAT_SIG_DIGS = 12
# AGG_SPLIT is a parameter to parallelize group-by aggregations. 
# This value must be tuned for optimal performance. Lower values generally
# provide small performance improvements. Larger values can leard to larger
# performance improvements or degradations when it is too large.
AGG_SPLIT = 3 
NUM_BRO_HEADER_LINES = 7
IP4_FAILURE_REPLACEMENT_STR = "x.x.x.x"
IP4_FAILURE_REPLACEMENT_INT = 0
FUSED_COLUMN_NAME_DELIMITER = "__"
VALID_NUM_BIN_TYPES = ['none', 'log10', 'binsize', 'round', 'cyclic'] 
VALID_IP_BIN_TYPES = ['ipsubnet', 'ipv6_hextets']
VALID_TIME_BIN_TYPES = ['second', 'minute', 'hour', 'day', 'month', 'year', \
                        'minute_of_hour', 'hour_of_day', 'day_of_week', \
                        'day_of_month', 'month_of_year']
VALUE_TENSOR_AGG_METHODS = ['sum', 'max', 'min', 'max_abs', 'min_abs', 'first', \
                            'last', 'mean', 'prod', 'idxmin', 'idxmax']
DATETIME_DIRECTIVES = ['a', 'A', 'w', 'd', 'b', 'B', 'm', 'y', 'Y', 'H', 'I', \
                       'p', 'M', 'S', 'f', 'z', 'Z', 'j', 'U', 'W', 'c', 'x', \
                       'X', '%', 'G', 'u', 'V', 'E']
VALID_TYPES = ['str', 'int64', 'float64', 'ip', 'datetime', 'timestamp', 'date', 'time']
TIME_TYPES = ['datetime', 'timestamp', 'date', 'time']
PANDAS_ENSIGN_TYPE_DICT = {
    'str': np.object,
    'int64': np.int64,
    'float64': np.float64,
    'ip': np.object,
    'datetime': np.dtype('<M8[ns]'),
    'timestamp': np.dtype('<M8[ns]'),
    'date': np.dtype('<M8[ns]'),
    'time': np.dtype('<M8[ns]')
}
BINNING_ACCEPTABLE_NP_TYPES = {
    'none': [np.object, np.int64, np.float64, np.dtype('<M8[ns]')],
    'binsize': [np.int64, np.float64],
    'ipsubnet': [np.object],
    'ipv6_hextets': [np.object],
    'log10': [np.int64, np.float64],
    'round': [np.float64],
    'cyclic': [np.int64, np.float64],
    'second': [np.dtype('<M8[ns]')],
    'minute': [np.dtype('<M8[ns]')],
    'hour': [np.dtype('<M8[ns]')],
    'day': [np.dtype('<M8[ns]')],
    'month': [np.dtype('<M8[ns]')],
    'year': [np.dtype('<M8[ns]')],
    'minute_of_hour': [np.dtype('<M8[ns]')],
    'hour_of_day': [np.dtype('<M8[ns]')],
    'day_of_week': [np.dtype('<M8[ns]')],
    'day_of_month': [np.dtype('<M8[ns]')],
    'month_of_year': [np.dtype('<M8[ns]')],
}
BINNING_ACCEPTABLE_TYPES = {
    'none': ['str', 'int64', 'float64', 'datetime', 'timestamp', 'ip', 'date', 'time'],
    'binsize': ['int64', 'float64'],
    'ipsubnet': ['ip'],
    'ipv6_hextets': ['ip'],
    'log10': ['int64', 'float64'],
    'round': ['float64'],
    'cyclic': ['int64', 'float64'],
    'second': ['datetime', 'timestamp', 'time'],
    'minute': ['datetime', 'timestamp', 'time'],
    'hour': ['datetime', 'timestamp', 'time'],
    'day': ['datetime', 'timestamp', 'date'],
    'month': ['datetime', 'timestamp', 'date'],
    'year': ['datetime', 'timestamp', 'date'],
    'minute_of_hour': ['datetime', 'timestamp', 'time'],
    'hour_of_day': ['datetime', 'timestamp', 'time'],
    'day_of_week': ['datetime', 'timestamp', 'date'],
    'day_of_month': ['datetime', 'timestamp', 'date'],
    'month_of_year': ['datetime', 'timestamp', 'date'],
}
TIME_BINNING_FORMATS = {
    'second': '%Y-%m-%d %H:%M:%S',
    'minute': '%Y-%m-%d %H:%M:00',
    'hour': '%Y-%m-%d %H:00:00',
    'day': '%Y-%m-%d',
    'month': '%Y-%m',
    'year': '%Y',
    'minute_of_hour': '%M',
    'hour_of_day': '%H',
    'day_of_week': '%w_%A',
    'day_of_month': '%d',
    'month_of_year': '%m_%B'
}
#------------------------------------------------------------------------------
# Private functions
#------------------------------------------------------------------------------
def _get_mode_labels(df, column, sort):
    labels = None
    if column in sort:
        labels = df[column].drop_duplicates().reset_index(drop=True).sort_values()
    else:
        labels = df[column].drop_duplicates().reset_index(drop=True)

    labels_dict = {v: k for k, v in enumerate(labels)}
    return (labels, labels_dict)

def _get_none_queries(label, column):
    return column + ' == ' + str(label)

def _get_log10_queries(label, column):
    label = int(float(label))
    if label == 0:
        return column + ' > -10 AND ' + column + ' < 10'
    if label < 0:
        label = abs(label)
        return column + ' > -' + str(10**(label+1)) + ' AND ' \
             + column + ' <= -' + str(10**label)
    else:
        return column + ' >= ' + str(10**label) + ' AND ' \
             + column + ' < ' + str(10**(label+1))

def _get_round_queries(label, precision, column):
    label = float(label)
    upper = round(label + 10**(-precision), FLOAT_SIG_DIGS)
    return column + ' >= ' + str(label) + ' AND ' \
         + column + ' < ' + str(upper)

def _get_binsize_queries(label, binsize, column):
    label = float(label)
    lower = round(binsize * label, FLOAT_SIG_DIGS)
    upper = round(binsize * (label + 1), FLOAT_SIG_DIGS)
    return column + ' >= ' + str(lower) + ' AND ' \
         + column + ' < ' + str(upper)

def _get_cyclic_queries(label, numbins, binsize, columns):
    period = round(numbins * binsize, FLOAT_SIG_DIGS)
    return '(' + columns + ' - MIN) % ' + str(period) \
               + ' / ' + str(binsize) + ' == ' + str(label)

def _get_ipsubnet_queries(label, mask, column):
    if '+' in mask:
        ipv4_mask, ipv6_mask = mask.split('+')
    elif '.' in mask:
        ipv4_mask = mask
        ipv6_mask = None
    else:
        ipv4_mask = None
        ipv6_mask = mask
    if '.' in label:
        if ipv4_mask:
            return column + ' & ' + ipv4_mask + ' == ' + label
        else:
            return column + ' == ' + label
    else:
        if ipv6_mask:
            return column + ' & ' + ipv6_mask + ' == ' + label
        else:
            return column + ' == ' + label

def _get_ipv6_hextets_queries(label, mask, column):
    num_hextets, end = mask.split(':')
    if end == 'MSB':
        mask = ':'.join(int(num_hextets)*['ffff',]) + '::'
    else:
        mask = '::' + ':'.join(int(num_hextets)*['ffff',]) 
    if ':' in label:
        return column + ' & ' + mask + ' == ' + label
    else:
        return column + ' == ' + label

def _get_time_queries(label, time_unit, column):
    if time_unit == 'year':
        t0 = date(int(label), 1, 1)
        t1 = date(int(label) + 1, 1, 1)
    elif time_unit == 'month':
        year, month = map(int, label.split('-'))
        t0 = date(year, month, 1)
        if month == 12:
            t1 = date(year + 1, 1, 1)
        else:
            t1 = date(year, month + 1, 1)
    elif time_unit == 'day':
        year, month, day = map(int, label.split('-'))
        t0 = date(year, month, day)
        t1 = t0 + timedelta(days=1)
    else: 
        if ' ' in label:
            label_date, label_time = label.split(' ')
            year, month, day = map(int, label_date.split('-'))
            hour, minute, second = map(int, label_time.split(':'))
            t0 = datetime(year, month, day, hour, minute, second)
            if time_unit == 'hour':
                t1 = t0 + timedelta(hours=1)
            elif time_unit == 'minute':
                t1 = t0 + timedelta(minutes=1)
            else:
                t1 = t0 + timedelta(seconds=1)
        else:
            hour, minute, second = map(int, label.split(':'))
            t0 = date_time(hour, minute, second)
            if time_unit == 'hour':
                if hour == 23:
                    return column + ' >= ' + label
                else:
                    t1 = date_time(hour+1, 0, 0)
            elif time_unit == 'minute':
                if hour == 23 and minute == 59:
                    return column + ' >= ' + label
                elif minute == 59:
                    t1 = date_time(hour+1, 0, 0)
                else:
                    t1 = date_time(hour, minute+1, 0)
            else:
                if hour == 23 and minute == 59 and second == 59:
                    return column + ' >= ' + label
                elif minute == 59 and second == 59:
                    t1 = date_time(hour+1, 0, 0)
                elif second == 59:
                    t1 = date_time(hour, minute+1, 0)
                else:
                    t1 = date_time(hour, minute, second+1)

    return column + ' >= ' + str(t0) + ' AND ' + column + ' < ' + str(t1)

def _get_cyclic_time_queries(label, time_unit, column):
    return time_unit.upper() + '(' + column + ') == ' + label

def _get_queries(label_list, bin_specs, joiner, column_name):
    if not isinstance(bin_specs, list):
        bin_specs = [bin_specs]
    queries = []
    labels_split = list(map(lambda label: str(label).split(joiner), label_list))
    column_split = column_name.split(FUSED_COLUMN_NAME_DELIMITER) 
    
    for col_id, bin_spec in enumerate(bin_specs):
        labels = [unfused_labels[col_id] for unfused_labels in labels_split]
        column = column_split[col_id]

        if '=' in bin_spec:
            bin_type, bin_value = bin_spec.split('=')
        else:
            bin_type = bin_spec

        if bin_type == 'none':
            queries.append(list(map(lambda label: 
                _get_none_queries(label, column), labels)))
        elif bin_type == 'log10':
            queries.append(list(map(lambda label: 
                _get_log10_queries(label, column), labels)))
        elif bin_type == 'round':
            queries.append(list(map(lambda label: 
                _get_round_queries(label, int(bin_value), column), labels)))
        elif bin_type == 'binsize':
            queries.append(list(map(lambda label: 
                _get_binsize_queries(label, float(bin_value), column), labels)))
        elif bin_type == 'cyclic':
            numbins, binsize = map(float, bin_value.split(':'))
            queries.append(list(map(lambda label:
                _get_cyclic_queries(label, numbins, binsize, column), labels)))
        elif bin_type == 'ipsubnet':
            queries.append(list(map(lambda label:
                _get_ipsubnet_queries(label, bin_value, column), labels)))
        elif bin_type == 'ipv6_hextets':
            queries.append(list(map(lambda label:
                _get_ipv6_hextets_queries(label, bin_value, column), labels)))
        elif bin_type in ['second', 'minute', 'hour', 'day', 'month', 'year']:
            queries.append(list(map(lambda label: 
                _get_time_queries(label, bin_type, column), labels)))
        elif bin_type in ['minute_of_hour', 'hour_of_day', 'day_of_week', \
                        'day_of_month', 'month_of_year']:
            queries.append(list(map(lambda label:
                _get_cyclic_time_queries(label, bin_type, column), labels)))
        
    if len(queries) == 1:
        return queries[0]
    else:
        return ['(' + ') AND ('.join([q[i] for q in queries]) + ')' 
                for i in range(len(labels))]
        
def _get_labels(df, columns, types, binning, sort, joiner, gen_queries):
    label_lists = []
    query_lists = [] if gen_queries else None
    for mode_id, column in enumerate(columns):
        labels, labels_dict = _get_mode_labels(df, column, sort)
        df[column] = df[column].map(labels_dict)
        if types[mode_id] == 'int64':
            labels = labels.replace(MAX_INT, 'nan')
        elif types[mode_id] == 'float64':
            labels = labels.replace(MAX_FLOAT, np.nan)
        labels = labels.values.tolist()
        label_lists.append(labels)
        if gen_queries:
            query_lists.append(_get_queries(labels, binning[mode_id], joiner, column))
    return df, label_lists, query_lists

def _build_sptensor(df, columns, labels, spt_backtrack, queries):
    tensor = spt.SPTensor(order=len(columns), 
                          mode_sizes=list(map(len, labels)), 
                          mode_names=columns, 
                          labels=labels)
    tensor.nnz = len(df.index)
    tensor.entries = df[columns+['val_idx']]
    tensor.spt_backtrack = spt_backtrack
    tensor.queries = queries
    return tensor

def _build_boolean_tensor(df, gen_backtrack, dask_client, failure_counter):
    if gen_backtrack:
        log.info('Generating spt_backtrack information ...')

        mode_names = df.columns.to_list()
        mode_names.remove('backtrack')

        # group the data by the mode entries
        grouped_data = df.groupby(mode_names)
        df = grouped_data.first().reset_index()
        df['val_idx'] = 1
        df = df.drop(columns=['backtrack'])
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)

        # aggregate backtracking data in each group
        spt_backtrack = grouped_data.apply(lambda group: list(group.backtrack), meta=list)
        if dask_client:
            spt_backtrack = dask_client.compute(spt_backtrack, sync=True)
        else:
            spt_backtrack = spt_backtrack.compute(sync=True, scheduler='threads')
        
        df = pd.concat([df.to_frame(), spt_backtrack.to_frame()], axis=1)
        df.columns = ['val_idx', 'backtrack']
        df = df.reset_index()

        return df, list(df['backtrack'])
    else:
        df = df.drop_duplicates().reset_index(drop=True)
        df['val_idx'] = 1
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)
        return df, None

def _build_count_tensor(df, columns, gen_backtrack, dask_client, failure_counter):
    if gen_backtrack:
        log.info('Generating spt_backtrack information ...')

        mode_names = df.columns.to_list()
        mode_names.remove('backtrack')

        df['val_idx'] = 1

        # group the data by the mode entries, then calculate count
        grouped_data = df.groupby(mode_names)
        df = grouped_data.size(split_out=AGG_SPLIT)
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)

        # aggregate backtracking data in each group
        spt_backtrack = grouped_data.apply(lambda group: list(group.backtrack), meta=list)
        if dask_client:
            spt_backtrack = dask_client.compute(spt_backtrack, sync=True)
        else:
            spt_backtrack = spt_backtrack.compute(sync=True, scheduler='threads')
        
        df = pd.concat([df.to_frame(), spt_backtrack.to_frame()], axis=1)
        df.columns = ['val_idx', 'backtrack']
        df = df.reset_index()

        return df, list(df['backtrack'])
    else:
        df = df.groupby(df.columns.tolist()).size(split_out=AGG_SPLIT)
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)

        df = df.reset_index().rename(columns={0: 'val_idx'})
        return df, None

def _aggregate_tensor_values(grouped_df, agg_method, value_col_name):
    def chunk(grouped):
        return grouped.max(), grouped.min()
    def agg(chunk_maxes, chunk_mins):
        return chunk_maxes.max(), chunk_mins.min()
    def max_finalize(maxima, minima):
        return pd.DataFrame([maxima, minima]).agg(lambda s: max(s, key=abs))
    def min_finalize(maxima, minima):
        return pd.DataFrame([maxima, minima]).agg(lambda s: min(s, key=abs))

    if agg_method == 'sum':
        return grouped_df.sum(split_out=AGG_SPLIT)
    elif agg_method == 'max':
        return grouped_df.max(split_out=AGG_SPLIT)
    elif agg_method == 'min':
        return grouped_df.min(split_out=AGG_SPLIT)
    elif agg_method == 'first':
        return grouped_df.first(split_out=AGG_SPLIT)
    elif agg_method == 'last':
        return grouped_df.last(split_out=AGG_SPLIT)
    elif agg_method == 'mean':
        return grouped_df.mean(split_out=AGG_SPLIT)
    elif agg_method == 'prod':
        return grouped_df.prod(split_out=AGG_SPLIT)
    elif agg_method == 'idxmin':
        return grouped_df.idxmin(split_out=AGG_SPLIT)
    elif agg_method == 'idxmax':
        return grouped_df.idxmax(split_out=AGG_SPLIT)
    elif agg_method == 'max_abs':
        max_abs = dd.Aggregation('max_abs', chunk, agg, finalize=max_finalize)
        return grouped_df.agg(max_abs, split_out=AGG_SPLIT)
    elif agg_method == 'min_abs':
        min_abs = dd.Aggregation('min_abs', chunk, agg, finalize=min_finalize)
        return grouped_df.agg(min_abs, split_out=AGG_SPLIT)

def _build_value_tensor(df, entries, columns, gen_backtrack, dask_client, failure_counter):
    value_col_name, agg_method = entries.split('=')[1].split(':')
    grouped_values = df[columns + [value_col_name]].groupby(columns)
    
    if gen_backtrack:
        log.info('Generating spt_backtrack information ...')

        grouped_backtrack = df[columns + ['backtrack']].groupby(columns)

        # group the data by the mode entries, then calculate count
        df = _aggregate_tensor_values(grouped_values, agg_method, value_col_name)
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)

        # aggregate backtracking data in each group
        spt_backtrack = grouped_backtrack.apply(lambda group: list(group.backtrack), meta=list)
        if dask_client:
            spt_backtrack = dask_client.compute(spt_backtrack, sync=True)
        else:
            spt_backtrack = spt_backtrack.compute(sync=True, scheduler='threads')
        
        df = pd.concat([df, spt_backtrack.to_frame()], axis=1)
        df.columns = ['val_idx', 'backtrack']
        df = df.reset_index()

        return df, list(df['backtrack'])
    else:
        df = _aggregate_tensor_values(grouped_values, agg_method, value_col_name)
        df = _compute_df(df, dask_client)
        _reset_failure_count(failure_counter)
        df.columns = ['val_idx']
        df = df.reset_index()
        return df, None

def _calc_entries(df, entries, columns, gen_backtrack, dask_client, failure_counter):
    if 'value' in entries:
        return _build_value_tensor(df, entries, columns, gen_backtrack, dask_client, failure_counter)
    elif entries == "boolean":
        return _build_boolean_tensor(df, gen_backtrack, dask_client, failure_counter)
    elif entries == "count":
        return _build_count_tensor(df, columns, gen_backtrack, dask_client, failure_counter)

def _fuse_names(joiner):
    def _fuse_func(x):
        return joiner.join(x.astype(str))
    return _fuse_func

def _fuse_columns(df, fuse_columns, joiner, header, binning, sort):
    if fuse_columns is None or len(fuse_columns) == 0:
        return df, header, binning, sort

    for columns in fuse_columns:
        log.info("  Fusing columns: {}".format(columns))
        new_col_name = ""
        new_binning = []
        sort_flag = 0
        for i, column in enumerate(columns):
            if i != 0:
                new_col_name += FUSED_COLUMN_NAME_DELIMITER
            new_col_name = new_col_name + str(column)
            index = header.index(column)
            new_binning.append(binning[index])
            header.pop(index)
            binning.pop(index)
            if column in sort:
                sort.remove(column)
                sort_flag = 1
        header.append(new_col_name)
        binning.append(new_binning)
        if sort_flag:
            sort.append(new_col_name)
            sort_flag = 0
        df[new_col_name] = df[columns].apply(
            _fuse_names(joiner), axis=1, meta=(new_col_name, object))

    return df, header, binning, sort

def _ip_to_subnet(dotted_ip, ipv4_subnet_mask, ipv6_subnet_mask):
    if ':' in dotted_ip:
        if ipv6_subnet_mask is None:
            return dotted_ip

        ip_list = dotted_ip.split(':') # Split IP string into hextets
        while len(ip_list) < 8: # Expand :: abbreviation
            ip_list.insert(ip_list.index(''), '0')
        while '' in ip_list: # Replace '' with '0'
            ip_list[ip_list.index('')] = '0'

        mask_ip_list = ipv6_subnet_mask.split(':')
        while len(mask_ip_list) < 8: # Expand :: abbreviation
            mask_ip_list.insert(mask_ip_list.index(''), '0')
        while '' in mask_ip_list: # Replace '' with '0'
            mask_ip_list[mask_ip_list.index('')] = '0'

        masked_ip = [] # Mask
        for hextet, mask in zip(ip_list, mask_ip_list):
            masked_ip.append(str(hex(int(hextet, base=16) & int(mask, base=16))[2:]))

        return ':'.join(masked_ip)

    else:
        if ipv4_subnet_mask is None:
            return dotted_ip

        masked_ip = []
        for octet, mask in zip(dotted_ip.split('.'), ipv4_subnet_mask.split('.')):
            masked_ip.append(str(int(octet) & int(mask)))
        
        return '.'.join(masked_ip)

def _bin_by_ipsubnet(df, column, ipv4_subnet_mask, ipv6_subnet_mask):
    df[column] = df[column].apply(
        lambda x: _ip_to_subnet(x, ipv4_subnet_mask, ipv6_subnet_mask), 
        meta=(column, str))
    return df

def _least_significant_ipv6_hextets(ip_string, num_hextets):
    ip_list = ip_string.split(':') # Split IP string into hextets

    while len(ip_list) < 8: # Expand :: abbreviation
        ip_list.insert(ip_list.index(''), '0')

    ip_list = ip_list[-num_hextets:] # Get last 64 bits

    while '' in ip_list:
        ip_list[ip_list.index('')] = '0'

    return '::' + ':'.join(ip_list)

def _most_significant_ipv6_hextets(ip_string, num_hextets):
    ip_list = ip_string.split(':') # Split IP string into hextets

    while len(ip_list) < 8: # Expand :: abbreviation
        ip_list.insert(ip_list.index(''), '0')

    ip_list = ip_list[:num_hextets] # Get first 64 bits

    while '' in ip_list:
        ip_list[ip_list.index('')] = '0'

    return ':'.join(ip_list) + '::'

def _ipv6_hextets(df, column, num_hextets, direction):
    if direction == 'MSB':
        df[column] = df[column].apply(
            lambda x: _most_significant_ipv6_hextets(x, num_hextets) if ':' in x else x,
            meta=(column, str))
    elif direction == 'LSB':
        df[column] = df[column].apply(
            lambda x: _least_significant_ipv6_hextets(x, num_hextets) if ':' in x else x,
            meta=(column, str))
    return df

def _check_datetime(failure_counter):
    def datetime_checker(val):
        if pd.isna(val):
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to format '{val}' as a datetime."
            log.info(msg)
        return str(val)
    return datetime_checker

def _check_time(failure_counter):
    def time_checker(val):
        if pd.isna(val):
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to format '{val}' as a datetime."
            log.info(msg)
            return str(val)
        else:
            return str(val)[11:] # Cut off date portion
    return time_checker

def _check_date(failure_counter):
    def date_checker(val):
        if pd.isna(val):
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to format '{val}' as a datetime."
            log.info(msg)
            return str(val)
        else:
            return str(val)[:10] # Cut off time portion
    return date_checker

def _format_datetime(out_format, failure_counter):
    def datetime_formatter(val):
        if pd.isna(val):
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to format '{val}' as a datetime."
            log.info(msg)
            return str(val)
        else:
            return val.strftime(out_format)
    return datetime_formatter

def _bin_by_timespan(df, column, timespan, col_type, failure_counter):
    out_format = ""
    if timespan in TIME_BINNING_FORMATS.keys():
        out_format = TIME_BINNING_FORMATS[timespan]
        if col_type == 'time' and timespan != 'minute_of_hour' and timespan != 'hour_of_day':
            out_format = out_format[9:] # Cut off date portion
    else:
        msg = f"Unrecognized time binning '{timespan}'"
        log.error(msg)
        raise ValueError(msg)
    df[column] = df[column].apply(
        _format_datetime(out_format, failure_counter), 
        meta=(column, str))
    return df

def _bin_by_cycle(df, column, num_bins, bin_size, column_type, dask_client):
    if dask_client:
        col_min = float(dask_client.compute(df[column].min()))
    else:
        col_min = float(df[column].min().compute(scheuler='threads'))

    period = num_bins * bin_size
    if column_type == 'float64':
        df[column] = df[column].apply(
            lambda x: x if (x == MAX_FLOAT or x == np.inf) 
                        else (x - col_min) % period // bin_size, 
            meta=(column, float))
    elif column_type == 'int64':
        df[column] = df[column].apply(
            lambda x: x if x == MAX_INT
                        else int((x - col_min) % period / bin_size), 
            meta=(column, int))
    return df

def _bin_by_rounding(df, column, prec):
    df[column] = df[column].apply(
        lambda x: x if (x == MAX_FLOAT or x == np.inf) else round(x, prec), 
        meta=(column, float))
    return df

def _bin_by_bin_size(df, column, column_type, binsize):
    if column_type == 'float64':
        df[column] = df[column].apply(
            lambda x: x if (x == MAX_FLOAT or x == np.inf) \
                        else np.floor(round(x / binsize, FLOAT_SIG_DIGS)), 
            meta=(column, float))
    elif column_type == 'int64':
        df[column] = df[column].apply(
            lambda x: x if (x == MAX_INT) 
                        else np.floor(round(x / binsize, FLOAT_SIG_DIGS)), 
            meta=(column, int))
    return df

def _ensign_log_binning_float(x):
    if x == 0:
        return 0
    elif x == MAX_FLOAT:
        return x
    elif x > 0 and x != np.inf:
        return int(np.log10(x))
    elif x < 0 and x != np.inf:
        return int(-1 * np.log10(-1 * x))
    else:
        return x

def _ensign_log_binning_int(x):
    if x == 0:
        return 0
    elif x > 0 and x != MAX_INT:
        return int(np.log10(x))
    elif x < 0:
        return int(-1 * np.log10(-1 * x))
    else:
        return x

def _bin_by_log(df, column, column_type):
    if column_type == 'float64':
        df[column] = df[column].apply(
            _ensign_log_binning_float, 
            meta=(column, float))
    elif column_type == 'int64':
        df[column] = df[column].apply(
            _ensign_log_binning_int, 
            meta=(column, int))
    return df

def _bin_columns(df, columns, binning, types, dask_client, failure_counter):
    for mode_id, bin_spec in enumerate(binning):
        log.info(f"  Binning mode {mode_id} ({bin_spec})")
        bin_type = 'none'
        bin_value = ''
        column = columns[mode_id]
        if "=" in bin_spec:
            split_spec = bin_spec.split("=")
            bin_type = split_spec[0]
            bin_value = split_spec[1]
        else:
            bin_type = bin_spec
        
        if bin_type == "log10":
            df = _bin_by_log(df, column, types[mode_id])
        elif bin_type == "binsize":
            bin_size = float(bin_value)
            df = _bin_by_bin_size(df, column, types[mode_id], bin_size)
        elif bin_type == "round":
            prec = int(bin_value)
            df = _bin_by_rounding(df, column, prec)
        elif bin_type == "cyclic":
            num_bins, bin_size = bin_value.split(":")
            num_bins = int(num_bins)
            bin_size = float(bin_size)
            df = _bin_by_cycle(
                df, column, num_bins, bin_size, types[mode_id], dask_client)
        elif bin_type == "ipsubnet":
            if '+' in bin_value:
                ipv4_subnet_mask, ipv6_subnet_mask = bin_value.split('+')
            elif '.' in bin_value:
                ipv4_subnet_mask, ipv6_subnet_mask = bin_value, None
            elif ':' in bin_value:
                ipv4_subnet_mask, ipv6_subnet_mask = None, bin_value
            df = _bin_by_ipsubnet(df, column, ipv4_subnet_mask, ipv6_subnet_mask)
        elif bin_type == "ipv6_hextets":
            num_hextets, direction = bin_value.split(":")
            num_hextets = int(num_hextets)
            df = _ipv6_hextets(df, column, num_hextets, direction)
        elif bin_type in VALID_TIME_BIN_TYPES:
            df = _bin_by_timespan(df, column, bin_type, types[mode_id], failure_counter)
        elif bin_type == "none" and (types[mode_id] == 'datetime' or 
                                     types[mode_id] == 'timestamp'):
            # Check for missing values in datetime column
            df[column] = df[column].apply(
                _check_datetime(failure_counter),
                meta=(column, np.dtype('<M8[ns]')))
        elif bin_type == "none" and types[mode_id] == 'time':
            # Check for missing values in datetime column
            # and cut off date portion
            df[column] = df[column].apply(
                _check_time(failure_counter),
                meta=(column, str))
        elif bin_type == "none" and types[mode_id] == 'date':
            # Check for missing values in datetime column
            # and cut off time portion
            df[column] = df[column].apply(
                _check_date(failure_counter),
                meta=(column, str))

    return df

def _combine_and_filter_dfs(dfs, columns, entries, binning, sort, gen_backtrack, in_memory):
    aux_cols = []
    if gen_backtrack:
        aux_cols.append('backtrack')

        if in_memory:
            backtrack_dfs = []
            for i, df in enumerate(dfs):
                npartitions = df.npartitions
                df = df.repartition(npartitions=1).reset_index()
                df = df.repartition(npartitions=npartitions).reset_index()
                df['backtrack'] = df.level_0.map(lambda x: (i, x))
                df = df.drop(columns=['level_0', 'index'])
                backtrack_dfs.append(df)
            dfs = backtrack_dfs

    df = dd.concat(dfs)

    if 'value' in entries:
        aux_cols.append(entries.split('=')[1].split(':')[0])

    filter_cols = copy.deepcopy(columns + aux_cols)
    df = df[filter_cols]

    columns, sort = _rename_duplicate_columns(columns, binning, sort)
    df.columns = columns + aux_cols

    return df, columns

def _compute_df(df, dask_client):
    if dask_client:
        df = dask_client.compute(df, sync=True)
    else:
        df = df.compute(sync=True, scheduler='threads')
    return df
    
def _reset_failure_count(failure_counter):
    if not failure_counter:
        return
    if failure_counter['distributed']:
        failure_counter['num_failures'].set(0)
    else:
        global num_failures_global
        num_failures_global = 0

#------------------------------------------------------------------------------
# Validation
#------------------------------------------------------------------------------
def _validate_bro_log(infile):
    # Check that the first 8 lines start with '#'
    with open(infile, "r") as f:
        line_num = 0
        for line in f:
            if line_num <= NUM_BRO_HEADER_LINES:
                try:
                    line.index('#')
                except:
                    msg = f'Infile: {infle} is not a valid Bro/Zeek-log'
                    log.error(msg)
                    raise TypeError(msg)
            else:
                break
            line_num += 1

    return True

def _validate_datetime_format_string(format_string):
    # Checks the character after % to make sure it is a valid directive.
    msg = f"Invalid time type: {format_string} not a valid datetime format string."
    for i, char in enumerate(format_string):
        if char == '%':
            try: # In case format string ends with %
                if format_string[i+1] not in DATETIME_DIRECTIVES:
                    log.error(msg)
                    raise TypeError(msg)
            except:
                log.error(msg)
                raise TypeError(msg)
    return True

def _validate_subnet_mask(subnet_mask):
    # raises an error if not a valid IPv4 or IPv6 address
    ip_address_test = ip_address(subnet_mask)

    return True

def _validate_columns(columns, df, arg):
    if columns is None:
        return True

    if not _is_list_of_str(columns):
        msg = f"'{arg}' argument invalid: not None or list of str: {columns}"
        log.error(msg)
        raise TypeError(msg)
    
    for column in columns:
        if column not in df.columns:
            msg = f"Column '{column}' invalid: not found in input data"
            log.error(msg)
            raise TypeError(msg)

    return True

def _validate_types(types):
    if types is None:
        return True

    if not _is_list_of_str(types):
        msg = f"'types' invalid: not None or list of str: {types}"
        log.error(msg)
        raise TypeError(msg)
    
    for type_spec in types:
        if type_spec not in VALID_TYPES:
            msg = f"Type '{type_spec}' invalid: not a recognized type"
            log.error(msg)
            raise TypeError(msg)

    return True

def _validate_binning(binning, types):
    if binning is None:
        return True
    if not _is_list_of_str(binning):
        msg = f"'binning' invalid: must be None or list of str: {binning}"
        log.error(msg)
        raise TypeError(msg)
    
    for mode_id, bin_spec in enumerate(binning):
        has_value = False
        bin_type = bin_spec
        bin_value = None

        if '=' in bin_spec:
            has_value = True
            bin_type = bin_spec.split('=')[0]
            bin_value = bin_spec.split('=')[1]

        if bin_type not in VALID_NUM_BIN_TYPES and\
           bin_type not in VALID_IP_BIN_TYPES and\
           bin_type not in VALID_TIME_BIN_TYPES:
            msg = f"Bin type '{bin_type}' invalid: not a recognized bin type"
            log.error(msg)
            raise TypeError(msg)

        if (bin_type == "binsize" or bin_type == "round" or\
            bin_type == "ipsubnet") and has_value == False:
            msg = f"Bin type '{bin_type}' invalid: requires '=<value>'"
            log.error(msg)
            raise TypeError(msg)

        if bin_type == "cyclic" and has_value == False:
            msg = f"Bin type '{bin_type}' invalid: requires '=<num_bins>:<bin_size>'"
            log.error(msg)
            raise TypeError(msg)

        if bin_type == 'ipsubnet':
            if '+' in bin_value:
                ipv4_subnet_mask, ipv6_subnet_mask = bin_value.split('+')
                if not _validate_subnet_mask(ipv4_subnet_mask) and\
                   not _validate_subnet_mask(ipv6_subnet_mask):
                    msg = f"Bin type '{bin_type}' invalid: requires a valid subnet mask"
                    log.error(msg)
                    raise TypeError(msg)
            else:
                if not _validate_subnet_mask(bin_value):
                    msg = f"Bin type '{bin_type}' invalid: requires a valid subnet mask"
                    log.error(msg)
                    raise TypeError(msg)

        if bin_type == "binsize": 
            try:
                float(bin_value)
            except:
                msg = f"Bin type '{bin_type}' invalid: requires a decimal size"
                log.error(msg)
                raise ValueError(msg)

            if float(bin_value) <= 0.0:
                msg = f"Bin type '{bin_type}' invalid: requires a size > 0.0"
                log.error(msg)
                raise ValueError(msg)

        if bin_type == "round":
            try:
                int(bin_value)
            except:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer precision"
                log.error(msg)
                raise ValueError(msg)

            if float(bin_value) <= 0.0:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer precision"
                log.error(msg)
                raise ValueError(msg)

        if bin_type == "cyclic":
            try:
                num_bins, bin_size = bin_value.split(":")
            except:
                msg = f"Bin type '{bin_type}' invalid: requires colon-separated bin count and size."
                log.error(msg)
                raise ValueError(msg)

            try:
                int(num_bins)
            except:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer number of bins"
                log.error(msg)
                raise ValueError(msg)
            if int(num_bins) <= 0:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer number of bins"
                log.error(msg)
                raise ValueError(msg)

            try:
                float(bin_size)
            except:
                msg = f"Bin type '{bin_type}' invalid: requires a decimal size."
                log.error(msg)
                raise ValueError(msg)
            if float(bin_size) <= 0.0:
                msg = f"Bin type '{bin_type}' invalid: requires a size > 0.0"
                log.error(msg)
                raise ValueError(msg)

        if bin_type == "ipv6_hextets":
            try:
                num_hextets, direction = bin_value.split(":")
            except:
                msg = f"Bin type '{bin_type}' invalid: requires colon-separated hextet count and direction."
                log.error(msg)
                raise ValueError(msg)

            try:
                int(num_hextets)
            except:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer number of hextets"
                log.error(msg)
                raise ValueError(msg)

            if int(num_hextets) <= 0 or int(num_hextets) >= 8:
                msg = f"Bin type '{bin_type}' invalid: requires a positive integer number of hextets less than 8."
                log.error(msg)
                raise ValueError(msg)

            if direction != 'MSB' and direction != 'LSB':
                msg = f"Bin type '{bin_type}' invalid: direction must be 'MSB' or 'LSB'"
                log.error(msg)
                raise ValueError(msg)

    return True

def _validate_entries(entries):
    if not (entries == 'count' or entries == 'boolean'):
        msg = f"Entry type '{entries}' invalid. Only 'count', 'boolean', and 'value=<column_name>:<aggregation_method> are supported."
        if 'value=' in entries:
            _, value_args = entries.split('=')
            if ':' not in value_args:
                log.error(msg)
                raise TypeError(msg)
            col_name, agg_method = value_args.split(':')
            if agg_method not in VALUE_TENSOR_AGG_METHODS:
                msg = f"Aggregation method for value-tensor construction invalid. Valid aggregation methods include: {VALUE_TENSOR_AGG_METHODS}"
                log.error(msg)
                raise TypeError(msg)
        else:
            log.error(msg)
            raise TypeError(msg)
    return True

def _validate_config(df, columns, types, binning, entries, sort, 
                     fuse_columns, joiner):
    columns_are_valid = _validate_columns(columns, df, 'columns')
    sort_columns_are_valid = _validate_columns(sort, df, 'sort')
    types_are_valid = _validate_types(types)
    binning_is_valid = _validate_binning(binning, types)

    if not set(sort).issubset(set(columns)):
        msg = f"Columns to sort ({sort}) is not a subset of the columns selected to be a part of the tensor ({columns})."
        log.error(msg)
        raise TypeError(msg)

    if not (columns_are_valid and types_are_valid and 
            binning_is_valid and sort_columns_are_valid):
        return False

    if len(columns) != len(types) or len(columns) != len(binning):
        msg = ("Number of columns, types, and binning must be equal"
               f"\n  columns length  : {len(columns)}"
               f"\n  types length    : {len(types)}"
               f"\n  binnings length : {len(binning)}")
        log.error(msg)
        raise TypeError(msg)

    if fuse_columns is not None:
        # Check if each set of columns to be fused has at least two columns
        for col_group in fuse_columns:
            if len(col_group) < 2:
                msg = "Number of columns to fuse must be greater than one"
                log.error(msg)
                raise TypeError(msg)

            for col in col_group:
                if '_by_' in col:
                    col = col.split('_by_')[0]
                if col not in df.columns.values.tolist():
                    msg = f"Fuse column {col} does not exist."
                    log.error(msg)
                    raise TypeError(msg)

    for col, col_type, bin_type in zip(columns, types, binning):
        np_col_type = PANDAS_ENSIGN_TYPE_DICT[col_type]
        if np_col_type != df[col].dtype:
            msg = f'Dtype of column {col}: {df[col].dtype} does not match declared type {col_type} ({np_col_type})'
            log.error(msg)
            raise TypeError(msg)

        if '=' in bin_type:
            idx = bin_type.index('=')
            bin_type = bin_type[:idx]
        if np_col_type not in BINNING_ACCEPTABLE_NP_TYPES[bin_type]:
            msg = f'Column type {col_type} not appropriate for binning specification {bin_type}. Check documentation for proper usage.'
            log.error(msg)
            raise TypeError(msg)

        if col_type not in BINNING_ACCEPTABLE_TYPES[bin_type]:
            msg = f'Column type {col_type} not appropriate for binning specification {bin_type}. Check documentation for proper usage.'
            log.error(msg)
            raise TypeError(msg)

    return True

#------------------------------------------------------------------------------
# Helpers
#------------------------------------------------------------------------------
def _check_failures_exceeded(num_failures, max_failures, distributed):
    max_fails_msg = 'Maximum number of missing/corrupted values ({}) exceeded.'
    if distributed:
        num_failures.set(num_failures.get() + 1)
        if num_failures.get() > max_failures:
            raise ValueError(max_fails_msg.format(max_failures))
    else:
        global num_failures_global
        num_failures_global += 1
        if num_failures_global > max_failures:
            raise ValueError(max_fails_msg.format(max_failures))

def _get_failure_counter(max_failures, dask_client):
    """
    Helper function to track the number of values that fail to be typed.

    This dictionary is used to call the _check_failures_exceeded function 
    defined above.

    The points that failures are counted are in the _type_float, _type_int,
    _format_datetime, _date_checker, _time_checker, and _datetime_checker 
    functions. The typing functions are run when the CSV is read in, whereas 
    the _format_datetime function is running during the binning process. If a 
    time column is not binned, _datetime_checker is used.

    This is done because typing float64 and int64 is done with a custom 
    'converter' function that allows for the addition of failure counting.
    The typing of temporal columns is done with the pandas to_datetime function 
    which does not allow for our custom failure counting code. Therefore it is 
    done at binning time where we use custom functions to bin, allowing for 
    failure counting.
    """
    failure_counter = None
    if max_failures is not None:
        if max_failures < 0:
            msg = 'Number of failures tolerated cannot be less than zero.'
            log.error(msg)
            raise ValueError(msg)
        distributed = 0
        if dask_client is not None:
            distributed = 1
            num_failures = Variable('num_failures')
            num_failures.set(0)
        else:
            global num_failures_global
            num_failures_global = 0
            num_failures = num_failures_global
        failure_counter = {
            'num_failures': num_failures, 
            'max_failures': max_failures, 
            'distributed': distributed
        }
    return failure_counter

def _is_list_of_str(arg):
    if not (not hasattr(arg, "strip") and
           (hasattr(arg, "__get_item__") or
            hasattr(arg, "__iter__"))):
       return False

    if len(arg) > 0:
        for element in arg:
            if not isinstance(element, str):
                return False

    return True

def _rename_duplicate_columns(columns, binning, sort):
    cols, counts = np.unique(columns, return_counts=True)
    column_counts = dict(zip(cols, counts))

    renamed_columns = []
    for i in range(len(columns)):
        if column_counts[columns[i]] > 1:
            new_col_name = columns[i] + '_by_' + binning[i]
            renamed_columns.append(new_col_name)
            if columns[i] in sort:
                sort.append(new_col_name)
        else:
            renamed_columns.append(columns[i])

    if len(renamed_columns) != len(set(renamed_columns)):
        raise ValueError('Duplicate columns must have distinct binning.')

    return renamed_columns, sort

def _prepare_columns_types_binning_sort(df, columns, types, binning, sort, entries, bro_log=False):
    _validate_entries(entries)

    if columns is None or len(columns) == 0:
        if isinstance(df, dd.DataFrame):
            columns = list(df.columns)
        elif isinstance(df, str):
            if bro_log:
                columns = _get_bro_log_header(df)
            else:
                columns = linecache.getline(df, 1).strip().split(',')

    if types is None or len(types) == 0:
        types = ['str'] * len(columns)
    if binning is None or len(binning) == 0:
        binning = ['none'] * len(columns)
    if sort is None:
        sort = []
        # Automatically sort time based columns
        for i, col_type in enumerate(types):
            if col_type in TIME_TYPES:
                sort.append(columns[i])

    if 'value' in entries:
        value_col_name, agg_method = entries.split('=')[1].split(':')

        if value_col_name in columns:
            types[columns.index(value_col_name)] = 'float64'

    return (columns, types, binning, sort) 

def _get_bro_log_header(logfile):
    with open(logfile, "r") as f:
        line_num = 0
        for line in f:
            line_num += 1
            if line_num < NUM_BRO_HEADER_LINES:
                continue
            elif line_num == NUM_BRO_HEADER_LINES:
                header = line.split()[1:]
            else:
                break
    return header

def _set_log_level(log_level):
    if log_level not in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL', 10, 20, 30, 40, 50]:
        log.warning("Invalid log-level provided. Defaulting to logging.INFO")
        log_level = 20
    ensign_logging.set_msg_only_output()
    if log_level=='DEBUG': log_level = 10
    if log_level=='INFO': log_level = 20
    if log_level=='WARNING': log_level = 30
    if log_level=='ERROR': log_level = 40
    if log_level=='CRITICAL': log_level = 50
    log.setLevel(log_level)

def _ensign_read_csv(input_file, columns, types, binning, header=None, delimiter=',',
                     entries='count', drop_na=False, failure_counter=None, backtrack=False):
    """
    Read a CSV into a DataFrame and perform efficent preprocessing
    according to typing and binning specifications.
    For use with ensign.csv2tensor.df2tensor()

    Parameters
    ----------
    input_file : str
        Path to CSV.
    **kwargs
        See ensign.csv2tensor.csv2tensor()

    Returns
    -------
    df : dask.dataframe.DataFrame
        DataFrame containing data from input_file.
    """
    value_converter = {}
    if 'value' in entries:
        value_converter[entries.split('=')[1].split(':')[0]] = _type_value(failure_counter)

    if header is not None:
        converters, datetime_cols, timestamp_cols = _get_converters(
            types, columns, binning, failure_counter)
        converters.update(value_converter)

        sample_line = linecache.getline(input_file, 1).strip().split(',')

        if sample_line == header:
            csv_df = dd.read_csv(input_file, comment='#', names=header, 
                sep=delimiter, header=0, usecols=columns+list(value_converter), 
                converters=converters, engine='c') # header=0 skips the first line
        else:
            csv_df = dd.read_csv(input_file, comment='#', names=header, 
                sep=delimiter, usecols=columns+list(value_converter), 
                converters=converters, engine='c')
        
    else:
        if columns is None:
            columns = linecache.getline(input_file, 1).strip().split(',')

        converters, datetime_cols, timestamp_cols = _get_converters(
            types, columns, binning, failure_counter)
        converters.update(value_converter)

        csv_df = dd.read_csv(input_file, comment='#', 
            usecols=columns+list(value_converter), sep=delimiter, 
            converters=converters, engine='c')

    _reset_failure_count(failure_counter)

    for col in timestamp_cols:
        csv_df[col] = dd.to_datetime(csv_df[col], unit='s', errors='coerce')
    for col in datetime_cols:
        csv_df[col] = dd.to_datetime(
            csv_df[col], infer_datetime_format=True, errors='coerce')

    if backtrack:
        npartitions = csv_df.npartitions
        csv_df = csv_df.repartition(npartitions=1).reset_index()
        csv_df = csv_df.repartition(npartitions=npartitions).reset_index()
        csv_df['backtrack'] = csv_df.level_0.map(lambda x: (input_file, x))
        csv_df = csv_df.drop(columns=['level_0', 'index'])

    # Drop rows where value column is NaN
    if value_converter:
        value_col_name = list(value_converter)[0]
        if value_col_name in columns:
            csv_df[value_col_name+'_val'] = csv_df[value_col_name]
            value_col_name = value_col_name + '_val'
            agg_method = entries.split('=')[1].split(':')[1]
            entries = f'value={value_col_name}:{agg_method}'
            csv_df[value_col_name] = csv_df[value_col_name].astype(float)
            csv_df[value_col_name] = csv_df[value_col_name].replace(np.inf, np.nan)
            csv_df[value_col_name] = csv_df[value_col_name].replace(MAX_INT, np.nan)
        csv_df = csv_df.loc[csv_df[value_col_name].notnull()]

    if drop_na:
        for col, col_type in zip(columns, types):
            if col_type == 'int64':
                csv_df[col] = csv_df[col].replace(MAX_INT, np.nan)
                
        csv_df = csv_df.dropna()

        for col, col_type in zip(columns, types):
            if col_type == 'int64':
                csv_df[col] = csv_df[col].astype(int)
    else:
        for col, col_type in zip(columns, types):
            if col_type == 'float64':
                csv_df[col] = csv_df[col].replace(np.nan, MAX_FLOAT)

    return csv_df, entries

def _ensign_read_bro_log(input_file, columns, types, binning, header=None,
                         entries='count', validate_bro_log=False, 
                         drop_na=False, failure_counter=None, backtrack=False):
    """
    Read a Bro/Zeek log into a DataFrame and perform efficent preprocessing
    according to typing and binning specifications.
    For use with ensign.csv2tensor.df2tensor()

    Parameters
    ----------
    input_file : str
        Path to CSV.
    **kwargs
        See ensign.csv2tensor.csv2tensor()

    Returns
    -------
    df : dask.dataframe.DataFrame
        DataFrame containing data from input_file.
    """
    # If the user requests to validate the bro log, do so, or if they turn it off
    if ((validate_bro_log and _validate_bro_log(input_file)) or not validate_bro_log):
        value_converter = {}
        if 'value' in entries:
            value_converter[entries.split('=')[1].split(':')[0]] = _type_value(failure_counter)

        if header is None:
            header = _get_bro_log_header(input_file)

        # Grab a temporary sample of infile to check if number of header matches
        # number of actual unnamed columns in infile.
        tmp_csv_df = dd.read_csv(
            input_file, comment='#', sep='\t', dtype='str').head()

        if (len(tmp_csv_df.columns.values.tolist()) != len(header)):
            msg = "Columns header names length does not match number of columns in CSV file."
            log.error(msg)
            raise TypeError(msg)

        # Read in infile for good.
        if columns == None:
            columns = header
        converters, datetime_cols, timestamp_cols = _get_converters(
            types, columns, binning, failure_counter)
        converters.update(value_converter)

        csv_df = dd.read_csv(input_file, sep='\t', low_memory=False, 
            names=header, header=None, comment='#', engine='c',
            usecols=columns+list(value_converter), converters=converters)

        _reset_failure_count(failure_counter)

        for col in timestamp_cols:
            csv_df[col] = dd.to_datetime(csv_df[col], unit='s', errors='coerce')
        for col in datetime_cols:
            csv_df[col] = dd.to_datetime(
                csv_df[col], infer_datetime_format=True, errors='coerce')

        if backtrack:
            npartitions = csv_df.npartitions
            csv_df = csv_df.repartition(npartitions=1).reset_index()
            csv_df = csv_df.repartition(npartitions=npartitions).reset_index()
            csv_df['backtrack'] = csv_df.level_0.map(lambda x: (input_file, x))
            csv_df = csv_df.drop(columns=['level_0', 'index'])

        # Drop rows where value column is NaN
        if value_converter:
            value_col_name = list(value_converter)[0]
            if value_col_name in columns:
                csv_df[value_col_name+'_val'] = csv_df[value_col_name]
                value_col_name = value_col_name + '_val'
                agg_method = entries.split('=')[1].split(':')[1]
                entries = f'value={value_col_name}:{agg_method}'
                csv_df[value_col_name] = csv_df[value_col_name].astype(float)
                csv_df[value_col_name] = csv_df[value_col_name].replace(np.inf, np.nan)
            csv_df = csv_df.loc[csv_df[value_col_name].notna()]

        if drop_na:
            for col, col_type in zip(columns, types):
                if col_type == 'int64':
                    csv_df[col] = csv_df[col].replace(MAX_INT, np.nan)

            csv_df = csv_df.dropna()

            for col, col_type in zip(columns, types):
                if col_type == 'int64':
                    csv_df[col] = csv_df[col].astype(int)
        else:
            for col, col_type in zip(columns, types):
                if col_type == 'float64':
                    csv_df[col] = csv_df[col].replace(np.nan, MAX_FLOAT)

    return csv_df, entries

def _type_float(true_type='float64', failure_counter=None):
    def float_typer(val):
        try:
            typed_val = float(val)
        except:
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to type '{val}' as {true_type}."
            log.info(msg)
            typed_val = np.nan
        return typed_val
    return float_typer

def _type_int(failure_counter=None):
    def int_typer(val):
        try:
            typed_val = int(val)
        except:
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to type '{val}' as int64."
            log.info(msg)
            typed_val = MAX_INT
        return typed_val
    return int_typer

def _type_value(failure_counter=None):
    def value_typer(val):
        try:
            typed_val = float(val)
            if typed_val == np.inf:
                if failure_counter is not None:
                    _check_failures_exceeded(**failure_counter)
                msg = f"Failed to type '{val}' as a valid float."
                log.info(msg)
                typed_val = np.nan
        except:
            if failure_counter is not None:
                _check_failures_exceeded(**failure_counter)
            msg = f"Failed to type '{val}' as a valid float."
            log.info(msg)
            typed_val = np.nan
        return typed_val
    return value_typer

def _get_converters(types, columns, binning, failure_counter):
    # We don't bother converting if a column is not binned.
    converters = {}
    datetime_cols = []
    timestamp_cols = []
    for col, col_type, col_bin in zip(columns, types, binning):        
        if col_type == 'int64':
            converters[col] = _type_int(failure_counter)
        elif col_type == 'float64':
            converters[col] = _type_float('float64', failure_counter)
        elif col_type == 'datetime':
            converters[col] = str
            datetime_cols.append(col)
        elif col_type == 'date':
            converters[col] = str
            datetime_cols.append(col)
        elif col_type == 'time':
            converters[col] = str
            datetime_cols.append(col)
        elif col_type == 'timestamp':
            converters[col] = _type_float('timestamp', failure_counter)
            timestamp_cols.append(col)
        else:
            converters[col] = str
            
    return converters, datetime_cols, timestamp_cols

def _process_file_paths(filepaths):
    if not isinstance(filepaths, list) and isinstance(filepaths, str):
        filepaths = filepaths.split(' ')

    size = 0
    processed_filepaths = []
    for filepath in filepaths:
        if glob(filepath) == []:
            msg = f"Input: {filepath} does not exist or is not readable."
            log.error(msg)
            raise ValueError(msg)

        for input_file in sorted(glob(filepath)):
            processed_filepaths.append(input_file)
            size += os.path.getsize(input_file)

    return processed_filepaths, size < 1024**3

[docs]def df2tensor(dfs, dask_client=None, columns=None, types=None, binning=None, entries='count', sort=None, fuse_columns=None, joiner='__', gen_backtrack=False, gen_queries=False, failure_counter=None, verbose=False, in_memory=True): """ Variant of csv2tensor where in-memory DataFrames are passed instead of paths to files on disk. csv2tensor is preferred when possible. Parameters ---------- dfs : list of dask.dataframe.DataFrame Dataframes to convert into a tensor. Each column must be typed as specified with the `types` argument. Any missing values must be bucketed as np.inf or np.nan in float64 columns, sys.maxsize in int64 columns, and pd.NaT in datetime/timestamp columns. dask_client : dask.distributed.Client Client object connected to a distributed Dask scheduler. If None, the default local threaded scheduler is used. Default: None. columns : list of str A list containing names of columns to be chosen for tensor construction. Default: None types : list of str The expected type of the columns list entry at the corresponding position. Options are: 'str', 'float64', 'int64', 'datetime', 'timestamp', and 'ip'. Columns typed as 'datetime' or 'timestamp' will be sorted automatically if `sort` is None. Default: None binning : list of str The binning technique to use for the columns list entry at the corresponding position. Options are: 'none', 'binsize=<float>', 'cyclic=<int:float>', 'log10', 'round=<int>', 'ipv6_hextets=<num_hextets>:['MSB'|'LSB']' '[<ipv4_mask>+<ipv6mask> | <ipv4_mask> | <ipv6mask>]', 'second', 'minute', 'hour', 'day', 'month', 'year', 'minute-of-hour', 'hour-of-day', 'day-of-week', 'day-of-month', and 'month-of-year'. Default: 'none' sort : list of str List of column names to sort. Mode labels of these columns will be sorted when mapped to indices. Sorting columns can increase run time. Default: only 'timestamp' and 'datetime' columns are sorted. fuse_columns : list of list of str Lists of columns to fuse into single columns. e.g. [['col1', 'col2'], ['col3', 'col4']], would fuse col1 and col2 into a single column named col1__col2 and col3 and col4 into a column named col3__col4. Default: no columns are fused. joiner : str Delimiter separating the values in a fused column. Default: '__' entries : str Tensor entry calculation method. Legal values are 'count' and 'boolean' and 'value=<column_name>:<aggregation_method>'. Valid aggregation methods are 'sum', 'max', 'min', 'max_abs', 'min_abs', 'first', 'last', 'mean', 'prod', 'idxmin', and 'idxmax'. Modes that are used as value columns will be typed as 'float64'. Default: 'count' gen_backtrack : bool If True, generate backtracking information from tensor to input files. This is a map from tensor entries to source lines in the original CSV file(s). Information is helpful for pulling data associated with specific sets of entries. This method has known scalability limitations and this option will be ignored for data over 1GB. For large files use the -q option as an alternative. Default: False. gen_queries : bool If True, generate map from each bin in each mode to a set of selection criteria that can be parsed to construct a query for finding original datalines. Scalable alternative to gen_backtrack. Default: False. verbose : bool Verbose output. Default False. in_memory : bool Whether or not df2tensor() is called standalone or as part of the csv2tensor() function. Returns ------- tensor : ensign.sptensor.SPTensor The sparse tensor produced from the CSV input file(s) and input parameters. """ if verbose: _set_log_level('DEBUG') else: _set_log_level('WARNING') if not isinstance(dfs, list): dfs = [dfs] for df in dfs: if not isinstance(df, dd.DataFrame): msg = f"Invalid value for argument 'dfs': {df}. Only Dask DataFrames are accepted." log.error(msg) raise TypeError(msg) columns, types, binning, sort = _prepare_columns_types_binning_sort( dfs[0], columns, types, binning, sort, entries) log.info("Validating ...") for df in dfs: is_valid_config = _validate_config(df, columns, types, binning, entries, sort, fuse_columns, joiner) if not is_valid_config: msg = "Invalid configuration, could not convert dataframe to tensor" log.error(msg) raise ValueError(msg) log.info("Filtering ...") df, columns = _combine_and_filter_dfs( dfs, columns, entries, binning, sort, gen_backtrack, in_memory) log.info("Binning ...") df = _bin_columns(df, columns, binning, types, dask_client, failure_counter) log.info("Fusing ...") df, columns, binning, sort = _fuse_columns( df, fuse_columns, joiner, columns, binning, sort) log.info("Calculating tensor entries ...") df, spt_backtrack = _calc_entries(df, entries, columns, gen_backtrack, dask_client, failure_counter) log.info("Constructing labels ...") df, labels, queries = _get_labels(df, columns, types, binning, sort, joiner, gen_queries) log.info("Building sparse tensor ...") tensor = _build_sptensor(df, columns, labels, spt_backtrack, queries) if dask_client: dask_client.close() return tensor
[docs]def csv2tensor(filepaths, distributed=False, columns=None, binning=None, types=None, sort=None, entries='count', fuse_columns=None, joiner='__', delimiter=',', bro_log=False, header=None, validate_bro_log=True, gen_backtrack=False, gen_queries=False, drop_missing_values=False, missing_vals_limit=None, verbose=False): """ Creates a sparse tensor from one or more CSV files or Bro/Zeek logs. The columns of the DataFrame(s) will become the modes (dimensions) of the tensor. It is important to carefully choose these columns as it's recommended to only use 3-6 columns. This is specified with the `columns` argument. The set of indices of each mode correspond to the unique set of values in the corresponding column. Therefore the values of each column need to be discretized. This is done with the `binning` argument. Each binning scheme requires the associated column to be of a particular type. Specify the types with the `types` argument. Parameters ---------- filepaths : list of str Path(s) to input file(s). If multiple input files are specified, other options such as 'types' and 'binning' will be applied the same to all files. distributed : bool or str Whether or not to use the Dask Distributed scheduler. If True, the Dask scheduler address is assumed to be '127.0.0.1:8786'. If False, the local threaded scheduler is used. If str, it should contain a Dask scheduler address. Default: use threaded scheduler. The distributed scheduler is strongly recommended. columns : list of str A list containing names of columns to be chosen for tensor construction. Default: None types : list of str The expected type of the columns list entry at the corresponding position. Options are: 'str', 'float64', 'int64', 'datetime', 'date', 'time', 'timestamp', and 'ip'. Columns typed as 'datetime', 'date', 'time' or 'timestamp' will be sorted automatically if `sort` is None. Default: None binning : list of str The binning technique to use for the columns list entry at the corresponding position. Options are: 'none', 'binsize=<float>', 'cyclic=<int:float>', 'log10', 'round=<int>', 'ipv6_hextets=<num_hextets>:['MSB'|'LSB']' '[<ipv4_mask>+<ipv6mask> | <ipv4_mask> | <ipv6mask>]', 'second', 'minute', 'hour', 'day', 'month', 'year', 'minute-of-hour', 'hour-of-day', 'day-of-week', 'day-of-month', and 'month-of-year'. Default: 'none' sort : list of str List of column names to sort. Mode labels of these columns will be sorted when mapped to indices. Sorting columns can increase run time. Default: only 'timestamp', 'datetime', 'date' and 'time' columns are sorted. fuse_columns : list of list of str Lists of columns to fuse into single columns. e.g. [['col1', 'col2'], ['col3', 'col4']], would fuse col1 and col2 into a single column named col1__col2 and col3 and col4 into a column named col3__col4. Default: no columns are fused. joiner : str Delimiter separating the values in a fused column. Default: '__' delimiter : str Delimiter separating the columns in the CSV file(s). Interpreted as a regular expression if longer than a single character. Default: ',' entries : str Tensor entry calculation method. Legal values are 'count' and 'boolean' and 'value=<column_name>:<aggregation_method>'. Valid aggregation methods are 'sum', 'max', 'min', 'max_abs', 'min_abs', 'first', 'last', 'mean', 'prod', 'idxmin', and 'idxmax'. Modes that are used as value columns will be typed as 'float64'. Default: 'count' bro_log : bool If True, treat input as a Bro/Zeek log. Default: False (treat input as a CSV) header : list of str For use with files that don't have headers. Specifies the names of the columns of the input file. Must have the same number of column names as there are columns in the input file. Default: None. validate_bro_log: bool If the input bro_log is true, this check will validate whether the input is indeed a Bro/Zeek log. Simply checks for the header of the file. Default is True only when bro_log is True. gen_backtrack : bool If True, generate backtracking information from tensor to input files. This is a map from tensor entries to source lines in the original CSV file(s). Information is helpful for pulling data associated with specific sets of entries. This method has known scalability limitations and this option will be ignored for data over 1GB. For large files use gen_queries as an alternative. Default: False. gen_queries : bool If True, generate map from each bin in each mode to a set of selection criteria that can be parsed to construct a query for finding original datalines. Scalable alternative to gen_backtrack. Default: False. drop_missing_values : bool If True, drop rows where any entry fails to be typed as a float64, int64, date, time, datetime, or timestamp. Otherwise, bucket the missing/corrupted values as NaN or NaT. Default: bucket missing values. missing_vals_limit : int Cap the number of values that fail to be typed. After the quantity is reached, csv2tensor will exit with an error. `None` means do not cap the amount of failures. Default: do not limit. verbose : bool Verbose output. Default False. Returns ------- tensor : ensign.sptensor.SPTensor The sparse tensor produced from the CSV input file(s) and input parameters. """ start = time.time() if verbose: _set_log_level('DEBUG') else: _set_log_level('WARNING') filepaths, are_files_small = _process_file_paths(filepaths) gen_backtrack = gen_backtrack and are_files_small if header is not None and columns is None: columns = header columns, types, binning, sort = _prepare_columns_types_binning_sort( filepaths[0], columns, types, binning, sort, entries, bro_log) _validate_types(types) if columns is not None: if not _is_list_of_str(columns): msg = f"'{arg}' argument invalid: not None or list of str: {columns}" log.error(msg) raise TypeError(msg) dask_client = None if distributed: if isinstance(distributed, str): try: dask_client = Client(distributed) except OSError: msg = f'There is no Dask scheduler running at the specified address: {distributed}. Check the dask scheduler logs to find the proper address.' log.error(msg) raise else: dask_client = Client(f'127.0.0.1:{DASK_DEFAULT_PORT}') # Used to track number of values that can't be typed. # More details about the implementation in the definition of _get_failure_counter() failure_counter = _get_failure_counter(missing_vals_limit, dask_client) try: dfs = [] for input_file in filepaths: log.info(f"Reading {input_file} ...") if bro_log: csv_df, entries = _ensign_read_bro_log( input_file, columns, types, binning, header, entries, validate_bro_log, drop_missing_values, failure_counter, gen_backtrack) else: if not isinstance(delimiter, str): raise TypeError('The column delimiter must be string.') csv_df, entries = _ensign_read_csv( input_file, columns, types, binning, header, delimiter, entries, drop_missing_values, failure_counter, gen_backtrack) log.info(f"Dataframe has {csv_df.npartitions} partitions") dfs.append(csv_df) log.info("Building tensor ...") tensor = df2tensor(dfs, dask_client, columns, types, binning, entries, sort, fuse_columns, joiner, gen_backtrack, gen_queries, failure_counter, verbose, False) finally: if dask_client is not None: dask_client.close() end = time.time() log.info(f'Built tensor in {end-start} seconds') if len(tensor.entries) == 0: log.warn('WARNING: Tensor is empty. Check types of columns and value.') return tensor