aboutsummaryrefslogtreecommitdiffstats
path: root/library/python/pytest/rewrite.py
blob: 38e80ebf5d023b75c836324e104a7930d3afd66b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import absolute_import
from __future__ import print_function

import ast

import py

from _pytest.assertion import rewrite

try:
    import importlib.util
except ImportError:
    pass

try:
    from pathlib import Path
except ImportError:
    pass

from __res import importer
import sys
import six


def _get_state(config):
    if hasattr(config, '_assertstate'):
        return config._assertstate
    return config._store[rewrite.assertstate_key]


class AssertionRewritingHook(rewrite.AssertionRewritingHook):
    def __init__(self, *args, **kwargs):
        self.modules = {}
        super(AssertionRewritingHook, self).__init__(*args, **kwargs)

    def find_module(self, name, path=None):
        co = self._find_module(name, path)
        if co is not None:
            return self

    def _find_module(self, name, path=None):
        state = _get_state(self.config)
        if not self._should_rewrite(name, None, state):
            return None
        state.trace("find_module called for: %s" % name)

        try:
            if self.is_package(name):
                return None
        except ImportError:
            return None

        if hasattr(self._rewritten_names, 'add'):
            self._rewritten_names.add(name)
        else:
            self._rewritten_names[name] = Path(path[0])

        state.trace("rewriting %s" % name)
        co = _rewrite_test(self.config, name)
        if co is None:
            # Probably a SyntaxError in the test.
            return None
        self.modules[name] = co, None
        return co

    def find_spec(self, name, path=None, target=None):
        co = self._find_module(name, path)
        if co is not None:
            return importlib.util.spec_from_file_location(
                name,
                co.co_filename,
                loader=self,
            )

    def _should_rewrite(self, name, fn, state):
        if name.startswith("__tests__.") or name.endswith(".conftest"):
            return True

        return self._is_marked_for_rewrite(name, state)

    def is_package(self, name):
        return importer.is_package(name)

    def get_source(self, name):
        return importer.get_source(name)

    if six.PY3:

        def load_module(self, module):
            co, _ = self.modules.pop(module.__name__)
            try:
                module.__file__ = co.co_filename
                module.__cached__ = None
                module.__loader__ = self
                module.__spec__ = importlib.util.spec_from_file_location(module.__name__, co.co_filename, loader=self)
                exec(co, module.__dict__)
            except:  # noqa
                if module.__name__ in sys.modules:
                    del sys.modules[module.__name__]
                raise
            return sys.modules[module.__name__]

        def exec_module(self, module):
            if module.__name__ in self.modules:
                self.load_module(module)
            else:
                super(AssertionRewritingHook, self).exec_module(module)


def _rewrite_test(config, name):
    """Try to read and rewrite *fn* and return the code object."""
    state = _get_state(config)

    source = importer.get_source(name)
    if source is None:
        return None

    path = importer.get_filename(name)

    try:
        tree = ast.parse(source, filename=path)
    except SyntaxError:
        # Let this pop up again in the real import.
        state.trace("failed to parse: %r" % (path,))
        return None
    rewrite.rewrite_asserts(tree, py.path.local(path), config)
    try:
        co = compile(tree, path, "exec", dont_inherit=True)
    except SyntaxError:
        # It's possible that this error is from some bug in the
        # assertion rewriting, but I don't know of a fast way to tell.
        state.trace("failed to compile: %r" % (path,))
        return None
    return co