summaryrefslogtreecommitdiffstats
path: root/contrib/tools/cython/Cython/Compiler/Optimize.py
diff options
context:
space:
mode:
authoralexv-smirnov <[email protected]>2023-03-15 19:59:12 +0300
committeralexv-smirnov <[email protected]>2023-03-15 19:59:12 +0300
commit056bb284ccf8dd6793ec3a54ffa36c4fb2b9ad11 (patch)
tree4740980126f32e3af7937ba0ca5f83e59baa4ab0 /contrib/tools/cython/Cython/Compiler/Optimize.py
parent269126dcced1cc8b53eb4398b4a33e5142f10290 (diff)
add library/cpp/actors, ymake build to ydb oss export
Diffstat (limited to 'contrib/tools/cython/Cython/Compiler/Optimize.py')
-rw-r--r--contrib/tools/cython/Cython/Compiler/Optimize.py4857
1 files changed, 4857 insertions, 0 deletions
diff --git a/contrib/tools/cython/Cython/Compiler/Optimize.py b/contrib/tools/cython/Cython/Compiler/Optimize.py
new file mode 100644
index 00000000000..7e9435ba0aa
--- /dev/null
+++ b/contrib/tools/cython/Cython/Compiler/Optimize.py
@@ -0,0 +1,4857 @@
+from __future__ import absolute_import
+
+import re
+import sys
+import copy
+import codecs
+import itertools
+
+from . import TypeSlots
+from .ExprNodes import not_a_constant
+import cython
+cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object,
+ Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
+ UtilNodes=object, _py_int_types=object)
+
+if sys.version_info[0] >= 3:
+ _py_int_types = int
+ _py_string_types = (bytes, str)
+else:
+ _py_int_types = (int, long)
+ _py_string_types = (bytes, unicode)
+
+from . import Nodes
+from . import ExprNodes
+from . import PyrexTypes
+from . import Visitor
+from . import Builtin
+from . import UtilNodes
+from . import Options
+
+from .Code import UtilityCode, TempitaUtilityCode
+from .StringEncoding import EncodedString, bytes_literal, encoded_string
+from .Errors import error, warning
+from .ParseTreeTransforms import SkipDeclarations
+
+try:
+ from __builtin__ import reduce
+except ImportError:
+ from functools import reduce
+
+try:
+ from __builtin__ import basestring
+except ImportError:
+ basestring = str # Python 3
+
+
+def load_c_utility(name):
+ return UtilityCode.load_cached(name, "Optimize.c")
+
+
+def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
+ if isinstance(node, coercion_nodes):
+ return node.arg
+ return node
+
+
+def unwrap_node(node):
+ while isinstance(node, UtilNodes.ResultRefNode):
+ node = node.expression
+ return node
+
+
+def is_common_value(a, b):
+ a = unwrap_node(a)
+ b = unwrap_node(b)
+ if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
+ return a.name == b.name
+ if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
+ return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
+ return False
+
+
+def filter_none_node(node):
+ if node is not None and node.constant_result is None:
+ return None
+ return node
+
+
+class _YieldNodeCollector(Visitor.TreeVisitor):
+ """
+ YieldExprNode finder for generator expressions.
+ """
+ def __init__(self):
+ Visitor.TreeVisitor.__init__(self)
+ self.yield_stat_nodes = {}
+ self.yield_nodes = []
+
+ visit_Node = Visitor.TreeVisitor.visitchildren
+
+ def visit_YieldExprNode(self, node):
+ self.yield_nodes.append(node)
+ self.visitchildren(node)
+
+ def visit_ExprStatNode(self, node):
+ self.visitchildren(node)
+ if node.expr in self.yield_nodes:
+ self.yield_stat_nodes[node.expr] = node
+
+ # everything below these nodes is out of scope:
+
+ def visit_GeneratorExpressionNode(self, node):
+ pass
+
+ def visit_LambdaNode(self, node):
+ pass
+
+ def visit_FuncDefNode(self, node):
+ pass
+
+
+def _find_single_yield_expression(node):
+ yield_statements = _find_yield_statements(node)
+ if len(yield_statements) != 1:
+ return None, None
+ return yield_statements[0]
+
+
+def _find_yield_statements(node):
+ collector = _YieldNodeCollector()
+ collector.visitchildren(node)
+ try:
+ yield_statements = [
+ (yield_node.arg, collector.yield_stat_nodes[yield_node])
+ for yield_node in collector.yield_nodes
+ ]
+ except KeyError:
+ # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
+ yield_statements = []
+ return yield_statements
+
+
+class IterationTransform(Visitor.EnvTransform):
+ """Transform some common for-in loop patterns into efficient C loops:
+
+ - for-in-dict loop becomes a while loop calling PyDict_Next()
+ - for-in-enumerate is replaced by an external counter variable
+ - for-in-range loop becomes a plain C for loop
+ """
+ def visit_PrimaryCmpNode(self, node):
+ if node.is_ptr_contains():
+
+ # for t in operand2:
+ # if operand1 == t:
+ # res = True
+ # break
+ # else:
+ # res = False
+
+ pos = node.pos
+ result_ref = UtilNodes.ResultRefNode(node)
+ if node.operand2.is_subscript:
+ base_type = node.operand2.base.type.base_type
+ else:
+ base_type = node.operand2.type.base_type
+ target_handle = UtilNodes.TempHandle(base_type)
+ target = target_handle.ref(pos)
+ cmp_node = ExprNodes.PrimaryCmpNode(
+ pos, operator=u'==', operand1=node.operand1, operand2=target)
+ if_body = Nodes.StatListNode(
+ pos,
+ stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
+ Nodes.BreakStatNode(pos)])
+ if_node = Nodes.IfStatNode(
+ pos,
+ if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
+ else_clause=None)
+ for_loop = UtilNodes.TempsBlockNode(
+ pos,
+ temps = [target_handle],
+ body = Nodes.ForInStatNode(
+ pos,
+ target=target,
+ iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
+ body=if_node,
+ else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
+ for_loop = for_loop.analyse_expressions(self.current_env())
+ for_loop = self.visit(for_loop)
+ new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
+
+ if node.operator == 'not_in':
+ new_node = ExprNodes.NotNode(pos, operand=new_node)
+ return new_node
+
+ else:
+ self.visitchildren(node)
+ return node
+
+ def visit_ForInStatNode(self, node):
+ self.visitchildren(node)
+ return self._optimise_for_loop(node, node.iterator.sequence)
+
+ def _optimise_for_loop(self, node, iterable, reversed=False):
+ annotation_type = None
+ if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
+ annotation = iterable.entry.annotation
+ if annotation.is_subscript:
+ annotation = annotation.base # container base type
+ # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
+ if annotation.is_name:
+ if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
+ annotation_type = Builtin.dict_type
+ elif annotation.name == 'Dict':
+ annotation_type = Builtin.dict_type
+ if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
+ annotation_type = Builtin.set_type
+ elif annotation.name in ('Set', 'FrozenSet'):
+ annotation_type = Builtin.set_type
+
+ if Builtin.dict_type in (iterable.type, annotation_type):
+ # like iterating over dict.keys()
+ if reversed:
+ # CPython raises an error here: not a sequence
+ return node
+ return self._transform_dict_iteration(
+ node, dict_obj=iterable, method=None, keys=True, values=False)
+
+ if (Builtin.set_type in (iterable.type, annotation_type) or
+ Builtin.frozenset_type in (iterable.type, annotation_type)):
+ if reversed:
+ # CPython raises an error here: not a sequence
+ return node
+ return self._transform_set_iteration(node, iterable)
+
+ # C array (slice) iteration?
+ if iterable.type.is_ptr or iterable.type.is_array:
+ return self._transform_carray_iteration(node, iterable, reversed=reversed)
+ if iterable.type is Builtin.bytes_type:
+ return self._transform_bytes_iteration(node, iterable, reversed=reversed)
+ if iterable.type is Builtin.unicode_type:
+ return self._transform_unicode_iteration(node, iterable, reversed=reversed)
+
+ # the rest is based on function calls
+ if not isinstance(iterable, ExprNodes.SimpleCallNode):
+ return node
+
+ if iterable.args is None:
+ arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0
+ else:
+ arg_count = len(iterable.args)
+ if arg_count and iterable.self is not None:
+ arg_count -= 1
+
+ function = iterable.function
+ # dict iteration?
+ if function.is_attribute and not reversed and not arg_count:
+ base_obj = iterable.self or function.obj
+ method = function.attribute
+ # in Py3, items() is equivalent to Py2's iteritems()
+ is_safe_iter = self.global_scope().context.language_level >= 3
+
+ if not is_safe_iter and method in ('keys', 'values', 'items'):
+ # try to reduce this to the corresponding .iter*() methods
+ if isinstance(base_obj, ExprNodes.CallNode):
+ inner_function = base_obj.function
+ if (inner_function.is_name and inner_function.name == 'dict'
+ and inner_function.entry
+ and inner_function.entry.is_builtin):
+ # e.g. dict(something).items() => safe to use .iter*()
+ is_safe_iter = True
+
+ keys = values = False
+ if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
+ keys = True
+ elif method == 'itervalues' or (is_safe_iter and method == 'values'):
+ values = True
+ elif method == 'iteritems' or (is_safe_iter and method == 'items'):
+ keys = values = True
+
+ if keys or values:
+ return self._transform_dict_iteration(
+ node, base_obj, method, keys, values)
+
+ # enumerate/reversed ?
+ if iterable.self is None and function.is_name and \
+ function.entry and function.entry.is_builtin:
+ if function.name == 'enumerate':
+ if reversed:
+ # CPython raises an error here: not a sequence
+ return node
+ return self._transform_enumerate_iteration(node, iterable)
+ elif function.name == 'reversed':
+ if reversed:
+ # CPython raises an error here: not a sequence
+ return node
+ return self._transform_reversed_iteration(node, iterable)
+
+ # range() iteration?
+ if Options.convert_range and 1 <= arg_count <= 3 and (
+ iterable.self is None and
+ function.is_name and function.name in ('range', 'xrange') and
+ function.entry and function.entry.is_builtin):
+ if node.target.type.is_int or node.target.type.is_enum:
+ return self._transform_range_iteration(node, iterable, reversed=reversed)
+ if node.target.type.is_pyobject:
+ # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
+ for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args):
+ if isinstance(arg, ExprNodes.IntNode):
+ if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
+ continue
+ break
+ else:
+ return self._transform_range_iteration(node, iterable, reversed=reversed)
+
+ return node
+
+ def _transform_reversed_iteration(self, node, reversed_function):
+ args = reversed_function.arg_tuple.args
+ if len(args) == 0:
+ error(reversed_function.pos,
+ "reversed() requires an iterable argument")
+ return node
+ elif len(args) > 1:
+ error(reversed_function.pos,
+ "reversed() takes exactly 1 argument")
+ return node
+ arg = args[0]
+
+ # reversed(list/tuple) ?
+ if arg.type in (Builtin.tuple_type, Builtin.list_type):
+ node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
+ node.iterator.reversed = True
+ return node
+
+ return self._optimise_for_loop(node, arg, reversed=True)
+
+ PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_char_ptr_type, [
+ PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
+ ])
+
+ PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ssize_t_type, [
+ PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
+ ])
+
+ def _transform_bytes_iteration(self, node, slice_node, reversed=False):
+ target_type = node.target.type
+ if not target_type.is_int and target_type is not Builtin.bytes_type:
+ # bytes iteration returns bytes objects in Py2, but
+ # integers in Py3
+ return node
+
+ unpack_temp_node = UtilNodes.LetRefNode(
+ slice_node.as_none_safe_node("'NoneType' is not iterable"))
+
+ slice_base_node = ExprNodes.PythonCapiCallNode(
+ slice_node.pos, "PyBytes_AS_STRING",
+ self.PyBytes_AS_STRING_func_type,
+ args = [unpack_temp_node],
+ is_temp = 0,
+ )
+ len_node = ExprNodes.PythonCapiCallNode(
+ slice_node.pos, "PyBytes_GET_SIZE",
+ self.PyBytes_GET_SIZE_func_type,
+ args = [unpack_temp_node],
+ is_temp = 0,
+ )
+
+ return UtilNodes.LetNode(
+ unpack_temp_node,
+ self._transform_carray_iteration(
+ node,
+ ExprNodes.SliceIndexNode(
+ slice_node.pos,
+ base = slice_base_node,
+ start = None,
+ step = None,
+ stop = len_node,
+ type = slice_base_node.type,
+ is_temp = 1,
+ ),
+ reversed = reversed))
+
+ PyUnicode_READ_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ucs4_type, [
+ PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
+ PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
+ ])
+
+ init_unicode_iteration_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_int_type, [
+ PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
+ ],
+ exception_value = '-1')
+
+ def _transform_unicode_iteration(self, node, slice_node, reversed=False):
+ if slice_node.is_literal:
+ # try to reduce to byte iteration for plain Latin-1 strings
+ try:
+ bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
+ except UnicodeEncodeError:
+ pass
+ else:
+ bytes_slice = ExprNodes.SliceIndexNode(
+ slice_node.pos,
+ base=ExprNodes.BytesNode(
+ slice_node.pos, value=bytes_value,
+ constant_result=bytes_value,
+ type=PyrexTypes.c_const_char_ptr_type).coerce_to(
+ PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
+ start=None,
+ stop=ExprNodes.IntNode(
+ slice_node.pos, value=str(len(bytes_value)),
+ constant_result=len(bytes_value),
+ type=PyrexTypes.c_py_ssize_t_type),
+ type=Builtin.unicode_type, # hint for Python conversion
+ )
+ return self._transform_carray_iteration(node, bytes_slice, reversed)
+
+ unpack_temp_node = UtilNodes.LetRefNode(
+ slice_node.as_none_safe_node("'NoneType' is not iterable"))
+
+ start_node = ExprNodes.IntNode(
+ node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
+ length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+ end_node = length_temp.ref(node.pos)
+ if reversed:
+ relation1, relation2 = '>', '>='
+ start_node, end_node = end_node, start_node
+ else:
+ relation1, relation2 = '<=', '<'
+
+ kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
+ data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
+ counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+
+ target_value = ExprNodes.PythonCapiCallNode(
+ slice_node.pos, "__Pyx_PyUnicode_READ",
+ self.PyUnicode_READ_func_type,
+ args = [kind_temp.ref(slice_node.pos),
+ data_temp.ref(slice_node.pos),
+ counter_temp.ref(node.target.pos)],
+ is_temp = False,
+ )
+ if target_value.type != node.target.type:
+ target_value = target_value.coerce_to(node.target.type,
+ self.current_env())
+ target_assign = Nodes.SingleAssignmentNode(
+ pos = node.target.pos,
+ lhs = node.target,
+ rhs = target_value)
+ body = Nodes.StatListNode(
+ node.pos,
+ stats = [target_assign, node.body])
+
+ loop_node = Nodes.ForFromStatNode(
+ node.pos,
+ bound1=start_node, relation1=relation1,
+ target=counter_temp.ref(node.target.pos),
+ relation2=relation2, bound2=end_node,
+ step=None, body=body,
+ else_clause=node.else_clause,
+ from_range=True)
+
+ setup_node = Nodes.ExprStatNode(
+ node.pos,
+ expr = ExprNodes.PythonCapiCallNode(
+ slice_node.pos, "__Pyx_init_unicode_iteration",
+ self.init_unicode_iteration_func_type,
+ args = [unpack_temp_node,
+ ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
+ type=PyrexTypes.c_py_ssize_t_ptr_type),
+ ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
+ type=PyrexTypes.c_void_ptr_ptr_type),
+ ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
+ type=PyrexTypes.c_int_ptr_type),
+ ],
+ is_temp = True,
+ result_is_used = False,
+ utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
+ ))
+ return UtilNodes.LetNode(
+ unpack_temp_node,
+ UtilNodes.TempsBlockNode(
+ node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
+ body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
+
+ def _transform_carray_iteration(self, node, slice_node, reversed=False):
+ neg_step = False
+ if isinstance(slice_node, ExprNodes.SliceIndexNode):
+ slice_base = slice_node.base
+ start = filter_none_node(slice_node.start)
+ stop = filter_none_node(slice_node.stop)
+ step = None
+ if not stop:
+ if not slice_base.type.is_pyobject:
+ error(slice_node.pos, "C array iteration requires known end index")
+ return node
+
+ elif slice_node.is_subscript:
+ assert isinstance(slice_node.index, ExprNodes.SliceNode)
+ slice_base = slice_node.base
+ index = slice_node.index
+ start = filter_none_node(index.start)
+ stop = filter_none_node(index.stop)
+ step = filter_none_node(index.step)
+ if step:
+ if not isinstance(step.constant_result, _py_int_types) \
+ or step.constant_result == 0 \
+ or step.constant_result > 0 and not stop \
+ or step.constant_result < 0 and not start:
+ if not slice_base.type.is_pyobject:
+ error(step.pos, "C array iteration requires known step size and end index")
+ return node
+ else:
+ # step sign is handled internally by ForFromStatNode
+ step_value = step.constant_result
+ if reversed:
+ step_value = -step_value
+ neg_step = step_value < 0
+ step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
+ value=str(abs(step_value)),
+ constant_result=abs(step_value))
+
+ elif slice_node.type.is_array:
+ if slice_node.type.size is None:
+ error(slice_node.pos, "C array iteration requires known end index")
+ return node
+ slice_base = slice_node
+ start = None
+ stop = ExprNodes.IntNode(
+ slice_node.pos, value=str(slice_node.type.size),
+ type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
+ step = None
+
+ else:
+ if not slice_node.type.is_pyobject:
+ error(slice_node.pos, "C array iteration requires known end index")
+ return node
+
+ if start:
+ start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ if stop:
+ stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ if stop is None:
+ if neg_step:
+ stop = ExprNodes.IntNode(
+ slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
+ else:
+ error(slice_node.pos, "C array iteration requires known step size and end index")
+ return node
+
+ if reversed:
+ if not start:
+ start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
+ type=PyrexTypes.c_py_ssize_t_type)
+ # if step was provided, it was already negated above
+ start, stop = stop, start
+
+ ptr_type = slice_base.type
+ if ptr_type.is_array:
+ ptr_type = ptr_type.element_ptr_type()
+ carray_ptr = slice_base.coerce_to_simple(self.current_env())
+
+ if start and start.constant_result != 0:
+ start_ptr_node = ExprNodes.AddNode(
+ start.pos,
+ operand1=carray_ptr,
+ operator='+',
+ operand2=start,
+ type=ptr_type)
+ else:
+ start_ptr_node = carray_ptr
+
+ if stop and stop.constant_result != 0:
+ stop_ptr_node = ExprNodes.AddNode(
+ stop.pos,
+ operand1=ExprNodes.CloneNode(carray_ptr),
+ operator='+',
+ operand2=stop,
+ type=ptr_type
+ ).coerce_to_simple(self.current_env())
+ else:
+ stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
+
+ counter = UtilNodes.TempHandle(ptr_type)
+ counter_temp = counter.ref(node.target.pos)
+
+ if slice_base.type.is_string and node.target.type.is_pyobject:
+ # special case: char* -> bytes/unicode
+ if slice_node.type is Builtin.unicode_type:
+ target_value = ExprNodes.CastNode(
+ ExprNodes.DereferenceNode(
+ node.target.pos, operand=counter_temp,
+ type=ptr_type.base_type),
+ PyrexTypes.c_py_ucs4_type).coerce_to(
+ node.target.type, self.current_env())
+ else:
+ # char* -> bytes coercion requires slicing, not indexing
+ target_value = ExprNodes.SliceIndexNode(
+ node.target.pos,
+ start=ExprNodes.IntNode(node.target.pos, value='0',
+ constant_result=0,
+ type=PyrexTypes.c_int_type),
+ stop=ExprNodes.IntNode(node.target.pos, value='1',
+ constant_result=1,
+ type=PyrexTypes.c_int_type),
+ base=counter_temp,
+ type=Builtin.bytes_type,
+ is_temp=1)
+ elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
+ # Allow iteration with pointer target to avoid copy.
+ target_value = counter_temp
+ else:
+ # TODO: can this safely be replaced with DereferenceNode() as above?
+ target_value = ExprNodes.IndexNode(
+ node.target.pos,
+ index=ExprNodes.IntNode(node.target.pos, value='0',
+ constant_result=0,
+ type=PyrexTypes.c_int_type),
+ base=counter_temp,
+ type=ptr_type.base_type)
+
+ if target_value.type != node.target.type:
+ target_value = target_value.coerce_to(node.target.type,
+ self.current_env())
+
+ target_assign = Nodes.SingleAssignmentNode(
+ pos = node.target.pos,
+ lhs = node.target,
+ rhs = target_value)
+
+ body = Nodes.StatListNode(
+ node.pos,
+ stats = [target_assign, node.body])
+
+ relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
+
+ for_node = Nodes.ForFromStatNode(
+ node.pos,
+ bound1=start_ptr_node, relation1=relation1,
+ target=counter_temp,
+ relation2=relation2, bound2=stop_ptr_node,
+ step=step, body=body,
+ else_clause=node.else_clause,
+ from_range=True)
+
+ return UtilNodes.TempsBlockNode(
+ node.pos, temps=[counter],
+ body=for_node)
+
+ def _transform_enumerate_iteration(self, node, enumerate_function):
+ args = enumerate_function.arg_tuple.args
+ if len(args) == 0:
+ error(enumerate_function.pos,
+ "enumerate() requires an iterable argument")
+ return node
+ elif len(args) > 2:
+ error(enumerate_function.pos,
+ "enumerate() takes at most 2 arguments")
+ return node
+
+ if not node.target.is_sequence_constructor:
+ # leave this untouched for now
+ return node
+ targets = node.target.args
+ if len(targets) != 2:
+ # leave this untouched for now
+ return node
+
+ enumerate_target, iterable_target = targets
+ counter_type = enumerate_target.type
+
+ if not counter_type.is_pyobject and not counter_type.is_int:
+ # nothing we can do here, I guess
+ return node
+
+ if len(args) == 2:
+ start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
+ else:
+ start = ExprNodes.IntNode(enumerate_function.pos,
+ value='0',
+ type=counter_type,
+ constant_result=0)
+ temp = UtilNodes.LetRefNode(start)
+
+ inc_expression = ExprNodes.AddNode(
+ enumerate_function.pos,
+ operand1 = temp,
+ operand2 = ExprNodes.IntNode(node.pos, value='1',
+ type=counter_type,
+ constant_result=1),
+ operator = '+',
+ type = counter_type,
+ #inplace = True, # not worth using in-place operation for Py ints
+ is_temp = counter_type.is_pyobject
+ )
+
+ loop_body = [
+ Nodes.SingleAssignmentNode(
+ pos = enumerate_target.pos,
+ lhs = enumerate_target,
+ rhs = temp),
+ Nodes.SingleAssignmentNode(
+ pos = enumerate_target.pos,
+ lhs = temp,
+ rhs = inc_expression)
+ ]
+
+ if isinstance(node.body, Nodes.StatListNode):
+ node.body.stats = loop_body + node.body.stats
+ else:
+ loop_body.append(node.body)
+ node.body = Nodes.StatListNode(
+ node.body.pos,
+ stats = loop_body)
+
+ node.target = iterable_target
+ node.item = node.item.coerce_to(iterable_target.type, self.current_env())
+ node.iterator.sequence = args[0]
+
+ # recurse into loop to check for further optimisations
+ return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
+
+ def _find_for_from_node_relations(self, neg_step_value, reversed):
+ if reversed:
+ if neg_step_value:
+ return '<', '<='
+ else:
+ return '>', '>='
+ else:
+ if neg_step_value:
+ return '>=', '>'
+ else:
+ return '<=', '<'
+
+ def _transform_range_iteration(self, node, range_function, reversed=False):
+ args = range_function.arg_tuple.args
+ if len(args) < 3:
+ step_pos = range_function.pos
+ step_value = 1
+ step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
+ else:
+ step = args[2]
+ step_pos = step.pos
+ if not isinstance(step.constant_result, _py_int_types):
+ # cannot determine step direction
+ return node
+ step_value = step.constant_result
+ if step_value == 0:
+ # will lead to an error elsewhere
+ return node
+ step = ExprNodes.IntNode(step_pos, value=str(step_value),
+ constant_result=step_value)
+
+ if len(args) == 1:
+ bound1 = ExprNodes.IntNode(range_function.pos, value='0',
+ constant_result=0)
+ bound2 = args[0].coerce_to_integer(self.current_env())
+ else:
+ bound1 = args[0].coerce_to_integer(self.current_env())
+ bound2 = args[1].coerce_to_integer(self.current_env())
+
+ relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
+
+ bound2_ref_node = None
+ if reversed:
+ bound1, bound2 = bound2, bound1
+ abs_step = abs(step_value)
+ if abs_step != 1:
+ if (isinstance(bound1.constant_result, _py_int_types) and
+ isinstance(bound2.constant_result, _py_int_types)):
+ # calculate final bounds now
+ if step_value < 0:
+ begin_value = bound2.constant_result
+ end_value = bound1.constant_result
+ bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
+ else:
+ begin_value = bound1.constant_result
+ end_value = bound2.constant_result
+ bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1
+
+ bound1 = ExprNodes.IntNode(
+ bound1.pos, value=str(bound1_value), constant_result=bound1_value,
+ type=PyrexTypes.spanning_type(bound1.type, bound2.type))
+ else:
+ # evaluate the same expression as above at runtime
+ bound2_ref_node = UtilNodes.LetRefNode(bound2)
+ bound1 = self._build_range_step_calculation(
+ bound1, bound2_ref_node, step, step_value)
+
+ if step_value < 0:
+ step_value = -step_value
+ step.value = str(step_value)
+ step.constant_result = step_value
+ step = step.coerce_to_integer(self.current_env())
+
+ if not bound2.is_literal:
+ # stop bound must be immutable => keep it in a temp var
+ bound2_is_temp = True
+ bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
+ else:
+ bound2_is_temp = False
+
+ for_node = Nodes.ForFromStatNode(
+ node.pos,
+ target=node.target,
+ bound1=bound1, relation1=relation1,
+ relation2=relation2, bound2=bound2,
+ step=step, body=node.body,
+ else_clause=node.else_clause,
+ from_range=True)
+ for_node.set_up_loop(self.current_env())
+
+ if bound2_is_temp:
+ for_node = UtilNodes.LetNode(bound2, for_node)
+
+ return for_node
+
+ def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
+ abs_step = abs(step_value)
+ spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
+ if step.type.is_int and abs_step < 0x7FFF:
+ # Avoid loss of integer precision warnings.
+ spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
+ else:
+ spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
+ if step_value < 0:
+ begin_value = bound2_ref_node
+ end_value = bound1
+ final_op = '-'
+ else:
+ begin_value = bound1
+ end_value = bound2_ref_node
+ final_op = '+'
+
+ step_calculation_node = ExprNodes.binop_node(
+ bound1.pos,
+ operand1=ExprNodes.binop_node(
+ bound1.pos,
+ operand1=bound2_ref_node,
+ operator=final_op, # +/-
+ operand2=ExprNodes.MulNode(
+ bound1.pos,
+ operand1=ExprNodes.IntNode(
+ bound1.pos,
+ value=str(abs_step),
+ constant_result=abs_step,
+ type=spanning_step_type),
+ operator='*',
+ operand2=ExprNodes.DivNode(
+ bound1.pos,
+ operand1=ExprNodes.SubNode(
+ bound1.pos,
+ operand1=ExprNodes.SubNode(
+ bound1.pos,
+ operand1=begin_value,
+ operator='-',
+ operand2=end_value,
+ type=spanning_type),
+ operator='-',
+ operand2=ExprNodes.IntNode(
+ bound1.pos,
+ value='1',
+ constant_result=1),
+ type=spanning_step_type),
+ operator='//',
+ operand2=ExprNodes.IntNode(
+ bound1.pos,
+ value=str(abs_step),
+ constant_result=abs_step,
+ type=spanning_step_type),
+ type=spanning_step_type),
+ type=spanning_step_type),
+ type=spanning_step_type),
+ operator=final_op, # +/-
+ operand2=ExprNodes.IntNode(
+ bound1.pos,
+ value='1',
+ constant_result=1),
+ type=spanning_type)
+ return step_calculation_node
+
+ def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
+ temps = []
+ temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
+ temps.append(temp)
+ dict_temp = temp.ref(dict_obj.pos)
+ temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+ temps.append(temp)
+ pos_temp = temp.ref(node.pos)
+
+ key_target = value_target = tuple_target = None
+ if keys and values:
+ if node.target.is_sequence_constructor:
+ if len(node.target.args) == 2:
+ key_target, value_target = node.target.args
+ else:
+ # unusual case that may or may not lead to an error
+ return node
+ else:
+ tuple_target = node.target
+ elif keys:
+ key_target = node.target
+ else:
+ value_target = node.target
+
+ if isinstance(node.body, Nodes.StatListNode):
+ body = node.body
+ else:
+ body = Nodes.StatListNode(pos = node.body.pos,
+ stats = [node.body])
+
+ # keep original length to guard against dict modification
+ dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+ temps.append(dict_len_temp)
+ dict_len_temp_addr = ExprNodes.AmpersandNode(
+ node.pos, operand=dict_len_temp.ref(dict_obj.pos),
+ type=PyrexTypes.c_ptr_type(dict_len_temp.type))
+ temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
+ temps.append(temp)
+ is_dict_temp = temp.ref(node.pos)
+ is_dict_temp_addr = ExprNodes.AmpersandNode(
+ node.pos, operand=is_dict_temp,
+ type=PyrexTypes.c_ptr_type(temp.type))
+
+ iter_next_node = Nodes.DictIterationNextNode(
+ dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
+ key_target, value_target, tuple_target,
+ is_dict_temp)
+ iter_next_node = iter_next_node.analyse_expressions(self.current_env())
+ body.stats[0:0] = [iter_next_node]
+
+ if method:
+ method_node = ExprNodes.StringNode(
+ dict_obj.pos, is_identifier=True, value=method)
+ dict_obj = dict_obj.as_none_safe_node(
+ "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''),
+ error = "PyExc_AttributeError",
+ format_args = [method])
+ else:
+ method_node = ExprNodes.NullNode(dict_obj.pos)
+ dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
+
+ def flag_node(value):
+ value = value and 1 or 0
+ return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
+
+ result_code = [
+ Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = pos_temp,
+ rhs = ExprNodes.IntNode(node.pos, value='0',
+ constant_result=0)),
+ Nodes.SingleAssignmentNode(
+ dict_obj.pos,
+ lhs = dict_temp,
+ rhs = ExprNodes.PythonCapiCallNode(
+ dict_obj.pos,
+ "__Pyx_dict_iterator",
+ self.PyDict_Iterator_func_type,
+ utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
+ args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
+ method_node, dict_len_temp_addr, is_dict_temp_addr,
+ ],
+ is_temp=True,
+ )),
+ Nodes.WhileStatNode(
+ node.pos,
+ condition = None,
+ body = body,
+ else_clause = node.else_clause
+ )
+ ]
+
+ return UtilNodes.TempsBlockNode(
+ node.pos, temps=temps,
+ body=Nodes.StatListNode(
+ node.pos,
+ stats = result_code
+ ))
+
+ PyDict_Iterator_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None),
+ PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None),
+ ])
+
+ PySet_Iterator_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("is_set", PyrexTypes.c_int_type, None),
+ PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None),
+ ])
+
+ def _transform_set_iteration(self, node, set_obj):
+ temps = []
+ temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
+ temps.append(temp)
+ set_temp = temp.ref(set_obj.pos)
+ temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+ temps.append(temp)
+ pos_temp = temp.ref(node.pos)
+
+ if isinstance(node.body, Nodes.StatListNode):
+ body = node.body
+ else:
+ body = Nodes.StatListNode(pos = node.body.pos,
+ stats = [node.body])
+
+ # keep original length to guard against set modification
+ set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
+ temps.append(set_len_temp)
+ set_len_temp_addr = ExprNodes.AmpersandNode(
+ node.pos, operand=set_len_temp.ref(set_obj.pos),
+ type=PyrexTypes.c_ptr_type(set_len_temp.type))
+ temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
+ temps.append(temp)
+ is_set_temp = temp.ref(node.pos)
+ is_set_temp_addr = ExprNodes.AmpersandNode(
+ node.pos, operand=is_set_temp,
+ type=PyrexTypes.c_ptr_type(temp.type))
+
+ value_target = node.target
+ iter_next_node = Nodes.SetIterationNextNode(
+ set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
+ iter_next_node = iter_next_node.analyse_expressions(self.current_env())
+ body.stats[0:0] = [iter_next_node]
+
+ def flag_node(value):
+ value = value and 1 or 0
+ return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
+
+ result_code = [
+ Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs=pos_temp,
+ rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
+ Nodes.SingleAssignmentNode(
+ set_obj.pos,
+ lhs=set_temp,
+ rhs=ExprNodes.PythonCapiCallNode(
+ set_obj.pos,
+ "__Pyx_set_iterator",
+ self.PySet_Iterator_func_type,
+ utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
+ args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
+ set_len_temp_addr, is_set_temp_addr,
+ ],
+ is_temp=True,
+ )),
+ Nodes.WhileStatNode(
+ node.pos,
+ condition=None,
+ body=body,
+ else_clause=node.else_clause,
+ )
+ ]
+
+ return UtilNodes.TempsBlockNode(
+ node.pos, temps=temps,
+ body=Nodes.StatListNode(
+ node.pos,
+ stats = result_code
+ ))
+
+
+class SwitchTransform(Visitor.EnvTransform):
+ """
+ This transformation tries to turn long if statements into C switch statements.
+ The requirement is that every clause be an (or of) var == value, where the var
+ is common among all clauses and both var and value are ints.
+ """
+ NO_MATCH = (None, None, None)
+
+ def extract_conditions(self, cond, allow_not_in):
+ while True:
+ if isinstance(cond, (ExprNodes.CoerceToTempNode,
+ ExprNodes.CoerceToBooleanNode)):
+ cond = cond.arg
+ elif isinstance(cond, ExprNodes.BoolBinopResultNode):
+ cond = cond.arg.arg
+ elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
+ # this is what we get from the FlattenInListTransform
+ cond = cond.subexpression
+ elif isinstance(cond, ExprNodes.TypecastNode):
+ cond = cond.operand
+ else:
+ break
+
+ if isinstance(cond, ExprNodes.PrimaryCmpNode):
+ if cond.cascade is not None:
+ return self.NO_MATCH
+ elif cond.is_c_string_contains() and \
+ isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
+ not_in = cond.operator == 'not_in'
+ if not_in and not allow_not_in:
+ return self.NO_MATCH
+ if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
+ cond.operand2.contains_surrogates():
+ # dealing with surrogates leads to different
+ # behaviour on wide and narrow Unicode
+ # platforms => refuse to optimise this case
+ return self.NO_MATCH
+ return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
+ elif not cond.is_python_comparison():
+ if cond.operator == '==':
+ not_in = False
+ elif allow_not_in and cond.operator == '!=':
+ not_in = True
+ else:
+ return self.NO_MATCH
+ # this looks somewhat silly, but it does the right
+ # checks for NameNode and AttributeNode
+ if is_common_value(cond.operand1, cond.operand1):
+ if cond.operand2.is_literal:
+ return not_in, cond.operand1, [cond.operand2]
+ elif getattr(cond.operand2, 'entry', None) \
+ and cond.operand2.entry.is_const:
+ return not_in, cond.operand1, [cond.operand2]
+ if is_common_value(cond.operand2, cond.operand2):
+ if cond.operand1.is_literal:
+ return not_in, cond.operand2, [cond.operand1]
+ elif getattr(cond.operand1, 'entry', None) \
+ and cond.operand1.entry.is_const:
+ return not_in, cond.operand2, [cond.operand1]
+ elif isinstance(cond, ExprNodes.BoolBinopNode):
+ if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
+ allow_not_in = (cond.operator == 'and')
+ not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
+ not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
+ if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
+ if (not not_in_1) or allow_not_in:
+ return not_in_1, t1, c1+c2
+ return self.NO_MATCH
+
+ def extract_in_string_conditions(self, string_literal):
+ if isinstance(string_literal, ExprNodes.UnicodeNode):
+ charvals = list(map(ord, set(string_literal.value)))
+ charvals.sort()
+ return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
+ constant_result=charval)
+ for charval in charvals ]
+ else:
+ # this is a bit tricky as Py3's bytes type returns
+ # integers on iteration, whereas Py2 returns 1-char byte
+ # strings
+ characters = string_literal.value
+ characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
+ characters.sort()
+ return [ ExprNodes.CharNode(string_literal.pos, value=charval,
+ constant_result=charval)
+ for charval in characters ]
+
+ def extract_common_conditions(self, common_var, condition, allow_not_in):
+ not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
+ if var is None:
+ return self.NO_MATCH
+ elif common_var is not None and not is_common_value(var, common_var):
+ return self.NO_MATCH
+ elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
+ return self.NO_MATCH
+ return not_in, var, conditions
+
+ def has_duplicate_values(self, condition_values):
+ # duplicated values don't work in a switch statement
+ seen = set()
+ for value in condition_values:
+ if value.has_constant_result():
+ if value.constant_result in seen:
+ return True
+ seen.add(value.constant_result)
+ else:
+ # this isn't completely safe as we don't know the
+ # final C value, but this is about the best we can do
+ try:
+ if value.entry.cname in seen:
+ return True
+ except AttributeError:
+ return True # play safe
+ seen.add(value.entry.cname)
+ return False
+
+ def visit_IfStatNode(self, node):
+ if not self.current_directives.get('optimize.use_switch'):
+ self.visitchildren(node)
+ return node
+
+ common_var = None
+ cases = []
+ for if_clause in node.if_clauses:
+ _, common_var, conditions = self.extract_common_conditions(
+ common_var, if_clause.condition, False)
+ if common_var is None:
+ self.visitchildren(node)
+ return node
+ cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos,
+ conditions=conditions,
+ body=if_clause.body))
+
+ condition_values = [
+ cond for case in cases for cond in case.conditions]
+ if len(condition_values) < 2:
+ self.visitchildren(node)
+ return node
+ if self.has_duplicate_values(condition_values):
+ self.visitchildren(node)
+ return node
+
+ # Recurse into body subtrees that we left untouched so far.
+ self.visitchildren(node, 'else_clause')
+ for case in cases:
+ self.visitchildren(case, 'body')
+
+ common_var = unwrap_node(common_var)
+ switch_node = Nodes.SwitchStatNode(pos=node.pos,
+ test=common_var,
+ cases=cases,
+ else_clause=node.else_clause)
+ return switch_node
+
+ def visit_CondExprNode(self, node):
+ if not self.current_directives.get('optimize.use_switch'):
+ self.visitchildren(node)
+ return node
+
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node.test, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
+ return node
+
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ node.true_val, node.false_val)
+
+ def visit_BoolBinopNode(self, node):
+ if not self.current_directives.get('optimize.use_switch'):
+ self.visitchildren(node)
+ return node
+
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
+ node.wrap_operands(self.current_env()) # in case we changed the operands
+ return node
+
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+ ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
+
+ def visit_PrimaryCmpNode(self, node):
+ if not self.current_directives.get('optimize.use_switch'):
+ self.visitchildren(node)
+ return node
+
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
+ return node
+
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+ ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
+
+ def build_simple_switch_statement(self, node, common_var, conditions,
+ not_in, true_val, false_val):
+ result_ref = UtilNodes.ResultRefNode(node)
+ true_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs=result_ref,
+ rhs=true_val.coerce_to(node.type, self.current_env()),
+ first=True)
+ false_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs=result_ref,
+ rhs=false_val.coerce_to(node.type, self.current_env()),
+ first=True)
+
+ if not_in:
+ true_body, false_body = false_body, true_body
+
+ cases = [Nodes.SwitchCaseNode(pos = node.pos,
+ conditions = conditions,
+ body = true_body)]
+
+ common_var = unwrap_node(common_var)
+ switch_node = Nodes.SwitchStatNode(pos = node.pos,
+ test = common_var,
+ cases = cases,
+ else_clause = false_body)
+ replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
+ return replacement
+
+ def visit_EvalWithTempExprNode(self, node):
+ if not self.current_directives.get('optimize.use_switch'):
+ self.visitchildren(node)
+ return node
+
+ # drop unused expression temp from FlattenInListTransform
+ orig_expr = node.subexpression
+ temp_ref = node.lazy_temp
+ self.visitchildren(node)
+ if node.subexpression is not orig_expr:
+ # node was restructured => check if temp is still used
+ if not Visitor.tree_contains(node.subexpression, temp_ref):
+ return node.subexpression
+ return node
+
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+
+class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
+ """
+ This transformation flattens "x in [val1, ..., valn]" into a sequential list
+ of comparisons.
+ """
+
+ def visit_PrimaryCmpNode(self, node):
+ self.visitchildren(node)
+ if node.cascade is not None:
+ return node
+ elif node.operator == 'in':
+ conjunction = 'or'
+ eq_or_neq = '=='
+ elif node.operator == 'not_in':
+ conjunction = 'and'
+ eq_or_neq = '!='
+ else:
+ return node
+
+ if not isinstance(node.operand2, (ExprNodes.TupleNode,
+ ExprNodes.ListNode,
+ ExprNodes.SetNode)):
+ return node
+
+ args = node.operand2.args
+ if len(args) == 0:
+ # note: lhs may have side effects
+ return node
+
+ if any([arg.is_starred for arg in args]):
+ # Starred arguments do not directly translate to comparisons or "in" tests.
+ return node
+
+ lhs = UtilNodes.ResultRefNode(node.operand1)
+
+ conds = []
+ temps = []
+ for arg in args:
+ try:
+ # Trial optimisation to avoid redundant temp
+ # assignments. However, since is_simple() is meant to
+ # be called after type analysis, we ignore any errors
+ # and just play safe in that case.
+ is_simple_arg = arg.is_simple()
+ except Exception:
+ is_simple_arg = False
+ if not is_simple_arg:
+ # must evaluate all non-simple RHS before doing the comparisons
+ arg = UtilNodes.LetRefNode(arg)
+ temps.append(arg)
+ cond = ExprNodes.PrimaryCmpNode(
+ pos = node.pos,
+ operand1 = lhs,
+ operator = eq_or_neq,
+ operand2 = arg,
+ cascade = None)
+ conds.append(ExprNodes.TypecastNode(
+ pos = node.pos,
+ operand = cond,
+ type = PyrexTypes.c_bint_type))
+ def concat(left, right):
+ return ExprNodes.BoolBinopNode(
+ pos = node.pos,
+ operator = conjunction,
+ operand1 = left,
+ operand2 = right)
+
+ condition = reduce(concat, conds)
+ new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
+ for temp in temps[::-1]:
+ new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
+ return new_node
+
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+
+class DropRefcountingTransform(Visitor.VisitorTransform):
+ """Drop ref-counting in safe places.
+ """
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+ def visit_ParallelAssignmentNode(self, node):
+ """
+ Parallel swap assignments like 'a,b = b,a' are safe.
+ """
+ left_names, right_names = [], []
+ left_indices, right_indices = [], []
+ temps = []
+
+ for stat in node.stats:
+ if isinstance(stat, Nodes.SingleAssignmentNode):
+ if not self._extract_operand(stat.lhs, left_names,
+ left_indices, temps):
+ return node
+ if not self._extract_operand(stat.rhs, right_names,
+ right_indices, temps):
+ return node
+ elif isinstance(stat, Nodes.CascadedAssignmentNode):
+ # FIXME
+ return node
+ else:
+ return node
+
+ if left_names or right_names:
+ # lhs/rhs names must be a non-redundant permutation
+ lnames = [ path for path, n in left_names ]
+ rnames = [ path for path, n in right_names ]
+ if set(lnames) != set(rnames):
+ return node
+ if len(set(lnames)) != len(right_names):
+ return node
+
+ if left_indices or right_indices:
+ # base name and index of index nodes must be a
+ # non-redundant permutation
+ lindices = []
+ for lhs_node in left_indices:
+ index_id = self._extract_index_id(lhs_node)
+ if not index_id:
+ return node
+ lindices.append(index_id)
+ rindices = []
+ for rhs_node in right_indices:
+ index_id = self._extract_index_id(rhs_node)
+ if not index_id:
+ return node
+ rindices.append(index_id)
+
+ if set(lindices) != set(rindices):
+ return node
+ if len(set(lindices)) != len(right_indices):
+ return node
+
+ # really supporting IndexNode requires support in
+ # __Pyx_GetItemInt(), so let's stop short for now
+ return node
+
+ temp_args = [t.arg for t in temps]
+ for temp in temps:
+ temp.use_managed_ref = False
+
+ for _, name_node in left_names + right_names:
+ if name_node not in temp_args:
+ name_node.use_managed_ref = False
+
+ for index_node in left_indices + right_indices:
+ index_node.use_managed_ref = False
+
+ return node
+
+ def _extract_operand(self, node, names, indices, temps):
+ node = unwrap_node(node)
+ if not node.type.is_pyobject:
+ return False
+ if isinstance(node, ExprNodes.CoerceToTempNode):
+ temps.append(node)
+ node = node.arg
+ name_path = []
+ obj_node = node
+ while obj_node.is_attribute:
+ if obj_node.is_py_attr:
+ return False
+ name_path.append(obj_node.member)
+ obj_node = obj_node.obj
+ if obj_node.is_name:
+ name_path.append(obj_node.name)
+ names.append( ('.'.join(name_path[::-1]), node) )
+ elif node.is_subscript:
+ if node.base.type != Builtin.list_type:
+ return False
+ if not node.index.type.is_int:
+ return False
+ if not node.base.is_name:
+ return False
+ indices.append(node)
+ else:
+ return False
+ return True
+
+ def _extract_index_id(self, index_node):
+ base = index_node.base
+ index = index_node.index
+ if isinstance(index, ExprNodes.NameNode):
+ index_val = index.name
+ elif isinstance(index, ExprNodes.ConstNode):
+ # FIXME:
+ return None
+ else:
+ return None
+ return (base.name, index_val)
+
+
+class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
+ """Optimize some common calls to builtin types *before* the type
+ analysis phase and *after* the declarations analysis phase.
+
+ This transform cannot make use of any argument types, but it can
+ restructure the tree in a way that the type analysis phase can
+ respond to.
+
+ Introducing C function calls here may not be a good idea. Move
+ them to the OptimizeBuiltinCalls transform instead, which runs
+ after type analysis.
+ """
+ # only intercept on call nodes
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+ def visit_SimpleCallNode(self, node):
+ self.visitchildren(node)
+ function = node.function
+ if not self._function_is_builtin_name(function):
+ return node
+ return self._dispatch_to_handler(node, function, node.args)
+
+ def visit_GeneralCallNode(self, node):
+ self.visitchildren(node)
+ function = node.function
+ if not self._function_is_builtin_name(function):
+ return node
+ arg_tuple = node.positional_args
+ if not isinstance(arg_tuple, ExprNodes.TupleNode):
+ return node
+ args = arg_tuple.args
+ return self._dispatch_to_handler(
+ node, function, args, node.keyword_args)
+
+ def _function_is_builtin_name(self, function):
+ if not function.is_name:
+ return False
+ env = self.current_env()
+ entry = env.lookup(function.name)
+ if entry is not env.builtin_scope().lookup_here(function.name):
+ return False
+ # if entry is None, it's at least an undeclared name, so likely builtin
+ return True
+
+ def _dispatch_to_handler(self, node, function, args, kwargs=None):
+ if kwargs is None:
+ handler_name = '_handle_simple_function_%s' % function.name
+ else:
+ handler_name = '_handle_general_function_%s' % function.name
+ handle_call = getattr(self, handler_name, None)
+ if handle_call is not None:
+ if kwargs is None:
+ return handle_call(node, args)
+ else:
+ return handle_call(node, args, kwargs)
+ return node
+
+ def _inject_capi_function(self, node, cname, func_type, utility_code=None):
+ node.function = ExprNodes.PythonCapiFunctionNode(
+ node.function.pos, node.function.name, cname, func_type,
+ utility_code = utility_code)
+
+ def _error_wrong_arg_count(self, function_name, node, args, expected=None):
+ if not expected: # None or 0
+ arg_str = ''
+ elif isinstance(expected, basestring) or expected > 1:
+ arg_str = '...'
+ elif expected == 1:
+ arg_str = 'x'
+ else:
+ arg_str = ''
+ if expected is not None:
+ expected_str = 'expected %s, ' % expected
+ else:
+ expected_str = ''
+ error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
+ function_name, arg_str, expected_str, len(args)))
+
+ # specific handlers for simple call nodes
+
+ def _handle_simple_function_float(self, node, pos_args):
+ if not pos_args:
+ return ExprNodes.FloatNode(node.pos, value='0.0')
+ if len(pos_args) > 1:
+ self._error_wrong_arg_count('float', node, pos_args, 1)
+ arg_type = getattr(pos_args[0], 'type', None)
+ if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
+ return pos_args[0]
+ return node
+
+ def _handle_simple_function_slice(self, node, pos_args):
+ arg_count = len(pos_args)
+ start = step = None
+ if arg_count == 1:
+ stop, = pos_args
+ elif arg_count == 2:
+ start, stop = pos_args
+ elif arg_count == 3:
+ start, stop, step = pos_args
+ else:
+ self._error_wrong_arg_count('slice', node, pos_args)
+ return node
+ return ExprNodes.SliceNode(
+ node.pos,
+ start=start or ExprNodes.NoneNode(node.pos),
+ stop=stop,
+ step=step or ExprNodes.NoneNode(node.pos))
+
+ def _handle_simple_function_ord(self, node, pos_args):
+ """Unpack ord('X').
+ """
+ if len(pos_args) != 1:
+ return node
+ arg = pos_args[0]
+ if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
+ if len(arg.value) == 1:
+ return ExprNodes.IntNode(
+ arg.pos, type=PyrexTypes.c_long_type,
+ value=str(ord(arg.value)),
+ constant_result=ord(arg.value)
+ )
+ elif isinstance(arg, ExprNodes.StringNode):
+ if arg.unicode_value and len(arg.unicode_value) == 1 \
+ and ord(arg.unicode_value) <= 255: # Py2/3 portability
+ return ExprNodes.IntNode(
+ arg.pos, type=PyrexTypes.c_int_type,
+ value=str(ord(arg.unicode_value)),
+ constant_result=ord(arg.unicode_value)
+ )
+ return node
+
+ # sequence processing
+
+ def _handle_simple_function_all(self, node, pos_args):
+ """Transform
+
+ _result = all(p(x) for L in LL for x in L)
+
+ into
+
+ for L in LL:
+ for x in L:
+ if not p(x):
+ return False
+ else:
+ return True
+ """
+ return self._transform_any_all(node, pos_args, False)
+
+ def _handle_simple_function_any(self, node, pos_args):
+ """Transform
+
+ _result = any(p(x) for L in LL for x in L)
+
+ into
+
+ for L in LL:
+ for x in L:
+ if p(x):
+ return True
+ else:
+ return False
+ """
+ return self._transform_any_all(node, pos_args, True)
+
+ def _transform_any_all(self, node, pos_args, is_any):
+ if len(pos_args) != 1:
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ generator_body = gen_expr_node.def_node.gbody
+ loop_node = generator_body.body
+ yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
+ if yield_expression is None:
+ return node
+
+ if is_any:
+ condition = yield_expression
+ else:
+ condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
+
+ test_node = Nodes.IfStatNode(
+ yield_expression.pos, else_clause=None, if_clauses=[
+ Nodes.IfClauseNode(
+ yield_expression.pos,
+ condition=condition,
+ body=Nodes.ReturnStatNode(
+ node.pos,
+ value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
+ )]
+ )
+ loop_node.else_clause = Nodes.ReturnStatNode(
+ node.pos,
+ value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
+
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
+
+ return ExprNodes.InlinedGeneratorExpressionNode(
+ gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
+
+ PySequence_List_func_type = PyrexTypes.CFuncType(
+ Builtin.list_type,
+ [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
+
+ def _handle_simple_function_sorted(self, node, pos_args):
+ """Transform sorted(genexpr) and sorted([listcomp]) into
+ [listcomp].sort(). CPython just reads the iterable into a
+ list and calls .sort() on it. Expanding the iterable in a
+ listcomp is still faster and the result can be sorted in
+ place.
+ """
+ if len(pos_args) != 1:
+ return node
+
+ arg = pos_args[0]
+ if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
+ list_node = pos_args[0]
+ loop_node = list_node.loop
+
+ elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
+ gen_expr_node = arg
+ loop_node = gen_expr_node.loop
+ yield_statements = _find_yield_statements(loop_node)
+ if not yield_statements:
+ return node
+
+ list_node = ExprNodes.InlinedGeneratorExpressionNode(
+ node.pos, gen_expr_node, orig_func='sorted',
+ comprehension_type=Builtin.list_type)
+
+ for yield_expression, yield_stat_node in yield_statements:
+ append_node = ExprNodes.ComprehensionAppendNode(
+ yield_expression.pos,
+ expr=yield_expression,
+ target=list_node.target)
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
+
+ elif arg.is_sequence_constructor:
+ # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list,
+ # so starting off with a fresh one is more efficient.
+ list_node = loop_node = arg.as_list()
+
+ else:
+ # Interestingly, PySequence_List works on a lot of non-sequence
+ # things as well.
+ list_node = loop_node = ExprNodes.PythonCapiCallNode(
+ node.pos, "PySequence_List", self.PySequence_List_func_type,
+ args=pos_args, is_temp=True)
+
+ result_node = UtilNodes.ResultRefNode(
+ pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
+ list_assign_node = Nodes.SingleAssignmentNode(
+ node.pos, lhs=result_node, rhs=list_node, first=True)
+
+ sort_method = ExprNodes.AttributeNode(
+ node.pos, obj=result_node, attribute=EncodedString('sort'),
+ # entry ? type ?
+ needs_none_check=False)
+ sort_node = Nodes.ExprStatNode(
+ node.pos, expr=ExprNodes.SimpleCallNode(
+ node.pos, function=sort_method, args=[]))
+
+ sort_node.analyse_declarations(self.current_env())
+
+ return UtilNodes.TempResultFromStatNode(
+ result_node,
+ Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
+
+ def __handle_simple_function_sum(self, node, pos_args):
+ """Transform sum(genexpr) into an equivalent inlined aggregation loop.
+ """
+ if len(pos_args) not in (1,2):
+ return node
+ if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
+ ExprNodes.ComprehensionNode)):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
+ yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
+ # FIXME: currently nonfunctional
+ yield_expression = None
+ if yield_expression is None:
+ return node
+ else: # ComprehensionNode
+ yield_stat_node = gen_expr_node.append
+ yield_expression = yield_stat_node.expr
+ try:
+ if not yield_expression.is_literal or not yield_expression.type.is_int:
+ return node
+ except AttributeError:
+ return node # in case we don't have a type yet
+ # special case: old Py2 backwards compatible "sum([int_const for ...])"
+ # can safely be unpacked into a genexpr
+
+ if len(pos_args) == 1:
+ start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
+ else:
+ start = pos_args[1]
+
+ result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
+ add_node = Nodes.SingleAssignmentNode(
+ yield_expression.pos,
+ lhs = result_ref,
+ rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
+ )
+
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
+
+ exec_code = Nodes.StatListNode(
+ node.pos,
+ stats = [
+ Nodes.SingleAssignmentNode(
+ start.pos,
+ lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
+ rhs = start,
+ first = True),
+ loop_node
+ ])
+
+ return ExprNodes.InlinedGeneratorExpressionNode(
+ gen_expr_node.pos, loop = exec_code, result_node = result_ref,
+ expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
+ has_local_scope = gen_expr_node.has_local_scope)
+
+ def _handle_simple_function_min(self, node, pos_args):
+ return self._optimise_min_max(node, pos_args, '<')
+
+ def _handle_simple_function_max(self, node, pos_args):
+ return self._optimise_min_max(node, pos_args, '>')
+
+ def _optimise_min_max(self, node, args, operator):
+ """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
+ """
+ if len(args) <= 1:
+ if len(args) == 1 and args[0].is_sequence_constructor:
+ args = args[0].args
+ if len(args) <= 1:
+ # leave this to Python
+ return node
+
+ cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
+
+ last_result = args[0]
+ for arg_node in cascaded_nodes:
+ result_ref = UtilNodes.ResultRefNode(last_result)
+ last_result = ExprNodes.CondExprNode(
+ arg_node.pos,
+ true_val = arg_node,
+ false_val = result_ref,
+ test = ExprNodes.PrimaryCmpNode(
+ arg_node.pos,
+ operand1 = arg_node,
+ operator = operator,
+ operand2 = result_ref,
+ )
+ )
+ last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
+
+ for ref_node in cascaded_nodes[::-1]:
+ last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
+
+ return last_result
+
+ # builtin type creation
+
+ def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
+ if not pos_args:
+ return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
+ # This is a bit special - for iterables (including genexps),
+ # Python actually overallocates and resizes a newly created
+ # tuple incrementally while reading items, which we can't
+ # easily do without explicit node support. Instead, we read
+ # the items into a list and then copy them into a tuple of the
+ # final size. This takes up to twice as much memory, but will
+ # have to do until we have real support for genexps.
+ result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
+ if result is not node:
+ return ExprNodes.AsTupleNode(node.pos, arg=result)
+ return node
+
+ def _handle_simple_function_frozenset(self, node, pos_args):
+ """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
+ """
+ if len(pos_args) != 1:
+ return node
+ if pos_args[0].is_sequence_constructor and not pos_args[0].args:
+ del pos_args[0]
+ elif isinstance(pos_args[0], ExprNodes.ListNode):
+ pos_args[0] = pos_args[0].as_tuple()
+ return node
+
+ def _handle_simple_function_list(self, node, pos_args):
+ if not pos_args:
+ return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
+ return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
+
+ def _handle_simple_function_set(self, node, pos_args):
+ if not pos_args:
+ return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
+ return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
+
+ def _transform_list_set_genexpr(self, node, pos_args, target_type):
+ """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
+ """
+ if len(pos_args) > 1:
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ yield_statements = _find_yield_statements(loop_node)
+ if not yield_statements:
+ return node
+
+ result_node = ExprNodes.InlinedGeneratorExpressionNode(
+ node.pos, gen_expr_node,
+ orig_func='set' if target_type is Builtin.set_type else 'list',
+ comprehension_type=target_type)
+
+ for yield_expression, yield_stat_node in yield_statements:
+ append_node = ExprNodes.ComprehensionAppendNode(
+ yield_expression.pos,
+ expr=yield_expression,
+ target=result_node.target)
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
+
+ return result_node
+
+ def _handle_simple_function_dict(self, node, pos_args):
+ """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
+ """
+ if len(pos_args) == 0:
+ return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
+ if len(pos_args) > 1:
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ yield_statements = _find_yield_statements(loop_node)
+ if not yield_statements:
+ return node
+
+ for yield_expression, _ in yield_statements:
+ if not isinstance(yield_expression, ExprNodes.TupleNode):
+ return node
+ if len(yield_expression.args) != 2:
+ return node
+
+ result_node = ExprNodes.InlinedGeneratorExpressionNode(
+ node.pos, gen_expr_node, orig_func='dict',
+ comprehension_type=Builtin.dict_type)
+
+ for yield_expression, yield_stat_node in yield_statements:
+ append_node = ExprNodes.DictComprehensionAppendNode(
+ yield_expression.pos,
+ key_expr=yield_expression.args[0],
+ value_expr=yield_expression.args[1],
+ target=result_node.target)
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
+
+ return result_node
+
+ # specific handlers for general call nodes
+
+ def _handle_general_function_dict(self, node, pos_args, kwargs):
+ """Replace dict(a=b,c=d,...) by the underlying keyword dict
+ construction which is done anyway.
+ """
+ if len(pos_args) > 0:
+ return node
+ if not isinstance(kwargs, ExprNodes.DictNode):
+ return node
+ return kwargs
+
+
+class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+ def get_constant_value_node(self, name_node):
+ if name_node.cf_state is None:
+ return None
+ if name_node.cf_state.cf_is_null:
+ return None
+ entry = self.current_env().lookup(name_node.name)
+ if not entry or (not entry.cf_assignments
+ or len(entry.cf_assignments) != 1):
+ # not just a single assignment in all closures
+ return None
+ return entry.cf_assignments[0].rhs
+
+ def visit_SimpleCallNode(self, node):
+ self.visitchildren(node)
+ if not self.current_directives.get('optimize.inline_defnode_calls'):
+ return node
+ function_name = node.function
+ if not function_name.is_name:
+ return node
+ function = self.get_constant_value_node(function_name)
+ if not isinstance(function, ExprNodes.PyCFunctionNode):
+ return node
+ inlined = ExprNodes.InlinedDefNodeCallNode(
+ node.pos, function_name=function_name,
+ function=function, args=node.args)
+ if inlined.can_be_inlined():
+ return self.replace(node, inlined)
+ return node
+
+
+class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
+ Visitor.MethodDispatcherTransform):
+ """Optimize some common methods calls and instantiation patterns
+ for builtin types *after* the type analysis phase.
+
+ Running after type analysis, this transform can only perform
+ function replacements that do not alter the function return type
+ in a way that was not anticipated by the type analysis.
+ """
+ ### cleanup to avoid redundant coercions to/from Python types
+
+ def visit_PyTypeTestNode(self, node):
+ """Flatten redundant type checks after tree changes.
+ """
+ self.visitchildren(node)
+ return node.reanalyse()
+
+ def _visit_TypecastNode(self, node):
+ # disabled - the user may have had a reason to put a type
+ # cast, even if it looks redundant to Cython
+ """
+ Drop redundant type casts.
+ """
+ self.visitchildren(node)
+ if node.type == node.operand.type:
+ return node.operand
+ return node
+
+ def visit_ExprStatNode(self, node):
+ """
+ Drop dead code and useless coercions.
+ """
+ self.visitchildren(node)
+ if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
+ node.expr = node.expr.arg
+ expr = node.expr
+ if expr is None or expr.is_none or expr.is_literal:
+ # Expression was removed or is dead code => remove ExprStatNode as well.
+ return None
+ if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
+ # Ignore dead references to local variables etc.
+ return None
+ return node
+
+ def visit_CoerceToBooleanNode(self, node):
+ """Drop redundant conversion nodes after tree changes.
+ """
+ self.visitchildren(node)
+ arg = node.arg
+ if isinstance(arg, ExprNodes.PyTypeTestNode):
+ arg = arg.arg
+ if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
+ if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
+ return arg.arg.coerce_to_boolean(self.current_env())
+ return node
+
+ PyNumber_Float_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
+ ])
+
+ def visit_CoerceToPyTypeNode(self, node):
+ """Drop redundant conversion nodes after tree changes."""
+ self.visitchildren(node)
+ arg = node.arg
+ if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
+ arg = arg.arg
+ if isinstance(arg, ExprNodes.PythonCapiCallNode):
+ if arg.function.name == 'float' and len(arg.args) == 1:
+ # undo redundant Py->C->Py coercion
+ func_arg = arg.args[0]
+ if func_arg.type is Builtin.float_type:
+ return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
+ elif func_arg.type.is_pyobject:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
+ args=[func_arg],
+ py_name='float',
+ is_temp=node.is_temp,
+ result_is_used=node.result_is_used,
+ ).coerce_to(node.type, self.current_env())
+ return node
+
+ def visit_CoerceFromPyTypeNode(self, node):
+ """Drop redundant conversion nodes after tree changes.
+
+ Also, optimise away calls to Python's builtin int() and
+ float() if the result is going to be coerced back into a C
+ type anyway.
+ """
+ self.visitchildren(node)
+ arg = node.arg
+ if not arg.type.is_pyobject:
+ # no Python conversion left at all, just do a C coercion instead
+ if node.type != arg.type:
+ arg = arg.coerce_to(node.type, self.current_env())
+ return arg
+ if isinstance(arg, ExprNodes.PyTypeTestNode):
+ arg = arg.arg
+ if arg.is_literal:
+ if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
+ node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
+ node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
+ return arg.coerce_to(node.type, self.current_env())
+ elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
+ if arg.type is PyrexTypes.py_object_type:
+ if node.type.assignable_from(arg.arg.type):
+ # completely redundant C->Py->C coercion
+ return arg.arg.coerce_to(node.type, self.current_env())
+ elif arg.type is Builtin.unicode_type:
+ if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
+ return arg.arg.coerce_to(node.type, self.current_env())
+ elif isinstance(arg, ExprNodes.SimpleCallNode):
+ if node.type.is_int or node.type.is_float:
+ return self._optimise_numeric_cast_call(node, arg)
+ elif arg.is_subscript:
+ index_node = arg.index
+ if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
+ index_node = index_node.arg
+ if index_node.type.is_int:
+ return self._optimise_int_indexing(node, arg, index_node)
+ return node
+
+ PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_char_type, [
+ PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
+ PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
+ ],
+ exception_value = "((char)-1)",
+ exception_check = True)
+
+ def _optimise_int_indexing(self, coerce_node, arg, index_node):
+ env = self.current_env()
+ bound_check_bool = env.directives['boundscheck'] and 1 or 0
+ if arg.base.type is Builtin.bytes_type:
+ if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
+ # bytes[index] -> char
+ bound_check_node = ExprNodes.IntNode(
+ coerce_node.pos, value=str(bound_check_bool),
+ constant_result=bound_check_bool)
+ node = ExprNodes.PythonCapiCallNode(
+ coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
+ self.PyBytes_GetItemInt_func_type,
+ args=[
+ arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
+ index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
+ bound_check_node,
+ ],
+ is_temp=True,
+ utility_code=UtilityCode.load_cached(
+ 'bytes_index', 'StringTools.c'))
+ if coerce_node.type is not PyrexTypes.c_char_type:
+ node = node.coerce_to(coerce_node.type, env)
+ return node
+ return coerce_node
+
+ float_float_func_types = dict(
+ (float_type, PyrexTypes.CFuncType(
+ float_type, [
+ PyrexTypes.CFuncTypeArg("arg", float_type, None)
+ ]))
+ for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type))
+
+ def _optimise_numeric_cast_call(self, node, arg):
+ function = arg.function
+ args = None
+ if isinstance(arg, ExprNodes.PythonCapiCallNode):
+ args = arg.args
+ elif isinstance(function, ExprNodes.NameNode):
+ if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
+ args = arg.arg_tuple.args
+
+ if args is None or len(args) != 1:
+ return node
+ func_arg = args[0]
+ if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
+ func_arg = func_arg.arg
+ elif func_arg.type.is_pyobject:
+ # play it safe: Python conversion might work on all sorts of things
+ return node
+
+ if function.name == 'int':
+ if func_arg.type.is_int or node.type.is_int:
+ if func_arg.type == node.type:
+ return func_arg
+ elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
+ return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
+ elif func_arg.type.is_float and node.type.is_numeric:
+ if func_arg.type.math_h_modifier == 'l':
+ # Work around missing Cygwin definition.
+ truncl = '__Pyx_truncl'
+ else:
+ truncl = 'trunc' + func_arg.type.math_h_modifier
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, truncl,
+ func_type=self.float_float_func_types[func_arg.type],
+ args=[func_arg],
+ py_name='int',
+ is_temp=node.is_temp,
+ result_is_used=node.result_is_used,
+ ).coerce_to(node.type, self.current_env())
+ elif function.name == 'float':
+ if func_arg.type.is_float or node.type.is_float:
+ if func_arg.type == node.type:
+ return func_arg
+ elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
+ return ExprNodes.TypecastNode(
+ node.pos, operand=func_arg, type=node.type)
+ return node
+
+ def _error_wrong_arg_count(self, function_name, node, args, expected=None):
+ if not expected: # None or 0
+ arg_str = ''
+ elif isinstance(expected, basestring) or expected > 1:
+ arg_str = '...'
+ elif expected == 1:
+ arg_str = 'x'
+ else:
+ arg_str = ''
+ if expected is not None:
+ expected_str = 'expected %s, ' % expected
+ else:
+ expected_str = ''
+ error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
+ function_name, arg_str, expected_str, len(args)))
+
+ ### generic fallbacks
+
+ def _handle_function(self, node, function_name, function, arg_list, kwargs):
+ return node
+
+ def _handle_method(self, node, type_name, attr_name, function,
+ arg_list, is_unbound_method, kwargs):
+ """
+ Try to inject C-API calls for unbound method calls to builtin types.
+ While the method declarations in Builtin.py already handle this, we
+ can additionally resolve bound and unbound methods here that were
+ assigned to variables ahead of time.
+ """
+ if kwargs:
+ return node
+ if not function or not function.is_attribute or not function.obj.is_name:
+ # cannot track unbound method calls over more than one indirection as
+ # the names might have been reassigned in the meantime
+ return node
+ type_entry = self.current_env().lookup(type_name)
+ if not type_entry:
+ return node
+ method = ExprNodes.AttributeNode(
+ node.function.pos,
+ obj=ExprNodes.NameNode(
+ function.pos,
+ name=type_name,
+ entry=type_entry,
+ type=type_entry.type),
+ attribute=attr_name,
+ is_called=True).analyse_as_type_attribute(self.current_env())
+ if method is None:
+ return self._optimise_generic_builtin_method_call(
+ node, attr_name, function, arg_list, is_unbound_method)
+ args = node.args
+ if args is None and node.arg_tuple:
+ args = node.arg_tuple.args
+ call_node = ExprNodes.SimpleCallNode(
+ node.pos,
+ function=method,
+ args=args)
+ if not is_unbound_method:
+ call_node.self = function.obj
+ call_node.analyse_c_function_call(self.current_env())
+ call_node.analysed = True
+ return call_node.coerce_to(node.type, self.current_env())
+
+ ### builtin types
+
+ def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
+ """
+ Try to inject an unbound method call for a call to a method of a known builtin type.
+ This enables caching the underlying C function of the method at runtime.
+ """
+ arg_count = len(arg_list)
+ if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
+ return node
+ if not function.obj.type.is_builtin_type:
+ return node
+ if function.obj.type.name in ('basestring', 'type'):
+ # these allow different actual types => unsafe
+ return node
+ return ExprNodes.CachedBuiltinMethodCallNode(
+ node, function.obj, attr_name, arg_list)
+
+ PyObject_Unicode_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_unicode(self, node, function, pos_args):
+ """Optimise single argument calls to unicode().
+ """
+ if len(pos_args) != 1:
+ if len(pos_args) == 0:
+ return ExprNodes.UnicodeNode(node.pos, value=EncodedString(), constant_result=u'')
+ return node
+ arg = pos_args[0]
+ if arg.type is Builtin.unicode_type:
+ if not arg.may_be_none():
+ return arg
+ cname = "__Pyx_PyUnicode_Unicode"
+ utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c')
+ else:
+ cname = "__Pyx_PyObject_Unicode"
+ utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c')
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, cname, self.PyObject_Unicode_func_type,
+ args=pos_args,
+ is_temp=node.is_temp,
+ utility_code=utility_code,
+ py_name="unicode")
+
+ def visit_FormattedValueNode(self, node):
+ """Simplify or avoid plain string formatting of a unicode value.
+ This seems misplaced here, but plain unicode formatting is essentially
+ a call to the unicode() builtin, which is optimised right above.
+ """
+ self.visitchildren(node)
+ if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec:
+ if not node.conversion_char or node.conversion_char == 's':
+ # value is definitely a unicode string and we don't format it any special
+ return self._handle_simple_function_unicode(node, None, [node.value])
+ return node
+
+ PyDict_Copy_func_type = PyrexTypes.CFuncType(
+ Builtin.dict_type, [
+ PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
+ ])
+
+ def _handle_simple_function_dict(self, node, function, pos_args):
+ """Replace dict(some_dict) by PyDict_Copy(some_dict).
+ """
+ if len(pos_args) != 1:
+ return node
+ arg = pos_args[0]
+ if arg.type is Builtin.dict_type:
+ arg = arg.as_none_safe_node("'NoneType' is not iterable")
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
+ args = [arg],
+ is_temp = node.is_temp
+ )
+ return node
+
+ PySequence_List_func_type = PyrexTypes.CFuncType(
+ Builtin.list_type,
+ [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
+
+ def _handle_simple_function_list(self, node, function, pos_args):
+ """Turn list(ob) into PySequence_List(ob).
+ """
+ if len(pos_args) != 1:
+ return node
+ arg = pos_args[0]
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PySequence_List", self.PySequence_List_func_type,
+ args=pos_args, is_temp=node.is_temp)
+
+ PyList_AsTuple_func_type = PyrexTypes.CFuncType(
+ Builtin.tuple_type, [
+ PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
+ ])
+
+ def _handle_simple_function_tuple(self, node, function, pos_args):
+ """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
+ """
+ if len(pos_args) != 1 or not node.is_temp:
+ return node
+ arg = pos_args[0]
+ if arg.type is Builtin.tuple_type and not arg.may_be_none():
+ return arg
+ if arg.type is Builtin.list_type:
+ pos_args[0] = arg.as_none_safe_node(
+ "'NoneType' object is not iterable")
+
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
+ args=pos_args, is_temp=node.is_temp)
+ else:
+ return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
+
+ PySet_New_func_type = PyrexTypes.CFuncType(
+ Builtin.set_type, [
+ PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_set(self, node, function, pos_args):
+ if len(pos_args) != 1:
+ return node
+ if pos_args[0].is_sequence_constructor:
+ # We can optimise set([x,y,z]) safely into a set literal,
+ # but only if we create all items before adding them -
+ # adding an item may raise an exception if it is not
+ # hashable, but creating the later items may have
+ # side-effects.
+ args = []
+ temps = []
+ for arg in pos_args[0].args:
+ if not arg.is_simple():
+ arg = UtilNodes.LetRefNode(arg)
+ temps.append(arg)
+ args.append(arg)
+ result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
+ self.replace(node, result)
+ for temp in temps[::-1]:
+ result = UtilNodes.EvalWithTempExprNode(temp, result)
+ return result
+ else:
+ # PySet_New(it) is better than a generic Python call to set(it)
+ return self.replace(node, ExprNodes.PythonCapiCallNode(
+ node.pos, "PySet_New",
+ self.PySet_New_func_type,
+ args=pos_args,
+ is_temp=node.is_temp,
+ py_name="set"))
+
+ PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
+ Builtin.frozenset_type, [
+ PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_frozenset(self, node, function, pos_args):
+ if not pos_args:
+ pos_args = [ExprNodes.NullNode(node.pos)]
+ elif len(pos_args) > 1:
+ return node
+ elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
+ return pos_args[0]
+ # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyFrozenSet_New",
+ self.PyFrozenSet_New_func_type,
+ args=pos_args,
+ is_temp=node.is_temp,
+ utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
+ py_name="frozenset")
+
+ PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_double_type, [
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+ ],
+ exception_value = "((double)-1)",
+ exception_check = True)
+
+ def _handle_simple_function_float(self, node, function, pos_args):
+ """Transform float() into either a C type cast or a faster C
+ function call.
+ """
+ # Note: this requires the float() function to be typed as
+ # returning a C 'double'
+ if len(pos_args) == 0:
+ return ExprNodes.FloatNode(
+ node, value="0.0", constant_result=0.0
+ ).coerce_to(Builtin.float_type, self.current_env())
+ elif len(pos_args) != 1:
+ self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
+ return node
+ func_arg = pos_args[0]
+ if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
+ func_arg = func_arg.arg
+ if func_arg.type is PyrexTypes.c_double_type:
+ return func_arg
+ elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
+ return ExprNodes.TypecastNode(
+ node.pos, operand=func_arg, type=node.type)
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyObject_AsDouble",
+ self.PyObject_AsDouble_func_type,
+ args = pos_args,
+ is_temp = node.is_temp,
+ utility_code = load_c_utility('pyobject_as_double'),
+ py_name = "float")
+
+ PyNumber_Int_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
+ ])
+
+ PyInt_FromDouble_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
+ ])
+
+ def _handle_simple_function_int(self, node, function, pos_args):
+ """Transform int() into a faster C function call.
+ """
+ if len(pos_args) == 0:
+ return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
+ type=PyrexTypes.py_object_type)
+ elif len(pos_args) != 1:
+ return node # int(x, base)
+ func_arg = pos_args[0]
+ if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
+ if func_arg.arg.type.is_float:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type,
+ args=[func_arg.arg], is_temp=True, py_name='int',
+ utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c"))
+ else:
+ return node # handled in visit_CoerceFromPyTypeNode()
+ if func_arg.type.is_pyobject and node.type.is_pyobject:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
+ args=pos_args, is_temp=True, py_name='int')
+ return node
+
+ def _handle_simple_function_bool(self, node, function, pos_args):
+ """Transform bool(x) into a type coercion to a boolean.
+ """
+ if len(pos_args) == 0:
+ return ExprNodes.BoolNode(
+ node.pos, value=False, constant_result=False
+ ).coerce_to(Builtin.bool_type, self.current_env())
+ elif len(pos_args) != 1:
+ self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
+ return node
+ else:
+ # => !!<bint>(x) to make sure it's exactly 0 or 1
+ operand = pos_args[0].coerce_to_boolean(self.current_env())
+ operand = ExprNodes.NotNode(node.pos, operand = operand)
+ operand = ExprNodes.NotNode(node.pos, operand = operand)
+ # coerce back to Python object as that's the result we are expecting
+ return operand.coerce_to_pyobject(self.current_env())
+
+ ### builtin functions
+
+ Pyx_strlen_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_size_t_type, [
+ PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
+ ])
+
+ Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_size_t_type, [
+ PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
+ ])
+
+ PyObject_Size_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ssize_t_type, [
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
+ ],
+ exception_value="-1")
+
+ _map_to_capi_len_function = {
+ Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH",
+ Builtin.bytes_type: "PyBytes_GET_SIZE",
+ Builtin.bytearray_type: 'PyByteArray_GET_SIZE',
+ Builtin.list_type: "PyList_GET_SIZE",
+ Builtin.tuple_type: "PyTuple_GET_SIZE",
+ Builtin.set_type: "PySet_GET_SIZE",
+ Builtin.frozenset_type: "PySet_GET_SIZE",
+ Builtin.dict_type: "PyDict_Size",
+ }.get
+
+ _ext_types_with_pysize = set(["cpython.array.array"])
+
+ def _handle_simple_function_len(self, node, function, pos_args):
+ """Replace len(char*) by the equivalent call to strlen(),
+ len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
+ len(known_builtin_type) by an equivalent C-API call.
+ """
+ if len(pos_args) != 1:
+ self._error_wrong_arg_count('len', node, pos_args, 1)
+ return node
+ arg = pos_args[0]
+ if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
+ arg = arg.arg
+ if arg.type.is_string:
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, "strlen", self.Pyx_strlen_func_type,
+ args = [arg],
+ is_temp = node.is_temp,
+ utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
+ elif arg.type.is_pyunicode_ptr:
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
+ args = [arg],
+ is_temp = node.is_temp)
+ elif arg.type.is_memoryviewslice:
+ func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_size_t_type, [
+ PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
+ ], nogil=True)
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_MemoryView_Len", func_type,
+ args=[arg], is_temp=node.is_temp)
+ elif arg.type.is_pyobject:
+ cfunc_name = self._map_to_capi_len_function(arg.type)
+ if cfunc_name is None:
+ arg_type = arg.type
+ if ((arg_type.is_extension_type or arg_type.is_builtin_type)
+ and arg_type.entry.qualified_name in self._ext_types_with_pysize):
+ cfunc_name = 'Py_SIZE'
+ else:
+ return node
+ arg = arg.as_none_safe_node(
+ "object of type 'NoneType' has no len()")
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, cfunc_name, self.PyObject_Size_func_type,
+ args=[arg], is_temp=node.is_temp)
+ elif arg.type.is_unicode_char:
+ return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
+ type=node.type)
+ else:
+ return node
+ if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
+ new_node = new_node.coerce_to(node.type, self.current_env())
+ return new_node
+
+ Pyx_Type_func_type = PyrexTypes.CFuncType(
+ Builtin.type_type, [
+ PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_type(self, node, function, pos_args):
+ """Replace type(o) by a macro call to Py_TYPE(o).
+ """
+ if len(pos_args) != 1:
+ return node
+ node = ExprNodes.PythonCapiCallNode(
+ node.pos, "Py_TYPE", self.Pyx_Type_func_type,
+ args = pos_args,
+ is_temp = False)
+ return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
+
+ Py_type_check_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_bint_type, [
+ PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
+ ])
+
+ def _handle_simple_function_isinstance(self, node, function, pos_args):
+ """Replace isinstance() checks against builtin types by the
+ corresponding C-API call.
+ """
+ if len(pos_args) != 2:
+ return node
+ arg, types = pos_args
+ temps = []
+ if isinstance(types, ExprNodes.TupleNode):
+ types = types.args
+ if len(types) == 1 and not types[0].type is Builtin.type_type:
+ return node # nothing to improve here
+ if arg.is_attribute or not arg.is_simple():
+ arg = UtilNodes.ResultRefNode(arg)
+ temps.append(arg)
+ elif types.type is Builtin.type_type:
+ types = [types]
+ else:
+ return node
+
+ tests = []
+ test_nodes = []
+ env = self.current_env()
+ for test_type_node in types:
+ builtin_type = None
+ if test_type_node.is_name:
+ if test_type_node.entry:
+ entry = env.lookup(test_type_node.entry.name)
+ if entry and entry.type and entry.type.is_builtin_type:
+ builtin_type = entry.type
+ if builtin_type is Builtin.type_type:
+ # all types have type "type", but there's only one 'type'
+ if entry.name != 'type' or not (
+ entry.scope and entry.scope.is_builtin_scope):
+ builtin_type = None
+ if builtin_type is not None:
+ type_check_function = entry.type.type_check_function(exact=False)
+ if type_check_function in tests:
+ continue
+ tests.append(type_check_function)
+ type_check_args = [arg]
+ elif test_type_node.type is Builtin.type_type:
+ type_check_function = '__Pyx_TypeCheck'
+ type_check_args = [arg, test_type_node]
+ else:
+ if not test_type_node.is_literal:
+ test_type_node = UtilNodes.ResultRefNode(test_type_node)
+ temps.append(test_type_node)
+ type_check_function = 'PyObject_IsInstance'
+ type_check_args = [arg, test_type_node]
+ test_nodes.append(
+ ExprNodes.PythonCapiCallNode(
+ test_type_node.pos, type_check_function, self.Py_type_check_func_type,
+ args=type_check_args,
+ is_temp=True,
+ ))
+
+ def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
+ or_node = make_binop_node(node.pos, 'or', a, b)
+ or_node.type = PyrexTypes.c_bint_type
+ or_node.wrap_operands(env)
+ return or_node
+
+ test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
+ for temp in temps[::-1]:
+ test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
+ return test_node
+
+ def _handle_simple_function_ord(self, node, function, pos_args):
+ """Unpack ord(Py_UNICODE) and ord('X').
+ """
+ if len(pos_args) != 1:
+ return node
+ arg = pos_args[0]
+ if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
+ if arg.arg.type.is_unicode_char:
+ return ExprNodes.TypecastNode(
+ arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
+ ).coerce_to(node.type, self.current_env())
+ elif isinstance(arg, ExprNodes.UnicodeNode):
+ if len(arg.value) == 1:
+ return ExprNodes.IntNode(
+ arg.pos, type=PyrexTypes.c_int_type,
+ value=str(ord(arg.value)),
+ constant_result=ord(arg.value)
+ ).coerce_to(node.type, self.current_env())
+ elif isinstance(arg, ExprNodes.StringNode):
+ if arg.unicode_value and len(arg.unicode_value) == 1 \
+ and ord(arg.unicode_value) <= 255: # Py2/3 portability
+ return ExprNodes.IntNode(
+ arg.pos, type=PyrexTypes.c_int_type,
+ value=str(ord(arg.unicode_value)),
+ constant_result=ord(arg.unicode_value)
+ ).coerce_to(node.type, self.current_env())
+ return node
+
+ ### special methods
+
+ Pyx_tp_new_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
+ ])
+
+ Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
+ PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
+ ])
+
+ def _handle_any_slot__new__(self, node, function, args,
+ is_unbound_method, kwargs=None):
+ """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
+ """
+ obj = function.obj
+ if not is_unbound_method or len(args) < 1:
+ return node
+ type_arg = args[0]
+ if not obj.is_name or not type_arg.is_name:
+ # play safe
+ return node
+ if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
+ # not a known type, play safe
+ return node
+ if not type_arg.type_entry or not obj.type_entry:
+ if obj.name != type_arg.name:
+ return node
+ # otherwise, we know it's a type and we know it's the same
+ # type for both - that should do
+ elif type_arg.type_entry != obj.type_entry:
+ # different types - may or may not lead to an error at runtime
+ return node
+
+ args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
+ args_tuple = args_tuple.analyse_types(
+ self.current_env(), skip_children=True)
+
+ if type_arg.type_entry:
+ ext_type = type_arg.type_entry.type
+ if (ext_type.is_extension_type and ext_type.typeobj_cname and
+ ext_type.scope.global_scope() == self.current_env().global_scope()):
+ # known type in current module
+ tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
+ slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
+ if slot_func_cname:
+ cython_scope = self.context.cython_scope
+ PyTypeObjectPtr = PyrexTypes.CPtrType(
+ cython_scope.lookup('PyTypeObject').type)
+ pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
+ ext_type, [
+ PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None),
+ PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
+ ])
+
+ type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
+ if not kwargs:
+ kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack?
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, slot_func_cname,
+ pyx_tp_new_kwargs_func_type,
+ args=[type_arg, args_tuple, kwargs],
+ may_return_none=False,
+ is_temp=True)
+ else:
+ # arbitrary variable, needs a None check for safety
+ type_arg = type_arg.as_none_safe_node(
+ "object.__new__(X): X is not a type object (NoneType)")
+
+ utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
+ if kwargs:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
+ args=[type_arg, args_tuple, kwargs],
+ utility_code=utility_code,
+ is_temp=node.is_temp
+ )
+ else:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
+ args=[type_arg, args_tuple],
+ utility_code=utility_code,
+ is_temp=node.is_temp
+ )
+
+ ### methods of builtin types
+
+ PyObject_Append_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_returncode_type, [
+ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
+ ],
+ exception_value="-1")
+
+ def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
+ """Optimistic optimisation as X.append() is almost always
+ referring to a list.
+ """
+ if len(args) != 2 or node.result_is_used or node.function.entry:
+ return node
+
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
+ args=args,
+ may_return_none=False,
+ is_temp=node.is_temp,
+ result_is_used=False,
+ utility_code=load_c_utility('append')
+ )
+
+ def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
+ """Replace list.extend([...]) for short sequence literals values by sequential appends
+ to avoid creating an intermediate sequence argument.
+ """
+ if len(args) != 2:
+ return node
+ obj, value = args
+ if not value.is_sequence_constructor:
+ return node
+ items = list(value.args)
+ if value.mult_factor is not None or len(items) > 8:
+ # Appending wins for short sequences but slows down when multiple resize operations are needed.
+ # This seems to be a good enough limit that avoids repeated resizing.
+ if False and isinstance(value, ExprNodes.ListNode):
+ # One would expect that tuples are more efficient here, but benchmarking with
+ # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point.
+ # Might be related to the usage of PySequence_FAST() in CPython's list.extend(),
+ # which is probably tuned more towards lists than tuples (and rightly so).
+ tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True)
+ Visitor.recursively_replace_node(node, args[1], tuple_node)
+ return node
+ wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
+ if not items:
+ # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
+ wrapped_obj.result_is_used = node.result_is_used
+ return wrapped_obj
+ cloned_obj = obj = wrapped_obj
+ if len(items) > 1 and not obj.is_simple():
+ cloned_obj = UtilNodes.LetRefNode(obj)
+ # Use ListComp_Append() for all but the last item and finish with PyList_Append()
+ # to shrink the list storage size at the very end if necessary.
+ temps = []
+ arg = items[-1]
+ if not arg.is_simple():
+ arg = UtilNodes.LetRefNode(arg)
+ temps.append(arg)
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
+ args=[cloned_obj, arg],
+ is_temp=True,
+ utility_code=load_c_utility("ListAppend"))
+ for arg in items[-2::-1]:
+ if not arg.is_simple():
+ arg = UtilNodes.LetRefNode(arg)
+ temps.append(arg)
+ new_node = ExprNodes.binop_node(
+ node.pos, '|',
+ ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
+ args=[cloned_obj, arg], py_name="extend",
+ is_temp=True,
+ utility_code=load_c_utility("ListCompAppend")),
+ new_node,
+ type=PyrexTypes.c_returncode_type,
+ )
+ new_node.result_is_used = node.result_is_used
+ if cloned_obj is not obj:
+ temps.append(cloned_obj)
+ for temp in temps:
+ new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
+ new_node.result_is_used = node.result_is_used
+ return new_node
+
+ PyByteArray_Append_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_returncode_type, [
+ PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
+ ],
+ exception_value="-1")
+
+ PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_returncode_type, [
+ PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
+ ],
+ exception_value="-1")
+
+ def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
+ if len(args) != 2:
+ return node
+ func_name = "__Pyx_PyByteArray_Append"
+ func_type = self.PyByteArray_Append_func_type
+
+ value = unwrap_coerced_node(args[1])
+ if value.type.is_int or isinstance(value, ExprNodes.IntNode):
+ value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
+ utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
+ elif value.is_string_literal:
+ if not value.can_coerce_to_char_literal():
+ return node
+ value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
+ utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
+ elif value.type.is_pyobject:
+ func_name = "__Pyx_PyByteArray_AppendObject"
+ func_type = self.PyByteArray_AppendObject_func_type
+ utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
+ else:
+ return node
+
+ new_node = ExprNodes.PythonCapiCallNode(
+ node.pos, func_name, func_type,
+ args=[args[0], value],
+ may_return_none=False,
+ is_temp=node.is_temp,
+ utility_code=utility_code,
+ )
+ if node.result_is_used:
+ new_node = new_node.coerce_to(node.type, self.current_env())
+ return new_node
+
+ PyObject_Pop_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
+ ])
+
+ PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
+ ],
+ has_varargs=True) # to fake the additional macro args that lack a proper C type
+
+ def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
+ return self._handle_simple_method_object_pop(
+ node, function, args, is_unbound_method, is_list=True)
+
+ def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
+ """Optimistic optimisation as X.pop([n]) is almost always
+ referring to a list.
+ """
+ if not args:
+ return node
+ obj = args[0]
+ if is_list:
+ type_name = 'List'
+ obj = obj.as_none_safe_node(
+ "'NoneType' object has no attribute '%.30s'",
+ error="PyExc_AttributeError",
+ format_args=['pop'])
+ else:
+ type_name = 'Object'
+ if len(args) == 1:
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_Py%s_Pop" % type_name,
+ self.PyObject_Pop_func_type,
+ args=[obj],
+ may_return_none=True,
+ is_temp=node.is_temp,
+ utility_code=load_c_utility('pop'),
+ )
+ elif len(args) == 2:
+ index = unwrap_coerced_node(args[1])
+ py_index = ExprNodes.NoneNode(index.pos)
+ orig_index_type = index.type
+ if not index.type.is_int:
+ if isinstance(index, ExprNodes.IntNode):
+ py_index = index.coerce_to_pyobject(self.current_env())
+ index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ elif is_list:
+ if index.type.is_pyobject:
+ py_index = index.coerce_to_simple(self.current_env())
+ index = ExprNodes.CloneNode(py_index)
+ index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ else:
+ return node
+ elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
+ return node
+ elif isinstance(index, ExprNodes.IntNode):
+ py_index = index.coerce_to_pyobject(self.current_env())
+ # real type might still be larger at runtime
+ if not orig_index_type.is_int:
+ orig_index_type = index.type
+ if not orig_index_type.create_to_py_utility_code(self.current_env()):
+ return node
+ convert_func = orig_index_type.to_py_function
+ conversion_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "__Pyx_Py%s_PopIndex" % type_name,
+ self.PyObject_PopIndex_func_type,
+ args=[obj, py_index, index,
+ ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
+ constant_result=orig_index_type.signed and 1 or 0,
+ type=PyrexTypes.c_int_type),
+ ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
+ orig_index_type.empty_declaration_code()),
+ ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
+ may_return_none=True,
+ is_temp=node.is_temp,
+ utility_code=load_c_utility("pop_index"),
+ )
+
+ return node
+
+ single_param_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_returncode_type, [
+ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+ ],
+ exception_value = "-1")
+
+ def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
+ """Call PyList_Sort() instead of the 0-argument l.sort().
+ """
+ if len(args) != 1:
+ return node
+ return self._substitute_method_call(
+ node, function, "PyList_Sort", self.single_param_func_type,
+ 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
+
+ Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
+ ])
+
+ def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
+ """Replace dict.get() by a call to PyDict_GetItem().
+ """
+ if len(args) == 2:
+ args.append(ExprNodes.NoneNode(node.pos))
+ elif len(args) != 3:
+ self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
+ return node
+
+ return self._substitute_method_call(
+ node, function,
+ "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
+ 'get', is_unbound_method, args,
+ may_return_none = True,
+ utility_code = load_c_utility("dict_getitem_default"))
+
+ Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
+ ])
+
+ def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
+ """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
+ """
+ if len(args) == 2:
+ args.append(ExprNodes.NoneNode(node.pos))
+ elif len(args) != 3:
+ self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
+ return node
+ key_type = args[1].type
+ if key_type.is_builtin_type:
+ is_safe_type = int(key_type.name in
+ 'str bytes unicode float int long bool')
+ elif key_type is PyrexTypes.py_object_type:
+ is_safe_type = -1 # don't know
+ else:
+ is_safe_type = 0 # definitely not
+ args.append(ExprNodes.IntNode(
+ node.pos, value=str(is_safe_type), constant_result=is_safe_type))
+
+ return self._substitute_method_call(
+ node, function,
+ "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
+ 'setdefault', is_unbound_method, args,
+ may_return_none=True,
+ utility_code=load_c_utility('dict_setdefault'))
+
+ PyDict_Pop_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.py_object_type, [
+ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
+ ])
+
+ def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
+ """Replace dict.pop() by a call to _PyDict_Pop().
+ """
+ if len(args) == 2:
+ args.append(ExprNodes.NullNode(node.pos))
+ elif len(args) != 3:
+ self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
+ return node
+
+ return self._substitute_method_call(
+ node, function,
+ "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type,
+ 'pop', is_unbound_method, args,
+ may_return_none=True,
+ utility_code=load_c_utility('py_dict_pop'))
+
+ Pyx_BinopInt_func_types = dict(
+ ((ctype, ret_type), PyrexTypes.CFuncType(
+ ret_type, [
+ PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("cval", ctype, None),
+ PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
+ PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None),
+ ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value))
+ for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type)
+ for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type)
+ )
+
+ def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('And', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Or', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
+ if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
+ return node
+ if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
+ return node
+ return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
+ if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
+ return node
+ if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
+ return node
+ return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_div('Divide', node, function, args, is_unbound_method)
+
+ def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
+ if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
+ return node
+ if isinstance(args[1], ExprNodes.IntNode):
+ if not (-2**30 <= args[1].constant_result <= 2**30):
+ return node
+ elif isinstance(args[1], ExprNodes.FloatNode):
+ if not (-2**53 <= args[1].constant_result <= 2**53):
+ return node
+ else:
+ return node
+ return self._optimise_num_binop(operator, node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
+
+ def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method):
+ return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
+
+ def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
+ """
+ Optimise math operators for (likely) float or small integer operations.
+ """
+ if len(args) != 2:
+ return node
+
+ if node.type.is_pyobject:
+ ret_type = PyrexTypes.py_object_type
+ elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'):
+ ret_type = PyrexTypes.c_bint_type
+ else:
+ return node
+
+ # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
+ # Prefer constants on RHS as they allows better size control for some operators.
+ num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
+ if isinstance(args[1], num_nodes):
+ if args[0].type is not PyrexTypes.py_object_type:
+ return node
+ numval = args[1]
+ arg_order = 'ObjC'
+ elif isinstance(args[0], num_nodes):
+ if args[1].type is not PyrexTypes.py_object_type:
+ return node
+ numval = args[0]
+ arg_order = 'CObj'
+ else:
+ return node
+
+ if not numval.has_constant_result():
+ return node
+
+ is_float = isinstance(numval, ExprNodes.FloatNode)
+ num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
+ if is_float:
+ if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
+ return node
+ elif operator == 'Divide':
+ # mixed old-/new-style division is not currently optimised for integers
+ return node
+ elif abs(numval.constant_result) > 2**30:
+ # Cut off at an integer border that is still safe for all operations.
+ return node
+
+ if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
+ if args[1].constant_result == 0:
+ # Don't optimise division by 0. :)
+ return node
+
+ args = list(args)
+ args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
+ numval.pos, value=numval.value, constant_result=numval.constant_result,
+ type=num_type))
+ inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
+ args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
+ if is_float or operator not in ('Eq', 'Ne'):
+ # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument.
+ zerodivision_check = arg_order == 'CObj' and (
+ not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
+ args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
+
+ utility_code = TempitaUtilityCode.load_cached(
+ "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
+ "Optimize.c",
+ context=dict(op=operator, order=arg_order, ret_type=ret_type))
+
+ call_node = self._substitute_method_call(
+ node, function,
+ "__Pyx_Py%s_%s%s%s" % (
+ 'Float' if is_float else 'Int',
+ '' if ret_type.is_pyobject else 'Bool',
+ operator,
+ arg_order),
+ self.Pyx_BinopInt_func_types[(num_type, ret_type)],
+ '__%s__' % operator[:3].lower(), is_unbound_method, args,
+ may_return_none=True,
+ with_none_check=False,
+ utility_code=utility_code)
+
+ if node.type.is_pyobject and not ret_type.is_pyobject:
+ call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type)
+ return call_node
+
+ ### unicode type methods
+
+ PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_bint_type, [
+ PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
+ ])
+
+ def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
+ if is_unbound_method or len(args) != 1:
+ return node
+ ustring = args[0]
+ if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
+ not ustring.arg.type.is_unicode_char:
+ return node
+ uchar = ustring.arg
+ method_name = function.attribute
+ if method_name == 'istitle':
+ # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
+ utility_code = UtilityCode.load_cached(
+ "py_unicode_istitle", "StringTools.c")
+ function_name = '__Pyx_Py_UNICODE_ISTITLE'
+ else:
+ utility_code = None
+ function_name = 'Py_UNICODE_%s' % method_name.upper()
+ func_call = self._substitute_method_call(
+ node, function,
+ function_name, self.PyUnicode_uchar_predicate_func_type,
+ method_name, is_unbound_method, [uchar],
+ utility_code = utility_code)
+ if node.type.is_pyobject:
+ func_call = func_call.coerce_to_pyobject(self.current_env)
+ return func_call
+
+ _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
+ _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
+ _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
+ _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
+ _handle_simple_method_unicode_islower = _inject_unicode_predicate
+ _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
+ _handle_simple_method_unicode_isspace = _inject_unicode_predicate
+ _handle_simple_method_unicode_istitle = _inject_unicode_predicate
+ _handle_simple_method_unicode_isupper = _inject_unicode_predicate
+
+ PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ucs4_type, [
+ PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
+ ])
+
+ def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
+ if is_unbound_method or len(args) != 1:
+ return node
+ ustring = args[0]
+ if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
+ not ustring.arg.type.is_unicode_char:
+ return node
+ uchar = ustring.arg
+ method_name = function.attribute
+ function_name = 'Py_UNICODE_TO%s' % method_name.upper()
+ func_call = self._substitute_method_call(
+ node, function,
+ function_name, self.PyUnicode_uchar_conversion_func_type,
+ method_name, is_unbound_method, [uchar])
+ if node.type.is_pyobject:
+ func_call = func_call.coerce_to_pyobject(self.current_env)
+ return func_call
+
+ _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
+ _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
+ _handle_simple_method_unicode_title = _inject_unicode_character_conversion
+
+ PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
+ Builtin.list_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
+ ])
+
+ def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
+ """Replace unicode.splitlines(...) by a direct call to the
+ corresponding C-API function.
+ """
+ if len(args) not in (1,2):
+ self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
+ return node
+ self._inject_bint_default_argument(node, args, 1, False)
+
+ return self._substitute_method_call(
+ node, function,
+ "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
+ 'splitlines', is_unbound_method, args)
+
+ PyUnicode_Split_func_type = PyrexTypes.CFuncType(
+ Builtin.list_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
+ ]
+ )
+
+ def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
+ """Replace unicode.split(...) by a direct call to the
+ corresponding C-API function.
+ """
+ if len(args) not in (1,2,3):
+ self._error_wrong_arg_count('unicode.split', node, args, "1-3")
+ return node
+ if len(args) < 2:
+ args.append(ExprNodes.NullNode(node.pos))
+ self._inject_int_default_argument(
+ node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
+
+ return self._substitute_method_call(
+ node, function,
+ "PyUnicode_Split", self.PyUnicode_Split_func_type,
+ 'split', is_unbound_method, args)
+
+ PyUnicode_Join_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
+ ])
+
+ def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
+ """
+ unicode.join() builds a list first => see if we can do this more efficiently
+ """
+ if len(args) != 2:
+ self._error_wrong_arg_count('unicode.join', node, args, "2")
+ return node
+ if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
+ gen_expr_node = args[1]
+ loop_node = gen_expr_node.loop
+
+ yield_statements = _find_yield_statements(loop_node)
+ if yield_statements:
+ inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
+ node.pos, gen_expr_node, orig_func='list',
+ comprehension_type=Builtin.list_type)
+
+ for yield_expression, yield_stat_node in yield_statements:
+ append_node = ExprNodes.ComprehensionAppendNode(
+ yield_expression.pos,
+ expr=yield_expression,
+ target=inlined_genexpr.target)
+
+ Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
+
+ args[1] = inlined_genexpr
+
+ return self._substitute_method_call(
+ node, function,
+ "PyUnicode_Join", self.PyUnicode_Join_func_type,
+ 'join', is_unbound_method, args)
+
+ PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_bint_type, [
+ PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
+ PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
+ ],
+ exception_value = '-1')
+
+ def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'unicode', 'endswith',
+ unicode_tailmatch_utility_code, +1)
+
+ def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'unicode', 'startswith',
+ unicode_tailmatch_utility_code, -1)
+
+ def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
+ method_name, utility_code, direction):
+ """Replace unicode.startswith(...) and unicode.endswith(...)
+ by a direct call to the corresponding C-API function.
+ """
+ if len(args) not in (2,3,4):
+ self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
+ return node
+ self._inject_int_default_argument(
+ node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
+ self._inject_int_default_argument(
+ node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
+ args.append(ExprNodes.IntNode(
+ node.pos, value=str(direction), type=PyrexTypes.c_int_type))
+
+ method_call = self._substitute_method_call(
+ node, function,
+ "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
+ self.PyString_Tailmatch_func_type,
+ method_name, is_unbound_method, args,
+ utility_code = utility_code)
+ return method_call.coerce_to(Builtin.bool_type, self.current_env())
+
+ PyUnicode_Find_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ssize_t_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
+ ],
+ exception_value = '-2')
+
+ def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
+ return self._inject_unicode_find(
+ node, function, args, is_unbound_method, 'find', +1)
+
+ def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
+ return self._inject_unicode_find(
+ node, function, args, is_unbound_method, 'rfind', -1)
+
+ def _inject_unicode_find(self, node, function, args, is_unbound_method,
+ method_name, direction):
+ """Replace unicode.find(...) and unicode.rfind(...) by a
+ direct call to the corresponding C-API function.
+ """
+ if len(args) not in (2,3,4):
+ self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
+ return node
+ self._inject_int_default_argument(
+ node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
+ self._inject_int_default_argument(
+ node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
+ args.append(ExprNodes.IntNode(
+ node.pos, value=str(direction), type=PyrexTypes.c_int_type))
+
+ method_call = self._substitute_method_call(
+ node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
+ method_name, is_unbound_method, args)
+ return method_call.coerce_to_pyobject(self.current_env())
+
+ PyUnicode_Count_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_py_ssize_t_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
+ ],
+ exception_value = '-1')
+
+ def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
+ """Replace unicode.count(...) by a direct call to the
+ corresponding C-API function.
+ """
+ if len(args) not in (2,3,4):
+ self._error_wrong_arg_count('unicode.count', node, args, "2-4")
+ return node
+ self._inject_int_default_argument(
+ node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
+ self._inject_int_default_argument(
+ node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
+
+ method_call = self._substitute_method_call(
+ node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
+ 'count', is_unbound_method, args)
+ return method_call.coerce_to_pyobject(self.current_env())
+
+ PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
+ ])
+
+ def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
+ """Replace unicode.replace(...) by a direct call to the
+ corresponding C-API function.
+ """
+ if len(args) not in (3,4):
+ self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
+ return node
+ self._inject_int_default_argument(
+ node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
+
+ return self._substitute_method_call(
+ node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
+ 'replace', is_unbound_method, args)
+
+ PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
+ Builtin.bytes_type, [
+ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
+ PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
+ ])
+
+ PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
+ Builtin.bytes_type, [
+ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
+ ])
+
+ _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
+ 'unicode_escape', 'raw_unicode_escape']
+
+ _special_codecs = [ (name, codecs.getencoder(name))
+ for name in _special_encodings ]
+
+ def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
+ """Replace unicode.encode(...) by a direct C-API call to the
+ corresponding codec.
+ """
+ if len(args) < 1 or len(args) > 3:
+ self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
+ return node
+
+ string_node = args[0]
+
+ if len(args) == 1:
+ null_node = ExprNodes.NullNode(node.pos)
+ return self._substitute_method_call(
+ node, function, "PyUnicode_AsEncodedString",
+ self.PyUnicode_AsEncodedString_func_type,
+ 'encode', is_unbound_method, [string_node, null_node, null_node])
+
+ parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+ if parameters is None:
+ return node
+ encoding, encoding_node, error_handling, error_handling_node = parameters
+
+ if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
+ # constant, so try to do the encoding at compile time
+ try:
+ value = string_node.value.encode(encoding, error_handling)
+ except:
+ # well, looks like we can't
+ pass
+ else:
+ value = bytes_literal(value, encoding)
+ return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
+
+ if encoding and error_handling == 'strict':
+ # try to find a specific encoder function
+ codec_name = self._find_special_codec_name(encoding)
+ if codec_name is not None and '-' not in codec_name:
+ encode_function = "PyUnicode_As%sString" % codec_name
+ return self._substitute_method_call(
+ node, function, encode_function,
+ self.PyUnicode_AsXyzString_func_type,
+ 'encode', is_unbound_method, [string_node])
+
+ return self._substitute_method_call(
+ node, function, "PyUnicode_AsEncodedString",
+ self.PyUnicode_AsEncodedString_func_type,
+ 'encode', is_unbound_method,
+ [string_node, encoding_node, error_handling_node])
+
+ PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
+ ]))
+
+ _decode_c_string_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
+ ])
+
+ _decode_bytes_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
+ ])
+
+ _decode_cpp_string_func_type = None # lazy init
+
+ def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
+ """Replace char*.decode() by a direct C-API call to the
+ corresponding codec, possibly resolving a slice on the char*.
+ """
+ if not (1 <= len(args) <= 3):
+ self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
+ return node
+
+ # normalise input nodes
+ string_node = args[0]
+ start = stop = None
+ if isinstance(string_node, ExprNodes.SliceIndexNode):
+ index_node = string_node
+ string_node = index_node.base
+ start, stop = index_node.start, index_node.stop
+ if not start or start.constant_result == 0:
+ start = None
+ if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
+ string_node = string_node.arg
+
+ string_type = string_node.type
+ if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
+ if is_unbound_method:
+ string_node = string_node.as_none_safe_node(
+ "descriptor '%s' requires a '%s' object but received a 'NoneType'",
+ format_args=['decode', string_type.name])
+ else:
+ string_node = string_node.as_none_safe_node(
+ "'NoneType' object has no attribute '%.30s'",
+ error="PyExc_AttributeError",
+ format_args=['decode'])
+ elif not string_type.is_string and not string_type.is_cpp_string:
+ # nothing to optimise here
+ return node
+
+ parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+ if parameters is None:
+ return node
+ encoding, encoding_node, error_handling, error_handling_node = parameters
+
+ if not start:
+ start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
+ elif not start.type.is_int:
+ start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ if stop and not stop.type.is_int:
+ stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+
+ # try to find a specific encoder function
+ codec_name = None
+ if encoding is not None:
+ codec_name = self._find_special_codec_name(encoding)
+ if codec_name is not None:
+ if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
+ codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
+ else:
+ codec_cname = "PyUnicode_Decode%s" % codec_name
+ decode_function = ExprNodes.RawCNameExprNode(
+ node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
+ encoding_node = ExprNodes.NullNode(node.pos)
+ else:
+ decode_function = ExprNodes.NullNode(node.pos)
+
+ # build the helper function call
+ temps = []
+ if string_type.is_string:
+ # C string
+ if not stop:
+ # use strlen() to find the string length, just as CPython would
+ if not string_node.is_name:
+ string_node = UtilNodes.LetRefNode(string_node) # used twice
+ temps.append(string_node)
+ stop = ExprNodes.PythonCapiCallNode(
+ string_node.pos, "strlen", self.Pyx_strlen_func_type,
+ args=[string_node],
+ is_temp=False,
+ utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
+ ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
+ helper_func_type = self._decode_c_string_func_type
+ utility_code_name = 'decode_c_string'
+ elif string_type.is_cpp_string:
+ # C++ std::string
+ if not stop:
+ stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
+ constant_result=ExprNodes.not_a_constant)
+ if self._decode_cpp_string_func_type is None:
+ # lazy init to reuse the C++ string type
+ self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", string_type, None),
+ PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
+ ])
+ helper_func_type = self._decode_cpp_string_func_type
+ utility_code_name = 'decode_cpp_string'
+ else:
+ # Python bytes/bytearray object
+ if not stop:
+ stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
+ constant_result=ExprNodes.not_a_constant)
+ helper_func_type = self._decode_bytes_func_type
+ if string_type is Builtin.bytes_type:
+ utility_code_name = 'decode_bytes'
+ else:
+ utility_code_name = 'decode_bytearray'
+
+ node = ExprNodes.PythonCapiCallNode(
+ node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
+ args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
+ is_temp=node.is_temp,
+ utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
+ )
+
+ for temp in temps[::-1]:
+ node = UtilNodes.EvalWithTempExprNode(temp, node)
+ return node
+
+ _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
+
+ def _find_special_codec_name(self, encoding):
+ try:
+ requested_codec = codecs.getencoder(encoding)
+ except LookupError:
+ return None
+ for name, codec in self._special_codecs:
+ if codec == requested_codec:
+ if '_' in name:
+ name = ''.join([s.capitalize()
+ for s in name.split('_')])
+ return name
+ return None
+
+ def _unpack_encoding_and_error_mode(self, pos, args):
+ null_node = ExprNodes.NullNode(pos)
+
+ if len(args) >= 2:
+ encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
+ if encoding_node is None:
+ return None
+ else:
+ encoding = None
+ encoding_node = null_node
+
+ if len(args) == 3:
+ error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
+ if error_handling_node is None:
+ return None
+ if error_handling == 'strict':
+ error_handling_node = null_node
+ else:
+ error_handling = 'strict'
+ error_handling_node = null_node
+
+ return (encoding, encoding_node, error_handling, error_handling_node)
+
+ def _unpack_string_and_cstring_node(self, node):
+ if isinstance(node, ExprNodes.CoerceToPyTypeNode):
+ node = node.arg
+ if isinstance(node, ExprNodes.UnicodeNode):
+ encoding = node.value
+ node = ExprNodes.BytesNode(
+ node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
+ elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
+ encoding = node.value.decode('ISO-8859-1')
+ node = ExprNodes.BytesNode(
+ node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
+ elif node.type is Builtin.bytes_type:
+ encoding = None
+ node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
+ elif node.type.is_string:
+ encoding = None
+ else:
+ encoding = node = None
+ return encoding, node
+
+ def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'str', 'endswith',
+ str_tailmatch_utility_code, +1)
+
+ def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'str', 'startswith',
+ str_tailmatch_utility_code, -1)
+
+ def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'bytes', 'endswith',
+ bytes_tailmatch_utility_code, +1)
+
+ def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'bytes', 'startswith',
+ bytes_tailmatch_utility_code, -1)
+
+ ''' # disabled for now, enable when we consider it worth it (see StringTools.c)
+ def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'bytearray', 'endswith',
+ bytes_tailmatch_utility_code, +1)
+
+ def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
+ return self._inject_tailmatch(
+ node, function, args, is_unbound_method, 'bytearray', 'startswith',
+ bytes_tailmatch_utility_code, -1)
+ '''
+
+ ### helpers
+
+ def _substitute_method_call(self, node, function, name, func_type,
+ attr_name, is_unbound_method, args=(),
+ utility_code=None, is_temp=None,
+ may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
+ with_none_check=True):
+ args = list(args)
+ if with_none_check and args:
+ args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
+ if is_temp is None:
+ is_temp = node.is_temp
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, name, func_type,
+ args = args,
+ is_temp = is_temp,
+ utility_code = utility_code,
+ may_return_none = may_return_none,
+ result_is_used = node.result_is_used,
+ )
+
+ def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
+ if self_arg.is_literal:
+ return self_arg
+ if is_unbound_method:
+ self_arg = self_arg.as_none_safe_node(
+ "descriptor '%s' requires a '%s' object but received a 'NoneType'",
+ format_args=[attr_name, self_arg.type.name])
+ else:
+ self_arg = self_arg.as_none_safe_node(
+ "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''),
+ error="PyExc_AttributeError",
+ format_args=[attr_name])
+ return self_arg
+
+ def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
+ assert len(args) >= arg_index
+ if len(args) == arg_index:
+ args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
+ type=type, constant_result=default_value))
+ else:
+ args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
+
+ def _inject_bint_default_argument(self, node, args, arg_index, default_value):
+ assert len(args) >= arg_index
+ if len(args) == arg_index:
+ default_value = bool(default_value)
+ args.append(ExprNodes.BoolNode(node.pos, value=default_value,
+ constant_result=default_value))
+ else:
+ args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
+
+
+unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
+bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
+str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
+
+
+class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
+ """Calculate the result of constant expressions to store it in
+ ``expr_node.constant_result``, and replace trivial cases by their
+ constant result.
+
+ General rules:
+
+ - We calculate float constants to make them available to the
+ compiler, but we do not aggregate them into a single literal
+ node to prevent any loss of precision.
+
+ - We recursively calculate constants from non-literal nodes to
+ make them available to the compiler, but we only aggregate
+ literal nodes at each step. Non-literal nodes are never merged
+ into a single node.
+ """
+
+ def __init__(self, reevaluate=False):
+ """
+ The reevaluate argument specifies whether constant values that were
+ previously computed should be recomputed.
+ """
+ super(ConstantFolding, self).__init__()
+ self.reevaluate = reevaluate
+
+ def _calculate_const(self, node):
+ if (not self.reevaluate and
+ node.constant_result is not ExprNodes.constant_value_not_set):
+ return
+
+ # make sure we always set the value
+ not_a_constant = ExprNodes.not_a_constant
+ node.constant_result = not_a_constant
+
+ # check if all children are constant
+ children = self.visitchildren(node)
+ for child_result in children.values():
+ if type(child_result) is list:
+ for child in child_result:
+ if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
+ return
+ elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
+ return
+
+ # now try to calculate the real constant value
+ try:
+ node.calculate_constant_result()
+# if node.constant_result is not ExprNodes.not_a_constant:
+# print node.__class__.__name__, node.constant_result
+ except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
+ # ignore all 'normal' errors here => no constant result
+ pass
+ except Exception:
+ # this looks like a real error
+ import traceback, sys
+ traceback.print_exc(file=sys.stdout)
+
+ NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
+ ExprNodes.IntNode, ExprNodes.FloatNode]
+
+ def _widest_node_class(self, *nodes):
+ try:
+ return self.NODE_TYPE_ORDER[
+ max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
+ except ValueError:
+ return None
+
+ def _bool_node(self, node, value):
+ value = bool(value)
+ return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
+
+ def visit_ExprNode(self, node):
+ self._calculate_const(node)
+ return node
+
+ def visit_UnopNode(self, node):
+ self._calculate_const(node)
+ if not node.has_constant_result():
+ if node.operator == '!':
+ return self._handle_NotNode(node)
+ return node
+ if not node.operand.is_literal:
+ return node
+ if node.operator == '!':
+ return self._bool_node(node, node.constant_result)
+ elif isinstance(node.operand, ExprNodes.BoolNode):
+ return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
+ type=PyrexTypes.c_int_type,
+ constant_result=int(node.constant_result))
+ elif node.operator == '+':
+ return self._handle_UnaryPlusNode(node)
+ elif node.operator == '-':
+ return self._handle_UnaryMinusNode(node)
+ return node
+
+ _negate_operator = {
+ 'in': 'not_in',
+ 'not_in': 'in',
+ 'is': 'is_not',
+ 'is_not': 'is'
+ }.get
+
+ def _handle_NotNode(self, node):
+ operand = node.operand
+ if isinstance(operand, ExprNodes.PrimaryCmpNode):
+ operator = self._negate_operator(operand.operator)
+ if operator:
+ node = copy.copy(operand)
+ node.operator = operator
+ node = self.visit_PrimaryCmpNode(node)
+ return node
+
+ def _handle_UnaryMinusNode(self, node):
+ def _negate(value):
+ if value.startswith('-'):
+ value = value[1:]
+ else:
+ value = '-' + value
+ return value
+
+ node_type = node.operand.type
+ if isinstance(node.operand, ExprNodes.FloatNode):
+ # this is a safe operation
+ return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
+ type=node_type,
+ constant_result=node.constant_result)
+ if node_type.is_int and node_type.signed or \
+ isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
+ return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
+ type=node_type,
+ longness=node.operand.longness,
+ constant_result=node.constant_result)
+ return node
+
+ def _handle_UnaryPlusNode(self, node):
+ if (node.operand.has_constant_result() and
+ node.constant_result == node.operand.constant_result):
+ return node.operand
+ return node
+
+ def visit_BoolBinopNode(self, node):
+ self._calculate_const(node)
+ if not node.operand1.has_constant_result():
+ return node
+ if node.operand1.constant_result:
+ if node.operator == 'and':
+ return node.operand2
+ else:
+ return node.operand1
+ else:
+ if node.operator == 'and':
+ return node.operand1
+ else:
+ return node.operand2
+
+ def visit_BinopNode(self, node):
+ self._calculate_const(node)
+ if node.constant_result is ExprNodes.not_a_constant:
+ return node
+ if isinstance(node.constant_result, float):
+ return node
+ operand1, operand2 = node.operand1, node.operand2
+ if not operand1.is_literal or not operand2.is_literal:
+ return node
+
+ # now inject a new constant node with the calculated value
+ try:
+ type1, type2 = operand1.type, operand2.type
+ if type1 is None or type2 is None:
+ return node
+ except AttributeError:
+ return node
+
+ if type1.is_numeric and type2.is_numeric:
+ widest_type = PyrexTypes.widest_numeric_type(type1, type2)
+ else:
+ widest_type = PyrexTypes.py_object_type
+
+ target_class = self._widest_node_class(operand1, operand2)
+ if target_class is None:
+ return node
+ elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
+ # C arithmetic results in at least an int type
+ target_class = ExprNodes.IntNode
+ elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
+ # C arithmetic results in at least an int type
+ target_class = ExprNodes.IntNode
+
+ if target_class is ExprNodes.IntNode:
+ unsigned = getattr(operand1, 'unsigned', '') and \
+ getattr(operand2, 'unsigned', '')
+ longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
+ len(getattr(operand2, 'longness', '')))]
+ new_node = ExprNodes.IntNode(pos=node.pos,
+ unsigned=unsigned, longness=longness,
+ value=str(int(node.constant_result)),
+ constant_result=int(node.constant_result))
+ # IntNode is smart about the type it chooses, so we just
+ # make sure we were not smarter this time
+ if widest_type.is_pyobject or new_node.type.is_pyobject:
+ new_node.type = PyrexTypes.py_object_type
+ else:
+ new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
+ else:
+ if target_class is ExprNodes.BoolNode:
+ node_value = node.constant_result
+ else:
+ node_value = str(node.constant_result)
+ new_node = target_class(pos=node.pos, type = widest_type,
+ value = node_value,
+ constant_result = node.constant_result)
+ return new_node
+
+ def visit_AddNode(self, node):
+ self._calculate_const(node)
+ if node.constant_result is ExprNodes.not_a_constant:
+ return node
+ if node.operand1.is_string_literal and node.operand2.is_string_literal:
+ # some people combine string literals with a '+'
+ str1, str2 = node.operand1, node.operand2
+ if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
+ bytes_value = None
+ if str1.bytes_value is not None and str2.bytes_value is not None:
+ if str1.bytes_value.encoding == str2.bytes_value.encoding:
+ bytes_value = bytes_literal(
+ str1.bytes_value + str2.bytes_value,
+ str1.bytes_value.encoding)
+ string_value = EncodedString(node.constant_result)
+ return ExprNodes.UnicodeNode(
+ str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
+ elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
+ if str1.value.encoding == str2.value.encoding:
+ bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
+ return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
+ # all other combinations are rather complicated
+ # to get right in Py2/3: encodings, unicode escapes, ...
+ return self.visit_BinopNode(node)
+
+ def visit_MulNode(self, node):
+ self._calculate_const(node)
+ if node.operand1.is_sequence_constructor:
+ return self._calculate_constant_seq(node, node.operand1, node.operand2)
+ if isinstance(node.operand1, ExprNodes.IntNode) and \
+ node.operand2.is_sequence_constructor:
+ return self._calculate_constant_seq(node, node.operand2, node.operand1)
+ if node.operand1.is_string_literal:
+ return self._multiply_string(node, node.operand1, node.operand2)
+ elif node.operand2.is_string_literal:
+ return self._multiply_string(node, node.operand2, node.operand1)
+ return self.visit_BinopNode(node)
+
+ def _multiply_string(self, node, string_node, multiplier_node):
+ multiplier = multiplier_node.constant_result
+ if not isinstance(multiplier, _py_int_types):
+ return node
+ if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)):
+ return node
+ if len(node.constant_result) > 256:
+ # Too long for static creation, leave it to runtime. (-> arbitrary limit)
+ return node
+
+ build_string = encoded_string
+ if isinstance(string_node, ExprNodes.BytesNode):
+ build_string = bytes_literal
+ elif isinstance(string_node, ExprNodes.StringNode):
+ if string_node.unicode_value is not None:
+ string_node.unicode_value = encoded_string(
+ string_node.unicode_value * multiplier,
+ string_node.unicode_value.encoding)
+ build_string = encoded_string if string_node.value.is_unicode else bytes_literal
+ elif isinstance(string_node, ExprNodes.UnicodeNode):
+ if string_node.bytes_value is not None:
+ string_node.bytes_value = bytes_literal(
+ string_node.bytes_value * multiplier,
+ string_node.bytes_value.encoding)
+ else:
+ assert False, "unknown string node type: %s" % type(string_node)
+ string_node.value = build_string(
+ string_node.value * multiplier,
+ string_node.value.encoding)
+ # follow constant-folding and use unicode_value in preference
+ if isinstance(string_node, ExprNodes.StringNode) and string_node.unicode_value is not None:
+ string_node.constant_result = string_node.unicode_value
+ else:
+ string_node.constant_result = string_node.value
+ return string_node
+
+ def _calculate_constant_seq(self, node, sequence_node, factor):
+ if factor.constant_result != 1 and sequence_node.args:
+ if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0:
+ del sequence_node.args[:]
+ sequence_node.mult_factor = None
+ elif sequence_node.mult_factor is not None:
+ if (isinstance(factor.constant_result, _py_int_types) and
+ isinstance(sequence_node.mult_factor.constant_result, _py_int_types)):
+ value = sequence_node.mult_factor.constant_result * factor.constant_result
+ sequence_node.mult_factor = ExprNodes.IntNode(
+ sequence_node.mult_factor.pos,
+ value=str(value), constant_result=value)
+ else:
+ # don't know if we can combine the factors, so don't
+ return self.visit_BinopNode(node)
+ else:
+ sequence_node.mult_factor = factor
+ return sequence_node
+
+ def visit_ModNode(self, node):
+ self.visitchildren(node)
+ if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode):
+ if not node.operand2.mult_factor:
+ fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args)
+ if fstring is not None:
+ return fstring
+ return self.visit_BinopNode(node)
+
+ _parse_string_format_regex = (
+ u'(%(?:' # %...
+ u'(?:[-0-9]+|[ ])?' # width (optional) or space prefix fill character (optional)
+ u'(?:[.][0-9]+)?' # precision (optional)
+ u')?.)' # format type (or something different for unsupported formats)
+ )
+
+ def _build_fstring(self, pos, ustring, format_args):
+ # Issues formatting warnings instead of errors since we really only catch a few errors by accident.
+ args = iter(format_args)
+ substrings = []
+ can_be_optimised = True
+ for s in re.split(self._parse_string_format_regex, ustring):
+ if not s:
+ continue
+ if s == u'%%':
+ substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u'%'), constant_result=u'%'))
+ continue
+ if s[0] != u'%':
+ if s[-1] == u'%':
+ warning(pos, "Incomplete format: '...%s'" % s[-3:], level=1)
+ can_be_optimised = False
+ substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s), constant_result=s))
+ continue
+ format_type = s[-1]
+ try:
+ arg = next(args)
+ except StopIteration:
+ warning(pos, "Too few arguments for format placeholders", level=1)
+ can_be_optimised = False
+ break
+ if arg.is_starred:
+ can_be_optimised = False
+ break
+ if format_type in u'asrfdoxX':
+ format_spec = s[1:]
+ conversion_char = None
+ if format_type in u'doxX' and u'.' in format_spec:
+ # Precision is not allowed for integers in format(), but ok in %-formatting.
+ can_be_optimised = False
+ elif format_type in u'ars':
+ format_spec = format_spec[:-1]
+ conversion_char = format_type
+ if format_spec.startswith('0'):
+ format_spec = '>' + format_spec[1:] # right-alignment '%05s' spells '{:>5}'
+ elif format_type == u'd':
+ # '%d' formatting supports float, but '{obj:d}' does not => convert to int first.
+ conversion_char = 'd'
+
+ if format_spec.startswith('-'):
+ format_spec = '<' + format_spec[1:] # left-alignment '%-5s' spells '{:<5}'
+
+ substrings.append(ExprNodes.FormattedValueNode(
+ arg.pos, value=arg,
+ conversion_char=conversion_char,
+ format_spec=ExprNodes.UnicodeNode(
+ pos, value=EncodedString(format_spec), constant_result=format_spec)
+ if format_spec else None,
+ ))
+ else:
+ # keep it simple for now ...
+ can_be_optimised = False
+ break
+
+ if not can_be_optimised:
+ # Print all warnings we can find before finally giving up here.
+ return None
+
+ try:
+ next(args)
+ except StopIteration: pass
+ else:
+ warning(pos, "Too many arguments for format placeholders", level=1)
+ return None
+
+ node = ExprNodes.JoinedStrNode(pos, values=substrings)
+ return self.visit_JoinedStrNode(node)
+
+ def visit_FormattedValueNode(self, node):
+ self.visitchildren(node)
+ conversion_char = node.conversion_char or 's'
+ if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value:
+ node.format_spec = None
+ if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
+ value = EncodedString(node.value.value)
+ if value.isdigit():
+ return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
+ if node.format_spec is None and conversion_char == 's':
+ value = None
+ if isinstance(node.value, ExprNodes.UnicodeNode):
+ value = node.value.value
+ elif isinstance(node.value, ExprNodes.StringNode):
+ value = node.value.unicode_value
+ if value is not None:
+ return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
+ return node
+
+ def visit_JoinedStrNode(self, node):
+ """
+ Clean up after the parser by discarding empty Unicode strings and merging
+ substring sequences. Empty or single-value join lists are not uncommon
+ because f-string format specs are always parsed into JoinedStrNodes.
+ """
+ self.visitchildren(node)
+ unicode_node = ExprNodes.UnicodeNode
+
+ values = []
+ for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)):
+ if is_unode_group:
+ substrings = list(substrings)
+ unode = substrings[0]
+ if len(substrings) > 1:
+ value = EncodedString(u''.join(value.value for value in substrings))
+ unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value)
+ # ignore empty Unicode strings
+ if unode.value:
+ values.append(unode)
+ else:
+ values.extend(substrings)
+
+ if not values:
+ value = EncodedString('')
+ node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value)
+ elif len(values) == 1:
+ node = values[0]
+ elif len(values) == 2:
+ # reduce to string concatenation
+ node = ExprNodes.binop_node(node.pos, '+', *values)
+ else:
+ node.values = values
+ return node
+
+ def visit_MergedDictNode(self, node):
+ """Unpack **args in place if we can."""
+ self.visitchildren(node)
+ args = []
+ items = []
+
+ def add(arg):
+ if arg.is_dict_literal:
+ if items:
+ items[0].key_value_pairs.extend(arg.key_value_pairs)
+ else:
+ items.append(arg)
+ elif isinstance(arg, ExprNodes.MergedDictNode):
+ for child_arg in arg.keyword_args:
+ add(child_arg)
+ else:
+ if items:
+ args.append(items[0])
+ del items[:]
+ args.append(arg)
+
+ for arg in node.keyword_args:
+ add(arg)
+ if items:
+ args.append(items[0])
+
+ if len(args) == 1:
+ arg = args[0]
+ if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
+ return arg
+ node.keyword_args[:] = args
+ self._calculate_const(node)
+ return node
+
+ def visit_MergedSequenceNode(self, node):
+ """Unpack *args in place if we can."""
+ self.visitchildren(node)
+
+ is_set = node.type is Builtin.set_type
+ args = []
+ values = []
+
+ def add(arg):
+ if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
+ if values:
+ values[0].args.extend(arg.args)
+ else:
+ values.append(arg)
+ elif isinstance(arg, ExprNodes.MergedSequenceNode):
+ for child_arg in arg.args:
+ add(child_arg)
+ else:
+ if values:
+ args.append(values[0])
+ del values[:]
+ args.append(arg)
+
+ for arg in node.args:
+ add(arg)
+ if values:
+ args.append(values[0])
+
+ if len(args) == 1:
+ arg = args[0]
+ if ((is_set and arg.is_set_literal) or
+ (arg.is_sequence_constructor and arg.type is node.type) or
+ isinstance(arg, ExprNodes.MergedSequenceNode)):
+ return arg
+ node.args[:] = args
+ self._calculate_const(node)
+ return node
+
+ def visit_SequenceNode(self, node):
+ """Unpack *args in place if we can."""
+ self.visitchildren(node)
+ args = []
+ for arg in node.args:
+ if not arg.is_starred:
+ args.append(arg)
+ elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
+ args.extend(arg.target.args)
+ else:
+ args.append(arg)
+ node.args[:] = args
+ self._calculate_const(node)
+ return node
+
+ def visit_PrimaryCmpNode(self, node):
+ # calculate constant partial results in the comparison cascade
+ self.visitchildren(node, ['operand1'])
+ left_node = node.operand1
+ cmp_node = node
+ while cmp_node is not None:
+ self.visitchildren(cmp_node, ['operand2'])
+ right_node = cmp_node.operand2
+ cmp_node.constant_result = not_a_constant
+ if left_node.has_constant_result() and right_node.has_constant_result():
+ try:
+ cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
+ except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
+ pass # ignore all 'normal' errors here => no constant result
+ left_node = right_node
+ cmp_node = cmp_node.cascade
+
+ if not node.cascade:
+ if node.has_constant_result():
+ return self._bool_node(node, node.constant_result)
+ return node
+
+ # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
+ cascades = [[node.operand1]]
+ final_false_result = []
+
+ def split_cascades(cmp_node):
+ if cmp_node.has_constant_result():
+ if not cmp_node.constant_result:
+ # False => short-circuit
+ final_false_result.append(self._bool_node(cmp_node, False))
+ return
+ else:
+ # True => discard and start new cascade
+ cascades.append([cmp_node.operand2])
+ else:
+ # not constant => append to current cascade
+ cascades[-1].append(cmp_node)
+ if cmp_node.cascade:
+ split_cascades(cmp_node.cascade)
+
+ split_cascades(node)
+
+ cmp_nodes = []
+ for cascade in cascades:
+ if len(cascade) < 2:
+ continue
+ cmp_node = cascade[1]
+ pcmp_node = ExprNodes.PrimaryCmpNode(
+ cmp_node.pos,
+ operand1=cascade[0],
+ operator=cmp_node.operator,
+ operand2=cmp_node.operand2,
+ constant_result=not_a_constant)
+ cmp_nodes.append(pcmp_node)
+
+ last_cmp_node = pcmp_node
+ for cmp_node in cascade[2:]:
+ last_cmp_node.cascade = cmp_node
+ last_cmp_node = cmp_node
+ last_cmp_node.cascade = None
+
+ if final_false_result:
+ # last cascade was constant False
+ cmp_nodes.append(final_false_result[0])
+ elif not cmp_nodes:
+ # only constants, but no False result
+ return self._bool_node(node, True)
+ node = cmp_nodes[0]
+ if len(cmp_nodes) == 1:
+ if node.has_constant_result():
+ return self._bool_node(node, node.constant_result)
+ else:
+ for cmp_node in cmp_nodes[1:]:
+ node = ExprNodes.BoolBinopNode(
+ node.pos,
+ operand1=node,
+ operator='and',
+ operand2=cmp_node,
+ constant_result=not_a_constant)
+ return node
+
+ def visit_CondExprNode(self, node):
+ self._calculate_const(node)
+ if not node.test.has_constant_result():
+ return node
+ if node.test.constant_result:
+ return node.true_val
+ else:
+ return node.false_val
+
+ def visit_IfStatNode(self, node):
+ self.visitchildren(node)
+ # eliminate dead code based on constant condition results
+ if_clauses = []
+ for if_clause in node.if_clauses:
+ condition = if_clause.condition
+ if condition.has_constant_result():
+ if condition.constant_result:
+ # always true => subsequent clauses can safely be dropped
+ node.else_clause = if_clause.body
+ break
+ # else: false => drop clause
+ else:
+ # unknown result => normal runtime evaluation
+ if_clauses.append(if_clause)
+ if if_clauses:
+ node.if_clauses = if_clauses
+ return node
+ elif node.else_clause:
+ return node.else_clause
+ else:
+ return Nodes.StatListNode(node.pos, stats=[])
+
+ def visit_SliceIndexNode(self, node):
+ self._calculate_const(node)
+ # normalise start/stop values
+ if node.start is None or node.start.constant_result is None:
+ start = node.start = None
+ else:
+ start = node.start.constant_result
+ if node.stop is None or node.stop.constant_result is None:
+ stop = node.stop = None
+ else:
+ stop = node.stop.constant_result
+ # cut down sliced constant sequences
+ if node.constant_result is not not_a_constant:
+ base = node.base
+ if base.is_sequence_constructor and base.mult_factor is None:
+ base.args = base.args[start:stop]
+ return base
+ elif base.is_string_literal:
+ base = base.as_sliced_node(start, stop)
+ if base is not None:
+ return base
+ return node
+
+ def visit_ComprehensionNode(self, node):
+ self.visitchildren(node)
+ if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
+ # loop was pruned already => transform into literal
+ if node.type is Builtin.list_type:
+ return ExprNodes.ListNode(
+ node.pos, args=[], constant_result=[])
+ elif node.type is Builtin.set_type:
+ return ExprNodes.SetNode(
+ node.pos, args=[], constant_result=set())
+ elif node.type is Builtin.dict_type:
+ return ExprNodes.DictNode(
+ node.pos, key_value_pairs=[], constant_result={})
+ return node
+
+ def visit_ForInStatNode(self, node):
+ self.visitchildren(node)
+ sequence = node.iterator.sequence
+ if isinstance(sequence, ExprNodes.SequenceNode):
+ if not sequence.args:
+ if node.else_clause:
+ return node.else_clause
+ else:
+ # don't break list comprehensions
+ return Nodes.StatListNode(node.pos, stats=[])
+ # iterating over a list literal? => tuples are more efficient
+ if isinstance(sequence, ExprNodes.ListNode):
+ node.iterator.sequence = sequence.as_tuple()
+ return node
+
+ def visit_WhileStatNode(self, node):
+ self.visitchildren(node)
+ if node.condition and node.condition.has_constant_result():
+ if node.condition.constant_result:
+ node.condition = None
+ node.else_clause = None
+ else:
+ return node.else_clause
+ return node
+
+ def visit_ExprStatNode(self, node):
+ self.visitchildren(node)
+ if not isinstance(node.expr, ExprNodes.ExprNode):
+ # ParallelRangeTransform does this ...
+ return node
+ # drop unused constant expressions
+ if node.expr.has_constant_result():
+ return None
+ return node
+
+ # in the future, other nodes can have their own handler method here
+ # that can replace them with a constant result node
+
+ visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+
+class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
+ """
+ This visitor handles several commuting optimizations, and is run
+ just before the C code generation phase.
+
+ The optimizations currently implemented in this class are:
+ - eliminate None assignment and refcounting for first assignment.
+ - isinstance -> typecheck for cdef types
+ - eliminate checks for None and/or types that became redundant after tree changes
+ - eliminate useless string formatting steps
+ - replace Python function calls that look like method calls by a faster PyMethodCallNode
+ """
+ in_loop = False
+
+ def visit_SingleAssignmentNode(self, node):
+ """Avoid redundant initialisation of local variables before their
+ first assignment.
+ """
+ self.visitchildren(node)
+ if node.first:
+ lhs = node.lhs
+ lhs.lhs_of_first_assignment = True
+ return node
+
+ def visit_SimpleCallNode(self, node):
+ """
+ Replace generic calls to isinstance(x, type) by a more efficient type check.
+ Replace likely Python method calls by a specialised PyMethodCallNode.
+ """
+ self.visitchildren(node)
+ function = node.function
+ if function.type.is_cfunction and function.is_name:
+ if function.name == 'isinstance' and len(node.args) == 2:
+ type_arg = node.args[1]
+ if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
+ cython_scope = self.context.cython_scope
+ function.entry = cython_scope.lookup('PyObject_TypeCheck')
+ function.type = function.entry.type
+ PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
+ node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
+ elif (node.is_temp and function.type.is_pyobject and self.current_directives.get(
+ "optimize.unpack_method_calls_in_pyinit"
+ if not self.in_loop and self.current_env().is_module_scope
+ else "optimize.unpack_method_calls")):
+ # optimise simple Python methods calls
+ if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not (
+ node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)):
+ # simple call, now exclude calls to objects that are definitely not methods
+ may_be_a_method = True
+ if function.type is Builtin.type_type:
+ may_be_a_method = False
+ elif function.is_attribute:
+ if function.entry and function.entry.type.is_cfunction:
+ # optimised builtin method
+ may_be_a_method = False
+ elif function.is_name:
+ entry = function.entry
+ if entry.is_builtin or entry.type.is_cfunction:
+ may_be_a_method = False
+ elif entry.cf_assignments:
+ # local functions/classes are definitely not methods
+ non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode)
+ may_be_a_method = any(
+ assignment.rhs and not isinstance(assignment.rhs, non_method_nodes)
+ for assignment in entry.cf_assignments)
+ if may_be_a_method:
+ if (node.self and function.is_attribute and
+ isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
+ # function self object was moved into a CloneNode => undo
+ function.obj = function.obj.arg
+ node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
+ node, function=function, arg_tuple=node.arg_tuple, type=node.type))
+ return node
+
+ def visit_NumPyMethodCallNode(self, node):
+ # Exclude from replacement above.
+ self.visitchildren(node)
+ return node
+
+ def visit_PyTypeTestNode(self, node):
+ """Remove tests for alternatively allowed None values from
+ type tests when we know that the argument cannot be None
+ anyway.
+ """
+ self.visitchildren(node)
+ if not node.notnone:
+ if not node.arg.may_be_none():
+ node.notnone = True
+ return node
+
+ def visit_NoneCheckNode(self, node):
+ """Remove None checks from expressions that definitely do not
+ carry a None value.
+ """
+ self.visitchildren(node)
+ if not node.arg.may_be_none():
+ return node.arg
+ return node
+
+ def visit_LoopNode(self, node):
+ """Remember when we enter a loop as some expensive optimisations might still be worth it there.
+ """
+ old_val = self.in_loop
+ self.in_loop = True
+ self.visitchildren(node)
+ self.in_loop = old_val
+ return node
+
+
+class ConsolidateOverflowCheck(Visitor.CythonTransform):
+ """
+ This class facilitates the sharing of overflow checking among all nodes
+ of a nested arithmetic expression. For example, given the expression
+ a*b + c, where a, b, and x are all possibly overflowing ints, the entire
+ sequence will be evaluated and the overflow bit checked only at the end.
+ """
+ overflow_bit_node = None
+
+ def visit_Node(self, node):
+ if self.overflow_bit_node is not None:
+ saved = self.overflow_bit_node
+ self.overflow_bit_node = None
+ self.visitchildren(node)
+ self.overflow_bit_node = saved
+ else:
+ self.visitchildren(node)
+ return node
+
+ def visit_NumBinopNode(self, node):
+ if node.overflow_check and node.overflow_fold:
+ top_level_overflow = self.overflow_bit_node is None
+ if top_level_overflow:
+ self.overflow_bit_node = node
+ else:
+ node.overflow_bit_node = self.overflow_bit_node
+ node.overflow_check = False
+ self.visitchildren(node)
+ if top_level_overflow:
+ self.overflow_bit_node = None
+ else:
+ self.visitchildren(node)
+ return node