summaryrefslogtreecommitdiffstats
path: root/library/python/pytest/rewrite.py
blob: e14bd3279b233cd456a7a49059c4de2e5e842859 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import absolute_import 
from __future__ import print_function 
 
import ast 
 
import py

from _pytest.assertion import rewrite 
try:
    import importlib.util
except ImportError:
    pass

try:
    from pathlib import Path
except ImportError:
    pass

from __res import importer 
import sys
import six
 
 
def _get_state(config):
    if hasattr(config, '_assertstate'):
        return config._assertstate
    return config._store[rewrite.assertstate_key]


class AssertionRewritingHook(rewrite.AssertionRewritingHook): 
    def __init__(self, *args, **kwargs):
        self.modules = {}
        super(AssertionRewritingHook, self).__init__(*args, **kwargs)
 
    def find_module(self, name, path=None): 
        co = self._find_module(name, path)
        if co is not None:
            return self

    def _find_module(self, name, path=None):
        state = _get_state(self.config)
        if not self._should_rewrite(name, None, state):
            return None 
        state.trace("find_module called for: %s" % name) 
 
        try: 
            if self.is_package(name): 
                return None 
        except ImportError: 
            return None 
 
        if hasattr(self._rewritten_names, 'add'):
            self._rewritten_names.add(name)
        else:
            self._rewritten_names[name] = Path(path[0])
 
        state.trace("rewriting %s" % name) 
        co = _rewrite_test(self.config, name) 
        if co is None: 
            # Probably a SyntaxError in the test. 
            return None 
        self.modules[name] = co, None 
        return co
 
    def find_spec(self, name, path=None, target=None):
        co = self._find_module(name, path)
        if co is not None:
            return importlib.util.spec_from_file_location(
            name,
            co.co_filename,
            loader=self,
        )

    def _should_rewrite(self, name, fn, state):
        if name.startswith("__tests__.") or name.endswith(".conftest"): 
            return True 
 
        return self._is_marked_for_rewrite(name, state) 
 
    def is_package(self, name): 
        return importer.is_package(name) 
 
    def get_source(self, name): 
        return importer.get_source(name) 
 
    if six.PY3:
        def load_module(self, module):
            co, _ = self.modules.pop(module.__name__)
            try:
                module.__file__ = co.co_filename
                module.__cached__ = None
                module.__loader__ = self
                module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
                exec(co, module.__dict__)
            except:  # noqa
                if module.__name__ in sys.modules:
                    del sys.modules[module.__name__]
                raise
            return sys.modules[module.__name__]
 
        def exec_module(self, module):
            if module.__name__ in self.modules:
                self.load_module(module)
            else:
                super(AssertionRewritingHook, self).exec_module(module)


def _rewrite_test(config, name): 
    """Try to read and rewrite *fn* and return the code object.""" 
    state = _get_state(config)
 
    source = importer.get_source(name) 
    if source is None: 
        return None 
 
    path = importer.get_filename(name) 
 
    try: 
        tree = ast.parse(source, filename=path) 
    except SyntaxError: 
        # Let this pop up again in the real import. 
        state.trace("failed to parse: %r" % (path,)) 
        return None 
    rewrite.rewrite_asserts(tree, py.path.local(path), config)
    try: 
        co = compile(tree, path, "exec", dont_inherit=True) 
    except SyntaxError: 
        # It's possible that this error is from some bug in the 
        # assertion rewriting, but I don't know of a fast way to tell. 
        state.trace("failed to compile: %r" % (path,)) 
        return None 
    return co