aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/parso/py3/tests/test_parser_tree.py
blob: b994b9bbb88d2d30c02d639be98249c9d30bdb60 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
# -*- coding: utf-8    # This file contains Unicode characters.

from textwrap import dedent

import pytest

from parso import parse
from parso.python import tree
from parso.tree import search_ancestor


class TestsFunctionAndLambdaParsing:

    FIXTURES = [
        ('def my_function(x, y, z) -> str:\n    return x + y * z\n', {
            'name': 'my_function',
            'call_sig': 'my_function(x, y, z)',
            'params': ['x', 'y', 'z'],
            'annotation': "str",
        }),
        ('lambda x, y, z: x + y * z\n', {
            'name': '<lambda>',
            'call_sig': '<lambda>(x, y, z)',
            'params': ['x', 'y', 'z'],
        }),
    ]

    @pytest.fixture(params=FIXTURES)
    def node(self, request):
        parsed = parse(dedent(request.param[0]), version='3.10')
        request.keywords['expected'] = request.param[1]
        child = parsed.children[0]
        if child.type == 'simple_stmt':
            child = child.children[0]
        return child

    @pytest.fixture()
    def expected(self, request, node):
        return request.keywords['expected']

    def test_name(self, node, expected):
        if node.type != 'lambdef':
            assert isinstance(node.name, tree.Name)
            assert node.name.value == expected['name']

    def test_params(self, node, expected):
        assert isinstance(node.get_params(), list)
        assert all(isinstance(x, tree.Param) for x in node.get_params())
        assert [str(x.name.value) for x in node.get_params()] == [x for x in expected['params']]

    def test_is_generator(self, node, expected):
        assert node.is_generator() is expected.get('is_generator', False)

    def test_yields(self, node, expected):
        assert node.is_generator() == expected.get('yields', False)

    def test_annotation(self, node, expected):
        expected_annotation = expected.get('annotation', None)
        if expected_annotation is None:
            assert node.annotation is None
        else:
            assert node.annotation.value == expected_annotation


def test_end_pos_line(each_version):
    # jedi issue #150
    s = "x()\nx( )\nx(  )\nx (  )\n"

    module = parse(s, version=each_version)
    for i, simple_stmt in enumerate(module.children[:-1]):
        expr_stmt = simple_stmt.children[0]
        assert expr_stmt.end_pos == (i + 1, i + 3)


def test_default_param(each_version):
    func = parse('def x(foo=42): pass', version=each_version).children[0]
    param, = func.get_params()
    assert param.default.value == '42'
    assert param.annotation is None
    assert not param.star_count


def test_annotation_param(each_version):
    func = parse('def x(foo: 3): pass', version=each_version).children[0]
    param, = func.get_params()
    assert param.default is None
    assert param.annotation.value == '3'
    assert not param.star_count


def test_annotation_params(each_version):
    func = parse('def x(foo: 3, bar: 4): pass', version=each_version).children[0]
    param1, param2 = func.get_params()

    assert param1.default is None
    assert param1.annotation.value == '3'
    assert not param1.star_count

    assert param2.default is None
    assert param2.annotation.value == '4'
    assert not param2.star_count


def test_default_and_annotation_param(each_version):
    func = parse('def x(foo:3=42): pass', version=each_version).children[0]
    param, = func.get_params()
    assert param.default.value == '42'
    assert param.annotation.value == '3'
    assert not param.star_count


def get_yield_exprs(code, version):
    return list(parse(code, version=version).children[0].iter_yield_exprs())


def get_return_stmts(code):
    return list(parse(code).children[0].iter_return_stmts())


def get_raise_stmts(code, child):
    return list(parse(code).children[child].iter_raise_stmts())


def test_yields(each_version):
    y, = get_yield_exprs('def x(): yield', each_version)
    assert y.value == 'yield'
    assert y.type == 'keyword'

    y, = get_yield_exprs('def x(): (yield 1)', each_version)
    assert y.type == 'yield_expr'

    y, = get_yield_exprs('def x(): [1, (yield)]', each_version)
    assert y.type == 'keyword'


def test_yield_from():
    y, = get_yield_exprs('def x(): (yield from 1)', '3.8')
    assert y.type == 'yield_expr'


def test_returns():
    r, = get_return_stmts('def x(): return')
    assert r.value == 'return'
    assert r.type == 'keyword'

    r, = get_return_stmts('def x(): return 1')
    assert r.type == 'return_stmt'


def test_raises():
    code = """
def single_function():
    raise Exception
def top_function():
    def inner_function():
        raise NotImplementedError()
    inner_function()
    raise Exception
def top_function_three():
    try:
        raise NotImplementedError()
    except NotImplementedError:
        pass
    raise Exception
    """

    r = get_raise_stmts(code, 0)  # Lists in a simple Function
    assert len(list(r)) == 1

    r = get_raise_stmts(code, 1)  # Doesn't Exceptions list in closures
    assert len(list(r)) == 1

    r = get_raise_stmts(code, 2)  # Lists inside try-catch
    assert len(list(r)) == 2


@pytest.mark.parametrize(
    'code, name_index, is_definition, include_setitem', [
        ('x = 3', 0, True, False),
        ('x.y = 3', 0, False, False),
        ('x.y = 3', 1, True, False),
        ('x.y = u.v = z', 0, False, False),
        ('x.y = u.v = z', 1, True, False),
        ('x.y = u.v = z', 2, False, False),
        ('x.y = u.v, w = z', 3, True, False),
        ('x.y = u.v, w = z', 4, True, False),
        ('x.y = u.v, w = z', 5, False, False),

        ('x, y = z', 0, True, False),
        ('x, y = z', 1, True, False),
        ('x, y = z', 2, False, False),
        ('x, y = z', 2, False, False),
        ('x[0], y = z', 2, False, False),
        ('x[0] = z', 0, False, False),
        ('x[0], y = z', 0, False, False),
        ('x[0], y = z', 2, False, True),
        ('x[0] = z', 0, True, True),
        ('x[0], y = z', 0, True, True),
        ('x: int = z', 0, True, False),
        ('x: int = z', 1, False, False),
        ('x: int = z', 2, False, False),
        ('x: int', 0, True, False),
        ('x: int', 1, False, False),
    ]
)
def test_is_definition(code, name_index, is_definition, include_setitem):
    module = parse(code, version='3.8')
    name = module.get_first_leaf()
    while True:
        if name.type == 'name':
            if name_index == 0:
                break
            name_index -= 1
        name = name.get_next_leaf()

    assert name.is_definition(include_setitem=include_setitem) == is_definition


def test_iter_funcdefs():
    code = dedent('''
        def normal(): ...
        async def asyn(): ...
        @dec
        def dec_normal(): ...
        @dec1
        @dec2
        async def dec_async(): ...
        def broken
        ''')
    module = parse(code, version='3.8')
    func_names = [f.name.value for f in module.iter_funcdefs()]
    assert func_names == ['normal', 'asyn', 'dec_normal', 'dec_async']


def test_with_stmt_get_test_node_from_name():
    code = "with A as X.Y, B as (Z), C as Q[0], D as Q['foo']: pass"
    with_stmt = parse(code, version='3').children[0]
    tests = [
        with_stmt.get_test_node_from_name(name).value
        for name in with_stmt.get_defined_names(include_setitem=True)
    ]
    assert tests == ["A", "B", "C", "D"]


sample_module = parse('x + y')
sample_node = sample_module.children[0]
sample_leaf = sample_node.children[0]


@pytest.mark.parametrize(
    'node,node_types,expected_ancestor', [
        (sample_module, ('file_input',), None),
        (sample_node, ('arith_expr',), None),
        (sample_node, ('file_input', 'eval_input'), sample_module),
        (sample_leaf, ('name',), None),
        (sample_leaf, ('arith_expr',), sample_node),
        (sample_leaf, ('file_input',), sample_module),
        (sample_leaf, ('file_input', 'arith_expr'), sample_node),
        (sample_leaf, ('shift_expr',), None),
        (sample_leaf, ('name', 'shift_expr',), None),
        (sample_leaf, (), None),
    ]
)
def test_search_ancestor(node, node_types, expected_ancestor):
    assert node.search_ancestor(*node_types) is expected_ancestor
    assert search_ancestor(node, *node_types) is expected_ancestor  # deprecated