aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pytest/py2/_pytest/assertion/rewrite.py
diff options
context:
space:
mode:
authornkozlovskiy <nmk@ydb.tech>2023-09-29 12:24:06 +0300
committernkozlovskiy <nmk@ydb.tech>2023-09-29 12:41:34 +0300
commite0e3e1717e3d33762ce61950504f9637a6e669ed (patch)
treebca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/pytest/py2/_pytest/assertion/rewrite.py
parent38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff)
downloadydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz
add ydb deps
Diffstat (limited to 'contrib/python/pytest/py2/_pytest/assertion/rewrite.py')
-rw-r--r--contrib/python/pytest/py2/_pytest/assertion/rewrite.py1072
1 files changed, 1072 insertions, 0 deletions
diff --git a/contrib/python/pytest/py2/_pytest/assertion/rewrite.py b/contrib/python/pytest/py2/_pytest/assertion/rewrite.py
new file mode 100644
index 0000000000..6cfd81a32f
--- /dev/null
+++ b/contrib/python/pytest/py2/_pytest/assertion/rewrite.py
@@ -0,0 +1,1072 @@
+# -*- coding: utf-8 -*-
+"""Rewrite assertion AST to produce nice error messages"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+warnings.filterwarnings("ignore", category=DeprecationWarning, module="_pytest.assertion.rewrite")
+
+import ast
+import errno
+import imp
+import itertools
+import marshal
+import os
+import re
+import string
+import struct
+import sys
+import types
+
+import atomicwrites
+import py
+import six
+
+from _pytest._io.saferepr import saferepr
+from _pytest.assertion import util
+from _pytest.assertion.util import ( # noqa: F401
+ format_explanation as _format_explanation,
+)
+from _pytest.compat import spec_from_file_location
+from _pytest.pathlib import fnmatch_ex
+from _pytest.pathlib import PurePath
+
+# pytest caches rewritten pycs in __pycache__.
+if hasattr(imp, "get_tag"):
+ PYTEST_TAG = imp.get_tag() + "-PYTEST"
+else:
+ if hasattr(sys, "pypy_version_info"):
+ impl = "pypy"
+ elif sys.platform == "java":
+ impl = "jython"
+ else:
+ impl = "cpython"
+ ver = sys.version_info
+ PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
+ del ver, impl
+
+PYC_EXT = ".py" + (__debug__ and "c" or "o")
+PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
+
+ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
+
+if sys.version_info >= (3, 5):
+ ast_Call = ast.Call
+else:
+
+ def ast_Call(a, b, c):
+ return ast.Call(a, b, c, None, None)
+
+
+class AssertionRewritingHook(object):
+ """PEP302 Import hook which rewrites asserts."""
+
+ def __init__(self, config):
+ self.config = config
+ try:
+ self.fnpats = config.getini("python_files")
+ except ValueError:
+ self.fnpats = ["test_*.py", "*_test.py"]
+ self.session = None
+ self.modules = {}
+ self._rewritten_names = set()
+ self._must_rewrite = set()
+ # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
+ # which might result in infinite recursion (#3506)
+ self._writing_pyc = False
+ self._basenames_to_check_rewrite = {"conftest"}
+ self._marked_for_rewrite_cache = {}
+ self._session_paths_checked = False
+
+ def set_session(self, session):
+ self.session = session
+ self._session_paths_checked = False
+
+ def _imp_find_module(self, name, path=None):
+ """Indirection so we can mock calls to find_module originated from the hook during testing"""
+ return imp.find_module(name, path)
+
+ def find_module(self, name, path=None):
+ if self._writing_pyc:
+ return None
+ state = self.config._assertstate
+ if self._early_rewrite_bailout(name, state):
+ return None
+ state.trace("find_module called for: %s" % name)
+ names = name.rsplit(".", 1)
+ lastname = names[-1]
+ pth = None
+ if path is not None:
+ # Starting with Python 3.3, path is a _NamespacePath(), which
+ # causes problems if not converted to list.
+ path = list(path)
+ if len(path) == 1:
+ pth = path[0]
+ if pth is None:
+ try:
+ fd, fn, desc = self._imp_find_module(lastname, path)
+ except ImportError:
+ return None
+ if fd is not None:
+ fd.close()
+ tp = desc[2]
+ if tp == imp.PY_COMPILED:
+ if hasattr(imp, "source_from_cache"):
+ try:
+ fn = imp.source_from_cache(fn)
+ except ValueError:
+ # Python 3 doesn't like orphaned but still-importable
+ # .pyc files.
+ fn = fn[:-1]
+ else:
+ fn = fn[:-1]
+ elif tp != imp.PY_SOURCE:
+ # Don't know what this is.
+ return None
+ else:
+ fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
+
+ fn_pypath = py.path.local(fn)
+ if not self._should_rewrite(name, fn_pypath, state):
+ return None
+
+ self._rewritten_names.add(name)
+
+ # The requested module looks like a test file, so rewrite it. This is
+ # the most magical part of the process: load the source, rewrite the
+ # asserts, and load the rewritten source. We also cache the rewritten
+ # module code in a special pyc. We must be aware of the possibility of
+ # concurrent pytest processes rewriting and loading pycs. To avoid
+ # tricky race conditions, we maintain the following invariant: The
+ # cached pyc is always a complete, valid pyc. Operations on it must be
+ # atomic. POSIX's atomic rename comes in handy.
+ write = not sys.dont_write_bytecode
+ cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
+ if write:
+ try:
+ os.mkdir(cache_dir)
+ except OSError:
+ e = sys.exc_info()[1].errno
+ if e == errno.EEXIST:
+ # Either the __pycache__ directory already exists (the
+ # common case) or it's blocked by a non-dir node. In the
+ # latter case, we'll ignore it in _write_pyc.
+ pass
+ elif e in [errno.ENOENT, errno.ENOTDIR]:
+ # One of the path components was not a directory, likely
+ # because we're in a zip file.
+ write = False
+ elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
+ state.trace("read only directory: %r" % fn_pypath.dirname)
+ write = False
+ else:
+ raise
+ cache_name = fn_pypath.basename[:-3] + PYC_TAIL
+ pyc = os.path.join(cache_dir, cache_name)
+ # Notice that even if we're in a read-only directory, I'm going
+ # to check for a cached pyc. This may not be optimal...
+ co = _read_pyc(fn_pypath, pyc, state.trace)
+ if co is None:
+ state.trace("rewriting %r" % (fn,))
+ source_stat, co = _rewrite_test(self.config, fn_pypath)
+ if co is None:
+ # Probably a SyntaxError in the test.
+ return None
+ if write:
+ self._writing_pyc = True
+ try:
+ _write_pyc(state, co, source_stat, pyc)
+ finally:
+ self._writing_pyc = False
+ else:
+ state.trace("found cached rewritten pyc for %r" % (fn,))
+ self.modules[name] = co, pyc
+ return self
+
+ def _early_rewrite_bailout(self, name, state):
+ """
+ This is a fast way to get out of rewriting modules. Profiling has
+ shown that the call to imp.find_module (inside of the find_module
+ from this class) is a major slowdown, so, this method tries to
+ filter what we're sure won't be rewritten before getting to it.
+ """
+ if self.session is not None and not self._session_paths_checked:
+ self._session_paths_checked = True
+ for path in self.session._initialpaths:
+ # Make something as c:/projects/my_project/path.py ->
+ # ['c:', 'projects', 'my_project', 'path.py']
+ parts = str(path).split(os.path.sep)
+ # add 'path' to basenames to be checked.
+ self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
+
+ # Note: conftest already by default in _basenames_to_check_rewrite.
+ parts = name.split(".")
+ if parts[-1] in self._basenames_to_check_rewrite:
+ return False
+
+ # For matching the name it must be as if it was a filename.
+ path = PurePath(os.path.sep.join(parts) + ".py")
+
+ for pat in self.fnpats:
+ # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based
+ # on the name alone because we need to match against the full path
+ if os.path.dirname(pat):
+ return False
+ if fnmatch_ex(pat, path):
+ return False
+
+ if self._is_marked_for_rewrite(name, state):
+ return False
+
+ state.trace("early skip of rewriting module: %s" % (name,))
+ return True
+
+ def _should_rewrite(self, name, fn_pypath, state):
+ # always rewrite conftest files
+ fn = str(fn_pypath)
+ if fn_pypath.basename == "conftest.py":
+ state.trace("rewriting conftest file: %r" % (fn,))
+ return True
+
+ if self.session is not None:
+ if self.session.isinitpath(fn):
+ state.trace("matched test file (was specified on cmdline): %r" % (fn,))
+ return True
+
+ # modules not passed explicitly on the command line are only
+ # rewritten if they match the naming convention for test files
+ for pat in self.fnpats:
+ if fn_pypath.fnmatch(pat):
+ state.trace("matched test file %r" % (fn,))
+ return True
+
+ return self._is_marked_for_rewrite(name, state)
+
+ def _is_marked_for_rewrite(self, name, state):
+ try:
+ return self._marked_for_rewrite_cache[name]
+ except KeyError:
+ for marked in self._must_rewrite:
+ if name == marked or name.startswith(marked + "."):
+ state.trace("matched marked file %r (from %r)" % (name, marked))
+ self._marked_for_rewrite_cache[name] = True
+ return True
+
+ self._marked_for_rewrite_cache[name] = False
+ return False
+
+ def mark_rewrite(self, *names):
+ """Mark import names as needing to be rewritten.
+
+ The named module or package as well as any nested modules will
+ be rewritten on import.
+ """
+ already_imported = (
+ set(names).intersection(sys.modules).difference(self._rewritten_names)
+ )
+ for name in already_imported:
+ if not AssertionRewriter.is_rewrite_disabled(
+ sys.modules[name].__doc__ or ""
+ ):
+ self._warn_already_imported(name)
+ self._must_rewrite.update(names)
+ self._marked_for_rewrite_cache.clear()
+
+ def _warn_already_imported(self, name):
+ from _pytest.warning_types import PytestAssertRewriteWarning
+ from _pytest.warnings import _issue_warning_captured
+
+ _issue_warning_captured(
+ PytestAssertRewriteWarning(
+ "Module already imported so cannot be rewritten: %s" % name
+ ),
+ self.config.hook,
+ stacklevel=5,
+ )
+
+ def load_module(self, name):
+ co, pyc = self.modules.pop(name)
+ if name in sys.modules:
+ # If there is an existing module object named 'fullname' in
+ # sys.modules, the loader must use that existing module. (Otherwise,
+ # the reload() builtin will not work correctly.)
+ mod = sys.modules[name]
+ else:
+ # I wish I could just call imp.load_compiled here, but __file__ has to
+ # be set properly. In Python 3.2+, this all would be handled correctly
+ # by load_compiled.
+ mod = sys.modules[name] = imp.new_module(name)
+ try:
+ mod.__file__ = co.co_filename
+ # Normally, this attribute is 3.2+.
+ mod.__cached__ = pyc
+ mod.__loader__ = self
+ # Normally, this attribute is 3.4+
+ mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
+ exec(co, mod.__dict__)
+ except: # noqa
+ if name in sys.modules:
+ del sys.modules[name]
+ raise
+ return sys.modules[name]
+
+ def is_package(self, name):
+ try:
+ fd, fn, desc = self._imp_find_module(name)
+ except ImportError:
+ return False
+ if fd is not None:
+ fd.close()
+ tp = desc[2]
+ return tp == imp.PKG_DIRECTORY
+
+ def get_data(self, pathname):
+ """Optional PEP302 get_data API.
+ """
+ with open(pathname, "rb") as f:
+ return f.read()
+
+
+def _write_pyc(state, co, source_stat, pyc):
+ # Technically, we don't have to have the same pyc format as
+ # (C)Python, since these "pycs" should never be seen by builtin
+ # import. However, there's little reason deviate, and I hope
+ # sometime to be able to use imp.load_compiled to load them. (See
+ # the comment in load_module above.)
+ try:
+ with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
+ fp.write(imp.get_magic())
+ # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
+ mtime = int(source_stat.mtime) & 0xFFFFFFFF
+ size = source_stat.size & 0xFFFFFFFF
+ # "<LL" stands for 2 unsigned longs, little-ending
+ fp.write(struct.pack("<LL", mtime, size))
+ fp.write(marshal.dumps(co))
+ except EnvironmentError as e:
+ state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno))
+ # we ignore any failure to write the cache file
+ # there are many reasons, permission-denied, __pycache__ being a
+ # file etc.
+ return False
+ return True
+
+
+RN = "\r\n".encode("utf-8")
+N = "\n".encode("utf-8")
+
+cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
+BOM_UTF8 = "\xef\xbb\xbf"
+
+
+def _rewrite_test(config, fn):
+ """Try to read and rewrite *fn* and return the code object."""
+ state = config._assertstate
+ try:
+ stat = fn.stat()
+ source = fn.read("rb")
+ except EnvironmentError:
+ return None, None
+ if ASCII_IS_DEFAULT_ENCODING:
+ # ASCII is the default encoding in Python 2. Without a coding
+ # declaration, Python 2 will complain about any bytes in the file
+ # outside the ASCII range. Sadly, this behavior does not extend to
+ # compile() or ast.parse(), which prefer to interpret the bytes as
+ # latin-1. (At least they properly handle explicit coding cookies.) To
+ # preserve this error behavior, we could force ast.parse() to use ASCII
+ # as the encoding by inserting a coding cookie. Unfortunately, that
+ # messes up line numbers. Thus, we have to check ourselves if anything
+ # is outside the ASCII range in the case no encoding is explicitly
+ # declared. For more context, see issue #269. Yay for Python 3 which
+ # gets this right.
+ end1 = source.find("\n")
+ end2 = source.find("\n", end1 + 1)
+ if (
+ not source.startswith(BOM_UTF8)
+ and cookie_re.match(source[0:end1]) is None
+ and cookie_re.match(source[end1 + 1 : end2]) is None
+ ):
+ if hasattr(state, "_indecode"):
+ # encodings imported us again, so don't rewrite.
+ return None, None
+ state._indecode = True
+ try:
+ try:
+ source.decode("ascii")
+ except UnicodeDecodeError:
+ # Let it fail in real import.
+ return None, None
+ finally:
+ del state._indecode
+ try:
+ tree = ast.parse(source, filename=fn.strpath)
+ except SyntaxError:
+ # Let this pop up again in the real import.
+ state.trace("failed to parse: %r" % (fn,))
+ return None, None
+ rewrite_asserts(tree, fn, config)
+ try:
+ co = compile(tree, fn.strpath, "exec", dont_inherit=True)
+ except SyntaxError:
+ # It's possible that this error is from some bug in the
+ # assertion rewriting, but I don't know of a fast way to tell.
+ state.trace("failed to compile: %r" % (fn,))
+ return None, None
+ return stat, co
+
+
+def _read_pyc(source, pyc, trace=lambda x: None):
+ """Possibly read a pytest pyc containing rewritten code.
+
+ Return rewritten code if successful or None if not.
+ """
+ try:
+ fp = open(pyc, "rb")
+ except IOError:
+ return None
+ with fp:
+ try:
+ mtime = int(source.mtime())
+ size = source.size()
+ data = fp.read(12)
+ except EnvironmentError as e:
+ trace("_read_pyc(%s): EnvironmentError %s" % (source, e))
+ return None
+ # Check for invalid or out of date pyc file.
+ if (
+ len(data) != 12
+ or data[:4] != imp.get_magic()
+ or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
+ ):
+ trace("_read_pyc(%s): invalid or out of date pyc" % source)
+ return None
+ try:
+ co = marshal.load(fp)
+ except Exception as e:
+ trace("_read_pyc(%s): marshal.load error %s" % (source, e))
+ return None
+ if not isinstance(co, types.CodeType):
+ trace("_read_pyc(%s): not a code object" % source)
+ return None
+ return co
+
+
+def rewrite_asserts(mod, module_path=None, config=None):
+ """Rewrite the assert statements in mod."""
+ AssertionRewriter(module_path, config).run(mod)
+
+
+def _saferepr(obj):
+ """Get a safe repr of an object for assertion error messages.
+
+ The assertion formatting (util.format_explanation()) requires
+ newlines to be escaped since they are a special character for it.
+ Normally assertion.util.format_explanation() does this but for a
+ custom repr it is possible to contain one of the special escape
+ sequences, especially '\n{' and '\n}' are likely to be present in
+ JSON reprs.
+
+ """
+ r = saferepr(obj)
+ # only occurs in python2.x, repr must return text in python3+
+ if isinstance(r, bytes):
+ # Represent unprintable bytes as `\x##`
+ r = u"".join(
+ u"\\x{:x}".format(ord(c)) if c not in string.printable else c.decode()
+ for c in r
+ )
+ return r.replace(u"\n", u"\\n")
+
+
+def _format_assertmsg(obj):
+ """Format the custom assertion message given.
+
+ For strings this simply replaces newlines with '\n~' so that
+ util.format_explanation() will preserve them instead of escaping
+ newlines. For other objects saferepr() is used first.
+
+ """
+ # reprlib appears to have a bug which means that if a string
+ # contains a newline it gets escaped, however if an object has a
+ # .__repr__() which contains newlines it does not get escaped.
+ # However in either case we want to preserve the newline.
+ replaces = [(u"\n", u"\n~"), (u"%", u"%%")]
+ if not isinstance(obj, six.string_types):
+ obj = saferepr(obj)
+ replaces.append((u"\\n", u"\n~"))
+
+ if isinstance(obj, bytes):
+ replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces]
+
+ for r1, r2 in replaces:
+ obj = obj.replace(r1, r2)
+
+ return obj
+
+
+def _should_repr_global_name(obj):
+ if callable(obj):
+ return False
+
+ try:
+ return not hasattr(obj, "__name__")
+ except Exception:
+ return True
+
+
+def _format_boolop(explanations, is_or):
+ explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
+ if isinstance(explanation, six.text_type):
+ return explanation.replace(u"%", u"%%")
+ else:
+ return explanation.replace(b"%", b"%%")
+
+
+def _call_reprcompare(ops, results, expls, each_obj):
+ for i, res, expl in zip(range(len(ops)), results, expls):
+ try:
+ done = not res
+ except Exception:
+ done = True
+ if done:
+ break
+ if util._reprcompare is not None:
+ custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
+ if custom is not None:
+ return custom
+ return expl
+
+
+unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
+
+binop_map = {
+ ast.BitOr: "|",
+ ast.BitXor: "^",
+ ast.BitAnd: "&",
+ ast.LShift: "<<",
+ ast.RShift: ">>",
+ ast.Add: "+",
+ ast.Sub: "-",
+ ast.Mult: "*",
+ ast.Div: "/",
+ ast.FloorDiv: "//",
+ ast.Mod: "%%", # escaped for string formatting
+ ast.Eq: "==",
+ ast.NotEq: "!=",
+ ast.Lt: "<",
+ ast.LtE: "<=",
+ ast.Gt: ">",
+ ast.GtE: ">=",
+ ast.Pow: "**",
+ ast.Is: "is",
+ ast.IsNot: "is not",
+ ast.In: "in",
+ ast.NotIn: "not in",
+}
+# Python 3.5+ compatibility
+try:
+ binop_map[ast.MatMult] = "@"
+except AttributeError:
+ pass
+
+# Python 3.4+ compatibility
+if hasattr(ast, "NameConstant"):
+ _NameConstant = ast.NameConstant
+else:
+
+ def _NameConstant(c):
+ return ast.Name(str(c), ast.Load())
+
+
+def set_location(node, lineno, col_offset):
+ """Set node location information recursively."""
+
+ def _fix(node, lineno, col_offset):
+ if "lineno" in node._attributes:
+ node.lineno = lineno
+ if "col_offset" in node._attributes:
+ node.col_offset = col_offset
+ for child in ast.iter_child_nodes(node):
+ _fix(child, lineno, col_offset)
+
+ _fix(node, lineno, col_offset)
+ return node
+
+
+class AssertionRewriter(ast.NodeVisitor):
+ """Assertion rewriting implementation.
+
+ The main entrypoint is to call .run() with an ast.Module instance,
+ this will then find all the assert statements and rewrite them to
+ provide intermediate values and a detailed assertion error. See
+ http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
+ for an overview of how this works.
+
+ The entry point here is .run() which will iterate over all the
+ statements in an ast.Module and for each ast.Assert statement it
+ finds call .visit() with it. Then .visit_Assert() takes over and
+ is responsible for creating new ast statements to replace the
+ original assert statement: it rewrites the test of an assertion
+ to provide intermediate values and replace it with an if statement
+ which raises an assertion error with a detailed explanation in
+ case the expression is false.
+
+ For this .visit_Assert() uses the visitor pattern to visit all the
+ AST nodes of the ast.Assert.test field, each visit call returning
+ an AST node and the corresponding explanation string. During this
+ state is kept in several instance attributes:
+
+ :statements: All the AST statements which will replace the assert
+ statement.
+
+ :variables: This is populated by .variable() with each variable
+ used by the statements so that they can all be set to None at
+ the end of the statements.
+
+ :variable_counter: Counter to create new unique variables needed
+ by statements. Variables are created using .variable() and
+ have the form of "@py_assert0".
+
+ :on_failure: The AST statements which will be executed if the
+ assertion test fails. This is the code which will construct
+ the failure message and raises the AssertionError.
+
+ :explanation_specifiers: A dict filled by .explanation_param()
+ with %-formatting placeholders and their corresponding
+ expressions to use in the building of an assertion message.
+ This is used by .pop_format_context() to build a message.
+
+ :stack: A stack of the explanation_specifiers dicts maintained by
+ .push_format_context() and .pop_format_context() which allows
+ to build another %-formatted string while already building one.
+
+ This state is reset on every new assert statement visited and used
+ by the other visitors.
+
+ """
+
+ def __init__(self, module_path, config):
+ super(AssertionRewriter, self).__init__()
+ self.module_path = module_path
+ self.config = config
+
+ def run(self, mod):
+ """Find all assert statements in *mod* and rewrite them."""
+ if not mod.body:
+ # Nothing to do.
+ return
+ # Insert some special imports at the top of the module but after any
+ # docstrings and __future__ imports.
+ aliases = [
+ ast.alias(six.moves.builtins.__name__, "@py_builtins"),
+ ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
+ ]
+ doc = getattr(mod, "docstring", None)
+ expect_docstring = doc is None
+ if doc is not None and self.is_rewrite_disabled(doc):
+ return
+ pos = 0
+ lineno = 1
+ for item in mod.body:
+ if (
+ expect_docstring
+ and isinstance(item, ast.Expr)
+ and isinstance(item.value, ast.Str)
+ ):
+ doc = item.value.s
+ if self.is_rewrite_disabled(doc):
+ return
+ expect_docstring = False
+ elif (
+ not isinstance(item, ast.ImportFrom)
+ or item.level > 0
+ or item.module != "__future__"
+ ):
+ lineno = item.lineno
+ break
+ pos += 1
+ else:
+ lineno = item.lineno
+ imports = [
+ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
+ ]
+ mod.body[pos:pos] = imports
+ # Collect asserts.
+ nodes = [mod]
+ while nodes:
+ node = nodes.pop()
+ for name, field in ast.iter_fields(node):
+ if isinstance(field, list):
+ new = []
+ for i, child in enumerate(field):
+ if isinstance(child, ast.Assert):
+ # Transform assert.
+ new.extend(self.visit(child))
+ else:
+ new.append(child)
+ if isinstance(child, ast.AST):
+ nodes.append(child)
+ setattr(node, name, new)
+ elif (
+ isinstance(field, ast.AST)
+ # Don't recurse into expressions as they can't contain
+ # asserts.
+ and not isinstance(field, ast.expr)
+ ):
+ nodes.append(field)
+
+ @staticmethod
+ def is_rewrite_disabled(docstring):
+ return "PYTEST_DONT_REWRITE" in docstring
+
+ def variable(self):
+ """Get a new variable."""
+ # Use a character invalid in python identifiers to avoid clashing.
+ name = "@py_assert" + str(next(self.variable_counter))
+ self.variables.append(name)
+ return name
+
+ def assign(self, expr):
+ """Give *expr* a name."""
+ name = self.variable()
+ self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
+ return ast.Name(name, ast.Load())
+
+ def display(self, expr):
+ """Call saferepr on the expression."""
+ return self.helper("_saferepr", expr)
+
+ def helper(self, name, *args):
+ """Call a helper in this module."""
+ py_name = ast.Name("@pytest_ar", ast.Load())
+ attr = ast.Attribute(py_name, name, ast.Load())
+ return ast_Call(attr, list(args), [])
+
+ def builtin(self, name):
+ """Return the builtin called *name*."""
+ builtin_name = ast.Name("@py_builtins", ast.Load())
+ return ast.Attribute(builtin_name, name, ast.Load())
+
+ def explanation_param(self, expr):
+ """Return a new named %-formatting placeholder for expr.
+
+ This creates a %-formatting placeholder for expr in the
+ current formatting context, e.g. ``%(py0)s``. The placeholder
+ and expr are placed in the current format context so that it
+ can be used on the next call to .pop_format_context().
+
+ """
+ specifier = "py" + str(next(self.variable_counter))
+ self.explanation_specifiers[specifier] = expr
+ return "%(" + specifier + ")s"
+
+ def push_format_context(self):
+ """Create a new formatting context.
+
+ The format context is used for when an explanation wants to
+ have a variable value formatted in the assertion message. In
+ this case the value required can be added using
+ .explanation_param(). Finally .pop_format_context() is used
+ to format a string of %-formatted values as added by
+ .explanation_param().
+
+ """
+ self.explanation_specifiers = {}
+ self.stack.append(self.explanation_specifiers)
+
+ def pop_format_context(self, expl_expr):
+ """Format the %-formatted string with current format context.
+
+ The expl_expr should be an ast.Str instance constructed from
+ the %-placeholders created by .explanation_param(). This will
+ add the required code to format said string to .on_failure and
+ return the ast.Name instance of the formatted string.
+
+ """
+ current = self.stack.pop()
+ if self.stack:
+ self.explanation_specifiers = self.stack[-1]
+ keys = [ast.Str(key) for key in current.keys()]
+ format_dict = ast.Dict(keys, list(current.values()))
+ form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
+ name = "@py_format" + str(next(self.variable_counter))
+ self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
+ return ast.Name(name, ast.Load())
+
+ def generic_visit(self, node):
+ """Handle expressions we don't have custom code for."""
+ assert isinstance(node, ast.expr)
+ res = self.assign(node)
+ return res, self.explanation_param(self.display(res))
+
+ def visit_Assert(self, assert_):
+ """Return the AST statements to replace the ast.Assert instance.
+
+ This rewrites the test of an assertion to provide
+ intermediate values and replace it with an if statement which
+ raises an assertion error with a detailed explanation in case
+ the expression is false.
+
+ """
+ if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1:
+ from _pytest.warning_types import PytestAssertRewriteWarning
+ import warnings
+
+ warnings.warn_explicit(
+ PytestAssertRewriteWarning(
+ "assertion is always true, perhaps remove parentheses?"
+ ),
+ category=None,
+ filename=str(self.module_path),
+ lineno=assert_.lineno,
+ )
+
+ self.statements = []
+ self.variables = []
+ self.variable_counter = itertools.count()
+ self.stack = []
+ self.on_failure = []
+ self.push_format_context()
+ # Rewrite assert into a bunch of statements.
+ top_condition, explanation = self.visit(assert_.test)
+ # If in a test module, check if directly asserting None, in order to warn [Issue #3191]
+ if self.module_path is not None:
+ self.statements.append(
+ self.warn_about_none_ast(
+ top_condition, module_path=self.module_path, lineno=assert_.lineno
+ )
+ )
+ # Create failure message.
+ body = self.on_failure
+ negation = ast.UnaryOp(ast.Not(), top_condition)
+ self.statements.append(ast.If(negation, body, []))
+ if assert_.msg:
+ assertmsg = self.helper("_format_assertmsg", assert_.msg)
+ explanation = "\n>assert " + explanation
+ else:
+ assertmsg = ast.Str("")
+ explanation = "assert " + explanation
+ template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
+ msg = self.pop_format_context(template)
+ fmt = self.helper("_format_explanation", msg)
+ err_name = ast.Name("AssertionError", ast.Load())
+ exc = ast_Call(err_name, [fmt], [])
+ if sys.version_info[0] >= 3:
+ raise_ = ast.Raise(exc, None)
+ else:
+ raise_ = ast.Raise(exc, None, None)
+ body.append(raise_)
+ # Clear temporary variables by setting them to None.
+ if self.variables:
+ variables = [ast.Name(name, ast.Store()) for name in self.variables]
+ clear = ast.Assign(variables, _NameConstant(None))
+ self.statements.append(clear)
+ # Fix line numbers.
+ for stmt in self.statements:
+ set_location(stmt, assert_.lineno, assert_.col_offset)
+ return self.statements
+
+ def warn_about_none_ast(self, node, module_path, lineno):
+ """
+ Returns an AST issuing a warning if the value of node is `None`.
+ This is used to warn the user when asserting a function that asserts
+ internally already.
+ See issue #3191 for more details.
+ """
+
+ # Using parse because it is different between py2 and py3.
+ AST_NONE = ast.parse("None").body[0].value
+ val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
+ send_warning = ast.parse(
+ """
+from _pytest.warning_types import PytestAssertRewriteWarning
+from warnings import warn_explicit
+warn_explicit(
+ PytestAssertRewriteWarning('asserting the value None, please use "assert is None"'),
+ category=None,
+ filename={filename!r},
+ lineno={lineno},
+)
+ """.format(
+ filename=module_path.strpath, lineno=lineno
+ )
+ ).body
+ return ast.If(val_is_none, send_warning, [])
+
+ def visit_Name(self, name):
+ # Display the repr of the name if it's a local variable or
+ # _should_repr_global_name() thinks it's acceptable.
+ locs = ast_Call(self.builtin("locals"), [], [])
+ inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
+ dorepr = self.helper("_should_repr_global_name", name)
+ test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
+ expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
+ return name, self.explanation_param(expr)
+
+ def visit_BoolOp(self, boolop):
+ res_var = self.variable()
+ expl_list = self.assign(ast.List([], ast.Load()))
+ app = ast.Attribute(expl_list, "append", ast.Load())
+ is_or = int(isinstance(boolop.op, ast.Or))
+ body = save = self.statements
+ fail_save = self.on_failure
+ levels = len(boolop.values) - 1
+ self.push_format_context()
+ # Process each operand, short-circuting if needed.
+ for i, v in enumerate(boolop.values):
+ if i:
+ fail_inner = []
+ # cond is set in a prior loop iteration below
+ self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
+ self.on_failure = fail_inner
+ self.push_format_context()
+ res, expl = self.visit(v)
+ body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
+ expl_format = self.pop_format_context(ast.Str(expl))
+ call = ast_Call(app, [expl_format], [])
+ self.on_failure.append(ast.Expr(call))
+ if i < levels:
+ cond = res
+ if is_or:
+ cond = ast.UnaryOp(ast.Not(), cond)
+ inner = []
+ self.statements.append(ast.If(cond, inner, []))
+ self.statements = body = inner
+ self.statements = save
+ self.on_failure = fail_save
+ expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
+ expl = self.pop_format_context(expl_template)
+ return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
+
+ def visit_UnaryOp(self, unary):
+ pattern = unary_map[unary.op.__class__]
+ operand_res, operand_expl = self.visit(unary.operand)
+ res = self.assign(ast.UnaryOp(unary.op, operand_res))
+ return res, pattern % (operand_expl,)
+
+ def visit_BinOp(self, binop):
+ symbol = binop_map[binop.op.__class__]
+ left_expr, left_expl = self.visit(binop.left)
+ right_expr, right_expl = self.visit(binop.right)
+ explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
+ res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
+ return res, explanation
+
+ def visit_Call_35(self, call):
+ """
+ visit `ast.Call` nodes on Python3.5 and after
+ """
+ new_func, func_expl = self.visit(call.func)
+ arg_expls = []
+ new_args = []
+ new_kwargs = []
+ for arg in call.args:
+ res, expl = self.visit(arg)
+ arg_expls.append(expl)
+ new_args.append(res)
+ for keyword in call.keywords:
+ res, expl = self.visit(keyword.value)
+ new_kwargs.append(ast.keyword(keyword.arg, res))
+ if keyword.arg:
+ arg_expls.append(keyword.arg + "=" + expl)
+ else: # **args have `arg` keywords with an .arg of None
+ arg_expls.append("**" + expl)
+
+ expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
+ new_call = ast.Call(new_func, new_args, new_kwargs)
+ res = self.assign(new_call)
+ res_expl = self.explanation_param(self.display(res))
+ outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
+ return res, outer_expl
+
+ def visit_Starred(self, starred):
+ # From Python 3.5, a Starred node can appear in a function call
+ res, expl = self.visit(starred.value)
+ new_starred = ast.Starred(res, starred.ctx)
+ return new_starred, "*" + expl
+
+ def visit_Call_legacy(self, call):
+ """
+ visit `ast.Call nodes on 3.4 and below`
+ """
+ new_func, func_expl = self.visit(call.func)
+ arg_expls = []
+ new_args = []
+ new_kwargs = []
+ new_star = new_kwarg = None
+ for arg in call.args:
+ res, expl = self.visit(arg)
+ new_args.append(res)
+ arg_expls.append(expl)
+ for keyword in call.keywords:
+ res, expl = self.visit(keyword.value)
+ new_kwargs.append(ast.keyword(keyword.arg, res))
+ arg_expls.append(keyword.arg + "=" + expl)
+ if call.starargs:
+ new_star, expl = self.visit(call.starargs)
+ arg_expls.append("*" + expl)
+ if call.kwargs:
+ new_kwarg, expl = self.visit(call.kwargs)
+ arg_expls.append("**" + expl)
+ expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
+ new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
+ res = self.assign(new_call)
+ res_expl = self.explanation_param(self.display(res))
+ outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
+ return res, outer_expl
+
+ # ast.Call signature changed on 3.5,
+ # conditionally change which methods is named
+ # visit_Call depending on Python version
+ if sys.version_info >= (3, 5):
+ visit_Call = visit_Call_35
+ else:
+ visit_Call = visit_Call_legacy
+
+ def visit_Attribute(self, attr):
+ if not isinstance(attr.ctx, ast.Load):
+ return self.generic_visit(attr)
+ value, value_expl = self.visit(attr.value)
+ res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
+ res_expl = self.explanation_param(self.display(res))
+ pat = "%s\n{%s = %s.%s\n}"
+ expl = pat % (res_expl, res_expl, value_expl, attr.attr)
+ return res, expl
+
+ def visit_Compare(self, comp):
+ self.push_format_context()
+ left_res, left_expl = self.visit(comp.left)
+ if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
+ left_expl = "({})".format(left_expl)
+ res_variables = [self.variable() for i in range(len(comp.ops))]
+ load_names = [ast.Name(v, ast.Load()) for v in res_variables]
+ store_names = [ast.Name(v, ast.Store()) for v in res_variables]
+ it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
+ expls = []
+ syms = []
+ results = [left_res]
+ for i, op, next_operand in it:
+ next_res, next_expl = self.visit(next_operand)
+ if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
+ next_expl = "({})".format(next_expl)
+ results.append(next_res)
+ sym = binop_map[op.__class__]
+ syms.append(ast.Str(sym))
+ expl = "%s %s %s" % (left_expl, sym, next_expl)
+ expls.append(ast.Str(expl))
+ res_expr = ast.Compare(left_res, [op], [next_res])
+ self.statements.append(ast.Assign([store_names[i]], res_expr))
+ left_res, left_expl = next_res, next_expl
+ # Use pytest.assertion.util._reprcompare if that's available.
+ expl_call = self.helper(
+ "_call_reprcompare",
+ ast.Tuple(syms, ast.Load()),
+ ast.Tuple(load_names, ast.Load()),
+ ast.Tuple(expls, ast.Load()),
+ ast.Tuple(results, ast.Load()),
+ )
+ if len(comp.ops) > 1:
+ res = ast.BoolOp(ast.And(), load_names)
+ else:
+ res = load_names[0]
+ return res, self.explanation_param(self.pop_format_context(expl_call))