diff options
author | AlexSm <alex@ydb.tech> | 2024-01-09 18:56:40 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 18:56:40 +0100 |
commit | e95f266d2a3e48e62015220588a4fd73d5d5a5cb (patch) | |
tree | a8a784b6931fe52ad5f511cfef85af14e5f63991 /contrib/python/pytest/py3/_pytest/assertion | |
parent | 50a65e3b48a82d5b51f272664da389f2e0b0c99a (diff) | |
download | ydb-e95f266d2a3e48e62015220588a4fd73d5d5a5cb.tar.gz |
Library import 6 (#888)
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
-rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/rewrite.py | 68 | ||||
-rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/util.py | 2 |
2 files changed, 51 insertions, 19 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py index ab83fee32b2..d1974bb3b4a 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py +++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py @@ -13,6 +13,7 @@ import struct import sys import tokenize import types +from collections import defaultdict from pathlib import Path from pathlib import PurePath from typing import Callable @@ -56,6 +57,10 @@ else: astNum = ast.Num +class Sentinel: + pass + + assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -63,6 +68,9 @@ PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +# Special marker that denotes we have just left a scope definition +_SCOPE_END_MARKER = Sentinel() + class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -596,6 +604,13 @@ def _get_assertion_exprs(src: bytes) -> Dict[int, str]: return ret +def _get_ast_constant_value(value: astStr) -> object: + if sys.version_info >= (3, 8): + return value.value + else: + return value.s + + class AssertionRewriter(ast.NodeVisitor): """Assertion rewriting implementation. @@ -645,6 +660,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. + :scope: A tuple containing the current scope used for variables_overwrite. + :variables_overwrite: A dict filled with references to variables that change value within an assert. This happens when a variable is reassigned with the walrus operator @@ -666,7 +683,10 @@ class AssertionRewriter(ast.NodeVisitor): else: self.enable_assertion_pass_hook = False self.source = source - self.variables_overwrite: Dict[str, str] = {} + self.scope: tuple[ast.AST, ...] = () + self.variables_overwrite: defaultdict[ + tuple[ast.AST, ...], Dict[str, str] + ] = defaultdict(dict) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -687,11 +707,10 @@ class AssertionRewriter(ast.NodeVisitor): expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, astStr) + and isinstance(_get_ast_constant_value(item.value), str) ): - if sys.version_info >= (3, 8): - doc = item.value.value - else: - doc = item.value.s + doc = _get_ast_constant_value(item.value) + assert isinstance(doc, str) if self.is_rewrite_disabled(doc): return expect_docstring = False @@ -732,9 +751,17 @@ class AssertionRewriter(ast.NodeVisitor): mod.body[pos:pos] = imports # Collect asserts. - nodes: List[ast.AST] = [mod] + self.scope = (mod,) + nodes: List[Union[ast.AST, Sentinel]] = [mod] while nodes: node = nodes.pop() + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self.scope = tuple((*self.scope, node)) + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: + self.scope = self.scope[:-1] + continue + assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): new: List[ast.AST] = [] @@ -1005,7 +1032,7 @@ class AssertionRewriter(ast.NodeVisitor): ] ): pytest_temp = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ v.left.target.id ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp @@ -1048,17 +1075,20 @@ class AssertionRewriter(ast.NodeVisitor): new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: - arg = self.variables_overwrite[arg.id] # type:ignore[assignment] + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( + self.scope, {} + ): + arg = self.variables_overwrite[self.scope][ + arg.id + ] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if ( - isinstance(keyword.value, ast.Name) - and keyword.value.id in self.variables_overwrite - ): - keyword.value = self.variables_overwrite[ + if isinstance( + keyword.value, ast.Name + ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): + keyword.value = self.variables_overwrite[self.scope][ keyword.value.id ] # type:ignore[assignment] res, expl = self.visit(keyword.value) @@ -1094,12 +1124,14 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left = self.variables_overwrite[ + if isinstance( + comp.left, ast.Name + ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): + comp.left = self.variables_overwrite[self.scope][ comp.left.id ] # type:ignore[assignment] if isinstance(comp.left, namedExpr): - self.variables_overwrite[ + self.variables_overwrite[self.scope][ comp.left.target.id ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) @@ -1119,7 +1151,7 @@ class AssertionRewriter(ast.NodeVisitor): and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ left_res.id ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py index fc5dfdbd5ba..39ca5403e04 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/util.py +++ b/contrib/python/pytest/py3/_pytest/assertion/util.py @@ -132,7 +132,7 @@ def isiterable(obj: Any) -> bool: try: iter(obj) return not istext(obj) - except TypeError: + except Exception: return False |