1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
|
from __future__ import absolute_import
from __future__ import print_function
import ast
import py
from _pytest.assertion import rewrite
try:
import importlib.util
except ImportError:
pass
try:
from pathlib import Path
except ImportError:
pass
from __res import importer
import sys
import six
def _get_state(config):
if hasattr(config, '_assertstate'):
return config._assertstate
return config._store[rewrite.assertstate_key]
class AssertionRewritingHook(rewrite.AssertionRewritingHook):
def __init__(self, *args, **kwargs):
self.modules = {}
super(AssertionRewritingHook, self).__init__(*args, **kwargs)
def find_module(self, name, path=None):
co = self._find_module(name, path)
if co is not None:
return self
def _find_module(self, name, path=None):
state = _get_state(self.config)
if not self._should_rewrite(name, None, state):
return None
state.trace("find_module called for: %s" % name)
try:
if self.is_package(name):
return None
except ImportError:
return None
if hasattr(self._rewritten_names, 'add'):
self._rewritten_names.add(name)
else:
self._rewritten_names[name] = Path(path[0])
state.trace("rewriting %s" % name)
co = _rewrite_test(self.config, name)
if co is None:
# Probably a SyntaxError in the test.
return None
self.modules[name] = co, None
return co
def find_spec(self, name, path=None, target=None):
co = self._find_module(name, path)
if co is not None:
return importlib.util.spec_from_file_location(
name,
co.co_filename,
loader=self,
)
def _should_rewrite(self, name, fn, state):
if name.startswith("__tests__.") or name.endswith(".conftest"):
return True
return self._is_marked_for_rewrite(name, state)
def is_package(self, name):
return importer.is_package(name)
def get_source(self, name):
return importer.get_source(name)
if six.PY3:
def load_module(self, module):
co, _ = self.modules.pop(module.__name__)
try:
module.__file__ = co.co_filename
module.__cached__ = None
module.__loader__ = self
module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
exec(co, module.__dict__)
except: # noqa
if module.__name__ in sys.modules:
del sys.modules[module.__name__]
raise
return sys.modules[module.__name__]
def exec_module(self, module):
if module.__name__ in self.modules:
self.load_module(module)
else:
super(AssertionRewritingHook, self).exec_module(module)
def _rewrite_test(config, name):
"""Try to read and rewrite *fn* and return the code object."""
state = _get_state(config)
source = importer.get_source(name)
if source is None:
return None
path = importer.get_filename(name)
try:
tree = ast.parse(source, filename=path)
except SyntaxError:
# Let this pop up again in the real import.
state.trace("failed to parse: %r" % (path,))
return None
rewrite.rewrite_asserts(tree, py.path.local(path), config)
try:
co = compile(tree, path, "exec", dont_inherit=True)
except SyntaxError:
# It's possible that this error is from some bug in the
# assertion rewriting, but I don't know of a fast way to tell.
state.trace("failed to compile: %r" % (path,))
return None
return co
|