diff options
author | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-02-14 00:49:36 +0300 |
---|---|---|
committer | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-02-14 00:49:36 +0300 |
commit | 82cfd1b7cab2d843cdf5467d9737f72597a493bd (patch) | |
tree | 1dfdcfe81a1a6b193ceacc2a828c521b657a339b /contrib/python/pytest/py3/_pytest/assertion/rewrite.py | |
parent | 3df7211d3e3691f8e33b0a1fb1764fe810d59302 (diff) | |
download | ydb-82cfd1b7cab2d843cdf5467d9737f72597a493bd.tar.gz |
intermediate changes
ref:68b1302de4b5da30b6bdf02193f7a2604d8b5cf8
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion/rewrite.py')
-rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/rewrite.py | 97 |
1 files changed, 54 insertions, 43 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py index 37ff076aab..88ac6cab36 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py +++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Dict from typing import IO from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Sequence @@ -27,8 +28,7 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union -import py - +from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import saferepr from _pytest._version import version from _pytest.assertion import util @@ -37,14 +37,15 @@ from _pytest.assertion.util import ( # noqa: F401 ) from _pytest.config import Config from _pytest.main import Session +from _pytest.pathlib import absolutepath from _pytest.pathlib import fnmatch_ex -from _pytest.store import StoreKey +from _pytest.stash import StashKey if TYPE_CHECKING: from _pytest.assertion import AssertionState -assertstate_key = StoreKey["AssertionState"]() +assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -63,7 +64,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) except ValueError: self.fnpats = ["test_*.py", "*_test.py"] self.session: Optional[Session] = None - self._rewritten_names: Set[str] = set() + self._rewritten_names: Dict[str, Path] = {} self._must_rewrite: Set[str] = set() # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) @@ -87,7 +88,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) ) -> Optional[importlib.machinery.ModuleSpec]: if self._writing_pyc: return None - state = self.config._store[assertstate_key] + state = self.config.stash[assertstate_key] if self._early_rewrite_bailout(name, state): return None state.trace("find_module called for: %s" % name) @@ -131,9 +132,9 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) assert module.__spec__ is not None assert module.__spec__.origin is not None fn = Path(module.__spec__.origin) - state = self.config._store[assertstate_key] + state = self.config.stash[assertstate_key] - self._rewritten_names.add(module.__name__) + self._rewritten_names[module.__name__] = fn # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the @@ -215,7 +216,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) return True if self.session is not None: - if self.session.isinitpath(py.path.local(fn)): + if self.session.isinitpath(absolutepath(fn)): state.trace(f"matched test file (was specified on cmdline): {fn!r}") return True @@ -275,6 +276,16 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) with open(pathname, "rb") as f: return f.read() + if sys.version_info >= (3, 10): + + def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore + if sys.version_info < (3, 11): + from importlib.readers import FileReader + else: + from importlib.resources.readers import FileReader + + return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) + def _write_pyc_fp( fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType @@ -333,7 +344,7 @@ else: try: _write_pyc_fp(fp, source_stat, co) - os.rename(proc_pyc, os.fspath(pyc)) + os.rename(proc_pyc, pyc) except OSError as e: state.trace(f"error writing pyc file at {pyc}: {e}") # we ignore any failure to write the cache file @@ -347,13 +358,12 @@ else: def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: """Read and rewrite *fn* and return the code object.""" - fn_ = os.fspath(fn) - stat = os.stat(fn_) - with open(fn_, "rb") as f: - source = f.read() - tree = ast.parse(source, filename=fn_) - rewrite_asserts(tree, source, fn_, config) - co = compile(tree, fn_, "exec", dont_inherit=True) + stat = os.stat(fn) + source = fn.read_bytes() + strfn = str(fn) + tree = ast.parse(source, filename=strfn) + rewrite_asserts(tree, source, strfn, config) + co = compile(tree, strfn, "exec", dont_inherit=True) return stat, co @@ -365,14 +375,14 @@ def _read_pyc( Return rewritten code if successful or None if not. """ try: - fp = open(os.fspath(pyc), "rb") + fp = open(pyc, "rb") except OSError: return None with fp: # https://www.python.org/dev/peps/pep-0552/ has_flags = sys.version_info >= (3, 7) try: - stat_result = os.stat(os.fspath(source)) + stat_result = os.stat(source) mtime = int(stat_result.st_mtime) size = stat_result.st_size data = fp.read(16 if has_flags else 12) @@ -428,7 +438,18 @@ def _saferepr(obj: object) -> str: sequences, especially '\n{' and '\n}' are likely to be present in JSON reprs. """ - return saferepr(obj).replace("\n", "\\n") + maxsize = _get_maxsize_for_saferepr(util._config) + return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") + + +def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]: + """Get `maxsize` configuration for saferepr based on the given config object.""" + verbosity = config.getoption("verbose") if config is not None else 0 + if verbosity >= 2: + return None + if verbosity >= 1: + return DEFAULT_REPR_MAX_SIZE * 10 + return DEFAULT_REPR_MAX_SIZE def _format_assertmsg(obj: object) -> str: @@ -495,7 +516,7 @@ def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: def _check_if_assertion_pass_impl() -> bool: """Check if any plugins implement the pytest_assertion_pass hook - in order not to generate explanation unecessarily (might be expensive).""" + in order not to generate explanation unnecessarily (might be expensive).""" return True if util._assertion_pass else False @@ -528,21 +549,14 @@ BINOP_MAP = { } -def set_location(node, lineno, col_offset): - """Set node location information recursively.""" - - def _fix(node, lineno, col_offset): - if "lineno" in node._attributes: - node.lineno = lineno - if "col_offset" in node._attributes: - node.col_offset = col_offset - for child in ast.iter_child_nodes(node): - _fix(child, lineno, col_offset) - - _fix(node, lineno, col_offset) - return node +def traverse_node(node: ast.AST) -> Iterator[ast.AST]: + """Recursively yield node and all its children in depth-first order.""" + yield node + for child in ast.iter_child_nodes(node): + yield from traverse_node(child) +@functools.lru_cache(maxsize=1) def _get_assertion_exprs(src: bytes) -> Dict[int, str]: """Return a mapping from {lineno: "assertion test expression"}.""" ret: Dict[int, str] = {} @@ -664,10 +678,6 @@ class AssertionRewriter(ast.NodeVisitor): self.enable_assertion_pass_hook = False self.source = source - @functools.lru_cache(maxsize=1) - def _assert_expr_to_lineno(self) -> Dict[int, str]: - return _get_assertion_exprs(self.source) - def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" if not mod.body: @@ -854,7 +864,7 @@ class AssertionRewriter(ast.NodeVisitor): "assertion is always true, perhaps remove parentheses?" ), category=None, - filename=os.fspath(self.module_path), + filename=self.module_path, lineno=assert_.lineno, ) @@ -895,7 +905,7 @@ class AssertionRewriter(ast.NodeVisitor): # Passed fmt_pass = self.helper("_format_explanation", msg) - orig = self._assert_expr_to_lineno()[assert_.lineno] + orig = _get_assertion_exprs(self.source)[assert_.lineno] hook_call_pass = ast.Expr( self.helper( "_call_assertion_pass", @@ -946,9 +956,10 @@ class AssertionRewriter(ast.NodeVisitor): variables = [ast.Name(name, ast.Store()) for name in self.variables] clear = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear) - # Fix line numbers. + # Fix locations (line numbers/column offsets). for stmt in self.statements: - set_location(stmt, assert_.lineno, assert_.col_offset) + for node in traverse_node(stmt): + ast.copy_location(node, assert_) return self.statements def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: @@ -1095,7 +1106,7 @@ def try_makedirs(cache_dir: Path) -> bool: Returns True if successful or if it already exists. """ try: - os.makedirs(os.fspath(cache_dir), exist_ok=True) + os.makedirs(cache_dir, exist_ok=True) except (FileNotFoundError, NotADirectoryError, FileExistsError): # One of the path components was not a directory: # - we're in a zip file |