diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
commit | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch) | |
tree | 64175d5cadab313b3e7039ebaa06c5bc3295e274 /library/python/pytest/rewrite.py | |
parent | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff) | |
download | ydb-e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'library/python/pytest/rewrite.py')
-rw-r--r-- | library/python/pytest/rewrite.py | 120 |
1 files changed, 60 insertions, 60 deletions
diff --git a/library/python/pytest/rewrite.py b/library/python/pytest/rewrite.py index afd0abc782..4cebcb1edd 100644 --- a/library/python/pytest/rewrite.py +++ b/library/python/pytest/rewrite.py @@ -3,13 +3,13 @@ from __future__ import print_function import ast -import py - +import py + from _pytest.assertion import rewrite -try: - import importlib.util -except ImportError: - pass +try: + import importlib.util +except ImportError: + pass try: from pathlib import Path @@ -17,29 +17,29 @@ except ImportError: pass from __res import importer -import sys -import six +import sys +import six + + +def _get_state(config): + if hasattr(config, '_assertstate'): + return config._assertstate + return config._store[rewrite.assertstate_key] -def _get_state(config): - if hasattr(config, '_assertstate'): - return config._assertstate - return config._store[rewrite.assertstate_key] - - class AssertionRewritingHook(rewrite.AssertionRewritingHook): - def __init__(self, *args, **kwargs): - self.modules = {} - super(AssertionRewritingHook, self).__init__(*args, **kwargs) + def __init__(self, *args, **kwargs): + self.modules = {} + super(AssertionRewritingHook, self).__init__(*args, **kwargs) def find_module(self, name, path=None): - co = self._find_module(name, path) - if co is not None: - return self - - def _find_module(self, name, path=None): - state = _get_state(self.config) - if not self._should_rewrite(name, None, state): + co = self._find_module(name, path) + if co is not None: + return self + + def _find_module(self, name, path=None): + state = _get_state(self.config) + if not self._should_rewrite(name, None, state): return None state.trace("find_module called for: %s" % name) @@ -60,18 +60,18 @@ class AssertionRewritingHook(rewrite.AssertionRewritingHook): # Probably a SyntaxError in the test. return None self.modules[name] = co, None - return co - - def find_spec(self, name, path=None, target=None): - co = self._find_module(name, path) - if co is not None: - return importlib.util.spec_from_file_location( - name, - co.co_filename, - loader=self, - ) - - def _should_rewrite(self, name, fn, state): + return co + + def find_spec(self, name, path=None, target=None): + co = self._find_module(name, path) + if co is not None: + return importlib.util.spec_from_file_location( + name, + co.co_filename, + loader=self, + ) + + def _should_rewrite(self, name, fn, state): if name.startswith("__tests__.") or name.endswith(".conftest"): return True @@ -83,31 +83,31 @@ class AssertionRewritingHook(rewrite.AssertionRewritingHook): def get_source(self, name): return importer.get_source(name) - if six.PY3: - def load_module(self, module): - co, _ = self.modules.pop(module.__name__) - try: - module.__file__ = co.co_filename - module.__cached__ = None - module.__loader__ = self - module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self) - exec(co, module.__dict__) - except: # noqa - if module.__name__ in sys.modules: - del sys.modules[module.__name__] - raise - return sys.modules[module.__name__] - - def exec_module(self, module): - if module.__name__ in self.modules: - self.load_module(module) - else: - super(AssertionRewritingHook, self).exec_module(module) - - + if six.PY3: + def load_module(self, module): + co, _ = self.modules.pop(module.__name__) + try: + module.__file__ = co.co_filename + module.__cached__ = None + module.__loader__ = self + module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self) + exec(co, module.__dict__) + except: # noqa + if module.__name__ in sys.modules: + del sys.modules[module.__name__] + raise + return sys.modules[module.__name__] + + def exec_module(self, module): + if module.__name__ in self.modules: + self.load_module(module) + else: + super(AssertionRewritingHook, self).exec_module(module) + + def _rewrite_test(config, name): """Try to read and rewrite *fn* and return the code object.""" - state = _get_state(config) + state = _get_state(config) source = importer.get_source(name) if source is None: @@ -121,7 +121,7 @@ def _rewrite_test(config, name): # Let this pop up again in the real import. state.trace("failed to parse: %r" % (path,)) return None - rewrite.rewrite_asserts(tree, py.path.local(path), config) + rewrite.rewrite_asserts(tree, py.path.local(path), config) try: co = compile(tree, path, "exec", dont_inherit=True) except SyntaxError: |