From 82cfd1b7cab2d843cdf5467d9737f72597a493bd Mon Sep 17 00:00:00 2001
From: arcadia-devtools <arcadia-devtools@yandex-team.ru>
Date: Mon, 14 Feb 2022 00:49:36 +0300
Subject: intermediate changes ref:68b1302de4b5da30b6bdf02193f7a2604d8b5cf8

---
 .../python/pytest/py3/_pytest/assertion/rewrite.py | 97 ++++++++++++----------
 1 file changed, 54 insertions(+), 43 deletions(-)

(limited to 'contrib/python/pytest/py3/_pytest/assertion/rewrite.py')

diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
index 37ff076aab..88ac6cab36 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -19,6 +19,7 @@ from typing import Callable
 from typing import Dict
 from typing import IO
 from typing import Iterable
+from typing import Iterator
 from typing import List
 from typing import Optional
 from typing import Sequence
@@ -27,8 +28,7 @@ from typing import Tuple
 from typing import TYPE_CHECKING
 from typing import Union
 
-import py
-
+from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
 from _pytest._io.saferepr import saferepr
 from _pytest._version import version
 from _pytest.assertion import util
@@ -37,14 +37,15 @@ from _pytest.assertion.util import (  # noqa: F401
 )
 from _pytest.config import Config
 from _pytest.main import Session
+from _pytest.pathlib import absolutepath
 from _pytest.pathlib import fnmatch_ex
-from _pytest.store import StoreKey
+from _pytest.stash import StashKey
 
 if TYPE_CHECKING:
     from _pytest.assertion import AssertionState
 
 
-assertstate_key = StoreKey["AssertionState"]()
+assertstate_key = StashKey["AssertionState"]()
 
 
 # pytest caches rewritten pycs in pycache dirs
@@ -63,7 +64,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
         except ValueError:
             self.fnpats = ["test_*.py", "*_test.py"]
         self.session: Optional[Session] = None
-        self._rewritten_names: Set[str] = set()
+        self._rewritten_names: Dict[str, Path] = {}
         self._must_rewrite: Set[str] = set()
         # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
         # which might result in infinite recursion (#3506)
@@ -87,7 +88,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
     ) -> Optional[importlib.machinery.ModuleSpec]:
         if self._writing_pyc:
             return None
-        state = self.config._store[assertstate_key]
+        state = self.config.stash[assertstate_key]
         if self._early_rewrite_bailout(name, state):
             return None
         state.trace("find_module called for: %s" % name)
@@ -131,9 +132,9 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
         assert module.__spec__ is not None
         assert module.__spec__.origin is not None
         fn = Path(module.__spec__.origin)
-        state = self.config._store[assertstate_key]
+        state = self.config.stash[assertstate_key]
 
-        self._rewritten_names.add(module.__name__)
+        self._rewritten_names[module.__name__] = fn
 
         # The requested module looks like a test file, so rewrite it. This is
         # the most magical part of the process: load the source, rewrite the
@@ -215,7 +216,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
             return True
 
         if self.session is not None:
-            if self.session.isinitpath(py.path.local(fn)):
+            if self.session.isinitpath(absolutepath(fn)):
                 state.trace(f"matched test file (was specified on cmdline): {fn!r}")
                 return True
 
@@ -275,6 +276,16 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
         with open(pathname, "rb") as f:
             return f.read()
 
+    if sys.version_info >= (3, 10):
+
+        def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources:  # type: ignore
+            if sys.version_info < (3, 11):
+                from importlib.readers import FileReader
+            else:
+                from importlib.resources.readers import FileReader
+
+            return FileReader(types.SimpleNamespace(path=self._rewritten_names[name]))
+
 
 def _write_pyc_fp(
     fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
@@ -333,7 +344,7 @@ else:
 
         try:
             _write_pyc_fp(fp, source_stat, co)
-            os.rename(proc_pyc, os.fspath(pyc))
+            os.rename(proc_pyc, pyc)
         except OSError as e:
             state.trace(f"error writing pyc file at {pyc}: {e}")
             # we ignore any failure to write the cache file
@@ -347,13 +358,12 @@ else:
 
 def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
     """Read and rewrite *fn* and return the code object."""
-    fn_ = os.fspath(fn)
-    stat = os.stat(fn_)
-    with open(fn_, "rb") as f:
-        source = f.read()
-    tree = ast.parse(source, filename=fn_)
-    rewrite_asserts(tree, source, fn_, config)
-    co = compile(tree, fn_, "exec", dont_inherit=True)
+    stat = os.stat(fn)
+    source = fn.read_bytes()
+    strfn = str(fn)
+    tree = ast.parse(source, filename=strfn)
+    rewrite_asserts(tree, source, strfn, config)
+    co = compile(tree, strfn, "exec", dont_inherit=True)
     return stat, co
 
 
@@ -365,14 +375,14 @@ def _read_pyc(
     Return rewritten code if successful or None if not.
     """
     try:
-        fp = open(os.fspath(pyc), "rb")
+        fp = open(pyc, "rb")
     except OSError:
         return None
     with fp:
         # https://www.python.org/dev/peps/pep-0552/
         has_flags = sys.version_info >= (3, 7)
         try:
-            stat_result = os.stat(os.fspath(source))
+            stat_result = os.stat(source)
             mtime = int(stat_result.st_mtime)
             size = stat_result.st_size
             data = fp.read(16 if has_flags else 12)
@@ -428,7 +438,18 @@ def _saferepr(obj: object) -> str:
     sequences, especially '\n{' and '\n}' are likely to be present in
     JSON reprs.
     """
-    return saferepr(obj).replace("\n", "\\n")
+    maxsize = _get_maxsize_for_saferepr(util._config)
+    return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
+
+
+def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
+    """Get `maxsize` configuration for saferepr based on the given config object."""
+    verbosity = config.getoption("verbose") if config is not None else 0
+    if verbosity >= 2:
+        return None
+    if verbosity >= 1:
+        return DEFAULT_REPR_MAX_SIZE * 10
+    return DEFAULT_REPR_MAX_SIZE
 
 
 def _format_assertmsg(obj: object) -> str:
@@ -495,7 +516,7 @@ def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
 
 def _check_if_assertion_pass_impl() -> bool:
     """Check if any plugins implement the pytest_assertion_pass hook
-    in order not to generate explanation unecessarily (might be expensive)."""
+    in order not to generate explanation unnecessarily (might be expensive)."""
     return True if util._assertion_pass else False
 
 
@@ -528,21 +549,14 @@ BINOP_MAP = {
 }
 
 
-def set_location(node, lineno, col_offset):
-    """Set node location information recursively."""
-
-    def _fix(node, lineno, col_offset):
-        if "lineno" in node._attributes:
-            node.lineno = lineno
-        if "col_offset" in node._attributes:
-            node.col_offset = col_offset
-        for child in ast.iter_child_nodes(node):
-            _fix(child, lineno, col_offset)
-
-    _fix(node, lineno, col_offset)
-    return node
+def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
+    """Recursively yield node and all its children in depth-first order."""
+    yield node
+    for child in ast.iter_child_nodes(node):
+        yield from traverse_node(child)
 
 
+@functools.lru_cache(maxsize=1)
 def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
     """Return a mapping from {lineno: "assertion test expression"}."""
     ret: Dict[int, str] = {}
@@ -664,10 +678,6 @@ class AssertionRewriter(ast.NodeVisitor):
             self.enable_assertion_pass_hook = False
         self.source = source
 
-    @functools.lru_cache(maxsize=1)
-    def _assert_expr_to_lineno(self) -> Dict[int, str]:
-        return _get_assertion_exprs(self.source)
-
     def run(self, mod: ast.Module) -> None:
         """Find all assert statements in *mod* and rewrite them."""
         if not mod.body:
@@ -854,7 +864,7 @@ class AssertionRewriter(ast.NodeVisitor):
                     "assertion is always true, perhaps remove parentheses?"
                 ),
                 category=None,
-                filename=os.fspath(self.module_path),
+                filename=self.module_path,
                 lineno=assert_.lineno,
             )
 
@@ -895,7 +905,7 @@ class AssertionRewriter(ast.NodeVisitor):
 
             # Passed
             fmt_pass = self.helper("_format_explanation", msg)
-            orig = self._assert_expr_to_lineno()[assert_.lineno]
+            orig = _get_assertion_exprs(self.source)[assert_.lineno]
             hook_call_pass = ast.Expr(
                 self.helper(
                     "_call_assertion_pass",
@@ -946,9 +956,10 @@ class AssertionRewriter(ast.NodeVisitor):
             variables = [ast.Name(name, ast.Store()) for name in self.variables]
             clear = ast.Assign(variables, ast.NameConstant(None))
             self.statements.append(clear)
-        # Fix line numbers.
+        # Fix locations (line numbers/column offsets).
         for stmt in self.statements:
-            set_location(stmt, assert_.lineno, assert_.col_offset)
+            for node in traverse_node(stmt):
+                ast.copy_location(node, assert_)
         return self.statements
 
     def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
@@ -1095,7 +1106,7 @@ def try_makedirs(cache_dir: Path) -> bool:
     Returns True if successful or if it already exists.
     """
     try:
-        os.makedirs(os.fspath(cache_dir), exist_ok=True)
+        os.makedirs(cache_dir, exist_ok=True)
     except (FileNotFoundError, NotADirectoryError, FileExistsError):
         # One of the path components was not a directory:
         # - we're in a zip file
-- 
cgit v1.2.3