aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pure-eval/pure_eval/utils.py
blob: a8a37302daa000d736ef9d935813892a725cbede (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from collections import OrderedDict, deque
from datetime import date, time, datetime
from decimal import Decimal
from fractions import Fraction
import ast
import enum
import typing


class CannotEval(Exception):
    def __repr__(self):
        return self.__class__.__name__

    __str__ = __repr__


def is_any(x, *args):
    return any(
        x is arg
        for arg in args
    )


def of_type(x, *types):
    if is_any(type(x), *types):
        return x
    else:
        raise CannotEval


def of_standard_types(x, *, check_dict_values: bool, deep: bool):
    if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
        return x
    else:
        raise CannotEval


def is_standard_types(x, *, check_dict_values: bool, deep: bool):
    try:
        return _is_standard_types_deep(x, check_dict_values, deep)[0]
    except RecursionError:
        return False


def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
    typ = type(x)
    if is_any(
        typ,
        str,
        int,
        bool,
        float,
        bytes,
        complex,
        date,
        time,
        datetime,
        Fraction,
        Decimal,
        type(None),
        object,
    ):
        return True, 0

    if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
        if typ in [slice]:
            length = 0
        else:
            length = len(x)
        assert isinstance(deep, bool)
        if not deep:
            return True, length

        if check_dict_values and typ in (dict, OrderedDict):
            items = (v for pair in x.items() for v in pair)
        elif typ is slice:
            items = [x.start, x.stop, x.step]
        else:
            items = x
        for item in items:
            if length > 100000:
                return False, length
            is_standard, item_length = _is_standard_types_deep(
                item, check_dict_values, deep
            )
            if not is_standard:
                return False, length
            length += item_length
        return True, length

    return False, 0


class _E(enum.Enum):
    pass


class _C:
    def foo(self): pass  # pragma: nocover

    def bar(self): pass  # pragma: nocover

    @classmethod
    def cm(cls): pass  # pragma: nocover

    @staticmethod
    def sm(): pass  # pragma: nocover


safe_name_samples = {
    "len": len,
    "append": list.append,
    "__add__": list.__add__,
    "insert": [].insert,
    "__mul__": [].__mul__,
    "fromkeys": dict.__dict__['fromkeys'],
    "is_any": is_any,
    "__repr__": CannotEval.__repr__,
    "foo": _C().foo,
    "bar": _C.bar,
    "cm": _C.cm,
    "sm": _C.sm,
    "ast": ast,
    "CannotEval": CannotEval,
    "_E": _E,
}

typing_annotation_samples = {
    name: getattr(typing, name)
    for name in "List Dict Tuple Set Callable Mapping".split()
}

safe_name_types = tuple({
    type(f)
    for f in safe_name_samples.values()
})


typing_annotation_types = tuple({
    type(f)
    for f in typing_annotation_samples.values()
})


def eq_checking_types(a, b):
    return type(a) is type(b) and a == b


def ast_name(node):
    if isinstance(node, ast.Name):
        return node.id
    elif isinstance(node, ast.Attribute):
        return node.attr
    else:
        return None


def safe_name(value):
    typ = type(value)
    if is_any(typ, *safe_name_types):
        return value.__name__
    elif value is typing.Optional:
        return "Optional"
    elif value is typing.Union:
        return "Union"
    elif is_any(typ, *typing_annotation_types):
        return getattr(value, "__name__", None) or getattr(value, "_name", None)
    else:
        return None


def has_ast_name(value, node):
    value_name = safe_name(value)
    if type(value_name) is not str:
        return False
    return eq_checking_types(ast_name(node), value_name)


def copy_ast_without_context(x):
    if isinstance(x, ast.AST):
        kwargs = {
            field: copy_ast_without_context(getattr(x, field))
            for field in x._fields
            if field != 'ctx'
            if hasattr(x, field)
        }
        return type(x)(**kwargs)
    elif isinstance(x, list):
        return list(map(copy_ast_without_context, x))
    else:
        return x


def ensure_dict(x):
    """
    Handles invalid non-dict inputs
    """
    try:
        return dict(x)
    except Exception:
        return {}