diff options
author | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:24:06 +0300 |
---|---|---|
committer | nkozlovskiy <nmk@ydb.tech> | 2023-09-29 12:41:34 +0300 |
commit | e0e3e1717e3d33762ce61950504f9637a6e669ed (patch) | |
tree | bca3ff6939b10ed60c3d5c12439963a1146b9711 /contrib/python/pytest/py2/_pytest/assertion/rewrite.py | |
parent | 38f2c5852db84c7b4d83adfcb009eb61541d1ccd (diff) | |
download | ydb-e0e3e1717e3d33762ce61950504f9637a6e669ed.tar.gz |
add ydb deps
Diffstat (limited to 'contrib/python/pytest/py2/_pytest/assertion/rewrite.py')
-rw-r--r-- | contrib/python/pytest/py2/_pytest/assertion/rewrite.py | 1072 |
1 files changed, 1072 insertions, 0 deletions
diff --git a/contrib/python/pytest/py2/_pytest/assertion/rewrite.py b/contrib/python/pytest/py2/_pytest/assertion/rewrite.py new file mode 100644 index 0000000000..6cfd81a32f --- /dev/null +++ b/contrib/python/pytest/py2/_pytest/assertion/rewrite.py @@ -0,0 +1,1072 @@ +# -*- coding: utf-8 -*- +"""Rewrite assertion AST to produce nice error messages""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning, module="_pytest.assertion.rewrite") + +import ast +import errno +import imp +import itertools +import marshal +import os +import re +import string +import struct +import sys +import types + +import atomicwrites +import py +import six + +from _pytest._io.saferepr import saferepr +from _pytest.assertion import util +from _pytest.assertion.util import ( # noqa: F401 + format_explanation as _format_explanation, +) +from _pytest.compat import spec_from_file_location +from _pytest.pathlib import fnmatch_ex +from _pytest.pathlib import PurePath + +# pytest caches rewritten pycs in __pycache__. +if hasattr(imp, "get_tag"): + PYTEST_TAG = imp.get_tag() + "-PYTEST" +else: + if hasattr(sys, "pypy_version_info"): + impl = "pypy" + elif sys.platform == "java": + impl = "jython" + else: + impl = "cpython" + ver = sys.version_info + PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1]) + del ver, impl + +PYC_EXT = ".py" + (__debug__ and "c" or "o") +PYC_TAIL = "." + PYTEST_TAG + PYC_EXT + +ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3 + +if sys.version_info >= (3, 5): + ast_Call = ast.Call +else: + + def ast_Call(a, b, c): + return ast.Call(a, b, c, None, None) + + +class AssertionRewritingHook(object): + """PEP302 Import hook which rewrites asserts.""" + + def __init__(self, config): + self.config = config + try: + self.fnpats = config.getini("python_files") + except ValueError: + self.fnpats = ["test_*.py", "*_test.py"] + self.session = None + self.modules = {} + self._rewritten_names = set() + self._must_rewrite = set() + # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file, + # which might result in infinite recursion (#3506) + self._writing_pyc = False + self._basenames_to_check_rewrite = {"conftest"} + self._marked_for_rewrite_cache = {} + self._session_paths_checked = False + + def set_session(self, session): + self.session = session + self._session_paths_checked = False + + def _imp_find_module(self, name, path=None): + """Indirection so we can mock calls to find_module originated from the hook during testing""" + return imp.find_module(name, path) + + def find_module(self, name, path=None): + if self._writing_pyc: + return None + state = self.config._assertstate + if self._early_rewrite_bailout(name, state): + return None + state.trace("find_module called for: %s" % name) + names = name.rsplit(".", 1) + lastname = names[-1] + pth = None + if path is not None: + # Starting with Python 3.3, path is a _NamespacePath(), which + # causes problems if not converted to list. + path = list(path) + if len(path) == 1: + pth = path[0] + if pth is None: + try: + fd, fn, desc = self._imp_find_module(lastname, path) + except ImportError: + return None + if fd is not None: + fd.close() + tp = desc[2] + if tp == imp.PY_COMPILED: + if hasattr(imp, "source_from_cache"): + try: + fn = imp.source_from_cache(fn) + except ValueError: + # Python 3 doesn't like orphaned but still-importable + # .pyc files. + fn = fn[:-1] + else: + fn = fn[:-1] + elif tp != imp.PY_SOURCE: + # Don't know what this is. + return None + else: + fn = os.path.join(pth, name.rpartition(".")[2] + ".py") + + fn_pypath = py.path.local(fn) + if not self._should_rewrite(name, fn_pypath, state): + return None + + self._rewritten_names.add(name) + + # The requested module looks like a test file, so rewrite it. This is + # the most magical part of the process: load the source, rewrite the + # asserts, and load the rewritten source. We also cache the rewritten + # module code in a special pyc. We must be aware of the possibility of + # concurrent pytest processes rewriting and loading pycs. To avoid + # tricky race conditions, we maintain the following invariant: The + # cached pyc is always a complete, valid pyc. Operations on it must be + # atomic. POSIX's atomic rename comes in handy. + write = not sys.dont_write_bytecode + cache_dir = os.path.join(fn_pypath.dirname, "__pycache__") + if write: + try: + os.mkdir(cache_dir) + except OSError: + e = sys.exc_info()[1].errno + if e == errno.EEXIST: + # Either the __pycache__ directory already exists (the + # common case) or it's blocked by a non-dir node. In the + # latter case, we'll ignore it in _write_pyc. + pass + elif e in [errno.ENOENT, errno.ENOTDIR]: + # One of the path components was not a directory, likely + # because we're in a zip file. + write = False + elif e in [errno.EACCES, errno.EROFS, errno.EPERM]: + state.trace("read only directory: %r" % fn_pypath.dirname) + write = False + else: + raise + cache_name = fn_pypath.basename[:-3] + PYC_TAIL + pyc = os.path.join(cache_dir, cache_name) + # Notice that even if we're in a read-only directory, I'm going + # to check for a cached pyc. This may not be optimal... + co = _read_pyc(fn_pypath, pyc, state.trace) + if co is None: + state.trace("rewriting %r" % (fn,)) + source_stat, co = _rewrite_test(self.config, fn_pypath) + if co is None: + # Probably a SyntaxError in the test. + return None + if write: + self._writing_pyc = True + try: + _write_pyc(state, co, source_stat, pyc) + finally: + self._writing_pyc = False + else: + state.trace("found cached rewritten pyc for %r" % (fn,)) + self.modules[name] = co, pyc + return self + + def _early_rewrite_bailout(self, name, state): + """ + This is a fast way to get out of rewriting modules. Profiling has + shown that the call to imp.find_module (inside of the find_module + from this class) is a major slowdown, so, this method tries to + filter what we're sure won't be rewritten before getting to it. + """ + if self.session is not None and not self._session_paths_checked: + self._session_paths_checked = True + for path in self.session._initialpaths: + # Make something as c:/projects/my_project/path.py -> + # ['c:', 'projects', 'my_project', 'path.py'] + parts = str(path).split(os.path.sep) + # add 'path' to basenames to be checked. + self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0]) + + # Note: conftest already by default in _basenames_to_check_rewrite. + parts = name.split(".") + if parts[-1] in self._basenames_to_check_rewrite: + return False + + # For matching the name it must be as if it was a filename. + path = PurePath(os.path.sep.join(parts) + ".py") + + for pat in self.fnpats: + # if the pattern contains subdirectories ("tests/**.py" for example) we can't bail out based + # on the name alone because we need to match against the full path + if os.path.dirname(pat): + return False + if fnmatch_ex(pat, path): + return False + + if self._is_marked_for_rewrite(name, state): + return False + + state.trace("early skip of rewriting module: %s" % (name,)) + return True + + def _should_rewrite(self, name, fn_pypath, state): + # always rewrite conftest files + fn = str(fn_pypath) + if fn_pypath.basename == "conftest.py": + state.trace("rewriting conftest file: %r" % (fn,)) + return True + + if self.session is not None: + if self.session.isinitpath(fn): + state.trace("matched test file (was specified on cmdline): %r" % (fn,)) + return True + + # modules not passed explicitly on the command line are only + # rewritten if they match the naming convention for test files + for pat in self.fnpats: + if fn_pypath.fnmatch(pat): + state.trace("matched test file %r" % (fn,)) + return True + + return self._is_marked_for_rewrite(name, state) + + def _is_marked_for_rewrite(self, name, state): + try: + return self._marked_for_rewrite_cache[name] + except KeyError: + for marked in self._must_rewrite: + if name == marked or name.startswith(marked + "."): + state.trace("matched marked file %r (from %r)" % (name, marked)) + self._marked_for_rewrite_cache[name] = True + return True + + self._marked_for_rewrite_cache[name] = False + return False + + def mark_rewrite(self, *names): + """Mark import names as needing to be rewritten. + + The named module or package as well as any nested modules will + be rewritten on import. + """ + already_imported = ( + set(names).intersection(sys.modules).difference(self._rewritten_names) + ) + for name in already_imported: + if not AssertionRewriter.is_rewrite_disabled( + sys.modules[name].__doc__ or "" + ): + self._warn_already_imported(name) + self._must_rewrite.update(names) + self._marked_for_rewrite_cache.clear() + + def _warn_already_imported(self, name): + from _pytest.warning_types import PytestAssertRewriteWarning + from _pytest.warnings import _issue_warning_captured + + _issue_warning_captured( + PytestAssertRewriteWarning( + "Module already imported so cannot be rewritten: %s" % name + ), + self.config.hook, + stacklevel=5, + ) + + def load_module(self, name): + co, pyc = self.modules.pop(name) + if name in sys.modules: + # If there is an existing module object named 'fullname' in + # sys.modules, the loader must use that existing module. (Otherwise, + # the reload() builtin will not work correctly.) + mod = sys.modules[name] + else: + # I wish I could just call imp.load_compiled here, but __file__ has to + # be set properly. In Python 3.2+, this all would be handled correctly + # by load_compiled. + mod = sys.modules[name] = imp.new_module(name) + try: + mod.__file__ = co.co_filename + # Normally, this attribute is 3.2+. + mod.__cached__ = pyc + mod.__loader__ = self + # Normally, this attribute is 3.4+ + mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self) + exec(co, mod.__dict__) + except: # noqa + if name in sys.modules: + del sys.modules[name] + raise + return sys.modules[name] + + def is_package(self, name): + try: + fd, fn, desc = self._imp_find_module(name) + except ImportError: + return False + if fd is not None: + fd.close() + tp = desc[2] + return tp == imp.PKG_DIRECTORY + + def get_data(self, pathname): + """Optional PEP302 get_data API. + """ + with open(pathname, "rb") as f: + return f.read() + + +def _write_pyc(state, co, source_stat, pyc): + # Technically, we don't have to have the same pyc format as + # (C)Python, since these "pycs" should never be seen by builtin + # import. However, there's little reason deviate, and I hope + # sometime to be able to use imp.load_compiled to load them. (See + # the comment in load_module above.) + try: + with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp: + fp.write(imp.get_magic()) + # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) + mtime = int(source_stat.mtime) & 0xFFFFFFFF + size = source_stat.size & 0xFFFFFFFF + # "<LL" stands for 2 unsigned longs, little-ending + fp.write(struct.pack("<LL", mtime, size)) + fp.write(marshal.dumps(co)) + except EnvironmentError as e: + state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno)) + # we ignore any failure to write the cache file + # there are many reasons, permission-denied, __pycache__ being a + # file etc. + return False + return True + + +RN = "\r\n".encode("utf-8") +N = "\n".encode("utf-8") + +cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+") +BOM_UTF8 = "\xef\xbb\xbf" + + +def _rewrite_test(config, fn): + """Try to read and rewrite *fn* and return the code object.""" + state = config._assertstate + try: + stat = fn.stat() + source = fn.read("rb") + except EnvironmentError: + return None, None + if ASCII_IS_DEFAULT_ENCODING: + # ASCII is the default encoding in Python 2. Without a coding + # declaration, Python 2 will complain about any bytes in the file + # outside the ASCII range. Sadly, this behavior does not extend to + # compile() or ast.parse(), which prefer to interpret the bytes as + # latin-1. (At least they properly handle explicit coding cookies.) To + # preserve this error behavior, we could force ast.parse() to use ASCII + # as the encoding by inserting a coding cookie. Unfortunately, that + # messes up line numbers. Thus, we have to check ourselves if anything + # is outside the ASCII range in the case no encoding is explicitly + # declared. For more context, see issue #269. Yay for Python 3 which + # gets this right. + end1 = source.find("\n") + end2 = source.find("\n", end1 + 1) + if ( + not source.startswith(BOM_UTF8) + and cookie_re.match(source[0:end1]) is None + and cookie_re.match(source[end1 + 1 : end2]) is None + ): + if hasattr(state, "_indecode"): + # encodings imported us again, so don't rewrite. + return None, None + state._indecode = True + try: + try: + source.decode("ascii") + except UnicodeDecodeError: + # Let it fail in real import. + return None, None + finally: + del state._indecode + try: + tree = ast.parse(source, filename=fn.strpath) + except SyntaxError: + # Let this pop up again in the real import. + state.trace("failed to parse: %r" % (fn,)) + return None, None + rewrite_asserts(tree, fn, config) + try: + co = compile(tree, fn.strpath, "exec", dont_inherit=True) + except SyntaxError: + # It's possible that this error is from some bug in the + # assertion rewriting, but I don't know of a fast way to tell. + state.trace("failed to compile: %r" % (fn,)) + return None, None + return stat, co + + +def _read_pyc(source, pyc, trace=lambda x: None): + """Possibly read a pytest pyc containing rewritten code. + + Return rewritten code if successful or None if not. + """ + try: + fp = open(pyc, "rb") + except IOError: + return None + with fp: + try: + mtime = int(source.mtime()) + size = source.size() + data = fp.read(12) + except EnvironmentError as e: + trace("_read_pyc(%s): EnvironmentError %s" % (source, e)) + return None + # Check for invalid or out of date pyc file. + if ( + len(data) != 12 + or data[:4] != imp.get_magic() + or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF) + ): + trace("_read_pyc(%s): invalid or out of date pyc" % source) + return None + try: + co = marshal.load(fp) + except Exception as e: + trace("_read_pyc(%s): marshal.load error %s" % (source, e)) + return None + if not isinstance(co, types.CodeType): + trace("_read_pyc(%s): not a code object" % source) + return None + return co + + +def rewrite_asserts(mod, module_path=None, config=None): + """Rewrite the assert statements in mod.""" + AssertionRewriter(module_path, config).run(mod) + + +def _saferepr(obj): + """Get a safe repr of an object for assertion error messages. + + The assertion formatting (util.format_explanation()) requires + newlines to be escaped since they are a special character for it. + Normally assertion.util.format_explanation() does this but for a + custom repr it is possible to contain one of the special escape + sequences, especially '\n{' and '\n}' are likely to be present in + JSON reprs. + + """ + r = saferepr(obj) + # only occurs in python2.x, repr must return text in python3+ + if isinstance(r, bytes): + # Represent unprintable bytes as `\x##` + r = u"".join( + u"\\x{:x}".format(ord(c)) if c not in string.printable else c.decode() + for c in r + ) + return r.replace(u"\n", u"\\n") + + +def _format_assertmsg(obj): + """Format the custom assertion message given. + + For strings this simply replaces newlines with '\n~' so that + util.format_explanation() will preserve them instead of escaping + newlines. For other objects saferepr() is used first. + + """ + # reprlib appears to have a bug which means that if a string + # contains a newline it gets escaped, however if an object has a + # .__repr__() which contains newlines it does not get escaped. + # However in either case we want to preserve the newline. + replaces = [(u"\n", u"\n~"), (u"%", u"%%")] + if not isinstance(obj, six.string_types): + obj = saferepr(obj) + replaces.append((u"\\n", u"\n~")) + + if isinstance(obj, bytes): + replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces] + + for r1, r2 in replaces: + obj = obj.replace(r1, r2) + + return obj + + +def _should_repr_global_name(obj): + if callable(obj): + return False + + try: + return not hasattr(obj, "__name__") + except Exception: + return True + + +def _format_boolop(explanations, is_or): + explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")" + if isinstance(explanation, six.text_type): + return explanation.replace(u"%", u"%%") + else: + return explanation.replace(b"%", b"%%") + + +def _call_reprcompare(ops, results, expls, each_obj): + for i, res, expl in zip(range(len(ops)), results, expls): + try: + done = not res + except Exception: + done = True + if done: + break + if util._reprcompare is not None: + custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1]) + if custom is not None: + return custom + return expl + + +unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"} + +binop_map = { + ast.BitOr: "|", + ast.BitXor: "^", + ast.BitAnd: "&", + ast.LShift: "<<", + ast.RShift: ">>", + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", + ast.FloorDiv: "//", + ast.Mod: "%%", # escaped for string formatting + ast.Eq: "==", + ast.NotEq: "!=", + ast.Lt: "<", + ast.LtE: "<=", + ast.Gt: ">", + ast.GtE: ">=", + ast.Pow: "**", + ast.Is: "is", + ast.IsNot: "is not", + ast.In: "in", + ast.NotIn: "not in", +} +# Python 3.5+ compatibility +try: + binop_map[ast.MatMult] = "@" +except AttributeError: + pass + +# Python 3.4+ compatibility +if hasattr(ast, "NameConstant"): + _NameConstant = ast.NameConstant +else: + + def _NameConstant(c): + return ast.Name(str(c), ast.Load()) + + +def set_location(node, lineno, col_offset): + """Set node location information recursively.""" + + def _fix(node, lineno, col_offset): + if "lineno" in node._attributes: + node.lineno = lineno + if "col_offset" in node._attributes: + node.col_offset = col_offset + for child in ast.iter_child_nodes(node): + _fix(child, lineno, col_offset) + + _fix(node, lineno, col_offset) + return node + + +class AssertionRewriter(ast.NodeVisitor): + """Assertion rewriting implementation. + + The main entrypoint is to call .run() with an ast.Module instance, + this will then find all the assert statements and rewrite them to + provide intermediate values and a detailed assertion error. See + http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html + for an overview of how this works. + + The entry point here is .run() which will iterate over all the + statements in an ast.Module and for each ast.Assert statement it + finds call .visit() with it. Then .visit_Assert() takes over and + is responsible for creating new ast statements to replace the + original assert statement: it rewrites the test of an assertion + to provide intermediate values and replace it with an if statement + which raises an assertion error with a detailed explanation in + case the expression is false. + + For this .visit_Assert() uses the visitor pattern to visit all the + AST nodes of the ast.Assert.test field, each visit call returning + an AST node and the corresponding explanation string. During this + state is kept in several instance attributes: + + :statements: All the AST statements which will replace the assert + statement. + + :variables: This is populated by .variable() with each variable + used by the statements so that they can all be set to None at + the end of the statements. + + :variable_counter: Counter to create new unique variables needed + by statements. Variables are created using .variable() and + have the form of "@py_assert0". + + :on_failure: The AST statements which will be executed if the + assertion test fails. This is the code which will construct + the failure message and raises the AssertionError. + + :explanation_specifiers: A dict filled by .explanation_param() + with %-formatting placeholders and their corresponding + expressions to use in the building of an assertion message. + This is used by .pop_format_context() to build a message. + + :stack: A stack of the explanation_specifiers dicts maintained by + .push_format_context() and .pop_format_context() which allows + to build another %-formatted string while already building one. + + This state is reset on every new assert statement visited and used + by the other visitors. + + """ + + def __init__(self, module_path, config): + super(AssertionRewriter, self).__init__() + self.module_path = module_path + self.config = config + + def run(self, mod): + """Find all assert statements in *mod* and rewrite them.""" + if not mod.body: + # Nothing to do. + return + # Insert some special imports at the top of the module but after any + # docstrings and __future__ imports. + aliases = [ + ast.alias(six.moves.builtins.__name__, "@py_builtins"), + ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), + ] + doc = getattr(mod, "docstring", None) + expect_docstring = doc is None + if doc is not None and self.is_rewrite_disabled(doc): + return + pos = 0 + lineno = 1 + for item in mod.body: + if ( + expect_docstring + and isinstance(item, ast.Expr) + and isinstance(item.value, ast.Str) + ): + doc = item.value.s + if self.is_rewrite_disabled(doc): + return + expect_docstring = False + elif ( + not isinstance(item, ast.ImportFrom) + or item.level > 0 + or item.module != "__future__" + ): + lineno = item.lineno + break + pos += 1 + else: + lineno = item.lineno + imports = [ + ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases + ] + mod.body[pos:pos] = imports + # Collect asserts. + nodes = [mod] + while nodes: + node = nodes.pop() + for name, field in ast.iter_fields(node): + if isinstance(field, list): + new = [] + for i, child in enumerate(field): + if isinstance(child, ast.Assert): + # Transform assert. + new.extend(self.visit(child)) + else: + new.append(child) + if isinstance(child, ast.AST): + nodes.append(child) + setattr(node, name, new) + elif ( + isinstance(field, ast.AST) + # Don't recurse into expressions as they can't contain + # asserts. + and not isinstance(field, ast.expr) + ): + nodes.append(field) + + @staticmethod + def is_rewrite_disabled(docstring): + return "PYTEST_DONT_REWRITE" in docstring + + def variable(self): + """Get a new variable.""" + # Use a character invalid in python identifiers to avoid clashing. + name = "@py_assert" + str(next(self.variable_counter)) + self.variables.append(name) + return name + + def assign(self, expr): + """Give *expr* a name.""" + name = self.variable() + self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr)) + return ast.Name(name, ast.Load()) + + def display(self, expr): + """Call saferepr on the expression.""" + return self.helper("_saferepr", expr) + + def helper(self, name, *args): + """Call a helper in this module.""" + py_name = ast.Name("@pytest_ar", ast.Load()) + attr = ast.Attribute(py_name, name, ast.Load()) + return ast_Call(attr, list(args), []) + + def builtin(self, name): + """Return the builtin called *name*.""" + builtin_name = ast.Name("@py_builtins", ast.Load()) + return ast.Attribute(builtin_name, name, ast.Load()) + + def explanation_param(self, expr): + """Return a new named %-formatting placeholder for expr. + + This creates a %-formatting placeholder for expr in the + current formatting context, e.g. ``%(py0)s``. The placeholder + and expr are placed in the current format context so that it + can be used on the next call to .pop_format_context(). + + """ + specifier = "py" + str(next(self.variable_counter)) + self.explanation_specifiers[specifier] = expr + return "%(" + specifier + ")s" + + def push_format_context(self): + """Create a new formatting context. + + The format context is used for when an explanation wants to + have a variable value formatted in the assertion message. In + this case the value required can be added using + .explanation_param(). Finally .pop_format_context() is used + to format a string of %-formatted values as added by + .explanation_param(). + + """ + self.explanation_specifiers = {} + self.stack.append(self.explanation_specifiers) + + def pop_format_context(self, expl_expr): + """Format the %-formatted string with current format context. + + The expl_expr should be an ast.Str instance constructed from + the %-placeholders created by .explanation_param(). This will + add the required code to format said string to .on_failure and + return the ast.Name instance of the formatted string. + + """ + current = self.stack.pop() + if self.stack: + self.explanation_specifiers = self.stack[-1] + keys = [ast.Str(key) for key in current.keys()] + format_dict = ast.Dict(keys, list(current.values())) + form = ast.BinOp(expl_expr, ast.Mod(), format_dict) + name = "@py_format" + str(next(self.variable_counter)) + self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form)) + return ast.Name(name, ast.Load()) + + def generic_visit(self, node): + """Handle expressions we don't have custom code for.""" + assert isinstance(node, ast.expr) + res = self.assign(node) + return res, self.explanation_param(self.display(res)) + + def visit_Assert(self, assert_): + """Return the AST statements to replace the ast.Assert instance. + + This rewrites the test of an assertion to provide + intermediate values and replace it with an if statement which + raises an assertion error with a detailed explanation in case + the expression is false. + + """ + if isinstance(assert_.test, ast.Tuple) and len(assert_.test.elts) >= 1: + from _pytest.warning_types import PytestAssertRewriteWarning + import warnings + + warnings.warn_explicit( + PytestAssertRewriteWarning( + "assertion is always true, perhaps remove parentheses?" + ), + category=None, + filename=str(self.module_path), + lineno=assert_.lineno, + ) + + self.statements = [] + self.variables = [] + self.variable_counter = itertools.count() + self.stack = [] + self.on_failure = [] + self.push_format_context() + # Rewrite assert into a bunch of statements. + top_condition, explanation = self.visit(assert_.test) + # If in a test module, check if directly asserting None, in order to warn [Issue #3191] + if self.module_path is not None: + self.statements.append( + self.warn_about_none_ast( + top_condition, module_path=self.module_path, lineno=assert_.lineno + ) + ) + # Create failure message. + body = self.on_failure + negation = ast.UnaryOp(ast.Not(), top_condition) + self.statements.append(ast.If(negation, body, [])) + if assert_.msg: + assertmsg = self.helper("_format_assertmsg", assert_.msg) + explanation = "\n>assert " + explanation + else: + assertmsg = ast.Str("") + explanation = "assert " + explanation + template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation)) + msg = self.pop_format_context(template) + fmt = self.helper("_format_explanation", msg) + err_name = ast.Name("AssertionError", ast.Load()) + exc = ast_Call(err_name, [fmt], []) + if sys.version_info[0] >= 3: + raise_ = ast.Raise(exc, None) + else: + raise_ = ast.Raise(exc, None, None) + body.append(raise_) + # Clear temporary variables by setting them to None. + if self.variables: + variables = [ast.Name(name, ast.Store()) for name in self.variables] + clear = ast.Assign(variables, _NameConstant(None)) + self.statements.append(clear) + # Fix line numbers. + for stmt in self.statements: + set_location(stmt, assert_.lineno, assert_.col_offset) + return self.statements + + def warn_about_none_ast(self, node, module_path, lineno): + """ + Returns an AST issuing a warning if the value of node is `None`. + This is used to warn the user when asserting a function that asserts + internally already. + See issue #3191 for more details. + """ + + # Using parse because it is different between py2 and py3. + AST_NONE = ast.parse("None").body[0].value + val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE]) + send_warning = ast.parse( + """ +from _pytest.warning_types import PytestAssertRewriteWarning +from warnings import warn_explicit +warn_explicit( + PytestAssertRewriteWarning('asserting the value None, please use "assert is None"'), + category=None, + filename={filename!r}, + lineno={lineno}, +) + """.format( + filename=module_path.strpath, lineno=lineno + ) + ).body + return ast.If(val_is_none, send_warning, []) + + def visit_Name(self, name): + # Display the repr of the name if it's a local variable or + # _should_repr_global_name() thinks it's acceptable. + locs = ast_Call(self.builtin("locals"), [], []) + inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs]) + dorepr = self.helper("_should_repr_global_name", name) + test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) + expr = ast.IfExp(test, self.display(name), ast.Str(name.id)) + return name, self.explanation_param(expr) + + def visit_BoolOp(self, boolop): + res_var = self.variable() + expl_list = self.assign(ast.List([], ast.Load())) + app = ast.Attribute(expl_list, "append", ast.Load()) + is_or = int(isinstance(boolop.op, ast.Or)) + body = save = self.statements + fail_save = self.on_failure + levels = len(boolop.values) - 1 + self.push_format_context() + # Process each operand, short-circuting if needed. + for i, v in enumerate(boolop.values): + if i: + fail_inner = [] + # cond is set in a prior loop iteration below + self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa + self.on_failure = fail_inner + self.push_format_context() + res, expl = self.visit(v) + body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) + expl_format = self.pop_format_context(ast.Str(expl)) + call = ast_Call(app, [expl_format], []) + self.on_failure.append(ast.Expr(call)) + if i < levels: + cond = res + if is_or: + cond = ast.UnaryOp(ast.Not(), cond) + inner = [] + self.statements.append(ast.If(cond, inner, [])) + self.statements = body = inner + self.statements = save + self.on_failure = fail_save + expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or)) + expl = self.pop_format_context(expl_template) + return ast.Name(res_var, ast.Load()), self.explanation_param(expl) + + def visit_UnaryOp(self, unary): + pattern = unary_map[unary.op.__class__] + operand_res, operand_expl = self.visit(unary.operand) + res = self.assign(ast.UnaryOp(unary.op, operand_res)) + return res, pattern % (operand_expl,) + + def visit_BinOp(self, binop): + symbol = binop_map[binop.op.__class__] + left_expr, left_expl = self.visit(binop.left) + right_expr, right_expl = self.visit(binop.right) + explanation = "(%s %s %s)" % (left_expl, symbol, right_expl) + res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) + return res, explanation + + def visit_Call_35(self, call): + """ + visit `ast.Call` nodes on Python3.5 and after + """ + new_func, func_expl = self.visit(call.func) + arg_expls = [] + new_args = [] + new_kwargs = [] + for arg in call.args: + res, expl = self.visit(arg) + arg_expls.append(expl) + new_args.append(res) + for keyword in call.keywords: + res, expl = self.visit(keyword.value) + new_kwargs.append(ast.keyword(keyword.arg, res)) + if keyword.arg: + arg_expls.append(keyword.arg + "=" + expl) + else: # **args have `arg` keywords with an .arg of None + arg_expls.append("**" + expl) + + expl = "%s(%s)" % (func_expl, ", ".join(arg_expls)) + new_call = ast.Call(new_func, new_args, new_kwargs) + res = self.assign(new_call) + res_expl = self.explanation_param(self.display(res)) + outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) + return res, outer_expl + + def visit_Starred(self, starred): + # From Python 3.5, a Starred node can appear in a function call + res, expl = self.visit(starred.value) + new_starred = ast.Starred(res, starred.ctx) + return new_starred, "*" + expl + + def visit_Call_legacy(self, call): + """ + visit `ast.Call nodes on 3.4 and below` + """ + new_func, func_expl = self.visit(call.func) + arg_expls = [] + new_args = [] + new_kwargs = [] + new_star = new_kwarg = None + for arg in call.args: + res, expl = self.visit(arg) + new_args.append(res) + arg_expls.append(expl) + for keyword in call.keywords: + res, expl = self.visit(keyword.value) + new_kwargs.append(ast.keyword(keyword.arg, res)) + arg_expls.append(keyword.arg + "=" + expl) + if call.starargs: + new_star, expl = self.visit(call.starargs) + arg_expls.append("*" + expl) + if call.kwargs: + new_kwarg, expl = self.visit(call.kwargs) + arg_expls.append("**" + expl) + expl = "%s(%s)" % (func_expl, ", ".join(arg_expls)) + new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg) + res = self.assign(new_call) + res_expl = self.explanation_param(self.display(res)) + outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl) + return res, outer_expl + + # ast.Call signature changed on 3.5, + # conditionally change which methods is named + # visit_Call depending on Python version + if sys.version_info >= (3, 5): + visit_Call = visit_Call_35 + else: + visit_Call = visit_Call_legacy + + def visit_Attribute(self, attr): + if not isinstance(attr.ctx, ast.Load): + return self.generic_visit(attr) + value, value_expl = self.visit(attr.value) + res = self.assign(ast.Attribute(value, attr.attr, ast.Load())) + res_expl = self.explanation_param(self.display(res)) + pat = "%s\n{%s = %s.%s\n}" + expl = pat % (res_expl, res_expl, value_expl, attr.attr) + return res, expl + + def visit_Compare(self, comp): + self.push_format_context() + left_res, left_expl = self.visit(comp.left) + if isinstance(comp.left, (ast.Compare, ast.BoolOp)): + left_expl = "({})".format(left_expl) + res_variables = [self.variable() for i in range(len(comp.ops))] + load_names = [ast.Name(v, ast.Load()) for v in res_variables] + store_names = [ast.Name(v, ast.Store()) for v in res_variables] + it = zip(range(len(comp.ops)), comp.ops, comp.comparators) + expls = [] + syms = [] + results = [left_res] + for i, op, next_operand in it: + next_res, next_expl = self.visit(next_operand) + if isinstance(next_operand, (ast.Compare, ast.BoolOp)): + next_expl = "({})".format(next_expl) + results.append(next_res) + sym = binop_map[op.__class__] + syms.append(ast.Str(sym)) + expl = "%s %s %s" % (left_expl, sym, next_expl) + expls.append(ast.Str(expl)) + res_expr = ast.Compare(left_res, [op], [next_res]) + self.statements.append(ast.Assign([store_names[i]], res_expr)) + left_res, left_expl = next_res, next_expl + # Use pytest.assertion.util._reprcompare if that's available. + expl_call = self.helper( + "_call_reprcompare", + ast.Tuple(syms, ast.Load()), + ast.Tuple(load_names, ast.Load()), + ast.Tuple(expls, ast.Load()), + ast.Tuple(results, ast.Load()), + ) + if len(comp.ops) > 1: + res = ast.BoolOp(ast.And(), load_names) + else: + res = load_names[0] + return res, self.explanation_param(self.pop_format_context(expl_call)) |