aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/tools/cython/Cython/Build/Inline.py
blob: 69684e03ff50f330b1c92df5c4a8676e45c421da (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
from __future__ import absolute_import

import hashlib
import inspect
import os
import re
import sys

from distutils.core import Distribution, Extension
from distutils.command.build_ext import build_ext

import Cython
from ..Compiler.Main import Context, default_options

from ..Compiler.Visitor import CythonTransform, EnvTransform
from ..Compiler.ParseTreeTransforms import SkipDeclarations
from ..Compiler.TreeFragment import parse_from_strings
from ..Compiler.StringEncoding import _unicode
from .Dependencies import strip_string_literals, cythonize, cached_function
from ..Compiler import Pipeline
from ..Utils import get_cython_cache_dir
import cython as cython_module


IS_PY3 = sys.version_info >= (3,)

# A utility function to convert user-supplied ASCII strings to unicode.
if not IS_PY3:
    def to_unicode(s):
        if isinstance(s, bytes):
            return s.decode('ascii')
        else:
            return s
else:
    to_unicode = lambda x: x

if sys.version_info < (3, 5):
    import imp
    def load_dynamic(name, module_path):
        return imp.load_dynamic(name, module_path)
else:
    import importlib.util
    from importlib.machinery import ExtensionFileLoader

    def load_dynamic(name, path):
        spec = importlib.util.spec_from_file_location(name, loader=ExtensionFileLoader(name, path))
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module

class UnboundSymbols(EnvTransform, SkipDeclarations):
    def __init__(self):
        CythonTransform.__init__(self, None)
        self.unbound = set()
    def visit_NameNode(self, node):
        if not self.current_env().lookup(node.name):
            self.unbound.add(node.name)
        return node
    def __call__(self, node):
        super(UnboundSymbols, self).__call__(node)
        return self.unbound


@cached_function
def unbound_symbols(code, context=None):
    code = to_unicode(code)
    if context is None:
        context = Context([], default_options)
    from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
    tree = parse_from_strings('(tree fragment)', code)
    for phase in Pipeline.create_pipeline(context, 'pyx'):
        if phase is None:
            continue
        tree = phase(tree)
        if isinstance(phase, AnalyseDeclarationsTransform):
            break
    try:
        import builtins
    except ImportError:
        import __builtin__ as builtins
    return tuple(UnboundSymbols()(tree) - set(dir(builtins)))


def unsafe_type(arg, context=None):
    py_type = type(arg)
    if py_type is int:
        return 'long'
    else:
        return safe_type(arg, context)


def safe_type(arg, context=None):
    py_type = type(arg)
    if py_type in (list, tuple, dict, str):
        return py_type.__name__
    elif py_type is complex:
        return 'double complex'
    elif py_type is float:
        return 'double'
    elif py_type is bool:
        return 'bint'
    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
    else:
        for base_type in py_type.__mro__:
            if base_type.__module__ in ('__builtin__', 'builtins'):
                return 'object'
            module = context.find_module(base_type.__module__, need_pxd=False)
            if module:
                entry = module.lookup(base_type.__name__)
                if entry.is_type:
                    return '%s.%s' % (base_type.__module__, base_type.__name__)
        return 'object'


def _get_build_extension():
    dist = Distribution()
    # Ensure the build respects distutils configuration by parsing
    # the configuration files
    config_files = dist.find_config_files()
    dist.parse_config_files(config_files)
    build_extension = build_ext(dist)
    build_extension.finalize_options()
    return build_extension


@cached_function
def _create_context(cython_include_dirs):
    return Context(list(cython_include_dirs), default_options)


_cython_inline_cache = {}
_cython_inline_default_context = _create_context(('.',))


def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
    for symbol in unbound_symbols:
        if symbol not in kwds:
            if locals is None or globals is None:
                calling_frame = inspect.currentframe().f_back.f_back.f_back
                if locals is None:
                    locals = calling_frame.f_locals
                if globals is None:
                    globals = calling_frame.f_globals
            if symbol in locals:
                kwds[symbol] = locals[symbol]
            elif symbol in globals:
                kwds[symbol] = globals[symbol]
            else:
                print("Couldn't find %r" % symbol)


def _inline_key(orig_code, arg_sigs, language_level):
    key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
    return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()


def cython_inline(code, get_type=unsafe_type,
                  lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
                  cython_include_dirs=None, cython_compiler_directives=None,
                  force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):

    if get_type is None:
        get_type = lambda x: 'object'
    ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context

    cython_compiler_directives = dict(cython_compiler_directives) if cython_compiler_directives else {}
    if language_level is None and 'language_level' not in cython_compiler_directives:
        language_level = '3str'
    if language_level is not None:
        cython_compiler_directives['language_level'] = language_level

    # Fast path if this has been called in this session.
    _unbound_symbols = _cython_inline_cache.get(code)
    if _unbound_symbols is not None:
        _populate_unbound(kwds, _unbound_symbols, locals, globals)
        args = sorted(kwds.items())
        arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
        key_hash = _inline_key(code, arg_sigs, language_level)
        invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
        if invoke is not None:
            arg_list = [arg[1] for arg in args]
            return invoke(*arg_list)

    orig_code = code
    code = to_unicode(code)
    code, literals = strip_string_literals(code)
    code = strip_common_indent(code)
    if locals is None:
        locals = inspect.currentframe().f_back.f_back.f_locals
    if globals is None:
        globals = inspect.currentframe().f_back.f_back.f_globals
    try:
        _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
        _populate_unbound(kwds, _unbound_symbols, locals, globals)
    except AssertionError:
        if not quiet:
            # Parsing from strings not fully supported (e.g. cimports).
            print("Could not parse code as a string (to extract unbound symbols).")

    cimports = []
    for name, arg in list(kwds.items()):
        if arg is cython_module:
            cimports.append('\ncimport cython as %s' % name)
            del kwds[name]
    arg_names = sorted(kwds)
    arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
    key_hash = _inline_key(orig_code, arg_sigs, language_level)
    module_name = "_cython_inline_" + key_hash

    if module_name in sys.modules:
        module = sys.modules[module_name]

    else:
        build_extension = None
        if cython_inline.so_ext is None:
            # Figure out and cache current extension suffix
            build_extension = _get_build_extension()
            cython_inline.so_ext = build_extension.get_ext_filename('')

        module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)

        if not os.path.exists(lib_dir):
            os.makedirs(lib_dir)
        if force or not os.path.isfile(module_path):
            cflags = []
            c_include_dirs = []
            qualified = re.compile(r'([.\w]+)[.]')
            for type, _ in arg_sigs:
                m = qualified.match(type)
                if m:
                    cimports.append('\ncimport %s' % m.groups()[0])
                    # one special case
                    if m.groups()[0] == 'numpy':
                        import numpy
                        c_include_dirs.append(numpy.get_include())
                        # cflags.append('-Wno-unused')
            module_body, func_body = extract_func_code(code)
            params = ', '.join(['%s %s' % a for a in arg_sigs])
            module_code = """
%(module_body)s
%(cimports)s
def __invoke(%(params)s):
%(func_body)s
    return locals()
            """ % {'cimports': '\n'.join(cimports),
                   'module_body': module_body,
                   'params': params,
                   'func_body': func_body }
            for key, value in literals.items():
                module_code = module_code.replace(key, value)
            pyx_file = os.path.join(lib_dir, module_name + '.pyx')
            fh = open(pyx_file, 'w')
            try:
                fh.write(module_code)
            finally:
                fh.close()
            extension = Extension(
                name = module_name,
                sources = [pyx_file],
                include_dirs = c_include_dirs,
                extra_compile_args = cflags)
            if build_extension is None:
                build_extension = _get_build_extension()
            build_extension.extensions = cythonize(
                [extension],
                include_path=cython_include_dirs or ['.'],
                compiler_directives=cython_compiler_directives,
                quiet=quiet)
            build_extension.build_temp = os.path.dirname(pyx_file)
            build_extension.build_lib  = lib_dir
            build_extension.run()

        module = load_dynamic(module_name, module_path)

    _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
    arg_list = [kwds[arg] for arg in arg_names]
    return module.__invoke(*arg_list)


# Cached suffix used by cython_inline above.  None should get
# overridden with actual value upon the first cython_inline invocation
cython_inline.so_ext = None

_find_non_space = re.compile('[^ ]').search


def strip_common_indent(code):
    min_indent = None
    lines = code.splitlines()
    for line in lines:
        match = _find_non_space(line)
        if not match:
            continue  # blank
        indent = match.start()
        if line[indent] == '#':
            continue  # comment
        if min_indent is None or min_indent > indent:
            min_indent = indent
    for ix, line in enumerate(lines):
        match = _find_non_space(line)
        if not match or not line or line[indent:indent+1] == '#':
            continue
        lines[ix] = line[min_indent:]
    return '\n'.join(lines)


module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
def extract_func_code(code):
    module = []
    function = []
    current = function
    code = code.replace('\t', ' ')
    lines = code.split('\n')
    for line in lines:
        if not line.startswith(' '):
            if module_statement.match(line):
                current = module
            else:
                current = function
        current.append(line)
    return '\n'.join(module), '    ' + '\n    '.join(function)


try:
    from inspect import getcallargs
except ImportError:
    def getcallargs(func, *arg_values, **kwd_values):
        all = {}
        args, varargs, kwds, defaults = inspect.getargspec(func)
        if varargs is not None:
            all[varargs] = arg_values[len(args):]
        for name, value in zip(args, arg_values):
            all[name] = value
        for name, value in list(kwd_values.items()):
            if name in args:
                if name in all:
                    raise TypeError("Duplicate argument %s" % name)
                all[name] = kwd_values.pop(name)
        if kwds is not None:
            all[kwds] = kwd_values
        elif kwd_values:
            raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
        if defaults is None:
            defaults = ()
        first_default = len(args) - len(defaults)
        for ix, name in enumerate(args):
            if name not in all:
                if ix >= first_default:
                    all[name] = defaults[ix - first_default]
                else:
                    raise TypeError("Missing argument: %s" % name)
        return all


def get_body(source):
    ix = source.index(':')
    if source[:5] == 'lambda':
        return "return %s" % source[ix+1:]
    else:
        return source[ix+1:]


# Lots to be done here... It would be especially cool if compiled functions
# could invoke each other quickly.
class RuntimeCompiledFunction(object):

    def __init__(self, f):
        self._f = f
        self._body = get_body(inspect.getsource(f))

    def __call__(self, *args, **kwds):
        all = getcallargs(self._f, *args, **kwds)
        if IS_PY3:
            return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
        else:
            return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)