1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
|
from collections import OrderedDict, deque
from datetime import date, time, datetime
from decimal import Decimal
from fractions import Fraction
import ast
import enum
import typing
class CannotEval(Exception):
def __repr__(self):
return self.__class__.__name__
__str__ = __repr__
def is_any(x, *args):
return any(
x is arg
for arg in args
)
def of_type(x, *types):
if is_any(type(x), *types):
return x
else:
raise CannotEval
def of_standard_types(x, *, check_dict_values: bool, deep: bool):
if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
return x
else:
raise CannotEval
def is_standard_types(x, *, check_dict_values: bool, deep: bool):
try:
return _is_standard_types_deep(x, check_dict_values, deep)[0]
except RecursionError:
return False
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
typ = type(x)
if is_any(
typ,
str,
int,
bool,
float,
bytes,
complex,
date,
time,
datetime,
Fraction,
Decimal,
type(None),
object,
):
return True, 0
if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
if typ in [slice]:
length = 0
else:
length = len(x)
assert isinstance(deep, bool)
if not deep:
return True, length
if check_dict_values and typ in (dict, OrderedDict):
items = (v for pair in x.items() for v in pair)
elif typ is slice:
items = [x.start, x.stop, x.step]
else:
items = x
for item in items:
if length > 100000:
return False, length
is_standard, item_length = _is_standard_types_deep(
item, check_dict_values, deep
)
if not is_standard:
return False, length
length += item_length
return True, length
return False, 0
class _E(enum.Enum):
pass
class _C:
def foo(self): pass # pragma: nocover
def bar(self): pass # pragma: nocover
@classmethod
def cm(cls): pass # pragma: nocover
@staticmethod
def sm(): pass # pragma: nocover
safe_name_samples = {
"len": len,
"append": list.append,
"__add__": list.__add__,
"insert": [].insert,
"__mul__": [].__mul__,
"fromkeys": dict.__dict__['fromkeys'],
"is_any": is_any,
"__repr__": CannotEval.__repr__,
"foo": _C().foo,
"bar": _C.bar,
"cm": _C.cm,
"sm": _C.sm,
"ast": ast,
"CannotEval": CannotEval,
"_E": _E,
}
typing_annotation_samples = {
name: getattr(typing, name)
for name in "List Dict Tuple Set Callable Mapping".split()
}
safe_name_types = tuple({
type(f)
for f in safe_name_samples.values()
})
typing_annotation_types = tuple({
type(f)
for f in typing_annotation_samples.values()
})
def eq_checking_types(a, b):
return type(a) is type(b) and a == b
def ast_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return node.attr
else:
return None
def safe_name(value):
typ = type(value)
if is_any(typ, *safe_name_types):
return value.__name__
elif value is typing.Optional:
return "Optional"
elif value is typing.Union:
return "Union"
elif is_any(typ, *typing_annotation_types):
return getattr(value, "__name__", None) or getattr(value, "_name", None)
else:
return None
def has_ast_name(value, node):
value_name = safe_name(value)
if type(value_name) is not str:
return False
return eq_checking_types(ast_name(node), value_name)
def copy_ast_without_context(x):
if isinstance(x, ast.AST):
kwargs = {
field: copy_ast_without_context(getattr(x, field))
for field in x._fields
if field != 'ctx'
if hasattr(x, field)
}
return type(x)(**kwargs)
elif isinstance(x, list):
return list(map(copy_ast_without_context, x))
else:
return x
def ensure_dict(x):
"""
Handles invalid non-dict inputs
"""
try:
return dict(x)
except Exception:
return {}
|