aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py3/_pytest/assertion
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /contrib/python/pytest/py3/_pytest/assertion
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'contrib/python/pytest/py3/_pytest/assertion')
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/__init__.py179
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/rewrite.py1066
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/truncate.py95
-rw-r--r--contrib/python/pytest/py3/_pytest/assertion/util.py463
4 files changed, 1803 insertions, 0 deletions
diff --git a/contrib/python/pytest/py3/_pytest/assertion/__init__.py b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
new file mode 100644
index 0000000000..ee7fa6a3af
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/assertion/__init__.py
@@ -0,0 +1,179 @@
+"""
+support for presenting detailed information in failing assertions.
+"""
+import sys
+from typing import Any
+from typing import List
+from typing import Optional
+
+from _pytest.assertion import rewrite
+from _pytest.assertion import truncate
+from _pytest.assertion import util
+from _pytest.assertion.rewrite import assertstate_key
+from _pytest.compat import TYPE_CHECKING
+from _pytest.config import Config
+from _pytest.config import hookimpl
+
+if TYPE_CHECKING:
+ from _pytest.main import Session
+
+
+def pytest_addoption(parser):
+ group = parser.getgroup("debugconfig")
+ group.addoption(
+ "--assert",
+ action="store",
+ dest="assertmode",
+ choices=("rewrite", "plain"),
+ default="rewrite",
+ metavar="MODE",
+ help="""Control assertion debugging tools. 'plain'
+ performs no assertion debugging. 'rewrite'
+ (the default) rewrites assert statements in
+ test modules on import to provide assert
+ expression information.""",
+ )
+ parser.addini(
+ "enable_assertion_pass_hook",
+ type="bool",
+ default=False,
+ help="Enables the pytest_assertion_pass hook."
+ "Make sure to delete any previously generated pyc cache files.",
+ )
+
+
+def register_assert_rewrite(*names) -> None:
+ """Register one or more module names to be rewritten on import.
+
+ This function will make sure that this module or all modules inside
+ the package will get their assert statements rewritten.
+ Thus you should make sure to call this before the module is
+ actually imported, usually in your __init__.py if you are a plugin
+ using a package.
+
+ :raise TypeError: if the given module names are not strings.
+ """
+ for name in names:
+ if not isinstance(name, str):
+ msg = "expected module names as *args, got {0} instead"
+ raise TypeError(msg.format(repr(names)))
+ for hook in sys.meta_path:
+ if isinstance(hook, rewrite.AssertionRewritingHook):
+ importhook = hook
+ break
+ else:
+ # TODO(typing): Add a protocol for mark_rewrite() and use it
+ # for importhook and for PytestPluginManager.rewrite_hook.
+ importhook = DummyRewriteHook() # type: ignore
+ importhook.mark_rewrite(*names)
+
+
+class DummyRewriteHook:
+ """A no-op import hook for when rewriting is disabled."""
+
+ def mark_rewrite(self, *names):
+ pass
+
+
+class AssertionState:
+ """State for the assertion plugin."""
+
+ def __init__(self, config, mode):
+ self.mode = mode
+ self.trace = config.trace.root.get("assertion")
+ self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
+
+
+def install_importhook(config):
+ """Try to install the rewrite hook, raise SystemError if it fails."""
+ config._store[assertstate_key] = AssertionState(config, "rewrite")
+ config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
+ sys.meta_path.insert(0, hook)
+ config._store[assertstate_key].trace("installed rewrite import hook")
+
+ def undo():
+ hook = config._store[assertstate_key].hook
+ if hook is not None and hook in sys.meta_path:
+ sys.meta_path.remove(hook)
+
+ config.add_cleanup(undo)
+ return hook
+
+
+def pytest_collection(session: "Session") -> None:
+ # this hook is only called when test modules are collected
+ # so for example not in the master process of pytest-xdist
+ # (which does not collect test modules)
+ assertstate = session.config._store.get(assertstate_key, None)
+ if assertstate:
+ if assertstate.hook is not None:
+ assertstate.hook.set_session(session)
+
+
+@hookimpl(tryfirst=True, hookwrapper=True)
+def pytest_runtest_protocol(item):
+ """Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
+
+ The newinterpret and rewrite modules will use util._reprcompare if
+ it exists to use custom reporting via the
+ pytest_assertrepr_compare hook. This sets up this custom
+ comparison for the test.
+ """
+
+ def callbinrepr(op, left, right):
+ # type: (str, object, object) -> Optional[str]
+ """Call the pytest_assertrepr_compare hook and prepare the result
+
+ This uses the first result from the hook and then ensures the
+ following:
+ * Overly verbose explanations are truncated unless configured otherwise
+ (eg. if running in verbose mode).
+ * Embedded newlines are escaped to help util.format_explanation()
+ later.
+ * If the rewrite mode is used embedded %-characters are replaced
+ to protect later % formatting.
+
+ The result can be formatted by util.format_explanation() for
+ pretty printing.
+ """
+ hook_result = item.ihook.pytest_assertrepr_compare(
+ config=item.config, op=op, left=left, right=right
+ )
+ for new_expl in hook_result:
+ if new_expl:
+ new_expl = truncate.truncate_if_required(new_expl, item)
+ new_expl = [line.replace("\n", "\\n") for line in new_expl]
+ res = "\n~".join(new_expl)
+ if item.config.getvalue("assertmode") == "rewrite":
+ res = res.replace("%", "%%")
+ return res
+ return None
+
+ saved_assert_hooks = util._reprcompare, util._assertion_pass
+ util._reprcompare = callbinrepr
+
+ if item.ihook.pytest_assertion_pass.get_hookimpls():
+
+ def call_assertion_pass_hook(lineno, orig, expl):
+ item.ihook.pytest_assertion_pass(
+ item=item, lineno=lineno, orig=orig, expl=expl
+ )
+
+ util._assertion_pass = call_assertion_pass_hook
+
+ yield
+
+ util._reprcompare, util._assertion_pass = saved_assert_hooks
+
+
+def pytest_sessionfinish(session):
+ assertstate = session.config._store.get(assertstate_key, None)
+ if assertstate:
+ if assertstate.hook is not None:
+ assertstate.hook.set_session(None)
+
+
+def pytest_assertrepr_compare(
+ config: Config, op: str, left: Any, right: Any
+) -> Optional[List[str]]:
+ return util.assertrepr_compare(config=config, op=op, left=left, right=right)
diff --git a/contrib/python/pytest/py3/_pytest/assertion/rewrite.py b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
new file mode 100644
index 0000000000..f84127dcaf
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/assertion/rewrite.py
@@ -0,0 +1,1066 @@
+"""Rewrite assertion AST to produce nice error messages"""
+import ast
+import errno
+import functools
+import importlib.abc
+import importlib.machinery
+import importlib.util
+import io
+import itertools
+import marshal
+import os
+import struct
+import sys
+import tokenize
+import types
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Set
+from typing import Tuple
+
+from _pytest._io.saferepr import saferepr
+from _pytest._version import version
+from _pytest.assertion import util
+from _pytest.assertion.util import ( # noqa: F401
+ format_explanation as _format_explanation,
+)
+from _pytest.compat import fspath
+from _pytest.compat import TYPE_CHECKING
+from _pytest.pathlib import fnmatch_ex
+from _pytest.pathlib import Path
+from _pytest.pathlib import PurePath
+from _pytest.store import StoreKey
+
+if TYPE_CHECKING:
+ from _pytest.assertion import AssertionState # noqa: F401
+
+
+assertstate_key = StoreKey["AssertionState"]()
+
+
+# pytest caches rewritten pycs in pycache dirs
+PYTEST_TAG = "{}-pytest-{}".format(sys.implementation.cache_tag, version)
+PYC_EXT = ".py" + (__debug__ and "c" or "o")
+PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
+
+
+class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
+ """PEP302/PEP451 import hook which rewrites asserts."""
+
+ def __init__(self, config):
+ self.config = config
+ try:
+ self.fnpats = config.getini("python_files")
+ except ValueError:
+ self.fnpats = ["test_*.py", "*_test.py"]
+ self.session = None
+ self._rewritten_names = set() # type: Set[str]
+ self._must_rewrite = set() # type: Set[str]
+ # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
+ # which might result in infinite recursion (#3506)
+ self._writing_pyc = False
+ self._basenames_to_check_rewrite = {"conftest"}
+ self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
+ self._session_paths_checked = False
+
+ def set_session(self, session):
+ self.session = session
+ self._session_paths_checked = False
+
+ # Indirection so we can mock calls to find_spec originated from the hook during testing
+ _find_spec = importlib.machinery.PathFinder.find_spec
+
+ def find_spec(self, name, path=None, target=None):
+ if self._writing_pyc:
+ return None
+ state = self.config._store[assertstate_key]
+ if self._early_rewrite_bailout(name, state):
+ return None
+ state.trace("find_module called for: %s" % name)
+
+ spec = self._find_spec(name, path)
+ if (
+ # the import machinery could not find a file to import
+ spec is None
+ # this is a namespace package (without `__init__.py`)
+ # there's nothing to rewrite there
+ # python3.5 - python3.6: `namespace`
+ # python3.7+: `None`
+ or spec.origin == "namespace"
+ or spec.origin is None
+ # we can only rewrite source files
+ or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
+ # if the file doesn't exist, we can't rewrite it
+ or not os.path.exists(spec.origin)
+ ):
+ return None
+ else:
+ fn = spec.origin
+
+ if not self._should_rewrite(name, fn, state):
+ return None
+
+ return importlib.util.spec_from_file_location(
+ name,
+ fn,
+ loader=self,
+ submodule_search_locations=spec.submodule_search_locations,
+ )
+
+ def create_module(self, spec):
+ return None # default behaviour is fine
+
+ def exec_module(self, module):
+ fn = Path(module.__spec__.origin)
+ state = self.config._store[assertstate_key]
+
+ self._rewritten_names.add(module.__name__)
+
+ # The requested module looks like a test file, so rewrite it. This is
+ # the most magical part of the process: load the source, rewrite the
+ # asserts, and load the rewritten source. We also cache the rewritten
+ # module code in a special pyc. We must be aware of the possibility of
+ # concurrent pytest processes rewriting and loading pycs. To avoid
+ # tricky race conditions, we maintain the following invariant: The
+ # cached pyc is always a complete, valid pyc. Operations on it must be
+ # atomic. POSIX's atomic rename comes in handy.
+ write = not sys.dont_write_bytecode
+ cache_dir = get_cache_dir(fn)
+ if write:
+ ok = try_makedirs(cache_dir)
+ if not ok:
+ write = False
+ state.trace("read only directory: {}".format(cache_dir))
+
+ cache_name = fn.name[:-3] + PYC_TAIL
+ pyc = cache_dir / cache_name
+ # Notice that even if we're in a read-only directory, I'm going
+ # to check for a cached pyc. This may not be optimal...
+ co = _read_pyc(fn, pyc, state.trace)
+ if co is None:
+ state.trace("rewriting {!r}".format(fn))
+ source_stat, co = _rewrite_test(fn, self.config)
+ if write:
+ self._writing_pyc = True
+ try:
+ _write_pyc(state, co, source_stat, pyc)
+ finally:
+ self._writing_pyc = False
+ else:
+ state.trace("found cached rewritten pyc for {}".format(fn))
+ exec(co, module.__dict__)
+
+ def _early_rewrite_bailout(self, name, state):
+ """This is a fast way to get out of rewriting modules.
+
+ Profiling has shown that the call to PathFinder.find_spec (inside of
+ the find_spec from this class) is a major slowdown, so, this method
+ tries to filter what we're sure won't be rewritten before getting to
+ it.
+ """
+ if self.session is not None and not self._session_paths_checked:
+ self._session_paths_checked = True
+ for path in self.session._initialpaths:
+ # Make something as c:/projects/my_project/path.py ->
+ # ['c:', 'projects', 'my_project', 'path.py']
+ parts = str(path).split(os.path.sep)
+ # add 'path' to basenames to be checked.
+ self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
+
+ # Note: conftest already by default in _basenames_to_check_rewrite.
+ parts = name.split(".")
+ if parts[-1] in self._basenames_to_check_rewrite:
+ return False
+
+ # For matching the name it must be as if it was a filename.
+ path = PurePath(os.path.sep.join(parts) + ".py")
+
+ for pat in self.fnpats:
+ # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
+ # on the name alone because we need to match against the full path
+ if os.path.dirname(pat):
+ return False
+ if fnmatch_ex(pat, path):
+ return False
+
+ if self._is_marked_for_rewrite(name, state):
+ return False
+
+ state.trace("early skip of rewriting module: {}".format(name))
+ return True
+
+ def _should_rewrite(self, name, fn, state):
+ # always rewrite conftest files
+ if os.path.basename(fn) == "conftest.py":
+ state.trace("rewriting conftest file: {!r}".format(fn))
+ return True
+
+ if self.session is not None:
+ if self.session.isinitpath(fn):
+ state.trace(
+ "matched test file (was specified on cmdline): {!r}".format(fn)
+ )
+ return True
+
+ # modules not passed explicitly on the command line are only
+ # rewritten if they match the naming convention for test files
+ fn_path = PurePath(fn)
+ for pat in self.fnpats:
+ if fnmatch_ex(pat, fn_path):
+ state.trace("matched test file {!r}".format(fn))
+ return True
+
+ return self._is_marked_for_rewrite(name, state)
+
+ def _is_marked_for_rewrite(self, name: str, state):
+ try:
+ return self._marked_for_rewrite_cache[name]
+ except KeyError:
+ for marked in self._must_rewrite:
+ if name == marked or name.startswith(marked + "."):
+ state.trace(
+ "matched marked file {!r} (from {!r})".format(name, marked)
+ )
+ self._marked_for_rewrite_cache[name] = True
+ return True
+
+ self._marked_for_rewrite_cache[name] = False
+ return False
+
+ def mark_rewrite(self, *names: str) -> None:
+ """Mark import names as needing to be rewritten.
+
+ The named module or package as well as any nested modules will
+ be rewritten on import.
+ """
+ already_imported = (
+ set(names).intersection(sys.modules).difference(self._rewritten_names)
+ )
+ for name in already_imported:
+ mod = sys.modules[name]
+ if not AssertionRewriter.is_rewrite_disabled(
+ mod.__doc__ or ""
+ ) and not isinstance(mod.__loader__, type(self)):
+ self._warn_already_imported(name)
+ self._must_rewrite.update(names)
+ self._marked_for_rewrite_cache.clear()
+
+ def _warn_already_imported(self, name):
+ from _pytest.warning_types import PytestAssertRewriteWarning
+ from _pytest.warnings import _issue_warning_captured
+
+ _issue_warning_captured(
+ PytestAssertRewriteWarning(
+ "Module already imported so cannot be rewritten: %s" % name
+ ),
+ self.config.hook,
+ stacklevel=5,
+ )
+
+ def get_data(self, pathname):
+ """Optional PEP302 get_data API."""
+ with open(pathname, "rb") as f:
+ return f.read()
+
+
+def _write_pyc_fp(fp, source_stat, co):
+ # Technically, we don't have to have the same pyc format as
+ # (C)Python, since these "pycs" should never be seen by builtin
+ # import. However, there's little reason deviate.
+ fp.write(importlib.util.MAGIC_NUMBER)
+ # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
+ mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
+ size = source_stat.st_size & 0xFFFFFFFF
+ # "<LL" stands for 2 unsigned longs, little-ending
+ fp.write(struct.pack("<LL", mtime, size))
+ fp.write(marshal.dumps(co))
+
+
+if sys.platform == "win32":
+ from atomicwrites import atomic_write
+
+ def _write_pyc(state, co, source_stat, pyc):
+ try:
+ with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
+ _write_pyc_fp(fp, source_stat, co)
+ except EnvironmentError as e:
+ state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, pycache dir being a
+ # file etc.
+ return False
+ return True
+
+
+else:
+
+ def _write_pyc(state, co, source_stat, pyc):
+ proc_pyc = "{}.{}".format(pyc, os.getpid())
+ try:
+ fp = open(proc_pyc, "wb")
+ except EnvironmentError as e:
+ state.trace(
+ "error writing pyc file at {}: errno={}".format(proc_pyc, e.errno)
+ )
+ return False
+
+ try:
+ _write_pyc_fp(fp, source_stat, co)
+ os.rename(proc_pyc, fspath(pyc))
+ except BaseException as e:
+ state.trace("error writing pyc file at {}: errno={}".format(pyc, e.errno))
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, pycache dir being a
+ # file etc.
+ return False
+ finally:
+ fp.close()
+ return True
+
+
+def _rewrite_test(fn, config):
+ """read and rewrite *fn* and return the code object."""
+ fn = fspath(fn)
+ stat = os.stat(fn)
+ with open(fn, "rb") as f:
+ source = f.read()
+ tree = ast.parse(source, filename=fn)
+ rewrite_asserts(tree, source, fn, config)
+ co = compile(tree, fn, "exec", dont_inherit=True)
+ return stat, co
+
+
+def _read_pyc(source, pyc, trace=lambda x: None):
+ """Possibly read a pytest pyc containing rewritten code.
+
+ Return rewritten code if successful or None if not.
+ """
+ try:
+ fp = open(fspath(pyc), "rb")
+ except IOError:
+ return None
+ with fp:
+ try:
+ stat_result = os.stat(fspath(source))
+ mtime = int(stat_result.st_mtime)
+ size = stat_result.st_size
+ data = fp.read(12)
+ except EnvironmentError as e:
+ trace("_read_pyc({}): EnvironmentError {}".format(source, e))
+ return None
+ # Check for invalid or out of date pyc file.
+ if (
+ len(data) != 12
+ or data[:4] != importlib.util.MAGIC_NUMBER
+ or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
+ ):
+ trace("_read_pyc(%s): invalid or out of date pyc" % source)
+ return None
+ try:
+ co = marshal.load(fp)
+ except Exception as e:
+ trace("_read_pyc({}): marshal.load error {}".format(source, e))
+ return None
+ if not isinstance(co, types.CodeType):
+ trace("_read_pyc(%s): not a code object" % source)
+ return None
+ return co
+
+
+def rewrite_asserts(mod, source, module_path=None, config=None):
+ """Rewrite the assert statements in mod."""
+ AssertionRewriter(module_path, config, source).run(mod)
+
+
+def _saferepr(obj):
+ """Get a safe repr of an object for assertion error messages.
+
+ The assertion formatting (util.format_explanation()) requires
+ newlines to be escaped since they are a special character for it.
+ Normally assertion.util.format_explanation() does this but for a
+ custom repr it is possible to contain one of the special escape
+ sequences, especially '\n{' and '\n}' are likely to be present in
+ JSON reprs.
+
+ """
+ return saferepr(obj).replace("\n", "\\n")
+
+
+def _format_assertmsg(obj):
+ """Format the custom assertion message given.
+
+ For strings this simply replaces newlines with '\n~' so that
+ util.format_explanation() will preserve them instead of escaping
+ newlines. For other objects saferepr() is used first.
+
+ """
+ # reprlib appears to have a bug which means that if a string
+ # contains a newline it gets escaped, however if an object has a
+ # .__repr__() which contains newlines it does not get escaped.
+ # However in either case we want to preserve the newline.
+ replaces = [("\n", "\n~"), ("%", "%%")]
+ if not isinstance(obj, str):
+ obj = saferepr(obj)
+ replaces.append(("\\n", "\n~"))
+
+ for r1, r2 in replaces:
+ obj = obj.replace(r1, r2)
+
+ return obj
+
+
+def _should_repr_global_name(obj):
+ if callable(obj):
+ return False
+
+ try:
+ return not hasattr(obj, "__name__")
+ except Exception:
+ return True
+
+
+def _format_boolop(explanations, is_or):
+ explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
+ if isinstance(explanation, str):
+ return explanation.replace("%", "%%")
+ else:
+ return explanation.replace(b"%", b"%%")
+
+
+def _call_reprcompare(ops, results, expls, each_obj):
+ # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
+ for i, res, expl in zip(range(len(ops)), results, expls):
+ try:
+ done = not res
+ except Exception:
+ done = True
+ if done:
+ break
+ if util._reprcompare is not None:
+ custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
+ if custom is not None:
+ return custom
+ return expl
+
+
+def _call_assertion_pass(lineno, orig, expl):
+ # type: (int, str, str) -> None
+ if util._assertion_pass is not None:
+ util._assertion_pass(lineno, orig, expl)
+
+
+def _check_if_assertion_pass_impl():
+ # type: () -> bool
+ """Checks if any plugins implement the pytest_assertion_pass hook
+ in order not to generate explanation unecessarily (might be expensive)"""
+ return True if util._assertion_pass else False
+
+
+UNARY_MAP = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
+
+BINOP_MAP = {
+ ast.BitOr: "|",
+ ast.BitXor: "^",
+ ast.BitAnd: "&",
+ ast.LShift: "<<",
+ ast.RShift: ">>",
+ ast.Add: "+",
+ ast.Sub: "-",
+ ast.Mult: "*",
+ ast.Div: "/",
+ ast.FloorDiv: "//",
+ ast.Mod: "%%", # escaped for string formatting
+ ast.Eq: "==",
+ ast.NotEq: "!=",
+ ast.Lt: "<",
+ ast.LtE: "<=",
+ ast.Gt: ">",
+ ast.GtE: ">=",
+ ast.Pow: "**",
+ ast.Is: "is",
+ ast.IsNot: "is not",
+ ast.In: "in",
+ ast.NotIn: "not in",
+ ast.MatMult: "@",
+}
+
+
+def set_location(node, lineno, col_offset):
+ """Set node location information recursively."""
+
+ def _fix(node, lineno, col_offset):
+ if "lineno" in node._attributes:
+ node.lineno = lineno
+ if "col_offset" in node._attributes:
+ node.col_offset = col_offset
+ for child in ast.iter_child_nodes(node):
+ _fix(child, lineno, col_offset)
+
+ _fix(node, lineno, col_offset)
+ return node
+
+
+def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
+ """Returns a mapping from {lineno: "assertion test expression"}"""
+ ret = {} # type: Dict[int, str]
+
+ depth = 0
+ lines = [] # type: List[str]
+ assert_lineno = None # type: Optional[int]
+ seen_lines = set() # type: Set[int]
+
+ def _write_and_reset() -> None:
+ nonlocal depth, lines, assert_lineno, seen_lines
+ assert assert_lineno is not None
+ ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
+ depth = 0
+ lines = []
+ assert_lineno = None
+ seen_lines = set()
+
+ tokens = tokenize.tokenize(io.BytesIO(src).readline)
+ for tp, source, (lineno, offset), _, line in tokens:
+ if tp == tokenize.NAME and source == "assert":
+ assert_lineno = lineno
+ elif assert_lineno is not None:
+ # keep track of depth for the assert-message `,` lookup
+ if tp == tokenize.OP and source in "([{":
+ depth += 1
+ elif tp == tokenize.OP and source in ")]}":
+ depth -= 1
+
+ if not lines:
+ lines.append(line[offset:])
+ seen_lines.add(lineno)
+ # a non-nested comma separates the expression from the message
+ elif depth == 0 and tp == tokenize.OP and source == ",":
+ # one line assert with message
+ if lineno in seen_lines and len(lines) == 1:
+ offset_in_trimmed = offset + len(lines[-1]) - len(line)
+ lines[-1] = lines[-1][:offset_in_trimmed]
+ # multi-line assert with message
+ elif lineno in seen_lines:
+ lines[-1] = lines[-1][:offset]
+ # multi line assert with escapd newline before message
+ else:
+ lines.append(line[:offset])
+ _write_and_reset()
+ elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
+ _write_and_reset()
+ elif lines and lineno not in seen_lines:
+ lines.append(line)
+ seen_lines.add(lineno)
+
+ return ret
+
+
+class AssertionRewriter(ast.NodeVisitor):
+ """Assertion rewriting implementation.
+
+ The main entrypoint is to call .run() with an ast.Module instance,
+ this will then find all the assert statements and rewrite them to
+ provide intermediate values and a detailed assertion error. See
+ http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
+ for an overview of how this works.
+
+ The entry point here is .run() which will iterate over all the
+ statements in an ast.Module and for each ast.Assert statement it
+ finds call .visit() with it. Then .visit_Assert() takes over and
+ is responsible for creating new ast statements to replace the
+ original assert statement: it rewrites the test of an assertion
+ to provide intermediate values and replace it with an if statement
+ which raises an assertion error with a detailed explanation in
+ case the expression is false and calls pytest_assertion_pass hook
+ if expression is true.
+
+ For this .visit_Assert() uses the visitor pattern to visit all the
+ AST nodes of the ast.Assert.test field, each visit call returning
+ an AST node and the corresponding explanation string. During this
+ state is kept in several instance attributes:
+
+ :statements: All the AST statements which will replace the assert
+ statement.
+
+ :variables: This is populated by .variable() with each variable
+ used by the statements so that they can all be set to None at
+ the end of the statements.
+
+ :variable_counter: Counter to create new unique variables needed
+ by statements. Variables are created using .variable() and
+ have the form of "@py_assert0".
+
+ :expl_stmts: The AST statements which will be executed to get
+ data from the assertion. This is the code which will construct
+ the detailed assertion message that is used in the AssertionError
+ or for the pytest_assertion_pass hook.
+
+ :explanation_specifiers: A dict filled by .explanation_param()
+ with %-formatting placeholders and their corresponding
+ expressions to use in the building of an assertion message.
+ This is used by .pop_format_context() to build a message.
+
+ :stack: A stack of the explanation_specifiers dicts maintained by
+ .push_format_context() and .pop_format_context() which allows
+ to build another %-formatted string while already building one.
+
+ This state is reset on every new assert statement visited and used
+ by the other visitors.
+
+ """
+
+ def __init__(self, module_path, config, source):
+ super().__init__()
+ self.module_path = module_path
+ self.config = config
+ if config is not None:
+ self.enable_assertion_pass_hook = config.getini(
+ "enable_assertion_pass_hook"
+ )
+ else:
+ self.enable_assertion_pass_hook = False
+ self.source = source
+
+ @functools.lru_cache(maxsize=1)
+ def _assert_expr_to_lineno(self):
+ return _get_assertion_exprs(self.source)
+
+ def run(self, mod: ast.Module) -> None:
+ """Find all assert statements in *mod* and rewrite them."""
+ if not mod.body:
+ # Nothing to do.
+ return
+ # Insert some special imports at the top of the module but after any
+ # docstrings and __future__ imports.
+ aliases = [
+ ast.alias("builtins", "@py_builtins"),
+ ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
+ ]
+ doc = getattr(mod, "docstring", None)
+ expect_docstring = doc is None
+ if doc is not None and self.is_rewrite_disabled(doc):
+ return
+ pos = 0
+ lineno = 1
+ for item in mod.body:
+ if (
+ expect_docstring
+ and isinstance(item, ast.Expr)
+ and isinstance(item.value, ast.Str)
+ ):
+ doc = item.value.s
+ if self.is_rewrite_disabled(doc):
+ return
+ expect_docstring = False
+ elif (
+ not isinstance(item, ast.ImportFrom)
+ or item.level > 0
+ or item.module != "__future__"
+ ):
+ lineno = item.lineno
+ break
+ pos += 1
+ else:
+ lineno = item.lineno
+ imports = [
+ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
+ ]
+ mod.body[pos:pos] = imports
+ # Collect asserts.
+ nodes = [mod] # type: List[ast.AST]
+ while nodes:
+ node = nodes.pop()
+ for name, field in ast.iter_fields(node):
+ if isinstance(field, list):
+ new = [] # type: List
+ for i, child in enumerate(field):
+ if isinstance(child, ast.Assert):
+ # Transform assert.
+ new.extend(self.visit(child))
+ else:
+ new.append(child)
+ if isinstance(child, ast.AST):
+ nodes.append(child)
+ setattr(node, name, new)
+ elif (
+ isinstance(field, ast.AST)
+ # Don't recurse into expressions as they can't contain
+ # asserts.
+ and not isinstance(field, ast.expr)
+ ):
+ nodes.append(field)
+
+ @staticmethod
+ def is_rewrite_disabled(docstring):
+ return "PYTEST_DONT_REWRITE" in docstring
+
+ def variable(self):
+ """Get a new variable."""
+ # Use a character invalid in python identifiers to avoid clashing.
+ name = "@py_assert" + str(next(self.variable_counter))
+ self.variables.append(name)
+ return name
+
+ def assign(self, expr):
+ """Give *expr* a name."""
+ name = self.variable()
+ self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
+ return ast.Name(name, ast.Load())
+
+ def display(self, expr):
+ """Call saferepr on the expression."""
+ return self.helper("_saferepr", expr)
+
+ def helper(self, name, *args):
+ """Call a helper in this module."""
+ py_name = ast.Name("@pytest_ar", ast.Load())
+ attr = ast.Attribute(py_name, name, ast.Load())
+ return ast.Call(attr, list(args), [])
+
+ def builtin(self, name):
+ """Return the builtin called *name*."""
+ builtin_name = ast.Name("@py_builtins", ast.Load())
+ return ast.Attribute(builtin_name, name, ast.Load())
+
+ def explanation_param(self, expr):
+ """Return a new named %-formatting placeholder for expr.
+
+ This creates a %-formatting placeholder for expr in the
+ current formatting context, e.g. ``%(py0)s``. The placeholder
+ and expr are placed in the current format context so that it
+ can be used on the next call to .pop_format_context().
+
+ """
+ specifier = "py" + str(next(self.variable_counter))
+ self.explanation_specifiers[specifier] = expr
+ return "%(" + specifier + ")s"
+
+ def push_format_context(self):
+ """Create a new formatting context.
+
+ The format context is used for when an explanation wants to
+ have a variable value formatted in the assertion message. In
+ this case the value required can be added using
+ .explanation_param(). Finally .pop_format_context() is used
+ to format a string of %-formatted values as added by
+ .explanation_param().
+
+ """
+ self.explanation_specifiers = {} # type: Dict[str, ast.expr]
+ self.stack.append(self.explanation_specifiers)
+
+ def pop_format_context(self, expl_expr):
+ """Format the %-formatted string with current format context.
+
+ The expl_expr should be an ast.Str instance constructed from
+ the %-placeholders created by .explanation_param(). This will
+ add the required code to format said string to .expl_stmts and
+ return the ast.Name instance of the formatted string.
+
+ """
+ current = self.stack.pop()
+ if self.stack:
+ self.explanation_specifiers = self.stack[-1]
+ keys = [ast.Str(key) for key in current.keys()]
+ format_dict = ast.Dict(keys, list(current.values()))
+ form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
+ name = "@py_format" + str(next(self.variable_counter))
+ if self.enable_assertion_pass_hook:
+ self.format_variables.append(name)
+ self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
+ return ast.Name(name, ast.Load())
+
+ def generic_visit(self, node):
+ """Handle expressions we don't have custom code for."""
+ assert isinstance(node, ast.expr)
+ res = self.assign(node)
+ return res, self.explanation_param(self.display(res))
+
+ def visit_Assert(self, assert_):
+ """Return the AST statements to replace the ast.Assert instance.
+
+ This rewrites the test of an assertion to provide
+ intermediate values and replace it with an if statement which
+ raises an assertion error with a detailed explanation in case
+ the expression is false.
+
+ """
+ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
+ from _pytest.warning_types import PytestAssertRewriteWarning
+ import warnings
+
+ warnings.warn_explicit(
+ PytestAssertRewriteWarning(
+ "assertion is always true, perhaps remove parentheses?"
+ ),
+ category=None,
+ filename=fspath(self.module_path),
+ lineno=assert_.lineno,
+ )
+
+ self.statements = [] # type: List[ast.stmt]
+ self.variables = [] # type: List[str]
+ self.variable_counter = itertools.count()
+
+ if self.enable_assertion_pass_hook:
+ self.format_variables = [] # type: List[str]
+
+ self.stack = [] # type: List[Dict[str, ast.expr]]
+ self.expl_stmts = [] # type: List[ast.stmt]
+ self.push_format_context()
+ # Rewrite assert into a bunch of statements.
+ top_condition, explanation = self.visit(assert_.test)
+
+ negation = ast.UnaryOp(ast.Not(), top_condition)
+
+ if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
+ msg = self.pop_format_context(ast.Str(explanation))
+
+ # Failed
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ gluestr = "\n>assert "
+ else:
+ assertmsg = ast.Str("")
+ gluestr = "assert "
+ err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
+ err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
+ err_name = ast.Name("AssertionError", ast.Load())
+ fmt = self.helper("_format_explanation", err_msg)
+ exc = ast.Call(err_name, [fmt], [])
+ raise_ = ast.Raise(exc, None)
+ statements_fail = []
+ statements_fail.extend(self.expl_stmts)
+ statements_fail.append(raise_)
+
+ # Passed
+ fmt_pass = self.helper("_format_explanation", msg)
+ orig = self._assert_expr_to_lineno()[assert_.lineno]
+ hook_call_pass = ast.Expr(
+ self.helper(
+ "_call_assertion_pass",
+ ast.Num(assert_.lineno),
+ ast.Str(orig),
+ fmt_pass,
+ )
+ )
+ # If any hooks implement assert_pass hook
+ hook_impl_test = ast.If(
+ self.helper("_check_if_assertion_pass_impl"),
+ self.expl_stmts + [hook_call_pass],
+ [],
+ )
+ statements_pass = [hook_impl_test]
+
+ # Test for assertion condition
+ main_test = ast.If(negation, statements_fail, statements_pass)
+ self.statements.append(main_test)
+ if self.format_variables:
+ variables = [
+ ast.Name(name, ast.Store()) for name in self.format_variables
+ ]
+ clear_format = ast.Assign(variables, ast.NameConstant(None))
+ self.statements.append(clear_format)
+
+ else: # Original assertion rewriting
+ # Create failure message.
+ body = self.expl_stmts
+ self.statements.append(ast.If(negation, body, []))
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ explanation = "\n>assert " + explanation
+ else:
+ assertmsg = ast.Str("")
+ explanation = "assert " + explanation
+ template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
+ msg = self.pop_format_context(template)
+ fmt = self.helper("_format_explanation", msg)
+ err_name = ast.Name("AssertionError", ast.Load())
+ exc = ast.Call(err_name, [fmt], [])
+ raise_ = ast.Raise(exc, None)
+
+ body.append(raise_)
+
+ # Clear temporary variables by setting them to None.
+ if self.variables:
+ variables = [ast.Name(name, ast.Store()) for name in self.variables]
+ clear = ast.Assign(variables, ast.NameConstant(None))
+ self.statements.append(clear)
+ # Fix line numbers.
+ for stmt in self.statements:
+ set_location(stmt, assert_.lineno, assert_.col_offset)
+ return self.statements
+
+ def visit_Name(self, name):
+ # Display the repr of the name if it's a local variable or
+ # _should_repr_global_name() thinks it's acceptable.
+ locs = ast.Call(self.builtin("locals"), [], [])
+ inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
+ dorepr = self.helper("_should_repr_global_name", name)
+ test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
+ expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
+ return name, self.explanation_param(expr)
+
+ def visit_BoolOp(self, boolop):
+ res_var = self.variable()
+ expl_list = self.assign(ast.List([], ast.Load()))
+ app = ast.Attribute(expl_list, "append", ast.Load())
+ is_or = int(isinstance(boolop.op, ast.Or))
+ body = save = self.statements
+ fail_save = self.expl_stmts
+ levels = len(boolop.values) - 1
+ self.push_format_context()
+ # Process each operand, short-circuiting if needed.
+ for i, v in enumerate(boolop.values):
+ if i:
+ fail_inner = [] # type: List[ast.stmt]
+ # cond is set in a prior loop iteration below
+ self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
+ self.expl_stmts = fail_inner
+ self.push_format_context()
+ res, expl = self.visit(v)
+ body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+ expl_format = self.pop_format_context(ast.Str(expl))
+ call = ast.Call(app, [expl_format], [])
+ self.expl_stmts.append(ast.Expr(call))
+ if i < levels:
+ cond = res # type: ast.expr
+ if is_or:
+ cond = ast.UnaryOp(ast.Not(), cond)
+ inner = [] # type: List[ast.stmt]
+ self.statements.append(ast.If(cond, inner, []))
+ self.statements = body = inner
+ self.statements = save
+ self.expl_stmts = fail_save
+ expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
+ expl = self.pop_format_context(expl_template)
+ return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
+
+ def visit_UnaryOp(self, unary):
+ pattern = UNARY_MAP[unary.op.__class__]
+ operand_res, operand_expl = self.visit(unary.operand)
+ res = self.assign(ast.UnaryOp(unary.op, operand_res))
+ return res, pattern % (operand_expl,)
+
+ def visit_BinOp(self, binop):
+ symbol = BINOP_MAP[binop.op.__class__]
+ left_expr, left_expl = self.visit(binop.left)
+ right_expr, right_expl = self.visit(binop.right)
+ explanation = "({} {} {})".format(left_expl, symbol, right_expl)
+ res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
+ return res, explanation
+
+ def visit_Call(self, call):
+ """
+ visit `ast.Call` nodes
+ """
+ new_func, func_expl = self.visit(call.func)
+ arg_expls = []
+ new_args = []
+ new_kwargs = []
+ for arg in call.args:
+ res, expl = self.visit(arg)
+ arg_expls.append(expl)
+ new_args.append(res)
+ for keyword in call.keywords:
+ res, expl = self.visit(keyword.value)
+ new_kwargs.append(ast.keyword(keyword.arg, res))
+ if keyword.arg:
+ arg_expls.append(keyword.arg + "=" + expl)
+ else: # **args have `arg` keywords with an .arg of None
+ arg_expls.append("**" + expl)
+
+ expl = "{}({})".format(func_expl, ", ".join(arg_expls))
+ new_call = ast.Call(new_func, new_args, new_kwargs)
+ res = self.assign(new_call)
+ res_expl = self.explanation_param(self.display(res))
+ outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
+ return res, outer_expl
+
+ def visit_Starred(self, starred):
+ # From Python 3.5, a Starred node can appear in a function call
+ res, expl = self.visit(starred.value)
+ new_starred = ast.Starred(res, starred.ctx)
+ return new_starred, "*" + expl
+
+ def visit_Attribute(self, attr):
+ if not isinstance(attr.ctx, ast.Load):
+ return self.generic_visit(attr)
+ value, value_expl = self.visit(attr.value)
+ res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
+ res_expl = self.explanation_param(self.display(res))
+ pat = "%s\n{%s = %s.%s\n}"
+ expl = pat % (res_expl, res_expl, value_expl, attr.attr)
+ return res, expl
+
+ def visit_Compare(self, comp: ast.Compare):
+ self.push_format_context()
+ left_res, left_expl = self.visit(comp.left)
+ if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
+ left_expl = "({})".format(left_expl)
+ res_variables = [self.variable() for i in range(len(comp.ops))]
+ load_names = [ast.Name(v, ast.Load()) for v in res_variables]
+ store_names = [ast.Name(v, ast.Store()) for v in res_variables]
+ it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
+ expls = []
+ syms = []
+ results = [left_res]
+ for i, op, next_operand in it:
+ next_res, next_expl = self.visit(next_operand)
+ if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
+ next_expl = "({})".format(next_expl)
+ results.append(next_res)
+ sym = BINOP_MAP[op.__class__]
+ syms.append(ast.Str(sym))
+ expl = "{} {} {}".format(left_expl, sym, next_expl)
+ expls.append(ast.Str(expl))
+ res_expr = ast.Compare(left_res, [op], [next_res])
+ self.statements.append(ast.Assign([store_names[i]], res_expr))
+ left_res, left_expl = next_res, next_expl
+ # Use pytest.assertion.util._reprcompare if that's available.
+ expl_call = self.helper(
+ "_call_reprcompare",
+ ast.Tuple(syms, ast.Load()),
+ ast.Tuple(load_names, ast.Load()),
+ ast.Tuple(expls, ast.Load()),
+ ast.Tuple(results, ast.Load()),
+ )
+ if len(comp.ops) > 1:
+ res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
+ else:
+ res = load_names[0]
+ return res, self.explanation_param(self.pop_format_context(expl_call))
+
+
+def try_makedirs(cache_dir) -> bool:
+ """Attempts to create the given directory and sub-directories exist, returns True if
+ successful or it already exists"""
+ try:
+ os.makedirs(fspath(cache_dir), exist_ok=True)
+ except (FileNotFoundError, NotADirectoryError, FileExistsError):
+ # One of the path components was not a directory:
+ # - we're in a zip file
+ # - it is a file
+ return False
+ except PermissionError:
+ return False
+ except OSError as e:
+ # as of now, EROFS doesn't have an equivalent OSError-subclass
+ if e.errno == errno.EROFS:
+ return False
+ raise
+ return True
+
+
+def get_cache_dir(file_path: Path) -> Path:
+ """Returns the cache directory to write .pyc files for the given .py file path"""
+ if sys.version_info >= (3, 8) and sys.pycache_prefix:
+ # given:
+ # prefix = '/tmp/pycs'
+ # path = '/home/user/proj/test_app.py'
+ # we want:
+ # '/tmp/pycs/home/user/proj'
+ return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1])
+ else:
+ # classic pycache directory
+ return file_path.parent / "__pycache__"
diff --git a/contrib/python/pytest/py3/_pytest/assertion/truncate.py b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
new file mode 100644
index 0000000000..d97b05b441
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/assertion/truncate.py
@@ -0,0 +1,95 @@
+"""
+Utilities for truncating assertion output.
+
+Current default behaviour is to truncate assertion explanations at
+~8 terminal lines, unless running in "-vv" mode or running on CI.
+"""
+import os
+
+DEFAULT_MAX_LINES = 8
+DEFAULT_MAX_CHARS = 8 * 80
+USAGE_MSG = "use '-vv' to show"
+
+
+def truncate_if_required(explanation, item, max_length=None):
+ """
+ Truncate this assertion explanation if the given test item is eligible.
+ """
+ if _should_truncate_item(item):
+ return _truncate_explanation(explanation)
+ return explanation
+
+
+def _should_truncate_item(item):
+ """
+ Whether or not this test item is eligible for truncation.
+ """
+ verbose = item.config.option.verbose
+ return verbose < 2 and not _running_on_ci()
+
+
+def _running_on_ci():
+ """Check if we're currently running on a CI system."""
+ env_vars = ["CI", "BUILD_NUMBER"]
+ return any(var in os.environ for var in env_vars)
+
+
+def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
+ """
+ Truncate given list of strings that makes up the assertion explanation.
+
+ Truncates to either 8 lines, or 640 characters - whichever the input reaches
+ first. The remaining lines will be replaced by a usage message.
+ """
+
+ if max_lines is None:
+ max_lines = DEFAULT_MAX_LINES
+ if max_chars is None:
+ max_chars = DEFAULT_MAX_CHARS
+
+ # Check if truncation required
+ input_char_count = len("".join(input_lines))
+ if len(input_lines) <= max_lines and input_char_count <= max_chars:
+ return input_lines
+
+ # Truncate first to max_lines, and then truncate to max_chars if max_chars
+ # is exceeded.
+ truncated_explanation = input_lines[:max_lines]
+ truncated_explanation = _truncate_by_char_count(truncated_explanation, max_chars)
+
+ # Add ellipsis to final line
+ truncated_explanation[-1] = truncated_explanation[-1] + "..."
+
+ # Append useful message to explanation
+ truncated_line_count = len(input_lines) - len(truncated_explanation)
+ truncated_line_count += 1 # Account for the part-truncated final line
+ msg = "...Full output truncated"
+ if truncated_line_count == 1:
+ msg += " ({} line hidden)".format(truncated_line_count)
+ else:
+ msg += " ({} lines hidden)".format(truncated_line_count)
+ msg += ", {}".format(USAGE_MSG)
+ truncated_explanation.extend(["", str(msg)])
+ return truncated_explanation
+
+
+def _truncate_by_char_count(input_lines, max_chars):
+ # Check if truncation required
+ if len("".join(input_lines)) <= max_chars:
+ return input_lines
+
+ # Find point at which input length exceeds total allowed length
+ iterated_char_count = 0
+ for iterated_index, input_line in enumerate(input_lines):
+ if iterated_char_count + len(input_line) > max_chars:
+ break
+ iterated_char_count += len(input_line)
+
+ # Create truncated explanation with modified final line
+ truncated_result = input_lines[:iterated_index]
+ final_line = input_lines[iterated_index]
+ if final_line:
+ final_line_truncate_point = max_chars - iterated_char_count
+ final_line = final_line[:final_line_truncate_point]
+ truncated_result.append(final_line)
+ return truncated_result
diff --git a/contrib/python/pytest/py3/_pytest/assertion/util.py b/contrib/python/pytest/py3/_pytest/assertion/util.py
new file mode 100644
index 0000000000..7d525aa4c4
--- /dev/null
+++ b/contrib/python/pytest/py3/_pytest/assertion/util.py
@@ -0,0 +1,463 @@
+"""Utilities for assertion debugging"""
+import collections.abc
+import pprint
+from typing import AbstractSet
+from typing import Any
+from typing import Callable
+from typing import Iterable
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+import _pytest._code
+from _pytest import outcomes
+from _pytest._io.saferepr import _pformat_dispatch
+from _pytest._io.saferepr import safeformat
+from _pytest._io.saferepr import saferepr
+from _pytest.compat import ATTRS_EQ_FIELD
+
+# The _reprcompare attribute on the util module is used by the new assertion
+# interpretation code and assertion rewriter to detect this plugin was
+# loaded and in turn call the hooks defined here as part of the
+# DebugInterpreter.
+_reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]]
+
+# Works similarly as _reprcompare attribute. Is populated with the hook call
+# when pytest_runtest_setup is called.
+_assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
+
+
+def format_explanation(explanation: str) -> str:
+ """This formats an explanation
+
+ Normally all embedded newlines are escaped, however there are
+ three exceptions: \n{, \n} and \n~. The first two are intended
+ cover nested explanations, see function and attribute explanations
+ for examples (.visit_Call(), visit_Attribute()). The last one is
+ for when one explanation needs to span multiple lines, e.g. when
+ displaying diffs.
+ """
+ lines = _split_explanation(explanation)
+ result = _format_lines(lines)
+ return "\n".join(result)
+
+
+def _split_explanation(explanation: str) -> List[str]:
+ """Return a list of individual lines in the explanation
+
+ This will return a list of lines split on '\n{', '\n}' and '\n~'.
+ Any other newlines will be escaped and appear in the line as the
+ literal '\n' characters.
+ """
+ raw_lines = (explanation or "").split("\n")
+ lines = [raw_lines[0]]
+ for values in raw_lines[1:]:
+ if values and values[0] in ["{", "}", "~", ">"]:
+ lines.append(values)
+ else:
+ lines[-1] += "\\n" + values
+ return lines
+
+
+def _format_lines(lines: Sequence[str]) -> List[str]:
+ """Format the individual lines
+
+ This will replace the '{', '}' and '~' characters of our mini
+ formatting language with the proper 'where ...', 'and ...' and ' +
+ ...' text, taking care of indentation along the way.
+
+ Return a list of formatted lines.
+ """
+ result = list(lines[:1])
+ stack = [0]
+ stackcnt = [0]
+ for line in lines[1:]:
+ if line.startswith("{"):
+ if stackcnt[-1]:
+ s = "and "
+ else:
+ s = "where "
+ stack.append(len(result))
+ stackcnt[-1] += 1
+ stackcnt.append(0)
+ result.append(" +" + " " * (len(stack) - 1) + s + line[1:])
+ elif line.startswith("}"):
+ stack.pop()
+ stackcnt.pop()
+ result[stack[-1]] += line[1:]
+ else:
+ assert line[0] in ["~", ">"]
+ stack[-1] += 1
+ indent = len(stack) if line.startswith("~") else len(stack) - 1
+ result.append(" " * indent + line[1:])
+ assert len(stack) == 1
+ return result
+
+
+def issequence(x: Any) -> bool:
+ return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
+
+
+def istext(x: Any) -> bool:
+ return isinstance(x, str)
+
+
+def isdict(x: Any) -> bool:
+ return isinstance(x, dict)
+
+
+def isset(x: Any) -> bool:
+ return isinstance(x, (set, frozenset))
+
+
+def isdatacls(obj: Any) -> bool:
+ return getattr(obj, "__dataclass_fields__", None) is not None
+
+
+def isattrs(obj: Any) -> bool:
+ return getattr(obj, "__attrs_attrs__", None) is not None
+
+
+def isiterable(obj: Any) -> bool:
+ try:
+ iter(obj)
+ return not istext(obj)
+ except TypeError:
+ return False
+
+
+def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
+ """Return specialised explanations for some operators/operands"""
+ verbose = config.getoption("verbose")
+ if verbose > 1:
+ left_repr = safeformat(left)
+ right_repr = safeformat(right)
+ else:
+ # XXX: "15 chars indentation" is wrong
+ # ("E AssertionError: assert "); should use term width.
+ maxsize = (
+ 80 - 15 - len(op) - 2
+ ) // 2 # 15 chars indentation, 1 space around op
+ left_repr = saferepr(left, maxsize=maxsize)
+ right_repr = saferepr(right, maxsize=maxsize)
+
+ summary = "{} {} {}".format(left_repr, op, right_repr)
+
+ explanation = None
+ try:
+ if op == "==":
+ if istext(left) and istext(right):
+ explanation = _diff_text(left, right, verbose)
+ else:
+ if issequence(left) and issequence(right):
+ explanation = _compare_eq_sequence(left, right, verbose)
+ elif isset(left) and isset(right):
+ explanation = _compare_eq_set(left, right, verbose)
+ elif isdict(left) and isdict(right):
+ explanation = _compare_eq_dict(left, right, verbose)
+ elif type(left) == type(right) and (isdatacls(left) or isattrs(left)):
+ type_fn = (isdatacls, isattrs)
+ explanation = _compare_eq_cls(left, right, verbose, type_fn)
+ elif verbose > 0:
+ explanation = _compare_eq_verbose(left, right)
+ if isiterable(left) and isiterable(right):
+ expl = _compare_eq_iterable(left, right, verbose)
+ if explanation is not None:
+ explanation.extend(expl)
+ else:
+ explanation = expl
+ elif op == "not in":
+ if istext(left) and istext(right):
+ explanation = _notin_text(left, right, verbose)
+ except outcomes.Exit:
+ raise
+ except Exception:
+ explanation = [
+ "(pytest_assertion plugin: representation of details failed: {}.".format(
+ _pytest._code.ExceptionInfo.from_current()._getreprcrash()
+ ),
+ " Probably an object has a faulty __repr__.)",
+ ]
+
+ if not explanation:
+ return None
+
+ return [summary] + explanation
+
+
+def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
+ """Return the explanation for the diff between text.
+
+ Unless --verbose is used this will skip leading and trailing
+ characters which are identical to keep the diff minimal.
+ """
+ from difflib import ndiff
+
+ explanation = [] # type: List[str]
+
+ if verbose < 1:
+ i = 0 # just in case left or right has zero length
+ for i in range(min(len(left), len(right))):
+ if left[i] != right[i]:
+ break
+ if i > 42:
+ i -= 10 # Provide some context
+ explanation = [
+ "Skipping %s identical leading characters in diff, use -v to show" % i
+ ]
+ left = left[i:]
+ right = right[i:]
+ if len(left) == len(right):
+ for i in range(len(left)):
+ if left[-i] != right[-i]:
+ break
+ if i > 42:
+ i -= 10 # Provide some context
+ explanation += [
+ "Skipping {} identical trailing "
+ "characters in diff, use -v to show".format(i)
+ ]
+ left = left[:-i]
+ right = right[:-i]
+ keepends = True
+ if left.isspace() or right.isspace():
+ left = repr(str(left))
+ right = repr(str(right))
+ explanation += ["Strings contain only whitespace, escaping them using repr()"]
+ # "right" is the expected base against which we compare "left",
+ # see https://github.com/pytest-dev/pytest/issues/3333
+ explanation += [
+ line.strip("\n")
+ for line in ndiff(right.splitlines(keepends), left.splitlines(keepends))
+ ]
+ return explanation
+
+
+def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
+ keepends = True
+ left_lines = repr(left).splitlines(keepends)
+ right_lines = repr(right).splitlines(keepends)
+
+ explanation = [] # type: List[str]
+ explanation += ["+" + line for line in left_lines]
+ explanation += ["-" + line for line in right_lines]
+
+ return explanation
+
+
+def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
+ """Move opening/closing parenthesis/bracket to own lines."""
+ opening = lines[0][:1]
+ if opening in ["(", "[", "{"]:
+ lines[0] = " " + lines[0][1:]
+ lines[:] = [opening] + lines
+ closing = lines[-1][-1:]
+ if closing in [")", "]", "}"]:
+ lines[-1] = lines[-1][:-1] + ","
+ lines[:] = lines + [closing]
+
+
+def _compare_eq_iterable(
+ left: Iterable[Any], right: Iterable[Any], verbose: int = 0
+) -> List[str]:
+ if not verbose:
+ return ["Use -v to get the full diff"]
+ # dynamic import to speedup pytest
+ import difflib
+
+ left_formatting = pprint.pformat(left).splitlines()
+ right_formatting = pprint.pformat(right).splitlines()
+
+ # Re-format for different output lengths.
+ lines_left = len(left_formatting)
+ lines_right = len(right_formatting)
+ if lines_left != lines_right:
+ left_formatting = _pformat_dispatch(left).splitlines()
+ right_formatting = _pformat_dispatch(right).splitlines()
+
+ if lines_left > 1 or lines_right > 1:
+ _surrounding_parens_on_own_lines(left_formatting)
+ _surrounding_parens_on_own_lines(right_formatting)
+
+ explanation = ["Full diff:"]
+ # "right" is the expected base against which we compare "left",
+ # see https://github.com/pytest-dev/pytest/issues/3333
+ explanation.extend(
+ line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting)
+ )
+ return explanation
+
+
+def _compare_eq_sequence(
+ left: Sequence[Any], right: Sequence[Any], verbose: int = 0
+) -> List[str]:
+ comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
+ explanation = [] # type: List[str]
+ len_left = len(left)
+ len_right = len(right)
+ for i in range(min(len_left, len_right)):
+ if left[i] != right[i]:
+ if comparing_bytes:
+ # when comparing bytes, we want to see their ascii representation
+ # instead of their numeric values (#5260)
+ # using a slice gives us the ascii representation:
+ # >>> s = b'foo'
+ # >>> s[0]
+ # 102
+ # >>> s[0:1]
+ # b'f'
+ left_value = left[i : i + 1]
+ right_value = right[i : i + 1]
+ else:
+ left_value = left[i]
+ right_value = right[i]
+
+ explanation += [
+ "At index {} diff: {!r} != {!r}".format(i, left_value, right_value)
+ ]
+ break
+
+ if comparing_bytes:
+ # when comparing bytes, it doesn't help to show the "sides contain one or more
+ # items" longer explanation, so skip it
+
+ return explanation
+
+ len_diff = len_left - len_right
+ if len_diff:
+ if len_diff > 0:
+ dir_with_more = "Left"
+ extra = saferepr(left[len_right])
+ else:
+ len_diff = 0 - len_diff
+ dir_with_more = "Right"
+ extra = saferepr(right[len_left])
+
+ if len_diff == 1:
+ explanation += [
+ "{} contains one more item: {}".format(dir_with_more, extra)
+ ]
+ else:
+ explanation += [
+ "%s contains %d more items, first extra item: %s"
+ % (dir_with_more, len_diff, extra)
+ ]
+ return explanation
+
+
+def _compare_eq_set(
+ left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
+) -> List[str]:
+ explanation = []
+ diff_left = left - right
+ diff_right = right - left
+ if diff_left:
+ explanation.append("Extra items in the left set:")
+ for item in diff_left:
+ explanation.append(saferepr(item))
+ if diff_right:
+ explanation.append("Extra items in the right set:")
+ for item in diff_right:
+ explanation.append(saferepr(item))
+ return explanation
+
+
+def _compare_eq_dict(
+ left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
+) -> List[str]:
+ explanation = [] # type: List[str]
+ set_left = set(left)
+ set_right = set(right)
+ common = set_left.intersection(set_right)
+ same = {k: left[k] for k in common if left[k] == right[k]}
+ if same and verbose < 2:
+ explanation += ["Omitting %s identical items, use -vv to show" % len(same)]
+ elif same:
+ explanation += ["Common items:"]
+ explanation += pprint.pformat(same).splitlines()
+ diff = {k for k in common if left[k] != right[k]}
+ if diff:
+ explanation += ["Differing items:"]
+ for k in diff:
+ explanation += [saferepr({k: left[k]}) + " != " + saferepr({k: right[k]})]
+ extra_left = set_left - set_right
+ len_extra_left = len(extra_left)
+ if len_extra_left:
+ explanation.append(
+ "Left contains %d more item%s:"
+ % (len_extra_left, "" if len_extra_left == 1 else "s")
+ )
+ explanation.extend(
+ pprint.pformat({k: left[k] for k in extra_left}).splitlines()
+ )
+ extra_right = set_right - set_left
+ len_extra_right = len(extra_right)
+ if len_extra_right:
+ explanation.append(
+ "Right contains %d more item%s:"
+ % (len_extra_right, "" if len_extra_right == 1 else "s")
+ )
+ explanation.extend(
+ pprint.pformat({k: right[k] for k in extra_right}).splitlines()
+ )
+ return explanation
+
+
+def _compare_eq_cls(
+ left: Any,
+ right: Any,
+ verbose: int,
+ type_fns: Tuple[Callable[[Any], bool], Callable[[Any], bool]],
+) -> List[str]:
+ isdatacls, isattrs = type_fns
+ if isdatacls(left):
+ all_fields = left.__dataclass_fields__
+ fields_to_check = [field for field, info in all_fields.items() if info.compare]
+ elif isattrs(left):
+ all_fields = left.__attrs_attrs__
+ fields_to_check = [
+ field.name for field in all_fields if getattr(field, ATTRS_EQ_FIELD)
+ ]
+
+ same = []
+ diff = []
+ for field in fields_to_check:
+ if getattr(left, field) == getattr(right, field):
+ same.append(field)
+ else:
+ diff.append(field)
+
+ explanation = []
+ if same and verbose < 2:
+ explanation.append("Omitting %s identical items, use -vv to show" % len(same))
+ elif same:
+ explanation += ["Matching attributes:"]
+ explanation += pprint.pformat(same).splitlines()
+ if diff:
+ explanation += ["Differing attributes:"]
+ for field in diff:
+ explanation += [
+ ("%s: %r != %r") % (field, getattr(left, field), getattr(right, field))
+ ]
+ return explanation
+
+
+def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
+ index = text.find(term)
+ head = text[:index]
+ tail = text[index + len(term) :]
+ correct_text = head + tail
+ diff = _diff_text(text, correct_text, verbose)
+ newdiff = ["%s is contained here:" % saferepr(term, maxsize=42)]
+ for line in diff:
+ if line.startswith("Skipping"):
+ continue
+ if line.startswith("- "):
+ continue
+ if line.startswith("+ "):
+ newdiff.append(" " + line[2:])
+ else:
+ newdiff.append(line)
+ return newdiff