"""
Classes used to support threading in RAVE.
"""
import pdb
import bisect
import copy
import threading
import time

import rave.log as rlog

get_log = rlog.log_factory("org.cert.rave.threads")


class ShutdownRequest(Exception):
    "Application is shutting down."
    pass

class Barrier(object):
    """
    Thread barrier lock.

    This class functions as a reverse semaphore. Calls to wait() will
    block until this classes count goes to zero, at which point they 
    will be notifyAll()ed.
    """
    def __init__(self):
        self.count = 0
        self.cond = threading.Condition()
    def increment(self):
        get_log("lock").debug("%s increment self.cond [%s]", self,
                              threading.currentThread())
        self.cond.acquire()
        try:
            self.count += 1
        finally:
            get_log("unlock").debug("%s increment self.cond [%s]", self,
                                    threading.currentThread())
            self.cond.release()
    def decrement(self):
        get_log("lock").debug("%s decrement self.cond [%s]", self,
                              threading.currentThread())
        self.cond.acquire()
        try:
            self.count -= 1
        finally:
            get_log("unlock").debug("%s decrement self.cond [%s]", self,
                                    threading.currentThread())
            self.cond.release()
    def wait(self, timeout=0):
        if self.count == 0:
            return
        get_log("lock").debug("%s wait [%s]", self,
                              threading.currentThread())
        self.cond.acquire()
        try:
            get_log("unlock").debug("%s wait(wait) [%s]", self,
                                    threading.currentThread())
            self.cond.wait(timeout)
            get_log("lock").debug("%s wait(wait) [%s]", self,
                                  threading.currentThread())
        finally:
            get_log("unlock").debug("%s wait [%s]", self,
                                  threading.currentThread())
            self.cond.release()

class Worker(threading.Thread):
    def __init__(self):
        super(Worker, self).__init__()
        self.work = None
        self.has_work = threading.Condition()
        self.should_stop = False
        self.evt_stopped = threading.Event()

    def run(self):
        log = get_log("Worker.run.%s" % self.getName())
        if self.should_stop:
            log.debug("Will not start: shutdown has been called")
            return
        log.debug("Worker starting")
        get_log("lock").debug("%s run [%s]", self,
                              threading.currentThread())
        self.has_work.acquire()
        try:
            log.debug("Entering work loop")
            while True:
                if self.work is None:
                    get_log("unlock").debug("%s run(wait) [%s]", self,
                                            threading.currentThread())
                    self.has_work.wait()
                    get_log("lock").debug("%s run(wait) [%s]", self,
                                          threading.currentThread())
                #   At this point, we either have work or a stop is signaled
                    if self.should_stop:
                        break
                try:
                    try:
                        self.work()
                    except:
                        log.exception("Uncaught error in job")
                finally:
                    self.work = None

            log.debug("Left work loop. End of life....")
        #   End of life
            self.evt_stopped.set()
        finally:
            get_log("unlock").debug("%s run [%s]", self,
                                    threading.currentThread())
            self.has_work.release()

    def assign(self, job):
        get_log("lock").debug("%s assign [%s]", self,
                              threading.currentThread())
        self.has_work.acquire()
        try:
            self.work = job
            get_log("unlock").debug("%s assign(notify) [%s]", self,
                                    threading.currentThread())
            self.has_work.notify()
            get_log("lock").debug("%s assign(notify) [%s]", self,
                                  threading.currentThread())            
        finally:
            get_log("unlock").debug("%s assign [%s]", self,
                                    threading.currentThread())
            self.has_work.release()

    def stop(self, barrier):
        log = get_log("Worker.stop")
        log.debug("Stopping %s", self)
        get_log("lock").debug("%s stop [%s]", self,
                              threading.currentThread())
        self.has_work.acquire()
        log.debug("Lock acquired for %s", self)
        try:
            log.debug("setting self.should_stop to True")
            self.should_stop = True
            log.debug("notify()ing self.has_work")
            get_log("unlock").debug("%s stop(notifyAll) [%s]", self,
                                    threading.currentThread())
            self.has_work.notifyAll()
            get_log("lock").debug("%s stop(notifyAll) [%s]", self,
                                  threading.currentThread())
        finally:
            get_log("unlock").debug("%s stop [%s]", self,
                                    threading.currentThread())
            self.has_work.release()

        log.debug("%s: Waiting for evt_stopped", self)
        self.evt_stopped.wait()
        #log.debug("%s: evt_stopped set. Quitting.", self)


class Job(object):
    def __init__(self, func, run_at):
        self.func = func
        self.run_at = run_at
    def __call__(self):
        self.func()
    def __cmp__(self, other):
        if self.run_at < other.run_at:
            return -1
        if self.run_at > other.run_at:
            return 1
        return 0
    def __str__(self):
        try:
            fname = self.func.func_name
        except AttributeEror:
            fname = str(self.func)
        return "<Job '%s', runs at %f>" % (fname, self.run_at)
    def next_job(self):
        """
        Get job that should follow up on this one, or None if
        such a job does not exist.
        """
        return None


class RecurrentJob(Job):
    def __init__(self, func, interval):
        super(RecurrentJob, self).__init__(func, time.time() + interval)
        self.interval = interval
    def next_job(self):
        job = RecurrentJob(self.func, self.interval)
        get_log('RecurrentJob.next_job').debug("Returning %s", job)
        return job


class PriorityQueue(object):
    """
    A priority queue implementation. Not thread-safe.
    """
    def __init__(self):
        self.items = []
    def push(self, item):
        bisect.insort(self.items, item)
    def pop(self):  # raises IndexError
        return self.items.pop(0)
    def peek(self): # raises IndexError
        return self.items[0]

class JobQueue(object):
    def __init__(self):
        self.queue = PriorityQueue()
    #   Set by add_job when a new job has been added to the queue
        self.job_awaits = threading.Condition()
        self.job_added = False
        self.should_stop = False

    def _get_free_time(self):
        """
        Get time interval between now and next scheduled job. If there is
        no job currently scheduled, return None.

        The return value of _get_free_time can be passed directly into
        the Condition.wait method. It can be None (for which wait will
        wait forever), a positive float (wait for the specified interval)
        or a negative float (return instantly).
        """
        try:
            return self.queue.peek().run_at - time.time()
        except IndexError:
            return None

    def get(self):
        """
        Get a job that needs to be run from queue, blocking until one
        is available.

        The only thread ever calling get_job for a given instance is the
        one running in the Scheduler object that owns this JobQueue
        instance.

        """
        log = get_log("JobQueue.get")
    #   Check for a stop request that might have come in while
    #   the calling thread was busy doing something else.
        if self.should_stop:
            raise ShutdownRequest()
        get_log("lock").debug("%s get [%s]", self,
                              threading.currentThread())
        self.job_awaits.acquire()
        try:
            while True:
            #   Ensure that we're only tracking jobs added after get_job
            #   goes into wait
                self.job_added = False
                free_time = self._get_free_time()
                if free_time is None:
                    log.debug("Waiting forever for job to need doing")
                else:
                    log.debug("Waiting %f for job to need doing", free_time)

                get_log("unlock").debug("%s get(wait) [%s]", self,
                                        threading.currentThread())
                self.job_awaits.wait(free_time)
                get_log("lock").debug("%s get(wait) [%s]", self,
                                      threading.currentThread())
                log.debug("done waiting")
            #   Reasons wait() may have finished:
            #
            #   1. We have been asked to shut down
                if self.should_stop:
                    log.debug("shutdown request. shutting down")
                    raise ShutdownRequest()
            #   2. There's a new job on the queue. Refresh timeout info
                elif self.job_added:
                    log.debug("job added. refreshing timeout info")
                    continue
            #   3. Our timeout expired. Return the top job.
                else:
                    log.debug("time to return the job")
                    return self.queue.pop()
        finally:
            get_log("unlock").debug("%s get [%s]", self,
                                    threading.currentThread())
            self.job_awaits.release()

    def add(self, job):
        """
        Add a job to the queue.
        """
        get_log('JobQueue.add').debug("Adding %s to job queue", job)
        get_log('lock').debug("%s add [%s]", self,
                              threading.currentThread())
        self.job_awaits.acquire()
        try:
            self.queue.push(job)
            self.job_added = True
            get_log('unlock').debug("%s add(notify) [%s]", self,
                                    threading.currentThread())
            self.job_awaits.notify()
            get_log('lock').debug("%s add(notify) [%s]", self,
                                  threading.currentThread())            
        finally:
            get_log('unlock').debug("%s add [%s]", self,
                                    threading.currentThread())
            self.job_awaits.release()

    def signal_stop(self):
        """
        Request that get_job's calling thread shut down as soon as possible.
        (Right now if it's in get_job, or on the next call to get_job otherwise.)
        """
        get_log('lock').debug("%s signal_stop [%s]", self,
                              threading.currentThread())
        self.job_awaits.acquire()
        try:
            self.should_stop = True
            get_log('unlock').debug("%s signal_stop(notify) [%s]", self,
                                    threading.currentThread())
            self.job_awaits.notify()
            get_log('lock').debug("%s signal_stop(notify) [%s]", self,
                                  threading.currentThread())
        finally:
            get_log('unlock').debug("%s signal_stop [%s]", self,
                                    threading.currentThread())
            self.job_awaits.release()



class WorkerPool(object):
    """
    A queue of available worker threads.
    """
    def __init__(self, num_threads):
        self.all = list(Worker() for x in xrange(num_threads))
        self.avail = copy.copy(self.all)
        self.worker_is_free = threading.Condition()
        self.should_stop = False
        for w in self.all:
            w.start()
    def all_workers(self):
        return iter(self.all)
    def acquire(self):
        """
        Acquire a worker from the pool. Should only ever be called by the
        thread running in the Scheduler instance which owns this WorkerPool.
        """
        log = get_log("WorkerPool.acquire")
        if self.should_stop:
            raise ShutdownRequest()
        log.debug("Acquiring lock on worker_is_free")
        get_log('lock').debug("%s acquire [%s]", self,
                              threading.currentThread())
        self.worker_is_free.acquire()
        log.debug("Got lock")
        try:
        #   This is a good deal simpler than JobQueue, because there are
        #   no situations where notify just means we have to go back to
        #   waiting....
            if len(self.avail) == 0:
                log.debug("No available workers. Waiting for a free one...")
                get_log('unlock').debug("%s acquire(wait) [%s]", self,
                                      threading.currentThread())
                self.worker_is_free.wait()
                get_log('lock').debug("%s acquire(wait) [%s]", self,
                                      threading.currentThread())
                if self.should_stop:
                    log.debug("Shutdown request. Shutting down....")
                    raise ShutdownRequest()
                log.debug("Got a worker")
            worker = self.avail.pop(0)
            log.debug("Returning worker '%s'" % worker.getName())
            return worker
        finally:
            get_log('unlock').debug("%s acquire [%s]", self,
                                    threading.currentThread())
            self.worker_is_free.release()
    def release(self, thread):
        """
        Release a worker back into the pool. Should only ever be called by the
        worker thread itself.)
        """
        log = get_log("WorkerPool.release")
        log.debug(
            "Releasing %s back to pool. Acquiring lock....", thread.getName()
        )
        get_log('lock').debug("%s release [%s]", self,
                                threading.currentThread())
        self.worker_is_free.acquire()
        try:
            log.debug("Returning %s to pool", thread.getName())
            self.avail.append(thread)
            log.debug("Notifying worker_is_free")
            get_log('unlock').debug("%s release(notify) [%s]", self,
                                    threading.currentThread())
            self.worker_is_free.notify()
            get_log('lock').debug("%s release(notify) [%s]", self,
                                  threading.currentThread())
        finally:
            log.debug("Releasing worker_is_free")
            get_log('unlock').debug("%s release [%s]", self,
                                    threading.currentThread())
            self.worker_is_free.release()
    def signal_stop(self):
        """
        Request that acquire's calling thread shut down as soon as possible.
        (Either right now if it's waiting in acquire, or on the next call to
        acquire otherwise.)
        """
        get_log('lock').debug("%s signal_stop [%s]", self,
                              threading.currentThread())
        self.worker_is_free.acquire()
        try:
            self.should_stop = True
            get_log('unlock').debug("%s signal_stop(notify) [%s]", self,
                                  threading.currentThread())
            self.worker_is_free.notify()
            get_log('lock').debug("%s signal_stop(notify) [%s]", self,
                                  threading.currentThread())
        finally:
            get_log('lock').debug("%s signal_stop [%s]", self,
                                  threading.currentThread())
            self.worker_is_free.release()



class Scheduler(threading.Thread):

    def make_job_runner(self, worker, job):
        def job_runner():
            log = get_log("job_runner.%s" % worker.getName())
            try:
                job()
            finally:
                log.debug("Releasing %s", worker.getName())
                self.workers.release(worker)
                next_job = job.next_job()
                if next_job:
                    self.jobs.add(next_job)
        return job_runner


    def __init__(self, num_threads=1):
        super(Scheduler, self).__init__()
        self.jobs = JobQueue()
        self.workers = WorkerPool(num_threads)
        self.evt_stopped = threading.Event()
        self.is_started = False

    def run(self):
        self.is_started = True
        log = get_log("Scheduler.run.%s" % self.getName())
        log.debug("Scheduler starting")
        try:
            while True:
                log.debug("Getting a job...")
                job = self.jobs.get()
                log.debug("got (%s). Acquiring a worker...", job)
                worker = self.workers.acquire()
                job_runner = self.make_job_runner(worker, job)
                log.debug("Assigning %s to %s", job, worker.getName())
                worker.assign(job_runner)
        except ShutdownRequest:
            log.debug("Received shutdown request")
        except:
            log.exception("Error in scheduler")
    #   We were dropped out of the main loop, either by error
    #   or design. Clean up and head out
        try:
            b = Barrier()
            for w in self.workers.all_workers():
                b.increment()
                log.debug("Stopping %s", w)
                w.stop(b)
            log.debug("Waiting for workers to stop")
            b.wait()
            self.evt_stopped.set()
        except:
            log.exception("Error cleaning up in scheduler")
            raise

    def schedule_at(self, func, at):
        """
        Schedule a job to run at a certain point in time.
        """
        self.jobs.add(Job(func, at))

    def schedule_after(self, func, delay):
        """
        Schedule a job to run after a specified delay
        """
        self.jobs.add(Job(func, time.time() + delay))

    def schedule_every(self, func, interval):
        """
        Schedule a job to run repeatedly every interval seconds.
        """
        #self.jobs.add(Job(func, time.time() + interval, recurring=True))
        self.jobs.add(RecurrentJob(func, interval))

    def schedule(self, func):
        """
        Schedule a job to be done as soon as possible (but not before
        existing jobs that should already have been run).

        This is exactly equivalent to schedule_at(func, time.time())
        """
        self.jobs.add(Job(func, time.time()))

    def stop(self, timeout=None):
        log = get_log("Scheduler.stop")
        if self.is_started:
            log.debug("Signaling JobQueue")
            self.jobs.signal_stop()
            log.debug("Signaling WorkerPool")
            self.workers.signal_stop()
            self.evt_stopped.wait(timeout)
        else:
        #   Special case, just stop all the workers
        #   and call it a day
            b = Barrier()
            for w in self.workers.all_workers():
                b.increment()
                log.debug("Stopping %s", w)
                w.stop(b)
            log.debug("Waiting for workers to stop")
            b.wait()

    def is_stopped(self):
        return self.evt_stopped.isSet()



