from __future__ import absolute_import from __future__ import print_function import ast import py from _pytest.assertion import rewrite try: import importlib.util except ImportError: pass try: from pathlib import Path except ImportError: pass from __res import importer import sys import six def _get_state(config): if hasattr(config, '_assertstate'): return config._assertstate return config._store[rewrite.assertstate_key] class AssertionRewritingHook(rewrite.AssertionRewritingHook): def __init__(self, *args, **kwargs): self.modules = {} super(AssertionRewritingHook, self).__init__(*args, **kwargs) def find_module(self, name, path=None): co = self._find_module(name, path) if co is not None: return self def _find_module(self, name, path=None): state = _get_state(self.config) if not self._should_rewrite(name, None, state): return None state.trace("find_module called for: %s" % name) try: if self.is_package(name): return None except ImportError: return None if hasattr(self._rewritten_names, 'add'): self._rewritten_names.add(name) else: self._rewritten_names[name] = Path(path[0]) state.trace("rewriting %s" % name) co = _rewrite_test(self.config, name) if co is None: # Probably a SyntaxError in the test. return None self.modules[name] = co, None return co def find_spec(self, name, path=None, target=None): co = self._find_module(name, path) if co is not None: return importlib.util.spec_from_file_location( name, co.co_filename, loader=self, ) def _should_rewrite(self, name, fn, state): if name.startswith("__tests__.") or name.endswith(".conftest"): return True return self._is_marked_for_rewrite(name, state) def is_package(self, name): return importer.is_package(name) def get_source(self, name): return importer.get_source(name) if six.PY3: def load_module(self, module): co, _ = self.modules.pop(module.__name__) try: module.__file__ = co.co_filename module.__cached__ = None module.__loader__ = self module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self) exec(co, module.__dict__) except: # noqa if module.__name__ in sys.modules: del sys.modules[module.__name__] raise return sys.modules[module.__name__] def exec_module(self, module): if module.__name__ in self.modules: self.load_module(module) else: super(AssertionRewritingHook, self).exec_module(module) def _rewrite_test(config, name): """Try to read and rewrite *fn* and return the code object.""" state = _get_state(config) source = importer.get_source(name) if source is None: return None path = importer.get_filename(name) try: tree = ast.parse(source, filename=path) except SyntaxError: # Let this pop up again in the real import. state.trace("failed to parse: %r" % (path,)) return None rewrite.rewrite_asserts(tree, py.path.local(path), config) try: co = compile(tree, path, "exec", dont_inherit=True) except SyntaxError: # It's possible that this error is from some bug in the # assertion rewriting, but I don't know of a fast way to tell. state.trace("failed to compile: %r" % (path,)) return None return co