aboutsummaryrefslogtreecommitdiffstats
path: root/library/python/pytest/rewrite.py
blob: ec188d847f18372ee086c2dc455e3dff6f3e035b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import absolute_import
from __future__ import print_function

import ast

import py

from _pytest.assertion import rewrite
try:
    import importlib.util
except ImportError:
    pass
from __res import importer
import sys
import six


def _get_state(config):
    if hasattr(config, '_assertstate'):
        return config._assertstate
    return config._store[rewrite.assertstate_key]


class AssertionRewritingHook(rewrite.AssertionRewritingHook):
    def __init__(self, *args, **kwargs):
        self.modules = {}
        super(AssertionRewritingHook, self).__init__(*args, **kwargs)

    def find_module(self, name, path=None):
        co = self._find_module(name, path)
        if co is not None:
            return self

    def _find_module(self, name, path=None):
        state = _get_state(self.config)
        if not self._should_rewrite(name, None, state):
            return None
        state.trace("find_module called for: %s" % name)

        try:
            if self.is_package(name):
                return None
        except ImportError:
            return None

        self._rewritten_names.add(name)

        state.trace("rewriting %s" % name)
        co = _rewrite_test(self.config, name)
        if co is None:
            # Probably a SyntaxError in the test.
            return None
        self.modules[name] = co, None
        return co

    def find_spec(self, name, path=None, target=None):
        co = self._find_module(name, path)
        if co is not None:
            return importlib.util.spec_from_file_location(
            name,
            co.co_filename,
            loader=self,
        )

    def _should_rewrite(self, name, fn, state):
        if name.startswith("__tests__.") or name.endswith(".conftest"):
            return True

        return self._is_marked_for_rewrite(name, state)

    def is_package(self, name):
        return importer.is_package(name)

    def get_source(self, name):
        return importer.get_source(name)

    if six.PY3:
        def load_module(self, module):
            co, _ = self.modules.pop(module.__name__)
            try:
                module.__file__ = co.co_filename
                module.__cached__ = None
                module.__loader__ = self
                module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
                exec(co, module.__dict__)
            except:  # noqa
                if module.__name__ in sys.modules:
                    del sys.modules[module.__name__]
                raise
            return sys.modules[module.__name__]

        def exec_module(self, module):
            if module.__name__ in self.modules:
                self.load_module(module)
            else:
                super(AssertionRewritingHook, self).exec_module(module)


def _rewrite_test(config, name):
    """Try to read and rewrite *fn* and return the code object."""
    state = _get_state(config)

    source = importer.get_source(name)
    if source is None:
        return None

    path = importer.get_filename(name)

    try:
        tree = ast.parse(source, filename=path)
    except SyntaxError:
        # Let this pop up again in the real import.
        state.trace("failed to parse: %r" % (path,))
        return None
    rewrite.rewrite_asserts(tree, py.path.local(path), config)
    try:
        co = compile(tree, path, "exec", dont_inherit=True)
    except SyntaxError:
        # It's possible that this error is from some bug in the
        # assertion rewriting, but I don't know of a fast way to tell.
        state.trace("failed to compile: %r" % (path,))
        return None
    return co