aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/pytest/py3/_pytest/assertion/rewrite.py
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion/rewrite.py')
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py1185
1 files changed, 1185 insertions, 0 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
new file mode 100644
index 0000000000..ab83fee32b
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -0,0 +1,1185 @@
+"""Rewrite assertion AST to produce nice error messages."""
+import ast
+import errno
+import functools
+import importlib.abc
+import importlib.machinery
+import importlib.util
+import io
+import itertools
+import marshal
+import os
+import struct
+import sys
+import tokenize
+import types
+from pathlib import Path
+from pathlib import PurePath
+from typing import Callable
+from typing import Dict
+from typing import IO
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
+
+from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
+from _pytest._io.saferepr import saferepr
+from _pytest._version import version
+from _pytest.assertion import util
+from _pytest.assertion.util import ( # noqa: F401
+ format_explanation as _format_explanation,
+)
+from _pytest.config import Config
+from _pytest.main import Session
+from _pytest.pathlib import absolutepath
+from _pytest.pathlib import fnmatch_ex
+from _pytest.stash import StashKey
+
+if TYPE_CHECKING:
+ from _pytest.assertion import AssertionState
+
+if sys.version_info >= (3, 8):
+ namedExpr = ast.NamedExpr
+ astNameConstant = ast.Constant
+ astStr = ast.Constant
+ astNum = ast.Constant
+else:
+ namedExpr = ast.Expr
+ astNameConstant = ast.NameConstant
+ astStr = ast.Str
+ astNum = ast.Num
+
+
+assertstate_key = StashKey["AssertionState"]()
+
+# pytest caches rewritten pycs in pycache dirs
+PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
+PYC_EXT = ".py" + (__debug__ and "c" or "o")
+PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
+
+
+class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
+ """PEP302/PEP451 import hook which rewrites asserts."""
+
+ def __init__(self, config: Config) -> None:
+ self.config = config
+ try:
+ self.fnpats = config.getini("python_files")
+ except ValueError:
+ self.fnpats = ["test_*.py", "*_test.py"]
+ self.session: Optional[Session] = None
+ self._rewritten_names: Dict[str, Path] = {}
+ self._must_rewrite: Set[str] = set()
+ # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
+ # which might result in infinite recursion (#3506)
+ self._writing_pyc = False
+ self._basenames_to_check_rewrite = {"conftest"}
+ self._marked_for_rewrite_cache: Dict[str, bool] = {}
+ self._session_paths_checked = False
+
+ def set_session(self, session: Optional[Session]) -> None:
+ self.session = session
+ self._session_paths_checked = False
+
+ # Indirection so we can mock calls to find_spec originated from the hook during testing
+ _find_spec = importlib.machinery.PathFinder.find_spec
+
+ def find_spec(
+ self,
+ name: str,
+ path: Optional[Sequence[Union[str, bytes]]] = None,
+ target: Optional[types.ModuleType] = None,
+ ) -> Optional[importlib.machinery.ModuleSpec]:
+ if self._writing_pyc:
+ return None
+ state = self.config.stash[assertstate_key]
+ if self._early_rewrite_bailout(name, state):
+ return None
+ state.trace("find_module called for: %s" % name)
+
+ # Type ignored because mypy is confused about the `self` binding here.
+ spec = self._find_spec(name, path) # type: ignore
+ if (
+ # the import machinery could not find a file to import
+ spec is None
+ # this is a namespace package (without `__init__.py`)
+ # there's nothing to rewrite there
+ or spec.origin is None
+ # we can only rewrite source files
+ or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
+ # if the file doesn't exist, we can't rewrite it
+ or not os.path.exists(spec.origin)
+ ):
+ return None
+ else:
+ fn = spec.origin
+
+ if not self._should_rewrite(name, fn, state):
+ return None
+
+ return importlib.util.spec_from_file_location(
+ name,
+ fn,
+ loader=self,
+ submodule_search_locations=spec.submodule_search_locations,
+ )
+
+ def create_module(
+ self, spec: importlib.machinery.ModuleSpec
+ ) -> Optional[types.ModuleType]:
+ return None # default behaviour is fine
+
+ def exec_module(self, module: types.ModuleType) -> None:
+ assert module.__spec__ is not None
+ assert module.__spec__.origin is not None
+ fn = Path(module.__spec__.origin)
+ state = self.config.stash[assertstate_key]
+
+ self._rewritten_names[module.__name__] = fn
+
+ # The requested module looks like a test file, so rewrite it. This is
+ # the most magical part of the process: load the source, rewrite the
+ # asserts, and load the rewritten source. We also cache the rewritten
+ # module code in a special pyc. We must be aware of the possibility of
+ # concurrent pytest processes rewriting and loading pycs. To avoid
+ # tricky race conditions, we maintain the following invariant: The
+ # cached pyc is always a complete, valid pyc. Operations on it must be
+ # atomic. POSIX's atomic rename comes in handy.
+ write = not sys.dont_write_bytecode
+ cache_dir = get_cache_dir(fn)
+ if write:
+ ok = try_makedirs(cache_dir)
+ if not ok:
+ write = False
+ state.trace(f"read only directory: {cache_dir}")
+
+ cache_name = fn.name[:-3] + PYC_TAIL
+ pyc = cache_dir / cache_name
+ # Notice that even if we're in a read-only directory, I'm going
+ # to check for a cached pyc. This may not be optimal...
+ co = _read_pyc(fn, pyc, state.trace)
+ if co is None:
+ state.trace(f"rewriting {fn!r}")
+ source_stat, co = _rewrite_test(fn, self.config)
+ if write:
+ self._writing_pyc = True
+ try:
+ _write_pyc(state, co, source_stat, pyc)
+ finally:
+ self._writing_pyc = False
+ else:
+ state.trace(f"found cached rewritten pyc for {fn}")
+ exec(co, module.__dict__)
+
+ def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
+ """A fast way to get out of rewriting modules.
+
+ Profiling has shown that the call to PathFinder.find_spec (inside of
+ the find_spec from this class) is a major slowdown, so, this method
+ tries to filter what we're sure won't be rewritten before getting to
+ it.
+ """
+ if self.session is not None and not self._session_paths_checked:
+ self._session_paths_checked = True
+ for initial_path in self.session._initialpaths:
+ # Make something as c:/projects/my_project/path.py ->
+ # ['c:', 'projects', 'my_project', 'path.py']
+ parts = str(initial_path).split(os.sep)
+ # add 'path' to basenames to be checked.
+ self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
+
+ # Note: conftest already by default in _basenames_to_check_rewrite.
+ parts = name.split(".")
+ if parts[-1] in self._basenames_to_check_rewrite:
+ return False
+
+ # For matching the name it must be as if it was a filename.
+ path = PurePath(*parts).with_suffix(".py")
+
+ for pat in self.fnpats:
+ # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
+ # on the name alone because we need to match against the full path
+ if os.path.dirname(pat):
+ return False
+ if fnmatch_ex(pat, path):
+ return False
+
+ if self._is_marked_for_rewrite(name, state):
+ return False
+
+ state.trace(f"early skip of rewriting module: {name}")
+ return True
+
+ def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
+ # always rewrite conftest files
+ if os.path.basename(fn) == "conftest.py":
+ state.trace(f"rewriting conftest file: {fn!r}")
+ return True
+
+ if self.session is not None:
+ if self.session.isinitpath(absolutepath(fn)):
+ state.trace(f"matched test file (was specified on cmdline): {fn!r}")
+ return True
+
+ # modules not passed explicitly on the command line are only
+ # rewritten if they match the naming convention for test files
+ fn_path = PurePath(fn)
+ for pat in self.fnpats:
+ if fnmatch_ex(pat, fn_path):
+ state.trace(f"matched test file {fn!r}")
+ return True
+
+ return self._is_marked_for_rewrite(name, state)
+
+ def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
+ try:
+ return self._marked_for_rewrite_cache[name]
+ except KeyError:
+ for marked in self._must_rewrite:
+ if name == marked or name.startswith(marked + "."):
+ state.trace(f"matched marked file {name!r} (from {marked!r})")
+ self._marked_for_rewrite_cache[name] = True
+ return True
+
+ self._marked_for_rewrite_cache[name] = False
+ return False
+
+ def mark_rewrite(self, *names: str) -> None:
+ """Mark import names as needing to be rewritten.
+
+ The named module or package as well as any nested modules will
+ be rewritten on import.
+ """
+ already_imported = (
+ set(names).intersection(sys.modules).difference(self._rewritten_names)
+ )
+ for name in already_imported:
+ mod = sys.modules[name]
+ if not AssertionRewriter.is_rewrite_disabled(
+ mod.__doc__ or ""
+ ) and not isinstance(mod.__loader__, type(self)):
+ self._warn_already_imported(name)
+ self._must_rewrite.update(names)
+ self._marked_for_rewrite_cache.clear()
+
+ def _warn_already_imported(self, name: str) -> None:
+ from _pytest.warning_types import PytestAssertRewriteWarning
+
+ self.config.issue_config_time_warning(
+ PytestAssertRewriteWarning(
+ "Module already imported so cannot be rewritten: %s" % name
+ ),
+ stacklevel=5,
+ )
+
+ def get_data(self, pathname: Union[str, bytes]) -> bytes:
+ """Optional PEP302 get_data API."""
+ with open(pathname, "rb") as f:
+ return f.read()
+
+ if sys.version_info >= (3, 10):
+ if sys.version_info >= (3, 12):
+ from importlib.resources.abc import TraversableResources
+ else:
+ from importlib.abc import TraversableResources
+
+ def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
+ if sys.version_info < (3, 11):
+ from importlib.readers import FileReader
+ else:
+ from importlib.resources.readers import FileReader
+
+ return FileReader( # type:ignore[no-any-return]
+ types.SimpleNamespace(path=self._rewritten_names[name])
+ )
+
+
+def _write_pyc_fp(
+ fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
+) -> None:
+ # Technically, we don't have to have the same pyc format as
+ # (C)Python, since these "pycs" should never be seen by builtin
+ # import. However, there's little reason to deviate.
+ fp.write(importlib.util.MAGIC_NUMBER)
+ # https://www.python.org/dev/peps/pep-0552/
+ flags = b"\x00\x00\x00\x00"
+ fp.write(flags)
+ # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
+ mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
+ size = source_stat.st_size & 0xFFFFFFFF
+ # "<LL" stands for 2 unsigned longs, little-endian.
+ fp.write(struct.pack("<LL", mtime, size))
+ fp.write(marshal.dumps(co))
+
+
+def _write_pyc(
+ state: "AssertionState",
+ co: types.CodeType,
+ source_stat: os.stat_result,
+ pyc: Path,
+) -> bool:
+ proc_pyc = f"{pyc}.{os.getpid()}"
+ try:
+ with open(proc_pyc, "wb") as fp:
+ _write_pyc_fp(fp, source_stat, co)
+ except OSError as e:
+ state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
+ return False
+
+ try:
+ os.replace(proc_pyc, pyc)
+ except OSError as e:
+ state.trace(f"error writing pyc file at {pyc}: {e}")
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, pycache dir being a
+ # file etc.
+ return False
+ return True
+
+
+def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
+ """Read and rewrite *fn* and return the code object."""
+ stat = os.stat(fn)
+ source = fn.read_bytes()
+ strfn = str(fn)
+ tree = ast.parse(source, filename=strfn)
+ rewrite_asserts(tree, source, strfn, config)
+ co = compile(tree, strfn, "exec", dont_inherit=True)
+ return stat, co
+
+
+def _read_pyc(
+ source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
+) -> Optional[types.CodeType]:
+ """Possibly read a pytest pyc containing rewritten code.
+
+ Return rewritten code if successful or None if not.
+ """
+ try:
+ fp = open(pyc, "rb")
+ except OSError:
+ return None
+ with fp:
+ try:
+ stat_result = os.stat(source)
+ mtime = int(stat_result.st_mtime)
+ size = stat_result.st_size
+ data = fp.read(16)
+ except OSError as e:
+ trace(f"_read_pyc({source}): OSError {e}")
+ return None
+ # Check for invalid or out of date pyc file.
+ if len(data) != (16):
+ trace("_read_pyc(%s): invalid pyc (too short)" % source)
+ return None
+ if data[:4] != importlib.util.MAGIC_NUMBER:
+ trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
+ return None
+ if data[4:8] != b"\x00\x00\x00\x00":
+ trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
+ return None
+ mtime_data = data[8:12]
+ if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
+ trace("_read_pyc(%s): out of date" % source)
+ return None
+ size_data = data[12:16]
+ if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
+ trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
+ return None
+ try:
+ co = marshal.load(fp)
+ except Exception as e:
+ trace(f"_read_pyc({source}): marshal.load error {e}")
+ return None
+ if not isinstance(co, types.CodeType):
+ trace("_read_pyc(%s): not a code object" % source)
+ return None
+ return co
+
+
+def rewrite_asserts(
+ mod: ast.Module,
+ source: bytes,
+ module_path: Optional[str] = None,
+ config: Optional[Config] = None,
+) -> None:
+ """Rewrite the assert statements in mod."""
+ AssertionRewriter(module_path, config, source).run(mod)
+
+
+def _saferepr(obj: object) -> str:
+ r"""Get a safe repr of an object for assertion error messages.
+
+ The assertion formatting (util.format_explanation()) requires
+ newlines to be escaped since they are a special character for it.
+ Normally assertion.util.format_explanation() does this but for a
+ custom repr it is possible to contain one of the special escape
+ sequences, especially '\n{' and '\n}' are likely to be present in
+ JSON reprs.
+ """
+ maxsize = _get_maxsize_for_saferepr(util._config)
+ return saferepr(obj, maxsize=maxsize).replace("\n", "\\n")
+
+
+def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
+ """Get `maxsize` configuration for saferepr based on the given config object."""
+ verbosity = config.getoption("verbose") if config is not None else 0
+ if verbosity >= 2:
+ return None
+ if verbosity >= 1:
+ return DEFAULT_REPR_MAX_SIZE * 10
+ return DEFAULT_REPR_MAX_SIZE
+
+
+def _format_assertmsg(obj: object) -> str:
+ r"""Format the custom assertion message given.
+
+ For strings this simply replaces newlines with '\n~' so that
+ util.format_explanation() will preserve them instead of escaping
+ newlines. For other objects saferepr() is used first.
+ """
+ # reprlib appears to have a bug which means that if a string
+ # contains a newline it gets escaped, however if an object has a
+ # .__repr__() which contains newlines it does not get escaped.
+ # However in either case we want to preserve the newline.
+ replaces = [("\n", "\n~"), ("%", "%%")]
+ if not isinstance(obj, str):
+ obj = saferepr(obj)
+ replaces.append(("\\n", "\n~"))
+
+ for r1, r2 in replaces:
+ obj = obj.replace(r1, r2)
+
+ return obj
+
+
+def _should_repr_global_name(obj: object) -> bool:
+ if callable(obj):
+ return False
+
+ try:
+ return not hasattr(obj, "__name__")
+ except Exception:
+ return True
+
+
+def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
+ explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
+ return explanation.replace("%", "%%")
+
+
+def _call_reprcompare(
+ ops: Sequence[str],
+ results: Sequence[bool],
+ expls: Sequence[str],
+ each_obj: Sequence[object],
+) -> str:
+ for i, res, expl in zip(range(len(ops)), results, expls):
+ try:
+ done = not res
+ except Exception:
+ done = True
+ if done:
+ break
+ if util._reprcompare is not None:
+ custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
+ if custom is not None:
+ return custom
+ return expl
+
+
+def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
+ if util._assertion_pass is not None:
+ util._assertion_pass(lineno, orig, expl)
+
+
+def _check_if_assertion_pass_impl() -> bool:
+ """Check if any plugins implement the pytest_assertion_pass hook
+ in order not to generate explanation unnecessarily (might be expensive)."""
+ return True if util._assertion_pass else False
+
+
+UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
+
+BINOP_MAP = {
+ ast.BitOr: "|",
+ ast.BitXor: "^",
+ ast.BitAnd: "&",
+ ast.LShift: "<<",
+ ast.RShift: ">>",
+ ast.Add: "+",
+ ast.Sub: "-",
+ ast.Mult: "*",
+ ast.Div: "/",
+ ast.FloorDiv: "//",
+ ast.Mod: "%%", # escaped for string formatting
+ ast.Eq: "==",
+ ast.NotEq: "!=",
+ ast.Lt: "<",
+ ast.LtE: "<=",
+ ast.Gt: ">",
+ ast.GtE: ">=",
+ ast.Pow: "**",
+ ast.Is: "is",
+ ast.IsNot: "is not",
+ ast.In: "in",
+ ast.NotIn: "not in",
+ ast.MatMult: "@",
+}
+
+
+def traverse_node(node: ast.AST) -> Iterator[ast.AST]:
+ """Recursively yield node and all its children in depth-first order."""
+ yield node
+ for child in ast.iter_child_nodes(node):
+ yield from traverse_node(child)
+
+
+@functools.lru_cache(maxsize=1)
+def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
+ """Return a mapping from {lineno: "assertion test expression"}."""
+ ret: Dict[int, str] = {}
+
+ depth = 0
+ lines: List[str] = []
+ assert_lineno: Optional[int] = None
+ seen_lines: Set[int] = set()
+
+ def _write_and_reset() -> None:
+ nonlocal depth, lines, assert_lineno, seen_lines
+ assert assert_lineno is not None
+ ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
+ depth = 0
+ lines = []
+ assert_lineno = None
+ seen_lines = set()
+
+ tokens = tokenize.tokenize(io.BytesIO(src).readline)
+ for tp, source, (lineno, offset), _, line in tokens:
+ if tp == tokenize.NAME and source == "assert":
+ assert_lineno = lineno
+ elif assert_lineno is not None:
+ # keep track of depth for the assert-message `,` lookup
+ if tp == tokenize.OP and source in "([{":
+ depth += 1
+ elif tp == tokenize.OP and source in ")]}":
+ depth -= 1
+
+ if not lines:
+ lines.append(line[offset:])
+ seen_lines.add(lineno)
+ # a non-nested comma separates the expression from the message
+ elif depth == 0 and tp == tokenize.OP and source == ",":
+ # one line assert with message
+ if lineno in seen_lines and len(lines) == 1:
+ offset_in_trimmed = offset + len(lines[-1]) - len(line)
+ lines[-1] = lines[-1][:offset_in_trimmed]
+ # multi-line assert with message
+ elif lineno in seen_lines:
+ lines[-1] = lines[-1][:offset]
+ # multi line assert with escapd newline before message
+ else:
+ lines.append(line[:offset])
+ _write_and_reset()
+ elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
+ _write_and_reset()
+ elif lines and lineno not in seen_lines:
+ lines.append(line)
+ seen_lines.add(lineno)
+
+ return ret
+
+
+class AssertionRewriter(ast.NodeVisitor):
+ """Assertion rewriting implementation.
+
+ The main entrypoint is to call .run() with an ast.Module instance,
+ this will then find all the assert statements and rewrite them to
+ provide intermediate values and a detailed assertion error. See
+ http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
+ for an overview of how this works.
+
+ The entry point here is .run() which will iterate over all the
+ statements in an ast.Module and for each ast.Assert statement it
+ finds call .visit() with it. Then .visit_Assert() takes over and
+ is responsible for creating new ast statements to replace the
+ original assert statement: it rewrites the test of an assertion
+ to provide intermediate values and replace it with an if statement
+ which raises an assertion error with a detailed explanation in
+ case the expression is false and calls pytest_assertion_pass hook
+ if expression is true.
+
+ For this .visit_Assert() uses the visitor pattern to visit all the
+ AST nodes of the ast.Assert.test field, each visit call returning
+ an AST node and the corresponding explanation string. During this
+ state is kept in several instance attributes:
+
+ :statements: All the AST statements which will replace the assert
+ statement.
+
+ :variables: This is populated by .variable() with each variable
+ used by the statements so that they can all be set to None at
+ the end of the statements.
+
+ :variable_counter: Counter to create new unique variables needed
+ by statements. Variables are created using .variable() and
+ have the form of "@py_assert0".
+
+ :expl_stmts: The AST statements which will be executed to get
+ data from the assertion. This is the code which will construct
+ the detailed assertion message that is used in the AssertionError
+ or for the pytest_assertion_pass hook.
+
+ :explanation_specifiers: A dict filled by .explanation_param()
+ with %-formatting placeholders and their corresponding
+ expressions to use in the building of an assertion message.
+ This is used by .pop_format_context() to build a message.
+
+ :stack: A stack of the explanation_specifiers dicts maintained by
+ .push_format_context() and .pop_format_context() which allows
+ to build another %-formatted string while already building one.
+
+ :variables_overwrite: A dict filled with references to variables
+ that change value within an assert. This happens when a variable is
+ reassigned with the walrus operator
+
+ This state, except the variables_overwrite, is reset on every new assert
+ statement visited and used by the other visitors.
+ """
+
+ def __init__(
+ self, module_path: Optional[str], config: Optional[Config], source: bytes
+ ) -> None:
+ super().__init__()
+ self.module_path = module_path
+ self.config = config
+ if config is not None:
+ self.enable_assertion_pass_hook = config.getini(
+ "enable_assertion_pass_hook"
+ )
+ else:
+ self.enable_assertion_pass_hook = False
+ self.source = source
+ self.variables_overwrite: Dict[str, str] = {}
+
+ def run(self, mod: ast.Module) -> None:
+ """Find all assert statements in *mod* and rewrite them."""
+ if not mod.body:
+ # Nothing to do.
+ return
+
+ # We'll insert some special imports at the top of the module, but after any
+ # docstrings and __future__ imports, so first figure out where that is.
+ doc = getattr(mod, "docstring", None)
+ expect_docstring = doc is None
+ if doc is not None and self.is_rewrite_disabled(doc):
+ return
+ pos = 0
+ item = None
+ for item in mod.body:
+ if (
+ expect_docstring
+ and isinstance(item, ast.Expr)
+ and isinstance(item.value, astStr)
+ ):
+ if sys.version_info >= (3, 8):
+ doc = item.value.value
+ else:
+ doc = item.value.s
+ if self.is_rewrite_disabled(doc):
+ return
+ expect_docstring = False
+ elif (
+ isinstance(item, ast.ImportFrom)
+ and item.level == 0
+ and item.module == "__future__"
+ ):
+ pass
+ else:
+ break
+ pos += 1
+ # Special case: for a decorated function, set the lineno to that of the
+ # first decorator, not the `def`. Issue #4984.
+ if isinstance(item, ast.FunctionDef) and item.decorator_list:
+ lineno = item.decorator_list[0].lineno
+ else:
+ lineno = item.lineno
+ # Now actually insert the special imports.
+ if sys.version_info >= (3, 10):
+ aliases = [
+ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
+ ast.alias(
+ "_pytest.assertion.rewrite",
+ "@pytest_ar",
+ lineno=lineno,
+ col_offset=0,
+ ),
+ ]
+ else:
+ aliases = [
+ ast.alias("builtins", "@py_builtins"),
+ ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
+ ]
+ imports = [
+ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
+ ]
+ mod.body[pos:pos] = imports
+
+ # Collect asserts.
+ nodes: List[ast.AST] = [mod]
+ while nodes:
+ node = nodes.pop()
+ for name, field in ast.iter_fields(node):
+ if isinstance(field, list):
+ new: List[ast.AST] = []
+ for i, child in enumerate(field):
+ if isinstance(child, ast.Assert):
+ # Transform assert.
+ new.extend(self.visit(child))
+ else:
+ new.append(child)
+ if isinstance(child, ast.AST):
+ nodes.append(child)
+ setattr(node, name, new)
+ elif (
+ isinstance(field, ast.AST)
+ # Don't recurse into expressions as they can't contain
+ # asserts.
+ and not isinstance(field, ast.expr)
+ ):
+ nodes.append(field)
+
+ @staticmethod
+ def is_rewrite_disabled(docstring: str) -> bool:
+ return "PYTEST_DONT_REWRITE" in docstring
+
+ def variable(self) -> str:
+ """Get a new variable."""
+ # Use a character invalid in python identifiers to avoid clashing.
+ name = "@py_assert" + str(next(self.variable_counter))
+ self.variables.append(name)
+ return name
+
+ def assign(self, expr: ast.expr) -> ast.Name:
+ """Give *expr* a name."""
+ name = self.variable()
+ self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
+ return ast.Name(name, ast.Load())
+
+ def display(self, expr: ast.expr) -> ast.expr:
+ """Call saferepr on the expression."""
+ return self.helper("_saferepr", expr)
+
+ def helper(self, name: str, *args: ast.expr) -> ast.expr:
+ """Call a helper in this module."""
+ py_name = ast.Name("@pytest_ar", ast.Load())
+ attr = ast.Attribute(py_name, name, ast.Load())
+ return ast.Call(attr, list(args), [])
+
+ def builtin(self, name: str) -> ast.Attribute:
+ """Return the builtin called *name*."""
+ builtin_name = ast.Name("@py_builtins", ast.Load())
+ return ast.Attribute(builtin_name, name, ast.Load())
+
+ def explanation_param(self, expr: ast.expr) -> str:
+ """Return a new named %-formatting placeholder for expr.
+
+ This creates a %-formatting placeholder for expr in the
+ current formatting context, e.g. ``%(py0)s``. The placeholder
+ and expr are placed in the current format context so that it
+ can be used on the next call to .pop_format_context().
+ """
+ specifier = "py" + str(next(self.variable_counter))
+ self.explanation_specifiers[specifier] = expr
+ return "%(" + specifier + ")s"
+
+ def push_format_context(self) -> None:
+ """Create a new formatting context.
+
+ The format context is used for when an explanation wants to
+ have a variable value formatted in the assertion message. In
+ this case the value required can be added using
+ .explanation_param(). Finally .pop_format_context() is used
+ to format a string of %-formatted values as added by
+ .explanation_param().
+ """
+ self.explanation_specifiers: Dict[str, ast.expr] = {}
+ self.stack.append(self.explanation_specifiers)
+
+ def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
+ """Format the %-formatted string with current format context.
+
+ The expl_expr should be an str ast.expr instance constructed from
+ the %-placeholders created by .explanation_param(). This will
+ add the required code to format said string to .expl_stmts and
+ return the ast.Name instance of the formatted string.
+ """
+ current = self.stack.pop()
+ if self.stack:
+ self.explanation_specifiers = self.stack[-1]
+ keys = [astStr(key) for key in current.keys()]
+ format_dict = ast.Dict(keys, list(current.values()))
+ form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
+ name = "@py_format" + str(next(self.variable_counter))
+ if self.enable_assertion_pass_hook:
+ self.format_variables.append(name)
+ self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
+ return ast.Name(name, ast.Load())
+
+ def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
+ """Handle expressions we don't have custom code for."""
+ assert isinstance(node, ast.expr)
+ res = self.assign(node)
+ return res, self.explanation_param(self.display(res))
+
+ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
+ """Return the AST statements to replace the ast.Assert instance.
+
+ This rewrites the test of an assertion to provide
+ intermediate values and replace it with an if statement which
+ raises an assertion error with a detailed explanation in case
+ the expression is false.
+ """
+ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
+ from _pytest.warning_types import PytestAssertRewriteWarning
+ import warnings
+
+ # TODO: This assert should not be needed.
+ assert self.module_path is not None
+ warnings.warn_explicit(
+ PytestAssertRewriteWarning(
+ "assertion is always true, perhaps remove parentheses?"
+ ),
+ category=None,
+ filename=self.module_path,
+ lineno=assert_.lineno,
+ )
+
+ self.statements: List[ast.stmt] = []
+ self.variables: List[str] = []
+ self.variable_counter = itertools.count()
+
+ if self.enable_assertion_pass_hook:
+ self.format_variables: List[str] = []
+
+ self.stack: List[Dict[str, ast.expr]] = []
+ self.expl_stmts: List[ast.stmt] = []
+ self.push_format_context()
+ # Rewrite assert into a bunch of statements.
+ top_condition, explanation = self.visit(assert_.test)
+
+ negation = ast.UnaryOp(ast.Not(), top_condition)
+
+ if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
+ msg = self.pop_format_context(astStr(explanation))
+
+ # Failed
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ gluestr = "\n>assert "
+ else:
+ assertmsg = astStr("")
+ gluestr = "assert "
+ err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
+ err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
+ err_name = ast.Name("AssertionError", ast.Load())
+ fmt = self.helper("_format_explanation", err_msg)
+ exc = ast.Call(err_name, [fmt], [])
+ raise_ = ast.Raise(exc, None)
+ statements_fail = []
+ statements_fail.extend(self.expl_stmts)
+ statements_fail.append(raise_)
+
+ # Passed
+ fmt_pass = self.helper("_format_explanation", msg)
+ orig = _get_assertion_exprs(self.source)[assert_.lineno]
+ hook_call_pass = ast.Expr(
+ self.helper(
+ "_call_assertion_pass",
+ astNum(assert_.lineno),
+ astStr(orig),
+ fmt_pass,
+ )
+ )
+ # If any hooks implement assert_pass hook
+ hook_impl_test = ast.If(
+ self.helper("_check_if_assertion_pass_impl"),
+ self.expl_stmts + [hook_call_pass],
+ [],
+ )
+ statements_pass = [hook_impl_test]
+
+ # Test for assertion condition
+ main_test = ast.If(negation, statements_fail, statements_pass)
+ self.statements.append(main_test)
+ if self.format_variables:
+ variables = [
+ ast.Name(name, ast.Store()) for name in self.format_variables
+ ]
+ clear_format = ast.Assign(variables, astNameConstant(None))
+ self.statements.append(clear_format)
+
+ else: # Original assertion rewriting
+ # Create failure message.
+ body = self.expl_stmts
+ self.statements.append(ast.If(negation, body, []))
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ explanation = "\n>assert " + explanation
+ else:
+ assertmsg = astStr("")
+ explanation = "assert " + explanation
+ template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
+ msg = self.pop_format_context(template)
+ fmt = self.helper("_format_explanation", msg)
+ err_name = ast.Name("AssertionError", ast.Load())
+ exc = ast.Call(err_name, [fmt], [])
+ raise_ = ast.Raise(exc, None)
+
+ body.append(raise_)
+
+ # Clear temporary variables by setting them to None.
+ if self.variables:
+ variables = [ast.Name(name, ast.Store()) for name in self.variables]
+ clear = ast.Assign(variables, astNameConstant(None))
+ self.statements.append(clear)
+ # Fix locations (line numbers/column offsets).
+ for stmt in self.statements:
+ for node in traverse_node(stmt):
+ ast.copy_location(node, assert_)
+ return self.statements
+
+ def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
+ # This method handles the 'walrus operator' repr of the target
+ # name if it's a local variable or _should_repr_global_name()
+ # thinks it's acceptable.
+ locs = ast.Call(self.builtin("locals"), [], [])
+ target_id = name.target.id # type: ignore[attr-defined]
+ inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
+ dorepr = self.helper("_should_repr_global_name", name)
+ test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
+ expr = ast.IfExp(test, self.display(name), astStr(target_id))
+ return name, self.explanation_param(expr)
+
+ def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
+ # Display the repr of the name if it's a local variable or
+ # _should_repr_global_name() thinks it's acceptable.
+ locs = ast.Call(self.builtin("locals"), [], [])
+ inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
+ dorepr = self.helper("_should_repr_global_name", name)
+ test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
+ expr = ast.IfExp(test, self.display(name), astStr(name.id))
+ return name, self.explanation_param(expr)
+
+ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
+ res_var = self.variable()
+ expl_list = self.assign(ast.List([], ast.Load()))
+ app = ast.Attribute(expl_list, "append", ast.Load())
+ is_or = int(isinstance(boolop.op, ast.Or))
+ body = save = self.statements
+ fail_save = self.expl_stmts
+ levels = len(boolop.values) - 1
+ self.push_format_context()
+ # Process each operand, short-circuiting if needed.
+ for i, v in enumerate(boolop.values):
+ if i:
+ fail_inner: List[ast.stmt] = []
+ # cond is set in a prior loop iteration below
+ self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
+ self.expl_stmts = fail_inner
+ # Check if the left operand is a namedExpr and the value has already been visited
+ if (
+ isinstance(v, ast.Compare)
+ and isinstance(v.left, namedExpr)
+ and v.left.target.id
+ in [
+ ast_expr.id
+ for ast_expr in boolop.values[:i]
+ if hasattr(ast_expr, "id")
+ ]
+ ):
+ pytest_temp = self.variable()
+ self.variables_overwrite[
+ v.left.target.id
+ ] = v.left # type:ignore[assignment]
+ v.left.target.id = pytest_temp
+ self.push_format_context()
+ res, expl = self.visit(v)
+ body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+ expl_format = self.pop_format_context(astStr(expl))
+ call = ast.Call(app, [expl_format], [])
+ self.expl_stmts.append(ast.Expr(call))
+ if i < levels:
+ cond: ast.expr = res
+ if is_or:
+ cond = ast.UnaryOp(ast.Not(), cond)
+ inner: List[ast.stmt] = []
+ self.statements.append(ast.If(cond, inner, []))
+ self.statements = body = inner
+ self.statements = save
+ self.expl_stmts = fail_save
+ expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
+ expl = self.pop_format_context(expl_template)
+ return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
+
+ def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
+ pattern = UNARY_MAP[unary.op.__class__]
+ operand_res, operand_expl = self.visit(unary.operand)
+ res = self.assign(ast.UnaryOp(unary.op, operand_res))
+ return res, pattern % (operand_expl,)
+
+ def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
+ symbol = BINOP_MAP[binop.op.__class__]
+ left_expr, left_expl = self.visit(binop.left)
+ right_expr, right_expl = self.visit(binop.right)
+ explanation = f"({left_expl} {symbol} {right_expl})"
+ res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
+ return res, explanation
+
+ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ new_func, func_expl = self.visit(call.func)
+ arg_expls = []
+ new_args = []
+ new_kwargs = []
+ for arg in call.args:
+ if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
+ arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
+ res, expl = self.visit(arg)
+ arg_expls.append(expl)
+ new_args.append(res)
+ for keyword in call.keywords:
+ if (
+ isinstance(keyword.value, ast.Name)
+ and keyword.value.id in self.variables_overwrite
+ ):
+ keyword.value = self.variables_overwrite[
+ keyword.value.id
+ ] # type:ignore[assignment]
+ res, expl = self.visit(keyword.value)
+ new_kwargs.append(ast.keyword(keyword.arg, res))
+ if keyword.arg:
+ arg_expls.append(keyword.arg + "=" + expl)
+ else: # **args have `arg` keywords with an .arg of None
+ arg_expls.append("**" + expl)
+
+ expl = "{}({})".format(func_expl, ", ".join(arg_expls))
+ new_call = ast.Call(new_func, new_args, new_kwargs)
+ res = self.assign(new_call)
+ res_expl = self.explanation_param(self.display(res))
+ outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
+ return res, outer_expl
+
+ def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
+ # A Starred node can appear in a function call.
+ res, expl = self.visit(starred.value)
+ new_starred = ast.Starred(res, starred.ctx)
+ return new_starred, "*" + expl
+
+ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
+ if not isinstance(attr.ctx, ast.Load):
+ return self.generic_visit(attr)
+ value, value_expl = self.visit(attr.value)
+ res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
+ res_expl = self.explanation_param(self.display(res))
+ pat = "%s\n{%s = %s.%s\n}"
+ expl = pat % (res_expl, res_expl, value_expl, attr.attr)
+ return res, expl
+
+ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
+ self.push_format_context()
+ # We first check if we have overwritten a variable in the previous assert
+ if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
+ comp.left = self.variables_overwrite[
+ comp.left.id
+ ] # type:ignore[assignment]
+ if isinstance(comp.left, namedExpr):
+ self.variables_overwrite[
+ comp.left.target.id
+ ] = comp.left # type:ignore[assignment]
+ left_res, left_expl = self.visit(comp.left)
+ if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
+ left_expl = f"({left_expl})"
+ res_variables = [self.variable() for i in range(len(comp.ops))]
+ load_names = [ast.Name(v, ast.Load()) for v in res_variables]
+ store_names = [ast.Name(v, ast.Store()) for v in res_variables]
+ it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
+ expls = []
+ syms = []
+ results = [left_res]
+ for i, op, next_operand in it:
+ if (
+ isinstance(next_operand, namedExpr)
+ and isinstance(left_res, ast.Name)
+ and next_operand.target.id == left_res.id
+ ):
+ next_operand.target.id = self.variable()
+ self.variables_overwrite[
+ left_res.id
+ ] = next_operand # type:ignore[assignment]
+ next_res, next_expl = self.visit(next_operand)
+ if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
+ next_expl = f"({next_expl})"
+ results.append(next_res)
+ sym = BINOP_MAP[op.__class__]
+ syms.append(astStr(sym))
+ expl = f"{left_expl} {sym} {next_expl}"
+ expls.append(astStr(expl))
+ res_expr = ast.Compare(left_res, [op], [next_res])
+ self.statements.append(ast.Assign([store_names[i]], res_expr))
+ left_res, left_expl = next_res, next_expl
+ # Use pytest.assertion.util._reprcompare if that's available.
+ expl_call = self.helper(
+ "_call_reprcompare",
+ ast.Tuple(syms, ast.Load()),
+ ast.Tuple(load_names, ast.Load()),
+ ast.Tuple(expls, ast.Load()),
+ ast.Tuple(results, ast.Load()),
+ )
+ if len(comp.ops) > 1:
+ res: ast.expr = ast.BoolOp(ast.And(), load_names)
+ else:
+ res = load_names[0]
+
+ return res, self.explanation_param(self.pop_format_context(expl_call))
+
+
+def try_makedirs(cache_dir: Path) -> bool:
+ """Attempt to create the given directory and sub-directories exist.
+
+ Returns True if successful or if it already exists.
+ """
+ try:
+ os.makedirs(cache_dir, exist_ok=True)
+ except (FileNotFoundError, NotADirectoryError, FileExistsError):
+ # One of the path components was not a directory:
+ # - we're in a zip file
+ # - it is a file
+ return False
+ except PermissionError:
+ return False
+ except OSError as e:
+ # as of now, EROFS doesn't have an equivalent OSError-subclass
+ if e.errno == errno.EROFS:
+ return False
+ raise
+ return True
+
+
+def get_cache_dir(file_path: Path) -> Path:
+ """Return the cache directory to write .pyc files for the given .py file path."""
+ if sys.version_info >= (3, 8) and sys.pycache_prefix:
+ # given:
+ # prefix = '/tmp/pycs'
+ # path = '/home/user/proj/test_app.py'
+ # we want:
+ # '/tmp/pycs/home/user/proj'
+ return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
+ else:
+ # classic pycache directory
+ return file_path.parent / "__pycache__"