"""
Namespace maintenance module. Namespaces are packages (i.e., modules that 
can contain other modules) which are dynamically loaded from locations 
that the administrator specifies in a configuration file. Unlike normal 
modules, RAVE reloads namespaces and modules beneath them when they change 
(if and only if the user uses the op_from_export() function to retrieve 
operations from the namespace as they are needed).
"""
from ConfigParser import ConfigParser, NoOptionError
import imp
import copy
import os
import os.path
import inspect
import sys
import __builtin__
from sets import Set
import threading
import errno
import pdb

import rave.exceptions as rexcept
import rave.log as rlog
import rave.plugins.decorators as rdeco

get_log = rlog.log_factory("org.cert.rave.plugins.names")



# Module variables

__all__ = [ 'load_namespaces', 'unload_namespaces', 'load_config'
          , 'load_config_fp', 'exported_names', 'op_from_export', 'init']


# Keeps track of the modules currently in the process of loading.
# Thus, modules that load other modules in the namespace won't
# traverse parts of the module tree that have already been loaded.
loads_in_progress = Set()


# Metadata about modules loaded under a namespace, including
# the namespace name, and which symbols are exported under
# what names.
#
# Structure:
#
# all_namespaces = { 
#     'namespacename' : <NamespaceInfo>
#   , 'namespacename' : <NamespaceInfo>
#   , ...
# }
all_namespaces = {}

# ConfigParser object containing the namespace
# configuration
namespace_config = None


# Lock for operations on all_namespaces (and possibly sys.modules)
# that require exclusivity
big_hairy_lock = threading.RLock()

# Constant used by modules underneath a namespace to refer to
# the namespace, usually so they can load some other module
# in the namespace, refer to a symbol in the namespace module, etc.
prefix = '__namespace__'


# Parts of this module can't use the normal logging infrastructure
# because they're called during import. (I'm not sure why this is.)
# To see debug messages from them on STDERR, change this function.
def debug(msg, *vars):
    pass
    #print >>sys.stderr, msg % vars

# Print out the value namespace_config, for debugging
def dump_nscfg():
    l = get_log("namespace_config")
    if namespace_config is None:
        l.debug("No namespace configuration file specified")
        return
    for s in namespace_config.sections():
        l.debug("[%s]", s)
        for o in namespace_config.options(s):
            l.debug("%s=%s", o, namespace_config.get(s, o))

# Determine name of namespace module from module name.
# This will either be the name of the module itself or
# the name of it's root parent module, e,g.:
# nsname_from_modname('a') -> 'a'
# nsname_from_modname('a.b.c') -> 'a'
def nsname_from_modname(modname):
    idx = modname.find('.')
    if idx == -1:
        return modname
    else:
        return modname[:idx]

# Decorator for functions that require exclusive access to
# all_namespaces
def atomic(fn):
    def wrapped(*args, **kwargs):
        #get_log().debug("Acquiring big hairy lock for %s", fn.__name__)
        big_hairy_lock.acquire()
        try:
            return fn(*args, **kwargs)
        finally:
            #get_log().debug("Releasing big hairy lock")
            big_hairy_lock.release()
    return wrapped



#   Update all_namespaces with additional references to the
#   module referred to by modname, or to the symbols in fromlist
#
#   This isn't perfect. It doesn't handle 'import foo as bar,' or
#   anything the caller might do with the reference once he has it
#   (like reassign it or delete it). It also doesn't handle references
#   in local namespaces which might (among other things) affect closures.
def add_references(modname, referrer_name, fromlist):
    module_topname = nsname_from_modname(modname)
    nsinfo = all_namespaces[module_topname]
    if fromlist is None:
    #   'import foo.bar.baz'
        nsinfo.add_module_reference(
            module_topname, referrer_name, module_topname)
    elif fromlist[0] == '*':
    #   'from foo.bar.baz import *'
        for sym in get_public_interface(sys.modules[modname]):
            nsinfo.add_symbol_reference(
                modname, referrer_name, sym, sym)
    else:
    #   'from foo.bar import baz, luhrman'
        mod = sys.modules[modname]
        for item in fromlist:
        #   It's an error to specify non-existant symbols in the fromlist,
        #   but the old_import will catch it, so we can ignore it here
            if not hasattr(mod, item):
                continue
            if inspect.ismodule(getattr(mod, item)):
                nsinfo.add_module_reference(
                    '.'.join(modname, item), referrer_name, reference)
            else:
                nsinfo.add_symbol_reference(
                    modname, referrer_name, item, item)

def is_contained_in(a, b):
    """
    Returns True if a is a path contained on the filesystem
    (logically, based on the filename) in b. Otherwise, False.
    """
#   Normalize the paths
    a = os.path.normpath(os.path.abspath(a))
    b = os.path.normpath(os.path.abspath(b))
#   a should be longer than b, or else it is not contained in b
    if len(a) <= len(b):
        return False
#   a.startswith(b) _almost_ does the trick, but if a = "/a/b/cde/f"
#   and b = "a/b/c", it doesn't quite work. So, truncate a to be
#   everything up to the the first directory separator after len(b)
#   (e.g., "a/b/cde") and test for equality.
    sep_idx = a.find(os.path.sep, len(b))
    if  sep_idx != -1:
        a = a[:sep_idx]
    return a == b

    

def resolve_relative_import(nspath, name, locals, globals, fromlist):
    """
    When importing a module with the same name as a namespace from
    within another namespace, figure out if we are importing from
    the other namespace (an error) or trying to import something
    in their own namespace that incidentally conflicts.

    To determine this, we import the module. If the import succeeds
    and the imported module lives under our namespace (as indicated
    by nspath), return the module. Otherwise, return an ImportError.

    Parameters:
    nspath: str
        path for the namespace under which this module should live
    name: str
        name argument from __import__
    globals: dict
        globals argument from __import__
    locals: dict
        locals argument from __import__
    fromlist: list
        fromlist argument from __import__

    Returns: None | Module object
    None if no suitable module could be located; otherwise, the
    Module object found.
    """
    loaded_modules = [x for x in sys.modules.keys()
                      if x.find(name) != -1]
    modpath = None
    try:
        mod = _old_import(name, locals, globals, fromlist)
        modpath = mod.__file__
    #   Is this module rooted under the namespace path?
        if is_contained_in(modpath, nspath):
            return mod
        else:
        #   Unload the module from sys.modules if we loaded it
            if mod.__name__ not in loaded_modules:
                sys.modules.pop(mod.__name__)
            raise ImportError(
                "Cannot load modules from one namespace into another")
    except ImportError:
    #   Couldn't load the module. Assume the user is trying to import
    #   from another namespace, but it hasn't been loaded yet. (This
    #   seems more likely than someone loading a non-existant module
    #   that just happens to share names with a namespace.)
        raise ImportError(
            "Cannot load modules from one namespace into another")


@atomic
def _import_hook(name, globals=None, locals=None, fromlist=None):
    debug("import('%s')", name)
#   If we don't know globals, we can't do anything useful
#   here. Luckily, this is usually going to be some third-party
#   module anyway, so we have no business doing anything with it
#   anyway.
#
#   The same holds true if namespace_config isn't loaded.
    if globals is None or namespace_config is None:
        return _old_import(name, globals, locals, fromlist)
#   Shorthand function for a klunky thing we do a lot
    def is_namespace_name(x):
        return namespace_config.has_option('namespaces', x)
#   Identify caller's location
    caller_topname = nsname_from_modname(globals['__name__'])
    module_topname = nsname_from_modname(name)
    #log = get_log("import_hook")
    if is_namespace_name(caller_topname):
        #log.debug("Caller is importing from within a namespace")
        if is_namespace_name(module_topname):
            #log.debug("Caller is importing something that might be "
            #          "from another namespace")
            return resolve_relative_import(
                namespace_config.get('namespaces', caller_topname),
                name, globals, locals, fromlist
            )
        if module_topname == prefix:
            #log.debug("Caller is looking for module in his namespace")
            name = name.replace(prefix, caller_topname, 1)
            module_topname = caller_topname
    if is_namespace_name(module_topname):
        #log.debug("Caller is requesting a namespace module")
        load_namespaces(module_topname)
        add_references(name, globals['__name__'], fromlist)
    return _old_import(name, globals, locals, fromlist)


_old_import = __builtin__.__import__
__builtin__.__import__ = _import_hook



def get_source_filename(mod):
    fname = mod.__file__
    if fname.endswith('.py'):
        return fname
    else:
        ext = fname.rfind('.')
        if ext == -1:
            raise rexcept.RaveException(
                "Can't cope with module filename: %s" % fname)
        return "%s.py" % fname[:ext]


def module_new_or_changed(modname):
#   Return True if a module is not yet loaded, or if an already-loaded
#   module has changed on disk.
    nsname = nsname_from_modname(modname)
    rc = None
    try:
        last_loaded_mtime = all_namespaces[nsname].get_mtime(modname)
    except KeyError:
        rc = True


    if rc is None:
        try:
            file_mtime = os.stat(get_source_filename(sys.modules[modname])).st_mtime
        except OSError, e:
            if e.errno == errno.ENOENT:
            #   The file has been removed since it was last loaded.
                raise rexcept.NoSuchModule(modname)
        rc = file_mtime != last_loaded_mtime
    #get_log().debug("Should I load %s? %s" % (modname, rc))
    return rc


def module_removed(modname):
#   Return True if a currently-loaded module has been removed from disk
    try:
        mod = sys.modules[modname]
    except KeyError:
    #   Module isn't loaded
        return False
    return not os.path.exists(mod.__file__)
        

def update_bindings(referrent, referrer_mod, reference_names):
#   Update references to a given module in a referring module,
#   updating the data structure keeping track of these references
#   as we go.
    new_reference_names= Set()
    for name in reference_names:
        if hasattr(referrer_mod, name):
            setattr(referrer_mod, name, referrent)
            new_reference_names.add(name)
    return new_reference_names


def update_references(referrent, referrers):
#   Update references to something in other modules, and the
#   data structure used to keep track of them.
    new_referrers = {}
    for referrer_modname, referring_names in referrers.items():
        try:
            referrer_mod = sys.modules[referrer_modname]
        except AttributeError:
            continue
        new_referring_names = update_bindings(
            referrent, referrer_mod, referring_names)
        new_referrers[referrer_modname] = new_referring_names
    return new_referrers


def remove_references(referrers):
#   Delete references in the referring modules to no-longer-existing symbols
    for referrer_modname, referring_names in referrers.items():
        try:
            referrer_mod = sys.modules[referrer_modname]
        except AttributeError:
            continue
        for name in referring_names:
            try:
                delattr(referrer_mod, name)
            except AttributeError:
                pass # Ignore if symbol no longer exists


def get_modules_from_all(mod):
#   Yes, 'not hasattr(...) or ismodule(...)'. If the module
#   doesn't have the attribute, we have to assume it's a package
#   with an unloaded module. The module could also be loaded, so
#   we also check with ismodule.
    def filterfunc(item):
        return ( not hasattr(mod, item) 
            or inspect.ismodule(getattr(mod, item)) )
    try:
        mod.__all__
    except AttributeError:
        return []

    return filter(filterfunc, mod.__all__)


def get_public_interface(mod):
    try:
        return mod.__all__
    except AttributeError:
        return filter(lambda x: not x.startswith('_'), dir(mod))


class NamespaceInfo(object):
    """
    A dual collection of ModuleInfo objects and exported symbols
    (all FileOperations) that make up a namespace.

    Avoid manipulating ModuleInfo objects directly.
    Call through the appropriate NamespaceInfo object instead.
    """

    def __init__(self):
    #   modinfos: ModuleInfo classes, keyed by (fully-qualified?) module name
        self.modinfos = {}
    #   exports: Symbols exported by this namespace, keyed by exported name
    #   (Basically a flattened version of all the __export__ symbols in the
    #   namespace)
        self.exports  = {}
        self.lock = threading.Lock()
    def modinfo(self, modname):
        if not self.modinfos.has_key(modname):
            self.modinfos[modname] = ModuleInfo()
            self.update_exports(sys.modules[modname])
        return self.modinfos[modname]
    def module_names(self):
        return self.modinfos.keys()
    def get_mtime(self, modname):
        return self.modinfos[modname].mtime
    def update_mtime(self, modname, new_mtime):
        self.modinfo(modname).mtime = new_mtime
    def add_module_reference(self, modname, referrer_name, reference):
        self.modinfo(modname).add_module_reference(
            referrer_name, reference)
    def add_symbol_reference(self, modname, referrer_name, symbol, reference):
        self.modinfo(modname).add_symbol_reference(
            referrer_name, symbol, reference)
    def clear_references(self, modname):
        def nuke(item):
            referrer_name, reference = item
            try:
                refmod = sys.modules[referrer_name]
            except KeyError:
                return
            if hasattr(refmod, reference):
                delattr(refmod, reference)
        map(nuke, self.modinfo(modname).all_references())
    def update_references(self, mod):
        modinfo = self.modinfo(mod.__name__)
        modinfo.update_module_references(mod)
        modinfo.update_symbol_references(mod)
    def update_exports(self, mod, old_mod=None):
        if old_mod is not None and hasattr(old_mod, '__export__'):
            old_exports = old_mod.__export__
        else:
            old_exports = {}

        if hasattr(mod, '__export__'):
            get_log().debug("updating exports with %s", mod.__export__)
            new_exports = mod.__export__
            self.exports.update(new_exports)
        else:
            new_exports = {}

    #   Remove name in old_exports if it's not in new_exports and
    #   hasn't been redefined by something else in self.exports
        for name in old_exports.keys():
            if ( name not in new_exports 
                 and name in self.exports
                 and old_exports[name] == self.exports[name] ):
                self.exports.pop(name)
    def clear_exports(self, modname):
        mod = sys.modules[modname]
        if hasattr(mod, '__export__'):
            for name, sym in mod.__export__.items():
                if sym != self.exports[name]:
                #   Someone's redefined this symbol, possibly in
                #   another module under the namespace. We could
                #   make some noises here....I'll go with ignoring
                #   it for now.
                    continue
                self.exports.pop(name)
    def exported_names(self):
        get_log().debug("keys: %s" % self.exports.keys())
        return self.exports.keys()
    def query(self, name):
        return self.exports[name]

class ModuleInfo(object):
    """
    Meta-information about modules loaded into namespaces. In particular,
    ModuleInfo stores info on modules that refer to this module or to
    symbols in this module.
    
    Note that the ModuleInfo doesn't store a reference to the module
    object with which it is associated. If you need the module, grab it
    from sys.modules. (The ModuleInfo doesn't store the name either, so
    you will need to know that a priori.)
    """
    def __init__(self):
    #   The mtime of the file on disk representing this module, at
    #   the time the module was loaded
        self.mtime = 0
    #   Modules containing a reference to this module. Keys are
    #   referring module names; values are the name used to refer
    #   to the module
        self.referrers = {}
    #   Modules containing a reference to a symbol in this module.
    #   Keys are symbols; values are dictionaries identical in structure
    #   to self.referrers (key = referring module name, value = reference
    #   name)
        self.symbol_refs = {}
    def add_module_reference(self, referrer_name, reference):
        self.referrers[referrer_name] = reference
    def add_symbol_reference(self, referrer_name, symbol, reference):
        if not self.symbol_refs.has_key(symbol):
            self.symbol_refs[symbol] = {}
        self.symbol_refs[symbol][referrer_name] = reference
    def module_references(self):
        for referrer_name, references in self.referrers.items():
            yield referrer_name, references
    def symbol_references(self):
        for symname, referrers in self.symbol_refs.items():
            for referrer_name, references in referrers.items():
                yield sym, referrer_name, references
    def all_references(self):
        for referrer_name, references in self.module_references():
            yield referrer_name, references
        for referrers in self.symbol_refs.values():
            for referrer_name, references in referrers.items():
                yield referrer_name, references
    def update_module_references(self, mod):
        self.referrers = update_references(mod, self.referrers)
    def update_symbol_references(self, mod):
        new_symbol_refs = {}
        for symname, referrers in self.symbol_refs.items():
        #   ignore some reserved names
            if symname in (  '__all__', '__builtins__', '__doc__'
                           , '__file__', '__name__', '__path__'):
                continue
            try:
                get_log().debug("Getting %s", symname)
                sym = getattr(mod, symname)
                get_log().debug("Got it")
            except AttributeError:
            #   sym existed and was referenced in an older version of
            #   this module, but went away.
                get_log().debug("Removing stale references to %s.%s",
                    mod.__name__, symname)
                remove_references(referrers)
                continue
        #   ignore modules - shouldn't show up in symbol_refs
        #    if inspect.ismodule(sym):
        #        continue
            new_referrers = update_references(sym, referrers)
            if len(new_referrers) != 0:
                new_symbol_refs[symname] = new_referrers
        self.symbol_refs = new_symbol_refs


# Wrap any Operations in __export__ with FileOperations
def convert_exports(mod):
    def exportable(k, v):
        is_exportable = False
        try:
            is_exportable = v.rave_exportable()
        except AttributeError:
            pass # Ignored -- is_exportable is already set correctly
        if not is_exportable:
            get_log().warn(
                "%s: Removing %s from export list "
                "(can only export @ops and @op_files)", mod.__name__, k)
        return is_exportable

    if hasattr(mod, '__export__'):
        get_log().debug("Converting exports: %s", mod.__export__)
        mod.__export__ = dict(
            (k, v.export()) for k, v in mod.__export__.items()
                if exportable(k, v)
        )



def load_mod(modname, modname_on_disk, modpaths):
#   Load the module
    old_mod = sys.modules.pop(modname, None) # Ensure we load from disk
    get_log().debug("find_module('%s', '%s')", modname_on_disk, modpaths)
    fd, fname, stuff = imp.find_module(modname_on_disk, modpaths)
    mod = imp.load_module(modname, fd, fname, stuff)
    convert_exports(mod)
#   Update meta-information about the module:
    nsname = nsname_from_modname(mod.__name__)
    if not all_namespaces.has_key(nsname):
        all_namespaces[nsname] = NamespaceInfo()
    nsinfo = all_namespaces[nsname]
#   Note: there's a little bit of a race condition here -- someone could
#   replace the file on the filesystem between loading and getting the mtime.
    nsinfo.update_mtime(
        mod.__name__, os.stat(get_source_filename(mod)).st_mtime)
    nsinfo.update_references(mod)
    nsinfo.update_exports(mod, old_mod)


def unload_mod(nsinfo, modname):
    nsinfo.clear_exports(modname)
    nsinfo.clear_references(modname)
    sys.modules.pop(modname, None)


# Load the child modules/packages of a given module
def load_children(mod):
    if hasattr(mod, '__all__'):
        childbases = get_modules_from_all(mod)
        debug("Child modules: %s", childbases)
        for childbase in childbases:
            childname = "%s.%s" % (mod.__name__, childbase)
            if childname in loads_in_progress:
                debug("%s load already in progress: skipping", childname)
                continue
            else:
                debug("%s is not in progress: loading", childname)
            loads_in_progress.add(childname)
            try:
                if module_new_or_changed(childname):
                    get_log().info("Loading module %s" % childname)
                    load_mod(childname, childbase, mod.__path__)
                    setattr(mod, childbase, sys.modules[childname])
                load_children(sys.modules[childname])
            finally:
                loads_in_progress.remove(childname)


# Load a single namespace-level module and all its children
def load_toplevel(nsname, nspath):
#    if nsname in loads_in_progress:
#        get_log().debug("%s load already in progress: skipping", nsname)
#        return
    loads_in_progress.add(nsname)
    try:
        if module_new_or_changed(nsname):
            get_log().info("Loading namespace %s (from %s)", nsname, nspath)
            nspath = copy.copy(nspath)
            while nspath.endswith('/'):
                nspath = nspath[:-1]
            path, dirname = os.path.split(nspath)
            if not path or not dirname:
                raise rexcept.ConfigurationError("Syntax error in namespace entry")
            load_mod(nsname, dirname, [path])
        load_children(sys.modules[nsname])
    finally:
        loads_in_progress.remove(nsname)


def load_config_fp(fp, default_nsroot='/'):
    global namespace_config
    cfg = ConfigParser({'nsroot':default_nsroot})
    cfg.readfp(fp)
    namespace_names = cfg.options('namespaces')
    if not namespace_names:
        try:
            errstr = "No namespaces defined in %s" % fp.name
        except AttributeError:
            errstr = "No namespaces defined in config file"
        raise rexcept.ConfigurationError(errstr)
    for nsname in namespace_names:
    #   normalize multiple slashes
        nspath = cfg.get('namespaces', nsname)
        cfg.set('namespaces', nsname, os.path.normpath(nspath))
    namespace_config = cfg


def load_config(f, default_nsroot='/'):
    opened = False
    try:
        f.read
    except AttributeError:
        f = open(f, 'r')
        opened = True
    try:
        return load_config_fp(f, default_nsroot)
    finally:
        if opened:
            f.close()


@atomic
def load_namespaces(*to_load):
    cfg = namespace_config
    if cfg is None:
        raise RuntimeError(
            "Load namespace configuration with names.load_config() "
            "before loading namespaces."
        )
    numloaded = 0
    exceptions = {}
    if len(to_load) == 0:
        to_load = cfg.options('namespaces')
    get_log().debug("Namespaces to load: %s", to_load)
    for nsname in to_load:
    #   rave is a reserved namespace. nsroot (and presumably
    #   other defaults) appear unbidden (thanks, ConfigParser
    #   implementation!)
        try:
            if nsname in ('rave', 'nsroot'):
                get_log().debug("'%s' is a reserved namespace. Ignoring", nsname)
                continue
            try:
                nspath = cfg.get('namespaces', nsname)
            except NoOptionError:
                raise rexcept.NoSuchNamespace(nsname)
            if not nspath:
                raise rexcept.ConfigurationError("No path defined")
            load_toplevel(nsname, nspath)
            numloaded+= 1
        except Exception:
            exceptions[nsname] = sys.exc_info()
    return (numloaded, exceptions)


@atomic
def unload_namespaces(*to_unload):
    if len(to_unload) == 0:
        to_unload = all_namespaces.keys()
    else:
    #   ignore already-unloaded namespaces
        to_unload = filter(all_namespaces.has_key, to_unload)
    for nsname in to_unload:
        nsinfo = all_namespaces.pop(nsname)
        for modname in nsinfo.module_names():
            unload_mod(nsinfo, modname)

@atomic
def exported_names():
    l = get_log()
    l.debug("all_namespaces: %s", str(all_namespaces))
    return dict((nsname, ns.exported_names())
             for nsname, ns in all_namespaces.items())

@atomic
def op_from_export(namespace, queryname):
    """
    Identify and return the Operation that maps to the given export name in
    the given namespace.
    """
    get_log().debug("Getting %s from %s", queryname, namespace)
    numloaded, exceptions = load_namespaces(namespace)
    if not numloaded:
        v = exceptions[namespace][1]
        if issubclass(exceptions[namespace][0], rexcept.NoSuchOperation):
            raise v
        raise rexcept.NamespaceLoadError(namespace, exceptions[namespace])
    try:
        return all_namespaces[namespace].query(queryname)
    except KeyError:
        raise rexcept.NoSuchOperation('%s/%s' % (namespace, queryname))

@atomic
def init(cfg, default_nsroot="/"):
    """
    Initialize namespaces. This is equivalent to:
        load_config(cfg)
        load_namespaces()
    """
    load_config(cfg, default_nsroot)
    return load_namespaces()
