aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
diff options
context:
space:
mode:
authorarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-02-14 00:49:36 +0300
committerarcadia-devtools <arcadia-devtools@yandex-team.ru>2022-02-14 00:49:36 +0300
commit82cfd1b7cab2d843cdf5467d9737f72597a493bd (patch)
tree1dfdcfe81a1a6b193ceacc2a828c521b657a339b /contrib/python/pytest/py3/_pytest/assertion/rewrite.py
parent3df7211d3e3691f8e33b0a1fb1764fe810d59302 (diff)
downloadydb-82cfd1b7cab2d843cdf5467d9737f72597a493bd.tar.gz
intermediate changes
ref:68b1302de4b5da30b6bdf02193f7a2604d8b5cf8
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion/rewrite.py')
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py97
1 files changed, 54 insertions, 43 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
index 37ff076aab..88ac6cab36 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -19,6 +19,7 @@ from typing import Callable
from typing import Dict
from typing import IO
from typing import Iterable
+from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
@@ -27,8 +28,7 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
-import py
-
+from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
from _pytest.assertion import util
@@ -37,14 +37,15 @@ from _pytest.assertion.util import ( # noqa: F401
)
from _pytest.config import Config
from _pytest.main import Session
+from _pytest.pathlib import absolutepath
from _pytest.pathlib import fnmatch_ex
-from _pytest.store import StoreKey
+from _pytest.stash import StashKey
if TYPE_CHECKING:
from _pytest.assertion import AssertionState
-assertstate_key = StoreKey["AssertionState"]()
+assertstate_key = StashKey["AssertionState"]()
# pytest caches rewritten pycs in pycache dirs
@@ -63,7 +64,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session: Optional[Session] = None
- self._rewritten_names: Set[str] = set()
+ self._rewritten_names: Dict[str, Path] = {}
self._must_rewrite: Set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
@@ -87,7 +88,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc:
return None
- state = self.config._store[assertstate_key]
+ state = self.config.stash[assertstate_key]
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name)
@@ -131,9 +132,9 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
assert module.__spec__ is not None
assert module.__spec__.origin is not None
fn = Path(module.__spec__.origin)
- state = self.config._store[assertstate_key]
+ state = self.config.stash[assertstate_key]
- self._rewritten_names.add(module.__name__)
+ self._rewritten_names[module.__name__] = fn
# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
@@ -215,7 +216,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return True
if self.session is not None:
- if self.session.isinitpath(py.path.local(fn)):
+ if self.session.isinitpath(absolutepath(fn)):
state.trace(f"matched test file (was specified on cmdline): {fn!r}")
return True
@@ -275,6 +276,16 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
with open(pathname, "rb") as f:
return f.read()
+ if sys.version_info >= (3, 10):
+
+ def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore
+ if sys.version_info < (3, 11):
+ from importlib.readers import FileReader
+ else:
+ from importlib.resources.readers import FileReader
+
+ return FileReader(types.SimpleNamespace(path=self._rewritten_names[name]))
+
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
@@ -333,7 +344,7 @@ else:
try:
_write_pyc_fp(fp, source_stat, co)
- os.rename(proc_pyc, os.fspath(pyc))
+ os.rename(proc_pyc, pyc)
except OSError as e:
state.trace(f"error writing pyc file at {pyc}: {e}")
# we ignore any failure to write the cache file
@@ -347,13 +358,12 @@ else:
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
"""Read and rewrite *fn* and return the code object."""
- fn_ = os.fspath(fn)
- stat = os.stat(fn_)
- with open(fn_, "rb") as f:
- source = f.read()
- tree = ast.parse(source, filename=fn_)
- rewrite_asserts(tree, source, fn_, config)
- co = compile(tree, fn_, "exec", dont_inherit=True)
+ stat = os.stat(fn)
+ source = fn.read_bytes()
+ strfn = str(fn)
+ tree = ast.parse(source, filename=strfn)
+ rewrite_asserts(tree, source, strfn, config)
+ co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co
@@ -365,14 +375,14 @@ def _read_pyc(
Return rewritten code if successful or None if not.
"""
try:
- fp = open(os.fspath(pyc), "rb")
+ fp = open(pyc, "rb")
except OSError:
return None
with fp:
# https://www.python.org/dev/peps/pep-0552/
has_flags = sys.version_info >= (3, 7)
try:
- stat_result = os.stat(os.fspath(source))
+ stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(16 if has_flags else 12)
@@ -428,7 +438,18 @@ def _saferepr(obj: object) -> str:
sequences, especially '\n{' and '\n}' are likely to be present in
JSON reprs.
"""
- return saferepr(obj).replace("\n", "\\n")
+ maxsize = _get_maxsize_for_saferepr(util._config)
+ return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
+
+
+def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
+ """Get `maxsize` configuration for saferepr based on the given config object."""
+ verbosity = config.getoption("verbose") if config is not None else 0
+ if verbosity >= 2:
+ return None
+ if verbosity >= 1:
+ return DEFAULT_REPR_MAX_SIZE * 10
+ return DEFAULT_REPR_MAX_SIZE
def _format_assertmsg(obj: object) -> str:
@@ -495,7 +516,7 @@ def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
def _check_if_assertion_pass_impl() -> bool:
"""Check if any plugins implement the pytest_assertion_pass hook
- in order not to generate explanation unecessarily (might be expensive)."""
+ in order not to generate explanation unnecessarily (might be expensive)."""
return True if util._assertion_pass else False
@@ -528,21 +549,14 @@ BINOP_MAP = {
}
-def set_location(node, lineno, col_offset):
- """Set node location information recursively."""
-
- def _fix(node, lineno, col_offset):
- if "lineno" in node._attributes:
- node.lineno = lineno
- if "col_offset" in node._attributes:
- node.col_offset = col_offset
- for child in ast.iter_child_nodes(node):
- _fix(child, lineno, col_offset)
-
- _fix(node, lineno, col_offset)
- return node
+def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
+ """Recursively yield node and all its children in depth-first order."""
+ yield node
+ for child in ast.iter_child_nodes(node):
+ yield from traverse_node(child)
+@functools.lru_cache(maxsize=1)
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
"""Return a mapping from {lineno: "assertion test expression"}."""
ret: Dict[int, str] = {}
@@ -664,10 +678,6 @@ class AssertionRewriter(ast.NodeVisitor):
self.enable_assertion_pass_hook = False
self.source = source
- @functools.lru_cache(maxsize=1)
- def _assert_expr_to_lineno(self) -> Dict[int, str]:
- return _get_assertion_exprs(self.source)
-
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
@@ -854,7 +864,7 @@ class AssertionRewriter(ast.NodeVisitor):
"assertion is always true, perhaps remove parentheses?"
),
category=None,
- filename=os.fspath(self.module_path),
+ filename=self.module_path,
lineno=assert_.lineno,
)
@@ -895,7 +905,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Passed
fmt_pass = self.helper("_format_explanation", msg)
- orig = self._assert_expr_to_lineno()[assert_.lineno]
+ orig = _get_assertion_exprs(self.source)[assert_.lineno]
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
@@ -946,9 +956,10 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, ast.NameConstant(None))
self.statements.append(clear)
- # Fix line numbers.
+ # Fix locations (line numbers/column offsets).
for stmt in self.statements:
- set_location(stmt, assert_.lineno, assert_.col_offset)
+ for node in traverse_node(stmt):
+ ast.copy_location(node, assert_)
return self.statements
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
@@ -1095,7 +1106,7 @@ def try_makedirs(cache_dir: Path) -> bool:
Returns True if successful or if it already exists.
"""
try:
- os.makedirs(os.fspath(cache_dir), exist_ok=True)
+ os.makedirs(cache_dir, exist_ok=True)
except (FileNotFoundError, NotADirectoryError, FileExistsError):
# One of the path components was not a directory:
# - we're in a zip file