summaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py3/_pytest/assertion
diff options
context:
space:
mode:
authorshadchin <[email protected]>2022-02-10 16:44:39 +0300
committerDaniil Cherednik <[email protected]>2022-02-10 16:44:39 +0300
commite9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch)
tree64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/python/pytest/py3/_pytest/assertion
parent2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff)
Restoring authorship annotation for <[email protected]>. Commit 2 of 2.
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/__init__.py154
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py1146
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/truncate.py48
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/util.py610
4 files changed, 979 insertions, 979 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
index 430eb2791b0..a18cf198df0 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
@@ -1,25 +1,25 @@
-"""Support for presenting detailed information in failing assertions."""
+"""Support for presenting detailed information in failing assertions."""
import sys
-from typing import Any
-from typing import Generator
-from typing import List
-from typing import Optional
-from typing import TYPE_CHECKING
+from typing import Any
+from typing import Generator
+from typing import List
+from typing import Optional
+from typing import TYPE_CHECKING
from _pytest.assertion import rewrite
from _pytest.assertion import truncate
from _pytest.assertion import util
-from _pytest.assertion.rewrite import assertstate_key
-from _pytest.config import Config
-from _pytest.config import hookimpl
-from _pytest.config.argparsing import Parser
-from _pytest.nodes import Item
+from _pytest.assertion.rewrite import assertstate_key
+from _pytest.config import Config
+from _pytest.config import hookimpl
+from _pytest.config.argparsing import Parser
+from _pytest.nodes import Item
-if TYPE_CHECKING:
- from _pytest.main import Session
+if TYPE_CHECKING:
+ from _pytest.main import Session
-
-def pytest_addoption(parser: Parser) -> None:
+
+def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
@@ -28,23 +28,23 @@ def pytest_addoption(parser: Parser) -> None:
choices=("rewrite", "plain"),
default="rewrite",
metavar="MODE",
- help=(
- "Control assertion debugging tools.\n"
- "'plain' performs no assertion debugging.\n"
- "'rewrite' (the default) rewrites assert statements in test modules"
- " on import to provide assert expression information."
- ),
+ help=(
+ "Control assertion debugging tools.\n"
+ "'plain' performs no assertion debugging.\n"
+ "'rewrite' (the default) rewrites assert statements in test modules"
+ " on import to provide assert expression information."
+ ),
+ )
+ parser.addini(
+ "enable_assertion_pass_hook",
+ type="bool",
+ default=False,
+ help="Enables the pytest_assertion_pass hook."
+ "Make sure to delete any previously generated pyc cache files.",
)
- parser.addini(
- "enable_assertion_pass_hook",
- type="bool",
- default=False,
- help="Enables the pytest_assertion_pass hook."
- "Make sure to delete any previously generated pyc cache files.",
- )
-def register_assert_rewrite(*names: str) -> None:
+def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
@@ -53,48 +53,48 @@ def register_assert_rewrite(*names: str) -> None:
actually imported, usually in your __init__.py if you are a plugin
using a package.
- :raises TypeError: If the given module names are not strings.
+ :raises TypeError: If the given module names are not strings.
"""
for name in names:
if not isinstance(name, str):
- msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
+ msg = "expected module names as *args, got {0} instead" # type: ignore[unreachable]
raise TypeError(msg.format(repr(names)))
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
importhook = hook
break
else:
- # TODO(typing): Add a protocol for mark_rewrite() and use it
- # for importhook and for PytestPluginManager.rewrite_hook.
- importhook = DummyRewriteHook() # type: ignore
+ # TODO(typing): Add a protocol for mark_rewrite() and use it
+ # for importhook and for PytestPluginManager.rewrite_hook.
+ importhook = DummyRewriteHook() # type: ignore
importhook.mark_rewrite(*names)
-class DummyRewriteHook:
+class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
- def mark_rewrite(self, *names: str) -> None:
+ def mark_rewrite(self, *names: str) -> None:
pass
-class AssertionState:
+class AssertionState:
"""State for the assertion plugin."""
- def __init__(self, config: Config, mode) -> None:
+ def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
- self.hook: Optional[rewrite.AssertionRewritingHook] = None
+ self.hook: Optional[rewrite.AssertionRewritingHook] = None
-def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
+def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
- config._store[assertstate_key] = AssertionState(config, "rewrite")
- config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
+ config._store[assertstate_key] = AssertionState(config, "rewrite")
+ config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
- config._store[assertstate_key].trace("installed rewrite import hook")
+ config._store[assertstate_key].trace("installed rewrite import hook")
- def undo() -> None:
- hook = config._store[assertstate_key].hook
+ def undo() -> None:
+ hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
@@ -102,30 +102,30 @@ def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
return hook
-def pytest_collection(session: "Session") -> None:
- # This hook is only called when test modules are collected
+def pytest_collection(session: "Session") -> None:
+ # This hook is only called when test modules are collected
# so for example not in the master process of pytest-xdist
- # (which does not collect test modules).
- assertstate = session.config._store.get(assertstate_key, None)
+ # (which does not collect test modules).
+ assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(session)
-@hookimpl(tryfirst=True, hookwrapper=True)
-def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
- """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
+@hookimpl(tryfirst=True, hookwrapper=True)
+def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
+ """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
- The rewrite module will use util._reprcompare if it exists to use custom
- reporting via the pytest_assertrepr_compare hook. This sets up this custom
+ The rewrite module will use util._reprcompare if it exists to use custom
+ reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
- ihook = item.ihook
+ ihook = item.ihook
+
+ def callbinrepr(op, left: object, right: object) -> Optional[str]:
+ """Call the pytest_assertrepr_compare hook and prepare the result.
- def callbinrepr(op, left: object, right: object) -> Optional[str]:
- """Call the pytest_assertrepr_compare hook and prepare the result.
-
This uses the first result from the hook and then ensures the
following:
* Overly verbose explanations are truncated unless configured otherwise
@@ -138,42 +138,42 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
The result can be formatted by util.format_explanation() for
pretty printing.
"""
- hook_result = ihook.pytest_assertrepr_compare(
+ hook_result = ihook.pytest_assertrepr_compare(
config=item.config, op=op, left=left, right=right
)
for new_expl in hook_result:
if new_expl:
new_expl = truncate.truncate_if_required(new_expl, item)
new_expl = [line.replace("\n", "\\n") for line in new_expl]
- res = "\n~".join(new_expl)
+ res = "\n~".join(new_expl)
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
- return None
+ return None
- saved_assert_hooks = util._reprcompare, util._assertion_pass
+ saved_assert_hooks = util._reprcompare, util._assertion_pass
util._reprcompare = callbinrepr
- if ihook.pytest_assertion_pass.get_hookimpls():
+ if ihook.pytest_assertion_pass.get_hookimpls():
+
+ def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
+ ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
+
+ util._assertion_pass = call_assertion_pass_hook
+
+ yield
- def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
- ihook.pytest_assertion_pass(item=item, lineno=lineno, orig=orig, expl=expl)
+ util._reprcompare, util._assertion_pass = saved_assert_hooks
- util._assertion_pass = call_assertion_pass_hook
- yield
-
- util._reprcompare, util._assertion_pass = saved_assert_hooks
-
-
-def pytest_sessionfinish(session: "Session") -> None:
- assertstate = session.config._store.get(assertstate_key, None)
+def pytest_sessionfinish(session: "Session") -> None:
+ assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:
assertstate.hook.set_session(None)
-def pytest_assertrepr_compare(
- config: Config, op: str, left: Any, right: Any
-) -> Optional[List[str]]:
- return util.assertrepr_compare(config=config, op=op, left=left, right=right)
+def pytest_assertrepr_compare(
+ config: Config, op: str, left: Any, right: Any
+) -> Optional[List[str]]:
+ return util.assertrepr_compare(config=config, op=op, left=left, right=right)
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
index c79dfd9a686..37ff076aab5 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -1,140 +1,140 @@
-"""Rewrite assertion AST to produce nice error messages."""
+"""Rewrite assertion AST to produce nice error messages."""
import ast
import errno
-import functools
-import importlib.abc
-import importlib.machinery
-import importlib.util
-import io
+import functools
+import importlib.abc
+import importlib.machinery
+import importlib.util
+import io
import itertools
import marshal
import os
import struct
import sys
-import tokenize
+import tokenize
import types
-from pathlib import Path
-from pathlib import PurePath
-from typing import Callable
-from typing import Dict
-from typing import IO
-from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Set
-from typing import Tuple
-from typing import TYPE_CHECKING
-from typing import Union
+from pathlib import Path
+from pathlib import PurePath
+from typing import Callable
+from typing import Dict
+from typing import IO
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
-import py
-
-from _pytest._io.saferepr import saferepr
-from _pytest._version import version
+import py
+
+from _pytest._io.saferepr import saferepr
+from _pytest._version import version
from _pytest.assertion import util
-from _pytest.assertion.util import ( # noqa: F401
- format_explanation as _format_explanation,
-)
-from _pytest.config import Config
-from _pytest.main import Session
+from _pytest.assertion.util import ( # noqa: F401
+ format_explanation as _format_explanation,
+)
+from _pytest.config import Config
+from _pytest.main import Session
from _pytest.pathlib import fnmatch_ex
-from _pytest.store import StoreKey
+from _pytest.store import StoreKey
-if TYPE_CHECKING:
- from _pytest.assertion import AssertionState
+if TYPE_CHECKING:
+ from _pytest.assertion import AssertionState
-assertstate_key = StoreKey["AssertionState"]()
+assertstate_key = StoreKey["AssertionState"]()
-# pytest caches rewritten pycs in pycache dirs
-PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
-PYC_EXT = ".py" + (__debug__ and "c" or "o")
-PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
+# pytest caches rewritten pycs in pycache dirs
+PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
+PYC_EXT = ".py" + (__debug__ and "c" or "o")
+PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
-class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
- """PEP302/PEP451 import hook which rewrites asserts."""
+class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
+ """PEP302/PEP451 import hook which rewrites asserts."""
- def __init__(self, config: Config) -> None:
+ def __init__(self, config: Config) -> None:
self.config = config
- try:
- self.fnpats = config.getini("python_files")
- except ValueError:
- self.fnpats = ["test_*.py", "*_test.py"]
- self.session: Optional[Session] = None
- self._rewritten_names: Set[str] = set()
- self._must_rewrite: Set[str] = set()
+ try:
+ self.fnpats = config.getini("python_files")
+ except ValueError:
+ self.fnpats = ["test_*.py", "*_test.py"]
+ self.session: Optional[Session] = None
+ self._rewritten_names: Set[str] = set()
+ self._must_rewrite: Set[str] = set()
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
- self._marked_for_rewrite_cache: Dict[str, bool] = {}
+ self._marked_for_rewrite_cache: Dict[str, bool] = {}
self._session_paths_checked = False
- def set_session(self, session: Optional[Session]) -> None:
+ def set_session(self, session: Optional[Session]) -> None:
self.session = session
self._session_paths_checked = False
- # Indirection so we can mock calls to find_spec originated from the hook during testing
- _find_spec = importlib.machinery.PathFinder.find_spec
+ # Indirection so we can mock calls to find_spec originated from the hook during testing
+ _find_spec = importlib.machinery.PathFinder.find_spec
- def find_spec(
- self,
- name: str,
- path: Optional[Sequence[Union[str, bytes]]] = None,
- target: Optional[types.ModuleType] = None,
- ) -> Optional[importlib.machinery.ModuleSpec]:
+ def find_spec(
+ self,
+ name: str,
+ path: Optional[Sequence[Union[str, bytes]]] = None,
+ target: Optional[types.ModuleType] = None,
+ ) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc:
return None
- state = self.config._store[assertstate_key]
+ state = self.config._store[assertstate_key]
if self._early_rewrite_bailout(name, state):
return None
state.trace("find_module called for: %s" % name)
-
- # Type ignored because mypy is confused about the `self` binding here.
- spec = self._find_spec(name, path) # type: ignore
- if (
- # the import machinery could not find a file to import
- spec is None
- # this is a namespace package (without `__init__.py`)
- # there's nothing to rewrite there
- # python3.6: `namespace`
- # python3.7+: `None`
- or spec.origin == "namespace"
- or spec.origin is None
- # we can only rewrite source files
- or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
- # if the file doesn't exist, we can't rewrite it
- or not os.path.exists(spec.origin)
- ):
- return None
+
+ # Type ignored because mypy is confused about the `self` binding here.
+ spec = self._find_spec(name, path) # type: ignore
+ if (
+ # the import machinery could not find a file to import
+ spec is None
+ # this is a namespace package (without `__init__.py`)
+ # there's nothing to rewrite there
+ # python3.6: `namespace`
+ # python3.7+: `None`
+ or spec.origin == "namespace"
+ or spec.origin is None
+ # we can only rewrite source files
+ or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
+ # if the file doesn't exist, we can't rewrite it
+ or not os.path.exists(spec.origin)
+ ):
+ return None
else:
- fn = spec.origin
+ fn = spec.origin
- if not self._should_rewrite(name, fn, state):
+ if not self._should_rewrite(name, fn, state):
return None
- return importlib.util.spec_from_file_location(
- name,
- fn,
- loader=self,
- submodule_search_locations=spec.submodule_search_locations,
- )
+ return importlib.util.spec_from_file_location(
+ name,
+ fn,
+ loader=self,
+ submodule_search_locations=spec.submodule_search_locations,
+ )
+
+ def create_module(
+ self, spec: importlib.machinery.ModuleSpec
+ ) -> Optional[types.ModuleType]:
+ return None # default behaviour is fine
+
+ def exec_module(self, module: types.ModuleType) -> None:
+ assert module.__spec__ is not None
+ assert module.__spec__.origin is not None
+ fn = Path(module.__spec__.origin)
+ state = self.config._store[assertstate_key]
+
+ self._rewritten_names.add(module.__name__)
- def create_module(
- self, spec: importlib.machinery.ModuleSpec
- ) -> Optional[types.ModuleType]:
- return None # default behaviour is fine
-
- def exec_module(self, module: types.ModuleType) -> None:
- assert module.__spec__ is not None
- assert module.__spec__.origin is not None
- fn = Path(module.__spec__.origin)
- state = self.config._store[assertstate_key]
-
- self._rewritten_names.add(module.__name__)
-
# The requested module looks like a test file, so rewrite it. This is
# the most magical part of the process: load the source, rewrite the
# asserts, and load the rewritten source. We also cache the rewritten
@@ -144,21 +144,21 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
# cached pyc is always a complete, valid pyc. Operations on it must be
# atomic. POSIX's atomic rename comes in handy.
write = not sys.dont_write_bytecode
- cache_dir = get_cache_dir(fn)
+ cache_dir = get_cache_dir(fn)
if write:
- ok = try_makedirs(cache_dir)
- if not ok:
- write = False
- state.trace(f"read only directory: {cache_dir}")
-
- cache_name = fn.name[:-3] + PYC_TAIL
- pyc = cache_dir / cache_name
+ ok = try_makedirs(cache_dir)
+ if not ok:
+ write = False
+ state.trace(f"read only directory: {cache_dir}")
+
+ cache_name = fn.name[:-3] + PYC_TAIL
+ pyc = cache_dir / cache_name
# Notice that even if we're in a read-only directory, I'm going
# to check for a cached pyc. This may not be optimal...
- co = _read_pyc(fn, pyc, state.trace)
+ co = _read_pyc(fn, pyc, state.trace)
if co is None:
- state.trace(f"rewriting {fn!r}")
- source_stat, co = _rewrite_test(fn, self.config)
+ state.trace(f"rewriting {fn!r}")
+ source_stat, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
@@ -166,23 +166,23 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
finally:
self._writing_pyc = False
else:
- state.trace(f"found cached rewritten pyc for {fn}")
- exec(co, module.__dict__)
+ state.trace(f"found cached rewritten pyc for {fn}")
+ exec(co, module.__dict__)
- def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
- """A fast way to get out of rewriting modules.
-
- Profiling has shown that the call to PathFinder.find_spec (inside of
- the find_spec from this class) is a major slowdown, so, this method
- tries to filter what we're sure won't be rewritten before getting to
- it.
+ def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
+ """A fast way to get out of rewriting modules.
+
+ Profiling has shown that the call to PathFinder.find_spec (inside of
+ the find_spec from this class) is a major slowdown, so, this method
+ tries to filter what we're sure won't be rewritten before getting to
+ it.
"""
if self.session is not None and not self._session_paths_checked:
self._session_paths_checked = True
- for initial_path in self.session._initialpaths:
+ for initial_path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
- parts = str(initial_path).split(os.path.sep)
+ parts = str(initial_path).split(os.path.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
@@ -205,44 +205,44 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
if self._is_marked_for_rewrite(name, state):
return False
- state.trace(f"early skip of rewriting module: {name}")
+ state.trace(f"early skip of rewriting module: {name}")
return True
- def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
+ def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
# always rewrite conftest files
- if os.path.basename(fn) == "conftest.py":
- state.trace(f"rewriting conftest file: {fn!r}")
+ if os.path.basename(fn) == "conftest.py":
+ state.trace(f"rewriting conftest file: {fn!r}")
return True
if self.session is not None:
- if self.session.isinitpath(py.path.local(fn)):
- state.trace(f"matched test file (was specified on cmdline): {fn!r}")
+ if self.session.isinitpath(py.path.local(fn)):
+ state.trace(f"matched test file (was specified on cmdline): {fn!r}")
return True
# modules not passed explicitly on the command line are only
# rewritten if they match the naming convention for test files
- fn_path = PurePath(fn)
+ fn_path = PurePath(fn)
for pat in self.fnpats:
- if fnmatch_ex(pat, fn_path):
- state.trace(f"matched test file {fn!r}")
+ if fnmatch_ex(pat, fn_path):
+ state.trace(f"matched test file {fn!r}")
return True
return self._is_marked_for_rewrite(name, state)
- def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
+ def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
for marked in self._must_rewrite:
if name == marked or name.startswith(marked + "."):
- state.trace(f"matched marked file {name!r} (from {marked!r})")
+ state.trace(f"matched marked file {name!r} (from {marked!r})")
self._marked_for_rewrite_cache[name] = True
return True
self._marked_for_rewrite_cache[name] = False
return False
- def mark_rewrite(self, *names: str) -> None:
+ def mark_rewrite(self, *names: str) -> None:
"""Mark import names as needing to be rewritten.
The named module or package as well as any nested modules will
@@ -252,155 +252,155 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
set(names).intersection(sys.modules).difference(self._rewritten_names)
)
for name in already_imported:
- mod = sys.modules[name]
+ mod = sys.modules[name]
if not AssertionRewriter.is_rewrite_disabled(
- mod.__doc__ or ""
- ) and not isinstance(mod.__loader__, type(self)):
+ mod.__doc__ or ""
+ ) and not isinstance(mod.__loader__, type(self)):
self._warn_already_imported(name)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
- def _warn_already_imported(self, name: str) -> None:
- from _pytest.warning_types import PytestAssertRewriteWarning
+ def _warn_already_imported(self, name: str) -> None:
+ from _pytest.warning_types import PytestAssertRewriteWarning
- self.config.issue_config_time_warning(
- PytestAssertRewriteWarning(
- "Module already imported so cannot be rewritten: %s" % name
- ),
+ self.config.issue_config_time_warning(
+ PytestAssertRewriteWarning(
+ "Module already imported so cannot be rewritten: %s" % name
+ ),
stacklevel=5,
)
- def get_data(self, pathname: Union[str, bytes]) -> bytes:
- """Optional PEP302 get_data API."""
+ def get_data(self, pathname: Union[str, bytes]) -> bytes:
+ """Optional PEP302 get_data API."""
with open(pathname, "rb") as f:
return f.read()
-def _write_pyc_fp(
- fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
-) -> None:
+def _write_pyc_fp(
+ fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
+) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
- # import. However, there's little reason to deviate.
- fp.write(importlib.util.MAGIC_NUMBER)
- # https://www.python.org/dev/peps/pep-0552/
- if sys.version_info >= (3, 7):
- flags = b"\x00\x00\x00\x00"
- fp.write(flags)
- # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
- mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
- size = source_stat.st_size & 0xFFFFFFFF
- # "<LL" stands for 2 unsigned longs, little-endian.
- fp.write(struct.pack("<LL", mtime, size))
- fp.write(marshal.dumps(co))
+ # import. However, there's little reason to deviate.
+ fp.write(importlib.util.MAGIC_NUMBER)
+ # https://www.python.org/dev/peps/pep-0552/
+ if sys.version_info >= (3, 7):
+ flags = b"\x00\x00\x00\x00"
+ fp.write(flags)
+ # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
+ mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
+ size = source_stat.st_size & 0xFFFFFFFF
+ # "<LL" stands for 2 unsigned longs, little-endian.
+ fp.write(struct.pack("<LL", mtime, size))
+ fp.write(marshal.dumps(co))
+
+
+if sys.platform == "win32":
+ from atomicwrites import atomic_write
+
+ def _write_pyc(
+ state: "AssertionState",
+ co: types.CodeType,
+ source_stat: os.stat_result,
+ pyc: Path,
+ ) -> bool:
+ try:
+ with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp:
+ _write_pyc_fp(fp, source_stat, co)
+ except OSError as e:
+ state.trace(f"error writing pyc file at {pyc}: {e}")
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, pycache dir being a
+ # file etc.
+ return False
+ return True
+
+else:
-if sys.platform == "win32":
- from atomicwrites import atomic_write
+ def _write_pyc(
+ state: "AssertionState",
+ co: types.CodeType,
+ source_stat: os.stat_result,
+ pyc: Path,
+ ) -> bool:
+ proc_pyc = f"{pyc}.{os.getpid()}"
+ try:
+ fp = open(proc_pyc, "wb")
+ except OSError as e:
+ state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
+ return False
- def _write_pyc(
- state: "AssertionState",
- co: types.CodeType,
- source_stat: os.stat_result,
- pyc: Path,
- ) -> bool:
- try:
- with atomic_write(os.fspath(pyc), mode="wb", overwrite=True) as fp:
- _write_pyc_fp(fp, source_stat, co)
- except OSError as e:
- state.trace(f"error writing pyc file at {pyc}: {e}")
- # we ignore any failure to write the cache file
- # there are many reasons, permission-denied, pycache dir being a
- # file etc.
- return False
- return True
+ try:
+ _write_pyc_fp(fp, source_stat, co)
+ os.rename(proc_pyc, os.fspath(pyc))
+ except OSError as e:
+ state.trace(f"error writing pyc file at {pyc}: {e}")
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, pycache dir being a
+ # file etc.
+ return False
+ finally:
+ fp.close()
+ return True
-else:
-
- def _write_pyc(
- state: "AssertionState",
- co: types.CodeType,
- source_stat: os.stat_result,
- pyc: Path,
- ) -> bool:
- proc_pyc = f"{pyc}.{os.getpid()}"
- try:
- fp = open(proc_pyc, "wb")
- except OSError as e:
- state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
- return False
-
- try:
- _write_pyc_fp(fp, source_stat, co)
- os.rename(proc_pyc, os.fspath(pyc))
- except OSError as e:
- state.trace(f"error writing pyc file at {pyc}: {e}")
- # we ignore any failure to write the cache file
- # there are many reasons, permission-denied, pycache dir being a
- # file etc.
- return False
- finally:
- fp.close()
- return True
-
-
-def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
- """Read and rewrite *fn* and return the code object."""
- fn_ = os.fspath(fn)
- stat = os.stat(fn_)
- with open(fn_, "rb") as f:
- source = f.read()
- tree = ast.parse(source, filename=fn_)
- rewrite_asserts(tree, source, fn_, config)
- co = compile(tree, fn_, "exec", dont_inherit=True)
+def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
+ """Read and rewrite *fn* and return the code object."""
+ fn_ = os.fspath(fn)
+ stat = os.stat(fn_)
+ with open(fn_, "rb") as f:
+ source = f.read()
+ tree = ast.parse(source, filename=fn_)
+ rewrite_asserts(tree, source, fn_, config)
+ co = compile(tree, fn_, "exec", dont_inherit=True)
return stat, co
-def _read_pyc(
- source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
-) -> Optional[types.CodeType]:
+def _read_pyc(
+ source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
+) -> Optional[types.CodeType]:
"""Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not.
"""
try:
- fp = open(os.fspath(pyc), "rb")
- except OSError:
+ fp = open(os.fspath(pyc), "rb")
+ except OSError:
return None
with fp:
- # https://www.python.org/dev/peps/pep-0552/
- has_flags = sys.version_info >= (3, 7)
+ # https://www.python.org/dev/peps/pep-0552/
+ has_flags = sys.version_info >= (3, 7)
try:
- stat_result = os.stat(os.fspath(source))
- mtime = int(stat_result.st_mtime)
- size = stat_result.st_size
- data = fp.read(16 if has_flags else 12)
- except OSError as e:
- trace(f"_read_pyc({source}): OSError {e}")
+ stat_result = os.stat(os.fspath(source))
+ mtime = int(stat_result.st_mtime)
+ size = stat_result.st_size
+ data = fp.read(16 if has_flags else 12)
+ except OSError as e:
+ trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
- if len(data) != (16 if has_flags else 12):
- trace("_read_pyc(%s): invalid pyc (too short)" % source)
+ if len(data) != (16 if has_flags else 12):
+ trace("_read_pyc(%s): invalid pyc (too short)" % source)
+ return None
+ if data[:4] != importlib.util.MAGIC_NUMBER:
+ trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
+ return None
+ if has_flags and data[4:8] != b"\x00\x00\x00\x00":
+ trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
+ return None
+ mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8]
+ if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
+ trace("_read_pyc(%s): out of date" % source)
+ return None
+ size_data = data[12 if has_flags else 8 : 16 if has_flags else 12]
+ if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
+ trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
- if data[:4] != importlib.util.MAGIC_NUMBER:
- trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
- return None
- if has_flags and data[4:8] != b"\x00\x00\x00\x00":
- trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
- return None
- mtime_data = data[8 if has_flags else 4 : 12 if has_flags else 8]
- if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
- trace("_read_pyc(%s): out of date" % source)
- return None
- size_data = data[12 if has_flags else 8 : 16 if has_flags else 12]
- if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
- trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
- return None
try:
co = marshal.load(fp)
except Exception as e:
- trace(f"_read_pyc({source}): marshal.load error {e}")
+ trace(f"_read_pyc({source}): marshal.load error {e}")
return None
if not isinstance(co, types.CodeType):
trace("_read_pyc(%s): not a code object" % source)
@@ -408,18 +408,18 @@ def _read_pyc(
return co
-def rewrite_asserts(
- mod: ast.Module,
- source: bytes,
- module_path: Optional[str] = None,
- config: Optional[Config] = None,
-) -> None:
+def rewrite_asserts(
+ mod: ast.Module,
+ source: bytes,
+ module_path: Optional[str] = None,
+ config: Optional[Config] = None,
+) -> None:
"""Rewrite the assert statements in mod."""
- AssertionRewriter(module_path, config, source).run(mod)
+ AssertionRewriter(module_path, config, source).run(mod)
-def _saferepr(obj: object) -> str:
- r"""Get a safe repr of an object for assertion error messages.
+def _saferepr(obj: object) -> str:
+ r"""Get a safe repr of an object for assertion error messages.
The assertion formatting (util.format_explanation()) requires
newlines to be escaped since they are a special character for it.
@@ -428,24 +428,24 @@ def _saferepr(obj: object) -> str:
sequences, especially '\n{' and '\n}' are likely to be present in
JSON reprs.
"""
- return saferepr(obj).replace("\n", "\\n")
+ return saferepr(obj).replace("\n", "\\n")
-def _format_assertmsg(obj: object) -> str:
- r"""Format the custom assertion message given.
+def _format_assertmsg(obj: object) -> str:
+ r"""Format the custom assertion message given.
For strings this simply replaces newlines with '\n~' so that
util.format_explanation() will preserve them instead of escaping
- newlines. For other objects saferepr() is used first.
+ newlines. For other objects saferepr() is used first.
"""
# reprlib appears to have a bug which means that if a string
# contains a newline it gets escaped, however if an object has a
# .__repr__() which contains newlines it does not get escaped.
# However in either case we want to preserve the newline.
- replaces = [("\n", "\n~"), ("%", "%%")]
- if not isinstance(obj, str):
- obj = saferepr(obj)
- replaces.append(("\\n", "\n~"))
+ replaces = [("\n", "\n~"), ("%", "%%")]
+ if not isinstance(obj, str):
+ obj = saferepr(obj)
+ replaces.append(("\\n", "\n~"))
for r1, r2 in replaces:
obj = obj.replace(r1, r2)
@@ -453,27 +453,27 @@ def _format_assertmsg(obj: object) -> str:
return obj
-def _should_repr_global_name(obj: object) -> bool:
- if callable(obj):
- return False
+def _should_repr_global_name(obj: object) -> bool:
+ if callable(obj):
+ return False
- try:
- return not hasattr(obj, "__name__")
- except Exception:
- return True
+ try:
+ return not hasattr(obj, "__name__")
+ except Exception:
+ return True
-
-def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
+
+def _format_boolop(explanations: Iterable[str], is_or: bool) -> str:
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
- return explanation.replace("%", "%%")
+ return explanation.replace("%", "%%")
-def _call_reprcompare(
- ops: Sequence[str],
- results: Sequence[bool],
- expls: Sequence[str],
- each_obj: Sequence[object],
-) -> str:
+def _call_reprcompare(
+ ops: Sequence[str],
+ results: Sequence[bool],
+ expls: Sequence[str],
+ each_obj: Sequence[object],
+) -> str:
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
@@ -488,20 +488,20 @@ def _call_reprcompare(
return expl
-def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
- if util._assertion_pass is not None:
- util._assertion_pass(lineno, orig, expl)
+def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
+ if util._assertion_pass is not None:
+ util._assertion_pass(lineno, orig, expl)
+
-
-def _check_if_assertion_pass_impl() -> bool:
- """Check if any plugins implement the pytest_assertion_pass hook
- in order not to generate explanation unecessarily (might be expensive)."""
- return True if util._assertion_pass else False
-
-
-UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
-
-BINOP_MAP = {
+def _check_if_assertion_pass_impl() -> bool:
+ """Check if any plugins implement the pytest_assertion_pass hook
+ in order not to generate explanation unecessarily (might be expensive)."""
+ return True if util._assertion_pass else False
+
+
+UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
+
+BINOP_MAP = {
ast.BitOr: "|",
ast.BitXor: "^",
ast.BitAnd: "&",
@@ -524,7 +524,7 @@ BINOP_MAP = {
ast.IsNot: "is not",
ast.In: "in",
ast.NotIn: "not in",
- ast.MatMult: "@",
+ ast.MatMult: "@",
}
@@ -543,60 +543,60 @@ def set_location(node, lineno, col_offset):
return node
-def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
- """Return a mapping from {lineno: "assertion test expression"}."""
- ret: Dict[int, str] = {}
-
- depth = 0
- lines: List[str] = []
- assert_lineno: Optional[int] = None
- seen_lines: Set[int] = set()
-
- def _write_and_reset() -> None:
- nonlocal depth, lines, assert_lineno, seen_lines
- assert assert_lineno is not None
- ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
- depth = 0
- lines = []
- assert_lineno = None
- seen_lines = set()
-
- tokens = tokenize.tokenize(io.BytesIO(src).readline)
- for tp, source, (lineno, offset), _, line in tokens:
- if tp == tokenize.NAME and source == "assert":
- assert_lineno = lineno
- elif assert_lineno is not None:
- # keep track of depth for the assert-message `,` lookup
- if tp == tokenize.OP and source in "([{":
- depth += 1
- elif tp == tokenize.OP and source in ")]}":
- depth -= 1
-
- if not lines:
- lines.append(line[offset:])
- seen_lines.add(lineno)
- # a non-nested comma separates the expression from the message
- elif depth == 0 and tp == tokenize.OP and source == ",":
- # one line assert with message
- if lineno in seen_lines and len(lines) == 1:
- offset_in_trimmed = offset + len(lines[-1]) - len(line)
- lines[-1] = lines[-1][:offset_in_trimmed]
- # multi-line assert with message
- elif lineno in seen_lines:
- lines[-1] = lines[-1][:offset]
- # multi line assert with escapd newline before message
- else:
- lines.append(line[:offset])
- _write_and_reset()
- elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
- _write_and_reset()
- elif lines and lineno not in seen_lines:
- lines.append(line)
- seen_lines.add(lineno)
-
- return ret
-
-
+def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
+ """Return a mapping from {lineno: "assertion test expression"}."""
+ ret: Dict[int, str] = {}
+
+ depth = 0
+ lines: List[str] = []
+ assert_lineno: Optional[int] = None
+ seen_lines: Set[int] = set()
+
+ def _write_and_reset() -> None:
+ nonlocal depth, lines, assert_lineno, seen_lines
+ assert assert_lineno is not None
+ ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
+ depth = 0
+ lines = []
+ assert_lineno = None
+ seen_lines = set()
+
+ tokens = tokenize.tokenize(io.BytesIO(src).readline)
+ for tp, source, (lineno, offset), _, line in tokens:
+ if tp == tokenize.NAME and source == "assert":
+ assert_lineno = lineno
+ elif assert_lineno is not None:
+ # keep track of depth for the assert-message `,` lookup
+ if tp == tokenize.OP and source in "([{":
+ depth += 1
+ elif tp == tokenize.OP and source in ")]}":
+ depth -= 1
+
+ if not lines:
+ lines.append(line[offset:])
+ seen_lines.add(lineno)
+ # a non-nested comma separates the expression from the message
+ elif depth == 0 and tp == tokenize.OP and source == ",":
+ # one line assert with message
+ if lineno in seen_lines and len(lines) == 1:
+ offset_in_trimmed = offset + len(lines[-1]) - len(line)
+ lines[-1] = lines[-1][:offset_in_trimmed]
+ # multi-line assert with message
+ elif lineno in seen_lines:
+ lines[-1] = lines[-1][:offset]
+ # multi line assert with escapd newline before message
+ else:
+ lines.append(line[:offset])
+ _write_and_reset()
+ elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
+ _write_and_reset()
+ elif lines and lineno not in seen_lines:
+ lines.append(line)
+ seen_lines.add(lineno)
+
+ return ret
+
+
class AssertionRewriter(ast.NodeVisitor):
"""Assertion rewriting implementation.
@@ -613,8 +613,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in
- case the expression is false and calls pytest_assertion_pass hook
- if expression is true.
+ case the expression is false and calls pytest_assertion_pass hook
+ if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning
@@ -632,10 +632,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and
have the form of "@py_assert0".
- :expl_stmts: The AST statements which will be executed to get
- data from the assertion. This is the code which will construct
- the detailed assertion message that is used in the AssertionError
- or for the pytest_assertion_pass hook.
+ :expl_stmts: The AST statements which will be executed to get
+ data from the assertion. This is the code which will construct
+ the detailed assertion message that is used in the AssertionError
+ or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding
@@ -650,32 +650,32 @@ class AssertionRewriter(ast.NodeVisitor):
by the other visitors.
"""
- def __init__(
- self, module_path: Optional[str], config: Optional[Config], source: bytes
- ) -> None:
- super().__init__()
+ def __init__(
+ self, module_path: Optional[str], config: Optional[Config], source: bytes
+ ) -> None:
+ super().__init__()
self.module_path = module_path
self.config = config
- if config is not None:
- self.enable_assertion_pass_hook = config.getini(
- "enable_assertion_pass_hook"
- )
- else:
- self.enable_assertion_pass_hook = False
- self.source = source
+ if config is not None:
+ self.enable_assertion_pass_hook = config.getini(
+ "enable_assertion_pass_hook"
+ )
+ else:
+ self.enable_assertion_pass_hook = False
+ self.source = source
+
+ @functools.lru_cache(maxsize=1)
+ def _assert_expr_to_lineno(self) -> Dict[int, str]:
+ return _get_assertion_exprs(self.source)
- @functools.lru_cache(maxsize=1)
- def _assert_expr_to_lineno(self) -> Dict[int, str]:
- return _get_assertion_exprs(self.source)
-
- def run(self, mod: ast.Module) -> None:
+ def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
-
- # We'll insert some special imports at the top of the module, but after any
- # docstrings and __future__ imports, so first figure out where that is.
+
+ # We'll insert some special imports at the top of the module, but after any
+ # docstrings and __future__ imports, so first figure out where that is.
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc):
@@ -693,48 +693,48 @@ class AssertionRewriter(ast.NodeVisitor):
return
expect_docstring = False
elif (
- isinstance(item, ast.ImportFrom)
- and item.level == 0
- and item.module == "__future__"
+ isinstance(item, ast.ImportFrom)
+ and item.level == 0
+ and item.module == "__future__"
):
- pass
- else:
+ pass
+ else:
break
pos += 1
- # Special case: for a decorated function, set the lineno to that of the
- # first decorator, not the `def`. Issue #4984.
- if isinstance(item, ast.FunctionDef) and item.decorator_list:
- lineno = item.decorator_list[0].lineno
+ # Special case: for a decorated function, set the lineno to that of the
+ # first decorator, not the `def`. Issue #4984.
+ if isinstance(item, ast.FunctionDef) and item.decorator_list:
+ lineno = item.decorator_list[0].lineno
else:
lineno = item.lineno
- # Now actually insert the special imports.
- if sys.version_info >= (3, 10):
- aliases = [
- ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
- ast.alias(
- "_pytest.assertion.rewrite",
- "@pytest_ar",
- lineno=lineno,
- col_offset=0,
- ),
- ]
- else:
- aliases = [
- ast.alias("builtins", "@py_builtins"),
- ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
- ]
+ # Now actually insert the special imports.
+ if sys.version_info >= (3, 10):
+ aliases = [
+ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0),
+ ast.alias(
+ "_pytest.assertion.rewrite",
+ "@pytest_ar",
+ lineno=lineno,
+ col_offset=0,
+ ),
+ ]
+ else:
+ aliases = [
+ ast.alias("builtins", "@py_builtins"),
+ ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
+ ]
imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
]
mod.body[pos:pos] = imports
-
+
# Collect asserts.
- nodes: List[ast.AST] = [mod]
+ nodes: List[ast.AST] = [mod]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
- new: List[ast.AST] = []
+ new: List[ast.AST] = []
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
@@ -753,38 +753,38 @@ class AssertionRewriter(ast.NodeVisitor):
nodes.append(field)
@staticmethod
- def is_rewrite_disabled(docstring: str) -> bool:
+ def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring
- def variable(self) -> str:
+ def variable(self) -> str:
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
self.variables.append(name)
return name
- def assign(self, expr: ast.expr) -> ast.Name:
+ def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
- def display(self, expr: ast.expr) -> ast.expr:
- """Call saferepr on the expression."""
- return self.helper("_saferepr", expr)
+ def display(self, expr: ast.expr) -> ast.expr:
+ """Call saferepr on the expression."""
+ return self.helper("_saferepr", expr)
- def helper(self, name: str, *args: ast.expr) -> ast.expr:
+ def helper(self, name: str, *args: ast.expr) -> ast.expr:
"""Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load())
- attr = ast.Attribute(py_name, name, ast.Load())
- return ast.Call(attr, list(args), [])
+ attr = ast.Attribute(py_name, name, ast.Load())
+ return ast.Call(attr, list(args), [])
- def builtin(self, name: str) -> ast.Attribute:
+ def builtin(self, name: str) -> ast.Attribute:
"""Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load())
- def explanation_param(self, expr: ast.expr) -> str:
+ def explanation_param(self, expr: ast.expr) -> str:
"""Return a new named %-formatting placeholder for expr.
This creates a %-formatting placeholder for expr in the
@@ -796,7 +796,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
- def push_format_context(self) -> None:
+ def push_format_context(self) -> None:
"""Create a new formatting context.
The format context is used for when an explanation wants to
@@ -806,15 +806,15 @@ class AssertionRewriter(ast.NodeVisitor):
to format a string of %-formatted values as added by
.explanation_param().
"""
- self.explanation_specifiers: Dict[str, ast.expr] = {}
+ self.explanation_specifiers: Dict[str, ast.expr] = {}
self.stack.append(self.explanation_specifiers)
- def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
+ def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
"""Format the %-formatted string with current format context.
- The expl_expr should be an str ast.expr instance constructed from
+ The expl_expr should be an str ast.expr instance constructed from
the %-placeholders created by .explanation_param(). This will
- add the required code to format said string to .expl_stmts and
+ add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
"""
current = self.stack.pop()
@@ -824,18 +824,18 @@ class AssertionRewriter(ast.NodeVisitor):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
- if self.enable_assertion_pass_hook:
- self.format_variables.append(name)
- self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
+ if self.enable_assertion_pass_hook:
+ self.format_variables.append(name)
+ self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
- def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
+ def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr)
res = self.assign(node)
return res, self.explanation_param(self.display(res))
- def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
+ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide
@@ -844,173 +844,173 @@ class AssertionRewriter(ast.NodeVisitor):
the expression is false.
"""
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
- from _pytest.warning_types import PytestAssertRewriteWarning
+ from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
- # TODO: This assert should not be needed.
- assert self.module_path is not None
+ # TODO: This assert should not be needed.
+ assert self.module_path is not None
warnings.warn_explicit(
- PytestAssertRewriteWarning(
- "assertion is always true, perhaps remove parentheses?"
- ),
+ PytestAssertRewriteWarning(
+ "assertion is always true, perhaps remove parentheses?"
+ ),
category=None,
- filename=os.fspath(self.module_path),
+ filename=os.fspath(self.module_path),
lineno=assert_.lineno,
)
- self.statements: List[ast.stmt] = []
- self.variables: List[str] = []
+ self.statements: List[ast.stmt] = []
+ self.variables: List[str] = []
self.variable_counter = itertools.count()
-
- if self.enable_assertion_pass_hook:
- self.format_variables: List[str] = []
-
- self.stack: List[Dict[str, ast.expr]] = []
- self.expl_stmts: List[ast.stmt] = []
+
+ if self.enable_assertion_pass_hook:
+ self.format_variables: List[str] = []
+
+ self.stack: List[Dict[str, ast.expr]] = []
+ self.expl_stmts: List[ast.stmt] = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
-
- negation = ast.UnaryOp(ast.Not(), top_condition)
-
- if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
- msg = self.pop_format_context(ast.Str(explanation))
-
- # Failed
- if assert_.msg:
- assertmsg = self.helper("_format_assertmsg", assert_.msg)
- gluestr = "\n>assert "
- else:
- assertmsg = ast.Str("")
- gluestr = "assert "
- err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
- err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
- err_name = ast.Name("AssertionError", ast.Load())
- fmt = self.helper("_format_explanation", err_msg)
- exc = ast.Call(err_name, [fmt], [])
- raise_ = ast.Raise(exc, None)
- statements_fail = []
- statements_fail.extend(self.expl_stmts)
- statements_fail.append(raise_)
-
- # Passed
- fmt_pass = self.helper("_format_explanation", msg)
- orig = self._assert_expr_to_lineno()[assert_.lineno]
- hook_call_pass = ast.Expr(
- self.helper(
- "_call_assertion_pass",
- ast.Num(assert_.lineno),
- ast.Str(orig),
- fmt_pass,
- )
- )
- # If any hooks implement assert_pass hook
- hook_impl_test = ast.If(
- self.helper("_check_if_assertion_pass_impl"),
- self.expl_stmts + [hook_call_pass],
- [],
- )
- statements_pass = [hook_impl_test]
-
- # Test for assertion condition
- main_test = ast.If(negation, statements_fail, statements_pass)
- self.statements.append(main_test)
- if self.format_variables:
- variables = [
- ast.Name(name, ast.Store()) for name in self.format_variables
- ]
- clear_format = ast.Assign(variables, ast.NameConstant(None))
- self.statements.append(clear_format)
-
- else: # Original assertion rewriting
- # Create failure message.
- body = self.expl_stmts
- self.statements.append(ast.If(negation, body, []))
- if assert_.msg:
- assertmsg = self.helper("_format_assertmsg", assert_.msg)
- explanation = "\n>assert " + explanation
- else:
- assertmsg = ast.Str("")
- explanation = "assert " + explanation
- template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
- msg = self.pop_format_context(template)
- fmt = self.helper("_format_explanation", msg)
- err_name = ast.Name("AssertionError", ast.Load())
- exc = ast.Call(err_name, [fmt], [])
+
+ negation = ast.UnaryOp(ast.Not(), top_condition)
+
+ if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
+ msg = self.pop_format_context(ast.Str(explanation))
+
+ # Failed
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ gluestr = "\n>assert "
+ else:
+ assertmsg = ast.Str("")
+ gluestr = "assert "
+ err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
+ err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
+ err_name = ast.Name("AssertionError", ast.Load())
+ fmt = self.helper("_format_explanation", err_msg)
+ exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
-
- body.append(raise_)
-
+ statements_fail = []
+ statements_fail.extend(self.expl_stmts)
+ statements_fail.append(raise_)
+
+ # Passed
+ fmt_pass = self.helper("_format_explanation", msg)
+ orig = self._assert_expr_to_lineno()[assert_.lineno]
+ hook_call_pass = ast.Expr(
+ self.helper(
+ "_call_assertion_pass",
+ ast.Num(assert_.lineno),
+ ast.Str(orig),
+ fmt_pass,
+ )
+ )
+ # If any hooks implement assert_pass hook
+ hook_impl_test = ast.If(
+ self.helper("_check_if_assertion_pass_impl"),
+ self.expl_stmts + [hook_call_pass],
+ [],
+ )
+ statements_pass = [hook_impl_test]
+
+ # Test for assertion condition
+ main_test = ast.If(negation, statements_fail, statements_pass)
+ self.statements.append(main_test)
+ if self.format_variables:
+ variables = [
+ ast.Name(name, ast.Store()) for name in self.format_variables
+ ]
+ clear_format = ast.Assign(variables, ast.NameConstant(None))
+ self.statements.append(clear_format)
+
+ else: # Original assertion rewriting
+ # Create failure message.
+ body = self.expl_stmts
+ self.statements.append(ast.If(negation, body, []))
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ explanation = "\n>assert " + explanation
+ else:
+ assertmsg = ast.Str("")
+ explanation = "assert " + explanation
+ template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
+ msg = self.pop_format_context(template)
+ fmt = self.helper("_format_explanation", msg)
+ err_name = ast.Name("AssertionError", ast.Load())
+ exc = ast.Call(err_name, [fmt], [])
+ raise_ = ast.Raise(exc, None)
+
+ body.append(raise_)
+
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
- clear = ast.Assign(variables, ast.NameConstant(None))
+ clear = ast.Assign(variables, ast.NameConstant(None))
self.statements.append(clear)
# Fix line numbers.
for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements
- def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
+ def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
- locs = ast.Call(self.builtin("locals"), [], [])
+ locs = ast.Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
- dorepr = self.helper("_should_repr_global_name", name)
+ dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr)
- def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
+ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements
- fail_save = self.expl_stmts
+ fail_save = self.expl_stmts
levels = len(boolop.values) - 1
self.push_format_context()
- # Process each operand, short-circuiting if needed.
+ # Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
- fail_inner: List[ast.stmt] = []
+ fail_inner: List[ast.stmt] = []
# cond is set in a prior loop iteration below
- self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
- self.expl_stmts = fail_inner
+ self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
+ self.expl_stmts = fail_inner
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
- call = ast.Call(app, [expl_format], [])
- self.expl_stmts.append(ast.Expr(call))
+ call = ast.Call(app, [expl_format], [])
+ self.expl_stmts.append(ast.Expr(call))
if i < levels:
- cond: ast.expr = res
+ cond: ast.expr = res
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
- inner: List[ast.stmt] = []
+ inner: List[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
- self.expl_stmts = fail_save
- expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
+ self.expl_stmts = fail_save
+ expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
- def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
- pattern = UNARY_MAP[unary.op.__class__]
+ def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
+ pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,)
- def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
- symbol = BINOP_MAP[binop.op.__class__]
+ def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
+ symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
- explanation = f"({left_expl} {symbol} {right_expl})"
+ explanation = f"({left_expl} {symbol} {right_expl})"
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation
- def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
+ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
@@ -1027,20 +1027,20 @@ class AssertionRewriter(ast.NodeVisitor):
else: # **args have `arg` keywords with an .arg of None
arg_expls.append("**" + expl)
- expl = "{}({})".format(func_expl, ", ".join(arg_expls))
+ expl = "{}({})".format(func_expl, ", ".join(arg_expls))
new_call = ast.Call(new_func, new_args, new_kwargs)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
- outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
+ outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
return res, outer_expl
- def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
- # A Starred node can appear in a function call.
+ def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
+ # A Starred node can appear in a function call.
res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl
- def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
+ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
@@ -1050,11 +1050,11 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
- def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
+ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
- left_expl = f"({left_expl})"
+ left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
@@ -1065,61 +1065,61 @@ class AssertionRewriter(ast.NodeVisitor):
for i, op, next_operand in it:
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
- next_expl = f"({next_expl})"
+ next_expl = f"({next_expl})"
results.append(next_res)
- sym = BINOP_MAP[op.__class__]
+ sym = BINOP_MAP[op.__class__]
syms.append(ast.Str(sym))
- expl = f"{left_expl} {sym} {next_expl}"
+ expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Str(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
# Use pytest.assertion.util._reprcompare if that's available.
expl_call = self.helper(
- "_call_reprcompare",
+ "_call_reprcompare",
ast.Tuple(syms, ast.Load()),
ast.Tuple(load_names, ast.Load()),
ast.Tuple(expls, ast.Load()),
ast.Tuple(results, ast.Load()),
)
if len(comp.ops) > 1:
- res: ast.expr = ast.BoolOp(ast.And(), load_names)
+ res: ast.expr = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))
-
-
-def try_makedirs(cache_dir: Path) -> bool:
- """Attempt to create the given directory and sub-directories exist.
-
- Returns True if successful or if it already exists.
- """
- try:
- os.makedirs(os.fspath(cache_dir), exist_ok=True)
- except (FileNotFoundError, NotADirectoryError, FileExistsError):
- # One of the path components was not a directory:
- # - we're in a zip file
- # - it is a file
- return False
- except PermissionError:
- return False
- except OSError as e:
- # as of now, EROFS doesn't have an equivalent OSError-subclass
- if e.errno == errno.EROFS:
- return False
- raise
- return True
-
-
-def get_cache_dir(file_path: Path) -> Path:
- """Return the cache directory to write .pyc files for the given .py file path."""
- if sys.version_info >= (3, 8) and sys.pycache_prefix:
- # given:
- # prefix = '/tmp/pycs'
- # path = '/home/user/proj/test_app.py'
- # we want:
- # '/tmp/pycs/home/user/proj'
- return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
- else:
- # classic pycache directory
- return file_path.parent / "__pycache__"
+
+
+def try_makedirs(cache_dir: Path) -> bool:
+ """Attempt to create the given directory and sub-directories exist.
+
+ Returns True if successful or if it already exists.
+ """
+ try:
+ os.makedirs(os.fspath(cache_dir), exist_ok=True)
+ except (FileNotFoundError, NotADirectoryError, FileExistsError):
+ # One of the path components was not a directory:
+ # - we're in a zip file
+ # - it is a file
+ return False
+ except PermissionError:
+ return False
+ except OSError as e:
+ # as of now, EROFS doesn't have an equivalent OSError-subclass
+ if e.errno == errno.EROFS:
+ return False
+ raise
+ return True
+
+
+def get_cache_dir(file_path: Path) -> Path:
+ """Return the cache directory to write .pyc files for the given .py file path."""
+ if sys.version_info >= (3, 8) and sys.pycache_prefix:
+ # given:
+ # prefix = '/tmp/pycs'
+ # path = '/home/user/proj/test_app.py'
+ # we want:
+ # '/tmp/pycs/home/user/proj'
+ return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
+ else:
+ # classic pycache directory
+ return file_path.parent / "__pycache__"
diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
index 00a2697363b..5ba9ddca75a 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
@@ -1,47 +1,47 @@
-"""Utilities for truncating assertion output.
+"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI.
"""
import os
-from typing import List
-from typing import Optional
+from typing import List
+from typing import Optional
+
+from _pytest.nodes import Item
+
-from _pytest.nodes import Item
-
-
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show"
-def truncate_if_required(
- explanation: List[str], item: Item, max_length: Optional[int] = None
-) -> List[str]:
- """Truncate this assertion explanation if the given test item is eligible."""
+def truncate_if_required(
+ explanation: List[str], item: Item, max_length: Optional[int] = None
+) -> List[str]:
+ """Truncate this assertion explanation if the given test item is eligible."""
if _should_truncate_item(item):
return _truncate_explanation(explanation)
return explanation
-def _should_truncate_item(item: Item) -> bool:
- """Whether or not this test item is eligible for truncation."""
+def _should_truncate_item(item: Item) -> bool:
+ """Whether or not this test item is eligible for truncation."""
verbose = item.config.option.verbose
return verbose < 2 and not _running_on_ci()
-def _running_on_ci() -> bool:
+def _running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)
-def _truncate_explanation(
- input_lines: List[str],
- max_lines: Optional[int] = None,
- max_chars: Optional[int] = None,
-) -> List[str]:
- """Truncate given list of strings that makes up the assertion explanation.
+def _truncate_explanation(
+ input_lines: List[str],
+ max_lines: Optional[int] = None,
+ max_chars: Optional[int] = None,
+) -> List[str]:
+ """Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches
first. The remaining lines will be replaced by a usage message.
@@ -70,15 +70,15 @@ def _truncate_explanation(
truncated_line_count += 1 # Account for the part-truncated final line
msg = "...Full output truncated"
if truncated_line_count == 1:
- msg += f" ({truncated_line_count} line hidden)"
+ msg += f" ({truncated_line_count} line hidden)"
else:
- msg += f" ({truncated_line_count} lines hidden)"
- msg += f", {USAGE_MSG}"
- truncated_explanation.extend(["", str(msg)])
+ msg += f" ({truncated_line_count} lines hidden)"
+ msg += f", {USAGE_MSG}"
+ truncated_explanation.extend(["", str(msg)])
return truncated_explanation
-def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
+def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines
diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py
index 60e8f3a6567..da1ffd15e37 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/util.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/util.py
@@ -1,34 +1,34 @@
-"""Utilities for assertion debugging."""
-import collections.abc
+"""Utilities for assertion debugging."""
+import collections.abc
import pprint
-from typing import AbstractSet
-from typing import Any
-from typing import Callable
-from typing import Iterable
-from typing import List
-from typing import Mapping
-from typing import Optional
-from typing import Sequence
+from typing import AbstractSet
+from typing import Any
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
import _pytest._code
-from _pytest import outcomes
-from _pytest._io.saferepr import _pformat_dispatch
-from _pytest._io.saferepr import safeformat
-from _pytest._io.saferepr import saferepr
+from _pytest import outcomes
+from _pytest._io.saferepr import _pformat_dispatch
+from _pytest._io.saferepr import safeformat
+from _pytest._io.saferepr import saferepr
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
# DebugInterpreter.
-_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
+_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
-# Works similarly as _reprcompare attribute. Is populated with the hook call
-# when pytest_runtest_setup is called.
-_assertion_pass: Optional[Callable[[int, str, str], None]] = None
+# Works similarly as _reprcompare attribute. Is populated with the hook call
+# when pytest_runtest_setup is called.
+_assertion_pass: Optional[Callable[[int, str, str], None]] = None
-def format_explanation(explanation: str) -> str:
- r"""Format an explanation.
+def format_explanation(explanation: str) -> str:
+ r"""Format an explanation.
Normally all embedded newlines are escaped, however there are
three exceptions: \n{, \n} and \n~. The first two are intended
@@ -39,17 +39,17 @@ def format_explanation(explanation: str) -> str:
"""
lines = _split_explanation(explanation)
result = _format_lines(lines)
- return "\n".join(result)
+ return "\n".join(result)
-def _split_explanation(explanation: str) -> List[str]:
- r"""Return a list of individual lines in the explanation.
+def _split_explanation(explanation: str) -> List[str]:
+ r"""Return a list of individual lines in the explanation.
This will return a list of lines split on '\n{', '\n}' and '\n~'.
Any other newlines will be escaped and appear in the line as the
literal '\n' characters.
"""
- raw_lines = (explanation or "").split("\n")
+ raw_lines = (explanation or "").split("\n")
lines = [raw_lines[0]]
for values in raw_lines[1:]:
if values and values[0] in ["{", "}", "~", ">"]:
@@ -59,28 +59,28 @@ def _split_explanation(explanation: str) -> List[str]:
return lines
-def _format_lines(lines: Sequence[str]) -> List[str]:
- """Format the individual lines.
+def _format_lines(lines: Sequence[str]) -> List[str]:
+ """Format the individual lines.
- This will replace the '{', '}' and '~' characters of our mini formatting
- language with the proper 'where ...', 'and ...' and ' + ...' text, taking
- care of indentation along the way.
+ This will replace the '{', '}' and '~' characters of our mini formatting
+ language with the proper 'where ...', 'and ...' and ' + ...' text, taking
+ care of indentation along the way.
Return a list of formatted lines.
"""
- result = list(lines[:1])
+ result = list(lines[:1])
stack = [0]
stackcnt = [0]
for line in lines[1:]:
if line.startswith("{"):
if stackcnt[-1]:
- s = "and "
+ s = "and "
else:
- s = "where "
+ s = "where "
stack.append(len(result))
stackcnt[-1] += 1
stackcnt.append(0)
- result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
+ result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
elif line.startswith("}"):
stack.pop()
stackcnt.pop()
@@ -89,79 +89,79 @@ def _format_lines(lines: Sequence[str]) -> List[str]:
assert line[0] in ["~", ">"]
stack[-1] += 1
indent = len(stack) if line.startswith("~") else len(stack) - 1
- result.append(" " * indent + line[1:])
+ result.append(" " * indent + line[1:])
assert len(stack) == 1
return result
-def issequence(x: Any) -> bool:
- return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
+def issequence(x: Any) -> bool:
+ return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
-def istext(x: Any) -> bool:
- return isinstance(x, str)
+def istext(x: Any) -> bool:
+ return isinstance(x, str)
-def isdict(x: Any) -> bool:
- return isinstance(x, dict)
+def isdict(x: Any) -> bool:
+ return isinstance(x, dict)
-def isset(x: Any) -> bool:
- return isinstance(x, (set, frozenset))
+def isset(x: Any) -> bool:
+ return isinstance(x, (set, frozenset))
+
+
+def isnamedtuple(obj: Any) -> bool:
+ return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
+
+
+def isdatacls(obj: Any) -> bool:
+ return getattr(obj, "__dataclass_fields__", None) is not None
+
+
+def isattrs(obj: Any) -> bool:
+ return getattr(obj, "__attrs_attrs__", None) is not None
+
+
+def isiterable(obj: Any) -> bool:
+ try:
+ iter(obj)
+ return not istext(obj)
+ except TypeError:
+ return False
+
+
+def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
+ """Return specialised explanations for some operators/operands."""
+ verbose = config.getoption("verbose")
+ if verbose > 1:
+ left_repr = safeformat(left)
+ right_repr = safeformat(right)
+ else:
+ # XXX: "15 chars indentation" is wrong
+ # ("E AssertionError: assert "); should use term width.
+ maxsize = (
+ 80 - 15 - len(op) - 2
+ ) // 2 # 15 chars indentation, 1 space around op
+ left_repr = saferepr(left, maxsize=maxsize)
+ right_repr = saferepr(right, maxsize=maxsize)
+
+ summary = f"{left_repr} {op} {right_repr}"
-
-def isnamedtuple(obj: Any) -> bool:
- return isinstance(obj, tuple) and getattr(obj, "_fields", None) is not None
-
-
-def isdatacls(obj: Any) -> bool:
- return getattr(obj, "__dataclass_fields__", None) is not None
-
-
-def isattrs(obj: Any) -> bool:
- return getattr(obj, "__attrs_attrs__", None) is not None
-
-
-def isiterable(obj: Any) -> bool:
- try:
- iter(obj)
- return not istext(obj)
- except TypeError:
- return False
-
-
-def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
- """Return specialised explanations for some operators/operands."""
- verbose = config.getoption("verbose")
- if verbose > 1:
- left_repr = safeformat(left)
- right_repr = safeformat(right)
- else:
- # XXX: "15 chars indentation" is wrong
- # ("E AssertionError: assert "); should use term width.
- maxsize = (
- 80 - 15 - len(op) - 2
- ) // 2 # 15 chars indentation, 1 space around op
- left_repr = saferepr(left, maxsize=maxsize)
- right_repr = saferepr(right, maxsize=maxsize)
-
- summary = f"{left_repr} {op} {right_repr}"
-
explanation = None
try:
if op == "==":
- explanation = _compare_eq_any(left, right, verbose)
+ explanation = _compare_eq_any(left, right, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
- except outcomes.Exit:
- raise
+ except outcomes.Exit:
+ raise
except Exception:
explanation = [
- "(pytest_assertion plugin: representation of details failed: {}.".format(
- _pytest._code.ExceptionInfo.from_current()._getreprcrash()
- ),
- " Probably an object has a faulty __repr__.)",
+ "(pytest_assertion plugin: representation of details failed: {}.".format(
+ _pytest._code.ExceptionInfo.from_current()._getreprcrash()
+ ),
+ " Probably an object has a faulty __repr__.)",
]
if not explanation:
@@ -170,44 +170,44 @@ def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[
return [summary] + explanation
-def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
- explanation = []
- if istext(left) and istext(right):
- explanation = _diff_text(left, right, verbose)
- else:
- if type(left) == type(right) and (
- isdatacls(left) or isattrs(left) or isnamedtuple(left)
- ):
- # Note: unlike dataclasses/attrs, namedtuples compare only the
- # field values, not the type or field names. But this branch
- # intentionally only handles the same-type case, which was often
- # used in older code bases before dataclasses/attrs were available.
- explanation = _compare_eq_cls(left, right, verbose)
- elif issequence(left) and issequence(right):
- explanation = _compare_eq_sequence(left, right, verbose)
- elif isset(left) and isset(right):
- explanation = _compare_eq_set(left, right, verbose)
- elif isdict(left) and isdict(right):
- explanation = _compare_eq_dict(left, right, verbose)
- elif verbose > 0:
- explanation = _compare_eq_verbose(left, right)
- if isiterable(left) and isiterable(right):
- expl = _compare_eq_iterable(left, right, verbose)
- explanation.extend(expl)
- return explanation
-
-
-def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
- """Return the explanation for the diff between text.
+def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
+ explanation = []
+ if istext(left) and istext(right):
+ explanation = _diff_text(left, right, verbose)
+ else:
+ if type(left) == type(right) and (
+ isdatacls(left) or isattrs(left) or isnamedtuple(left)
+ ):
+ # Note: unlike dataclasses/attrs, namedtuples compare only the
+ # field values, not the type or field names. But this branch
+ # intentionally only handles the same-type case, which was often
+ # used in older code bases before dataclasses/attrs were available.
+ explanation = _compare_eq_cls(left, right, verbose)
+ elif issequence(left) and issequence(right):
+ explanation = _compare_eq_sequence(left, right, verbose)
+ elif isset(left) and isset(right):
+ explanation = _compare_eq_set(left, right, verbose)
+ elif isdict(left) and isdict(right):
+ explanation = _compare_eq_dict(left, right, verbose)
+ elif verbose > 0:
+ explanation = _compare_eq_verbose(left, right)
+ if isiterable(left) and isiterable(right):
+ expl = _compare_eq_iterable(left, right, verbose)
+ explanation.extend(expl)
+ return explanation
+
+
+def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
+ """Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing
characters which are identical to keep the diff minimal.
"""
from difflib import ndiff
- explanation: List[str] = []
+ explanation: List[str] = []
- if verbose < 1:
+ if verbose < 1:
i = 0 # just in case left or right has zero length
for i in range(min(len(left), len(right))):
if left[i] != right[i]:
@@ -215,7 +215,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42:
i -= 10 # Provide some context
explanation = [
- "Skipping %s identical leading characters in diff, use -v to show" % i
+ "Skipping %s identical leading characters in diff, use -v to show" % i
]
left = left[i:]
right = right[i:]
@@ -226,8 +226,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42:
i -= 10 # Provide some context
explanation += [
- "Skipping {} identical trailing "
- "characters in diff, use -v to show".format(i)
+ "Skipping {} identical trailing "
+ "characters in diff, use -v to show".format(i)
]
left = left[:-i]
right = right[:-i]
@@ -235,243 +235,243 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if left.isspace() or right.isspace():
left = repr(str(left))
right = repr(str(right))
- explanation += ["Strings contain only whitespace, escaping them using repr()"]
- # "right" is the expected base against which we compare "left",
- # see https://github.com/pytest-dev/pytest/issues/3333
+ explanation += ["Strings contain only whitespace, escaping them using repr()"]
+ # "right" is the expected base against which we compare "left",
+ # see https://github.com/pytest-dev/pytest/issues/3333
explanation += [
line.strip("\n")
- for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
+ for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
]
return explanation
-def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
- keepends = True
- left_lines = repr(left).splitlines(keepends)
- right_lines = repr(right).splitlines(keepends)
-
- explanation: List[str] = []
- explanation += ["+" + line for line in left_lines]
- explanation += ["-" + line for line in right_lines]
-
- return explanation
-
-
-def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
- """Move opening/closing parenthesis/bracket to own lines."""
- opening = lines[0][:1]
- if opening in ["(", "[", "{"]:
- lines[0] = " " + lines[0][1:]
- lines[:] = [opening] + lines
- closing = lines[-1][-1:]
- if closing in [")", "]", "}"]:
- lines[-1] = lines[-1][:-1] + ","
- lines[:] = lines + [closing]
-
-
-def _compare_eq_iterable(
- left: Iterable[Any], right: Iterable[Any], verbose: int = 0
-) -> List[str]:
+def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
+ keepends = True
+ left_lines = repr(left).splitlines(keepends)
+ right_lines = repr(right).splitlines(keepends)
+
+ explanation: List[str] = []
+ explanation += ["+" + line for line in left_lines]
+ explanation += ["-" + line for line in right_lines]
+
+ return explanation
+
+
+def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
+ """Move opening/closing parenthesis/bracket to own lines."""
+ opening = lines[0][:1]
+ if opening in ["(", "[", "{"]:
+ lines[0] = " " + lines[0][1:]
+ lines[:] = [opening] + lines
+ closing = lines[-1][-1:]
+ if closing in [")", "]", "}"]:
+ lines[-1] = lines[-1][:-1] + ","
+ lines[:] = lines + [closing]
+
+
+def _compare_eq_iterable(
+ left: Iterable[Any], right: Iterable[Any], verbose: int = 0
+) -> List[str]:
if not verbose:
- return ["Use -v to get the full diff"]
+ return ["Use -v to get the full diff"]
# dynamic import to speedup pytest
import difflib
- left_formatting = pprint.pformat(left).splitlines()
- right_formatting = pprint.pformat(right).splitlines()
-
- # Re-format for different output lengths.
- lines_left = len(left_formatting)
- lines_right = len(right_formatting)
- if lines_left != lines_right:
- left_formatting = _pformat_dispatch(left).splitlines()
- right_formatting = _pformat_dispatch(right).splitlines()
-
- if lines_left > 1 or lines_right > 1:
- _surrounding_parens_on_own_lines(left_formatting)
- _surrounding_parens_on_own_lines(right_formatting)
-
- explanation = ["Full diff:"]
- # "right" is the expected base against which we compare "left",
- # see https://github.com/pytest-dev/pytest/issues/3333
+ left_formatting = pprint.pformat(left).splitlines()
+ right_formatting = pprint.pformat(right).splitlines()
+
+ # Re-format for different output lengths.
+ lines_left = len(left_formatting)
+ lines_right = len(right_formatting)
+ if lines_left != lines_right:
+ left_formatting = _pformat_dispatch(left).splitlines()
+ right_formatting = _pformat_dispatch(right).splitlines()
+
+ if lines_left > 1 or lines_right > 1:
+ _surrounding_parens_on_own_lines(left_formatting)
+ _surrounding_parens_on_own_lines(right_formatting)
+
+ explanation = ["Full diff:"]
+ # "right" is the expected base against which we compare "left",
+ # see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
- line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
+ line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
)
return explanation
-def _compare_eq_sequence(
- left: Sequence[Any], right: Sequence[Any], verbose: int = 0
-) -> List[str]:
- comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
- explanation: List[str] = []
- len_left = len(left)
- len_right = len(right)
- for i in range(min(len_left, len_right)):
+def _compare_eq_sequence(
+ left: Sequence[Any], right: Sequence[Any], verbose: int = 0
+) -> List[str]:
+ comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
+ explanation: List[str] = []
+ len_left = len(left)
+ len_right = len(right)
+ for i in range(min(len_left, len_right)):
if left[i] != right[i]:
- if comparing_bytes:
- # when comparing bytes, we want to see their ascii representation
- # instead of their numeric values (#5260)
- # using a slice gives us the ascii representation:
- # >>> s = b'foo'
- # >>> s[0]
- # 102
- # >>> s[0:1]
- # b'f'
- left_value = left[i : i + 1]
- right_value = right[i : i + 1]
- else:
- left_value = left[i]
- right_value = right[i]
-
- explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
+ if comparing_bytes:
+ # when comparing bytes, we want to see their ascii representation
+ # instead of their numeric values (#5260)
+ # using a slice gives us the ascii representation:
+ # >>> s = b'foo'
+ # >>> s[0]
+ # 102
+ # >>> s[0:1]
+ # b'f'
+ left_value = left[i : i + 1]
+ right_value = right[i : i + 1]
+ else:
+ left_value = left[i]
+ right_value = right[i]
+
+ explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
break
-
- if comparing_bytes:
- # when comparing bytes, it doesn't help to show the "sides contain one or more
- # items" longer explanation, so skip it
-
- return explanation
-
- len_diff = len_left - len_right
- if len_diff:
- if len_diff > 0:
- dir_with_more = "Left"
- extra = saferepr(left[len_right])
- else:
- len_diff = 0 - len_diff
- dir_with_more = "Right"
- extra = saferepr(right[len_left])
-
- if len_diff == 1:
- explanation += [f"{dir_with_more} contains one more item: {extra}"]
- else:
- explanation += [
- "%s contains %d more items, first extra item: %s"
- % (dir_with_more, len_diff, extra)
- ]
+
+ if comparing_bytes:
+ # when comparing bytes, it doesn't help to show the "sides contain one or more
+ # items" longer explanation, so skip it
+
+ return explanation
+
+ len_diff = len_left - len_right
+ if len_diff:
+ if len_diff > 0:
+ dir_with_more = "Left"
+ extra = saferepr(left[len_right])
+ else:
+ len_diff = 0 - len_diff
+ dir_with_more = "Right"
+ extra = saferepr(right[len_left])
+
+ if len_diff == 1:
+ explanation += [f"{dir_with_more} contains one more item: {extra}"]
+ else:
+ explanation += [
+ "%s contains %d more items, first extra item: %s"
+ % (dir_with_more, len_diff, extra)
+ ]
return explanation
-def _compare_eq_set(
- left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
-) -> List[str]:
+def _compare_eq_set(
+ left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
+) -> List[str]:
explanation = []
diff_left = left - right
diff_right = right - left
if diff_left:
- explanation.append("Extra items in the left set:")
+ explanation.append("Extra items in the left set:")
for item in diff_left:
- explanation.append(saferepr(item))
+ explanation.append(saferepr(item))
if diff_right:
- explanation.append("Extra items in the right set:")
+ explanation.append("Extra items in the right set:")
for item in diff_right:
- explanation.append(saferepr(item))
+ explanation.append(saferepr(item))
return explanation
-def _compare_eq_dict(
- left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
-) -> List[str]:
- explanation: List[str] = []
- set_left = set(left)
- set_right = set(right)
- common = set_left.intersection(set_right)
+def _compare_eq_dict(
+ left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
+) -> List[str]:
+ explanation: List[str] = []
+ set_left = set(left)
+ set_right = set(right)
+ common = set_left.intersection(set_right)
same = {k: left[k] for k in common if left[k] == right[k]}
if same and verbose < 2:
- explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
+ explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
elif same:
- explanation += ["Common items:"]
+ explanation += ["Common items:"]
explanation += pprint.pformat(same).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
- explanation += ["Differing items:"]
+ explanation += ["Differing items:"]
for k in diff:
- explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
- extra_left = set_left - set_right
- len_extra_left = len(extra_left)
- if len_extra_left:
- explanation.append(
- "Left contains %d more item%s:"
- % (len_extra_left, "" if len_extra_left == 1 else "s")
- )
+ explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
+ extra_left = set_left - set_right
+ len_extra_left = len(extra_left)
+ if len_extra_left:
+ explanation.append(
+ "Left contains %d more item%s:"
+ % (len_extra_left, "" if len_extra_left == 1 else "s")
+ )
explanation.extend(
pprint.pformat({k: left[k] for k in extra_left}).splitlines()
)
- extra_right = set_right - set_left
- len_extra_right = len(extra_right)
- if len_extra_right:
- explanation.append(
- "Right contains %d more item%s:"
- % (len_extra_right, "" if len_extra_right == 1 else "s")
- )
+ extra_right = set_right - set_left
+ len_extra_right = len(extra_right)
+ if len_extra_right:
+ explanation.append(
+ "Right contains %d more item%s:"
+ % (len_extra_right, "" if len_extra_right == 1 else "s")
+ )
explanation.extend(
pprint.pformat({k: right[k] for k in extra_right}).splitlines()
)
return explanation
-def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
- if isdatacls(left):
- all_fields = left.__dataclass_fields__
- fields_to_check = [field for field, info in all_fields.items() if info.compare]
- elif isattrs(left):
- all_fields = left.__attrs_attrs__
- fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
- elif isnamedtuple(left):
- fields_to_check = left._fields
- else:
- assert False
-
- indent = " "
- same = []
- diff = []
- for field in fields_to_check:
- if getattr(left, field) == getattr(right, field):
- same.append(field)
- else:
- diff.append(field)
-
- explanation = []
- if same or diff:
- explanation += [""]
- if same and verbose < 2:
- explanation.append("Omitting %s identical items, use -vv to show" % len(same))
- elif same:
- explanation += ["Matching attributes:"]
- explanation += pprint.pformat(same).splitlines()
- if diff:
- explanation += ["Differing attributes:"]
- explanation += pprint.pformat(diff).splitlines()
- for field in diff:
- field_left = getattr(left, field)
- field_right = getattr(right, field)
- explanation += [
- "",
- "Drill down into differing attribute %s:" % field,
- ("%s%s: %r != %r") % (indent, field, field_left, field_right),
- ]
- explanation += [
- indent + line
- for line in _compare_eq_any(field_left, field_right, verbose)
- ]
- return explanation
-
-
-def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
+def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
+ if isdatacls(left):
+ all_fields = left.__dataclass_fields__
+ fields_to_check = [field for field, info in all_fields.items() if info.compare]
+ elif isattrs(left):
+ all_fields = left.__attrs_attrs__
+ fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
+ elif isnamedtuple(left):
+ fields_to_check = left._fields
+ else:
+ assert False
+
+ indent = " "
+ same = []
+ diff = []
+ for field in fields_to_check:
+ if getattr(left, field) == getattr(right, field):
+ same.append(field)
+ else:
+ diff.append(field)
+
+ explanation = []
+ if same or diff:
+ explanation += [""]
+ if same and verbose < 2:
+ explanation.append("Omitting %s identical items, use -vv to show" % len(same))
+ elif same:
+ explanation += ["Matching attributes:"]
+ explanation += pprint.pformat(same).splitlines()
+ if diff:
+ explanation += ["Differing attributes:"]
+ explanation += pprint.pformat(diff).splitlines()
+ for field in diff:
+ field_left = getattr(left, field)
+ field_right = getattr(right, field)
+ explanation += [
+ "",
+ "Drill down into differing attribute %s:" % field,
+ ("%s%s: %r != %r") % (indent, field, field_left, field_right),
+ ]
+ explanation += [
+ indent + line
+ for line in _compare_eq_any(field_left, field_right, verbose)
+ ]
+ return explanation
+
+
+def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
index = text.find(term)
head = text[:index]
tail = text[index + len(term) :]
correct_text = head + tail
- diff = _diff_text(text, correct_text, verbose)
- newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
+ diff = _diff_text(text, correct_text, verbose)
+ newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
for line in diff:
- if line.startswith("Skipping"):
+ if line.startswith("Skipping"):
continue
- if line.startswith("- "):
+ if line.startswith("- "):
continue
- if line.startswith("+ "):
- newdiff.append(" " + line[2:])
+ if line.startswith("+ "):
+ newdiff.append(" " + line[2:])
else:
newdiff.append(line)
return newdiff