1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
from __future__ import absolute_import
from __future__ import print_function
import ast
import py
from _pytest.assertion import rewrite
try:
import importlib.util
except ImportError:
pass
from __res import importer
import sys
import six
def _get_state(config):
if hasattr(config, '_assertstate'):
return config._assertstate
return config._store[rewrite.assertstate_key]
class AssertionRewritingHook(rewrite.AssertionRewritingHook):
def __init__(self, *args, **kwargs):
self.modules = {}
super(AssertionRewritingHook, self).__init__(*args, **kwargs)
def find_module(self, name, path=None):
co = self._find_module(name, path)
if co is not None:
return self
def _find_module(self, name, path=None):
state = _get_state(self.config)
if not self._should_rewrite(name, None, state):
return None
state.trace("find_module called for: %s" % name)
try:
if self.is_package(name):
return None
except ImportError:
return None
self._rewritten_names.add(name)
state.trace("rewriting %s" % name)
co = _rewrite_test(self.config, name)
if co is None:
# Probably a SyntaxError in the test.
return None
self.modules[name] = co, None
return co
def find_spec(self, name, path=None, target=None):
co = self._find_module(name, path)
if co is not None:
return importlib.util.spec_from_file_location(
name,
co.co_filename,
loader=self,
)
def _should_rewrite(self, name, fn, state):
if name.startswith("__tests__.") or name.endswith(".conftest"):
return True
return self._is_marked_for_rewrite(name, state)
def is_package(self, name):
return importer.is_package(name)
def get_source(self, name):
return importer.get_source(name)
if six.PY3:
def load_module(self, module):
co, _ = self.modules.pop(module.__name__)
try:
module.__file__ = co.co_filename
module.__cached__ = None
module.__loader__ = self
module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
exec(co, module.__dict__)
except: # noqa
if module.__name__ in sys.modules:
del sys.modules[module.__name__]
raise
return sys.modules[module.__name__]
def exec_module(self, module):
if module.__name__ in self.modules:
self.load_module(module)
else:
super(AssertionRewritingHook, self).exec_module(module)
def _rewrite_test(config, name):
"""Try to read and rewrite *fn* and return the code object."""
state = _get_state(config)
source = importer.get_source(name)
if source is None:
return None
path = importer.get_filename(name)
try:
tree = ast.parse(source, filename=path)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: %r" % (path,))
return None
rewrite.rewrite_asserts(tree, py.path.local(path), config)
try:
co = compile(tree, path, "exec", dont_inherit=True)
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (path,))
return None
return co
|