diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/python/pytest/rewrite.py | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/python/pytest/rewrite.py')
-rw-r--r-- | library/python/pytest/rewrite.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/library/python/pytest/rewrite.py b/library/python/pytest/rewrite.py new file mode 100644 index 0000000000..ec188d847f --- /dev/null +++ b/library/python/pytest/rewrite.py @@ -0,0 +1,123 @@ +from __future__ import absolute_import +from __future__ import print_function + +import ast + +import py + +from _pytest.assertion import rewrite +try: + import importlib.util +except ImportError: + pass +from __res import importer +import sys +import six + + +def _get_state(config): + if hasattr(config, '_assertstate'): + return config._assertstate + return config._store[rewrite.assertstate_key] + + +class AssertionRewritingHook(rewrite.AssertionRewritingHook): + def __init__(self, *args, **kwargs): + self.modules = {} + super(AssertionRewritingHook, self).__init__(*args, **kwargs) + + def find_module(self, name, path=None): + co = self._find_module(name, path) + if co is not None: + return self + + def _find_module(self, name, path=None): + state = _get_state(self.config) + if not self._should_rewrite(name, None, state): + return None + state.trace("find_module called for: %s" % name) + + try: + if self.is_package(name): + return None + except ImportError: + return None + + self._rewritten_names.add(name) + + state.trace("rewriting %s" % name) + co = _rewrite_test(self.config, name) + if co is None: + # Probably a SyntaxError in the test. + return None + self.modules[name] = co, None + return co + + def find_spec(self, name, path=None, target=None): + co = self._find_module(name, path) + if co is not None: + return importlib.util.spec_from_file_location( + name, + co.co_filename, + loader=self, + ) + + def _should_rewrite(self, name, fn, state): + if name.startswith("__tests__.") or name.endswith(".conftest"): + return True + + return self._is_marked_for_rewrite(name, state) + + def is_package(self, name): + return importer.is_package(name) + + def get_source(self, name): + return importer.get_source(name) + + if six.PY3: + def load_module(self, module): + co, _ = self.modules.pop(module.__name__) + try: + module.__file__ = co.co_filename + module.__cached__ = None + module.__loader__ = self + module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self) + exec(co, module.__dict__) + except: # noqa + if module.__name__ in sys.modules: + del sys.modules[module.__name__] + raise + return sys.modules[module.__name__] + + def exec_module(self, module): + if module.__name__ in self.modules: + self.load_module(module) + else: + super(AssertionRewritingHook, self).exec_module(module) + + +def _rewrite_test(config, name): + """Try to read and rewrite *fn* and return the code object.""" + state = _get_state(config) + + source = importer.get_source(name) + if source is None: + return None + + path = importer.get_filename(name) + + try: + tree = ast.parse(source, filename=path) + except SyntaxError: + # Let this pop up again in the real import. + state.trace("failed to parse: %r" % (path,)) + return None + rewrite.rewrite_asserts(tree, py.path.local(path), config) + try: + co = compile(tree, path, "exec", dont_inherit=True) + except SyntaxError: + # It's possible that this error is from some bug in the + # assertion rewriting, but I don't know of a fast way to tell. + state.trace("failed to compile: %r" % (path,)) + return None + return co |