diff options
| author | shadchin <[email protected]> | 2022-02-10 16:44:39 +0300 |
|---|---|---|
| committer | Daniil Cherednik <[email protected]> | 2022-02-10 16:44:39 +0300 |
| commit | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch) | |
| tree | 64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/python/pytest/py3/_pytest/assertion | |
| parent | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff) | |
Restoring authorship annotation for <[email protected]>. Commit 2 of 2.
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/__init__.py | 154 | ||||
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/rewrite.py | 1146 | ||||
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/truncate.py | 48 | ||||
| -rw-r--r-- | contrib/python/pytest/py3/_pytest/assertion/util.py | 610 |
4 files changed, 979 insertions, 979 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py index 430eb2791b0..a18cf198df0 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py +++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py @@ -1,25 +1,25 @@ -"""Support for presenting detailed information in failing assertions.""" +"""Support for presenting detailed information in failing assertions.""" import sys -from typing import Any -from typing import Generator -from typing import List -from typing import Optional -from typing import TYPE_CHECKING +from typing import Any +from typing import Generator +from typing import List +from typing import Optional +from typing import TYPE_CHECKING from _pytest.assertion import rewrite from _pytest.assertion import truncate from _pytest.assertion import util -from _pytest.assertion.rewrite import assertstate_key -from _pytest.config import Config -from _pytest.config import hookimpl -from _pytest.config.argparsing import Parser -from _pytest.nodes import Item +from _pytest.assertion.rewrite import assertstate_key +from _pytest.config import Config +from _pytest.config import hookimpl +from _pytest.config.argparsing import Parser +from _pytest.nodes import Item -if TYPE_CHECKING: - from _pytest.main import Session +if TYPE_CHECKING: + from _pytest.main import Session - -def pytest_addoption(parser: Parser) -> None: + +def pytest_addoption(parser: Parser) -> None: group = parser.getgroup("debugconfig") group.addoption( "--assert", @@ -28,23 +28,23 @@ def pytest_addoption(parser: Parser) -> None: choices=("rewrite", "plain"), default="rewrite", metavar="MODE", - help=( - "Control assertion debugging tools.\n" - "'plain' performs no assertion debugging.\n" - "'rewrite' (the default) rewrites assert statements in test modules" - " on import to provide assert expression information." - ), + help=( + "Control assertion debugging tools.\n" + "'plain' performs no assertion debugging.\n" + "'rewrite' (the default) rewrites assert statements in test modules" + " on import to provide assert expression information." + ), + ) + parser.addini( + "enable_assertion_pass_hook", + type="bool", + default=False, + help="Enables the pytest_assertion_pass hook." + "Make sure to delete any previously generated pyc cache files.", ) - parser.addini( - "enable_assertion_pass_hook", - type="bool", - default=False, - help="Enables the pytest_assertion_pass hook." - "Make sure to delete any previously generated pyc cache files.", - ) -def register_assert_rewrite(*names: str) -> None: +def register_assert_rewrite(*names: str) -> None: """Register one or more module names to be rewritten on import. This function will make sure that this module or all modules inside @@ -53,48 +53,48 @@ def register_assert_rewrite(*names: str) -> None: actually imported, usually in your __init__.py if you are a plugin using a package. - :raises TypeError: If the given module names are not strings. + :raises TypeError: If the given module names are not strings. """ for name in names: if not isinstance(name, str): - msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable] + msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable] raise TypeError(msg.format(repr(names))) for hook in sys.meta_path: if isinstance(hook, rewrite.AssertionRewritingHook): importhook = hook break else: - # TODO(typing): Add a protocol for mark_rewrite() and use it - # for importhook and for PytestPluginManager.rewrite_hook. - importhook = DummyRewriteHook() # type: ignore + # TODO(typing): Add a protocol for mark_rewrite() and use it + # for importhook and for PytestPluginManager.rewrite_hook. + importhook = DummyRewriteHook() # type: ignore importhook.mark_rewrite(*names) -class DummyRewriteHook: +class DummyRewriteHook: """A no-op import hook for when rewriting is disabled.""" - def mark_rewrite(self, *names: str) -> None: + def mark_rewrite(self, *names: str) -> None: pass -class AssertionState: +class AssertionState: """State for the assertion plugin.""" - def __init__(self, config: Config, mode) -> None: + def __init__(self, config: Config, mode) -> None: self.mode = mode self.trace = config.trace.root.get("assertion") - self.hook: Optional[rewrite.AssertionRewritingHook] = None + self.hook: Optional[rewrite.AssertionRewritingHook] = None -def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: +def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: """Try to install the rewrite hook, raise SystemError if it fails.""" - config._store[assertstate_key] = AssertionState(config, "rewrite") - config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) + config._store[assertstate_key] = AssertionState(config, "rewrite") + config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config) sys.meta_path.insert(0, hook) - config._store[assertstate_key].trace("installed rewrite import hook") + config._store[assertstate_key].trace("installed rewrite import hook") - def undo() -> None: - hook = config._store[assertstate_key].hook + def undo() -> None: + hook = config._store[assertstate_key].hook if hook is not None and hook in sys.meta_path: sys.meta_path.remove(hook) @@ -102,30 +102,30 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook: return hook -def pytest_collection(session: "Session") -> None: - # This hook is only called when test modules are collected +def pytest_collection(session: "Session") -> None: + # This hook is only called when test modules are collected # so for example not in the master process of pytest-xdist - # (which does not collect test modules). - assertstate = session.config._store.get(assertstate_key, None) + # (which does not collect test modules). + assertstate = session.config._store.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(session) -@hookimpl(tryfirst=True, hookwrapper=True) -def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: - """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks. +@hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: + """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks. - The rewrite module will use util._reprcompare if it exists to use custom - reporting via the pytest_assertrepr_compare hook. This sets up this custom + The rewrite module will use util._reprcompare if it exists to use custom + reporting via the pytest_assertrepr_compare hook. This sets up this custom comparison for the test. """ - ihook = item.ihook + ihook = item.ihook + + def callbinrepr(op, left: object, right: object) -> Optional[str]: + """Call the pytest_assertrepr_compare hook and prepare the result. - def callbinrepr(op, left: object, right: object) -> Optional[str]: - """Call the pytest_assertrepr_compare hook and prepare the result. - This uses the first result from the hook and then ensures the following: * Overly verbose explanations are truncated unless configured otherwise @@ -138,42 +138,42 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: The result can be formatted by util.format_explanation() for pretty printing. """ - hook_result = ihook.pytest_assertrepr_compare( + hook_result = ihook.pytest_assertrepr_compare( config=item.config, op=op, left=left, right=right ) for new_expl in hook_result: if new_expl: new_expl = truncate.truncate_if_required(new_expl, item) new_expl = [line.replace("\n", "\\n") for line in new_expl] - res = "\n~".join(new_expl) + res = "\n~".join(new_expl) if item.config.getvalue("assertmode") == "rewrite": res = res.replace("%", "%%") return res - return None + return None - saved_assert_hooks = util._reprcompare, util._assertion_pass + saved_assert_hooks = util._reprcompare, util._assertion_pass util._reprcompare = callbinrepr - if ihook.pytest_assertion_pass.get_hookimpls(): + if ihook.pytest_assertion_pass.get_hookimpls(): + + def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None: + ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl) + + util._assertion_pass = call_assertion_pass_hook + + yield - def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None: - ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl) + util._reprcompare, util._assertion_pass = saved_assert_hooks - util._assertion_pass = call_assertion_pass_hook - yield - - util._reprcompare, util._assertion_pass = saved_assert_hooks - - -def pytest_sessionfinish(session: "Session") -> None: - assertstate = session.config._store.get(assertstate_key, None) +def pytest_sessionfinish(session: "Session") -> None: + assertstate = session.config._store.get(assertstate_key, None) if assertstate: if assertstate.hook is not None: assertstate.hook.set_session(None) -def pytest_assertrepr_compare( - config: Config, op: str, left: Any, right: Any -) -> Optional[List[str]]: - return util.assertrepr_compare(config=config, op=op, left=left, right=right) +def pytest_assertrepr_compare( + config: Config, op: str, left: Any, right: Any +) -> Optional[List[str]]: + return util.assertrepr_compare(config=config, op=op, left=left, right=right) diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py index c79dfd9a686..37ff076aab5 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py +++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py @@ -1,140 +1,140 @@ -"""Rewrite assertion AST to produce nice error messages.""" +"""Rewrite assertion AST to produce nice error messages.""" import ast import errno -import functools -import importlib.abc -import importlib.machinery -import importlib.util -import io +import functools +import importlib.abc +import importlib.machinery +import importlib.util +import io import itertools import marshal import os import struct import sys -import tokenize +import tokenize import types -from pathlib import Path -from pathlib import PurePath -from typing import Callable -from typing import Dict -from typing import IO -from typing import Iterable -from typing import List -from typing import Optional -from typing import Sequence -from typing import Set -from typing import Tuple -from typing import TYPE_CHECKING -from typing import Union +from pathlib import Path +from pathlib import PurePath +from typing import Callable +from typing import Dict +from typing import IO +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union -import py - -from _pytest._io.saferepr import saferepr -from _pytest._version import version +import py + +from _pytest._io.saferepr import saferepr +from _pytest._version import version from _pytest.assertion import util -from _pytest.assertion.util import ( # noqa: F401 - format_explanation as _format_explanation, -) -from _pytest.config import Config -from _pytest.main import Session +from _pytest.assertion.util import ( # noqa: F401 + format_explanation as _format_explanation, +) +from _pytest.config import Config +from _pytest.main import Session from _pytest.pathlib import fnmatch_ex -from _pytest.store import StoreKey +from _pytest.store import StoreKey -if TYPE_CHECKING: - from _pytest.assertion import AssertionState +if TYPE_CHECKING: + from _pytest.assertion import AssertionState -assertstate_key = StoreKey["AssertionState"]() +assertstate_key = StoreKey["AssertionState"]() -# pytest caches rewritten pycs in pycache dirs -PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" -PYC_EXT = ".py" + (__debug__ and "c" or "o") -PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +# pytest caches rewritten pycs in pycache dirs +PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" +PYC_EXT = ".py" + (__debug__ and "c" or "o") +PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): - """PEP302/PEP451 import hook which rewrites asserts.""" +class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): + """PEP302/PEP451 import hook which rewrites asserts.""" - def __init__(self, config: Config) -> None: + def __init__(self, config: Config) -> None: self.config = config - try: - self.fnpats = config.getini("python_files") - except ValueError: - self.fnpats = ["test_*.py", "*_test.py"] - self.session: Optional[Session] = None - self._rewritten_names: Set[str] = set() - self._must_rewrite: Set[str] = set() + try: + self.fnpats = config.getini("python_files") + except ValueError: + self.fnpats = ["test_*.py", "*_test.py"] + self.session: Optional[Session] = None + self._rewritten_names: Set[str] = set() + self._must_rewrite: Set[str] = set() # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, # which might result in infinite recursion (#3506) self._writing_pyc = False self._basenames_to_check_rewrite = {"conftest"} - self._marked_for_rewrite_cache: Dict[str, bool] = {} + self._marked_for_rewrite_cache: Dict[str, bool] = {} self._session_paths_checked = False - def set_session(self, session: Optional[Session]) -> None: + def set_session(self, session: Optional[Session]) -> None: self.session = session self._session_paths_checked = False - # Indirection so we can mock calls to find_spec originated from the hook during testing - _find_spec = importlib.machinery.PathFinder.find_spec + # Indirection so we can mock calls to find_spec originated from the hook during testing + _find_spec = importlib.machinery.PathFinder.find_spec - def find_spec( - self, - name: str, - path: Optional[Sequence[Union[str, bytes]]] = None, - target: Optional[types.ModuleType] = None, - ) -> Optional[importlib.machinery.ModuleSpec]: + def find_spec( + self, + name: str, + path: Optional[Sequence[Union[str, bytes]]] = None, + target: Optional[types.ModuleType] = None, + ) -> Optional[importlib.machinery.ModuleSpec]: if self._writing_pyc: return None - state = self.config._store[assertstate_key] + state = self.config._store[assertstate_key] if self._early_rewrite_bailout(name, state): return None state.trace("find_module called for: %s" % name) - - # Type ignored because mypy is confused about the `self` binding here. - spec = self._find_spec(name, path) # type: ignore - if ( - # the import machinery could not find a file to import - spec is None - # this is a namespace package (without `__init__.py`) - # there's nothing to rewrite there - # python3.6: `namespace` - # python3.7+: `None` - or spec.origin == "namespace" - or spec.origin is None - # we can only rewrite source files - or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) - # if the file doesn't exist, we can't rewrite it - or not os.path.exists(spec.origin) - ): - return None + + # Type ignored because mypy is confused about the `self` binding here. + spec = self._find_spec(name, path) # type: ignore + if ( + # the import machinery could not find a file to import + spec is None + # this is a namespace package (without `__init__.py`) + # there's nothing to rewrite there + # python3.6: `namespace` + # python3.7+: `None` + or spec.origin == "namespace" + or spec.origin is None + # we can only rewrite source files + or not isinstance(spec.loader, importlib.machinery.SourceFileLoader) + # if the file doesn't exist, we can't rewrite it + or not os.path.exists(spec.origin) + ): + return None else: - fn = spec.origin + fn = spec.origin - if not self._should_rewrite(name, fn, state): + if not self._should_rewrite(name, fn, state): return None - return importlib.util.spec_from_file_location( - name, - fn, - loader=self, - submodule_search_locations=spec.submodule_search_locations, - ) + return importlib.util.spec_from_file_location( + name, + fn, + loader=self, + submodule_search_locations=spec.submodule_search_locations, + ) + + def create_module( + self, spec: importlib.machinery.ModuleSpec + ) -> Optional[types.ModuleType]: + return None # default behaviour is fine + + def exec_module(self, module: types.ModuleType) -> None: + assert module.__spec__ is not None + assert module.__spec__.origin is not None + fn = Path(module.__spec__.origin) + state = self.config._store[assertstate_key] + + self._rewritten_names.add(module.__name__) - def create_module( - self, spec: importlib.machinery.ModuleSpec - ) -> Optional[types.ModuleType]: - return None # default behaviour is fine - - def exec_module(self, module: types.ModuleType) -> None: - assert module.__spec__ is not None - assert module.__spec__.origin is not None - fn = Path(module.__spec__.origin) - state = self.config._store[assertstate_key] - - self._rewritten_names.add(module.__name__) - # The requested module looks like a test file, so rewrite it. This is # the most magical part of the process: load the source, rewrite the # asserts, and load the rewritten source. We also cache the rewritten @@ -144,21 +144,21 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) # cached pyc is always a complete, valid pyc. Operations on it must be # atomic. POSIX's atomic rename comes in handy. write = not sys.dont_write_bytecode - cache_dir = get_cache_dir(fn) + cache_dir = get_cache_dir(fn) if write: - ok = try_makedirs(cache_dir) - if not ok: - write = False - state.trace(f"read only directory: {cache_dir}") - - cache_name = fn.name[:-3] + PYC_TAIL - pyc = cache_dir / cache_name + ok = try_makedirs(cache_dir) + if not ok: + write = False + state.trace(f"read only directory: {cache_dir}") + + cache_name = fn.name[:-3] + PYC_TAIL + pyc = cache_dir / cache_name # Notice that even if we're in a read-only directory, I'm going # to check for a cached pyc. This may not be optimal... - co = _read_pyc(fn, pyc, state.trace) + co = _read_pyc(fn, pyc, state.trace) if co is None: - state.trace(f"rewriting {fn!r}") - source_stat, co = _rewrite_test(fn, self.config) + state.trace(f"rewriting {fn!r}") + source_stat, co = _rewrite_test(fn, self.config) if write: self._writing_pyc = True try: @@ -166,23 +166,23 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) finally: self._writing_pyc = False else: - state.trace(f"found cached rewritten pyc for {fn}") - exec(co, module.__dict__) + state.trace(f"found cached rewritten pyc for {fn}") + exec(co, module.__dict__) - def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: - """A fast way to get out of rewriting modules. - - Profiling has shown that the call to PathFinder.find_spec (inside of - the find_spec from this class) is a major slowdown, so, this method - tries to filter what we're sure won't be rewritten before getting to - it. + def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool: + """A fast way to get out of rewriting modules. + + Profiling has shown that the call to PathFinder.find_spec (inside of + the find_spec from this class) is a major slowdown, so, this method + tries to filter what we're sure won't be rewritten before getting to + it. """ if self.session is not None and not self._session_paths_checked: self._session_paths_checked = True - for initial_path in self.session._initialpaths: + for initial_path in self.session._initialpaths: # Make something as c:/projects/my_project/path.py -> # ['c:', 'projects', 'my_project', 'path.py'] - parts = str(initial_path).split(os.path.sep) + parts = str(initial_path).split(os.path.sep) # add 'path' to basenames to be checked. self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) @@ -205,44 +205,44 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) if self._is_marked_for_rewrite(name, state): return False - state.trace(f"early skip of rewriting module: {name}") + state.trace(f"early skip of rewriting module: {name}") return True - def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: + def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool: # always rewrite conftest files - if os.path.basename(fn) == "conftest.py": - state.trace(f"rewriting conftest file: {fn!r}") + if os.path.basename(fn) == "conftest.py": + state.trace(f"rewriting conftest file: {fn!r}") return True if self.session is not None: - if self.session.isinitpath(py.path.local(fn)): - state.trace(f"matched test file (was specified on cmdline): {fn!r}") + if self.session.isinitpath(py.path.local(fn)): + state.trace(f"matched test file (was specified on cmdline): {fn!r}") return True # modules not passed explicitly on the command line are only # rewritten if they match the naming convention for test files - fn_path = PurePath(fn) + fn_path = PurePath(fn) for pat in self.fnpats: - if fnmatch_ex(pat, fn_path): - state.trace(f"matched test file {fn!r}") + if fnmatch_ex(pat, fn_path): + state.trace(f"matched test file {fn!r}") return True return self._is_marked_for_rewrite(name, state) - def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: + def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool: try: return self._marked_for_rewrite_cache[name] except KeyError: for marked in self._must_rewrite: if name == marked or name.startswith(marked + "."): - state.trace(f"matched marked file {name!r} (from {marked!r})") + state.trace(f"matched marked file {name!r} (from {marked!r})") self._marked_for_rewrite_cache[name] = True return True self._marked_for_rewrite_cache[name] = False return False - def mark_rewrite(self, *names: str) -> None: + def mark_rewrite(self, *names: str) -> None: """Mark import names as needing to be rewritten. The named module or package as well as any nested modules will @@ -252,155 +252,155 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) set(names).intersection(sys.modules).difference(self._rewritten_names) ) for name in already_imported: - mod = sys.modules[name] + mod = sys.modules[name] if not AssertionRewriter.is_rewrite_disabled( - mod.__doc__ or "" - ) and not isinstance(mod.__loader__, type(self)): + mod.__doc__ or "" + ) and not isinstance(mod.__loader__, type(self)): self._warn_already_imported(name) self._must_rewrite.update(names) self._marked_for_rewrite_cache.clear() - def _warn_already_imported(self, name: str) -> None: - from _pytest.warning_types import PytestAssertRewriteWarning + def _warn_already_imported(self, name: str) -> None: + from _pytest.warning_types import PytestAssertRewriteWarning - self.config.issue_config_time_warning( - PytestAssertRewriteWarning( - "Module already imported so cannot be rewritten: %s" % name - ), + self.config.issue_config_time_warning( + PytestAssertRewriteWarning( + "Module already imported so cannot be rewritten: %s" % name + ), stacklevel=5, ) - def get_data(self, pathname: Union[str, bytes]) -> bytes: - """Optional PEP302 get_data API.""" + def get_data(self, pathname: Union[str, bytes]) -> bytes: + """Optional PEP302 get_data API.""" with open(pathname, "rb") as f: return f.read() -def _write_pyc_fp( - fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType -) -> None: +def _write_pyc_fp( + fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType +) -> None: # Technically, we don't have to have the same pyc format as # (C)Python, since these "pycs" should never be seen by builtin - # import. However, there's little reason to deviate. - fp.write(importlib.util.MAGIC_NUMBER) - # https://www.python.org/dev/peps/pep-0552/ - if sys.version_info >= (3, 7): - flags = b"\x00\x00\x00\x00" - fp.write(flags) - # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) - mtime = int(source_stat.st_mtime) & 0xFFFFFFFF - size = source_stat.st_size & 0xFFFFFFFF - # "<LL" stands for 2 unsigned longs, little-endian. - fp.write(struct.pack("<LL", mtime, size)) - fp.write(marshal.dumps(co)) + # import. However, there's little reason to deviate. + fp.write(importlib.util.MAGIC_NUMBER) + # https://www.python.org/dev/peps/pep-0552/ + if sys.version_info >= (3, 7): + flags = b"\x00\x00\x00\x00" + fp.write(flags) + # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) + mtime = int(source_stat.st_mtime) & 0xFFFFFFFF + size = source_stat.st_size & 0xFFFFFFFF + # "<LL" stands for 2 unsigned longs, little-endian. + fp.write(struct.pack("<LL", mtime, size)) + fp.write(marshal.dumps(co)) + + +if sys.platform == "win32": + from atomicwrites import atomic_write + + def _write_pyc( + state: "AssertionState", + co: types.CodeType, + source_stat: os.stat_result, + pyc: Path, + ) -> bool: + try: + with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp: + _write_pyc_fp(fp, source_stat, co) + except OSError as e: + state.trace(f"error writing pyc file at {pyc}: {e}") + # we ignore any failure to write the cache file + # there are many reasons, permission-denied, pycache dir being a + # file etc. + return False + return True + +else: -if sys.platform == "win32": - from atomicwrites import atomic_write + def _write_pyc( + state: "AssertionState", + co: types.CodeType, + source_stat: os.stat_result, + pyc: Path, + ) -> bool: + proc_pyc = f"{pyc}.{os.getpid()}" + try: + fp = open(proc_pyc, "wb") + except OSError as e: + state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") + return False - def _write_pyc( - state: "AssertionState", - co: types.CodeType, - source_stat: os.stat_result, - pyc: Path, - ) -> bool: - try: - with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp: - _write_pyc_fp(fp, source_stat, co) - except OSError as e: - state.trace(f"error writing pyc file at {pyc}: {e}") - # we ignore any failure to write the cache file - # there are many reasons, permission-denied, pycache dir being a - # file etc. - return False - return True + try: + _write_pyc_fp(fp, source_stat, co) + os.rename(proc_pyc, os.fspath(pyc)) + except OSError as e: + state.trace(f"error writing pyc file at {pyc}: {e}") + # we ignore any failure to write the cache file + # there are many reasons, permission-denied, pycache dir being a + # file etc. + return False + finally: + fp.close() + return True -else: - - def _write_pyc( - state: "AssertionState", - co: types.CodeType, - source_stat: os.stat_result, - pyc: Path, - ) -> bool: - proc_pyc = f"{pyc}.{os.getpid()}" - try: - fp = open(proc_pyc, "wb") - except OSError as e: - state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") - return False - - try: - _write_pyc_fp(fp, source_stat, co) - os.rename(proc_pyc, os.fspath(pyc)) - except OSError as e: - state.trace(f"error writing pyc file at {pyc}: {e}") - # we ignore any failure to write the cache file - # there are many reasons, permission-denied, pycache dir being a - # file etc. - return False - finally: - fp.close() - return True - - -def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: - """Read and rewrite *fn* and return the code object.""" - fn_ = os.fspath(fn) - stat = os.stat(fn_) - with open(fn_, "rb") as f: - source = f.read() - tree = ast.parse(source, filename=fn_) - rewrite_asserts(tree, source, fn_, config) - co = compile(tree, fn_, "exec", dont_inherit=True) +def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: + """Read and rewrite *fn* and return the code object.""" + fn_ = os.fspath(fn) + stat = os.stat(fn_) + with open(fn_, "rb") as f: + source = f.read() + tree = ast.parse(source, filename=fn_) + rewrite_asserts(tree, source, fn_, config) + co = compile(tree, fn_, "exec", dont_inherit=True) return stat, co -def _read_pyc( - source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None -) -> Optional[types.CodeType]: +def _read_pyc( + source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None +) -> Optional[types.CodeType]: """Possibly read a pytest pyc containing rewritten code. Return rewritten code if successful or None if not. """ try: - fp = open(os.fspath(pyc), "rb") - except OSError: + fp = open(os.fspath(pyc), "rb") + except OSError: return None with fp: - # https://www.python.org/dev/peps/pep-0552/ - has_flags = sys.version_info >= (3, 7) + # https://www.python.org/dev/peps/pep-0552/ + has_flags = sys.version_info >= (3, 7) try: - stat_result = os.stat(os.fspath(source)) - mtime = int(stat_result.st_mtime) - size = stat_result.st_size - data = fp.read(16 if has_flags else 12) - except OSError as e: - trace(f"_read_pyc({source}): OSError {e}") + stat_result = os.stat(os.fspath(source)) + mtime = int(stat_result.st_mtime) + size = stat_result.st_size + data = fp.read(16 if has_flags else 12) + except OSError as e: + trace(f"_read_pyc({source}): OSError {e}") return None # Check for invalid or out of date pyc file. - if len(data) != (16 if has_flags else 12): - trace("_read_pyc(%s): invalid pyc (too short)" % source) + if len(data) != (16 if has_flags else 12): + trace("_read_pyc(%s): invalid pyc (too short)" % source) + return None + if data[:4] != importlib.util.MAGIC_NUMBER: + trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) + return None + if has_flags and data[4:8] != b"\x00\x00\x00\x00": + trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) + return None + mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8] + if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: + trace("_read_pyc(%s): out of date" % source) + return None + size_data = data[12 if has_flags else 8 : 16 if has_flags else 12] + if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: + trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) return None - if data[:4] != importlib.util.MAGIC_NUMBER: - trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) - return None - if has_flags and data[4:8] != b"\x00\x00\x00\x00": - trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) - return None - mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8] - if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: - trace("_read_pyc(%s): out of date" % source) - return None - size_data = data[12 if has_flags else 8 : 16 if has_flags else 12] - if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: - trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) - return None try: co = marshal.load(fp) except Exception as e: - trace(f"_read_pyc({source}): marshal.load error {e}") + trace(f"_read_pyc({source}): marshal.load error {e}") return None if not isinstance(co, types.CodeType): trace("_read_pyc(%s): not a code object" % source) @@ -408,18 +408,18 @@ def _read_pyc( return co -def rewrite_asserts( - mod: ast.Module, - source: bytes, - module_path: Optional[str] = None, - config: Optional[Config] = None, -) -> None: +def rewrite_asserts( + mod: ast.Module, + source: bytes, + module_path: Optional[str] = None, + config: Optional[Config] = None, +) -> None: """Rewrite the assert statements in mod.""" - AssertionRewriter(module_path, config, source).run(mod) + AssertionRewriter(module_path, config, source).run(mod) -def _saferepr(obj: object) -> str: - r"""Get a safe repr of an object for assertion error messages. +def _saferepr(obj: object) -> str: + r"""Get a safe repr of an object for assertion error messages. The assertion formatting (util.format_explanation()) requires newlines to be escaped since they are a special character for it. @@ -428,24 +428,24 @@ def _saferepr(obj: object) -> str: sequences, especially '\n{' and '\n}' are likely to be present in JSON reprs. """ - return saferepr(obj).replace("\n", "\\n") + return saferepr(obj).replace("\n", "\\n") -def _format_assertmsg(obj: object) -> str: - r"""Format the custom assertion message given. +def _format_assertmsg(obj: object) -> str: + r"""Format the custom assertion message given. For strings this simply replaces newlines with '\n~' so that util.format_explanation() will preserve them instead of escaping - newlines. For other objects saferepr() is used first. + newlines. For other objects saferepr() is used first. """ # reprlib appears to have a bug which means that if a string # contains a newline it gets escaped, however if an object has a # .__repr__() which contains newlines it does not get escaped. # However in either case we want to preserve the newline. - replaces = [("\n", "\n~"), ("%", "%%")] - if not isinstance(obj, str): - obj = saferepr(obj) - replaces.append(("\\n", "\n~")) + replaces = [("\n", "\n~"), ("%", "%%")] + if not isinstance(obj, str): + obj = saferepr(obj) + replaces.append(("\\n", "\n~")) for r1, r2 in replaces: obj = obj.replace(r1, r2) @@ -453,27 +453,27 @@ def _format_assertmsg(obj: object) -> str: return obj -def _should_repr_global_name(obj: object) -> bool: - if callable(obj): - return False +def _should_repr_global_name(obj: object) -> bool: + if callable(obj): + return False - try: - return not hasattr(obj, "__name__") - except Exception: - return True + try: + return not hasattr(obj, "__name__") + except Exception: + return True - -def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: + +def _format_boolop(explanations: Iterable[str], is_or: bool) -> str: explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" - return explanation.replace("%", "%%") + return explanation.replace("%", "%%") -def _call_reprcompare( - ops: Sequence[str], - results: Sequence[bool], - expls: Sequence[str], - each_obj: Sequence[object], -) -> str: +def _call_reprcompare( + ops: Sequence[str], + results: Sequence[bool], + expls: Sequence[str], + each_obj: Sequence[object], +) -> str: for i, res, expl in zip(range(len(ops)), results, expls): try: done = not res @@ -488,20 +488,20 @@ def _call_reprcompare( return expl -def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: - if util._assertion_pass is not None: - util._assertion_pass(lineno, orig, expl) +def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None: + if util._assertion_pass is not None: + util._assertion_pass(lineno, orig, expl) + - -def _check_if_assertion_pass_impl() -> bool: - """Check if any plugins implement the pytest_assertion_pass hook - in order not to generate explanation unecessarily (might be expensive).""" - return True if util._assertion_pass else False - - -UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} - -BINOP_MAP = { +def _check_if_assertion_pass_impl() -> bool: + """Check if any plugins implement the pytest_assertion_pass hook + in order not to generate explanation unecessarily (might be expensive).""" + return True if util._assertion_pass else False + + +UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} + +BINOP_MAP = { ast.BitOr: "|", ast.BitXor: "^", ast.BitAnd: "&", @@ -524,7 +524,7 @@ BINOP_MAP = { ast.IsNot: "is not", ast.In: "in", ast.NotIn: "not in", - ast.MatMult: "@", + ast.MatMult: "@", } @@ -543,60 +543,60 @@ def set_location(node, lineno, col_offset): return node -def _get_assertion_exprs(src: bytes) -> Dict[int, str]: - """Return a mapping from {lineno: "assertion test expression"}.""" - ret: Dict[int, str] = {} - - depth = 0 - lines: List[str] = [] - assert_lineno: Optional[int] = None - seen_lines: Set[int] = set() - - def _write_and_reset() -> None: - nonlocal depth, lines, assert_lineno, seen_lines - assert assert_lineno is not None - ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") - depth = 0 - lines = [] - assert_lineno = None - seen_lines = set() - - tokens = tokenize.tokenize(io.BytesIO(src).readline) - for tp, source, (lineno, offset), _, line in tokens: - if tp == tokenize.NAME and source == "assert": - assert_lineno = lineno - elif assert_lineno is not None: - # keep track of depth for the assert-message `,` lookup - if tp == tokenize.OP and source in "([{": - depth += 1 - elif tp == tokenize.OP and source in ")]}": - depth -= 1 - - if not lines: - lines.append(line[offset:]) - seen_lines.add(lineno) - # a non-nested comma separates the expression from the message - elif depth == 0 and tp == tokenize.OP and source == ",": - # one line assert with message - if lineno in seen_lines and len(lines) == 1: - offset_in_trimmed = offset + len(lines[-1]) - len(line) - lines[-1] = lines[-1][:offset_in_trimmed] - # multi-line assert with message - elif lineno in seen_lines: - lines[-1] = lines[-1][:offset] - # multi line assert with escapd newline before message - else: - lines.append(line[:offset]) - _write_and_reset() - elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: - _write_and_reset() - elif lines and lineno not in seen_lines: - lines.append(line) - seen_lines.add(lineno) - - return ret - - +def _get_assertion_exprs(src: bytes) -> Dict[int, str]: + """Return a mapping from {lineno: "assertion test expression"}.""" + ret: Dict[int, str] = {} + + depth = 0 + lines: List[str] = [] + assert_lineno: Optional[int] = None + seen_lines: Set[int] = set() + + def _write_and_reset() -> None: + nonlocal depth, lines, assert_lineno, seen_lines + assert assert_lineno is not None + ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") + depth = 0 + lines = [] + assert_lineno = None + seen_lines = set() + + tokens = tokenize.tokenize(io.BytesIO(src).readline) + for tp, source, (lineno, offset), _, line in tokens: + if tp == tokenize.NAME and source == "assert": + assert_lineno = lineno + elif assert_lineno is not None: + # keep track of depth for the assert-message `,` lookup + if tp == tokenize.OP and source in "([{": + depth += 1 + elif tp == tokenize.OP and source in ")]}": + depth -= 1 + + if not lines: + lines.append(line[offset:]) + seen_lines.add(lineno) + # a non-nested comma separates the expression from the message + elif depth == 0 and tp == tokenize.OP and source == ",": + # one line assert with message + if lineno in seen_lines and len(lines) == 1: + offset_in_trimmed = offset + len(lines[-1]) - len(line) + lines[-1] = lines[-1][:offset_in_trimmed] + # multi-line assert with message + elif lineno in seen_lines: + lines[-1] = lines[-1][:offset] + # multi line assert with escapd newline before message + else: + lines.append(line[:offset]) + _write_and_reset() + elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: + _write_and_reset() + elif lines and lineno not in seen_lines: + lines.append(line) + seen_lines.add(lineno) + + return ret + + class AssertionRewriter(ast.NodeVisitor): """Assertion rewriting implementation. @@ -613,8 +613,8 @@ class AssertionRewriter(ast.NodeVisitor): original assert statement: it rewrites the test of an assertion to provide intermediate values and replace it with an if statement which raises an assertion error with a detailed explanation in - case the expression is false and calls pytest_assertion_pass hook - if expression is true. + case the expression is false and calls pytest_assertion_pass hook + if expression is true. For this .visit_Assert() uses the visitor pattern to visit all the AST nodes of the ast.Assert.test field, each visit call returning @@ -632,10 +632,10 @@ class AssertionRewriter(ast.NodeVisitor): by statements. Variables are created using .variable() and have the form of "@py_assert0". - :expl_stmts: The AST statements which will be executed to get - data from the assertion. This is the code which will construct - the detailed assertion message that is used in the AssertionError - or for the pytest_assertion_pass hook. + :expl_stmts: The AST statements which will be executed to get + data from the assertion. This is the code which will construct + the detailed assertion message that is used in the AssertionError + or for the pytest_assertion_pass hook. :explanation_specifiers: A dict filled by .explanation_param() with %-formatting placeholders and their corresponding @@ -650,32 +650,32 @@ class AssertionRewriter(ast.NodeVisitor): by the other visitors. """ - def __init__( - self, module_path: Optional[str], config: Optional[Config], source: bytes - ) -> None: - super().__init__() + def __init__( + self, module_path: Optional[str], config: Optional[Config], source: bytes + ) -> None: + super().__init__() self.module_path = module_path self.config = config - if config is not None: - self.enable_assertion_pass_hook = config.getini( - "enable_assertion_pass_hook" - ) - else: - self.enable_assertion_pass_hook = False - self.source = source + if config is not None: + self.enable_assertion_pass_hook = config.getini( + "enable_assertion_pass_hook" + ) + else: + self.enable_assertion_pass_hook = False + self.source = source + + @functools.lru_cache(maxsize=1) + def _assert_expr_to_lineno(self) -> Dict[int, str]: + return _get_assertion_exprs(self.source) - @functools.lru_cache(maxsize=1) - def _assert_expr_to_lineno(self) -> Dict[int, str]: - return _get_assertion_exprs(self.source) - - def run(self, mod: ast.Module) -> None: + def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" if not mod.body: # Nothing to do. return - - # We'll insert some special imports at the top of the module, but after any - # docstrings and __future__ imports, so first figure out where that is. + + # We'll insert some special imports at the top of the module, but after any + # docstrings and __future__ imports, so first figure out where that is. doc = getattr(mod, "docstring", None) expect_docstring = doc is None if doc is not None and self.is_rewrite_disabled(doc): @@ -693,48 +693,48 @@ class AssertionRewriter(ast.NodeVisitor): return expect_docstring = False elif ( - isinstance(item, ast.ImportFrom) - and item.level == 0 - and item.module == "__future__" + isinstance(item, ast.ImportFrom) + and item.level == 0 + and item.module == "__future__" ): - pass - else: + pass + else: break pos += 1 - # Special case: for a decorated function, set the lineno to that of the - # first decorator, not the `def`. Issue #4984. - if isinstance(item, ast.FunctionDef) and item.decorator_list: - lineno = item.decorator_list[0].lineno + # Special case: for a decorated function, set the lineno to that of the + # first decorator, not the `def`. Issue #4984. + if isinstance(item, ast.FunctionDef) and item.decorator_list: + lineno = item.decorator_list[0].lineno else: lineno = item.lineno - # Now actually insert the special imports. - if sys.version_info >= (3, 10): - aliases = [ - ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), - ast.alias( - "_pytest.assertion.rewrite", - "@pytest_ar", - lineno=lineno, - col_offset=0, - ), - ] - else: - aliases = [ - ast.alias("builtins", "@py_builtins"), - ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), - ] + # Now actually insert the special imports. + if sys.version_info >= (3, 10): + aliases = [ + ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), + ast.alias( + "_pytest.assertion.rewrite", + "@pytest_ar", + lineno=lineno, + col_offset=0, + ), + ] + else: + aliases = [ + ast.alias("builtins", "@py_builtins"), + ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), + ] imports = [ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ] mod.body[pos:pos] = imports - + # Collect asserts. - nodes: List[ast.AST] = [mod] + nodes: List[ast.AST] = [mod] while nodes: node = nodes.pop() for name, field in ast.iter_fields(node): if isinstance(field, list): - new: List[ast.AST] = [] + new: List[ast.AST] = [] for i, child in enumerate(field): if isinstance(child, ast.Assert): # Transform assert. @@ -753,38 +753,38 @@ class AssertionRewriter(ast.NodeVisitor): nodes.append(field) @staticmethod - def is_rewrite_disabled(docstring: str) -> bool: + def is_rewrite_disabled(docstring: str) -> bool: return "PYTEST_DONT_REWRITE" in docstring - def variable(self) -> str: + def variable(self) -> str: """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. name = "@py_assert" + str(next(self.variable_counter)) self.variables.append(name) return name - def assign(self, expr: ast.expr) -> ast.Name: + def assign(self, expr: ast.expr) -> ast.Name: """Give *expr* a name.""" name = self.variable() self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) return ast.Name(name, ast.Load()) - def display(self, expr: ast.expr) -> ast.expr: - """Call saferepr on the expression.""" - return self.helper("_saferepr", expr) + def display(self, expr: ast.expr) -> ast.expr: + """Call saferepr on the expression.""" + return self.helper("_saferepr", expr) - def helper(self, name: str, *args: ast.expr) -> ast.expr: + def helper(self, name: str, *args: ast.expr) -> ast.expr: """Call a helper in this module.""" py_name = ast.Name("@pytest_ar", ast.Load()) - attr = ast.Attribute(py_name, name, ast.Load()) - return ast.Call(attr, list(args), []) + attr = ast.Attribute(py_name, name, ast.Load()) + return ast.Call(attr, list(args), []) - def builtin(self, name: str) -> ast.Attribute: + def builtin(self, name: str) -> ast.Attribute: """Return the builtin called *name*.""" builtin_name = ast.Name("@py_builtins", ast.Load()) return ast.Attribute(builtin_name, name, ast.Load()) - def explanation_param(self, expr: ast.expr) -> str: + def explanation_param(self, expr: ast.expr) -> str: """Return a new named %-formatting placeholder for expr. This creates a %-formatting placeholder for expr in the @@ -796,7 +796,7 @@ class AssertionRewriter(ast.NodeVisitor): self.explanation_specifiers[specifier] = expr return "%(" + specifier + ")s" - def push_format_context(self) -> None: + def push_format_context(self) -> None: """Create a new formatting context. The format context is used for when an explanation wants to @@ -806,15 +806,15 @@ class AssertionRewriter(ast.NodeVisitor): to format a string of %-formatted values as added by .explanation_param(). """ - self.explanation_specifiers: Dict[str, ast.expr] = {} + self.explanation_specifiers: Dict[str, ast.expr] = {} self.stack.append(self.explanation_specifiers) - def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: + def pop_format_context(self, expl_expr: ast.expr) -> ast.Name: """Format the %-formatted string with current format context. - The expl_expr should be an str ast.expr instance constructed from + The expl_expr should be an str ast.expr instance constructed from the %-placeholders created by .explanation_param(). This will - add the required code to format said string to .expl_stmts and + add the required code to format said string to .expl_stmts and return the ast.Name instance of the formatted string. """ current = self.stack.pop() @@ -824,18 +824,18 @@ class AssertionRewriter(ast.NodeVisitor): format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) - if self.enable_assertion_pass_hook: - self.format_variables.append(name) - self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) + if self.enable_assertion_pass_hook: + self.format_variables.append(name) + self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form)) return ast.Name(name, ast.Load()) - def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: + def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]: """Handle expressions we don't have custom code for.""" assert isinstance(node, ast.expr) res = self.assign(node) return res, self.explanation_param(self.display(res)) - def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: + def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: """Return the AST statements to replace the ast.Assert instance. This rewrites the test of an assertion to provide @@ -844,173 +844,173 @@ class AssertionRewriter(ast.NodeVisitor): the expression is false. """ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: - from _pytest.warning_types import PytestAssertRewriteWarning + from _pytest.warning_types import PytestAssertRewriteWarning import warnings - # TODO: This assert should not be needed. - assert self.module_path is not None + # TODO: This assert should not be needed. + assert self.module_path is not None warnings.warn_explicit( - PytestAssertRewriteWarning( - "assertion is always true, perhaps remove parentheses?" - ), + PytestAssertRewriteWarning( + "assertion is always true, perhaps remove parentheses?" + ), category=None, - filename=os.fspath(self.module_path), + filename=os.fspath(self.module_path), lineno=assert_.lineno, ) - self.statements: List[ast.stmt] = [] - self.variables: List[str] = [] + self.statements: List[ast.stmt] = [] + self.variables: List[str] = [] self.variable_counter = itertools.count() - - if self.enable_assertion_pass_hook: - self.format_variables: List[str] = [] - - self.stack: List[Dict[str, ast.expr]] = [] - self.expl_stmts: List[ast.stmt] = [] + + if self.enable_assertion_pass_hook: + self.format_variables: List[str] = [] + + self.stack: List[Dict[str, ast.expr]] = [] + self.expl_stmts: List[ast.stmt] = [] self.push_format_context() # Rewrite assert into a bunch of statements. top_condition, explanation = self.visit(assert_.test) - - negation = ast.UnaryOp(ast.Not(), top_condition) - - if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook - msg = self.pop_format_context(ast.Str(explanation)) - - # Failed - if assert_.msg: - assertmsg = self.helper("_format_assertmsg", assert_.msg) - gluestr = "\n>assert " - else: - assertmsg = ast.Str("") - gluestr = "assert " - err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) - err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) - err_name = ast.Name("AssertionError", ast.Load()) - fmt = self.helper("_format_explanation", err_msg) - exc = ast.Call(err_name, [fmt], []) - raise_ = ast.Raise(exc, None) - statements_fail = [] - statements_fail.extend(self.expl_stmts) - statements_fail.append(raise_) - - # Passed - fmt_pass = self.helper("_format_explanation", msg) - orig = self._assert_expr_to_lineno()[assert_.lineno] - hook_call_pass = ast.Expr( - self.helper( - "_call_assertion_pass", - ast.Num(assert_.lineno), - ast.Str(orig), - fmt_pass, - ) - ) - # If any hooks implement assert_pass hook - hook_impl_test = ast.If( - self.helper("_check_if_assertion_pass_impl"), - self.expl_stmts + [hook_call_pass], - [], - ) - statements_pass = [hook_impl_test] - - # Test for assertion condition - main_test = ast.If(negation, statements_fail, statements_pass) - self.statements.append(main_test) - if self.format_variables: - variables = [ - ast.Name(name, ast.Store()) for name in self.format_variables - ] - clear_format = ast.Assign(variables, ast.NameConstant(None)) - self.statements.append(clear_format) - - else: # Original assertion rewriting - # Create failure message. - body = self.expl_stmts - self.statements.append(ast.If(negation, body, [])) - if assert_.msg: - assertmsg = self.helper("_format_assertmsg", assert_.msg) - explanation = "\n>assert " + explanation - else: - assertmsg = ast.Str("") - explanation = "assert " + explanation - template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) - msg = self.pop_format_context(template) - fmt = self.helper("_format_explanation", msg) - err_name = ast.Name("AssertionError", ast.Load()) - exc = ast.Call(err_name, [fmt], []) + + negation = ast.UnaryOp(ast.Not(), top_condition) + + if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook + msg = self.pop_format_context(ast.Str(explanation)) + + # Failed + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + gluestr = "\n>assert " + else: + assertmsg = ast.Str("") + gluestr = "assert " + err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg) + err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) + err_name = ast.Name("AssertionError", ast.Load()) + fmt = self.helper("_format_explanation", err_msg) + exc = ast.Call(err_name, [fmt], []) raise_ = ast.Raise(exc, None) - - body.append(raise_) - + statements_fail = [] + statements_fail.extend(self.expl_stmts) + statements_fail.append(raise_) + + # Passed + fmt_pass = self.helper("_format_explanation", msg) + orig = self._assert_expr_to_lineno()[assert_.lineno] + hook_call_pass = ast.Expr( + self.helper( + "_call_assertion_pass", + ast.Num(assert_.lineno), + ast.Str(orig), + fmt_pass, + ) + ) + # If any hooks implement assert_pass hook + hook_impl_test = ast.If( + self.helper("_check_if_assertion_pass_impl"), + self.expl_stmts + [hook_call_pass], + [], + ) + statements_pass = [hook_impl_test] + + # Test for assertion condition + main_test = ast.If(negation, statements_fail, statements_pass) + self.statements.append(main_test) + if self.format_variables: + variables = [ + ast.Name(name, ast.Store()) for name in self.format_variables + ] + clear_format = ast.Assign(variables, ast.NameConstant(None)) + self.statements.append(clear_format) + + else: # Original assertion rewriting + # Create failure message. + body = self.expl_stmts + self.statements.append(ast.If(negation, body, [])) + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + explanation = "\n>assert " + explanation + else: + assertmsg = ast.Str("") + explanation = "assert " + explanation + template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) + msg = self.pop_format_context(template) + fmt = self.helper("_format_explanation", msg) + err_name = ast.Name("AssertionError", ast.Load()) + exc = ast.Call(err_name, [fmt], []) + raise_ = ast.Raise(exc, None) + + body.append(raise_) + # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] - clear = ast.Assign(variables, ast.NameConstant(None)) + clear = ast.Assign(variables, ast.NameConstant(None)) self.statements.append(clear) # Fix line numbers. for stmt in self.statements: set_location(stmt, assert_.lineno, assert_.col_offset) return self.statements - def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: + def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. - locs = ast.Call(self.builtin("locals"), [], []) + locs = ast.Call(self.builtin("locals"), [], []) inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) - dorepr = self.helper("_should_repr_global_name", name) + dorepr = self.helper("_should_repr_global_name", name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) return name, self.explanation_param(expr) - def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: + def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: res_var = self.variable() expl_list = self.assign(ast.List([], ast.Load())) app = ast.Attribute(expl_list, "append", ast.Load()) is_or = int(isinstance(boolop.op, ast.Or)) body = save = self.statements - fail_save = self.expl_stmts + fail_save = self.expl_stmts levels = len(boolop.values) - 1 self.push_format_context() - # Process each operand, short-circuiting if needed. + # Process each operand, short-circuiting if needed. for i, v in enumerate(boolop.values): if i: - fail_inner: List[ast.stmt] = [] + fail_inner: List[ast.stmt] = [] # cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa - self.expl_stmts = fail_inner + self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa + self.expl_stmts = fail_inner self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) expl_format = self.pop_format_context(ast.Str(expl)) - call = ast.Call(app, [expl_format], []) - self.expl_stmts.append(ast.Expr(call)) + call = ast.Call(app, [expl_format], []) + self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond: ast.expr = res + cond: ast.expr = res if is_or: cond = ast.UnaryOp(ast.Not(), cond) - inner: List[ast.stmt] = [] + inner: List[ast.stmt] = [] self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner self.statements = save - self.expl_stmts = fail_save - expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) + self.expl_stmts = fail_save + expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) expl = self.pop_format_context(expl_template) return ast.Name(res_var, ast.Load()), self.explanation_param(expl) - def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: - pattern = UNARY_MAP[unary.op.__class__] + def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]: + pattern = UNARY_MAP[unary.op.__class__] operand_res, operand_expl = self.visit(unary.operand) res = self.assign(ast.UnaryOp(unary.op, operand_res)) return res, pattern % (operand_expl,) - def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: - symbol = BINOP_MAP[binop.op.__class__] + def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]: + symbol = BINOP_MAP[binop.op.__class__] left_expr, left_expl = self.visit(binop.left) right_expr, right_expl = self.visit(binop.right) - explanation = f"({left_expl} {symbol} {right_expl})" + explanation = f"({left_expl} {symbol} {right_expl})" res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) return res, explanation - def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: + def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: new_func, func_expl = self.visit(call.func) arg_expls = [] new_args = [] @@ -1027,20 +1027,20 @@ class AssertionRewriter(ast.NodeVisitor): else: # **args have `arg` keywords with an .arg of None arg_expls.append("**" + expl) - expl = "{}({})".format(func_expl, ", ".join(arg_expls)) + expl = "{}({})".format(func_expl, ", ".join(arg_expls)) new_call = ast.Call(new_func, new_args, new_kwargs) res = self.assign(new_call) res_expl = self.explanation_param(self.display(res)) - outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" + outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}" return res, outer_expl - def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: - # A Starred node can appear in a function call. + def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]: + # A Starred node can appear in a function call. res, expl = self.visit(starred.value) new_starred = ast.Starred(res, starred.ctx) return new_starred, "*" + expl - def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: + def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: if not isinstance(attr.ctx, ast.Load): return self.generic_visit(attr) value, value_expl = self.visit(attr.value) @@ -1050,11 +1050,11 @@ class AssertionRewriter(ast.NodeVisitor): expl = pat % (res_expl, res_expl, value_expl, attr.attr) return res, expl - def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: + def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): - left_expl = f"({left_expl})" + left_expl = f"({left_expl})" res_variables = [self.variable() for i in range(len(comp.ops))] load_names = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] @@ -1065,61 +1065,61 @@ class AssertionRewriter(ast.NodeVisitor): for i, op, next_operand in it: next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): - next_expl = f"({next_expl})" + next_expl = f"({next_expl})" results.append(next_res) - sym = BINOP_MAP[op.__class__] + sym = BINOP_MAP[op.__class__] syms.append(ast.Str(sym)) - expl = f"{left_expl} {sym} {next_expl}" + expl = f"{left_expl} {sym} {next_expl}" expls.append(ast.Str(expl)) res_expr = ast.Compare(left_res, [op], [next_res]) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl # Use pytest.assertion.util._reprcompare if that's available. expl_call = self.helper( - "_call_reprcompare", + "_call_reprcompare", ast.Tuple(syms, ast.Load()), ast.Tuple(load_names, ast.Load()), ast.Tuple(expls, ast.Load()), ast.Tuple(results, ast.Load()), ) if len(comp.ops) > 1: - res: ast.expr = ast.BoolOp(ast.And(), load_names) + res: ast.expr = ast.BoolOp(ast.And(), load_names) else: res = load_names[0] return res, self.explanation_param(self.pop_format_context(expl_call)) - - -def try_makedirs(cache_dir: Path) -> bool: - """Attempt to create the given directory and sub-directories exist. - - Returns True if successful or if it already exists. - """ - try: - os.makedirs(os.fspath(cache_dir), exist_ok=True) - except (FileNotFoundError, NotADirectoryError, FileExistsError): - # One of the path components was not a directory: - # - we're in a zip file - # - it is a file - return False - except PermissionError: - return False - except OSError as e: - # as of now, EROFS doesn't have an equivalent OSError-subclass - if e.errno == errno.EROFS: - return False - raise - return True - - -def get_cache_dir(file_path: Path) -> Path: - """Return the cache directory to write .pyc files for the given .py file path.""" - if sys.version_info >= (3, 8) and sys.pycache_prefix: - # given: - # prefix = '/tmp/pycs' - # path = '/home/user/proj/test_app.py' - # we want: - # '/tmp/pycs/home/user/proj' - return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) - else: - # classic pycache directory - return file_path.parent / "__pycache__" + + +def try_makedirs(cache_dir: Path) -> bool: + """Attempt to create the given directory and sub-directories exist. + + Returns True if successful or if it already exists. + """ + try: + os.makedirs(os.fspath(cache_dir), exist_ok=True) + except (FileNotFoundError, NotADirectoryError, FileExistsError): + # One of the path components was not a directory: + # - we're in a zip file + # - it is a file + return False + except PermissionError: + return False + except OSError as e: + # as of now, EROFS doesn't have an equivalent OSError-subclass + if e.errno == errno.EROFS: + return False + raise + return True + + +def get_cache_dir(file_path: Path) -> Path: + """Return the cache directory to write .pyc files for the given .py file path.""" + if sys.version_info >= (3, 8) and sys.pycache_prefix: + # given: + # prefix = '/tmp/pycs' + # path = '/home/user/proj/test_app.py' + # we want: + # '/tmp/pycs/home/user/proj' + return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) + else: + # classic pycache directory + return file_path.parent / "__pycache__" diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py index 00a2697363b..5ba9ddca75a 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py +++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py @@ -1,47 +1,47 @@ -"""Utilities for truncating assertion output. +"""Utilities for truncating assertion output. Current default behaviour is to truncate assertion explanations at ~8 terminal lines, unless running in "-vv" mode or running on CI. """ import os -from typing import List -from typing import Optional +from typing import List +from typing import Optional + +from _pytest.nodes import Item + -from _pytest.nodes import Item - - DEFAULT_MAX_LINES = 8 DEFAULT_MAX_CHARS = 8 * 80 USAGE_MSG = "use '-vv' to show" -def truncate_if_required( - explanation: List[str], item: Item, max_length: Optional[int] = None -) -> List[str]: - """Truncate this assertion explanation if the given test item is eligible.""" +def truncate_if_required( + explanation: List[str], item: Item, max_length: Optional[int] = None +) -> List[str]: + """Truncate this assertion explanation if the given test item is eligible.""" if _should_truncate_item(item): return _truncate_explanation(explanation) return explanation -def _should_truncate_item(item: Item) -> bool: - """Whether or not this test item is eligible for truncation.""" +def _should_truncate_item(item: Item) -> bool: + """Whether or not this test item is eligible for truncation.""" verbose = item.config.option.verbose return verbose < 2 and not _running_on_ci() -def _running_on_ci() -> bool: +def _running_on_ci() -> bool: """Check if we're currently running on a CI system.""" env_vars = ["CI", "BUILD_NUMBER"] return any(var in os.environ for var in env_vars) -def _truncate_explanation( - input_lines: List[str], - max_lines: Optional[int] = None, - max_chars: Optional[int] = None, -) -> List[str]: - """Truncate given list of strings that makes up the assertion explanation. +def _truncate_explanation( + input_lines: List[str], + max_lines: Optional[int] = None, + max_chars: Optional[int] = None, +) -> List[str]: + """Truncate given list of strings that makes up the assertion explanation. Truncates to either 8 lines, or 640 characters - whichever the input reaches first. The remaining lines will be replaced by a usage message. @@ -70,15 +70,15 @@ def _truncate_explanation( truncated_line_count += 1 # Account for the part-truncated final line msg = "...Full output truncated" if truncated_line_count == 1: - msg += f" ({truncated_line_count} line hidden)" + msg += f" ({truncated_line_count} line hidden)" else: - msg += f" ({truncated_line_count} lines hidden)" - msg += f", {USAGE_MSG}" - truncated_explanation.extend(["", str(msg)]) + msg += f" ({truncated_line_count} lines hidden)" + msg += f", {USAGE_MSG}" + truncated_explanation.extend(["", str(msg)]) return truncated_explanation -def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]: +def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]: # Check if truncation required if len("".join(input_lines)) <= max_chars: return input_lines diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py index 60e8f3a6567..da1ffd15e37 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/util.py +++ b/contrib/python/pytest/py3/_pytest/assertion/util.py @@ -1,34 +1,34 @@ -"""Utilities for assertion debugging.""" -import collections.abc +"""Utilities for assertion debugging.""" +import collections.abc import pprint -from typing import AbstractSet -from typing import Any -from typing import Callable -from typing import Iterable -from typing import List -from typing import Mapping -from typing import Optional -from typing import Sequence +from typing import AbstractSet +from typing import Any +from typing import Callable +from typing import Iterable +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence import _pytest._code -from _pytest import outcomes -from _pytest._io.saferepr import _pformat_dispatch -from _pytest._io.saferepr import safeformat -from _pytest._io.saferepr import saferepr +from _pytest import outcomes +from _pytest._io.saferepr import _pformat_dispatch +from _pytest._io.saferepr import safeformat +from _pytest._io.saferepr import saferepr # The _reprcompare attribute on the util module is used by the new assertion # interpretation code and assertion rewriter to detect this plugin was # loaded and in turn call the hooks defined here as part of the # DebugInterpreter. -_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None +_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None -# Works similarly as _reprcompare attribute. Is populated with the hook call -# when pytest_runtest_setup is called. -_assertion_pass: Optional[Callable[[int, str, str], None]] = None +# Works similarly as _reprcompare attribute. Is populated with the hook call +# when pytest_runtest_setup is called. +_assertion_pass: Optional[Callable[[int, str, str], None]] = None -def format_explanation(explanation: str) -> str: - r"""Format an explanation. +def format_explanation(explanation: str) -> str: + r"""Format an explanation. Normally all embedded newlines are escaped, however there are three exceptions: \n{, \n} and \n~. The first two are intended @@ -39,17 +39,17 @@ def format_explanation(explanation: str) -> str: """ lines = _split_explanation(explanation) result = _format_lines(lines) - return "\n".join(result) + return "\n".join(result) -def _split_explanation(explanation: str) -> List[str]: - r"""Return a list of individual lines in the explanation. +def _split_explanation(explanation: str) -> List[str]: + r"""Return a list of individual lines in the explanation. This will return a list of lines split on '\n{', '\n}' and '\n~'. Any other newlines will be escaped and appear in the line as the literal '\n' characters. """ - raw_lines = (explanation or "").split("\n") + raw_lines = (explanation or "").split("\n") lines = [raw_lines[0]] for values in raw_lines[1:]: if values and values[0] in ["{", "}", "~", ">"]: @@ -59,28 +59,28 @@ def _split_explanation(explanation: str) -> List[str]: return lines -def _format_lines(lines: Sequence[str]) -> List[str]: - """Format the individual lines. +def _format_lines(lines: Sequence[str]) -> List[str]: + """Format the individual lines. - This will replace the '{', '}' and '~' characters of our mini formatting - language with the proper 'where ...', 'and ...' and ' + ...' text, taking - care of indentation along the way. + This will replace the '{', '}' and '~' characters of our mini formatting + language with the proper 'where ...', 'and ...' and ' + ...' text, taking + care of indentation along the way. Return a list of formatted lines. """ - result = list(lines[:1]) + result = list(lines[:1]) stack = [0] stackcnt = [0] for line in lines[1:]: if line.startswith("{"): if stackcnt[-1]: - s = "and " + s = "and " else: - s = "where " + s = "where " stack.append(len(result)) stackcnt[-1] += 1 stackcnt.append(0) - result.append(" +" + " " * (len(stack) - 1) + s + line[1:]) + result.append(" +" + " " * (len(stack) - 1) + s + line[1:]) elif line.startswith("}"): stack.pop() stackcnt.pop() @@ -89,79 +89,79 @@ def _format_lines(lines: Sequence[str]) -> List[str]: assert line[0] in ["~", ">"] stack[-1] += 1 indent = len(stack) if line.startswith("~") else len(stack) - 1 - result.append(" " * indent + line[1:]) + result.append(" " * indent + line[1:]) assert len(stack) == 1 return result -def issequence(x: Any) -> bool: - return isinstance(x, collections.abc.Sequence) and not isinstance(x, str) +def issequence(x: Any) -> bool: + return isinstance(x, collections.abc.Sequence) and not isinstance(x, str) -def istext(x: Any) -> bool: - return isinstance(x, str) +def istext(x: Any) -> bool: + return isinstance(x, str) -def isdict(x: Any) -> bool: - return isinstance(x, dict) +def isdict(x: Any) -> bool: + return isinstance(x, dict) -def isset(x: Any) -> bool: - return isinstance(x, (set, frozenset)) +def isset(x: Any) -> bool: + return isinstance(x, (set, frozenset)) + + +def isnamedtuple(obj: Any) -> bool: + return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None + + +def isdatacls(obj: Any) -> bool: + return getattr(obj, "__dataclass_fields__", None) is not None + + +def isattrs(obj: Any) -> bool: + return getattr(obj, "__attrs_attrs__", None) is not None + + +def isiterable(obj: Any) -> bool: + try: + iter(obj) + return not istext(obj) + except TypeError: + return False + + +def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]: + """Return specialised explanations for some operators/operands.""" + verbose = config.getoption("verbose") + if verbose > 1: + left_repr = safeformat(left) + right_repr = safeformat(right) + else: + # XXX: "15 chars indentation" is wrong + # ("E AssertionError: assert "); should use term width. + maxsize = ( + 80 - 15 - len(op) - 2 + ) // 2 # 15 chars indentation, 1 space around op + left_repr = saferepr(left, maxsize=maxsize) + right_repr = saferepr(right, maxsize=maxsize) + + summary = f"{left_repr} {op} {right_repr}" - -def isnamedtuple(obj: Any) -> bool: - return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None - - -def isdatacls(obj: Any) -> bool: - return getattr(obj, "__dataclass_fields__", None) is not None - - -def isattrs(obj: Any) -> bool: - return getattr(obj, "__attrs_attrs__", None) is not None - - -def isiterable(obj: Any) -> bool: - try: - iter(obj) - return not istext(obj) - except TypeError: - return False - - -def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]: - """Return specialised explanations for some operators/operands.""" - verbose = config.getoption("verbose") - if verbose > 1: - left_repr = safeformat(left) - right_repr = safeformat(right) - else: - # XXX: "15 chars indentation" is wrong - # ("E AssertionError: assert "); should use term width. - maxsize = ( - 80 - 15 - len(op) - 2 - ) // 2 # 15 chars indentation, 1 space around op - left_repr = saferepr(left, maxsize=maxsize) - right_repr = saferepr(right, maxsize=maxsize) - - summary = f"{left_repr} {op} {right_repr}" - explanation = None try: if op == "==": - explanation = _compare_eq_any(left, right, verbose) + explanation = _compare_eq_any(left, right, verbose) elif op == "not in": if istext(left) and istext(right): explanation = _notin_text(left, right, verbose) - except outcomes.Exit: - raise + except outcomes.Exit: + raise except Exception: explanation = [ - "(pytest_assertion plugin: representation of details failed: {}.".format( - _pytest._code.ExceptionInfo.from_current()._getreprcrash() - ), - " Probably an object has a faulty __repr__.)", + "(pytest_assertion plugin: representation of details failed: {}.".format( + _pytest._code.ExceptionInfo.from_current()._getreprcrash() + ), + " Probably an object has a faulty __repr__.)", ] if not explanation: @@ -170,44 +170,44 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[ return [summary] + explanation -def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: - explanation = [] - if istext(left) and istext(right): - explanation = _diff_text(left, right, verbose) - else: - if type(left) == type(right) and ( - isdatacls(left) or isattrs(left) or isnamedtuple(left) - ): - # Note: unlike dataclasses/attrs, namedtuples compare only the - # field values, not the type or field names. But this branch - # intentionally only handles the same-type case, which was often - # used in older code bases before dataclasses/attrs were available. - explanation = _compare_eq_cls(left, right, verbose) - elif issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right, verbose) - elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right, verbose) - elif isdict(left) and isdict(right): - explanation = _compare_eq_dict(left, right, verbose) - elif verbose > 0: - explanation = _compare_eq_verbose(left, right) - if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, verbose) - explanation.extend(expl) - return explanation - - -def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: - """Return the explanation for the diff between text. +def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: + explanation = [] + if istext(left) and istext(right): + explanation = _diff_text(left, right, verbose) + else: + if type(left) == type(right) and ( + isdatacls(left) or isattrs(left) or isnamedtuple(left) + ): + # Note: unlike dataclasses/attrs, namedtuples compare only the + # field values, not the type or field names. But this branch + # intentionally only handles the same-type case, which was often + # used in older code bases before dataclasses/attrs were available. + explanation = _compare_eq_cls(left, right, verbose) + elif issequence(left) and issequence(right): + explanation = _compare_eq_sequence(left, right, verbose) + elif isset(left) and isset(right): + explanation = _compare_eq_set(left, right, verbose) + elif isdict(left) and isdict(right): + explanation = _compare_eq_dict(left, right, verbose) + elif verbose > 0: + explanation = _compare_eq_verbose(left, right) + if isiterable(left) and isiterable(right): + expl = _compare_eq_iterable(left, right, verbose) + explanation.extend(expl) + return explanation + + +def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: + """Return the explanation for the diff between text. Unless --verbose is used this will skip leading and trailing characters which are identical to keep the diff minimal. """ from difflib import ndiff - explanation: List[str] = [] + explanation: List[str] = [] - if verbose < 1: + if verbose < 1: i = 0 # just in case left or right has zero length for i in range(min(len(left), len(right))): if left[i] != right[i]: @@ -215,7 +215,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: if i > 42: i -= 10 # Provide some context explanation = [ - "Skipping %s identical leading characters in diff, use -v to show" % i + "Skipping %s identical leading characters in diff, use -v to show" % i ] left = left[i:] right = right[i:] @@ -226,8 +226,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: if i > 42: i -= 10 # Provide some context explanation += [ - "Skipping {} identical trailing " - "characters in diff, use -v to show".format(i) + "Skipping {} identical trailing " + "characters in diff, use -v to show".format(i) ] left = left[:-i] right = right[:-i] @@ -235,243 +235,243 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: if left.isspace() or right.isspace(): left = repr(str(left)) right = repr(str(right)) - explanation += ["Strings contain only whitespace, escaping them using repr()"] - # "right" is the expected base against which we compare "left", - # see https://github.com/pytest-dev/pytest/issues/3333 + explanation += ["Strings contain only whitespace, escaping them using repr()"] + # "right" is the expected base against which we compare "left", + # see https://github.com/pytest-dev/pytest/issues/3333 explanation += [ line.strip("\n") - for line in ndiff(right.splitlines(keepends), left.splitlines(keepends)) + for line in ndiff(right.splitlines(keepends), left.splitlines(keepends)) ] return explanation -def _compare_eq_verbose(left: Any, right: Any) -> List[str]: - keepends = True - left_lines = repr(left).splitlines(keepends) - right_lines = repr(right).splitlines(keepends) - - explanation: List[str] = [] - explanation += ["+" + line for line in left_lines] - explanation += ["-" + line for line in right_lines] - - return explanation - - -def _surrounding_parens_on_own_lines(lines: List[str]) -> None: - """Move opening/closing parenthesis/bracket to own lines.""" - opening = lines[0][:1] - if opening in ["(", "[", "{"]: - lines[0] = " " + lines[0][1:] - lines[:] = [opening] + lines - closing = lines[-1][-1:] - if closing in [")", "]", "}"]: - lines[-1] = lines[-1][:-1] + "," - lines[:] = lines + [closing] - - -def _compare_eq_iterable( - left: Iterable[Any], right: Iterable[Any], verbose: int = 0 -) -> List[str]: +def _compare_eq_verbose(left: Any, right: Any) -> List[str]: + keepends = True + left_lines = repr(left).splitlines(keepends) + right_lines = repr(right).splitlines(keepends) + + explanation: List[str] = [] + explanation += ["+" + line for line in left_lines] + explanation += ["-" + line for line in right_lines] + + return explanation + + +def _surrounding_parens_on_own_lines(lines: List[str]) -> None: + """Move opening/closing parenthesis/bracket to own lines.""" + opening = lines[0][:1] + if opening in ["(", "[", "{"]: + lines[0] = " " + lines[0][1:] + lines[:] = [opening] + lines + closing = lines[-1][-1:] + if closing in [")", "]", "}"]: + lines[-1] = lines[-1][:-1] + "," + lines[:] = lines + [closing] + + +def _compare_eq_iterable( + left: Iterable[Any], right: Iterable[Any], verbose: int = 0 +) -> List[str]: if not verbose: - return ["Use -v to get the full diff"] + return ["Use -v to get the full diff"] # dynamic import to speedup pytest import difflib - left_formatting = pprint.pformat(left).splitlines() - right_formatting = pprint.pformat(right).splitlines() - - # Re-format for different output lengths. - lines_left = len(left_formatting) - lines_right = len(right_formatting) - if lines_left != lines_right: - left_formatting = _pformat_dispatch(left).splitlines() - right_formatting = _pformat_dispatch(right).splitlines() - - if lines_left > 1 or lines_right > 1: - _surrounding_parens_on_own_lines(left_formatting) - _surrounding_parens_on_own_lines(right_formatting) - - explanation = ["Full diff:"] - # "right" is the expected base against which we compare "left", - # see https://github.com/pytest-dev/pytest/issues/3333 + left_formatting = pprint.pformat(left).splitlines() + right_formatting = pprint.pformat(right).splitlines() + + # Re-format for different output lengths. + lines_left = len(left_formatting) + lines_right = len(right_formatting) + if lines_left != lines_right: + left_formatting = _pformat_dispatch(left).splitlines() + right_formatting = _pformat_dispatch(right).splitlines() + + if lines_left > 1 or lines_right > 1: + _surrounding_parens_on_own_lines(left_formatting) + _surrounding_parens_on_own_lines(right_formatting) + + explanation = ["Full diff:"] + # "right" is the expected base against which we compare "left", + # see https://github.com/pytest-dev/pytest/issues/3333 explanation.extend( - line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting) + line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting) ) return explanation -def _compare_eq_sequence( - left: Sequence[Any], right: Sequence[Any], verbose: int = 0 -) -> List[str]: - comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) - explanation: List[str] = [] - len_left = len(left) - len_right = len(right) - for i in range(min(len_left, len_right)): +def _compare_eq_sequence( + left: Sequence[Any], right: Sequence[Any], verbose: int = 0 +) -> List[str]: + comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) + explanation: List[str] = [] + len_left = len(left) + len_right = len(right) + for i in range(min(len_left, len_right)): if left[i] != right[i]: - if comparing_bytes: - # when comparing bytes, we want to see their ascii representation - # instead of their numeric values (#5260) - # using a slice gives us the ascii representation: - # >>> s = b'foo' - # >>> s[0] - # 102 - # >>> s[0:1] - # b'f' - left_value = left[i : i + 1] - right_value = right[i : i + 1] - else: - left_value = left[i] - right_value = right[i] - - explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"] + if comparing_bytes: + # when comparing bytes, we want to see their ascii representation + # instead of their numeric values (#5260) + # using a slice gives us the ascii representation: + # >>> s = b'foo' + # >>> s[0] + # 102 + # >>> s[0:1] + # b'f' + left_value = left[i : i + 1] + right_value = right[i : i + 1] + else: + left_value = left[i] + right_value = right[i] + + explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"] break - - if comparing_bytes: - # when comparing bytes, it doesn't help to show the "sides contain one or more - # items" longer explanation, so skip it - - return explanation - - len_diff = len_left - len_right - if len_diff: - if len_diff > 0: - dir_with_more = "Left" - extra = saferepr(left[len_right]) - else: - len_diff = 0 - len_diff - dir_with_more = "Right" - extra = saferepr(right[len_left]) - - if len_diff == 1: - explanation += [f"{dir_with_more} contains one more item: {extra}"] - else: - explanation += [ - "%s contains %d more items, first extra item: %s" - % (dir_with_more, len_diff, extra) - ] + + if comparing_bytes: + # when comparing bytes, it doesn't help to show the "sides contain one or more + # items" longer explanation, so skip it + + return explanation + + len_diff = len_left - len_right + if len_diff: + if len_diff > 0: + dir_with_more = "Left" + extra = saferepr(left[len_right]) + else: + len_diff = 0 - len_diff + dir_with_more = "Right" + extra = saferepr(right[len_left]) + + if len_diff == 1: + explanation += [f"{dir_with_more} contains one more item: {extra}"] + else: + explanation += [ + "%s contains %d more items, first extra item: %s" + % (dir_with_more, len_diff, extra) + ] return explanation -def _compare_eq_set( - left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0 -) -> List[str]: +def _compare_eq_set( + left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0 +) -> List[str]: explanation = [] diff_left = left - right diff_right = right - left if diff_left: - explanation.append("Extra items in the left set:") + explanation.append("Extra items in the left set:") for item in diff_left: - explanation.append(saferepr(item)) + explanation.append(saferepr(item)) if diff_right: - explanation.append("Extra items in the right set:") + explanation.append("Extra items in the right set:") for item in diff_right: - explanation.append(saferepr(item)) + explanation.append(saferepr(item)) return explanation -def _compare_eq_dict( - left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0 -) -> List[str]: - explanation: List[str] = [] - set_left = set(left) - set_right = set(right) - common = set_left.intersection(set_right) +def _compare_eq_dict( + left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0 +) -> List[str]: + explanation: List[str] = [] + set_left = set(left) + set_right = set(right) + common = set_left.intersection(set_right) same = {k: left[k] for k in common if left[k] == right[k]} if same and verbose < 2: - explanation += ["Omitting %s identical items, use -vv to show" % len(same)] + explanation += ["Omitting %s identical items, use -vv to show" % len(same)] elif same: - explanation += ["Common items:"] + explanation += ["Common items:"] explanation += pprint.pformat(same).splitlines() diff = {k for k in common if left[k] != right[k]} if diff: - explanation += ["Differing items:"] + explanation += ["Differing items:"] for k in diff: - explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})] - extra_left = set_left - set_right - len_extra_left = len(extra_left) - if len_extra_left: - explanation.append( - "Left contains %d more item%s:" - % (len_extra_left, "" if len_extra_left == 1 else "s") - ) + explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})] + extra_left = set_left - set_right + len_extra_left = len(extra_left) + if len_extra_left: + explanation.append( + "Left contains %d more item%s:" + % (len_extra_left, "" if len_extra_left == 1 else "s") + ) explanation.extend( pprint.pformat({k: left[k] for k in extra_left}).splitlines() ) - extra_right = set_right - set_left - len_extra_right = len(extra_right) - if len_extra_right: - explanation.append( - "Right contains %d more item%s:" - % (len_extra_right, "" if len_extra_right == 1 else "s") - ) + extra_right = set_right - set_left + len_extra_right = len(extra_right) + if len_extra_right: + explanation.append( + "Right contains %d more item%s:" + % (len_extra_right, "" if len_extra_right == 1 else "s") + ) explanation.extend( pprint.pformat({k: right[k] for k in extra_right}).splitlines() ) return explanation -def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: - if isdatacls(left): - all_fields = left.__dataclass_fields__ - fields_to_check = [field for field, info in all_fields.items() if info.compare] - elif isattrs(left): - all_fields = left.__attrs_attrs__ - fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] - elif isnamedtuple(left): - fields_to_check = left._fields - else: - assert False - - indent = " " - same = [] - diff = [] - for field in fields_to_check: - if getattr(left, field) == getattr(right, field): - same.append(field) - else: - diff.append(field) - - explanation = [] - if same or diff: - explanation += [""] - if same and verbose < 2: - explanation.append("Omitting %s identical items, use -vv to show" % len(same)) - elif same: - explanation += ["Matching attributes:"] - explanation += pprint.pformat(same).splitlines() - if diff: - explanation += ["Differing attributes:"] - explanation += pprint.pformat(diff).splitlines() - for field in diff: - field_left = getattr(left, field) - field_right = getattr(right, field) - explanation += [ - "", - "Drill down into differing attribute %s:" % field, - ("%s%s: %r != %r") % (indent, field, field_left, field_right), - ] - explanation += [ - indent + line - for line in _compare_eq_any(field_left, field_right, verbose) - ] - return explanation - - -def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: +def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: + if isdatacls(left): + all_fields = left.__dataclass_fields__ + fields_to_check = [field for field, info in all_fields.items() if info.compare] + elif isattrs(left): + all_fields = left.__attrs_attrs__ + fields_to_check = [field.name for field in all_fields if getattr(field, "eq")] + elif isnamedtuple(left): + fields_to_check = left._fields + else: + assert False + + indent = " " + same = [] + diff = [] + for field in fields_to_check: + if getattr(left, field) == getattr(right, field): + same.append(field) + else: + diff.append(field) + + explanation = [] + if same or diff: + explanation += [""] + if same and verbose < 2: + explanation.append("Omitting %s identical items, use -vv to show" % len(same)) + elif same: + explanation += ["Matching attributes:"] + explanation += pprint.pformat(same).splitlines() + if diff: + explanation += ["Differing attributes:"] + explanation += pprint.pformat(diff).splitlines() + for field in diff: + field_left = getattr(left, field) + field_right = getattr(right, field) + explanation += [ + "", + "Drill down into differing attribute %s:" % field, + ("%s%s: %r != %r") % (indent, field, field_left, field_right), + ] + explanation += [ + indent + line + for line in _compare_eq_any(field_left, field_right, verbose) + ] + return explanation + + +def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]: index = text.find(term) head = text[:index] tail = text[index + len(term) :] correct_text = head + tail - diff = _diff_text(text, correct_text, verbose) - newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)] + diff = _diff_text(text, correct_text, verbose) + newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)] for line in diff: - if line.startswith("Skipping"): + if line.startswith("Skipping"): continue - if line.startswith("- "): + if line.startswith("- "): continue - if line.startswith("+ "): - newdiff.append(" " + line[2:]) + if line.startswith("+ "): + newdiff.append(" " + line[2:]) else: newdiff.append(line) return newdiff |
