from __future__ import division

import sys
import pdb 
from numpy import arange
from math import ceil

from PIL import Image, ImageDraw, ImageFont, ImageColor

from rave.util import get_units

font_path='/Users/pgroce/Library/Fonts'

def get_font(name):
    return "%s/%s" % (font_path, name)


class Cartesian(object):
    """Convert from Cartesian to PIL coordinates.  TODO: There's
    some off-by-one goodness when y=0, and possibly also when x=0.
    The workaround is to always use a little margin top and bottom.

    offset_x and offset_y are expressed in PIL's coordinate system
    (IOW, (0,0) is in the upper left.)"""
    def __init__(self, width, height, xlim=None, ylim=None,
                 offset_x=0, offset_y=0,
                 margin_left=0, margin_right=0,
                 margin_top=0, margin_bottom=0):
        if not xlim:
            xlim = (0, width)
        if not ylim:
            ylim = (0, height)
        self.width = width
        self.height = height
        self.xlim = xlim
        self.ylim = ylim
        self.offset_x = offset_x
        self.offset_y = offset_y
        self.margin_top = margin_top
        self.margin_bottom = margin_bottom
        self.margin_left = margin_left
        self.margin_right = margin_right
    def __str__(self):
        props = (
            ('width', "%d" % self.width),
            ('height', "%d" % self.height),
            ('x'     , "%.2f-%.2f" % self.xlim),
            ('y'     , "%.2f-%.2f" % self.ylim),
            ('x offset', '%d' % self.offset_x),
            ('y offset', '%d' % self.offset_y),
            ('top margin', '%d' % self.margin_top),
            ('bottom margin', '%d' % self.margin_bottom),
            ('left margin', '%d' % self.margin_left),
            ('right margin', '%d' % self.margin_right),
        )
        return "<Cartesian: %s>" % '; '.join(
            ("%s: %s" % (k, v)) for k, v in props
        )
    def _height(self):
        return self.height - (self.margin_top + self.margin_bottom)
    def _width(self):
        return self.width - (self.margin_left + self.margin_right)
    def x(self, old_x):
    #   Slide scale to start at zero
        x_zeroed = old_x - self.xlim[0]
    #   Express x as a percentage of the total scale
        x_pct = x_zeroed / (self.xlim[1] - self.xlim[0])
    #   Return pixels (Add 2 to hack a left margin in)
        return ceil(self._width() * x_pct + self.margin_left + self.offset_x)
    def y(self, old_y):
    #   Slide scale to start at zero
        y_zeroed = old_y - self.ylim[0]
    #   Express y as a percentage of the total scale
    #   TODO: real solution
        if self.ylim[0] == self.ylim[1]:
            self.ylim = (self.ylim[0], self.ylim[0] + 1)
        y_pct = y_zeroed / (self.ylim[1] - self.ylim[0])
    #   Convert to pixels. Invert the percentage, because PIL's
    #   coordinates start at upper left, not lower left
        rc = ceil(self._height() * (1 - y_pct) + 
                  self.margin_top + self.offset_y)
        return rc

class HorizontalGridLayout(object):
    def __init__(self, width, height, rows, columns):
        self.cell_width = width / columns
        self.cell_height = height / rows
        curr_x = 0
        curr_y = 0
        self.rows = rows
        self.columns = columns
        self.cells = []
        for i in xrange(rows * columns):
            self.cells.append((curr_x, curr_y))
            curr_x += self.cell_width
            if curr_x >= width:
                curr_x = 0
                curr_y += self.cell_height
    def offset_idx(self, idx):
        return self.cells[idx]
    def offset_xy(self, x, y):
        "First cell is at (0,0)."
        return self.cells[ (x*self.rows) + y ]



def timeseries_core(im, coords, value_data, width,
                    start_color=None, end_color=None,
                    min_color=None, max_color=None):
    def steps(start, end, num_steps):
        r = end - start
        i = 0 
        while i < num_steps - 1:
            step = r/(num_steps-1)
            yield start + (i * step)
            i+=1 
        yield end 
        
    val_max = max(value_data)
    val_min = min(value_data)
    x_vals = [(x, coords.x(x)) for x in steps(0, width, len(value_data))]
    y_vals = [(y, coords.y(y)) for y in value_data]

    pixels = im.load()
    draw = ImageDraw.Draw(im)
    hi = []
    lo = []
    def highlight(x, y, col):
        radius = 2
        bbox = (x - radius, y - radius, x + radius, y + radius)
        draw.ellipse(bbox, outline=col, fill=col)
        #draw.line((x, y-1, x, y+1), fill=col)
        #draw.point((x, y, x-1, y, x, y+1, x, y-1), fill=col)
    for i in xrange(1, len(x_vals)):
        curr_x = x_vals[i]
        curr_y = y_vals[i]
        prev_x, prev_y = x_vals[i-1], y_vals[i-1]
        draw.line([prev_x[1], prev_y[1], curr_x[1], curr_y[1]], fill='#333333')
        if curr_y[0] == val_max:
            hi.append((curr_x[1], curr_y[1]))
        if curr_y[0] == val_min:
            lo.append((curr_x[1], curr_y[1]))
    if start_color:
        highlight(x_vals[0][1], y_vals[0][1], start_color)
    if end_color:
        highlight(x_vals[-1][1], y_vals[-1][1], end_color)
    if max_color:
        if len(hi) > 0:
            highlight(hi[0][0], hi[0][1], max_color)
        #for x, y in hi:
        #    highlight(x, y, max_color)
    if min_color:
        if len(lo) > 0:
            highlight(lo[0][0], lo[0][1], min_color)
        #for x, y in lo:
        #    highlight(x, y, min_color)


class MetricFormatter(object):
    """Formats numbers into "metric" (i.e., powers of ten) units for display."""
    def __init__(self, font, size, color, num_fmt="%d", unit_fmt="%(short)s"):
        self.font = ImageFont.truetype(font, size)
        self.color = color
        self.num_fmt = num_fmt
        self.unit_fmt = unit_fmt
    def _get_units(self, number):
        k = 1000
        m = 1000 * k
        g = 1000 * m
        t = 1000 * g
        if number >= t:
            return {'num': number/t, 'short': 'T', 'long': 'Tera'}
        if number >= g:
            return {'num': number/g, 'short': 'G', 'long': 'Giga'}
        if number >= m:
            return {'num': number/m, 'short': 'M', 'long': 'Mega'}
        if number >= k:
            return {'num': number/k, 'short': 'k', 'long': 'Kilo'}
        return {'num': number, 'short': '', 'long': ''}
    def format_number(self, number):
        return self.num_fmt % self._get_units(number)
    def format_unit(self, number):
        if self.unit_fmt is None:
            return ''
        else:
            return self.unit_fmt % self._get_units(number)
    def width_of(self, txt):
        if txt is None:
            return 0
        else:
            return self.font.getsize(txt)[0]
    def height_of(self, txt):
        if txt is None:
            return 0
        else:
            return self.font.getsize(txt)[1]
    def number_max_width(self):
    #   Arbitrarily declare 999.99T as the largest thing we format
        return self.width_of(self.format_number(10**14))
    def number_height(self):
        return self.height_of(self.format_number(10**14))
    def unit_max_width(self):
        if self.unit_fmt is None:
            return 0
        else:
            return self.width_of(self.format_unit(10**14))
    def unit_height(self):
        return self.height_of(self.unit_fmt % 
                              {'num': 999.99, 'short': 'M', 'long': 'Mega'})
    def total_width(self, extra=None):
        return (self.number_max_width() + 
                self.unit_max_width()   + 
                self.width_of(extra))
    def max_height(self, extra=None):
        return max(self.number_height(),
                   self.unit_height(), 
                   self.height_of(extra))


def render_timeseries_primitive(out_file_name, value_data,
                                width, height, bgcolor="#ffebcd",
                                start_color=None, end_color=None,
                                max_color=None, min_color=None,
                                max_fmt=None, min_fmt=None):
    "values is assumed to be a regular (i.e., no gaps) time series."
#   Oversample the image for antialiasing later
    final_width = width
    final_height = height
    width = width * 2
    height = height * 2
    drawable_width = width
    drawable_height = height

#   Gutter between min and max values
    gutter = 10

#   We're still at 2x oversampling, so we need to double any
#   width/height values we get out of the Formatter
    drawable_width -= 4
    if min_fmt is not None:
        drawable_width -= min_fmt.total_width() * 2
    if max_fmt is not None:
        if min_fmt is not None:
            drawable_width -= gutter * 2
        drawable_width -= max_fmt.total_width() * 2


    background_color = ImageColor.getrgb(bgcolor)
    im = Image.new('RGB', (width, height), 
                   background_color)
    coords = Cartesian(drawable_width, drawable_height,
                       ylim=(min(value_data), max(value_data)),
                       margin_left=6, margin_right=6,
                       margin_top=6, margin_bottom=6)

    timeseries_core(im, coords, value_data, drawable_width,
                    start_color, end_color,
                    min_color, max_color)

#   From now on, we're drawing on the newly-resized image, so proportions
#   calculated from the large image will need to be halved.
    im = im.resize((final_width, final_height), Image.ANTIALIAS)

    if min_fmt is not None or max_fmt is not None:
        draw = ImageDraw.Draw(im)
    #   Offset past graphic
        label_x_offset = drawable_width/2 + 2
    #   Distance between bottom of image and baseline of text
        baseline = drawable_height/2 - 6
        def draw_unit_labels(x, val, fmt):
            y = baseline - fmt.max_height()
            font = fmt.font
            col  = fmt.color
            # Right-justify number
            num  = fmt.format_number(val)
            x += fmt.number_max_width() - fmt.width_of(num)
            draw.text((x, y), num, font=font, fill=col)
            # Advance x past min number column
            x += fmt.width_of(num)
            # Right-justify units
            unit = fmt.format_unit(val)
            x += fmt.unit_max_width() - fmt.width_of(unit)
            draw.text((x, y), unit, font=font, fill=col)
            # Advance x past min units column
            x += fmt.width_of(unit)
            return x

        x = label_x_offset
        if min_fmt is not None:
            x = draw_unit_labels(x, min(value_data), min_fmt)
        if max_fmt is not None:
            if min_fmt is not None:
                x += gutter
            x = draw_unit_labels(x, max(value_data), max_fmt)

    im.save(out_file_name, 'PNG')
    return out_file_name

        
def histogram_core(im, coords, value_data, width,
                   fillcolor="#333333", strokecolor="#333333",
                   baselinecolor="#666666"):
    bar_width = width/len(value_data)
    pixels = im.load()
    draw = ImageDraw.Draw(im)
    draw.line([(coords.x(coords.xlim[0]), (coords.y(0))),
               (coords.x(coords.xlim[1]), (coords.y(0)))], fill=baselinecolor)
    curr_x = 0
    for val in value_data:
        if val <> 0:
            verts = ((curr_x, 0),
                     (curr_x, val),
                     (curr_x + bar_width, val),
                     (curr_x + bar_width, 0))
            draw.polygon([(coords.x(x), coords.y(y)) for x, y in verts],
                         outline=strokecolor, fill=fillcolor)
        curr_x += bar_width

def render_histogram_primitive(out_file_name, value_data,
                               width, height, bgcolor="#ffebcd",
                               fillcolor="#333333", strokecolor="#333333",
                               baselinecolor="#666666"):
    "Values is assumed to be a presorted sequence of (>=0) counts"
#   Oversample the image for antialiasing later
    final_width = width
    final_height = height
    width = width * 2
    height = height * 2

    coords = Cartesian(width, height,
                       ylim=(0, max(value_data)),
                       margin_left=6, margin_right=6,
                       margin_top=6, margin_bottom=6)

    background_color = ImageColor.getrgb(bgcolor)
    im = Image.new('RGB', (width, height), background_color)

    histogram_core(im, coords, value_data, width,
                   fillcolor, strokecolor, baselinecolor)

    im = im.resize((final_width, final_height), Image.ANTIALIAS)
    im.save(out_file_name, 'PNG')
    return out_file_name


def scatter_core(im, coords, x_data, y_data,
                 marker_radius=1, markercolor="#333333"):
    marker_color = ImageColor.getrgb(markercolor)
    pixels = im.load()
    if marker_radius > 1:
    #   *sigh* Tradeoffs...we can either muck with fractions,
    #   or get the radius a little wrong. So, a radius of 10
    #   can actually mean 9.5 or 10.5 (assuming the center pixel
    #   is part of the diameter). We go smaller, since it keeps
    #   the radius == 1 case intuitive and the radius == 2 case
    #   manageable.
        draw = ImageDraw.Draw(im)
        for x, y in zip(x_data, y_data):
            x, y = coords.x(x), coords.y(y)
            bbox = ((x - marker_radius, y - marker_radius),
                    (x + marker_radius, y + marker_radius))
            draw.ellipse(bbox, outline=markercolor, fill=markercolor)
    else:
    #   The fast way if marker_radius == 1
        for x, y in zip(x_data, y_data):
            #print "(%.2f, %.2f)" % (x, y)
            pixels[coords.x(x), coords.y(y)] = marker_color


def render_scatter_primitive(out_file_name, x_data, y_data,
                             width, height, bgcolor="#ffebcd",
                             marker_radius=1, markercolor="#333333"):
#   Oversample the image for antialiasing later
    final_width = width
    final_height = height
    width = width * 2
    height = height * 2

    coords = Cartesian(width, height,
                       xlim=(min(x_data), max(x_data)),
                       ylim=(min(y_data), max(y_data)),
                       margin_left=6, margin_right=6,
                       margin_top=6, margin_bottom=6)

    background_color = ImageColor.getrgb(bgcolor)
    im = Image.new('RGB', (width, height), background_color)

    scatter_core(im, coords, x_data, y_data, marker_radius, markercolor)

    im = im.resize((final_width, final_height), Image.ANTIALIAS)
    im.save(out_file_name, 'PNG')
    return out_file_name



def render_grid_test(out_file_name, x_vals, y_vals,
                     width, height, bgcolor="#ffebcd"):
    """A scatterplot matrix. All series is a sequence of equally-sized
    sequences of data."""

#   Oversample the image for antialiasing later
    final_width = width
    final_height = height
    width = width * 2
    height = height * 2

    background_color = ImageColor.getrgb(bgcolor)
    im = Image.new('RGB', (width, height), background_color)

    layout = HorizontalGridLayout(width, height, 2, 2)
    for i in xrange(2*2):
        off_x, off_y = layout.offset_idx(i)
        coords = Cartesian(layout.cell_width, layout.cell_height,
                           offset_x=off_x, offset_y=off_y,
                           xlim=(min(x_vals), max(x_vals)),
                           ylim=(min(y_vals), max(y_vals)),
                           margin_left=6, margin_right=6,
                           margin_top=6, margin_bottom=6)
        scatter_core(im, coords, x_vals, y_vals)
    im = im.resize((final_width, final_height), Image.ANTIALIAS)
    im.save(out_file_name, 'PNG')
    return out_file_name
        
def render_splom(out_file_name, all_data, width, height,
                 bgcolor="#ffebcd", markercolor="#333333", marker_radius=1,
                 fillcolor="#333333", strokecolor="#333333"):

#   Oversample the image for antialiasing later
    final_width = width
    final_height = height
    width = width * 2
    height = height * 2

    background_color = ImageColor.getrgb(bgcolor)
    im = Image.new('RGB', (width, height), background_color)
    layout = HorizontalGridLayout(
        width, height, len(all_data), len(all_data))

    for i in xrange(len(all_data)):
        for j in xrange(len(all_data)):
            x_data = all_data[i]
            y_data = all_data[j]
            off_x, off_y = layout.offset_xy(i, j)
            if i == j:
            #   y_data and x_data are interchangeable in this case...
                coords = Cartesian(layout.cell_width, layout.cell_height,
                                   offset_x=off_x, offset_y=off_y,
                                   ylim=(min(y_data), max(y_data)),
                                   margin_left=6, margin_right=6,
                                   margin_top=6, margin_bottom=6)
                histogram_core(im, coords, y_data, layout.cell_width)
            else:
                coords = Cartesian(layout.cell_width, layout.cell_height,
                                   offset_x=off_x, offset_y=off_y,
                                   xlim=(min(x_data), max(x_data)),
                                   ylim=(min(y_data), max(y_data)),
                                   margin_left=6, margin_right=6,
                                   margin_top=6, margin_bottom=6)
                scatter_core(im, coords, x_data, y_data,
                             marker_radius=marker_radius)

    im = im.resize((final_width, final_height), Image.ANTIALIAS)
    im.save(out_file_name, 'PNG')
    return out_file_name
            
        
def render_font_test(out_file_name):
    #font = ImageFont.truetype('/Users/pgroce/Library/Fonts/VeraMono.ttf')
    background_color = ImageColor.getrgb('#ffffff')
    font = ImageFont.truetype('/Users/pgroce/Library/Fonts/VeraMono.ttf', 9)
    im = Image.new('RGB', (100, 100), background_color)

    draw = ImageDraw.Draw(im)
    draw.text((0, 0), "phil 12 / 34", font=font, fill="#000000")
    im.save(out_file_name, 'PNG')
    return out_file_name
 
    
                    
    
    
    
    

import pdb
if __name__ == '__main__':
    import math
    import random
    W = 600
    H = 100
    w = 120
    h = 20

#   time series
    vals = [math.sin(x) for x in arange(0, 8, .01).tolist()]
    min_fmt = MetricFormatter(get_font('VeraMono.ttf'), 10, "#0000aa",
                              "%(num).2f", "%(short)s")
    max_fmt = MetricFormatter(get_font('VeraMono.ttf'), 10, "#00aa00",
                              "%(num).2f", "%(short)s")
    render_timeseries_primitive('timeseries.png', vals, 600, h,
                                start_color="#aa0000", end_color="#aa0000",
                                max_fmt=max_fmt, min_fmt=min_fmt,)

#   histogram
    vals = dict((x, 0) for x in xrange(1,1001))
    for x in xrange(10000):
       samp = sum([random.randint(1, 101) for x in xrange(10)])
       vals[samp] += 1
#    histogram_primitive('histogram.png', vals.values(), w, h)


#   scatterplot
#    scatter_primitive('scatterplot.png', vals.keys(), vals.values(), w, h)

#   grid test
#    grid_test('grid_test.png', vals.keys(), vals.values(), W, H)

#   splom
    render_splom('splom.png', (vals.keys(), vals.values()), W, H)

#   font test
    render_font_test('font.png')
