1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
|
from __future__ import absolute_import
import hashlib
import inspect
import os
import re
import sys
from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext
import Cython
from ..Compiler.Main import Context, default_options
from ..Compiler.Visitor import CythonTransform, EnvTransform
from ..Compiler.ParseTreeTransforms import SkipDeclarations
from ..Compiler.TreeFragment import parse_from_strings
from ..Compiler.StringEncoding import _unicode
from .Dependencies import strip_string_literals, cythonize, cached_function
from ..Compiler import Pipeline
from ..Utils import get_cython_cache_dir
import cython as cython_module
IS_PY3 = sys.version_info >= (3,)
# A utility function to convert user-supplied ASCII strings to unicode.
if not IS_PY3:
def to_unicode(s):
if isinstance(s, bytes):
return s.decode('ascii')
else:
return s
else:
to_unicode = lambda x: x
if sys.version_info < (3, 5):
import imp
def load_dynamic(name, module_path):
return imp.load_dynamic(name, module_path)
else:
import importlib.util
from importlib.machinery import ExtensionFileLoader
def load_dynamic(name, path):
spec = importlib.util.spec_from_file_location(name, loader=ExtensionFileLoader(name, path))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class UnboundSymbols(EnvTransform, SkipDeclarations):
def __init__(self):
CythonTransform.__init__(self, None)
self.unbound = set()
def visit_NameNode(self, node):
if not self.current_env().lookup(node.name):
self.unbound.add(node.name)
return node
def __call__(self, node):
super(UnboundSymbols, self).__call__(node)
return self.unbound
@cached_function
def unbound_symbols(code, context=None):
code = to_unicode(code)
if context is None:
context = Context([], default_options)
from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
tree = parse_from_strings('(tree fragment)', code)
for phase in Pipeline.create_pipeline(context, 'pyx'):
if phase is None:
continue
tree = phase(tree)
if isinstance(phase, AnalyseDeclarationsTransform):
break
try:
import builtins
except ImportError:
import __builtin__ as builtins
return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
def unsafe_type(arg, context=None):
py_type = type(arg)
if py_type is int:
return 'long'
else:
return safe_type(arg, context)
def safe_type(arg, context=None):
py_type = type(arg)
if py_type in (list, tuple, dict, str):
return py_type.__name__
elif py_type is complex:
return 'double complex'
elif py_type is float:
return 'double'
elif py_type is bool:
return 'bint'
elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
else:
for base_type in py_type.__mro__:
if base_type.__module__ in ('__builtin__', 'builtins'):
return 'object'
module = context.find_module(base_type.__module__, need_pxd=False)
if module:
entry = module.lookup(base_type.__name__)
if entry.is_type:
return '%s.%s' % (base_type.__module__, base_type.__name__)
return 'object'
def _get_build_extension():
dist = Distribution()
# Ensure the build respects distutils configuration by parsing
# the configuration files
config_files = dist.find_config_files()
dist.parse_config_files(config_files)
build_extension = build_ext(dist)
build_extension.finalize_options()
return build_extension
@cached_function
def _create_context(cython_include_dirs):
return Context(list(cython_include_dirs), default_options)
_cython_inline_cache = {}
_cython_inline_default_context = _create_context(('.',))
def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
for symbol in unbound_symbols:
if symbol not in kwds:
if locals is None or globals is None:
calling_frame = inspect.currentframe().f_back.f_back.f_back
if locals is None:
locals = calling_frame.f_locals
if globals is None:
globals = calling_frame.f_globals
if symbol in locals:
kwds[symbol] = locals[symbol]
elif symbol in globals:
kwds[symbol] = globals[symbol]
else:
print("Couldn't find %r" % symbol)
def _inline_key(orig_code, arg_sigs, language_level):
key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
def cython_inline(code, get_type=unsafe_type,
lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
cython_include_dirs=None, cython_compiler_directives=None,
force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
if get_type is None:
get_type = lambda x: 'object'
ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
cython_compiler_directives = dict(cython_compiler_directives) if cython_compiler_directives else {}
if language_level is None and 'language_level' not in cython_compiler_directives:
language_level = '3str'
if language_level is not None:
cython_compiler_directives['language_level'] = language_level
# Fast path if this has been called in this session.
_unbound_symbols = _cython_inline_cache.get(code)
if _unbound_symbols is not None:
_populate_unbound(kwds, _unbound_symbols, locals, globals)
args = sorted(kwds.items())
arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
key_hash = _inline_key(code, arg_sigs, language_level)
invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
if invoke is not None:
arg_list = [arg[1] for arg in args]
return invoke(*arg_list)
orig_code = code
code = to_unicode(code)
code, literals = strip_string_literals(code)
code = strip_common_indent(code)
if locals is None:
locals = inspect.currentframe().f_back.f_back.f_locals
if globals is None:
globals = inspect.currentframe().f_back.f_back.f_globals
try:
_cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
_populate_unbound(kwds, _unbound_symbols, locals, globals)
except AssertionError:
if not quiet:
# Parsing from strings not fully supported (e.g. cimports).
print("Could not parse code as a string (to extract unbound symbols).")
cimports = []
for name, arg in list(kwds.items()):
if arg is cython_module:
cimports.append('\ncimport cython as %s' % name)
del kwds[name]
arg_names = sorted(kwds)
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key_hash = _inline_key(orig_code, arg_sigs, language_level)
module_name = "_cython_inline_" + key_hash
if module_name in sys.modules:
module = sys.modules[module_name]
else:
build_extension = None
if cython_inline.so_ext is None:
# Figure out and cache current extension suffix
build_extension = _get_build_extension()
cython_inline.so_ext = build_extension.get_ext_filename('')
module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
if not os.path.exists(lib_dir):
os.makedirs(lib_dir)
if force or not os.path.isfile(module_path):
cflags = []
c_include_dirs = []
qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs:
m = qualified.match(type)
if m:
cimports.append('\ncimport %s' % m.groups()[0])
# one special case
if m.groups()[0] == 'numpy':
import numpy
c_include_dirs.append(numpy.get_include())
# cflags.append('-Wno-unused')
module_body, func_body = extract_func_code(code)
params = ', '.join(['%s %s' % a for a in arg_sigs])
module_code = """
%(module_body)s
%(cimports)s
def __invoke(%(params)s):
%(func_body)s
return locals()
""" % {'cimports': '\n'.join(cimports),
'module_body': module_body,
'params': params,
'func_body': func_body }
for key, value in literals.items():
module_code = module_code.replace(key, value)
pyx_file = os.path.join(lib_dir, module_name + '.pyx')
fh = open(pyx_file, 'w')
try:
fh.write(module_code)
finally:
fh.close()
extension = Extension(
name = module_name,
sources = [pyx_file],
include_dirs = c_include_dirs,
extra_compile_args = cflags)
if build_extension is None:
build_extension = _get_build_extension()
build_extension.extensions = cythonize(
[extension],
include_path=cython_include_dirs or ['.'],
compiler_directives=cython_compiler_directives,
quiet=quiet)
build_extension.build_temp = os.path.dirname(pyx_file)
build_extension.build_lib = lib_dir
build_extension.run()
module = load_dynamic(module_name, module_path)
_cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
arg_list = [kwds[arg] for arg in arg_names]
return module.__invoke(*arg_list)
# Cached suffix used by cython_inline above. None should get
# overridden with actual value upon the first cython_inline invocation
cython_inline.so_ext = None
_find_non_space = re.compile('[^ ]').search
def strip_common_indent(code):
min_indent = None
lines = code.splitlines()
for line in lines:
match = _find_non_space(line)
if not match:
continue # blank
indent = match.start()
if line[indent] == '#':
continue # comment
if min_indent is None or min_indent > indent:
min_indent = indent
for ix, line in enumerate(lines):
match = _find_non_space(line)
if not match or not line or line[indent:indent+1] == '#':
continue
lines[ix] = line[min_indent:]
return '\n'.join(lines)
module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
def extract_func_code(code):
module = []
function = []
current = function
code = code.replace('\t', ' ')
lines = code.split('\n')
for line in lines:
if not line.startswith(' '):
if module_statement.match(line):
current = module
else:
current = function
current.append(line)
return '\n'.join(module), ' ' + '\n '.join(function)
try:
from inspect import getcallargs
except ImportError:
def getcallargs(func, *arg_values, **kwd_values):
all = {}
args, varargs, kwds, defaults = inspect.getargspec(func)
if varargs is not None:
all[varargs] = arg_values[len(args):]
for name, value in zip(args, arg_values):
all[name] = value
for name, value in list(kwd_values.items()):
if name in args:
if name in all:
raise TypeError("Duplicate argument %s" % name)
all[name] = kwd_values.pop(name)
if kwds is not None:
all[kwds] = kwd_values
elif kwd_values:
raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
if defaults is None:
defaults = ()
first_default = len(args) - len(defaults)
for ix, name in enumerate(args):
if name not in all:
if ix >= first_default:
all[name] = defaults[ix - first_default]
else:
raise TypeError("Missing argument: %s" % name)
return all
def get_body(source):
ix = source.index(':')
if source[:5] == 'lambda':
return "return %s" % source[ix+1:]
else:
return source[ix+1:]
# Lots to be done here... It would be especially cool if compiled functions
# could invoke each other quickly.
class RuntimeCompiledFunction(object):
def __init__(self, f):
self._f = f
self._body = get_body(inspect.getsource(f))
def __call__(self, *args, **kwds):
all = getcallargs(self._f, *args, **kwds)
if IS_PY3:
return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
else:
return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
|