diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2025-05-05 12:31:52 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2025-05-05 12:41:33 +0300 |
commit | 6ff49ec58061f642c3a2f83c61eba12820787dfc (patch) | |
tree | c733ec9bdb15ed280080d31dea8725bfec717acd /contrib/python/pytest/py3/_pytest/assertion | |
parent | eefca8305c6a545cc6b16dca3eb0d91dcef2adcd (diff) | |
download | ydb-6ff49ec58061f642c3a2f83c61eba12820787dfc.tar.gz |
Intermediate changes
commit_hash:8b3bb826b17db8329ed1221f545c0645f12c552d
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
4 files changed, 223 insertions, 148 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py index a46e58136ba..2bce0ec7cb5 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/__init__.py +++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py @@ -1,4 +1,5 @@ """Support for presenting detailed information in failing assertions.""" + import sys from typing import Any from typing import Generator @@ -15,6 +16,7 @@ from _pytest.config import hookimpl from _pytest.config.argparsing import Parser from _pytest.nodes import Item + if TYPE_CHECKING: from _pytest.main import Session @@ -42,6 +44,14 @@ def pytest_addoption(parser: Parser) -> None: help="Enables the pytest_assertion_pass hook. " "Make sure to delete any previously generated pyc cache files.", ) + Config._add_verbosity_ini( + parser, + Config.VERBOSITY_ASSERTIONS, + help=( + "Specify a verbosity level for assertions, overriding the main level. " + "Higher levels will provide more detailed explanation when an assertion fails." + ), + ) def register_assert_rewrite(*names: str) -> None: @@ -112,15 +122,14 @@ def pytest_collection(session: "Session") -> None: assertstate.hook.set_session(session) -@hookimpl(tryfirst=True, hookwrapper=True) -def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: +@hookimpl(wrapper=True, tryfirst=True) +def pytest_runtest_protocol(item: Item) -> Generator[None, object, object]: """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks. The rewrite module will use util._reprcompare if it exists to use custom reporting via the pytest_assertrepr_compare hook. This sets up this custom comparison for the test. """ - ihook = item.ihook def callbinrepr(op, left: object, right: object) -> Optional[str]: @@ -162,10 +171,11 @@ def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]: util._assertion_pass = call_assertion_pass_hook - yield - - util._reprcompare, util._assertion_pass = saved_assert_hooks - util._config = None + try: + return (yield) + finally: + util._reprcompare, util._assertion_pass = saved_assert_hooks + util._config = None def pytest_sessionfinish(session: "Session") -> None: diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py index d1974bb3b4a..0ab6eaa1393 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py +++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py @@ -1,5 +1,7 @@ """Rewrite assertion AST to produce nice error messages.""" + import ast +from collections import defaultdict import errno import functools import importlib.abc @@ -9,13 +11,12 @@ import io import itertools import marshal import os +from pathlib import Path +from pathlib import PurePath import struct import sys import tokenize import types -from collections import defaultdict -from pathlib import Path -from pathlib import PurePath from typing import Callable from typing import Dict from typing import IO @@ -33,29 +34,20 @@ from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import saferepr from _pytest._version import version from _pytest.assertion import util -from _pytest.assertion.util import ( # noqa: F401 - format_explanation as _format_explanation, -) from _pytest.config import Config from _pytest.main import Session from _pytest.pathlib import absolutepath from _pytest.pathlib import fnmatch_ex from _pytest.stash import StashKey + +# fmt: off +from _pytest.assertion.util import format_explanation as _format_explanation # noqa:F401, isort:skip +# fmt:on + if TYPE_CHECKING: from _pytest.assertion import AssertionState -if sys.version_info >= (3, 8): - namedExpr = ast.NamedExpr - astNameConstant = ast.Constant - astStr = ast.Constant - astNum = ast.Constant -else: - namedExpr = ast.Expr - astNameConstant = ast.NameConstant - astStr = ast.Str - astNum = ast.Num - class Sentinel: pass @@ -437,7 +429,10 @@ def _saferepr(obj: object) -> str: def _get_maxsize_for_saferepr(config: Optional[Config]) -> Optional[int]: """Get `maxsize` configuration for saferepr based on the given config object.""" - verbosity = config.getoption("verbose") if config is not None else 0 + if config is None: + verbosity = 0 + else: + verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) if verbosity >= 2: return None if verbosity >= 1: @@ -604,13 +599,6 @@ def _get_assertion_exprs(src: bytes) -> Dict[int, str]: return ret -def _get_ast_constant_value(value: astStr) -> object: - if sys.version_info >= (3, 8): - return value.value - else: - return value.s - - class AssertionRewriter(ast.NodeVisitor): """Assertion rewriting implementation. @@ -706,11 +694,10 @@ class AssertionRewriter(ast.NodeVisitor): if ( expect_docstring and isinstance(item, ast.Expr) - and isinstance(item.value, astStr) - and isinstance(_get_ast_constant_value(item.value), str) + and isinstance(item.value, ast.Constant) + and isinstance(item.value.value, str) ): - doc = _get_ast_constant_value(item.value) - assert isinstance(doc, str) + doc = item.value.value if self.is_rewrite_disabled(doc): return expect_docstring = False @@ -850,7 +837,7 @@ class AssertionRewriter(ast.NodeVisitor): current = self.stack.pop() if self.stack: self.explanation_specifiers = self.stack[-1] - keys = [astStr(key) for key in current.keys()] + keys = [ast.Constant(key) for key in current.keys()] format_dict = ast.Dict(keys, list(current.values())) form = ast.BinOp(expl_expr, ast.Mod(), format_dict) name = "@py_format" + str(next(self.variable_counter)) @@ -874,9 +861,10 @@ class AssertionRewriter(ast.NodeVisitor): the expression is false. """ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: - from _pytest.warning_types import PytestAssertRewriteWarning import warnings + from _pytest.warning_types import PytestAssertRewriteWarning + # TODO: This assert should not be needed. assert self.module_path is not None warnings.warn_explicit( @@ -904,16 +892,16 @@ class AssertionRewriter(ast.NodeVisitor): negation = ast.UnaryOp(ast.Not(), top_condition) if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook - msg = self.pop_format_context(astStr(explanation)) + msg = self.pop_format_context(ast.Constant(explanation)) # Failed if assert_.msg: assertmsg = self.helper("_format_assertmsg", assert_.msg) gluestr = "\n>assert " else: - assertmsg = astStr("") + assertmsg = ast.Constant("") gluestr = "assert " - err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg) + err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg) err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation) err_name = ast.Name("AssertionError", ast.Load()) fmt = self.helper("_format_explanation", err_msg) @@ -929,8 +917,8 @@ class AssertionRewriter(ast.NodeVisitor): hook_call_pass = ast.Expr( self.helper( "_call_assertion_pass", - astNum(assert_.lineno), - astStr(orig), + ast.Constant(assert_.lineno), + ast.Constant(orig), fmt_pass, ) ) @@ -949,7 +937,7 @@ class AssertionRewriter(ast.NodeVisitor): variables = [ ast.Name(name, ast.Store()) for name in self.format_variables ] - clear_format = ast.Assign(variables, astNameConstant(None)) + clear_format = ast.Assign(variables, ast.Constant(None)) self.statements.append(clear_format) else: # Original assertion rewriting @@ -960,9 +948,9 @@ class AssertionRewriter(ast.NodeVisitor): assertmsg = self.helper("_format_assertmsg", assert_.msg) explanation = "\n>assert " + explanation else: - assertmsg = astStr("") + assertmsg = ast.Constant("") explanation = "assert " + explanation - template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation)) + template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation)) msg = self.pop_format_context(template) fmt = self.helper("_format_explanation", msg) err_name = ast.Name("AssertionError", ast.Load()) @@ -974,7 +962,7 @@ class AssertionRewriter(ast.NodeVisitor): # Clear temporary variables by setting them to None. if self.variables: variables = [ast.Name(name, ast.Store()) for name in self.variables] - clear = ast.Assign(variables, astNameConstant(None)) + clear = ast.Assign(variables, ast.Constant(None)) self.statements.append(clear) # Fix locations (line numbers/column offsets). for stmt in self.statements: @@ -982,26 +970,26 @@ class AssertionRewriter(ast.NodeVisitor): ast.copy_location(node, assert_) return self.statements - def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: + def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]: # This method handles the 'walrus operator' repr of the target # name if it's a local variable or _should_repr_global_name() # thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id # type: ignore[attr-defined] - inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs]) + inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) dorepr = self.helper("_should_repr_global_name", name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), astStr(target_id)) + expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) return name, self.explanation_param(expr) def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) - inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs]) + inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs]) dorepr = self.helper("_should_repr_global_name", name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), astStr(name.id)) + expr = ast.IfExp(test, self.display(name), ast.Constant(name.id)) return name, self.explanation_param(expr) def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: @@ -1020,10 +1008,10 @@ class AssertionRewriter(ast.NodeVisitor): # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner - # Check if the left operand is a namedExpr and the value has already been visited + # Check if the left operand is a ast.NamedExpr and the value has already been visited if ( isinstance(v, ast.Compare) - and isinstance(v.left, namedExpr) + and isinstance(v.left, ast.NamedExpr) and v.left.target.id in [ ast_expr.id @@ -1032,14 +1020,12 @@ class AssertionRewriter(ast.NodeVisitor): ] ): pytest_temp = self.variable() - self.variables_overwrite[self.scope][ - v.left.target.id - ] = v.left # type:ignore[assignment] + self.variables_overwrite[self.scope][v.left.target.id] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) - expl_format = self.pop_format_context(astStr(expl)) + expl_format = self.pop_format_context(ast.Constant(expl)) call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: @@ -1051,7 +1037,7 @@ class AssertionRewriter(ast.NodeVisitor): self.statements = body = inner self.statements = save self.expl_stmts = fail_save - expl_template = self.helper("_format_boolop", expl_list, astNum(is_or)) + expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or)) expl = self.pop_format_context(expl_template) return ast.Name(res_var, ast.Load()), self.explanation_param(expl) @@ -1078,9 +1064,7 @@ class AssertionRewriter(ast.NodeVisitor): if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( self.scope, {} ): - arg = self.variables_overwrite[self.scope][ - arg.id - ] # type:ignore[assignment] + arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) @@ -1088,9 +1072,7 @@ class AssertionRewriter(ast.NodeVisitor): if isinstance( keyword.value, ast.Name ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): - keyword.value = self.variables_overwrite[self.scope][ - keyword.value.id - ] # type:ignore[assignment] + keyword.value = self.variables_overwrite[self.scope][keyword.value.id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1127,13 +1109,9 @@ class AssertionRewriter(ast.NodeVisitor): if isinstance( comp.left, ast.Name ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): - comp.left = self.variables_overwrite[self.scope][ - comp.left.id - ] # type:ignore[assignment] - if isinstance(comp.left, namedExpr): - self.variables_overwrite[self.scope][ - comp.left.target.id - ] = comp.left # type:ignore[assignment] + comp.left = self.variables_overwrite[self.scope][comp.left.id] # type:ignore[assignment] + if isinstance(comp.left, ast.NamedExpr): + self.variables_overwrite[self.scope][comp.left.target.id] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1146,22 +1124,20 @@ class AssertionRewriter(ast.NodeVisitor): results = [left_res] for i, op, next_operand in it: if ( - isinstance(next_operand, namedExpr) + isinstance(next_operand, ast.NamedExpr) and isinstance(left_res, ast.Name) and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][ - left_res.id - ] = next_operand # type:ignore[assignment] + self.variables_overwrite[self.scope][left_res.id] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" results.append(next_res) sym = BINOP_MAP[op.__class__] - syms.append(astStr(sym)) + syms.append(ast.Constant(sym)) expl = f"{left_expl} {sym} {next_expl}" - expls.append(astStr(expl)) + expls.append(ast.Constant(expl)) res_expr = ast.Compare(left_res, [op], [next_res]) self.statements.append(ast.Assign([store_names[i]], res_expr)) left_res, left_expl = next_res, next_expl @@ -1205,7 +1181,7 @@ def try_makedirs(cache_dir: Path) -> bool: def get_cache_dir(file_path: Path) -> Path: """Return the cache directory to write .pyc files for the given .py file path.""" - if sys.version_info >= (3, 8) and sys.pycache_prefix: + if sys.pycache_prefix: # given: # prefix = '/tmp/pycs' # path = '/home/user/proj/test_app.py' diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py index dfd6f65d281..902d4baf846 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/truncate.py +++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py @@ -1,12 +1,14 @@ """Utilities for truncating assertion output. Current default behaviour is to truncate assertion explanations at -~8 terminal lines, unless running in "-vv" mode or running on CI. +terminal lines, unless running with an assertions verbosity level of at least 2 or running on CI. """ + from typing import List from typing import Optional from _pytest.assertion import util +from _pytest.config import Config from _pytest.nodes import Item @@ -26,7 +28,7 @@ def truncate_if_required( def _should_truncate_item(item: Item) -> bool: """Whether or not this test item is eligible for truncation.""" - verbose = item.config.option.verbose + verbose = item.config.get_verbosity(Config.VERBOSITY_ASSERTIONS) return verbose < 2 and not util.running_on_ci() diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py index 39ca5403e04..a7074115d65 100644 --- a/contrib/python/pytest/py3/_pytest/assertion/util.py +++ b/contrib/python/pytest/py3/_pytest/assertion/util.py @@ -1,4 +1,5 @@ """Utilities for assertion debugging.""" + import collections.abc import os import pprint @@ -7,18 +8,21 @@ from typing import Any from typing import Callable from typing import Iterable from typing import List +from typing import Literal from typing import Mapping from typing import Optional +from typing import Protocol from typing import Sequence from unicodedata import normalize -import _pytest._code from _pytest import outcomes -from _pytest._io.saferepr import _pformat_dispatch +import _pytest._code +from _pytest._io.pprint import PrettyPrinter from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited from _pytest.config import Config + # The _reprcompare attribute on the util module is used by the new assertion # interpretation code and assertion rewriter to detect this plugin was # loaded and in turn call the hooks defined here as part of the @@ -33,6 +37,11 @@ _assertion_pass: Optional[Callable[[int, str, str], None]] = None _config: Optional[Config] = None +class _HighlightFunc(Protocol): + def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Apply highlighting to the given source.""" + + def format_explanation(explanation: str) -> str: r"""Format an explanation. @@ -161,7 +170,7 @@ def assertrepr_compare( config, op: str, left: Any, right: Any, use_ascii: bool = False ) -> Optional[List[str]]: """Return specialised explanations for some operators/operands.""" - verbose = config.getoption("verbose") + verbose = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) # Strings which normalize equal are often hard to distinguish when printed; use ascii() to make this easier. # See issue #3246. @@ -185,14 +194,31 @@ def assertrepr_compare( right_repr = saferepr(right, maxsize=maxsize, use_ascii=use_ascii) summary = f"{left_repr} {op} {right_repr}" + highlighter = config.get_terminal_writer()._highlight explanation = None try: if op == "==": - explanation = _compare_eq_any(left, right, verbose) + explanation = _compare_eq_any(left, right, highlighter, verbose) elif op == "not in": if istext(left) and istext(right): explanation = _notin_text(left, right, verbose) + elif op == "!=": + if isset(left) and isset(right): + explanation = ["Both sets are equal"] + elif op == ">=": + if isset(left) and isset(right): + explanation = _compare_gte_set(left, right, highlighter, verbose) + elif op == "<=": + if isset(left) and isset(right): + explanation = _compare_lte_set(left, right, highlighter, verbose) + elif op == ">": + if isset(left) and isset(right): + explanation = _compare_gt_set(left, right, highlighter, verbose) + elif op == "<": + if isset(left) and isset(right): + explanation = _compare_lt_set(left, right, highlighter, verbose) + except outcomes.Exit: raise except Exception: @@ -206,10 +232,14 @@ def assertrepr_compare( if not explanation: return None + if explanation[0] != "": + explanation = [""] + explanation return [summary] + explanation -def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: +def _compare_eq_any( + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 +) -> List[str]: explanation = [] if istext(left) and istext(right): explanation = _diff_text(left, right, verbose) @@ -222,23 +252,23 @@ def _compare_eq_any(left: Any, right: Any, verbose: int = 0) -> List[str]: other_side = right if isinstance(left, ApproxBase) else left explanation = approx_side._repr_compare(other_side) - elif type(left) == type(right) and ( + elif type(left) is type(right) and ( isdatacls(left) or isattrs(left) or isnamedtuple(left) ): # Note: unlike dataclasses/attrs, namedtuples compare only the # field values, not the type or field names. But this branch # intentionally only handles the same-type case, which was often # used in older code bases before dataclasses/attrs were available. - explanation = _compare_eq_cls(left, right, verbose) + explanation = _compare_eq_cls(left, right, highlighter, verbose) elif issequence(left) and issequence(right): - explanation = _compare_eq_sequence(left, right, verbose) + explanation = _compare_eq_sequence(left, right, highlighter, verbose) elif isset(left) and isset(right): - explanation = _compare_eq_set(left, right, verbose) + explanation = _compare_eq_set(left, right, highlighter, verbose) elif isdict(left) and isdict(right): - explanation = _compare_eq_dict(left, right, verbose) + explanation = _compare_eq_dict(left, right, highlighter, verbose) if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, verbose) + expl = _compare_eq_iterable(left, right, highlighter, verbose) explanation.extend(expl) return explanation @@ -273,8 +303,8 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: if i > 42: i -= 10 # Provide some context explanation += [ - "Skipping {} identical trailing " - "characters in diff, use -v to show".format(i) + f"Skipping {i} identical trailing " + "characters in diff, use -v to show" ] left = left[:-i] right = right[:-i] @@ -292,51 +322,40 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]: return explanation -def _surrounding_parens_on_own_lines(lines: List[str]) -> None: - """Move opening/closing parenthesis/bracket to own lines.""" - opening = lines[0][:1] - if opening in ["(", "[", "{"]: - lines[0] = " " + lines[0][1:] - lines[:] = [opening] + lines - closing = lines[-1][-1:] - if closing in [")", "]", "}"]: - lines[-1] = lines[-1][:-1] + "," - lines[:] = lines + [closing] - - def _compare_eq_iterable( - left: Iterable[Any], right: Iterable[Any], verbose: int = 0 + left: Iterable[Any], + right: Iterable[Any], + highligher: _HighlightFunc, + verbose: int = 0, ) -> List[str]: if verbose <= 0 and not running_on_ci(): return ["Use -v to get more diff"] # dynamic import to speedup pytest import difflib - left_formatting = pprint.pformat(left).splitlines() - right_formatting = pprint.pformat(right).splitlines() + left_formatting = PrettyPrinter().pformat(left).splitlines() + right_formatting = PrettyPrinter().pformat(right).splitlines() - # Re-format for different output lengths. - lines_left = len(left_formatting) - lines_right = len(right_formatting) - if lines_left != lines_right: - left_formatting = _pformat_dispatch(left).splitlines() - right_formatting = _pformat_dispatch(right).splitlines() - - if lines_left > 1 or lines_right > 1: - _surrounding_parens_on_own_lines(left_formatting) - _surrounding_parens_on_own_lines(right_formatting) - - explanation = ["Full diff:"] + explanation = ["", "Full diff:"] # "right" is the expected base against which we compare "left", # see https://github.com/pytest-dev/pytest/issues/3333 explanation.extend( - line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting) + highligher( + "\n".join( + line.rstrip() + for line in difflib.ndiff(right_formatting, left_formatting) + ), + lexer="diff", + ).splitlines() ) return explanation def _compare_eq_sequence( - left: Sequence[Any], right: Sequence[Any], verbose: int = 0 + left: Sequence[Any], + right: Sequence[Any], + highlighter: _HighlightFunc, + verbose: int = 0, ) -> List[str]: comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) explanation: List[str] = [] @@ -359,7 +378,10 @@ def _compare_eq_sequence( left_value = left[i] right_value = right[i] - explanation += [f"At index {i} diff: {left_value!r} != {right_value!r}"] + explanation.append( + f"At index {i} diff:" + f" {highlighter(repr(left_value))} != {highlighter(repr(right_value))}" + ) break if comparing_bytes: @@ -379,34 +401,91 @@ def _compare_eq_sequence( extra = saferepr(right[len_left]) if len_diff == 1: - explanation += [f"{dir_with_more} contains one more item: {extra}"] + explanation += [ + f"{dir_with_more} contains one more item: {highlighter(extra)}" + ] else: explanation += [ "%s contains %d more items, first extra item: %s" - % (dir_with_more, len_diff, extra) + % (dir_with_more, len_diff, highlighter(extra)) ] return explanation def _compare_eq_set( - left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0 + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> List[str]: + explanation = [] + explanation.extend(_set_one_sided_diff("left", left, right, highlighter)) + explanation.extend(_set_one_sided_diff("right", right, left, highlighter)) + return explanation + + +def _compare_gt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> List[str]: + explanation = _compare_gte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_lt_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> List[str]: + explanation = _compare_lte_set(left, right, highlighter) + if not explanation: + return ["Both sets are equal"] + return explanation + + +def _compare_gte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> List[str]: + return _set_one_sided_diff("right", right, left, highlighter) + + +def _compare_lte_set( + left: AbstractSet[Any], + right: AbstractSet[Any], + highlighter: _HighlightFunc, + verbose: int = 0, +) -> List[str]: + return _set_one_sided_diff("left", left, right, highlighter) + + +def _set_one_sided_diff( + posn: str, + set1: AbstractSet[Any], + set2: AbstractSet[Any], + highlighter: _HighlightFunc, ) -> List[str]: explanation = [] - diff_left = left - right - diff_right = right - left - if diff_left: - explanation.append("Extra items in the left set:") - for item in diff_left: - explanation.append(saferepr(item)) - if diff_right: - explanation.append("Extra items in the right set:") - for item in diff_right: - explanation.append(saferepr(item)) + diff = set1 - set2 + if diff: + explanation.append(f"Extra items in the {posn} set:") + for item in diff: + explanation.append(highlighter(saferepr(item))) return explanation def _compare_eq_dict( - left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0 + left: Mapping[Any, Any], + right: Mapping[Any, Any], + highlighter: _HighlightFunc, + verbose: int = 0, ) -> List[str]: explanation: List[str] = [] set_left = set(left) @@ -417,12 +496,16 @@ def _compare_eq_dict( explanation += ["Omitting %s identical items, use -vv to show" % len(same)] elif same: explanation += ["Common items:"] - explanation += pprint.pformat(same).splitlines() + explanation += highlighter(pprint.pformat(same)).splitlines() diff = {k for k in common if left[k] != right[k]} if diff: explanation += ["Differing items:"] for k in diff: - explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})] + explanation += [ + highlighter(saferepr({k: left[k]})) + + " != " + + highlighter(saferepr({k: right[k]})) + ] extra_left = set_left - set_right len_extra_left = len(extra_left) if len_extra_left: @@ -431,7 +514,7 @@ def _compare_eq_dict( % (len_extra_left, "" if len_extra_left == 1 else "s") ) explanation.extend( - pprint.pformat({k: left[k] for k in extra_left}).splitlines() + highlighter(pprint.pformat({k: left[k] for k in extra_left})).splitlines() ) extra_right = set_right - set_left len_extra_right = len(extra_right) @@ -441,12 +524,14 @@ def _compare_eq_dict( % (len_extra_right, "" if len_extra_right == 1 else "s") ) explanation.extend( - pprint.pformat({k: right[k] for k in extra_right}).splitlines() + highlighter(pprint.pformat({k: right[k] for k in extra_right})).splitlines() ) return explanation -def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: +def _compare_eq_cls( + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int +) -> List[str]: if not has_default_eq(left): return [] if isdatacls(left): @@ -478,21 +563,23 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]: explanation.append("Omitting %s identical items, use -vv to show" % len(same)) elif same: explanation += ["Matching attributes:"] - explanation += pprint.pformat(same).splitlines() + explanation += highlighter(pprint.pformat(same)).splitlines() if diff: explanation += ["Differing attributes:"] - explanation += pprint.pformat(diff).splitlines() + explanation += highlighter(pprint.pformat(diff)).splitlines() for field in diff: field_left = getattr(left, field) field_right = getattr(right, field) explanation += [ "", - "Drill down into differing attribute %s:" % field, - ("%s%s: %r != %r") % (indent, field, field_left, field_right), + f"Drill down into differing attribute {field}:", + f"{indent}{field}: {highlighter(repr(field_left))} != {highlighter(repr(field_right))}", ] explanation += [ indent + line - for line in _compare_eq_any(field_left, field_right, verbose) + for line in _compare_eq_any( + field_left, field_right, highlighter, verbose + ) ] return explanation |