aboutsummaryrefslogtreecommitdiffstats
path: root/library/python/pytest/rewrite.py
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/python/pytest/rewrite.py
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/python/pytest/rewrite.py')
-rw-r--r--library/python/pytest/rewrite.py123
1 files changed, 123 insertions, 0 deletions
diff --git a/library/python/pytest/rewrite.py b/library/python/pytest/rewrite.py
new file mode 100644
index 0000000000..ec188d847f
--- /dev/null
+++ b/library/python/pytest/rewrite.py
@@ -0,0 +1,123 @@
+from __future__ import absolute_import
+from __future__ import print_function
+
+import ast
+
+import py
+
+from _pytest.assertion import rewrite
+try:
+ import importlib.util
+except ImportError:
+ pass
+from __res import importer
+import sys
+import six
+
+
+def _get_state(config):
+ if hasattr(config, '_assertstate'):
+ return config._assertstate
+ return config._store[rewrite.assertstate_key]
+
+
+class AssertionRewritingHook(rewrite.AssertionRewritingHook):
+ def __init__(self, *args, **kwargs):
+ self.modules = {}
+ super(AssertionRewritingHook, self).__init__(*args, **kwargs)
+
+ def find_module(self, name, path=None):
+ co = self._find_module(name, path)
+ if co is not None:
+ return self
+
+ def _find_module(self, name, path=None):
+ state = _get_state(self.config)
+ if not self._should_rewrite(name, None, state):
+ return None
+ state.trace("find_module called for: %s" % name)
+
+ try:
+ if self.is_package(name):
+ return None
+ except ImportError:
+ return None
+
+ self._rewritten_names.add(name)
+
+ state.trace("rewriting %s" % name)
+ co = _rewrite_test(self.config, name)
+ if co is None:
+ # Probably a SyntaxError in the test.
+ return None
+ self.modules[name] = co, None
+ return co
+
+ def find_spec(self, name, path=None, target=None):
+ co = self._find_module(name, path)
+ if co is not None:
+ return importlib.util.spec_from_file_location(
+ name,
+ co.co_filename,
+ loader=self,
+ )
+
+ def _should_rewrite(self, name, fn, state):
+ if name.startswith("__tests__.") or name.endswith(".conftest"):
+ return True
+
+ return self._is_marked_for_rewrite(name, state)
+
+ def is_package(self, name):
+ return importer.is_package(name)
+
+ def get_source(self, name):
+ return importer.get_source(name)
+
+ if six.PY3:
+ def load_module(self, module):
+ co, _ = self.modules.pop(module.__name__)
+ try:
+ module.__file__ = co.co_filename
+ module.__cached__ = None
+ module.__loader__ = self
+ module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
+ exec(co, module.__dict__)
+ except: # noqa
+ if module.__name__ in sys.modules:
+ del sys.modules[module.__name__]
+ raise
+ return sys.modules[module.__name__]
+
+ def exec_module(self, module):
+ if module.__name__ in self.modules:
+ self.load_module(module)
+ else:
+ super(AssertionRewritingHook, self).exec_module(module)
+
+
+def _rewrite_test(config, name):
+ """Try to read and rewrite *fn* and return the code object."""
+ state = _get_state(config)
+
+ source = importer.get_source(name)
+ if source is None:
+ return None
+
+ path = importer.get_filename(name)
+
+ try:
+ tree = ast.parse(source, filename=path)
+ except SyntaxError:
+ # Let this pop up again in the real import.
+ state.trace("failed to parse: %r" % (path,))
+ return None
+ rewrite.rewrite_asserts(tree, py.path.local(path), config)
+ try:
+ co = compile(tree, path, "exec", dont_inherit=True)
+ except SyntaxError:
+ # It's possible that this error is from some bug in the
+ # assertion rewriting, but I don't know of a fast way to tell.
+ state.trace("failed to compile: %r" % (path,))
+ return None
+ return co