"""
Column-based dataset data manipulation tools.
"""

__version__ = "$Rev: 9633 $"

import pdb, sys
from types import FunctionType, TypeType, NoneType

def xlen(x):
    return xrange(len(x))

class DatasetIter(object):

    def __init__(self, dataset):
        self._dataset = dataset
        self._i = -1

    def __iter__(self):
        return self

    def next(self):
        self._i = self._i + 1
        if self._i >= len(self._dataset):
            raise StopIteration()
        return self._dataset[self._i]


class DatasetRow(object):

    def __init__(self, data, _dataset=None, _index=None):
        if _dataset:
            self._dataset = _dataset
            self._index = _index
        else:
            self._dataset = None
            self._columns = [c for (c,_) in data]
            self._col_ind = dict((self.columns[i],i)
                                 for i in xlen(self.columns))
            self._data = [d for (_,d) in data]

    def __len__(self):
        if self._dataset: return len(self._dataset.columns)
        return len(self.columns)

    def __iter__(self):
        if self._dataset: return self._dataset.columns.__iter__()
        return self.columns.__iter__()

    def __contains__(self, item):
        if self._dataset: return item in self._dataset.columns
        return item in self.columns

    def __getitem__(self, c):
        if isinstance(c, int):
            raise KeyError("int keys may not be used on DatasetRows")
        if self._dataset: return self._dataset[c][self._index]
        return self._data[self._col_ind[c]]

    def __repr__(self):
        return 'DatasetRow({' + ', '.join(`c` + ": " + `self[c]`
                                          for c in self) + '})'


class Dataset(object):

    def __init__(self, _columns=(), _data=None, _row_data=None, **kwargs):

        self._data = None
        self._row_data = None

        if (_columns or _data or _row_data) and kwargs:
            raise TypeError("illegal Dataset definition")

        if kwargs:
            _columns = kwargs.keys()
            _data = [list(kwargs[c]) for c in _columns]
            max_len = 0
            for i in xlen(_data):
                cur_len = len(_data[i])
                if cur_len > max_len:
                    max_len = cur_len
            for i in xlen(_data):
                if len(_data[i]) == 1:
                    _data[i] = _data[i] * max_len
                elif len(_data[i]) <> max_len:
                    raise TypeError("not all columns are of equal length")
        
        # Check for datasetrow list.  If there's only one argument and
        # it's a list of strings, then it's a list of column names.
        # Otherwise, it's assumed to be a datasetrow collection.

        self.columns = []
        self._col_ind = {}
        self._data = None
        self._row_data = []

        for x in _columns:
            if isinstance(x, (DatasetRow, dict)):
                self.add_row(x)
            else:
                self.columns.append(x)

        if self._row_data:
            self._columnize()
            return
                
        self._col_ind = dict((self.columns[i], i) for i in xlen(self.columns))
        self._data = None
        self._row_data = None
        
        if _data <> None:
            self._data = tuple(_data)
        elif _row_data <> None:
            self._row_data = list(map(tuple,_row_data))
        else:
            self._data = tuple([] for i in xlen(self.columns))

    def __len__(self):
        if self._data:
            return len(self._data[0])
        elif self._row_data:
            return len(self._row_data)
        else:
            return 0

    def __iter__(self):
        return DatasetIter(self)

    def __contains__(self, item):
        try:
            x = self.get_column_index(item)
            return True
        except:
            return False

    def __getitem__(self, i):
        if self._row_data:
            self._columnize()
        if isinstance(i, (str, NoneType)):
            return self._data[self._col_ind[i]]
        elif isinstance(i, slice):
            return Dataset(self.columns,
                           (c[i] for c in self._data))
        else:
            return DatasetRow(None, _dataset=self, _index=i)

    def add_row(self, r):
        if isinstance(r, (DatasetRow, dict)):
            new_row = [None] * len(self.columns)
            for c in r:
                if c not in self._col_ind:
                    self.columns.append(c)
                    self._col_ind[c] = len(self.columns) - 1
                    new_row.append(None)
                    if self._data:
                        self._data = self._data + ([None] * len(self),)
                    else:
                        self._row_data = map(lambda r: r + (None,),
                                             self._row_data)
                new_row[self._col_ind[c]] = r[c]
            r = new_row
        if len(r) <> len(self.columns):
            raise TypeError('data width does not match dataset width')
        if self._data:
            for i in xlen(r):
                self._data[i].append(r[i])
        else:
            self._row_data.append(tuple(r))

    def add_rows(self, data_rows):
        if self._data:
            for i in xlen(self._data):
                col = self._data[i]
                for data_row in data_rows:
                    col.append(data_row[i])
        else:
            self._row_data.extend(data_rows)

    def _columnize(self):
        if self._data:
            return
        else:
            self._data = tuple([row[i] for row in self._row_data]
                               for i in xlen(self.columns))
            self._row_data = None

    def filter(self, filter):
        self._columnize()
        passed = [i for i in xlen(self) if filter(self[i])]
        return Dataset(self.columns,
                       ([col[i] for i in passed] for col in self._data))

    def filter_col(self, c, filter):
        self._columnize()
        c = self._col_ind[c]
        if isinstance(filter, FunctionType):
            passed = [i for i in xlen(self) if filter(self._data[c][i])]
        elif isinstance(filter, tuple):
            passed = [i for i in xlen(self) if self._data[c][i] in filter]
        else:
            passed = [i for i in xlen(self) if self._data[c][i] == filter]
        return Dataset(self.columns,
                       ([col[i] for i in passed] for col in self._data))

    def cross_tab(self, key_columns, cross_column, value_column,
                  key_list=None,
                  cross_list=None, 
                  value_default=None):
        """
        Cross tabulate a data table on some column.

        All column names given in key_columns will end up as unique
        keys (one row per value of the key_columns).

        All unique values of the cross_column will end up becoming the
        names of columns (after the key_columns).

        Each of the cross_column value columns will have the
        appropriate value from the value column.

        For proper use, key_columns+cross_column should be a unique
        key for the input dataset.  If it is not, then there will be
        multiple values for some key+cross combinations.  In that
        case, it is undefined which value will be taken.
        """

        if isinstance(key_columns, str):
            key_columns = [key_columns]

        self._columnize()
        
        cross_data = self._data[self._col_ind[cross_column]]
        key_data = [self._data[self._col_ind[x]] for x in key_columns]
        val_data = self._data[self._col_ind[value_column]]

        if cross_list <> None:
            labels = cross_list
            label_count = len(labels)
        else:
            labels = list(set(cross_data))
            label_count = len(labels)

        keys = {}

        if key_list:
            for k in key_list:
                keys[k] = [value_default] * len(labels)
        else:
            for k in set(tuple(x[i] for x in key_data)
                         for i in xlen(key_data[0])):
                keys[k] = [value_default] * len(labels)

        label_pos = dict((labels[i], i) for i in xlen(labels))

        for i in xlen(val_data):
            key = tuple(x[i] for x in key_data)
            if key in keys and cross_data[i] in label_pos:
                keys[key][label_pos[cross_data[i]]] = val_data[i]
        
        keys_list = keys.keys()

        result = Dataset(key_columns + [str(l) for l in labels],
                         _row_data=[(key_values + tuple(keys[key_values]))
                                    for key_values in keys_list])
        result._columnize()
        return result

    def sort(self, cmp=None, key=None, reverse=False):
        self._columnize()
        if key == None:
            key = (lambda i: self[i])
        else:
            old_key = key
            key = (lambda i: old_key(self[i]))
        order = [i for i in xlen(self)]
        order.sort(cmp, key, reverse)
        return Dataset(self.columns,
                       ([col[i] for i in order] for col in self._data))

    def sort_col(self, c, cmp=None, key=None, reverse=False):
        self._columnize()
        c = self._col_ind[c]
        if key == None:
            key = (lambda i: self._data[c][i])
        else:
            old_key = key
            key = (lambda i: old_key(self._data[c][i]))
        order = [i for i in xlen(self)]
        order.sort(cmp, key, reverse)
        return Dataset(self.columns,
                       ([col[i] for i in order] for col in self._data))

    def __add__(self, other):
        if not isinstance(other, Dataset):
            raise TypeError("unsupported operand type(s) for +: '%s' and '%s'"%
                            (type(self), type(other)))
        self._columnize()
        other._columnize()
        columns = set(self.columns) | set(other.columns)
        non_self_columns = columns - set(self.columns)
        non_other_columns = columns - set(other.columns)
        result = Dataset(_columns=columns)
        for c in self.columns:
            result[c].extend(self[c])
        for c in non_self_columns:
            result[c].extend([None] * len(self))
        for c in other.columns:
            result[c].extend(other[c])
        for c in non_other_columns:
            result[c].extend([None] * len(other))
        return result

    def __repr__(self):
        self._columnize()
        return "Dataset(%s, ...) <%d rows>" % (`self.columns`, len(self))

    def dump(self):
        self._columnize()
        return "Dataset(%s, %s)" % (`self.columns`, `self._data`)

    def sum_uniq(self, key_cols, sum_cols):
        self._columnize()
        data = {}
        for r in self:
            key = tuple(r[c] for c in key_cols)
            if key in data:
                sums = data[key]
                for i in xrange(len(sum_cols)):
                    sums[i] += r[sum_cols[i]]
                data[key] = sums
            else:
                sums = [r[c] for c in sum_cols]
                data[key] = sums
        return Dataset(tuple(key_cols) + tuple(sum_cols),
                       _row_data=[(key + tuple(data[key]))
                                  for key in data])

    def merge_uniq(self, key_cols, sum_cols, min_cols=[], max_cols=[]):
        self._columnize()
        sum_data = {}
        min_data = {}
        max_data = {}
        for r in self:
            key = tuple(r[c] for c in key_cols)
            if key in sum_data:
                sums = sum_data[key]
                for i in xrange(len(sum_cols)):
                    sums[i] += r[sum_cols[i]]
                mins = min_data[key]
                for i in xrange(len(min_cols)):
                    mins[i] = min(mins[i], r[min_cols[i]])
                maxs = max_data[key]
                for i in xrange(len(max_cols)):
                    maxs[i] = max(maxs[i], r[max_cols[i]])
            else:
                sum_data[key] = [r[c] for c in sum_cols]
                min_data[key] = [r[c] for c in min_cols]
                max_data[key] = [r[c] for c in max_cols]
        return Dataset(tuple(key_cols) + tuple(sum_cols) + 
                       tuple(min_cols) + tuple(max_cols),
                       _row_data=[(key + tuple(sum_data[key]) +
                                   tuple(min_data[key]) +
                                   tuple(max_data[key]))
                                  for key in sum_data])

    def eject(self):
        self._columnize()
        return(list(self.columns), list(self._data))

    def map_dataset(self, column_mappers):
        self._columnize()
        columns = list(self.columns)
        data = list(self._data)
        def map_column(name, mapper, new_name=None):
            try:
                index = columns.index(name)
            except ValueError:
                raise KeyError("invalid column %s" % name)
            if not callable(mapper):
                func = lambda x: mapper[x]
            else:
                func = mapper
            data[index] = map(func, data[index])
            if new_name is not None:
                columns[index] = new_name
        for cm in column_mappers:
            map_column(*cm)
        return Dataset(columns, data)

    def __str__(self):
        self._columnize()
        result = []
        column_header_formats = []
        column_formats = []
        for i in xlen(self.columns):
            length = len(str(self.columns[i]))
            if len(self._data[i]) > 0:
                if isinstance(self._data[i][0], (int, long)):
                    format = "%d"
                elif isinstance(self._data[i][0], float):
                    format = "%.2f"
                else:
                    format = "%s"
            else:
                return ' '.join(self.columns)
            for d in self._data[i]:
                #print "[%d][0]: '%s' %% %s" % (i, format, d)
                l = len(format % d)
                if l > length:
                    length = l
            if format[1] == 's':
                length = -length
            format = format[0] + `length` + format[1:]
            column_formats.append(format)
            column_header_formats.append('%' + `length` + 's')
        output = []
        for i in xlen(self.columns):
            output.append(column_header_formats[i] % self.columns[i])
        result.append(' '.join(output))
        if len(self._data):
            for i in xlen(self._data[0]):
                output = []
                for j in xlen(self._data):
                    output.append(column_formats[j] % self._data[j][i])
                result.append(' '.join(output))
        return '\n'.join(result)

    def tostring(self):
        def rjust_str(width, val):
            if len(val) < width:
                return (" " * (width - len(val))) + val
            else:
                return val
        # for all format_* functions, if width is 0, values of width <
        # the length of the stringified val result in no padding
        # (i.e., not truncation)
        def format_str(width, val):
            if val is None:
                val = "None"
            try:
                len(val)
            except:
                pdb.set_trace()
            if len(val) < width:
                val = val + (" " * (width - len(val)))
            rc = val
            return rc
        def format_int(width, val):
            if val is None:
                val = "None"
            else:
                val = "%d" % val
            rc = rjust_str(width, val)
            return rc
        def format_float(width, val):
            if val is None:
                val = "None"
            else:
                val = "%.2f" % val
            rc = rjust_str(width, val)
            return rc
        
        self._columnize()
        col_widths = []
        col_formatters = []

        if len(self._data[0]) == 0:
            return ' '.join(self.columns)

        # Go through once to get max width
        for i in xlen(self.columns):
            # Find a non-None item in the column (if one exists) and
            # use that to type the column
            sample = None
            for j in self._data[i]:
                if j is not None:
                    sample = j
            if isinstance(sample, (int, long)):
                formatter = format_int
            elif isinstance(sample, float):
                formatter = format_float
            else:
                formatter = format_str
            col_formatters.append(formatter)
            col_widths.append(max(len(formatter(0, x)) for x in self._data[i]))

        # Go through again and actually format
        output = []
        out_row = []
        # headers
        for i in xlen(self.columns):
            if len(self.columns[i]) > col_widths[i]:
                col_widths[i] = len(self.columns[i])
            if col_formatters[i] in (format_int, format_float):
                out_row.append(rjust_str(col_widths[i], self.columns[i]))
            else:
                out_row.append(format_str(col_widths[i], self.columns[i]))
        output.append(' '.join(out_row))
        # rows
        for i in xlen(self._data[0]):
            out_row = []
            for j in xlen(self.columns):
                formatter = col_formatters[j]
                out_row.append(formatter(col_widths[j], self._data[j][i]))
            output.append(' '.join(out_row))
        return "\n".join(output)

        
def seq_index(a, x):
    try:
        return a.index(x)
    except:
        return None

def seq_find(a, b):
    return [seq_index(a,x) for x in b]

def seq_set(a, ais, xs):
    for i in xlen(ais):
        if i <> None and ais[i] <> None:
            a[ais[i]] = xs[i]

__all__ = [ 'Dataset', 'seq_index', 'seq_find', 'seq_set' ]
