aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py3/_pytest/assertion
diff options
context:
space:
mode:
authorrobot-piglet <robot-piglet@yandex-team.com>2025-05-05 12:31:52 +0300
committerrobot-piglet <robot-piglet@yandex-team.com>2025-05-05 12:41:33 +0300
commit6ff49ec58061f642c3a2f83c61eba12820787dfc (patch)
treec733ec9bdb15ed280080d31dea8725bfec717acd /contrib/python/pytest/py3/_pytest/assertion
parenteefca8305c6a545cc6b16dca3eb0d91dcef2adcd (diff)
downloadydb-6ff49ec58061f642c3a2f83c61eba12820787dfc.tar.gz
Intermediate changes
commit_hash:8b3bb826b17db8329ed1221f545c0645f12c552d
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/__init__.py24
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py120
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/truncate.py6
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/util.py221
4 files changed, 223 insertions, 148 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
index a46e58136ba..2bce0ec7cb5 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
@@ -1,4 +1,5 @@
"""Support for presenting detailed information in failing assertions."""
+
import sys
from typing import Any
from typing import Generator
@@ -15,6 +16,7 @@ from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
+
if TYPE_CHECKING:
from _pytest.main import Session
@@ -42,6 +44,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Enables the pytest_assertion_pass hook. "
"Make sure to delete any previously generated pyc cache files.",
)
+ Config._add_verbosity_ini(
+ parser,
+ Config.VERBOSITY_ASSERTIONS,
+ help=(
+ "Specify a verbosity level for assertions, overriding the main level. "
+ "Higher levels will provide more detailed explanation when an assertion fails."
+ ),
+ )
def register_assert_rewrite(*names: str) -> None:
@@ -112,15 +122,14 @@ def pytest_collection(session: "Session") -> None:
assertstate.hook.set_session(session)
-@hookimpl(tryfirst=True, hookwrapper=True)
-def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
+@hookimpl(wrapper=True, tryfirst=True)
+def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks.
The rewrite module will use util._reprcompare if it exists to use custom
reporting via the pytest_assertrepr_compare hook. This sets up this custom
comparison for the test.
"""
-
ihook = item.ihook
def callbinrepr(op, left: object, right: object) -> Optional[str]:
@@ -162,10 +171,11 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
util._assertion_pass = call_assertion_pass_hook
- yield
-
- util._reprcompare, util._assertion_pass = saved_assert_hooks
- util._config = None
+ try:
+ return (yield)
+ finally:
+ util._reprcompare, util._assertion_pass = saved_assert_hooks
+ util._config = None
def pytest_sessionfinish(session: "Session") -> None:
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
index d1974bb3b4a..0ab6eaa1393 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -1,5 +1,7 @@
"""Rewrite assertion AST to produce nice error messages."""
+
import ast
+from collections import defaultdict
import errno
import functools
import importlib.abc
@@ -9,13 +11,12 @@ import io
import itertools
import marshal
import os
+from pathlib import Path
+from pathlib import PurePath
import struct
import sys
import tokenize
import types
-from collections import defaultdict
-from pathlib import Path
-from pathlib import PurePath
from typing import Callable
from typing import Dict
from typing import IO
@@ -33,29 +34,20 @@ from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
from _pytest.assertion import util
-from _pytest.assertion.util import ( # noqa: F401
- format_explanation as _format_explanation,
-)
from _pytest.config import Config
from _pytest.main import Session
from _pytest.pathlib import absolutepath
from _pytest.pathlib import fnmatch_ex
from _pytest.stash import StashKey
+
+# fmt: off
+from _pytest.assertion.util import format_explanation as _format_explanation # noqa:F401, isort:skip
+# fmt:on
+
if TYPE_CHECKING:
from _pytest.assertion import AssertionState
-if sys.version_info >= (3, 8):
- namedExpr = ast.NamedExpr
- astNameConstant = ast.Constant
- astStr = ast.Constant
- astNum = ast.Constant
-else:
- namedExpr = ast.Expr
- astNameConstant = ast.NameConstant
- astStr = ast.Str
- astNum = ast.Num
-
class Sentinel:
pass
@@ -437,7 +429,10 @@ def _saferepr(obj: object) -> str:
def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]:
"""Get `maxsize` configuration for saferepr based on the given config object."""
- verbosity = config.getoption("verbose") if config is not None else 0
+ if config is None:
+ verbosity = 0
+ else:
+ verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
if verbosity >= 2:
return None
if verbosity >= 1:
@@ -604,13 +599,6 @@ def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
return ret
-def _get_ast_constant_value(value: astStr) -> object:
- if sys.version_info >= (3, 8):
- return value.value
- else:
- return value.s
-
-
class AssertionRewriter(ast.NodeVisitor):
"""Assertion rewriting implementation.
@@ -706,11 +694,10 @@ class AssertionRewriter(ast.NodeVisitor):
if (
expect_docstring
and isinstance(item, ast.Expr)
- and isinstance(item.value, astStr)
- and isinstance(_get_ast_constant_value(item.value), str)
+ and isinstance(item.value, ast.Constant)
+ and isinstance(item.value.value, str)
):
- doc = _get_ast_constant_value(item.value)
- assert isinstance(doc, str)
+ doc = item.value.value
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
@@ -850,7 +837,7 @@ class AssertionRewriter(ast.NodeVisitor):
current = self.stack.pop()
if self.stack:
self.explanation_specifiers = self.stack[-1]
- keys = [astStr(key) for key in current.keys()]
+ keys = [ast.Constant(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
@@ -874,9 +861,10 @@ class AssertionRewriter(ast.NodeVisitor):
the expression is false.
"""
if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
- from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
+ from _pytest.warning_types import PytestAssertRewriteWarning
+
# TODO: This assert should not be needed.
assert self.module_path is not None
warnings.warn_explicit(
@@ -904,16 +892,16 @@ class AssertionRewriter(ast.NodeVisitor):
negation = ast.UnaryOp(ast.Not(), top_condition)
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
- msg = self.pop_format_context(astStr(explanation))
+ msg = self.pop_format_context(ast.Constant(explanation))
# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
- assertmsg = astStr("")
+ assertmsg = ast.Constant("")
gluestr = "assert "
- err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
+ err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
@@ -929,8 +917,8 @@ class AssertionRewriter(ast.NodeVisitor):
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
- astNum(assert_.lineno),
- astStr(orig),
+ ast.Constant(assert_.lineno),
+ ast.Constant(orig),
fmt_pass,
)
)
@@ -949,7 +937,7 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
- clear_format = ast.Assign(variables, astNameConstant(None))
+ clear_format = ast.Assign(variables, ast.Constant(None))
self.statements.append(clear_format)
else: # Original assertion rewriting
@@ -960,9 +948,9 @@ class AssertionRewriter(ast.NodeVisitor):
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
- assertmsg = astStr("")
+ assertmsg = ast.Constant("")
explanation = "assert " + explanation
- template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
+ template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
@@ -974,7 +962,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
- clear = ast.Assign(variables, astNameConstant(None))
+ clear = ast.Assign(variables, ast.Constant(None))
self.statements.append(clear)
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
@@ -982,26 +970,26 @@ class AssertionRewriter(ast.NodeVisitor):
ast.copy_location(node, assert_)
return self.statements
- def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
+ def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
- inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
+ inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
- expr = ast.IfExp(test, self.display(name), astStr(target_id))
+ expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
return name, self.explanation_param(expr)
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
- inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
+ inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
- expr = ast.IfExp(test, self.display(name), astStr(name.id))
+ expr = ast.IfExp(test, self.display(name), ast.Constant(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
@@ -1020,10 +1008,10 @@ class AssertionRewriter(ast.NodeVisitor):
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
- # Check if the left operand is a namedExpr and the value has already been visited
+ # Check if the left operand is a ast.NamedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
- and isinstance(v.left, namedExpr)
+ and isinstance(v.left, ast.NamedExpr)
and v.left.target.id
in [
ast_expr.id
@@ -1032,14 +1020,12 @@ class AssertionRewriter(ast.NodeVisitor):
]
):
pytest_temp = self.variable()
- self.variables_overwrite[self.scope][
- v.left.target.id
- ] = v.left # type:ignore[assignment]
+ self.variables_overwrite[self.scope][v.left.target.id] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
- expl_format = self.pop_format_context(astStr(expl))
+ expl_format = self.pop_format_context(ast.Constant(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
@@ -1051,7 +1037,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = body = inner
self.statements = save
self.expl_stmts = fail_save
- expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
+ expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
@@ -1078,9 +1064,7 @@ class AssertionRewriter(ast.NodeVisitor):
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {}
):
- arg = self.variables_overwrite[self.scope][
- arg.id
- ] # type:ignore[assignment]
+ arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
@@ -1088,9 +1072,7 @@ class AssertionRewriter(ast.NodeVisitor):
if isinstance(
keyword.value, ast.Name
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
- keyword.value = self.variables_overwrite[self.scope][
- keyword.value.id
- ] # type:ignore[assignment]
+ keyword.value = self.variables_overwrite[self.scope][keyword.value.id] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
@@ -1127,13 +1109,9 @@ class AssertionRewriter(ast.NodeVisitor):
if isinstance(
comp.left, ast.Name
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
- comp.left = self.variables_overwrite[self.scope][
- comp.left.id
- ] # type:ignore[assignment]
- if isinstance(comp.left, namedExpr):
- self.variables_overwrite[self.scope][
- comp.left.target.id
- ] = comp.left # type:ignore[assignment]
+ comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment]
+ if isinstance(comp.left, ast.NamedExpr):
+ self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
@@ -1146,22 +1124,20 @@ class AssertionRewriter(ast.NodeVisitor):
results = [left_res]
for i, op, next_operand in it:
if (
- isinstance(next_operand, namedExpr)
+ isinstance(next_operand, ast.NamedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
- self.variables_overwrite[self.scope][
- left_res.id
- ] = next_operand # type:ignore[assignment]
+ self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"
results.append(next_res)
sym = BINOP_MAP[op.__class__]
- syms.append(astStr(sym))
+ syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
- expls.append(astStr(expl))
+ expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
@@ -1205,7 +1181,7 @@ def try_makedirs(cache_dir: Path) -> bool:
def get_cache_dir(file_path: Path) -> Path:
"""Return the cache directory to write .pyc files for the given .py file path."""
- if sys.version_info >= (3, 8) and sys.pycache_prefix:
+ if sys.pycache_prefix:
# given:
# prefix = '/tmp/pycs'
# path = '/home/user/proj/test_app.py'
diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
index dfd6f65d281..902d4baf846 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
@@ -1,12 +1,14 @@
"""Utilities for truncating assertion output.
Current default behaviour is to truncate assertion explanations at
-~8 terminal lines, unless running in "-vv" mode or running on CI.
+terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI.
"""
+
from typing import List
from typing import Optional
from _pytest.assertion import util
+from _pytest.config import Config
from _pytest.nodes import Item
@@ -26,7 +28,7 @@ def truncate_if_required(
def _should_truncate_item(item: Item) -> bool:
"""Whether or not this test item is eligible for truncation."""
- verbose = item.config.option.verbose
+ verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
return verbose < 2 and not util.running_on_ci()
diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py
index 39ca5403e04..a7074115d65 100644
--- a/contrib/python/pytest/py3/_pytest/assertion/util.py
+++ b/contrib/python/pytest/py3/_pytest/assertion/util.py
@@ -1,4 +1,5 @@
"""Utilities for assertion debugging."""
+
import collections.abc
import os
import pprint
@@ -7,18 +8,21 @@ from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
+from typing import Literal
from typing import Mapping
from typing import Optional
+from typing import Protocol
from typing import Sequence
from unicodedata import normalize
-import _pytest._code
from _pytest import outcomes
-from _pytest._io.saferepr import _pformat_dispatch
+import _pytest._code
+from _pytest._io.pprint import PrettyPrinter
from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited
from _pytest.config import Config
+
# The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the
@@ -33,6 +37,11 @@ _assertion_pass: Optional[Callable[[int, str, str], None]] = None
_config: Optional[Config] = None
+class _HighlightFunc(Protocol):
+ def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
+ """Apply highlighting to the given source."""
+
+
def format_explanation(explanation: str) -> str:
r"""Format an explanation.
@@ -161,7 +170,7 @@ def assertrepr_compare(
config, op: str, left: Any, right: Any, use_ascii: bool = False
) -> Optional[List[str]]:
"""Return specialised explanations for some operators/operands."""
- verbose = config.getoption("verbose")
+ verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS)
# Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier.
# See issue #3246.
@@ -185,14 +194,31 @@ def assertrepr_compare(
right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii)
summary = f"{left_repr} {op} {right_repr}"
+ highlighter = config.get_terminal_writer()._highlight
explanation = None
try:
if op == "==":
- explanation = _compare_eq_any(left, right, verbose)
+ explanation = _compare_eq_any(left, right, highlighter, verbose)
elif op == "not in":
if istext(left) and istext(right):
explanation = _notin_text(left, right, verbose)
+ elif op == "!=":
+ if isset(left) and isset(right):
+ explanation = ["Both sets are equal"]
+ elif op == ">=":
+ if isset(left) and isset(right):
+ explanation = _compare_gte_set(left, right, highlighter, verbose)
+ elif op == "<=":
+ if isset(left) and isset(right):
+ explanation = _compare_lte_set(left, right, highlighter, verbose)
+ elif op == ">":
+ if isset(left) and isset(right):
+ explanation = _compare_gt_set(left, right, highlighter, verbose)
+ elif op == "<":
+ if isset(left) and isset(right):
+ explanation = _compare_lt_set(left, right, highlighter, verbose)
+
except outcomes.Exit:
raise
except Exception:
@@ -206,10 +232,14 @@ def assertrepr_compare(
if not explanation:
return None
+ if explanation[0] != "":
+ explanation = [""] + explanation
return [summary] + explanation
-def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
+def _compare_eq_any(
+ left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
+) -> List[str]:
explanation = []
if istext(left) and istext(right):
explanation = _diff_text(left, right, verbose)
@@ -222,23 +252,23 @@ def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]:
other_side = right if isinstance(left, ApproxBase) else left
explanation = approx_side._repr_compare(other_side)
- elif type(left) == type(right) and (
+ elif type(left) is type(right) and (
isdatacls(left) or isattrs(left) or isnamedtuple(left)
):
# Note: unlike dataclasses/attrs, namedtuples compare only the
# field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available.
- explanation = _compare_eq_cls(left, right, verbose)
+ explanation = _compare_eq_cls(left, right, highlighter, verbose)
elif issequence(left) and issequence(right):
- explanation = _compare_eq_sequence(left, right, verbose)
+ explanation = _compare_eq_sequence(left, right, highlighter, verbose)
elif isset(left) and isset(right):
- explanation = _compare_eq_set(left, right, verbose)
+ explanation = _compare_eq_set(left, right, highlighter, verbose)
elif isdict(left) and isdict(right):
- explanation = _compare_eq_dict(left, right, verbose)
+ explanation = _compare_eq_dict(left, right, highlighter, verbose)
if isiterable(left) and isiterable(right):
- expl = _compare_eq_iterable(left, right, verbose)
+ expl = _compare_eq_iterable(left, right, highlighter, verbose)
explanation.extend(expl)
return explanation
@@ -273,8 +303,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
if i > 42:
i -= 10 # Provide some context
explanation += [
- "Skipping {} identical trailing "
- "characters in diff, use -v to show".format(i)
+ f"Skipping {i} identical trailing "
+ "characters in diff, use -v to show"
]
left = left[:-i]
right = right[:-i]
@@ -292,51 +322,40 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
return explanation
-def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
- """Move opening/closing parenthesis/bracket to own lines."""
- opening = lines[0][:1]
- if opening in ["(", "[", "{"]:
- lines[0] = " " + lines[0][1:]
- lines[:] = [opening] + lines
- closing = lines[-1][-1:]
- if closing in [")", "]", "}"]:
- lines[-1] = lines[-1][:-1] + ","
- lines[:] = lines + [closing]
-
-
def _compare_eq_iterable(
- left: Iterable[Any], right: Iterable[Any], verbose: int = 0
+ left: Iterable[Any],
+ right: Iterable[Any],
+ highligher: _HighlightFunc,
+ verbose: int = 0,
) -> List[str]:
if verbose <= 0 and not running_on_ci():
return ["Use -v to get more diff"]
# dynamic import to speedup pytest
import difflib
- left_formatting = pprint.pformat(left).splitlines()
- right_formatting = pprint.pformat(right).splitlines()
+ left_formatting = PrettyPrinter().pformat(left).splitlines()
+ right_formatting = PrettyPrinter().pformat(right).splitlines()
- # Re-format for different output lengths.
- lines_left = len(left_formatting)
- lines_right = len(right_formatting)
- if lines_left != lines_right:
- left_formatting = _pformat_dispatch(left).splitlines()
- right_formatting = _pformat_dispatch(right).splitlines()
-
- if lines_left > 1 or lines_right > 1:
- _surrounding_parens_on_own_lines(left_formatting)
- _surrounding_parens_on_own_lines(right_formatting)
-
- explanation = ["Full diff:"]
+ explanation = ["", "Full diff:"]
# "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend(
- line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
+ highligher(
+ "\n".join(
+ line.rstrip()
+ for line in difflib.ndiff(right_formatting, left_formatting)
+ ),
+ lexer="diff",
+ ).splitlines()
)
return explanation
def _compare_eq_sequence(
- left: Sequence[Any], right: Sequence[Any], verbose: int = 0
+ left: Sequence[Any],
+ right: Sequence[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
) -> List[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation: List[str] = []
@@ -359,7 +378,10 @@ def _compare_eq_sequence(
left_value = left[i]
right_value = right[i]
- explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"]
+ explanation.append(
+ f"At index {i} diff:"
+ f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}"
+ )
break
if comparing_bytes:
@@ -379,34 +401,91 @@ def _compare_eq_sequence(
extra = saferepr(right[len_left])
if len_diff == 1:
- explanation += [f"{dir_with_more} contains one more item: {extra}"]
+ explanation += [
+ f"{dir_with_more} contains one more item: {highlighter(extra)}"
+ ]
else:
explanation += [
"%s contains %d more items, first extra item: %s"
- % (dir_with_more, len_diff, extra)
+ % (dir_with_more, len_diff, highlighter(extra))
]
return explanation
def _compare_eq_set(
- left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
+ left: AbstractSet[Any],
+ right: AbstractSet[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
+) -> List[str]:
+ explanation = []
+ explanation.extend(_set_one_sided_diff("left", left, right, highlighter))
+ explanation.extend(_set_one_sided_diff("right", right, left, highlighter))
+ return explanation
+
+
+def _compare_gt_set(
+ left: AbstractSet[Any],
+ right: AbstractSet[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
+) -> List[str]:
+ explanation = _compare_gte_set(left, right, highlighter)
+ if not explanation:
+ return ["Both sets are equal"]
+ return explanation
+
+
+def _compare_lt_set(
+ left: AbstractSet[Any],
+ right: AbstractSet[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
+) -> List[str]:
+ explanation = _compare_lte_set(left, right, highlighter)
+ if not explanation:
+ return ["Both sets are equal"]
+ return explanation
+
+
+def _compare_gte_set(
+ left: AbstractSet[Any],
+ right: AbstractSet[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
+) -> List[str]:
+ return _set_one_sided_diff("right", right, left, highlighter)
+
+
+def _compare_lte_set(
+ left: AbstractSet[Any],
+ right: AbstractSet[Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
+) -> List[str]:
+ return _set_one_sided_diff("left", left, right, highlighter)
+
+
+def _set_one_sided_diff(
+ posn: str,
+ set1: AbstractSet[Any],
+ set2: AbstractSet[Any],
+ highlighter: _HighlightFunc,
) -> List[str]:
explanation = []
- diff_left = left - right
- diff_right = right - left
- if diff_left:
- explanation.append("Extra items in the left set:")
- for item in diff_left:
- explanation.append(saferepr(item))
- if diff_right:
- explanation.append("Extra items in the right set:")
- for item in diff_right:
- explanation.append(saferepr(item))
+ diff = set1 - set2
+ if diff:
+ explanation.append(f"Extra items in the {posn} set:")
+ for item in diff:
+ explanation.append(highlighter(saferepr(item)))
return explanation
def _compare_eq_dict(
- left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
+ left: Mapping[Any, Any],
+ right: Mapping[Any, Any],
+ highlighter: _HighlightFunc,
+ verbose: int = 0,
) -> List[str]:
explanation: List[str] = []
set_left = set(left)
@@ -417,12 +496,16 @@ def _compare_eq_dict(
explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
elif same:
explanation += ["Common items:"]
- explanation += pprint.pformat(same).splitlines()
+ explanation += highlighter(pprint.pformat(same)).splitlines()
diff = {k for k in common if left[k] != right[k]}
if diff:
explanation += ["Differing items:"]
for k in diff:
- explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
+ explanation += [
+ highlighter(saferepr({k: left[k]}))
+ + " != "
+ + highlighter(saferepr({k: right[k]}))
+ ]
extra_left = set_left - set_right
len_extra_left = len(extra_left)
if len_extra_left:
@@ -431,7 +514,7 @@ def _compare_eq_dict(
% (len_extra_left, "" if len_extra_left == 1 else "s")
)
explanation.extend(
- pprint.pformat({k: left[k] for k in extra_left}).splitlines()
+ highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines()
)
extra_right = set_right - set_left
len_extra_right = len(extra_right)
@@ -441,12 +524,14 @@ def _compare_eq_dict(
% (len_extra_right, "" if len_extra_right == 1 else "s")
)
explanation.extend(
- pprint.pformat({k: right[k] for k in extra_right}).splitlines()
+ highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines()
)
return explanation
-def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
+def _compare_eq_cls(
+ left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
+) -> List[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
@@ -478,21 +563,23 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
explanation.append("Omitting %s identical items, use -vv to show" % len(same))
elif same:
explanation += ["Matching attributes:"]
- explanation += pprint.pformat(same).splitlines()
+ explanation += highlighter(pprint.pformat(same)).splitlines()
if diff:
explanation += ["Differing attributes:"]
- explanation += pprint.pformat(diff).splitlines()
+ explanation += highlighter(pprint.pformat(diff)).splitlines()
for field in diff:
field_left = getattr(left, field)
field_right = getattr(right, field)
explanation += [
"",
- "Drill down into differing attribute %s:" % field,
- ("%s%s: %r != %r") % (indent, field, field_left, field_right),
+ f"Drill down into differing attribute {field}:",
+ f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}",
]
explanation += [
indent + line
- for line in _compare_eq_any(field_left, field_right, verbose)
+ for line in _compare_eq_any(
+ field_left, field_right, highlighter, verbose
+ )
]
return explanation