diff options
author | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-02-14 00:49:36 +0300 |
---|---|---|
committer | arcadia-devtools <arcadia-devtools@yandex-team.ru> | 2022-02-14 00:49:36 +0300 |
commit | 82cfd1b7cab2d843cdf5467d9737f72597a493bd (patch) | |
tree | 1dfdcfe81a1a6b193ceacc2a828c521b657a339b /contrib/python/pytest/py3/_pytest/assertion | |
parent | 3df7211d3e3691f8e33b0a1fb1764fe810d59302 (diff) | |
download | ydb-82cfd1b7cab2d843cdf5467d9737f72597a493bd.tar.gz |
intermediate changes
ref:68b1302de4b5da30b6bdf02193f7a2604d8b5cf8
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
4 files changed, 88 insertions, 60 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py index a18cf198df..480a26ad86 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py +++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py @@ -88,13 +88,13 @@ class AssertionState: def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: """Try to install the rewrite hook, raise SystemError if it fails.""" - config._store[assertstate_key] = AssertionState(config, "rewrite") - config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) + config.stash[assertstate_key] = AssertionState(config, "rewrite") + config.stash[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) sys.meta_path.insert(0, hook) - config._store[assertstate_key].trace("installed rewrite import hook") + config.stash[assertstate_key].trace("installed rewrite import hook") def undo() -> None: - hook = config._store[assertstate_key].hook + hook = config.stash[assertstate_key].hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) @@ -104,9 +104,9 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: def pytest_collection(session: "Session") -> None: # This hook is only called when test modules are collected - # so for example not in the master process of pytest-xdist + # so for example not in the managing process of pytest-xdist # (which does not collect test modules). - assertstate = session.config._store.get(assertstate_key, None) + assertstate = session.config.stash.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(session) @@ -153,6 +153,7 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: saved_assert_hooks = util._reprcompare, util._assertion_pass util._reprcompare = callbinrepr + util._config = item.config if ihook.pytest_assertion_pass.get_hookimpls(): @@ -164,10 +165,11 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: yield util._reprcompare, util._assertion_pass = saved_assert_hooks + util._config = None def pytest_sessionfinish(session: "Session") -> None: - assertstate = session.config._store.get(assertstate_key, None) + assertstate = session.config.stash.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(None) diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py index 37ff076aab..88ac6cab36 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py +++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py @@ -19,6 +19,7 @@ from typing import Callable from typing import Dict from typing import IO from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Sequence @@ -27,8 +28,7 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union -import py - +from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import saferepr from _pytest._version import version from _pytest.assertion import util @@ -37,14 +37,15 @@ from _pytest.assertion.util import ( # noqa: F401 ) from _pytest.config import Config from _pytest.main import Session +from _pytest.pathlib import absolutepath from _pytest.pathlib import fnmatch_ex -from _pytest.store import StoreKey +from _pytest.stash import StashKey if TYPE_CHECKING: from _pytest.assertion import AssertionState -assertstate_key = StoreKey["AssertionState"]() +assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -63,7 +64,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) except ValueError: self.fnpats = ["test_*.py", "*_test.py"] self.session: Optional[Session] = None - self._rewritten_names: Set[str] = set() + self._rewritten_names: Dict[str, Path] = {} self._must_rewrite: Set[str] = set() # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) @@ -87,7 +88,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) ) -> Optional[importlib.machinery.ModuleSpec]: if self._writing_pyc: return None - state = self.config._store[assertstate_key] + state = self.config.stash[assertstate_key] if self._early_rewrite_bailout(name, state): return None state.trace("find_module called for: %s" % name) @@ -131,9 +132,9 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) assert module.__spec__ is not None assert module.__spec__.origin is not None fn = Path(module.__spec__.origin) - state = self.config._store[assertstate_key] + state = self.config.stash[assertstate_key] - self._rewritten_names.add(module.__name__) + self._rewritten_names[module.__name__] = fn # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the @@ -215,7 +216,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) return True if self.session is not None: - if self.session.isinitpath(py.path.local(fn)): + if self.session.isinitpath(absolutepath(fn)): state.trace(f"matched test file (was specified on cmdline): {fn!r}") return True @@ -275,6 +276,16 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) with open(pathname, "rb") as f: return f.read() + if sys.version_info >= (3, 10): + + def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore + if sys.version_info < (3, 11): + from importlib.readers import FileReader + else: + from importlib.resources.readers import FileReader + + return FileReader(types.SimpleNamespace(path=self._rewritten_names[name])) + def _write_pyc_fp( fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType @@ -333,7 +344,7 @@ else: try: _write_pyc_fp(fp, source_stat, co) - os.rename(proc_pyc, os.fspath(pyc)) + os.rename(proc_pyc, pyc) except OSError as e: state.trace(f"error writing pyc file at {pyc}: {e}") # we ignore any failure to write the cache file @@ -347,13 +358,12 @@ else: def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: """Read and rewrite *fn* and return the code object.""" - fn_ = os.fspath(fn) - stat = os.stat(fn_) - with open(fn_, "rb") as f: - source = f.read() - tree = ast.parse(source, filename=fn_) - rewrite_asserts(tree, source, fn_, config) - co = compile(tree, fn_, "exec", dont_inherit=True) + stat = os.stat(fn) + source = fn.read_bytes() + strfn = str(fn) + tree = ast.parse(source, filename=strfn) + rewrite_asserts(tree, source, strfn, config) + co = compile(tree, strfn, "exec", dont_inherit=True) return stat, co @@ -365,14 +375,14 @@ def _read_pyc( Return rewritten code if successful or None if not. """ try: - fp = open(os.fspath(pyc), "rb") + fp = open(pyc, "rb") except OSError: return None with fp: # https://www.python.org/dev/peps/pep-0552/ has_flags = sys.version_info >= (3, 7) try: - stat_result = os.stat(os.fspath(source)) + stat_result = os.stat(source) mtime = int(stat_result.st_mtime) size = stat_result.st_size data = fp.read(16 if has_flags else 12) @@ -428,7 +438,18 @@ def _saferepr(obj: object) -> str: sequences, especially '\n{' and '\n}' are likely to be present in JSON reprs. """ - return saferepr(obj).replace("\n", "\\n") + maxsize = _get_maxsize_for_saferepr(util._config) + return saferepr(obj, maxsize=maxsize).replace("\n", "\\n") + + +def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]: + """Get `maxsize` configuration for saferepr based on the given config object.""" + verbosity = config.getoption("verbose") if config is not None else 0 + if verbosity >= 2: + return None + if verbosity >= 1: + return DEFAULT_REPR_MAX_SIZE * 10 + return DEFAULT_REPR_MAX_SIZE def _format_assertmsg(obj: object) -> str: @@ -495,7 +516,7 @@ def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: def _check_if_assertion_pass_impl() -> bool: """Check if any plugins implement the pytest_assertion_pass hook - in order not to generate explanation unecessarily (might be expensive).""" + in order not to generate explanation unnecessarily (might be expensive).""" return True if util._assertion_pass else False @@ -528,21 +549,14 @@ BINOP_MAP = { } -def set_location(node, lineno, col_offset): - """Set node location information recursively.""" - - def _fix(node, lineno, col_offset): - if "lineno" in node._attributes: - node.lineno = lineno - if "col_offset" in node._attributes: - node.col_offset = col_offset - for child in ast.iter_child_nodes(node): - _fix(child, lineno, col_offset) - - _fix(node, lineno, col_offset) - return node +def traverse_node(node: ast.AST) -> Iterator[ast.AST]: + """Recursively yield node and all its children in depth-first order.""" + yield node + for child in ast.iter_child_nodes(node): + yield from traverse_node(child) +@functools.lru_cache(maxsize=1) def _get_assertion_exprs(src: bytes) -> Dict[int, str]: """Return a mapping from {lineno: "assertion test expression"}.""" ret: Dict[int, str] = {} @@ -664,10 +678,6 @@ class AssertionRewriter(ast.NodeVisitor): self.enable_assertion_pass_hook = False self.source = source - @functools.lru_cache(maxsize=1) - def _assert_expr_to_lineno(self) -> Dict[int, str]: - return _get_assertion_exprs(self.source) - def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" if not mod.body: @@ -854,7 +864,7 @@ class AssertionRewriter(ast.NodeVisitor): "assertion is always true, perhaps remove parentheses?" ), category=None, - filename=os.fspath(self.module_path), + filename=self.module_path, lineno=assert_.lineno, ) @@ -895,7 +905,7 @@ class AssertionRewriter(ast.NodeVisitor): # Passed fmt_pass = self.helper("_format_explanation", msg) - orig = self._assert_expr_to_lineno()[assert_.lineno] + orig = _get_assertion_exprs(self.source)[assert_.lineno] hook_call_pass = ast.Expr( self.helper( "_call_assertion_pass", @@ -946,9 +956,10 @@ class AssertionRewriter(ast.NodeVisitor): variables = [ast.Name(name, ast.Store()) for name in self.variables] clear = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear) - # Fix line numbers. + # Fix locations (line numbers/column offsets). for stmt in self.statements: - set_location(stmt, assert_.lineno, assert_.col_offset) + for node in traverse_node(stmt): + ast.copy_location(node, assert_) return self.statements def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: @@ -1095,7 +1106,7 @@ def try_makedirs(cache_dir: Path) -> bool: Returns True if successful or if it already exists. """ try: - os.makedirs(os.fspath(cache_dir), exist_ok=True) + os.makedirs(cache_dir, exist_ok=True) except (FileNotFoundError, NotADirectoryError, FileExistsError): # One of the path components was not a directory: # - we're in a zip file diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py index 5ba9ddca75..ce148dca09 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py +++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py @@ -3,10 +3,10 @@ Current default behaviour is to truncate assertion explanations at ~8 terminal lines, unless running in "-vv" mode or running on CI. """ -import os from typing import List from typing import Optional +from _pytest.assertion import util from _pytest.nodes import Item @@ -27,13 +27,7 @@ def truncate_if_required( def _should_truncate_item(item: Item) -> bool: """Whether or not this test item is eligible for truncation.""" verbose = item.config.option.verbose - return verbose < 2 and not _running_on_ci() - - -def _running_on_ci() -> bool: - """Check if we're currently running on a CI system.""" - env_vars = ["CI", "BUILD_NUMBER"] - return any(var in os.environ for var in env_vars) + return verbose < 2 and not util.running_on_ci() def _truncate_explanation( diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py index da1ffd15e3..19f1089c20 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/util.py +++ b/contrib/python/pytest/py3/_pytest/assertion/util.py @@ -1,5 +1,6 @@ """Utilities for assertion debugging.""" import collections.abc +import os import pprint from typing import AbstractSet from typing import Any @@ -15,6 +16,7 @@ from _pytest import outcomes from _pytest._io.saferepr import _pformat_dispatch from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import saferepr +from _pytest.config import Config # The _reprcompare attribute on the util module is used by the new assertion # interpretation code and assertion rewriter to detect this plugin was @@ -26,6 +28,9 @@ _reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None # when pytest_runtest_setup is called. _assertion_pass: Optional[Callable[[int, str, str], None]] = None +# Config object which is assigned during pytest_runtest_protocol. +_config: Optional[Config] = None + def format_explanation(explanation: str) -> str: r"""Format an explanation. @@ -175,7 +180,15 @@ def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: if istext(left) and istext(right): explanation = _diff_text(left, right, verbose) else: - if type(left) == type(right) and ( + from _pytest.python_api import ApproxBase + + if isinstance(left, ApproxBase) or isinstance(right, ApproxBase): + # Although the common order should be obtained == expected, this ensures both ways + approx_side = left if isinstance(left, ApproxBase) else right + other_side = right if isinstance(left, ApproxBase) else left + + explanation = approx_side._repr_compare(other_side) + elif type(left) == type(right) and ( isdatacls(left) or isattrs(left) or isnamedtuple(left) ): # Note: unlike dataclasses/attrs, namedtuples compare only the @@ -191,9 +204,11 @@ def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: explanation = _compare_eq_dict(left, right, verbose) elif verbose > 0: explanation = _compare_eq_verbose(left, right) + if isiterable(left) and isiterable(right): expl = _compare_eq_iterable(left, right, verbose) explanation.extend(expl) + return explanation @@ -272,7 +287,7 @@ def _surrounding_parens_on_own_lines(lines: List[str]) -> None: def _compare_eq_iterable( left: Iterable[Any], right: Iterable[Any], verbose: int = 0 ) -> List[str]: - if not verbose: + if not verbose and not running_on_ci(): return ["Use -v to get the full diff"] # dynamic import to speedup pytest import difflib @@ -475,3 +490,9 @@ def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: else: newdiff.append(line) return newdiff + + +def running_on_ci() -> bool: + """Check if we're currently running on a CI system.""" + env_vars = ["CI", "BUILD_NUMBER"] + return any(var in os.environ for var in env_vars) |