# -*- coding: utf-8 -*-
"""Utilities related to importing modules and symbols by name."""
from __future__ import absolute_import, unicode_literals

import importlib
import os
import sys
import warnings
from contextlib import contextmanager

from kombu.utils.imports import symbol_by_name

from celery.five import reload

#: Billiard sets this when execv is enabled.
#: We use it to find out the name of the original ``__main__``
#: module, so that we can properly rewrite the name of the
#: task to be that of ``App.main``.
MP_MAIN_FILE = os.environ.get('MP_MAIN_FILE')

__all__ = (
    'NotAPackage', 'qualname', 'instantiate', 'symbol_by_name',
    'cwd_in_path', 'find_module', 'import_from_cwd',
    'reload_from_cwd', 'module_file', 'gen_task_name',

[docs]class NotAPackage(Exception): """Raised when importing a package, but it's not a package."""
if sys.version_info > (3, 3): # pragma: no cover def qualname(obj): """Return object name.""" if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): obj = obj.__class__ q = getattr(obj, '__qualname__', None) if '.' not in q: q = '.'.join((obj.__module__, q)) return q else:
[docs] def qualname(obj): # noqa """Return object name.""" if not hasattr(obj, '__name__') and hasattr(obj, '__class__'): obj = obj.__class__ return '.'.join((obj.__module__, obj.__name__))
[docs]def instantiate(name, *args, **kwargs): """Instantiate class by name. See Also: :func:`symbol_by_name`. """ return symbol_by_name(name)(*args, **kwargs)
[docs]@contextmanager def cwd_in_path(): """Context adding the current working directory to sys.path.""" cwd = os.getcwd() if cwd in sys.path: yield else: sys.path.insert(0, cwd) try: yield cwd finally: try: sys.path.remove(cwd) except ValueError: # pragma: no cover pass
[docs]def find_module(module, path=None, imp=None): """Version of :func:`imp.find_module` supporting dots.""" if imp is None: imp = importlib.import_module with cwd_in_path(): try: return imp(module) except ImportError: # Raise a more specific error if the problem is that one of the # dot-separated segments of the module name is not a package. if '.' in module: parts = module.split('.') for i, part in enumerate(parts[:-1]): package = '.'.join(parts[:i + 1]) try: mpart = imp(package) except ImportError: # Break out and re-raise the original ImportError # instead. break try: mpart.__path__ except AttributeError: raise NotAPackage(package) raise
[docs]def import_from_cwd(module, imp=None, package=None): """Import module, temporarily including modules in the current directory. Modules located in the current directory has precedence over modules located in `sys.path`. """ if imp is None: imp = importlib.import_module with cwd_in_path(): return imp(module, package=package)
[docs]def reload_from_cwd(module, reloader=None): """Reload module (ensuring that CWD is in sys.path).""" if reloader is None: reloader = reload with cwd_in_path(): return reloader(module)
[docs]def module_file(module): """Return the correct original file name of a module.""" name = module.__file__ return name[:-1] if name.endswith('.pyc') else name
[docs]def gen_task_name(app, name, module_name): """Generate task name from name/module pair.""" module_name = module_name or '__main__' try: module = sys.modules[module_name] except KeyError: # Fix for shell_plus (Issue #366) module = None if module is not None: module_name = module.__name__ # - If the task module is used as the __main__ script # - we need to rewrite the module part of the task name # - to match App.main. if MP_MAIN_FILE and module.__file__ == MP_MAIN_FILE: # - see comment about :envvar:`MP_MAIN_FILE` above. module_name = '__main__' if module_name == '__main__' and app.main: return '.'.join([app.main, name]) return '.'.join(p for p in (module_name, name) if p)
def load_extension_class_names(namespace): try: from pkg_resources import iter_entry_points except ImportError: # pragma: no cover return for ep in iter_entry_points(namespace): yield, ':'.join([ep.module_name, ep.attrs[0]]) def load_extension_classes(namespace): for name, class_name in load_extension_class_names(namespace): try: cls = symbol_by_name(class_name) except (ImportError, SyntaxError) as exc: warnings.warn( 'Cannot load {0} extension {1!r}: {2!r}'.format( namespace, class_name, exc)) else: yield name, cls