diff options
author | Anton Samokhvalov <pg83@yandex.ru> | 2022-02-10 16:45:15 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:45:15 +0300 |
commit | 72cb13b4aff9bc9cf22e49251bc8fd143f82538f (patch) | |
tree | da2c34829458c7d4e74bdfbdf85dff449e9e7fb8 /contrib/tools/cython/Cython/Compiler/Optimize.py | |
parent | 778e51ba091dc39e7b7fcab2b9cf4dbedfb6f2b5 (diff) | |
download | ydb-72cb13b4aff9bc9cf22e49251bc8fd143f82538f.tar.gz |
Restoring authorship annotation for Anton Samokhvalov <pg83@yandex.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/tools/cython/Cython/Compiler/Optimize.py')
-rw-r--r-- | contrib/tools/cython/Cython/Compiler/Optimize.py | 6992 |
1 files changed, 3496 insertions, 3496 deletions
diff --git a/contrib/tools/cython/Cython/Compiler/Optimize.py b/contrib/tools/cython/Cython/Compiler/Optimize.py index 3cb77efe2c..fa479644cd 100644 --- a/contrib/tools/cython/Cython/Compiler/Optimize.py +++ b/contrib/tools/cython/Cython/Compiler/Optimize.py @@ -1,18 +1,18 @@ -from __future__ import absolute_import - +from __future__ import absolute_import + import re import sys import copy import codecs import itertools -from . import TypeSlots -from .ExprNodes import not_a_constant -import cython +from . import TypeSlots +from .ExprNodes import not_a_constant +import cython cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object, - Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, + Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, UtilNodes=object, _py_int_types=object) - + if sys.version_info[0] >= 3: _py_int_types = int _py_string_types = (bytes, str) @@ -20,61 +20,61 @@ else: _py_int_types = (int, long) _py_string_types = (bytes, unicode) -from . import Nodes -from . import ExprNodes -from . import PyrexTypes -from . import Visitor -from . import Builtin -from . import UtilNodes -from . import Options - +from . import Nodes +from . import ExprNodes +from . import PyrexTypes +from . import Visitor +from . import Builtin +from . import UtilNodes +from . import Options + from .Code import UtilityCode, TempitaUtilityCode from .StringEncoding import EncodedString, bytes_literal, encoded_string from .Errors import error, warning -from .ParseTreeTransforms import SkipDeclarations - -try: - from __builtin__ import reduce -except ImportError: - from functools import reduce - -try: - from __builtin__ import basestring -except ImportError: - basestring = str # Python 3 - - -def load_c_utility(name): - return UtilityCode.load_cached(name, "Optimize.c") - - -def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): - if isinstance(node, coercion_nodes): - return node.arg - return node - - -def unwrap_node(node): - while isinstance(node, UtilNodes.ResultRefNode): - node = node.expression - return node - - -def is_common_value(a, b): - a = unwrap_node(a) - b = unwrap_node(b) - if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): - return a.name == b.name - if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): - return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute - return False - - -def filter_none_node(node): - if node is not None and node.constant_result is None: - return None - return node - +from .ParseTreeTransforms import SkipDeclarations + +try: + from __builtin__ import reduce +except ImportError: + from functools import reduce + +try: + from __builtin__ import basestring +except ImportError: + basestring = str # Python 3 + + +def load_c_utility(name): + return UtilityCode.load_cached(name, "Optimize.c") + + +def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): + if isinstance(node, coercion_nodes): + return node.arg + return node + + +def unwrap_node(node): + while isinstance(node, UtilNodes.ResultRefNode): + node = node.expression + return node + + +def is_common_value(a, b): + a = unwrap_node(a) + b = unwrap_node(b) + if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): + return a.name == b.name + if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): + return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute + return False + + +def filter_none_node(node): + if node is not None and node.constant_result is None: + return None + return node + class _YieldNodeCollector(Visitor.TreeVisitor): """ @@ -129,66 +129,66 @@ def _find_yield_statements(node): return yield_statements -class IterationTransform(Visitor.EnvTransform): - """Transform some common for-in loop patterns into efficient C loops: - - - for-in-dict loop becomes a while loop calling PyDict_Next() - - for-in-enumerate is replaced by an external counter variable - - for-in-range loop becomes a plain C for loop - """ - def visit_PrimaryCmpNode(self, node): - if node.is_ptr_contains(): - - # for t in operand2: - # if operand1 == t: - # res = True - # break - # else: - # res = False - - pos = node.pos - result_ref = UtilNodes.ResultRefNode(node) +class IterationTransform(Visitor.EnvTransform): + """Transform some common for-in loop patterns into efficient C loops: + + - for-in-dict loop becomes a while loop calling PyDict_Next() + - for-in-enumerate is replaced by an external counter variable + - for-in-range loop becomes a plain C for loop + """ + def visit_PrimaryCmpNode(self, node): + if node.is_ptr_contains(): + + # for t in operand2: + # if operand1 == t: + # res = True + # break + # else: + # res = False + + pos = node.pos + result_ref = UtilNodes.ResultRefNode(node) if node.operand2.is_subscript: - base_type = node.operand2.base.type.base_type - else: - base_type = node.operand2.type.base_type - target_handle = UtilNodes.TempHandle(base_type) - target = target_handle.ref(pos) - cmp_node = ExprNodes.PrimaryCmpNode( - pos, operator=u'==', operand1=node.operand1, operand2=target) - if_body = Nodes.StatListNode( - pos, - stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)), - Nodes.BreakStatNode(pos)]) - if_node = Nodes.IfStatNode( - pos, - if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)], - else_clause=None) - for_loop = UtilNodes.TempsBlockNode( - pos, - temps = [target_handle], - body = Nodes.ForInStatNode( - pos, - target=target, - iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), - body=if_node, - else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) - for_loop = for_loop.analyse_expressions(self.current_env()) - for_loop = self.visit(for_loop) - new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) - - if node.operator == 'not_in': - new_node = ExprNodes.NotNode(pos, operand=new_node) - return new_node - - else: - self.visitchildren(node) - return node - - def visit_ForInStatNode(self, node): - self.visitchildren(node) - return self._optimise_for_loop(node, node.iterator.sequence) - + base_type = node.operand2.base.type.base_type + else: + base_type = node.operand2.type.base_type + target_handle = UtilNodes.TempHandle(base_type) + target = target_handle.ref(pos) + cmp_node = ExprNodes.PrimaryCmpNode( + pos, operator=u'==', operand1=node.operand1, operand2=target) + if_body = Nodes.StatListNode( + pos, + stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)), + Nodes.BreakStatNode(pos)]) + if_node = Nodes.IfStatNode( + pos, + if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)], + else_clause=None) + for_loop = UtilNodes.TempsBlockNode( + pos, + temps = [target_handle], + body = Nodes.ForInStatNode( + pos, + target=target, + iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), + body=if_node, + else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) + for_loop = for_loop.analyse_expressions(self.current_env()) + for_loop = self.visit(for_loop) + new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) + + if node.operator == 'not_in': + new_node = ExprNodes.NotNode(pos, operand=new_node) + return new_node + + else: + self.visitchildren(node) + return node + + def visit_ForInStatNode(self, node): + self.visitchildren(node) + return self._optimise_for_loop(node, node.iterator.sequence) + def _optimise_for_loop(self, node, iterable, reversed=False): annotation_type = None if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation: @@ -207,13 +207,13 @@ class IterationTransform(Visitor.EnvTransform): annotation_type = Builtin.set_type if Builtin.dict_type in (iterable.type, annotation_type): - # like iterating over dict.keys() - if reversed: - # CPython raises an error here: not a sequence - return node - return self._transform_dict_iteration( + # like iterating over dict.keys() + if reversed: + # CPython raises an error here: not a sequence + return node + return self._transform_dict_iteration( node, dict_obj=iterable, method=None, keys=True, values=False) - + if (Builtin.set_type in (iterable.type, annotation_type) or Builtin.frozenset_type in (iterable.type, annotation_type)): if reversed: @@ -221,70 +221,70 @@ class IterationTransform(Visitor.EnvTransform): return node return self._transform_set_iteration(node, iterable) - # C array (slice) iteration? + # C array (slice) iteration? if iterable.type.is_ptr or iterable.type.is_array: return self._transform_carray_iteration(node, iterable, reversed=reversed) if iterable.type is Builtin.bytes_type: return self._transform_bytes_iteration(node, iterable, reversed=reversed) if iterable.type is Builtin.unicode_type: return self._transform_unicode_iteration(node, iterable, reversed=reversed) - - # the rest is based on function calls + + # the rest is based on function calls if not isinstance(iterable, ExprNodes.SimpleCallNode): - return node - + return node + if iterable.args is None: arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0 - else: + else: arg_count = len(iterable.args) if arg_count and iterable.self is not None: - arg_count -= 1 - + arg_count -= 1 + function = iterable.function - # dict iteration? - if function.is_attribute and not reversed and not arg_count: + # dict iteration? + if function.is_attribute and not reversed and not arg_count: base_obj = iterable.self or function.obj - method = function.attribute - # in Py3, items() is equivalent to Py2's iteritems() - is_safe_iter = self.global_scope().context.language_level >= 3 - - if not is_safe_iter and method in ('keys', 'values', 'items'): - # try to reduce this to the corresponding .iter*() methods + method = function.attribute + # in Py3, items() is equivalent to Py2's iteritems() + is_safe_iter = self.global_scope().context.language_level >= 3 + + if not is_safe_iter and method in ('keys', 'values', 'items'): + # try to reduce this to the corresponding .iter*() methods if isinstance(base_obj, ExprNodes.CallNode): - inner_function = base_obj.function - if (inner_function.is_name and inner_function.name == 'dict' - and inner_function.entry - and inner_function.entry.is_builtin): - # e.g. dict(something).items() => safe to use .iter*() - is_safe_iter = True - - keys = values = False - if method == 'iterkeys' or (is_safe_iter and method == 'keys'): - keys = True - elif method == 'itervalues' or (is_safe_iter and method == 'values'): - values = True - elif method == 'iteritems' or (is_safe_iter and method == 'items'): - keys = values = True - - if keys or values: - return self._transform_dict_iteration( - node, base_obj, method, keys, values) - - # enumerate/reversed ? + inner_function = base_obj.function + if (inner_function.is_name and inner_function.name == 'dict' + and inner_function.entry + and inner_function.entry.is_builtin): + # e.g. dict(something).items() => safe to use .iter*() + is_safe_iter = True + + keys = values = False + if method == 'iterkeys' or (is_safe_iter and method == 'keys'): + keys = True + elif method == 'itervalues' or (is_safe_iter and method == 'values'): + values = True + elif method == 'iteritems' or (is_safe_iter and method == 'items'): + keys = values = True + + if keys or values: + return self._transform_dict_iteration( + node, base_obj, method, keys, values) + + # enumerate/reversed ? if iterable.self is None and function.is_name and \ - function.entry and function.entry.is_builtin: - if function.name == 'enumerate': - if reversed: - # CPython raises an error here: not a sequence - return node + function.entry and function.entry.is_builtin: + if function.name == 'enumerate': + if reversed: + # CPython raises an error here: not a sequence + return node return self._transform_enumerate_iteration(node, iterable) - elif function.name == 'reversed': - if reversed: - # CPython raises an error here: not a sequence - return node + elif function.name == 'reversed': + if reversed: + # CPython raises an error here: not a sequence + return node return self._transform_reversed_iteration(node, iterable) - - # range() iteration? + + # range() iteration? if Options.convert_range and 1 <= arg_count <= 3 and ( iterable.self is None and function.is_name and function.name in ('range', 'xrange') and @@ -300,466 +300,466 @@ class IterationTransform(Visitor.EnvTransform): break else: return self._transform_range_iteration(node, iterable, reversed=reversed) - - return node - - def _transform_reversed_iteration(self, node, reversed_function): - args = reversed_function.arg_tuple.args - if len(args) == 0: - error(reversed_function.pos, - "reversed() requires an iterable argument") - return node - elif len(args) > 1: - error(reversed_function.pos, - "reversed() takes exactly 1 argument") - return node - arg = args[0] - - # reversed(list/tuple) ? - if arg.type in (Builtin.tuple_type, Builtin.list_type): - node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable") - node.iterator.reversed = True - return node - - return self._optimise_for_loop(node, arg, reversed=True) - - PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_char_ptr_type, [ - PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) - ]) - - PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ssize_t_type, [ - PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) - ]) - - def _transform_bytes_iteration(self, node, slice_node, reversed=False): - target_type = node.target.type - if not target_type.is_int and target_type is not Builtin.bytes_type: - # bytes iteration returns bytes objects in Py2, but - # integers in Py3 - return node - - unpack_temp_node = UtilNodes.LetRefNode( - slice_node.as_none_safe_node("'NoneType' is not iterable")) - - slice_base_node = ExprNodes.PythonCapiCallNode( - slice_node.pos, "PyBytes_AS_STRING", - self.PyBytes_AS_STRING_func_type, - args = [unpack_temp_node], - is_temp = 0, - ) - len_node = ExprNodes.PythonCapiCallNode( - slice_node.pos, "PyBytes_GET_SIZE", - self.PyBytes_GET_SIZE_func_type, - args = [unpack_temp_node], - is_temp = 0, - ) - - return UtilNodes.LetNode( - unpack_temp_node, - self._transform_carray_iteration( - node, - ExprNodes.SliceIndexNode( - slice_node.pos, - base = slice_base_node, - start = None, - step = None, - stop = len_node, - type = slice_base_node.type, - is_temp = 1, - ), - reversed = reversed)) - - PyUnicode_READ_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ucs4_type, [ - PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None), - PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None), - PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None) - ]) - - init_unicode_iteration_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_int_type, [ - PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None), - PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None), - PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None) - ], - exception_value = '-1') - - def _transform_unicode_iteration(self, node, slice_node, reversed=False): - if slice_node.is_literal: - # try to reduce to byte iteration for plain Latin-1 strings - try: + + return node + + def _transform_reversed_iteration(self, node, reversed_function): + args = reversed_function.arg_tuple.args + if len(args) == 0: + error(reversed_function.pos, + "reversed() requires an iterable argument") + return node + elif len(args) > 1: + error(reversed_function.pos, + "reversed() takes exactly 1 argument") + return node + arg = args[0] + + # reversed(list/tuple) ? + if arg.type in (Builtin.tuple_type, Builtin.list_type): + node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable") + node.iterator.reversed = True + return node + + return self._optimise_for_loop(node, arg, reversed=True) + + PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_char_ptr_type, [ + PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) + ]) + + PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) + ]) + + def _transform_bytes_iteration(self, node, slice_node, reversed=False): + target_type = node.target.type + if not target_type.is_int and target_type is not Builtin.bytes_type: + # bytes iteration returns bytes objects in Py2, but + # integers in Py3 + return node + + unpack_temp_node = UtilNodes.LetRefNode( + slice_node.as_none_safe_node("'NoneType' is not iterable")) + + slice_base_node = ExprNodes.PythonCapiCallNode( + slice_node.pos, "PyBytes_AS_STRING", + self.PyBytes_AS_STRING_func_type, + args = [unpack_temp_node], + is_temp = 0, + ) + len_node = ExprNodes.PythonCapiCallNode( + slice_node.pos, "PyBytes_GET_SIZE", + self.PyBytes_GET_SIZE_func_type, + args = [unpack_temp_node], + is_temp = 0, + ) + + return UtilNodes.LetNode( + unpack_temp_node, + self._transform_carray_iteration( + node, + ExprNodes.SliceIndexNode( + slice_node.pos, + base = slice_base_node, + start = None, + step = None, + stop = len_node, + type = slice_base_node.type, + is_temp = 1, + ), + reversed = reversed)) + + PyUnicode_READ_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ucs4_type, [ + PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None), + PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None), + PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None) + ]) + + init_unicode_iteration_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_int_type, [ + PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None), + PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None), + PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None) + ], + exception_value = '-1') + + def _transform_unicode_iteration(self, node, slice_node, reversed=False): + if slice_node.is_literal: + # try to reduce to byte iteration for plain Latin-1 strings + try: bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1') - except UnicodeEncodeError: - pass - else: - bytes_slice = ExprNodes.SliceIndexNode( - slice_node.pos, - base=ExprNodes.BytesNode( - slice_node.pos, value=bytes_value, - constant_result=bytes_value, + except UnicodeEncodeError: + pass + else: + bytes_slice = ExprNodes.SliceIndexNode( + slice_node.pos, + base=ExprNodes.BytesNode( + slice_node.pos, value=bytes_value, + constant_result=bytes_value, type=PyrexTypes.c_const_char_ptr_type).coerce_to( PyrexTypes.c_const_uchar_ptr_type, self.current_env()), - start=None, - stop=ExprNodes.IntNode( - slice_node.pos, value=str(len(bytes_value)), - constant_result=len(bytes_value), - type=PyrexTypes.c_py_ssize_t_type), - type=Builtin.unicode_type, # hint for Python conversion - ) - return self._transform_carray_iteration(node, bytes_slice, reversed) - - unpack_temp_node = UtilNodes.LetRefNode( - slice_node.as_none_safe_node("'NoneType' is not iterable")) - - start_node = ExprNodes.IntNode( - node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) - length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) - end_node = length_temp.ref(node.pos) - if reversed: - relation1, relation2 = '>', '>=' - start_node, end_node = end_node, start_node - else: - relation1, relation2 = '<=', '<' - - kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) - data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type) - counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) - - target_value = ExprNodes.PythonCapiCallNode( - slice_node.pos, "__Pyx_PyUnicode_READ", - self.PyUnicode_READ_func_type, - args = [kind_temp.ref(slice_node.pos), - data_temp.ref(slice_node.pos), - counter_temp.ref(node.target.pos)], - is_temp = False, - ) - if target_value.type != node.target.type: - target_value = target_value.coerce_to(node.target.type, - self.current_env()) - target_assign = Nodes.SingleAssignmentNode( - pos = node.target.pos, - lhs = node.target, - rhs = target_value) - body = Nodes.StatListNode( - node.pos, - stats = [target_assign, node.body]) - - loop_node = Nodes.ForFromStatNode( - node.pos, - bound1=start_node, relation1=relation1, - target=counter_temp.ref(node.target.pos), - relation2=relation2, bound2=end_node, - step=None, body=body, - else_clause=node.else_clause, - from_range=True) - - setup_node = Nodes.ExprStatNode( - node.pos, - expr = ExprNodes.PythonCapiCallNode( - slice_node.pos, "__Pyx_init_unicode_iteration", - self.init_unicode_iteration_func_type, - args = [unpack_temp_node, - ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos), - type=PyrexTypes.c_py_ssize_t_ptr_type), - ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos), - type=PyrexTypes.c_void_ptr_ptr_type), - ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos), - type=PyrexTypes.c_int_ptr_type), - ], - is_temp = True, - result_is_used = False, - utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"), - )) - return UtilNodes.LetNode( - unpack_temp_node, - UtilNodes.TempsBlockNode( - node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp], - body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node]))) - - def _transform_carray_iteration(self, node, slice_node, reversed=False): - neg_step = False - if isinstance(slice_node, ExprNodes.SliceIndexNode): - slice_base = slice_node.base - start = filter_none_node(slice_node.start) - stop = filter_none_node(slice_node.stop) - step = None - if not stop: - if not slice_base.type.is_pyobject: - error(slice_node.pos, "C array iteration requires known end index") - return node - + start=None, + stop=ExprNodes.IntNode( + slice_node.pos, value=str(len(bytes_value)), + constant_result=len(bytes_value), + type=PyrexTypes.c_py_ssize_t_type), + type=Builtin.unicode_type, # hint for Python conversion + ) + return self._transform_carray_iteration(node, bytes_slice, reversed) + + unpack_temp_node = UtilNodes.LetRefNode( + slice_node.as_none_safe_node("'NoneType' is not iterable")) + + start_node = ExprNodes.IntNode( + node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type) + length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) + end_node = length_temp.ref(node.pos) + if reversed: + relation1, relation2 = '>', '>=' + start_node, end_node = end_node, start_node + else: + relation1, relation2 = '<=', '<' + + kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) + data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type) + counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) + + target_value = ExprNodes.PythonCapiCallNode( + slice_node.pos, "__Pyx_PyUnicode_READ", + self.PyUnicode_READ_func_type, + args = [kind_temp.ref(slice_node.pos), + data_temp.ref(slice_node.pos), + counter_temp.ref(node.target.pos)], + is_temp = False, + ) + if target_value.type != node.target.type: + target_value = target_value.coerce_to(node.target.type, + self.current_env()) + target_assign = Nodes.SingleAssignmentNode( + pos = node.target.pos, + lhs = node.target, + rhs = target_value) + body = Nodes.StatListNode( + node.pos, + stats = [target_assign, node.body]) + + loop_node = Nodes.ForFromStatNode( + node.pos, + bound1=start_node, relation1=relation1, + target=counter_temp.ref(node.target.pos), + relation2=relation2, bound2=end_node, + step=None, body=body, + else_clause=node.else_clause, + from_range=True) + + setup_node = Nodes.ExprStatNode( + node.pos, + expr = ExprNodes.PythonCapiCallNode( + slice_node.pos, "__Pyx_init_unicode_iteration", + self.init_unicode_iteration_func_type, + args = [unpack_temp_node, + ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos), + type=PyrexTypes.c_py_ssize_t_ptr_type), + ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos), + type=PyrexTypes.c_void_ptr_ptr_type), + ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos), + type=PyrexTypes.c_int_ptr_type), + ], + is_temp = True, + result_is_used = False, + utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"), + )) + return UtilNodes.LetNode( + unpack_temp_node, + UtilNodes.TempsBlockNode( + node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp], + body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node]))) + + def _transform_carray_iteration(self, node, slice_node, reversed=False): + neg_step = False + if isinstance(slice_node, ExprNodes.SliceIndexNode): + slice_base = slice_node.base + start = filter_none_node(slice_node.start) + stop = filter_none_node(slice_node.stop) + step = None + if not stop: + if not slice_base.type.is_pyobject: + error(slice_node.pos, "C array iteration requires known end index") + return node + elif slice_node.is_subscript: - assert isinstance(slice_node.index, ExprNodes.SliceNode) - slice_base = slice_node.base - index = slice_node.index - start = filter_none_node(index.start) - stop = filter_none_node(index.stop) - step = filter_none_node(index.step) - if step: + assert isinstance(slice_node.index, ExprNodes.SliceNode) + slice_base = slice_node.base + index = slice_node.index + start = filter_none_node(index.start) + stop = filter_none_node(index.stop) + step = filter_none_node(index.step) + if step: if not isinstance(step.constant_result, _py_int_types) \ - or step.constant_result == 0 \ - or step.constant_result > 0 and not stop \ - or step.constant_result < 0 and not start: - if not slice_base.type.is_pyobject: - error(step.pos, "C array iteration requires known step size and end index") - return node - else: - # step sign is handled internally by ForFromStatNode - step_value = step.constant_result - if reversed: - step_value = -step_value - neg_step = step_value < 0 - step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, - value=str(abs(step_value)), - constant_result=abs(step_value)) - - elif slice_node.type.is_array: - if slice_node.type.size is None: - error(slice_node.pos, "C array iteration requires known end index") - return node - slice_base = slice_node - start = None - stop = ExprNodes.IntNode( - slice_node.pos, value=str(slice_node.type.size), - type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) - step = None - - else: - if not slice_node.type.is_pyobject: - error(slice_node.pos, "C array iteration requires known end index") - return node - - if start: - start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - if stop: - stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - if stop is None: - if neg_step: - stop = ExprNodes.IntNode( - slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1) - else: - error(slice_node.pos, "C array iteration requires known step size and end index") - return node - - if reversed: - if not start: - start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0, - type=PyrexTypes.c_py_ssize_t_type) - # if step was provided, it was already negated above - start, stop = stop, start - - ptr_type = slice_base.type - if ptr_type.is_array: - ptr_type = ptr_type.element_ptr_type() - carray_ptr = slice_base.coerce_to_simple(self.current_env()) - - if start and start.constant_result != 0: - start_ptr_node = ExprNodes.AddNode( - start.pos, - operand1=carray_ptr, - operator='+', - operand2=start, - type=ptr_type) - else: - start_ptr_node = carray_ptr - - if stop and stop.constant_result != 0: - stop_ptr_node = ExprNodes.AddNode( - stop.pos, - operand1=ExprNodes.CloneNode(carray_ptr), - operator='+', - operand2=stop, - type=ptr_type - ).coerce_to_simple(self.current_env()) - else: - stop_ptr_node = ExprNodes.CloneNode(carray_ptr) - - counter = UtilNodes.TempHandle(ptr_type) - counter_temp = counter.ref(node.target.pos) - - if slice_base.type.is_string and node.target.type.is_pyobject: - # special case: char* -> bytes/unicode - if slice_node.type is Builtin.unicode_type: - target_value = ExprNodes.CastNode( - ExprNodes.DereferenceNode( - node.target.pos, operand=counter_temp, - type=ptr_type.base_type), - PyrexTypes.c_py_ucs4_type).coerce_to( - node.target.type, self.current_env()) - else: - # char* -> bytes coercion requires slicing, not indexing - target_value = ExprNodes.SliceIndexNode( - node.target.pos, - start=ExprNodes.IntNode(node.target.pos, value='0', - constant_result=0, - type=PyrexTypes.c_int_type), - stop=ExprNodes.IntNode(node.target.pos, value='1', - constant_result=1, - type=PyrexTypes.c_int_type), - base=counter_temp, - type=Builtin.bytes_type, - is_temp=1) - elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type): - # Allow iteration with pointer target to avoid copy. - target_value = counter_temp - else: - # TODO: can this safely be replaced with DereferenceNode() as above? - target_value = ExprNodes.IndexNode( - node.target.pos, - index=ExprNodes.IntNode(node.target.pos, value='0', - constant_result=0, - type=PyrexTypes.c_int_type), - base=counter_temp, - type=ptr_type.base_type) - - if target_value.type != node.target.type: - target_value = target_value.coerce_to(node.target.type, - self.current_env()) - - target_assign = Nodes.SingleAssignmentNode( - pos = node.target.pos, - lhs = node.target, - rhs = target_value) - - body = Nodes.StatListNode( - node.pos, - stats = [target_assign, node.body]) - - relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed) - - for_node = Nodes.ForFromStatNode( - node.pos, - bound1=start_ptr_node, relation1=relation1, - target=counter_temp, - relation2=relation2, bound2=stop_ptr_node, - step=step, body=body, - else_clause=node.else_clause, - from_range=True) - - return UtilNodes.TempsBlockNode( - node.pos, temps=[counter], - body=for_node) - - def _transform_enumerate_iteration(self, node, enumerate_function): - args = enumerate_function.arg_tuple.args - if len(args) == 0: - error(enumerate_function.pos, - "enumerate() requires an iterable argument") - return node - elif len(args) > 2: - error(enumerate_function.pos, - "enumerate() takes at most 2 arguments") - return node - - if not node.target.is_sequence_constructor: - # leave this untouched for now - return node - targets = node.target.args - if len(targets) != 2: - # leave this untouched for now - return node - - enumerate_target, iterable_target = targets - counter_type = enumerate_target.type - - if not counter_type.is_pyobject and not counter_type.is_int: - # nothing we can do here, I guess - return node - - if len(args) == 2: - start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env()) - else: - start = ExprNodes.IntNode(enumerate_function.pos, - value='0', - type=counter_type, - constant_result=0) - temp = UtilNodes.LetRefNode(start) - - inc_expression = ExprNodes.AddNode( - enumerate_function.pos, - operand1 = temp, - operand2 = ExprNodes.IntNode(node.pos, value='1', - type=counter_type, - constant_result=1), - operator = '+', - type = counter_type, - #inplace = True, # not worth using in-place operation for Py ints - is_temp = counter_type.is_pyobject - ) - - loop_body = [ - Nodes.SingleAssignmentNode( - pos = enumerate_target.pos, - lhs = enumerate_target, - rhs = temp), - Nodes.SingleAssignmentNode( - pos = enumerate_target.pos, - lhs = temp, - rhs = inc_expression) - ] - - if isinstance(node.body, Nodes.StatListNode): - node.body.stats = loop_body + node.body.stats - else: - loop_body.append(node.body) - node.body = Nodes.StatListNode( - node.body.pos, - stats = loop_body) - - node.target = iterable_target - node.item = node.item.coerce_to(iterable_target.type, self.current_env()) - node.iterator.sequence = args[0] - - # recurse into loop to check for further optimisations - return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence)) - - def _find_for_from_node_relations(self, neg_step_value, reversed): - if reversed: - if neg_step_value: - return '<', '<=' - else: - return '>', '>=' - else: - if neg_step_value: - return '>=', '>' - else: - return '<=', '<' - - def _transform_range_iteration(self, node, range_function, reversed=False): - args = range_function.arg_tuple.args - if len(args) < 3: - step_pos = range_function.pos - step_value = 1 + or step.constant_result == 0 \ + or step.constant_result > 0 and not stop \ + or step.constant_result < 0 and not start: + if not slice_base.type.is_pyobject: + error(step.pos, "C array iteration requires known step size and end index") + return node + else: + # step sign is handled internally by ForFromStatNode + step_value = step.constant_result + if reversed: + step_value = -step_value + neg_step = step_value < 0 + step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, + value=str(abs(step_value)), + constant_result=abs(step_value)) + + elif slice_node.type.is_array: + if slice_node.type.size is None: + error(slice_node.pos, "C array iteration requires known end index") + return node + slice_base = slice_node + start = None + stop = ExprNodes.IntNode( + slice_node.pos, value=str(slice_node.type.size), + type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) + step = None + + else: + if not slice_node.type.is_pyobject: + error(slice_node.pos, "C array iteration requires known end index") + return node + + if start: + start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + if stop: + stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + if stop is None: + if neg_step: + stop = ExprNodes.IntNode( + slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1) + else: + error(slice_node.pos, "C array iteration requires known step size and end index") + return node + + if reversed: + if not start: + start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0, + type=PyrexTypes.c_py_ssize_t_type) + # if step was provided, it was already negated above + start, stop = stop, start + + ptr_type = slice_base.type + if ptr_type.is_array: + ptr_type = ptr_type.element_ptr_type() + carray_ptr = slice_base.coerce_to_simple(self.current_env()) + + if start and start.constant_result != 0: + start_ptr_node = ExprNodes.AddNode( + start.pos, + operand1=carray_ptr, + operator='+', + operand2=start, + type=ptr_type) + else: + start_ptr_node = carray_ptr + + if stop and stop.constant_result != 0: + stop_ptr_node = ExprNodes.AddNode( + stop.pos, + operand1=ExprNodes.CloneNode(carray_ptr), + operator='+', + operand2=stop, + type=ptr_type + ).coerce_to_simple(self.current_env()) + else: + stop_ptr_node = ExprNodes.CloneNode(carray_ptr) + + counter = UtilNodes.TempHandle(ptr_type) + counter_temp = counter.ref(node.target.pos) + + if slice_base.type.is_string and node.target.type.is_pyobject: + # special case: char* -> bytes/unicode + if slice_node.type is Builtin.unicode_type: + target_value = ExprNodes.CastNode( + ExprNodes.DereferenceNode( + node.target.pos, operand=counter_temp, + type=ptr_type.base_type), + PyrexTypes.c_py_ucs4_type).coerce_to( + node.target.type, self.current_env()) + else: + # char* -> bytes coercion requires slicing, not indexing + target_value = ExprNodes.SliceIndexNode( + node.target.pos, + start=ExprNodes.IntNode(node.target.pos, value='0', + constant_result=0, + type=PyrexTypes.c_int_type), + stop=ExprNodes.IntNode(node.target.pos, value='1', + constant_result=1, + type=PyrexTypes.c_int_type), + base=counter_temp, + type=Builtin.bytes_type, + is_temp=1) + elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type): + # Allow iteration with pointer target to avoid copy. + target_value = counter_temp + else: + # TODO: can this safely be replaced with DereferenceNode() as above? + target_value = ExprNodes.IndexNode( + node.target.pos, + index=ExprNodes.IntNode(node.target.pos, value='0', + constant_result=0, + type=PyrexTypes.c_int_type), + base=counter_temp, + type=ptr_type.base_type) + + if target_value.type != node.target.type: + target_value = target_value.coerce_to(node.target.type, + self.current_env()) + + target_assign = Nodes.SingleAssignmentNode( + pos = node.target.pos, + lhs = node.target, + rhs = target_value) + + body = Nodes.StatListNode( + node.pos, + stats = [target_assign, node.body]) + + relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed) + + for_node = Nodes.ForFromStatNode( + node.pos, + bound1=start_ptr_node, relation1=relation1, + target=counter_temp, + relation2=relation2, bound2=stop_ptr_node, + step=step, body=body, + else_clause=node.else_clause, + from_range=True) + + return UtilNodes.TempsBlockNode( + node.pos, temps=[counter], + body=for_node) + + def _transform_enumerate_iteration(self, node, enumerate_function): + args = enumerate_function.arg_tuple.args + if len(args) == 0: + error(enumerate_function.pos, + "enumerate() requires an iterable argument") + return node + elif len(args) > 2: + error(enumerate_function.pos, + "enumerate() takes at most 2 arguments") + return node + + if not node.target.is_sequence_constructor: + # leave this untouched for now + return node + targets = node.target.args + if len(targets) != 2: + # leave this untouched for now + return node + + enumerate_target, iterable_target = targets + counter_type = enumerate_target.type + + if not counter_type.is_pyobject and not counter_type.is_int: + # nothing we can do here, I guess + return node + + if len(args) == 2: + start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env()) + else: + start = ExprNodes.IntNode(enumerate_function.pos, + value='0', + type=counter_type, + constant_result=0) + temp = UtilNodes.LetRefNode(start) + + inc_expression = ExprNodes.AddNode( + enumerate_function.pos, + operand1 = temp, + operand2 = ExprNodes.IntNode(node.pos, value='1', + type=counter_type, + constant_result=1), + operator = '+', + type = counter_type, + #inplace = True, # not worth using in-place operation for Py ints + is_temp = counter_type.is_pyobject + ) + + loop_body = [ + Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = enumerate_target, + rhs = temp), + Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = temp, + rhs = inc_expression) + ] + + if isinstance(node.body, Nodes.StatListNode): + node.body.stats = loop_body + node.body.stats + else: + loop_body.append(node.body) + node.body = Nodes.StatListNode( + node.body.pos, + stats = loop_body) + + node.target = iterable_target + node.item = node.item.coerce_to(iterable_target.type, self.current_env()) + node.iterator.sequence = args[0] + + # recurse into loop to check for further optimisations + return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence)) + + def _find_for_from_node_relations(self, neg_step_value, reversed): + if reversed: + if neg_step_value: + return '<', '<=' + else: + return '>', '>=' + else: + if neg_step_value: + return '>=', '>' + else: + return '<=', '<' + + def _transform_range_iteration(self, node, range_function, reversed=False): + args = range_function.arg_tuple.args + if len(args) < 3: + step_pos = range_function.pos + step_value = 1 step = ExprNodes.IntNode(step_pos, value='1', constant_result=1) - else: - step = args[2] - step_pos = step.pos + else: + step = args[2] + step_pos = step.pos if not isinstance(step.constant_result, _py_int_types): - # cannot determine step direction - return node - step_value = step.constant_result - if step_value == 0: - # will lead to an error elsewhere - return node + # cannot determine step direction + return node + step_value = step.constant_result + if step_value == 0: + # will lead to an error elsewhere + return node step = ExprNodes.IntNode(step_pos, value=str(step_value), constant_result=step_value) - - if len(args) == 1: - bound1 = ExprNodes.IntNode(range_function.pos, value='0', - constant_result=0) - bound2 = args[0].coerce_to_integer(self.current_env()) - else: - bound1 = args[0].coerce_to_integer(self.current_env()) - bound2 = args[1].coerce_to_integer(self.current_env()) - - relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) - + + if len(args) == 1: + bound1 = ExprNodes.IntNode(range_function.pos, value='0', + constant_result=0) + bound2 = args[0].coerce_to_integer(self.current_env()) + else: + bound1 = args[0].coerce_to_integer(self.current_env()) + bound2 = args[1].coerce_to_integer(self.current_env()) + + relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) + bound2_ref_node = None - if reversed: - bound1, bound2 = bound2, bound1 + if reversed: + bound1, bound2 = bound2, bound1 abs_step = abs(step_value) if abs_step != 1: if (isinstance(bound1.constant_result, _py_int_types) and @@ -773,7 +773,7 @@ class IterationTransform(Visitor.EnvTransform): begin_value = bound1.constant_result end_value = bound2.constant_result bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1 - + bound1 = ExprNodes.IntNode( bound1.pos, value=str(bound1_value), constant_result=bound1_value, type=PyrexTypes.spanning_type(bound1.type, bound2.type)) @@ -785,32 +785,32 @@ class IterationTransform(Visitor.EnvTransform): if step_value < 0: step_value = -step_value - step.value = str(step_value) - step.constant_result = step_value - step = step.coerce_to_integer(self.current_env()) - - if not bound2.is_literal: - # stop bound must be immutable => keep it in a temp var - bound2_is_temp = True + step.value = str(step_value) + step.constant_result = step_value + step = step.coerce_to_integer(self.current_env()) + + if not bound2.is_literal: + # stop bound must be immutable => keep it in a temp var + bound2_is_temp = True bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2) - else: - bound2_is_temp = False - - for_node = Nodes.ForFromStatNode( - node.pos, - target=node.target, - bound1=bound1, relation1=relation1, - relation2=relation2, bound2=bound2, - step=step, body=node.body, - else_clause=node.else_clause, - from_range=True) + else: + bound2_is_temp = False + + for_node = Nodes.ForFromStatNode( + node.pos, + target=node.target, + bound1=bound1, relation1=relation1, + relation2=relation2, bound2=bound2, + step=step, body=node.body, + else_clause=node.else_clause, + from_range=True) for_node.set_up_loop(self.current_env()) - - if bound2_is_temp: - for_node = UtilNodes.LetNode(bound2, for_node) - - return for_node - + + if bound2_is_temp: + for_node = UtilNodes.LetNode(bound2, for_node) + + return for_node + def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value): abs_step = abs(step_value) spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type) @@ -875,114 +875,114 @@ class IterationTransform(Visitor.EnvTransform): type=spanning_type) return step_calculation_node - def _transform_dict_iteration(self, node, dict_obj, method, keys, values): - temps = [] - temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) - temps.append(temp) - dict_temp = temp.ref(dict_obj.pos) - temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) - temps.append(temp) - pos_temp = temp.ref(node.pos) - - key_target = value_target = tuple_target = None - if keys and values: - if node.target.is_sequence_constructor: - if len(node.target.args) == 2: - key_target, value_target = node.target.args - else: - # unusual case that may or may not lead to an error - return node - else: - tuple_target = node.target - elif keys: - key_target = node.target - else: - value_target = node.target - - if isinstance(node.body, Nodes.StatListNode): - body = node.body - else: - body = Nodes.StatListNode(pos = node.body.pos, - stats = [node.body]) - - # keep original length to guard against dict modification - dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) - temps.append(dict_len_temp) - dict_len_temp_addr = ExprNodes.AmpersandNode( - node.pos, operand=dict_len_temp.ref(dict_obj.pos), - type=PyrexTypes.c_ptr_type(dict_len_temp.type)) - temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) - temps.append(temp) - is_dict_temp = temp.ref(node.pos) - is_dict_temp_addr = ExprNodes.AmpersandNode( - node.pos, operand=is_dict_temp, - type=PyrexTypes.c_ptr_type(temp.type)) - - iter_next_node = Nodes.DictIterationNextNode( - dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, - key_target, value_target, tuple_target, - is_dict_temp) - iter_next_node = iter_next_node.analyse_expressions(self.current_env()) - body.stats[0:0] = [iter_next_node] - - if method: - method_node = ExprNodes.StringNode( - dict_obj.pos, is_identifier=True, value=method) - dict_obj = dict_obj.as_none_safe_node( + def _transform_dict_iteration(self, node, dict_obj, method, keys, values): + temps = [] + temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) + temps.append(temp) + dict_temp = temp.ref(dict_obj.pos) + temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) + temps.append(temp) + pos_temp = temp.ref(node.pos) + + key_target = value_target = tuple_target = None + if keys and values: + if node.target.is_sequence_constructor: + if len(node.target.args) == 2: + key_target, value_target = node.target.args + else: + # unusual case that may or may not lead to an error + return node + else: + tuple_target = node.target + elif keys: + key_target = node.target + else: + value_target = node.target + + if isinstance(node.body, Nodes.StatListNode): + body = node.body + else: + body = Nodes.StatListNode(pos = node.body.pos, + stats = [node.body]) + + # keep original length to guard against dict modification + dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) + temps.append(dict_len_temp) + dict_len_temp_addr = ExprNodes.AmpersandNode( + node.pos, operand=dict_len_temp.ref(dict_obj.pos), + type=PyrexTypes.c_ptr_type(dict_len_temp.type)) + temp = UtilNodes.TempHandle(PyrexTypes.c_int_type) + temps.append(temp) + is_dict_temp = temp.ref(node.pos) + is_dict_temp_addr = ExprNodes.AmpersandNode( + node.pos, operand=is_dict_temp, + type=PyrexTypes.c_ptr_type(temp.type)) + + iter_next_node = Nodes.DictIterationNextNode( + dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp, + key_target, value_target, tuple_target, + is_dict_temp) + iter_next_node = iter_next_node.analyse_expressions(self.current_env()) + body.stats[0:0] = [iter_next_node] + + if method: + method_node = ExprNodes.StringNode( + dict_obj.pos, is_identifier=True, value=method) + dict_obj = dict_obj.as_none_safe_node( "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''), - error = "PyExc_AttributeError", - format_args = [method]) - else: - method_node = ExprNodes.NullNode(dict_obj.pos) - dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable") - - def flag_node(value): - value = value and 1 or 0 - return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value) - - result_code = [ - Nodes.SingleAssignmentNode( - node.pos, - lhs = pos_temp, - rhs = ExprNodes.IntNode(node.pos, value='0', - constant_result=0)), - Nodes.SingleAssignmentNode( - dict_obj.pos, - lhs = dict_temp, - rhs = ExprNodes.PythonCapiCallNode( - dict_obj.pos, - "__Pyx_dict_iterator", - self.PyDict_Iterator_func_type, - utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"), - args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type), - method_node, dict_len_temp_addr, is_dict_temp_addr, - ], - is_temp=True, - )), - Nodes.WhileStatNode( - node.pos, - condition = None, - body = body, - else_clause = node.else_clause - ) - ] - - return UtilNodes.TempsBlockNode( - node.pos, temps=temps, - body=Nodes.StatListNode( - node.pos, - stats = result_code - )) - - PyDict_Iterator_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None), - PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), - PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None), - ]) - + error = "PyExc_AttributeError", + format_args = [method]) + else: + method_node = ExprNodes.NullNode(dict_obj.pos) + dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable") + + def flag_node(value): + value = value and 1 or 0 + return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value) + + result_code = [ + Nodes.SingleAssignmentNode( + node.pos, + lhs = pos_temp, + rhs = ExprNodes.IntNode(node.pos, value='0', + constant_result=0)), + Nodes.SingleAssignmentNode( + dict_obj.pos, + lhs = dict_temp, + rhs = ExprNodes.PythonCapiCallNode( + dict_obj.pos, + "__Pyx_dict_iterator", + self.PyDict_Iterator_func_type, + utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"), + args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type), + method_node, dict_len_temp_addr, is_dict_temp_addr, + ], + is_temp=True, + )), + Nodes.WhileStatNode( + node.pos, + condition = None, + body = body, + else_clause = node.else_clause + ) + ] + + return UtilNodes.TempsBlockNode( + node.pos, temps=temps, + body=Nodes.StatListNode( + node.pos, + stats = result_code + )) + + PyDict_Iterator_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None), + PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), + PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None), + ]) + PySet_Iterator_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None), @@ -990,7 +990,7 @@ class IterationTransform(Visitor.EnvTransform): PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None), PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None), ]) - + def _transform_set_iteration(self, node, set_obj): temps = [] temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) @@ -1063,561 +1063,561 @@ class IterationTransform(Visitor.EnvTransform): )) -class SwitchTransform(Visitor.EnvTransform): - """ - This transformation tries to turn long if statements into C switch statements. - The requirement is that every clause be an (or of) var == value, where the var - is common among all clauses and both var and value are ints. - """ - NO_MATCH = (None, None, None) - - def extract_conditions(self, cond, allow_not_in): - while True: - if isinstance(cond, (ExprNodes.CoerceToTempNode, - ExprNodes.CoerceToBooleanNode)): - cond = cond.arg - elif isinstance(cond, ExprNodes.BoolBinopResultNode): - cond = cond.arg.arg - elif isinstance(cond, UtilNodes.EvalWithTempExprNode): - # this is what we get from the FlattenInListTransform - cond = cond.subexpression - elif isinstance(cond, ExprNodes.TypecastNode): - cond = cond.operand - else: - break - - if isinstance(cond, ExprNodes.PrimaryCmpNode): - if cond.cascade is not None: - return self.NO_MATCH - elif cond.is_c_string_contains() and \ - isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): - not_in = cond.operator == 'not_in' - if not_in and not allow_not_in: - return self.NO_MATCH - if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \ - cond.operand2.contains_surrogates(): - # dealing with surrogates leads to different - # behaviour on wide and narrow Unicode - # platforms => refuse to optimise this case - return self.NO_MATCH - return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2) - elif not cond.is_python_comparison(): - if cond.operator == '==': - not_in = False - elif allow_not_in and cond.operator == '!=': - not_in = True - else: - return self.NO_MATCH - # this looks somewhat silly, but it does the right - # checks for NameNode and AttributeNode - if is_common_value(cond.operand1, cond.operand1): - if cond.operand2.is_literal: - return not_in, cond.operand1, [cond.operand2] - elif getattr(cond.operand2, 'entry', None) \ - and cond.operand2.entry.is_const: - return not_in, cond.operand1, [cond.operand2] - if is_common_value(cond.operand2, cond.operand2): - if cond.operand1.is_literal: - return not_in, cond.operand2, [cond.operand1] - elif getattr(cond.operand1, 'entry', None) \ - and cond.operand1.entry.is_const: - return not_in, cond.operand2, [cond.operand1] - elif isinstance(cond, ExprNodes.BoolBinopNode): - if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'): - allow_not_in = (cond.operator == 'and') - not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in) - not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in) - if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2): - if (not not_in_1) or allow_not_in: - return not_in_1, t1, c1+c2 - return self.NO_MATCH - - def extract_in_string_conditions(self, string_literal): - if isinstance(string_literal, ExprNodes.UnicodeNode): - charvals = list(map(ord, set(string_literal.value))) - charvals.sort() - return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), - constant_result=charval) - for charval in charvals ] - else: - # this is a bit tricky as Py3's bytes type returns - # integers on iteration, whereas Py2 returns 1-char byte - # strings - characters = string_literal.value - characters = list(set([ characters[i:i+1] for i in range(len(characters)) ])) - characters.sort() - return [ ExprNodes.CharNode(string_literal.pos, value=charval, - constant_result=charval) - for charval in characters ] - - def extract_common_conditions(self, common_var, condition, allow_not_in): - not_in, var, conditions = self.extract_conditions(condition, allow_not_in) - if var is None: - return self.NO_MATCH - elif common_var is not None and not is_common_value(var, common_var): - return self.NO_MATCH - elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): - return self.NO_MATCH - return not_in, var, conditions - - def has_duplicate_values(self, condition_values): - # duplicated values don't work in a switch statement - seen = set() - for value in condition_values: - if value.has_constant_result(): - if value.constant_result in seen: - return True - seen.add(value.constant_result) - else: - # this isn't completely safe as we don't know the - # final C value, but this is about the best we can do - try: - if value.entry.cname in seen: - return True - except AttributeError: - return True # play safe - seen.add(value.entry.cname) - return False - - def visit_IfStatNode(self, node): - if not self.current_directives.get('optimize.use_switch'): - self.visitchildren(node) - return node - - common_var = None - cases = [] - for if_clause in node.if_clauses: - _, common_var, conditions = self.extract_common_conditions( - common_var, if_clause.condition, False) - if common_var is None: - self.visitchildren(node) - return node +class SwitchTransform(Visitor.EnvTransform): + """ + This transformation tries to turn long if statements into C switch statements. + The requirement is that every clause be an (or of) var == value, where the var + is common among all clauses and both var and value are ints. + """ + NO_MATCH = (None, None, None) + + def extract_conditions(self, cond, allow_not_in): + while True: + if isinstance(cond, (ExprNodes.CoerceToTempNode, + ExprNodes.CoerceToBooleanNode)): + cond = cond.arg + elif isinstance(cond, ExprNodes.BoolBinopResultNode): + cond = cond.arg.arg + elif isinstance(cond, UtilNodes.EvalWithTempExprNode): + # this is what we get from the FlattenInListTransform + cond = cond.subexpression + elif isinstance(cond, ExprNodes.TypecastNode): + cond = cond.operand + else: + break + + if isinstance(cond, ExprNodes.PrimaryCmpNode): + if cond.cascade is not None: + return self.NO_MATCH + elif cond.is_c_string_contains() and \ + isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): + not_in = cond.operator == 'not_in' + if not_in and not allow_not_in: + return self.NO_MATCH + if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \ + cond.operand2.contains_surrogates(): + # dealing with surrogates leads to different + # behaviour on wide and narrow Unicode + # platforms => refuse to optimise this case + return self.NO_MATCH + return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2) + elif not cond.is_python_comparison(): + if cond.operator == '==': + not_in = False + elif allow_not_in and cond.operator == '!=': + not_in = True + else: + return self.NO_MATCH + # this looks somewhat silly, but it does the right + # checks for NameNode and AttributeNode + if is_common_value(cond.operand1, cond.operand1): + if cond.operand2.is_literal: + return not_in, cond.operand1, [cond.operand2] + elif getattr(cond.operand2, 'entry', None) \ + and cond.operand2.entry.is_const: + return not_in, cond.operand1, [cond.operand2] + if is_common_value(cond.operand2, cond.operand2): + if cond.operand1.is_literal: + return not_in, cond.operand2, [cond.operand1] + elif getattr(cond.operand1, 'entry', None) \ + and cond.operand1.entry.is_const: + return not_in, cond.operand2, [cond.operand1] + elif isinstance(cond, ExprNodes.BoolBinopNode): + if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'): + allow_not_in = (cond.operator == 'and') + not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in) + not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in) + if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2): + if (not not_in_1) or allow_not_in: + return not_in_1, t1, c1+c2 + return self.NO_MATCH + + def extract_in_string_conditions(self, string_literal): + if isinstance(string_literal, ExprNodes.UnicodeNode): + charvals = list(map(ord, set(string_literal.value))) + charvals.sort() + return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), + constant_result=charval) + for charval in charvals ] + else: + # this is a bit tricky as Py3's bytes type returns + # integers on iteration, whereas Py2 returns 1-char byte + # strings + characters = string_literal.value + characters = list(set([ characters[i:i+1] for i in range(len(characters)) ])) + characters.sort() + return [ ExprNodes.CharNode(string_literal.pos, value=charval, + constant_result=charval) + for charval in characters ] + + def extract_common_conditions(self, common_var, condition, allow_not_in): + not_in, var, conditions = self.extract_conditions(condition, allow_not_in) + if var is None: + return self.NO_MATCH + elif common_var is not None and not is_common_value(var, common_var): + return self.NO_MATCH + elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]): + return self.NO_MATCH + return not_in, var, conditions + + def has_duplicate_values(self, condition_values): + # duplicated values don't work in a switch statement + seen = set() + for value in condition_values: + if value.has_constant_result(): + if value.constant_result in seen: + return True + seen.add(value.constant_result) + else: + # this isn't completely safe as we don't know the + # final C value, but this is about the best we can do + try: + if value.entry.cname in seen: + return True + except AttributeError: + return True # play safe + seen.add(value.entry.cname) + return False + + def visit_IfStatNode(self, node): + if not self.current_directives.get('optimize.use_switch'): + self.visitchildren(node) + return node + + common_var = None + cases = [] + for if_clause in node.if_clauses: + _, common_var, conditions = self.extract_common_conditions( + common_var, if_clause.condition, False) + if common_var is None: + self.visitchildren(node) + return node cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos, conditions=conditions, body=if_clause.body)) - - condition_values = [ - cond for case in cases for cond in case.conditions] - if len(condition_values) < 2: - self.visitchildren(node) - return node - if self.has_duplicate_values(condition_values): - self.visitchildren(node) - return node - + + condition_values = [ + cond for case in cases for cond in case.conditions] + if len(condition_values) < 2: + self.visitchildren(node) + return node + if self.has_duplicate_values(condition_values): + self.visitchildren(node) + return node + # Recurse into body subtrees that we left untouched so far. self.visitchildren(node, 'else_clause') for case in cases: self.visitchildren(case, 'body') - common_var = unwrap_node(common_var) + common_var = unwrap_node(common_var) switch_node = Nodes.SwitchStatNode(pos=node.pos, test=common_var, cases=cases, else_clause=node.else_clause) - return switch_node - - def visit_CondExprNode(self, node): - if not self.current_directives.get('optimize.use_switch'): - self.visitchildren(node) - return node - - not_in, common_var, conditions = self.extract_common_conditions( - None, node.test, True) - if common_var is None \ + return switch_node + + def visit_CondExprNode(self, node): + if not self.current_directives.get('optimize.use_switch'): + self.visitchildren(node) + return node + + not_in, common_var, conditions = self.extract_common_conditions( + None, node.test, True) + if common_var is None \ or len(conditions) < 2 \ or self.has_duplicate_values(conditions): - self.visitchildren(node) - return node - - return self.build_simple_switch_statement( - node, common_var, conditions, not_in, - node.true_val, node.false_val) - - def visit_BoolBinopNode(self, node): - if not self.current_directives.get('optimize.use_switch'): - self.visitchildren(node) - return node - - not_in, common_var, conditions = self.extract_common_conditions( - None, node, True) - if common_var is None \ + self.visitchildren(node) + return node + + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + node.true_val, node.false_val) + + def visit_BoolBinopNode(self, node): + if not self.current_directives.get('optimize.use_switch'): + self.visitchildren(node) + return node + + not_in, common_var, conditions = self.extract_common_conditions( + None, node, True) + if common_var is None \ or len(conditions) < 2 \ or self.has_duplicate_values(conditions): - self.visitchildren(node) - node.wrap_operands(self.current_env()) # in case we changed the operands - return node - - return self.build_simple_switch_statement( - node, common_var, conditions, not_in, - ExprNodes.BoolNode(node.pos, value=True, constant_result=True), - ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) - - def visit_PrimaryCmpNode(self, node): - if not self.current_directives.get('optimize.use_switch'): - self.visitchildren(node) - return node - - not_in, common_var, conditions = self.extract_common_conditions( - None, node, True) - if common_var is None \ + self.visitchildren(node) + node.wrap_operands(self.current_env()) # in case we changed the operands + return node + + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + ExprNodes.BoolNode(node.pos, value=True, constant_result=True), + ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) + + def visit_PrimaryCmpNode(self, node): + if not self.current_directives.get('optimize.use_switch'): + self.visitchildren(node) + return node + + not_in, common_var, conditions = self.extract_common_conditions( + None, node, True) + if common_var is None \ or len(conditions) < 2 \ or self.has_duplicate_values(conditions): - self.visitchildren(node) - return node - - return self.build_simple_switch_statement( - node, common_var, conditions, not_in, - ExprNodes.BoolNode(node.pos, value=True, constant_result=True), - ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) - - def build_simple_switch_statement(self, node, common_var, conditions, - not_in, true_val, false_val): - result_ref = UtilNodes.ResultRefNode(node) - true_body = Nodes.SingleAssignmentNode( - node.pos, - lhs=result_ref, - rhs=true_val.coerce_to(node.type, self.current_env()), - first=True) - false_body = Nodes.SingleAssignmentNode( - node.pos, - lhs=result_ref, - rhs=false_val.coerce_to(node.type, self.current_env()), - first=True) - - if not_in: - true_body, false_body = false_body, true_body - - cases = [Nodes.SwitchCaseNode(pos = node.pos, - conditions = conditions, - body = true_body)] - - common_var = unwrap_node(common_var) - switch_node = Nodes.SwitchStatNode(pos = node.pos, - test = common_var, - cases = cases, - else_clause = false_body) - replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node) - return replacement - - def visit_EvalWithTempExprNode(self, node): - if not self.current_directives.get('optimize.use_switch'): - self.visitchildren(node) - return node - - # drop unused expression temp from FlattenInListTransform - orig_expr = node.subexpression - temp_ref = node.lazy_temp - self.visitchildren(node) - if node.subexpression is not orig_expr: - # node was restructured => check if temp is still used - if not Visitor.tree_contains(node.subexpression, temp_ref): - return node.subexpression - return node - - visit_Node = Visitor.VisitorTransform.recurse_to_children - - -class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): - """ - This transformation flattens "x in [val1, ..., valn]" into a sequential list - of comparisons. - """ - - def visit_PrimaryCmpNode(self, node): - self.visitchildren(node) - if node.cascade is not None: - return node - elif node.operator == 'in': - conjunction = 'or' - eq_or_neq = '==' - elif node.operator == 'not_in': - conjunction = 'and' - eq_or_neq = '!=' - else: - return node - - if not isinstance(node.operand2, (ExprNodes.TupleNode, - ExprNodes.ListNode, - ExprNodes.SetNode)): - return node - - args = node.operand2.args - if len(args) == 0: - # note: lhs may have side effects - return node - + self.visitchildren(node) + return node + + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + ExprNodes.BoolNode(node.pos, value=True, constant_result=True), + ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) + + def build_simple_switch_statement(self, node, common_var, conditions, + not_in, true_val, false_val): + result_ref = UtilNodes.ResultRefNode(node) + true_body = Nodes.SingleAssignmentNode( + node.pos, + lhs=result_ref, + rhs=true_val.coerce_to(node.type, self.current_env()), + first=True) + false_body = Nodes.SingleAssignmentNode( + node.pos, + lhs=result_ref, + rhs=false_val.coerce_to(node.type, self.current_env()), + first=True) + + if not_in: + true_body, false_body = false_body, true_body + + cases = [Nodes.SwitchCaseNode(pos = node.pos, + conditions = conditions, + body = true_body)] + + common_var = unwrap_node(common_var) + switch_node = Nodes.SwitchStatNode(pos = node.pos, + test = common_var, + cases = cases, + else_clause = false_body) + replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node) + return replacement + + def visit_EvalWithTempExprNode(self, node): + if not self.current_directives.get('optimize.use_switch'): + self.visitchildren(node) + return node + + # drop unused expression temp from FlattenInListTransform + orig_expr = node.subexpression + temp_ref = node.lazy_temp + self.visitchildren(node) + if node.subexpression is not orig_expr: + # node was restructured => check if temp is still used + if not Visitor.tree_contains(node.subexpression, temp_ref): + return node.subexpression + return node + + visit_Node = Visitor.VisitorTransform.recurse_to_children + + +class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): + """ + This transformation flattens "x in [val1, ..., valn]" into a sequential list + of comparisons. + """ + + def visit_PrimaryCmpNode(self, node): + self.visitchildren(node) + if node.cascade is not None: + return node + elif node.operator == 'in': + conjunction = 'or' + eq_or_neq = '==' + elif node.operator == 'not_in': + conjunction = 'and' + eq_or_neq = '!=' + else: + return node + + if not isinstance(node.operand2, (ExprNodes.TupleNode, + ExprNodes.ListNode, + ExprNodes.SetNode)): + return node + + args = node.operand2.args + if len(args) == 0: + # note: lhs may have side effects + return node + if any([arg.is_starred for arg in args]): # Starred arguments do not directly translate to comparisons or "in" tests. return node - lhs = UtilNodes.ResultRefNode(node.operand1) - - conds = [] - temps = [] - for arg in args: - try: - # Trial optimisation to avoid redundant temp - # assignments. However, since is_simple() is meant to - # be called after type analysis, we ignore any errors - # and just play safe in that case. - is_simple_arg = arg.is_simple() - except Exception: - is_simple_arg = False - if not is_simple_arg: - # must evaluate all non-simple RHS before doing the comparisons - arg = UtilNodes.LetRefNode(arg) - temps.append(arg) - cond = ExprNodes.PrimaryCmpNode( - pos = node.pos, - operand1 = lhs, - operator = eq_or_neq, - operand2 = arg, - cascade = None) - conds.append(ExprNodes.TypecastNode( - pos = node.pos, - operand = cond, - type = PyrexTypes.c_bint_type)) - def concat(left, right): - return ExprNodes.BoolBinopNode( - pos = node.pos, - operator = conjunction, - operand1 = left, - operand2 = right) - - condition = reduce(concat, conds) - new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) - for temp in temps[::-1]: - new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) - return new_node - - visit_Node = Visitor.VisitorTransform.recurse_to_children - - -class DropRefcountingTransform(Visitor.VisitorTransform): - """Drop ref-counting in safe places. - """ - visit_Node = Visitor.VisitorTransform.recurse_to_children - - def visit_ParallelAssignmentNode(self, node): - """ - Parallel swap assignments like 'a,b = b,a' are safe. - """ - left_names, right_names = [], [] - left_indices, right_indices = [], [] - temps = [] - - for stat in node.stats: - if isinstance(stat, Nodes.SingleAssignmentNode): - if not self._extract_operand(stat.lhs, left_names, - left_indices, temps): - return node - if not self._extract_operand(stat.rhs, right_names, - right_indices, temps): - return node - elif isinstance(stat, Nodes.CascadedAssignmentNode): - # FIXME - return node - else: - return node - - if left_names or right_names: - # lhs/rhs names must be a non-redundant permutation - lnames = [ path for path, n in left_names ] - rnames = [ path for path, n in right_names ] - if set(lnames) != set(rnames): - return node - if len(set(lnames)) != len(right_names): - return node - - if left_indices or right_indices: - # base name and index of index nodes must be a - # non-redundant permutation - lindices = [] - for lhs_node in left_indices: - index_id = self._extract_index_id(lhs_node) - if not index_id: - return node - lindices.append(index_id) - rindices = [] - for rhs_node in right_indices: - index_id = self._extract_index_id(rhs_node) - if not index_id: - return node - rindices.append(index_id) - - if set(lindices) != set(rindices): - return node - if len(set(lindices)) != len(right_indices): - return node - - # really supporting IndexNode requires support in - # __Pyx_GetItemInt(), so let's stop short for now - return node - - temp_args = [t.arg for t in temps] - for temp in temps: - temp.use_managed_ref = False - - for _, name_node in left_names + right_names: - if name_node not in temp_args: - name_node.use_managed_ref = False - - for index_node in left_indices + right_indices: - index_node.use_managed_ref = False - - return node - - def _extract_operand(self, node, names, indices, temps): - node = unwrap_node(node) - if not node.type.is_pyobject: - return False - if isinstance(node, ExprNodes.CoerceToTempNode): - temps.append(node) - node = node.arg - name_path = [] - obj_node = node + lhs = UtilNodes.ResultRefNode(node.operand1) + + conds = [] + temps = [] + for arg in args: + try: + # Trial optimisation to avoid redundant temp + # assignments. However, since is_simple() is meant to + # be called after type analysis, we ignore any errors + # and just play safe in that case. + is_simple_arg = arg.is_simple() + except Exception: + is_simple_arg = False + if not is_simple_arg: + # must evaluate all non-simple RHS before doing the comparisons + arg = UtilNodes.LetRefNode(arg) + temps.append(arg) + cond = ExprNodes.PrimaryCmpNode( + pos = node.pos, + operand1 = lhs, + operator = eq_or_neq, + operand2 = arg, + cascade = None) + conds.append(ExprNodes.TypecastNode( + pos = node.pos, + operand = cond, + type = PyrexTypes.c_bint_type)) + def concat(left, right): + return ExprNodes.BoolBinopNode( + pos = node.pos, + operator = conjunction, + operand1 = left, + operand2 = right) + + condition = reduce(concat, conds) + new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) + for temp in temps[::-1]: + new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) + return new_node + + visit_Node = Visitor.VisitorTransform.recurse_to_children + + +class DropRefcountingTransform(Visitor.VisitorTransform): + """Drop ref-counting in safe places. + """ + visit_Node = Visitor.VisitorTransform.recurse_to_children + + def visit_ParallelAssignmentNode(self, node): + """ + Parallel swap assignments like 'a,b = b,a' are safe. + """ + left_names, right_names = [], [] + left_indices, right_indices = [], [] + temps = [] + + for stat in node.stats: + if isinstance(stat, Nodes.SingleAssignmentNode): + if not self._extract_operand(stat.lhs, left_names, + left_indices, temps): + return node + if not self._extract_operand(stat.rhs, right_names, + right_indices, temps): + return node + elif isinstance(stat, Nodes.CascadedAssignmentNode): + # FIXME + return node + else: + return node + + if left_names or right_names: + # lhs/rhs names must be a non-redundant permutation + lnames = [ path for path, n in left_names ] + rnames = [ path for path, n in right_names ] + if set(lnames) != set(rnames): + return node + if len(set(lnames)) != len(right_names): + return node + + if left_indices or right_indices: + # base name and index of index nodes must be a + # non-redundant permutation + lindices = [] + for lhs_node in left_indices: + index_id = self._extract_index_id(lhs_node) + if not index_id: + return node + lindices.append(index_id) + rindices = [] + for rhs_node in right_indices: + index_id = self._extract_index_id(rhs_node) + if not index_id: + return node + rindices.append(index_id) + + if set(lindices) != set(rindices): + return node + if len(set(lindices)) != len(right_indices): + return node + + # really supporting IndexNode requires support in + # __Pyx_GetItemInt(), so let's stop short for now + return node + + temp_args = [t.arg for t in temps] + for temp in temps: + temp.use_managed_ref = False + + for _, name_node in left_names + right_names: + if name_node not in temp_args: + name_node.use_managed_ref = False + + for index_node in left_indices + right_indices: + index_node.use_managed_ref = False + + return node + + def _extract_operand(self, node, names, indices, temps): + node = unwrap_node(node) + if not node.type.is_pyobject: + return False + if isinstance(node, ExprNodes.CoerceToTempNode): + temps.append(node) + node = node.arg + name_path = [] + obj_node = node while obj_node.is_attribute: - if obj_node.is_py_attr: - return False - name_path.append(obj_node.member) - obj_node = obj_node.obj + if obj_node.is_py_attr: + return False + name_path.append(obj_node.member) + obj_node = obj_node.obj if obj_node.is_name: - name_path.append(obj_node.name) - names.append( ('.'.join(name_path[::-1]), node) ) + name_path.append(obj_node.name) + names.append( ('.'.join(name_path[::-1]), node) ) elif node.is_subscript: - if node.base.type != Builtin.list_type: - return False - if not node.index.type.is_int: - return False + if node.base.type != Builtin.list_type: + return False + if not node.index.type.is_int: + return False if not node.base.is_name: - return False - indices.append(node) - else: - return False - return True - - def _extract_index_id(self, index_node): - base = index_node.base - index = index_node.index - if isinstance(index, ExprNodes.NameNode): - index_val = index.name - elif isinstance(index, ExprNodes.ConstNode): - # FIXME: - return None - else: - return None - return (base.name, index_val) - - -class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): - """Optimize some common calls to builtin types *before* the type - analysis phase and *after* the declarations analysis phase. - - This transform cannot make use of any argument types, but it can - restructure the tree in a way that the type analysis phase can - respond to. - - Introducing C function calls here may not be a good idea. Move - them to the OptimizeBuiltinCalls transform instead, which runs - after type analysis. - """ - # only intercept on call nodes - visit_Node = Visitor.VisitorTransform.recurse_to_children - - def visit_SimpleCallNode(self, node): - self.visitchildren(node) - function = node.function - if not self._function_is_builtin_name(function): - return node - return self._dispatch_to_handler(node, function, node.args) - - def visit_GeneralCallNode(self, node): - self.visitchildren(node) - function = node.function - if not self._function_is_builtin_name(function): - return node - arg_tuple = node.positional_args - if not isinstance(arg_tuple, ExprNodes.TupleNode): - return node - args = arg_tuple.args - return self._dispatch_to_handler( - node, function, args, node.keyword_args) - - def _function_is_builtin_name(self, function): - if not function.is_name: - return False - env = self.current_env() - entry = env.lookup(function.name) - if entry is not env.builtin_scope().lookup_here(function.name): - return False - # if entry is None, it's at least an undeclared name, so likely builtin - return True - - def _dispatch_to_handler(self, node, function, args, kwargs=None): - if kwargs is None: - handler_name = '_handle_simple_function_%s' % function.name - else: - handler_name = '_handle_general_function_%s' % function.name - handle_call = getattr(self, handler_name, None) - if handle_call is not None: - if kwargs is None: - return handle_call(node, args) - else: - return handle_call(node, args, kwargs) - return node - - def _inject_capi_function(self, node, cname, func_type, utility_code=None): - node.function = ExprNodes.PythonCapiFunctionNode( - node.function.pos, node.function.name, cname, func_type, - utility_code = utility_code) - - def _error_wrong_arg_count(self, function_name, node, args, expected=None): - if not expected: # None or 0 - arg_str = '' - elif isinstance(expected, basestring) or expected > 1: - arg_str = '...' - elif expected == 1: - arg_str = 'x' - else: - arg_str = '' - if expected is not None: - expected_str = 'expected %s, ' % expected - else: - expected_str = '' - error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( - function_name, arg_str, expected_str, len(args))) - - # specific handlers for simple call nodes - - def _handle_simple_function_float(self, node, pos_args): - if not pos_args: - return ExprNodes.FloatNode(node.pos, value='0.0') - if len(pos_args) > 1: - self._error_wrong_arg_count('float', node, pos_args, 1) - arg_type = getattr(pos_args[0], 'type', None) - if arg_type in (PyrexTypes.c_double_type, Builtin.float_type): - return pos_args[0] - return node - - def _handle_simple_function_slice(self, node, pos_args): - arg_count = len(pos_args) - start = step = None - if arg_count == 1: - stop, = pos_args - elif arg_count == 2: - start, stop = pos_args - elif arg_count == 3: - start, stop, step = pos_args - else: - self._error_wrong_arg_count('slice', node, pos_args) - return node - return ExprNodes.SliceNode( - node.pos, - start=start or ExprNodes.NoneNode(node.pos), - stop=stop, - step=step or ExprNodes.NoneNode(node.pos)) - + return False + indices.append(node) + else: + return False + return True + + def _extract_index_id(self, index_node): + base = index_node.base + index = index_node.index + if isinstance(index, ExprNodes.NameNode): + index_val = index.name + elif isinstance(index, ExprNodes.ConstNode): + # FIXME: + return None + else: + return None + return (base.name, index_val) + + +class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): + """Optimize some common calls to builtin types *before* the type + analysis phase and *after* the declarations analysis phase. + + This transform cannot make use of any argument types, but it can + restructure the tree in a way that the type analysis phase can + respond to. + + Introducing C function calls here may not be a good idea. Move + them to the OptimizeBuiltinCalls transform instead, which runs + after type analysis. + """ + # only intercept on call nodes + visit_Node = Visitor.VisitorTransform.recurse_to_children + + def visit_SimpleCallNode(self, node): + self.visitchildren(node) + function = node.function + if not self._function_is_builtin_name(function): + return node + return self._dispatch_to_handler(node, function, node.args) + + def visit_GeneralCallNode(self, node): + self.visitchildren(node) + function = node.function + if not self._function_is_builtin_name(function): + return node + arg_tuple = node.positional_args + if not isinstance(arg_tuple, ExprNodes.TupleNode): + return node + args = arg_tuple.args + return self._dispatch_to_handler( + node, function, args, node.keyword_args) + + def _function_is_builtin_name(self, function): + if not function.is_name: + return False + env = self.current_env() + entry = env.lookup(function.name) + if entry is not env.builtin_scope().lookup_here(function.name): + return False + # if entry is None, it's at least an undeclared name, so likely builtin + return True + + def _dispatch_to_handler(self, node, function, args, kwargs=None): + if kwargs is None: + handler_name = '_handle_simple_function_%s' % function.name + else: + handler_name = '_handle_general_function_%s' % function.name + handle_call = getattr(self, handler_name, None) + if handle_call is not None: + if kwargs is None: + return handle_call(node, args) + else: + return handle_call(node, args, kwargs) + return node + + def _inject_capi_function(self, node, cname, func_type, utility_code=None): + node.function = ExprNodes.PythonCapiFunctionNode( + node.function.pos, node.function.name, cname, func_type, + utility_code = utility_code) + + def _error_wrong_arg_count(self, function_name, node, args, expected=None): + if not expected: # None or 0 + arg_str = '' + elif isinstance(expected, basestring) or expected > 1: + arg_str = '...' + elif expected == 1: + arg_str = 'x' + else: + arg_str = '' + if expected is not None: + expected_str = 'expected %s, ' % expected + else: + expected_str = '' + error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( + function_name, arg_str, expected_str, len(args))) + + # specific handlers for simple call nodes + + def _handle_simple_function_float(self, node, pos_args): + if not pos_args: + return ExprNodes.FloatNode(node.pos, value='0.0') + if len(pos_args) > 1: + self._error_wrong_arg_count('float', node, pos_args, 1) + arg_type = getattr(pos_args[0], 'type', None) + if arg_type in (PyrexTypes.c_double_type, Builtin.float_type): + return pos_args[0] + return node + + def _handle_simple_function_slice(self, node, pos_args): + arg_count = len(pos_args) + start = step = None + if arg_count == 1: + stop, = pos_args + elif arg_count == 2: + start, stop = pos_args + elif arg_count == 3: + start, stop, step = pos_args + else: + self._error_wrong_arg_count('slice', node, pos_args) + return node + return ExprNodes.SliceNode( + node.pos, + start=start or ExprNodes.NoneNode(node.pos), + stop=stop, + step=step or ExprNodes.NoneNode(node.pos)) + def _handle_simple_function_ord(self, node, pos_args): """Unpack ord('X'). """ @@ -1640,59 +1640,59 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): constant_result=ord(arg.unicode_value) ) return node - + # sequence processing - - def _handle_simple_function_all(self, node, pos_args): - """Transform - + + def _handle_simple_function_all(self, node, pos_args): + """Transform + _result = all(p(x) for L in LL for x in L) - - into - - for L in LL: - for x in L: + + into + + for L in LL: + for x in L: if not p(x): return False - else: + else: return True - """ - return self._transform_any_all(node, pos_args, False) - - def _handle_simple_function_any(self, node, pos_args): - """Transform - + """ + return self._transform_any_all(node, pos_args, False) + + def _handle_simple_function_any(self, node, pos_args): + """Transform + _result = any(p(x) for L in LL for x in L) - - into - - for L in LL: - for x in L: + + into + + for L in LL: + for x in L: if p(x): return True - else: + else: return False - """ - return self._transform_any_all(node, pos_args, True) - - def _transform_any_all(self, node, pos_args, is_any): - if len(pos_args) != 1: - return node - if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): - return node - gen_expr_node = pos_args[0] + """ + return self._transform_any_all(node, pos_args, True) + + def _transform_any_all(self, node, pos_args, is_any): + if len(pos_args) != 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] generator_body = gen_expr_node.def_node.gbody loop_node = generator_body.body yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) - if yield_expression is None: - return node - - if is_any: - condition = yield_expression - else: + if yield_expression is None: + return node + + if is_any: + condition = yield_expression + else: condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression) - - test_node = Nodes.IfStatNode( + + test_node = Nodes.IfStatNode( yield_expression.pos, else_clause=None, if_clauses=[ Nodes.IfClauseNode( yield_expression.pos, @@ -1703,27 +1703,27 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): )] ) loop_node.else_clause = Nodes.ReturnStatNode( - node.pos, + node.pos, value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any)) - + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node) - - return ExprNodes.InlinedGeneratorExpressionNode( + + return ExprNodes.InlinedGeneratorExpressionNode( gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all') - + PySequence_List_func_type = PyrexTypes.CFuncType( Builtin.list_type, [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) - def _handle_simple_function_sorted(self, node, pos_args): - """Transform sorted(genexpr) and sorted([listcomp]) into - [listcomp].sort(). CPython just reads the iterable into a - list and calls .sort() on it. Expanding the iterable in a - listcomp is still faster and the result can be sorted in - place. - """ - if len(pos_args) != 1: - return node + def _handle_simple_function_sorted(self, node, pos_args): + """Transform sorted(genexpr) and sorted([listcomp]) into + [listcomp].sort(). CPython just reads the iterable into a + list and calls .sort() on it. Expanding the iterable in a + listcomp is still faster and the result can be sorted in + place. + """ + if len(pos_args) != 1: + return node arg = pos_args[0] if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type: @@ -1732,11 +1732,11 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): elif isinstance(arg, ExprNodes.GeneratorExpressionNode): gen_expr_node = arg - loop_node = gen_expr_node.loop + loop_node = gen_expr_node.loop yield_statements = _find_yield_statements(loop_node) if not yield_statements: - return node - + return node + list_node = ExprNodes.InlinedGeneratorExpressionNode( node.pos, gen_expr_node, orig_func='sorted', comprehension_type=Builtin.list_type) @@ -1747,187 +1747,187 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): expr=yield_expression, target=list_node.target) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - + elif arg.is_sequence_constructor: # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list, # so starting off with a fresh one is more efficient. list_node = loop_node = arg.as_list() - else: + else: # Interestingly, PySequence_List works on a lot of non-sequence # things as well. list_node = loop_node = ExprNodes.PythonCapiCallNode( node.pos, "PySequence_List", self.PySequence_List_func_type, args=pos_args, is_temp=True) - - result_node = UtilNodes.ResultRefNode( + + result_node = UtilNodes.ResultRefNode( pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False) list_assign_node = Nodes.SingleAssignmentNode( node.pos, lhs=result_node, rhs=list_node, first=True) - - sort_method = ExprNodes.AttributeNode( + + sort_method = ExprNodes.AttributeNode( node.pos, obj=result_node, attribute=EncodedString('sort'), - # entry ? type ? + # entry ? type ? needs_none_check=False) - sort_node = Nodes.ExprStatNode( + sort_node = Nodes.ExprStatNode( node.pos, expr=ExprNodes.SimpleCallNode( node.pos, function=sort_method, args=[])) - - sort_node.analyse_declarations(self.current_env()) - - return UtilNodes.TempResultFromStatNode( - result_node, + + sort_node.analyse_declarations(self.current_env()) + + return UtilNodes.TempResultFromStatNode( + result_node, Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node])) - + def __handle_simple_function_sum(self, node, pos_args): - """Transform sum(genexpr) into an equivalent inlined aggregation loop. - """ - if len(pos_args) not in (1,2): - return node - if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode, - ExprNodes.ComprehensionNode)): - return node - gen_expr_node = pos_args[0] - loop_node = gen_expr_node.loop - - if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): + """Transform sum(genexpr) into an equivalent inlined aggregation loop. + """ + if len(pos_args) not in (1,2): + return node + if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode, + ExprNodes.ComprehensionNode)): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + + if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) # FIXME: currently nonfunctional yield_expression = None - if yield_expression is None: - return node + if yield_expression is None: + return node else: # ComprehensionNode - yield_stat_node = gen_expr_node.append - yield_expression = yield_stat_node.expr - try: - if not yield_expression.is_literal or not yield_expression.type.is_int: - return node - except AttributeError: - return node # in case we don't have a type yet - # special case: old Py2 backwards compatible "sum([int_const for ...])" - # can safely be unpacked into a genexpr - - if len(pos_args) == 1: - start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) - else: - start = pos_args[1] - - result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) - add_node = Nodes.SingleAssignmentNode( - yield_expression.pos, - lhs = result_ref, - rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) - ) - + yield_stat_node = gen_expr_node.append + yield_expression = yield_stat_node.expr + try: + if not yield_expression.is_literal or not yield_expression.type.is_int: + return node + except AttributeError: + return node # in case we don't have a type yet + # special case: old Py2 backwards compatible "sum([int_const for ...])" + # can safely be unpacked into a genexpr + + if len(pos_args) == 1: + start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) + else: + start = pos_args[1] + + result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) + add_node = Nodes.SingleAssignmentNode( + yield_expression.pos, + lhs = result_ref, + rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) + ) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node) - - exec_code = Nodes.StatListNode( - node.pos, - stats = [ - Nodes.SingleAssignmentNode( - start.pos, - lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), - rhs = start, - first = True), - loop_node - ]) - - return ExprNodes.InlinedGeneratorExpressionNode( - gen_expr_node.pos, loop = exec_code, result_node = result_ref, - expr_scope = gen_expr_node.expr_scope, orig_func = 'sum', - has_local_scope = gen_expr_node.has_local_scope) - - def _handle_simple_function_min(self, node, pos_args): - return self._optimise_min_max(node, pos_args, '<') - - def _handle_simple_function_max(self, node, pos_args): - return self._optimise_min_max(node, pos_args, '>') - - def _optimise_min_max(self, node, args, operator): - """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. - """ - if len(args) <= 1: - if len(args) == 1 and args[0].is_sequence_constructor: - args = args[0].args + + exec_code = Nodes.StatListNode( + node.pos, + stats = [ + Nodes.SingleAssignmentNode( + start.pos, + lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), + rhs = start, + first = True), + loop_node + ]) + + return ExprNodes.InlinedGeneratorExpressionNode( + gen_expr_node.pos, loop = exec_code, result_node = result_ref, + expr_scope = gen_expr_node.expr_scope, orig_func = 'sum', + has_local_scope = gen_expr_node.has_local_scope) + + def _handle_simple_function_min(self, node, pos_args): + return self._optimise_min_max(node, pos_args, '<') + + def _handle_simple_function_max(self, node, pos_args): + return self._optimise_min_max(node, pos_args, '>') + + def _optimise_min_max(self, node, args, operator): + """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. + """ + if len(args) <= 1: + if len(args) == 1 and args[0].is_sequence_constructor: + args = args[0].args if len(args) <= 1: - # leave this to Python - return node - - cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:])) - - last_result = args[0] - for arg_node in cascaded_nodes: - result_ref = UtilNodes.ResultRefNode(last_result) - last_result = ExprNodes.CondExprNode( - arg_node.pos, - true_val = arg_node, - false_val = result_ref, - test = ExprNodes.PrimaryCmpNode( - arg_node.pos, - operand1 = arg_node, - operator = operator, - operand2 = result_ref, - ) - ) - last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result) - - for ref_node in cascaded_nodes[::-1]: - last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) - - return last_result - + # leave this to Python + return node + + cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:])) + + last_result = args[0] + for arg_node in cascaded_nodes: + result_ref = UtilNodes.ResultRefNode(last_result) + last_result = ExprNodes.CondExprNode( + arg_node.pos, + true_val = arg_node, + false_val = result_ref, + test = ExprNodes.PrimaryCmpNode( + arg_node.pos, + operand1 = arg_node, + operator = operator, + operand2 = result_ref, + ) + ) + last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result) + + for ref_node in cascaded_nodes[::-1]: + last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) + + return last_result + # builtin type creation - def _DISABLED_handle_simple_function_tuple(self, node, pos_args): - if not pos_args: - return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) - # This is a bit special - for iterables (including genexps), - # Python actually overallocates and resizes a newly created - # tuple incrementally while reading items, which we can't - # easily do without explicit node support. Instead, we read - # the items into a list and then copy them into a tuple of the - # final size. This takes up to twice as much memory, but will - # have to do until we have real support for genexps. - result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) - if result is not node: - return ExprNodes.AsTupleNode(node.pos, arg=result) - return node - - def _handle_simple_function_frozenset(self, node, pos_args): - """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient. - """ - if len(pos_args) != 1: - return node - if pos_args[0].is_sequence_constructor and not pos_args[0].args: - del pos_args[0] - elif isinstance(pos_args[0], ExprNodes.ListNode): - pos_args[0] = pos_args[0].as_tuple() - return node - - def _handle_simple_function_list(self, node, pos_args): - if not pos_args: - return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) - return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) - - def _handle_simple_function_set(self, node, pos_args): - if not pos_args: - return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) - return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) - - def _transform_list_set_genexpr(self, node, pos_args, target_type): + def _DISABLED_handle_simple_function_tuple(self, node, pos_args): + if not pos_args: + return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) + # This is a bit special - for iterables (including genexps), + # Python actually overallocates and resizes a newly created + # tuple incrementally while reading items, which we can't + # easily do without explicit node support. Instead, we read + # the items into a list and then copy them into a tuple of the + # final size. This takes up to twice as much memory, but will + # have to do until we have real support for genexps. + result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) + if result is not node: + return ExprNodes.AsTupleNode(node.pos, arg=result) + return node + + def _handle_simple_function_frozenset(self, node, pos_args): + """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient. + """ + if len(pos_args) != 1: + return node + if pos_args[0].is_sequence_constructor and not pos_args[0].args: + del pos_args[0] + elif isinstance(pos_args[0], ExprNodes.ListNode): + pos_args[0] = pos_args[0].as_tuple() + return node + + def _handle_simple_function_list(self, node, pos_args): + if not pos_args: + return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) + return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type) + + def _handle_simple_function_set(self, node, pos_args): + if not pos_args: + return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) + return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) + + def _transform_list_set_genexpr(self, node, pos_args, target_type): """Replace set(genexpr) and list(genexpr) by an inlined comprehension. - """ - if len(pos_args) > 1: - return node - if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): - return node - gen_expr_node = pos_args[0] - loop_node = gen_expr_node.loop - + """ + if len(pos_args) > 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + yield_statements = _find_yield_statements(loop_node) if not yield_statements: - return node - + return node + result_node = ExprNodes.InlinedGeneratorExpressionNode( node.pos, gen_expr_node, orig_func='set' if target_type is Builtin.set_type else 'list', @@ -1939,31 +1939,31 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): expr=yield_expression, target=result_node.target) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - + return result_node - - def _handle_simple_function_dict(self, node, pos_args): + + def _handle_simple_function_dict(self, node, pos_args): """Replace dict( (a,b) for ... ) by an inlined { a:b for ... } - """ - if len(pos_args) == 0: - return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) - if len(pos_args) > 1: - return node - if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): - return node - gen_expr_node = pos_args[0] - loop_node = gen_expr_node.loop - + """ + if len(pos_args) == 0: + return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) + if len(pos_args) > 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + yield_statements = _find_yield_statements(loop_node) if not yield_statements: - return node - + return node + for yield_expression, _ in yield_statements: if not isinstance(yield_expression, ExprNodes.TupleNode): return node if len(yield_expression.args) != 2: return node - + result_node = ExprNodes.InlinedGeneratorExpressionNode( node.pos, gen_expr_node, orig_func='dict', comprehension_type=Builtin.dict_type) @@ -1975,90 +1975,90 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): value_expr=yield_expression.args[1], target=result_node.target) Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - + return result_node - - # specific handlers for general call nodes - - def _handle_general_function_dict(self, node, pos_args, kwargs): - """Replace dict(a=b,c=d,...) by the underlying keyword dict - construction which is done anyway. - """ - if len(pos_args) > 0: - return node - if not isinstance(kwargs, ExprNodes.DictNode): - return node - return kwargs - - -class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): - visit_Node = Visitor.VisitorTransform.recurse_to_children - - def get_constant_value_node(self, name_node): - if name_node.cf_state is None: - return None - if name_node.cf_state.cf_is_null: - return None - entry = self.current_env().lookup(name_node.name) - if not entry or (not entry.cf_assignments - or len(entry.cf_assignments) != 1): - # not just a single assignment in all closures - return None - return entry.cf_assignments[0].rhs - - def visit_SimpleCallNode(self, node): - self.visitchildren(node) - if not self.current_directives.get('optimize.inline_defnode_calls'): - return node - function_name = node.function - if not function_name.is_name: - return node - function = self.get_constant_value_node(function_name) - if not isinstance(function, ExprNodes.PyCFunctionNode): - return node - inlined = ExprNodes.InlinedDefNodeCallNode( - node.pos, function_name=function_name, - function=function, args=node.args) - if inlined.can_be_inlined(): - return self.replace(node, inlined) - return node - - + + # specific handlers for general call nodes + + def _handle_general_function_dict(self, node, pos_args, kwargs): + """Replace dict(a=b,c=d,...) by the underlying keyword dict + construction which is done anyway. + """ + if len(pos_args) > 0: + return node + if not isinstance(kwargs, ExprNodes.DictNode): + return node + return kwargs + + +class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): + visit_Node = Visitor.VisitorTransform.recurse_to_children + + def get_constant_value_node(self, name_node): + if name_node.cf_state is None: + return None + if name_node.cf_state.cf_is_null: + return None + entry = self.current_env().lookup(name_node.name) + if not entry or (not entry.cf_assignments + or len(entry.cf_assignments) != 1): + # not just a single assignment in all closures + return None + return entry.cf_assignments[0].rhs + + def visit_SimpleCallNode(self, node): + self.visitchildren(node) + if not self.current_directives.get('optimize.inline_defnode_calls'): + return node + function_name = node.function + if not function_name.is_name: + return node + function = self.get_constant_value_node(function_name) + if not isinstance(function, ExprNodes.PyCFunctionNode): + return node + inlined = ExprNodes.InlinedDefNodeCallNode( + node.pos, function_name=function_name, + function=function, args=node.args) + if inlined.can_be_inlined(): + return self.replace(node, inlined) + return node + + class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, Visitor.MethodDispatcherTransform): - """Optimize some common methods calls and instantiation patterns - for builtin types *after* the type analysis phase. - - Running after type analysis, this transform can only perform - function replacements that do not alter the function return type - in a way that was not anticipated by the type analysis. - """ - ### cleanup to avoid redundant coercions to/from Python types - + """Optimize some common methods calls and instantiation patterns + for builtin types *after* the type analysis phase. + + Running after type analysis, this transform can only perform + function replacements that do not alter the function return type + in a way that was not anticipated by the type analysis. + """ + ### cleanup to avoid redundant coercions to/from Python types + def visit_PyTypeTestNode(self, node): - """Flatten redundant type checks after tree changes. - """ - self.visitchildren(node) + """Flatten redundant type checks after tree changes. + """ + self.visitchildren(node) return node.reanalyse() - - def _visit_TypecastNode(self, node): - # disabled - the user may have had a reason to put a type - # cast, even if it looks redundant to Cython - """ - Drop redundant type casts. - """ - self.visitchildren(node) - if node.type == node.operand.type: - return node.operand - return node - - def visit_ExprStatNode(self, node): - """ + + def _visit_TypecastNode(self, node): + # disabled - the user may have had a reason to put a type + # cast, even if it looks redundant to Cython + """ + Drop redundant type casts. + """ + self.visitchildren(node) + if node.type == node.operand.type: + return node.operand + return node + + def visit_ExprStatNode(self, node): + """ Drop dead code and useless coercions. - """ - self.visitchildren(node) - if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode): - node.expr = node.expr.arg + """ + self.visitchildren(node) + if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode): + node.expr = node.expr.arg expr = node.expr if expr is None or expr.is_none or expr.is_literal: # Expression was removed or is dead code => remove ExprStatNode as well. @@ -2066,20 +2066,20 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg): # Ignore dead references to local variables etc. return None - return node - - def visit_CoerceToBooleanNode(self, node): - """Drop redundant conversion nodes after tree changes. - """ - self.visitchildren(node) - arg = node.arg - if isinstance(arg, ExprNodes.PyTypeTestNode): - arg = arg.arg - if isinstance(arg, ExprNodes.CoerceToPyTypeNode): - if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type): - return arg.arg.coerce_to_boolean(self.current_env()) - return node - + return node + + def visit_CoerceToBooleanNode(self, node): + """Drop redundant conversion nodes after tree changes. + """ + self.visitchildren(node) + arg = node.arg + if isinstance(arg, ExprNodes.PyTypeTestNode): + arg = arg.arg + if isinstance(arg, ExprNodes.CoerceToPyTypeNode): + if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type): + return arg.arg.coerce_to_boolean(self.current_env()) + return node + PyNumber_Float_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) @@ -2107,80 +2107,80 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ).coerce_to(node.type, self.current_env()) return node - def visit_CoerceFromPyTypeNode(self, node): - """Drop redundant conversion nodes after tree changes. - - Also, optimise away calls to Python's builtin int() and - float() if the result is going to be coerced back into a C - type anyway. - """ - self.visitchildren(node) - arg = node.arg - if not arg.type.is_pyobject: - # no Python conversion left at all, just do a C coercion instead + def visit_CoerceFromPyTypeNode(self, node): + """Drop redundant conversion nodes after tree changes. + + Also, optimise away calls to Python's builtin int() and + float() if the result is going to be coerced back into a C + type anyway. + """ + self.visitchildren(node) + arg = node.arg + if not arg.type.is_pyobject: + # no Python conversion left at all, just do a C coercion instead if node.type != arg.type: arg = arg.coerce_to(node.type, self.current_env()) return arg - if isinstance(arg, ExprNodes.PyTypeTestNode): - arg = arg.arg - if arg.is_literal: - if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or - node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or - node.type.is_int and isinstance(arg, ExprNodes.BoolNode)): - return arg.coerce_to(node.type, self.current_env()) - elif isinstance(arg, ExprNodes.CoerceToPyTypeNode): - if arg.type is PyrexTypes.py_object_type: - if node.type.assignable_from(arg.arg.type): - # completely redundant C->Py->C coercion - return arg.arg.coerce_to(node.type, self.current_env()) + if isinstance(arg, ExprNodes.PyTypeTestNode): + arg = arg.arg + if arg.is_literal: + if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or + node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or + node.type.is_int and isinstance(arg, ExprNodes.BoolNode)): + return arg.coerce_to(node.type, self.current_env()) + elif isinstance(arg, ExprNodes.CoerceToPyTypeNode): + if arg.type is PyrexTypes.py_object_type: + if node.type.assignable_from(arg.arg.type): + # completely redundant C->Py->C coercion + return arg.arg.coerce_to(node.type, self.current_env()) elif arg.type is Builtin.unicode_type: if arg.arg.type.is_unicode_char and node.type.is_unicode_char: return arg.arg.coerce_to(node.type, self.current_env()) - elif isinstance(arg, ExprNodes.SimpleCallNode): - if node.type.is_int or node.type.is_float: - return self._optimise_numeric_cast_call(node, arg) + elif isinstance(arg, ExprNodes.SimpleCallNode): + if node.type.is_int or node.type.is_float: + return self._optimise_numeric_cast_call(node, arg) elif arg.is_subscript: - index_node = arg.index - if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): - index_node = index_node.arg - if index_node.type.is_int: - return self._optimise_int_indexing(node, arg, index_node) - return node - - PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_char_type, [ - PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None), - PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None), - ], - exception_value = "((char)-1)", - exception_check = True) - - def _optimise_int_indexing(self, coerce_node, arg, index_node): - env = self.current_env() - bound_check_bool = env.directives['boundscheck'] and 1 or 0 - if arg.base.type is Builtin.bytes_type: - if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type): - # bytes[index] -> char - bound_check_node = ExprNodes.IntNode( - coerce_node.pos, value=str(bound_check_bool), - constant_result=bound_check_bool) - node = ExprNodes.PythonCapiCallNode( - coerce_node.pos, "__Pyx_PyBytes_GetItemInt", - self.PyBytes_GetItemInt_func_type, - args=[ - arg.base.as_none_safe_node("'NoneType' object is not subscriptable"), - index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env), - bound_check_node, - ], - is_temp=True, - utility_code=UtilityCode.load_cached( - 'bytes_index', 'StringTools.c')) - if coerce_node.type is not PyrexTypes.c_char_type: - node = node.coerce_to(coerce_node.type, env) - return node - return coerce_node - + index_node = arg.index + if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): + index_node = index_node.arg + if index_node.type.is_int: + return self._optimise_int_indexing(node, arg, index_node) + return node + + PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_char_type, [ + PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None), + PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None), + ], + exception_value = "((char)-1)", + exception_check = True) + + def _optimise_int_indexing(self, coerce_node, arg, index_node): + env = self.current_env() + bound_check_bool = env.directives['boundscheck'] and 1 or 0 + if arg.base.type is Builtin.bytes_type: + if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type): + # bytes[index] -> char + bound_check_node = ExprNodes.IntNode( + coerce_node.pos, value=str(bound_check_bool), + constant_result=bound_check_bool) + node = ExprNodes.PythonCapiCallNode( + coerce_node.pos, "__Pyx_PyBytes_GetItemInt", + self.PyBytes_GetItemInt_func_type, + args=[ + arg.base.as_none_safe_node("'NoneType' object is not subscriptable"), + index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env), + bound_check_node, + ], + is_temp=True, + utility_code=UtilityCode.load_cached( + 'bytes_index', 'StringTools.c')) + if coerce_node.type is not PyrexTypes.c_char_type: + node = node.coerce_to(coerce_node.type, env) + return node + return coerce_node + float_float_func_types = dict( (float_type, PyrexTypes.CFuncType( float_type, [ @@ -2188,8 +2188,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ])) for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type)) - def _optimise_numeric_cast_call(self, node, arg): - function = arg.function + def _optimise_numeric_cast_call(self, node, arg): + function = arg.function args = None if isinstance(arg, ExprNodes.PythonCapiCallNode): args = arg.args @@ -2198,19 +2198,19 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, args = arg.arg_tuple.args if args is None or len(args) != 1: - return node - func_arg = args[0] - if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): - func_arg = func_arg.arg - elif func_arg.type.is_pyobject: + return node + func_arg = args[0] + if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): + func_arg = func_arg.arg + elif func_arg.type.is_pyobject: # play it safe: Python conversion might work on all sorts of things - return node + return node - if function.name == 'int': - if func_arg.type.is_int or node.type.is_int: - if func_arg.type == node.type: - return func_arg - elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: + if function.name == 'int': + if func_arg.type.is_int or node.type.is_int: + if func_arg.type == node.type: + return func_arg + elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type) elif func_arg.type.is_float and node.type.is_numeric: if func_arg.type.math_h_modifier == 'l': @@ -2226,80 +2226,80 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, is_temp=node.is_temp, result_is_used=node.result_is_used, ).coerce_to(node.type, self.current_env()) - elif function.name == 'float': - if func_arg.type.is_float or node.type.is_float: - if func_arg.type == node.type: - return func_arg - elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: - return ExprNodes.TypecastNode( - node.pos, operand=func_arg, type=node.type) - return node - - def _error_wrong_arg_count(self, function_name, node, args, expected=None): - if not expected: # None or 0 - arg_str = '' - elif isinstance(expected, basestring) or expected > 1: - arg_str = '...' - elif expected == 1: - arg_str = 'x' - else: - arg_str = '' - if expected is not None: - expected_str = 'expected %s, ' % expected - else: - expected_str = '' - error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( - function_name, arg_str, expected_str, len(args))) - - ### generic fallbacks - - def _handle_function(self, node, function_name, function, arg_list, kwargs): - return node - - def _handle_method(self, node, type_name, attr_name, function, - arg_list, is_unbound_method, kwargs): - """ - Try to inject C-API calls for unbound method calls to builtin types. - While the method declarations in Builtin.py already handle this, we - can additionally resolve bound and unbound methods here that were - assigned to variables ahead of time. - """ - if kwargs: - return node - if not function or not function.is_attribute or not function.obj.is_name: - # cannot track unbound method calls over more than one indirection as - # the names might have been reassigned in the meantime - return node - type_entry = self.current_env().lookup(type_name) - if not type_entry: - return node - method = ExprNodes.AttributeNode( - node.function.pos, - obj=ExprNodes.NameNode( - function.pos, - name=type_name, - entry=type_entry, - type=type_entry.type), - attribute=attr_name, + elif function.name == 'float': + if func_arg.type.is_float or node.type.is_float: + if func_arg.type == node.type: + return func_arg + elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: + return ExprNodes.TypecastNode( + node.pos, operand=func_arg, type=node.type) + return node + + def _error_wrong_arg_count(self, function_name, node, args, expected=None): + if not expected: # None or 0 + arg_str = '' + elif isinstance(expected, basestring) or expected > 1: + arg_str = '...' + elif expected == 1: + arg_str = 'x' + else: + arg_str = '' + if expected is not None: + expected_str = 'expected %s, ' % expected + else: + expected_str = '' + error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % ( + function_name, arg_str, expected_str, len(args))) + + ### generic fallbacks + + def _handle_function(self, node, function_name, function, arg_list, kwargs): + return node + + def _handle_method(self, node, type_name, attr_name, function, + arg_list, is_unbound_method, kwargs): + """ + Try to inject C-API calls for unbound method calls to builtin types. + While the method declarations in Builtin.py already handle this, we + can additionally resolve bound and unbound methods here that were + assigned to variables ahead of time. + """ + if kwargs: + return node + if not function or not function.is_attribute or not function.obj.is_name: + # cannot track unbound method calls over more than one indirection as + # the names might have been reassigned in the meantime + return node + type_entry = self.current_env().lookup(type_name) + if not type_entry: + return node + method = ExprNodes.AttributeNode( + node.function.pos, + obj=ExprNodes.NameNode( + function.pos, + name=type_name, + entry=type_entry, + type=type_entry.type), + attribute=attr_name, is_called=True).analyse_as_type_attribute(self.current_env()) - if method is None: + if method is None: return self._optimise_generic_builtin_method_call( node, attr_name, function, arg_list, is_unbound_method) - args = node.args - if args is None and node.arg_tuple: - args = node.arg_tuple.args - call_node = ExprNodes.SimpleCallNode( - node.pos, - function=method, - args=args) - if not is_unbound_method: - call_node.self = function.obj - call_node.analyse_c_function_call(self.current_env()) - call_node.analysed = True - return call_node.coerce_to(node.type, self.current_env()) - - ### builtin types - + args = node.args + if args is None and node.arg_tuple: + args = node.arg_tuple.args + call_node = ExprNodes.SimpleCallNode( + node.pos, + function=method, + args=args) + if not is_unbound_method: + call_node.self = function.obj + call_node.analyse_c_function_call(self.current_env()) + call_node.analysed = True + return call_node.coerce_to(node.type, self.current_env()) + + ### builtin types + def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method): """ Try to inject an unbound method call for a call to a method of a known builtin type. @@ -2356,26 +2356,26 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return self._handle_simple_function_unicode(node, None, [node.value]) return node - PyDict_Copy_func_type = PyrexTypes.CFuncType( - Builtin.dict_type, [ - PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) - ]) - - def _handle_simple_function_dict(self, node, function, pos_args): - """Replace dict(some_dict) by PyDict_Copy(some_dict). - """ - if len(pos_args) != 1: - return node - arg = pos_args[0] - if arg.type is Builtin.dict_type: - arg = arg.as_none_safe_node("'NoneType' is not iterable") - return ExprNodes.PythonCapiCallNode( - node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, - args = [arg], - is_temp = node.is_temp - ) - return node - + PyDict_Copy_func_type = PyrexTypes.CFuncType( + Builtin.dict_type, [ + PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) + ]) + + def _handle_simple_function_dict(self, node, function, pos_args): + """Replace dict(some_dict) by PyDict_Copy(some_dict). + """ + if len(pos_args) != 1: + return node + arg = pos_args[0] + if arg.type is Builtin.dict_type: + arg = arg.as_none_safe_node("'NoneType' is not iterable") + return ExprNodes.PythonCapiCallNode( + node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, + args = [arg], + is_temp = node.is_temp + ) + return node + PySequence_List_func_type = PyrexTypes.CFuncType( Builtin.list_type, [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) @@ -2390,141 +2390,141 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, node.pos, "PySequence_List", self.PySequence_List_func_type, args=pos_args, is_temp=node.is_temp) - PyList_AsTuple_func_type = PyrexTypes.CFuncType( - Builtin.tuple_type, [ - PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) - ]) - - def _handle_simple_function_tuple(self, node, function, pos_args): + PyList_AsTuple_func_type = PyrexTypes.CFuncType( + Builtin.tuple_type, [ + PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) + ]) + + def _handle_simple_function_tuple(self, node, function, pos_args): """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple. - """ + """ if len(pos_args) != 1 or not node.is_temp: - return node - arg = pos_args[0] - if arg.type is Builtin.tuple_type and not arg.may_be_none(): - return arg + return node + arg = pos_args[0] + if arg.type is Builtin.tuple_type and not arg.may_be_none(): + return arg if arg.type is Builtin.list_type: pos_args[0] = arg.as_none_safe_node( "'NoneType' object is not iterable") - + return ExprNodes.PythonCapiCallNode( node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, args=pos_args, is_temp=node.is_temp) else: return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type) - - PySet_New_func_type = PyrexTypes.CFuncType( - Builtin.set_type, [ - PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) - ]) - - def _handle_simple_function_set(self, node, function, pos_args): - if len(pos_args) != 1: - return node - if pos_args[0].is_sequence_constructor: - # We can optimise set([x,y,z]) safely into a set literal, - # but only if we create all items before adding them - - # adding an item may raise an exception if it is not - # hashable, but creating the later items may have - # side-effects. - args = [] - temps = [] - for arg in pos_args[0].args: - if not arg.is_simple(): - arg = UtilNodes.LetRefNode(arg) - temps.append(arg) - args.append(arg) - result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) + + PySet_New_func_type = PyrexTypes.CFuncType( + Builtin.set_type, [ + PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) + ]) + + def _handle_simple_function_set(self, node, function, pos_args): + if len(pos_args) != 1: + return node + if pos_args[0].is_sequence_constructor: + # We can optimise set([x,y,z]) safely into a set literal, + # but only if we create all items before adding them - + # adding an item may raise an exception if it is not + # hashable, but creating the later items may have + # side-effects. + args = [] + temps = [] + for arg in pos_args[0].args: + if not arg.is_simple(): + arg = UtilNodes.LetRefNode(arg) + temps.append(arg) + args.append(arg) + result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) self.replace(node, result) - for temp in temps[::-1]: - result = UtilNodes.EvalWithTempExprNode(temp, result) - return result - else: - # PySet_New(it) is better than a generic Python call to set(it) + for temp in temps[::-1]: + result = UtilNodes.EvalWithTempExprNode(temp, result) + return result + else: + # PySet_New(it) is better than a generic Python call to set(it) return self.replace(node, ExprNodes.PythonCapiCallNode( - node.pos, "PySet_New", - self.PySet_New_func_type, - args=pos_args, - is_temp=node.is_temp, + node.pos, "PySet_New", + self.PySet_New_func_type, + args=pos_args, + is_temp=node.is_temp, py_name="set")) - - PyFrozenSet_New_func_type = PyrexTypes.CFuncType( - Builtin.frozenset_type, [ - PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) - ]) - - def _handle_simple_function_frozenset(self, node, function, pos_args): - if not pos_args: - pos_args = [ExprNodes.NullNode(node.pos)] - elif len(pos_args) > 1: - return node - elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none(): - return pos_args[0] - # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it) - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyFrozenSet_New", - self.PyFrozenSet_New_func_type, - args=pos_args, - is_temp=node.is_temp, - utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'), - py_name="frozenset") - - PyObject_AsDouble_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_double_type, [ - PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), - ], - exception_value = "((double)-1)", - exception_check = True) - - def _handle_simple_function_float(self, node, function, pos_args): - """Transform float() into either a C type cast or a faster C - function call. - """ - # Note: this requires the float() function to be typed as - # returning a C 'double' - if len(pos_args) == 0: - return ExprNodes.FloatNode( - node, value="0.0", constant_result=0.0 - ).coerce_to(Builtin.float_type, self.current_env()) - elif len(pos_args) != 1: - self._error_wrong_arg_count('float', node, pos_args, '0 or 1') - return node - func_arg = pos_args[0] - if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): - func_arg = func_arg.arg - if func_arg.type is PyrexTypes.c_double_type: - return func_arg - elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: - return ExprNodes.TypecastNode( - node.pos, operand=func_arg, type=node.type) - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyObject_AsDouble", - self.PyObject_AsDouble_func_type, - args = pos_args, - is_temp = node.is_temp, - utility_code = load_c_utility('pyobject_as_double'), - py_name = "float") - - PyNumber_Int_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) - ]) - + + PyFrozenSet_New_func_type = PyrexTypes.CFuncType( + Builtin.frozenset_type, [ + PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None) + ]) + + def _handle_simple_function_frozenset(self, node, function, pos_args): + if not pos_args: + pos_args = [ExprNodes.NullNode(node.pos)] + elif len(pos_args) > 1: + return node + elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none(): + return pos_args[0] + # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it) + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_PyFrozenSet_New", + self.PyFrozenSet_New_func_type, + args=pos_args, + is_temp=node.is_temp, + utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'), + py_name="frozenset") + + PyObject_AsDouble_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_double_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), + ], + exception_value = "((double)-1)", + exception_check = True) + + def _handle_simple_function_float(self, node, function, pos_args): + """Transform float() into either a C type cast or a faster C + function call. + """ + # Note: this requires the float() function to be typed as + # returning a C 'double' + if len(pos_args) == 0: + return ExprNodes.FloatNode( + node, value="0.0", constant_result=0.0 + ).coerce_to(Builtin.float_type, self.current_env()) + elif len(pos_args) != 1: + self._error_wrong_arg_count('float', node, pos_args, '0 or 1') + return node + func_arg = pos_args[0] + if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): + func_arg = func_arg.arg + if func_arg.type is PyrexTypes.c_double_type: + return func_arg + elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: + return ExprNodes.TypecastNode( + node.pos, operand=func_arg, type=node.type) + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_PyObject_AsDouble", + self.PyObject_AsDouble_func_type, + args = pos_args, + is_temp = node.is_temp, + utility_code = load_c_utility('pyobject_as_double'), + py_name = "float") + + PyNumber_Int_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) + ]) + PyInt_FromDouble_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None) ]) - def _handle_simple_function_int(self, node, function, pos_args): - """Transform int() into a faster C function call. - """ - if len(pos_args) == 0: - return ExprNodes.IntNode(node.pos, value="0", constant_result=0, - type=PyrexTypes.py_object_type) - elif len(pos_args) != 1: - return node # int(x, base) - func_arg = pos_args[0] - if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): + def _handle_simple_function_int(self, node, function, pos_args): + """Transform int() into a faster C function call. + """ + if len(pos_args) == 0: + return ExprNodes.IntNode(node.pos, value="0", constant_result=0, + type=PyrexTypes.py_object_type) + elif len(pos_args) != 1: + return node # int(x, base) + func_arg = pos_args[0] + if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): if func_arg.arg.type.is_float: return ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type, @@ -2532,49 +2532,49 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c")) else: return node # handled in visit_CoerceFromPyTypeNode() - if func_arg.type.is_pyobject and node.type.is_pyobject: - return ExprNodes.PythonCapiCallNode( + if func_arg.type.is_pyobject and node.type.is_pyobject: + return ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type, args=pos_args, is_temp=True, py_name='int') - return node - - def _handle_simple_function_bool(self, node, function, pos_args): - """Transform bool(x) into a type coercion to a boolean. - """ - if len(pos_args) == 0: - return ExprNodes.BoolNode( - node.pos, value=False, constant_result=False - ).coerce_to(Builtin.bool_type, self.current_env()) - elif len(pos_args) != 1: - self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') - return node - else: - # => !!<bint>(x) to make sure it's exactly 0 or 1 - operand = pos_args[0].coerce_to_boolean(self.current_env()) - operand = ExprNodes.NotNode(node.pos, operand = operand) - operand = ExprNodes.NotNode(node.pos, operand = operand) - # coerce back to Python object as that's the result we are expecting - return operand.coerce_to_pyobject(self.current_env()) - - ### builtin functions - - Pyx_strlen_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_size_t_type, [ + return node + + def _handle_simple_function_bool(self, node, function, pos_args): + """Transform bool(x) into a type coercion to a boolean. + """ + if len(pos_args) == 0: + return ExprNodes.BoolNode( + node.pos, value=False, constant_result=False + ).coerce_to(Builtin.bool_type, self.current_env()) + elif len(pos_args) != 1: + self._error_wrong_arg_count('bool', node, pos_args, '0 or 1') + return node + else: + # => !!<bint>(x) to make sure it's exactly 0 or 1 + operand = pos_args[0].coerce_to_boolean(self.current_env()) + operand = ExprNodes.NotNode(node.pos, operand = operand) + operand = ExprNodes.NotNode(node.pos, operand = operand) + # coerce back to Python object as that's the result we are expecting + return operand.coerce_to_pyobject(self.current_env()) + + ### builtin functions + + Pyx_strlen_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_size_t_type, [ PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) ]) - - Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_size_t_type, [ + + Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_size_t_type, [ PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None) ]) - - PyObject_Size_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ssize_t_type, [ - PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) + + PyObject_Size_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) ], - exception_value="-1") - - _map_to_capi_len_function = { + exception_value="-1") + + _map_to_capi_len_function = { Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH", Builtin.bytes_type: "PyBytes_GET_SIZE", Builtin.bytearray_type: 'PyByteArray_GET_SIZE', @@ -2584,31 +2584,31 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, Builtin.frozenset_type: "PySet_GET_SIZE", Builtin.dict_type: "PyDict_Size", }.get - - _ext_types_with_pysize = set(["cpython.array.array"]) - - def _handle_simple_function_len(self, node, function, pos_args): - """Replace len(char*) by the equivalent call to strlen(), - len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and - len(known_builtin_type) by an equivalent C-API call. - """ - if len(pos_args) != 1: - self._error_wrong_arg_count('len', node, pos_args, 1) - return node - arg = pos_args[0] - if isinstance(arg, ExprNodes.CoerceToPyTypeNode): - arg = arg.arg - if arg.type.is_string: - new_node = ExprNodes.PythonCapiCallNode( - node.pos, "strlen", self.Pyx_strlen_func_type, - args = [arg], - is_temp = node.is_temp, - utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c")) - elif arg.type.is_pyunicode_ptr: - new_node = ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type, - args = [arg], - is_temp = node.is_temp) + + _ext_types_with_pysize = set(["cpython.array.array"]) + + def _handle_simple_function_len(self, node, function, pos_args): + """Replace len(char*) by the equivalent call to strlen(), + len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and + len(known_builtin_type) by an equivalent C-API call. + """ + if len(pos_args) != 1: + self._error_wrong_arg_count('len', node, pos_args, 1) + return node + arg = pos_args[0] + if isinstance(arg, ExprNodes.CoerceToPyTypeNode): + arg = arg.arg + if arg.type.is_string: + new_node = ExprNodes.PythonCapiCallNode( + node.pos, "strlen", self.Pyx_strlen_func_type, + args = [arg], + is_temp = node.is_temp, + utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c")) + elif arg.type.is_pyunicode_ptr: + new_node = ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type, + args = [arg], + is_temp = node.is_temp) elif arg.type.is_memoryviewslice: func_type = PyrexTypes.CFuncType( PyrexTypes.c_size_t_type, [ @@ -2617,261 +2617,261 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, new_node = ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_MemoryView_Len", func_type, args=[arg], is_temp=node.is_temp) - elif arg.type.is_pyobject: - cfunc_name = self._map_to_capi_len_function(arg.type) - if cfunc_name is None: - arg_type = arg.type - if ((arg_type.is_extension_type or arg_type.is_builtin_type) - and arg_type.entry.qualified_name in self._ext_types_with_pysize): - cfunc_name = 'Py_SIZE' - else: - return node - arg = arg.as_none_safe_node( - "object of type 'NoneType' has no len()") - new_node = ExprNodes.PythonCapiCallNode( - node.pos, cfunc_name, self.PyObject_Size_func_type, + elif arg.type.is_pyobject: + cfunc_name = self._map_to_capi_len_function(arg.type) + if cfunc_name is None: + arg_type = arg.type + if ((arg_type.is_extension_type or arg_type.is_builtin_type) + and arg_type.entry.qualified_name in self._ext_types_with_pysize): + cfunc_name = 'Py_SIZE' + else: + return node + arg = arg.as_none_safe_node( + "object of type 'NoneType' has no len()") + new_node = ExprNodes.PythonCapiCallNode( + node.pos, cfunc_name, self.PyObject_Size_func_type, args=[arg], is_temp=node.is_temp) - elif arg.type.is_unicode_char: - return ExprNodes.IntNode(node.pos, value='1', constant_result=1, - type=node.type) - else: - return node - if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type): - new_node = new_node.coerce_to(node.type, self.current_env()) - return new_node - - Pyx_Type_func_type = PyrexTypes.CFuncType( - Builtin.type_type, [ - PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) - ]) - - def _handle_simple_function_type(self, node, function, pos_args): - """Replace type(o) by a macro call to Py_TYPE(o). - """ - if len(pos_args) != 1: - return node - node = ExprNodes.PythonCapiCallNode( - node.pos, "Py_TYPE", self.Pyx_Type_func_type, - args = pos_args, - is_temp = False) - return ExprNodes.CastNode(node, PyrexTypes.py_object_type) - - Py_type_check_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_bint_type, [ - PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None) - ]) - - def _handle_simple_function_isinstance(self, node, function, pos_args): - """Replace isinstance() checks against builtin types by the - corresponding C-API call. - """ - if len(pos_args) != 2: - return node - arg, types = pos_args + elif arg.type.is_unicode_char: + return ExprNodes.IntNode(node.pos, value='1', constant_result=1, + type=node.type) + else: + return node + if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type): + new_node = new_node.coerce_to(node.type, self.current_env()) + return new_node + + Pyx_Type_func_type = PyrexTypes.CFuncType( + Builtin.type_type, [ + PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) + ]) + + def _handle_simple_function_type(self, node, function, pos_args): + """Replace type(o) by a macro call to Py_TYPE(o). + """ + if len(pos_args) != 1: + return node + node = ExprNodes.PythonCapiCallNode( + node.pos, "Py_TYPE", self.Pyx_Type_func_type, + args = pos_args, + is_temp = False) + return ExprNodes.CastNode(node, PyrexTypes.py_object_type) + + Py_type_check_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None) + ]) + + def _handle_simple_function_isinstance(self, node, function, pos_args): + """Replace isinstance() checks against builtin types by the + corresponding C-API call. + """ + if len(pos_args) != 2: + return node + arg, types = pos_args temps = [] - if isinstance(types, ExprNodes.TupleNode): - types = types.args + if isinstance(types, ExprNodes.TupleNode): + types = types.args if len(types) == 1 and not types[0].type is Builtin.type_type: return node # nothing to improve here - if arg.is_attribute or not arg.is_simple(): + if arg.is_attribute or not arg.is_simple(): arg = UtilNodes.ResultRefNode(arg) temps.append(arg) - elif types.type is Builtin.type_type: - types = [types] - else: - return node - - tests = [] - test_nodes = [] - env = self.current_env() - for test_type_node in types: - builtin_type = None - if test_type_node.is_name: - if test_type_node.entry: - entry = env.lookup(test_type_node.entry.name) - if entry and entry.type and entry.type.is_builtin_type: - builtin_type = entry.type - if builtin_type is Builtin.type_type: - # all types have type "type", but there's only one 'type' - if entry.name != 'type' or not ( - entry.scope and entry.scope.is_builtin_scope): - builtin_type = None - if builtin_type is not None: - type_check_function = entry.type.type_check_function(exact=False) - if type_check_function in tests: - continue - tests.append(type_check_function) - type_check_args = [arg] - elif test_type_node.type is Builtin.type_type: - type_check_function = '__Pyx_TypeCheck' - type_check_args = [arg, test_type_node] - else: + elif types.type is Builtin.type_type: + types = [types] + else: + return node + + tests = [] + test_nodes = [] + env = self.current_env() + for test_type_node in types: + builtin_type = None + if test_type_node.is_name: + if test_type_node.entry: + entry = env.lookup(test_type_node.entry.name) + if entry and entry.type and entry.type.is_builtin_type: + builtin_type = entry.type + if builtin_type is Builtin.type_type: + # all types have type "type", but there's only one 'type' + if entry.name != 'type' or not ( + entry.scope and entry.scope.is_builtin_scope): + builtin_type = None + if builtin_type is not None: + type_check_function = entry.type.type_check_function(exact=False) + if type_check_function in tests: + continue + tests.append(type_check_function) + type_check_args = [arg] + elif test_type_node.type is Builtin.type_type: + type_check_function = '__Pyx_TypeCheck' + type_check_args = [arg, test_type_node] + else: if not test_type_node.is_literal: test_type_node = UtilNodes.ResultRefNode(test_type_node) temps.append(test_type_node) type_check_function = 'PyObject_IsInstance' type_check_args = [arg, test_type_node] - test_nodes.append( - ExprNodes.PythonCapiCallNode( - test_type_node.pos, type_check_function, self.Py_type_check_func_type, + test_nodes.append( + ExprNodes.PythonCapiCallNode( + test_type_node.pos, type_check_function, self.Py_type_check_func_type, args=type_check_args, is_temp=True, )) - - def join_with_or(a, b, make_binop_node=ExprNodes.binop_node): - or_node = make_binop_node(node.pos, 'or', a, b) - or_node.type = PyrexTypes.c_bint_type - or_node.wrap_operands(env) - return or_node - - test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) + + def join_with_or(a, b, make_binop_node=ExprNodes.binop_node): + or_node = make_binop_node(node.pos, 'or', a, b) + or_node.type = PyrexTypes.c_bint_type + or_node.wrap_operands(env) + return or_node + + test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) for temp in temps[::-1]: - test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) - return test_node - - def _handle_simple_function_ord(self, node, function, pos_args): - """Unpack ord(Py_UNICODE) and ord('X'). - """ - if len(pos_args) != 1: - return node - arg = pos_args[0] - if isinstance(arg, ExprNodes.CoerceToPyTypeNode): - if arg.arg.type.is_unicode_char: - return ExprNodes.TypecastNode( + test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) + return test_node + + def _handle_simple_function_ord(self, node, function, pos_args): + """Unpack ord(Py_UNICODE) and ord('X'). + """ + if len(pos_args) != 1: + return node + arg = pos_args[0] + if isinstance(arg, ExprNodes.CoerceToPyTypeNode): + if arg.arg.type.is_unicode_char: + return ExprNodes.TypecastNode( arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type - ).coerce_to(node.type, self.current_env()) - elif isinstance(arg, ExprNodes.UnicodeNode): - if len(arg.value) == 1: - return ExprNodes.IntNode( - arg.pos, type=PyrexTypes.c_int_type, - value=str(ord(arg.value)), - constant_result=ord(arg.value) - ).coerce_to(node.type, self.current_env()) - elif isinstance(arg, ExprNodes.StringNode): - if arg.unicode_value and len(arg.unicode_value) == 1 \ - and ord(arg.unicode_value) <= 255: # Py2/3 portability - return ExprNodes.IntNode( - arg.pos, type=PyrexTypes.c_int_type, - value=str(ord(arg.unicode_value)), - constant_result=ord(arg.unicode_value) - ).coerce_to(node.type, self.current_env()) - return node - - ### special methods - - Pyx_tp_new_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), - ]) - - Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), - PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None), - ]) - - def _handle_any_slot__new__(self, node, function, args, - is_unbound_method, kwargs=None): - """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new() - """ - obj = function.obj - if not is_unbound_method or len(args) < 1: - return node - type_arg = args[0] - if not obj.is_name or not type_arg.is_name: - # play safe - return node - if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: - # not a known type, play safe - return node - if not type_arg.type_entry or not obj.type_entry: - if obj.name != type_arg.name: - return node - # otherwise, we know it's a type and we know it's the same - # type for both - that should do - elif type_arg.type_entry != obj.type_entry: - # different types - may or may not lead to an error at runtime - return node - - args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:]) - args_tuple = args_tuple.analyse_types( - self.current_env(), skip_children=True) - - if type_arg.type_entry: - ext_type = type_arg.type_entry.type - if (ext_type.is_extension_type and ext_type.typeobj_cname and - ext_type.scope.global_scope() == self.current_env().global_scope()): - # known type in current module - tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') - slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot) - if slot_func_cname: - cython_scope = self.context.cython_scope - PyTypeObjectPtr = PyrexTypes.CPtrType( - cython_scope.lookup('PyTypeObject').type) - pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( + ).coerce_to(node.type, self.current_env()) + elif isinstance(arg, ExprNodes.UnicodeNode): + if len(arg.value) == 1: + return ExprNodes.IntNode( + arg.pos, type=PyrexTypes.c_int_type, + value=str(ord(arg.value)), + constant_result=ord(arg.value) + ).coerce_to(node.type, self.current_env()) + elif isinstance(arg, ExprNodes.StringNode): + if arg.unicode_value and len(arg.unicode_value) == 1 \ + and ord(arg.unicode_value) <= 255: # Py2/3 portability + return ExprNodes.IntNode( + arg.pos, type=PyrexTypes.c_int_type, + value=str(ord(arg.unicode_value)), + constant_result=ord(arg.unicode_value) + ).coerce_to(node.type, self.current_env()) + return node + + ### special methods + + Pyx_tp_new_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), + ]) + + Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None), + PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None), + ]) + + def _handle_any_slot__new__(self, node, function, args, + is_unbound_method, kwargs=None): + """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new() + """ + obj = function.obj + if not is_unbound_method or len(args) < 1: + return node + type_arg = args[0] + if not obj.is_name or not type_arg.is_name: + # play safe + return node + if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type: + # not a known type, play safe + return node + if not type_arg.type_entry or not obj.type_entry: + if obj.name != type_arg.name: + return node + # otherwise, we know it's a type and we know it's the same + # type for both - that should do + elif type_arg.type_entry != obj.type_entry: + # different types - may or may not lead to an error at runtime + return node + + args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:]) + args_tuple = args_tuple.analyse_types( + self.current_env(), skip_children=True) + + if type_arg.type_entry: + ext_type = type_arg.type_entry.type + if (ext_type.is_extension_type and ext_type.typeobj_cname and + ext_type.scope.global_scope() == self.current_env().global_scope()): + # known type in current module + tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__') + slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot) + if slot_func_cname: + cython_scope = self.context.cython_scope + PyTypeObjectPtr = PyrexTypes.CPtrType( + cython_scope.lookup('PyTypeObject').type) + pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType( ext_type, [ - PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None), - PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None), - ]) - - type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr) - if not kwargs: - kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack? - return ExprNodes.PythonCapiCallNode( - node.pos, slot_func_cname, - pyx_tp_new_kwargs_func_type, - args=[type_arg, args_tuple, kwargs], + PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None), + PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None), + ]) + + type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr) + if not kwargs: + kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack? + return ExprNodes.PythonCapiCallNode( + node.pos, slot_func_cname, + pyx_tp_new_kwargs_func_type, + args=[type_arg, args_tuple, kwargs], may_return_none=False, - is_temp=True) - else: - # arbitrary variable, needs a None check for safety - type_arg = type_arg.as_none_safe_node( - "object.__new__(X): X is not a type object (NoneType)") - - utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c') - if kwargs: - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type, - args=[type_arg, args_tuple, kwargs], - utility_code=utility_code, - is_temp=node.is_temp - ) - else: - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type, - args=[type_arg, args_tuple], - utility_code=utility_code, - is_temp=node.is_temp - ) - - ### methods of builtin types - - PyObject_Append_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_returncode_type, [ - PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), - ], - exception_value="-1") - - def _handle_simple_method_object_append(self, node, function, args, is_unbound_method): - """Optimistic optimisation as X.append() is almost always - referring to a list. - """ - if len(args) != 2 or node.result_is_used: - return node - - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, - args=args, - may_return_none=False, - is_temp=node.is_temp, - result_is_used=False, - utility_code=load_c_utility('append') - ) - + is_temp=True) + else: + # arbitrary variable, needs a None check for safety + type_arg = type_arg.as_none_safe_node( + "object.__new__(X): X is not a type object (NoneType)") + + utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c') + if kwargs: + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type, + args=[type_arg, args_tuple, kwargs], + utility_code=utility_code, + is_temp=node.is_temp + ) + else: + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type, + args=[type_arg, args_tuple], + utility_code=utility_code, + is_temp=node.is_temp + ) + + ### methods of builtin types + + PyObject_Append_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_returncode_type, [ + PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), + ], + exception_value="-1") + + def _handle_simple_method_object_append(self, node, function, args, is_unbound_method): + """Optimistic optimisation as X.append() is almost always + referring to a list. + """ + if len(args) != 2 or node.result_is_used: + return node + + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, + args=args, + may_return_none=False, + is_temp=node.is_temp, + result_is_used=False, + utility_code=load_c_utility('append') + ) + def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method): """Replace list.extend([...]) for short sequence literals values by sequential appends to avoid creating an intermediate sequence argument. @@ -2935,211 +2935,211 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, new_node.result_is_used = node.result_is_used return new_node - PyByteArray_Append_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_returncode_type, [ - PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None), - ], - exception_value="-1") - - PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_returncode_type, [ - PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None), - ], - exception_value="-1") - - def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method): - if len(args) != 2: - return node - func_name = "__Pyx_PyByteArray_Append" - func_type = self.PyByteArray_Append_func_type - - value = unwrap_coerced_node(args[1]) - if value.type.is_int or isinstance(value, ExprNodes.IntNode): - value = value.coerce_to(PyrexTypes.c_int_type, self.current_env()) - utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") - elif value.is_string_literal: - if not value.can_coerce_to_char_literal(): - return node - value = value.coerce_to(PyrexTypes.c_char_type, self.current_env()) - utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") - elif value.type.is_pyobject: - func_name = "__Pyx_PyByteArray_AppendObject" - func_type = self.PyByteArray_AppendObject_func_type - utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c") - else: - return node - - new_node = ExprNodes.PythonCapiCallNode( - node.pos, func_name, func_type, - args=[args[0], value], - may_return_none=False, - is_temp=node.is_temp, - utility_code=utility_code, - ) - if node.result_is_used: - new_node = new_node.coerce_to(node.type, self.current_env()) - return new_node - - PyObject_Pop_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), - ]) - - PyObject_PopIndex_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), + PyByteArray_Append_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_returncode_type, [ + PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None), + ], + exception_value="-1") + + PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_returncode_type, [ + PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None), + ], + exception_value="-1") + + def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method): + if len(args) != 2: + return node + func_name = "__Pyx_PyByteArray_Append" + func_type = self.PyByteArray_Append_func_type + + value = unwrap_coerced_node(args[1]) + if value.type.is_int or isinstance(value, ExprNodes.IntNode): + value = value.coerce_to(PyrexTypes.c_int_type, self.current_env()) + utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") + elif value.is_string_literal: + if not value.can_coerce_to_char_literal(): + return node + value = value.coerce_to(PyrexTypes.c_char_type, self.current_env()) + utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c") + elif value.type.is_pyobject: + func_name = "__Pyx_PyByteArray_AppendObject" + func_type = self.PyByteArray_AppendObject_func_type + utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c") + else: + return node + + new_node = ExprNodes.PythonCapiCallNode( + node.pos, func_name, func_type, + args=[args[0], value], + may_return_none=False, + is_temp=node.is_temp, + utility_code=utility_code, + ) + if node.result_is_used: + new_node = new_node.coerce_to(node.type, self.current_env()) + return new_node + + PyObject_Pop_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), + ]) + + PyObject_PopIndex_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None), - ], - has_varargs=True) # to fake the additional macro args that lack a proper C type - - def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method): - return self._handle_simple_method_object_pop( - node, function, args, is_unbound_method, is_list=True) - - def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False): - """Optimistic optimisation as X.pop([n]) is almost always - referring to a list. - """ - if not args: - return node - obj = args[0] - if is_list: - type_name = 'List' - obj = obj.as_none_safe_node( + PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None), + ], + has_varargs=True) # to fake the additional macro args that lack a proper C type + + def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method): + return self._handle_simple_method_object_pop( + node, function, args, is_unbound_method, is_list=True) + + def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False): + """Optimistic optimisation as X.pop([n]) is almost always + referring to a list. + """ + if not args: + return node + obj = args[0] + if is_list: + type_name = 'List' + obj = obj.as_none_safe_node( "'NoneType' object has no attribute '%.30s'", - error="PyExc_AttributeError", - format_args=['pop']) - else: - type_name = 'Object' - if len(args) == 1: - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_Py%s_Pop" % type_name, - self.PyObject_Pop_func_type, - args=[obj], - may_return_none=True, - is_temp=node.is_temp, - utility_code=load_c_utility('pop'), - ) - elif len(args) == 2: - index = unwrap_coerced_node(args[1]) + error="PyExc_AttributeError", + format_args=['pop']) + else: + type_name = 'Object' + if len(args) == 1: + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_Py%s_Pop" % type_name, + self.PyObject_Pop_func_type, + args=[obj], + may_return_none=True, + is_temp=node.is_temp, + utility_code=load_c_utility('pop'), + ) + elif len(args) == 2: + index = unwrap_coerced_node(args[1]) py_index = ExprNodes.NoneNode(index.pos) - orig_index_type = index.type - if not index.type.is_int: + orig_index_type = index.type + if not index.type.is_int: if isinstance(index, ExprNodes.IntNode): py_index = index.coerce_to_pyobject(self.current_env()) - index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) elif is_list: if index.type.is_pyobject: py_index = index.coerce_to_simple(self.current_env()) index = ExprNodes.CloneNode(py_index) index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - else: - return node - elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type): - return node + else: + return node + elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type): + return node elif isinstance(index, ExprNodes.IntNode): py_index = index.coerce_to_pyobject(self.current_env()) - # real type might still be larger at runtime - if not orig_index_type.is_int: - orig_index_type = index.type - if not orig_index_type.create_to_py_utility_code(self.current_env()): - return node - convert_func = orig_index_type.to_py_function - conversion_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)]) - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_Py%s_PopIndex" % type_name, - self.PyObject_PopIndex_func_type, + # real type might still be larger at runtime + if not orig_index_type.is_int: + orig_index_type = index.type + if not orig_index_type.create_to_py_utility_code(self.current_env()): + return node + convert_func = orig_index_type.to_py_function + conversion_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)]) + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_Py%s_PopIndex" % type_name, + self.PyObject_PopIndex_func_type, args=[obj, py_index, index, - ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0), - constant_result=orig_index_type.signed and 1 or 0, - type=PyrexTypes.c_int_type), - ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type, + ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0), + constant_result=orig_index_type.signed and 1 or 0, + type=PyrexTypes.c_int_type), + ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type, orig_index_type.empty_declaration_code()), - ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)], - may_return_none=True, - is_temp=node.is_temp, - utility_code=load_c_utility("pop_index"), - ) - - return node - - single_param_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_returncode_type, [ - PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), - ], - exception_value = "-1") - - def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method): - """Call PyList_Sort() instead of the 0-argument l.sort(). - """ - if len(args) != 1: - return node - return self._substitute_method_call( - node, function, "PyList_Sort", self.single_param_func_type, - 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env) - - Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), - ]) - - def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method): - """Replace dict.get() by a call to PyDict_GetItem(). - """ - if len(args) == 2: - args.append(ExprNodes.NoneNode(node.pos)) - elif len(args) != 3: - self._error_wrong_arg_count('dict.get', node, args, "2 or 3") - return node - - return self._substitute_method_call( - node, function, - "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, - 'get', is_unbound_method, args, - may_return_none = True, - utility_code = load_c_utility("dict_getitem_default")) - - Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None), - ]) - - def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method): - """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem(). - """ - if len(args) == 2: - args.append(ExprNodes.NoneNode(node.pos)) - elif len(args) != 3: - self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") - return node - key_type = args[1].type - if key_type.is_builtin_type: - is_safe_type = int(key_type.name in - 'str bytes unicode float int long bool') - elif key_type is PyrexTypes.py_object_type: - is_safe_type = -1 # don't know - else: - is_safe_type = 0 # definitely not - args.append(ExprNodes.IntNode( - node.pos, value=str(is_safe_type), constant_result=is_safe_type)) - - return self._substitute_method_call( - node, function, - "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, - 'setdefault', is_unbound_method, args, - may_return_none=True, - utility_code=load_c_utility('dict_setdefault')) - + ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)], + may_return_none=True, + is_temp=node.is_temp, + utility_code=load_c_utility("pop_index"), + ) + + return node + + single_param_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_returncode_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), + ], + exception_value = "-1") + + def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method): + """Call PyList_Sort() instead of the 0-argument l.sort(). + """ + if len(args) != 1: + return node + return self._substitute_method_call( + node, function, "PyList_Sort", self.single_param_func_type, + 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env) + + Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), + ]) + + def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method): + """Replace dict.get() by a call to PyDict_GetItem(). + """ + if len(args) == 2: + args.append(ExprNodes.NoneNode(node.pos)) + elif len(args) != 3: + self._error_wrong_arg_count('dict.get', node, args, "2 or 3") + return node + + return self._substitute_method_call( + node, function, + "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, + 'get', is_unbound_method, args, + may_return_none = True, + utility_code = load_c_utility("dict_getitem_default")) + + Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None), + ]) + + def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method): + """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem(). + """ + if len(args) == 2: + args.append(ExprNodes.NoneNode(node.pos)) + elif len(args) != 3: + self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3") + return node + key_type = args[1].type + if key_type.is_builtin_type: + is_safe_type = int(key_type.name in + 'str bytes unicode float int long bool') + elif key_type is PyrexTypes.py_object_type: + is_safe_type = -1 # don't know + else: + is_safe_type = 0 # definitely not + args.append(ExprNodes.IntNode( + node.pos, value=str(is_safe_type), constant_result=is_safe_type)) + + return self._substitute_method_call( + node, function, + "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type, + 'setdefault', is_unbound_method, args, + may_return_none=True, + utility_code=load_c_utility('dict_setdefault')) + PyDict_Pop_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), @@ -3175,7 +3175,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type) for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type) ) - + def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method): return self._optimise_num_binop('Add', node, function, args, is_unbound_method) @@ -3341,121 +3341,121 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type) return call_node - ### unicode type methods - - PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_bint_type, [ - PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), - ]) - - def _inject_unicode_predicate(self, node, function, args, is_unbound_method): - if is_unbound_method or len(args) != 1: - return node - ustring = args[0] - if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ - not ustring.arg.type.is_unicode_char: - return node - uchar = ustring.arg - method_name = function.attribute - if method_name == 'istitle': - # istitle() doesn't directly map to Py_UNICODE_ISTITLE() - utility_code = UtilityCode.load_cached( - "py_unicode_istitle", "StringTools.c") - function_name = '__Pyx_Py_UNICODE_ISTITLE' - else: - utility_code = None - function_name = 'Py_UNICODE_%s' % method_name.upper() - func_call = self._substitute_method_call( - node, function, - function_name, self.PyUnicode_uchar_predicate_func_type, - method_name, is_unbound_method, [uchar], - utility_code = utility_code) - if node.type.is_pyobject: - func_call = func_call.coerce_to_pyobject(self.current_env) - return func_call - - _handle_simple_method_unicode_isalnum = _inject_unicode_predicate - _handle_simple_method_unicode_isalpha = _inject_unicode_predicate - _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate - _handle_simple_method_unicode_isdigit = _inject_unicode_predicate - _handle_simple_method_unicode_islower = _inject_unicode_predicate - _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate - _handle_simple_method_unicode_isspace = _inject_unicode_predicate - _handle_simple_method_unicode_istitle = _inject_unicode_predicate - _handle_simple_method_unicode_isupper = _inject_unicode_predicate - - PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ucs4_type, [ - PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), - ]) - - def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method): - if is_unbound_method or len(args) != 1: - return node - ustring = args[0] - if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ - not ustring.arg.type.is_unicode_char: - return node - uchar = ustring.arg - method_name = function.attribute - function_name = 'Py_UNICODE_TO%s' % method_name.upper() - func_call = self._substitute_method_call( - node, function, - function_name, self.PyUnicode_uchar_conversion_func_type, - method_name, is_unbound_method, [uchar]) - if node.type.is_pyobject: - func_call = func_call.coerce_to_pyobject(self.current_env) - return func_call - - _handle_simple_method_unicode_lower = _inject_unicode_character_conversion - _handle_simple_method_unicode_upper = _inject_unicode_character_conversion - _handle_simple_method_unicode_title = _inject_unicode_character_conversion - - PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( - Builtin.list_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None), - ]) - - def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method): - """Replace unicode.splitlines(...) by a direct call to the - corresponding C-API function. - """ - if len(args) not in (1,2): - self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2") - return node - self._inject_bint_default_argument(node, args, 1, False) - - return self._substitute_method_call( - node, function, - "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type, - 'splitlines', is_unbound_method, args) - - PyUnicode_Split_func_type = PyrexTypes.CFuncType( - Builtin.list_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None), - ] - ) - - def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method): - """Replace unicode.split(...) by a direct call to the - corresponding C-API function. - """ - if len(args) not in (1,2,3): - self._error_wrong_arg_count('unicode.split', node, args, "1-3") - return node - if len(args) < 2: - args.append(ExprNodes.NullNode(node.pos)) - self._inject_int_default_argument( - node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") - - return self._substitute_method_call( - node, function, - "PyUnicode_Split", self.PyUnicode_Split_func_type, - 'split', is_unbound_method, args) - + ### unicode type methods + + PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), + ]) + + def _inject_unicode_predicate(self, node, function, args, is_unbound_method): + if is_unbound_method or len(args) != 1: + return node + ustring = args[0] + if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ + not ustring.arg.type.is_unicode_char: + return node + uchar = ustring.arg + method_name = function.attribute + if method_name == 'istitle': + # istitle() doesn't directly map to Py_UNICODE_ISTITLE() + utility_code = UtilityCode.load_cached( + "py_unicode_istitle", "StringTools.c") + function_name = '__Pyx_Py_UNICODE_ISTITLE' + else: + utility_code = None + function_name = 'Py_UNICODE_%s' % method_name.upper() + func_call = self._substitute_method_call( + node, function, + function_name, self.PyUnicode_uchar_predicate_func_type, + method_name, is_unbound_method, [uchar], + utility_code = utility_code) + if node.type.is_pyobject: + func_call = func_call.coerce_to_pyobject(self.current_env) + return func_call + + _handle_simple_method_unicode_isalnum = _inject_unicode_predicate + _handle_simple_method_unicode_isalpha = _inject_unicode_predicate + _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate + _handle_simple_method_unicode_isdigit = _inject_unicode_predicate + _handle_simple_method_unicode_islower = _inject_unicode_predicate + _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate + _handle_simple_method_unicode_isspace = _inject_unicode_predicate + _handle_simple_method_unicode_istitle = _inject_unicode_predicate + _handle_simple_method_unicode_isupper = _inject_unicode_predicate + + PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ucs4_type, [ + PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None), + ]) + + def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method): + if is_unbound_method or len(args) != 1: + return node + ustring = args[0] + if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \ + not ustring.arg.type.is_unicode_char: + return node + uchar = ustring.arg + method_name = function.attribute + function_name = 'Py_UNICODE_TO%s' % method_name.upper() + func_call = self._substitute_method_call( + node, function, + function_name, self.PyUnicode_uchar_conversion_func_type, + method_name, is_unbound_method, [uchar]) + if node.type.is_pyobject: + func_call = func_call.coerce_to_pyobject(self.current_env) + return func_call + + _handle_simple_method_unicode_lower = _inject_unicode_character_conversion + _handle_simple_method_unicode_upper = _inject_unicode_character_conversion + _handle_simple_method_unicode_title = _inject_unicode_character_conversion + + PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType( + Builtin.list_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None), + ]) + + def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method): + """Replace unicode.splitlines(...) by a direct call to the + corresponding C-API function. + """ + if len(args) not in (1,2): + self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2") + return node + self._inject_bint_default_argument(node, args, 1, False) + + return self._substitute_method_call( + node, function, + "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type, + 'splitlines', is_unbound_method, args) + + PyUnicode_Split_func_type = PyrexTypes.CFuncType( + Builtin.list_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None), + ] + ) + + def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method): + """Replace unicode.split(...) by a direct call to the + corresponding C-API function. + """ + if len(args) not in (1,2,3): + self._error_wrong_arg_count('unicode.split', node, args, "1-3") + return node + if len(args) < 2: + args.append(ExprNodes.NullNode(node.pos)) + self._inject_int_default_argument( + node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1") + + return self._substitute_method_call( + node, function, + "PyUnicode_Split", self.PyUnicode_Split_func_type, + 'split', is_unbound_method, args) + PyUnicode_Join_func_type = PyrexTypes.CFuncType( Builtin.unicode_type, [ PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), @@ -3494,461 +3494,461 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, "PyUnicode_Join", self.PyUnicode_Join_func_type, 'join', is_unbound_method, args) - PyString_Tailmatch_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_bint_type, [ - PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode - PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), - ], - exception_value = '-1') - - def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'unicode', 'endswith', - unicode_tailmatch_utility_code, +1) - - def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'unicode', 'startswith', - unicode_tailmatch_utility_code, -1) - - def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name, - method_name, utility_code, direction): - """Replace unicode.startswith(...) and unicode.endswith(...) - by a direct call to the corresponding C-API function. - """ - if len(args) not in (2,3,4): - self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4") - return node - self._inject_int_default_argument( - node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") - self._inject_int_default_argument( - node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") - args.append(ExprNodes.IntNode( - node.pos, value=str(direction), type=PyrexTypes.c_int_type)) - - method_call = self._substitute_method_call( - node, function, - "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), - self.PyString_Tailmatch_func_type, - method_name, is_unbound_method, args, - utility_code = utility_code) - return method_call.coerce_to(Builtin.bool_type, self.current_env()) - - PyUnicode_Find_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ssize_t_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), - ], - exception_value = '-2') - - def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method): - return self._inject_unicode_find( - node, function, args, is_unbound_method, 'find', +1) - - def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method): - return self._inject_unicode_find( - node, function, args, is_unbound_method, 'rfind', -1) - - def _inject_unicode_find(self, node, function, args, is_unbound_method, - method_name, direction): - """Replace unicode.find(...) and unicode.rfind(...) by a - direct call to the corresponding C-API function. - """ - if len(args) not in (2,3,4): - self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4") - return node - self._inject_int_default_argument( - node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") - self._inject_int_default_argument( - node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") - args.append(ExprNodes.IntNode( - node.pos, value=str(direction), type=PyrexTypes.c_int_type)) - - method_call = self._substitute_method_call( - node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type, - method_name, is_unbound_method, args) - return method_call.coerce_to_pyobject(self.current_env()) - - PyUnicode_Count_func_type = PyrexTypes.CFuncType( - PyrexTypes.c_py_ssize_t_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), - ], - exception_value = '-1') - - def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method): - """Replace unicode.count(...) by a direct call to the - corresponding C-API function. - """ - if len(args) not in (2,3,4): - self._error_wrong_arg_count('unicode.count', node, args, "2-4") - return node - self._inject_int_default_argument( - node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") - self._inject_int_default_argument( - node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") - - method_call = self._substitute_method_call( - node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type, - 'count', is_unbound_method, args) - return method_call.coerce_to_pyobject(self.current_env()) - - PyUnicode_Replace_func_type = PyrexTypes.CFuncType( - Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None), - ]) - - def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method): - """Replace unicode.replace(...) by a direct call to the - corresponding C-API function. - """ - if len(args) not in (3,4): - self._error_wrong_arg_count('unicode.replace', node, args, "3-4") - return node - self._inject_int_default_argument( - node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1") - - return self._substitute_method_call( - node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type, - 'replace', is_unbound_method, args) - - PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( - Builtin.bytes_type, [ - PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), + PyString_Tailmatch_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode + PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), + ], + exception_value = '-1') + + def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'unicode', 'endswith', + unicode_tailmatch_utility_code, +1) + + def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'unicode', 'startswith', + unicode_tailmatch_utility_code, -1) + + def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name, + method_name, utility_code, direction): + """Replace unicode.startswith(...) and unicode.endswith(...) + by a direct call to the corresponding C-API function. + """ + if len(args) not in (2,3,4): + self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4") + return node + self._inject_int_default_argument( + node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") + self._inject_int_default_argument( + node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") + args.append(ExprNodes.IntNode( + node.pos, value=str(direction), type=PyrexTypes.c_int_type)) + + method_call = self._substitute_method_call( + node, function, + "__Pyx_Py%s_Tailmatch" % type_name.capitalize(), + self.PyString_Tailmatch_func_type, + method_name, is_unbound_method, args, + utility_code = utility_code) + return method_call.coerce_to(Builtin.bool_type, self.current_env()) + + PyUnicode_Find_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None), + ], + exception_value = '-2') + + def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method): + return self._inject_unicode_find( + node, function, args, is_unbound_method, 'find', +1) + + def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method): + return self._inject_unicode_find( + node, function, args, is_unbound_method, 'rfind', -1) + + def _inject_unicode_find(self, node, function, args, is_unbound_method, + method_name, direction): + """Replace unicode.find(...) and unicode.rfind(...) by a + direct call to the corresponding C-API function. + """ + if len(args) not in (2,3,4): + self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4") + return node + self._inject_int_default_argument( + node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") + self._inject_int_default_argument( + node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") + args.append(ExprNodes.IntNode( + node.pos, value=str(direction), type=PyrexTypes.c_int_type)) + + method_call = self._substitute_method_call( + node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type, + method_name, is_unbound_method, args) + return method_call.coerce_to_pyobject(self.current_env()) + + PyUnicode_Count_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), + ], + exception_value = '-1') + + def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method): + """Replace unicode.count(...) by a direct call to the + corresponding C-API function. + """ + if len(args) not in (2,3,4): + self._error_wrong_arg_count('unicode.count', node, args, "2-4") + return node + self._inject_int_default_argument( + node, args, 2, PyrexTypes.c_py_ssize_t_type, "0") + self._inject_int_default_argument( + node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX") + + method_call = self._substitute_method_call( + node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type, + 'count', is_unbound_method, args) + return method_call.coerce_to_pyobject(self.current_env()) + + PyUnicode_Replace_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None), + ]) + + def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method): + """Replace unicode.replace(...) by a direct call to the + corresponding C-API function. + """ + if len(args) not in (3,4): + self._error_wrong_arg_count('unicode.replace', node, args, "3-4") + return node + self._inject_int_default_argument( + node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1") + + return self._substitute_method_call( + node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type, + 'replace', is_unbound_method, args) + + PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( + Builtin.bytes_type, [ + PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), - ]) - - PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( - Builtin.bytes_type, [ - PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), - ]) - + ]) + + PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( + Builtin.bytes_type, [ + PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), + ]) + _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII', - 'unicode_escape', 'raw_unicode_escape'] - - _special_codecs = [ (name, codecs.getencoder(name)) - for name in _special_encodings ] - - def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method): - """Replace unicode.encode(...) by a direct C-API call to the - corresponding codec. - """ - if len(args) < 1 or len(args) > 3: - self._error_wrong_arg_count('unicode.encode', node, args, '1-3') - return node - - string_node = args[0] - - if len(args) == 1: - null_node = ExprNodes.NullNode(node.pos) - return self._substitute_method_call( - node, function, "PyUnicode_AsEncodedString", - self.PyUnicode_AsEncodedString_func_type, - 'encode', is_unbound_method, [string_node, null_node, null_node]) - - parameters = self._unpack_encoding_and_error_mode(node.pos, args) - if parameters is None: - return node - encoding, encoding_node, error_handling, error_handling_node = parameters - - if encoding and isinstance(string_node, ExprNodes.UnicodeNode): - # constant, so try to do the encoding at compile time - try: - value = string_node.value.encode(encoding, error_handling) - except: - # well, looks like we can't - pass - else: + 'unicode_escape', 'raw_unicode_escape'] + + _special_codecs = [ (name, codecs.getencoder(name)) + for name in _special_encodings ] + + def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method): + """Replace unicode.encode(...) by a direct C-API call to the + corresponding codec. + """ + if len(args) < 1 or len(args) > 3: + self._error_wrong_arg_count('unicode.encode', node, args, '1-3') + return node + + string_node = args[0] + + if len(args) == 1: + null_node = ExprNodes.NullNode(node.pos) + return self._substitute_method_call( + node, function, "PyUnicode_AsEncodedString", + self.PyUnicode_AsEncodedString_func_type, + 'encode', is_unbound_method, [string_node, null_node, null_node]) + + parameters = self._unpack_encoding_and_error_mode(node.pos, args) + if parameters is None: + return node + encoding, encoding_node, error_handling, error_handling_node = parameters + + if encoding and isinstance(string_node, ExprNodes.UnicodeNode): + # constant, so try to do the encoding at compile time + try: + value = string_node.value.encode(encoding, error_handling) + except: + # well, looks like we can't + pass + else: value = bytes_literal(value, encoding) return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type) - - if encoding and error_handling == 'strict': - # try to find a specific encoder function - codec_name = self._find_special_codec_name(encoding) + + if encoding and error_handling == 'strict': + # try to find a specific encoder function + codec_name = self._find_special_codec_name(encoding) if codec_name is not None and '-' not in codec_name: - encode_function = "PyUnicode_As%sString" % codec_name - return self._substitute_method_call( - node, function, encode_function, - self.PyUnicode_AsXyzString_func_type, - 'encode', is_unbound_method, [string_node]) - - return self._substitute_method_call( - node, function, "PyUnicode_AsEncodedString", - self.PyUnicode_AsEncodedString_func_type, - 'encode', is_unbound_method, - [string_node, encoding_node, error_handling_node]) - - PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType( - Builtin.unicode_type, [ + encode_function = "PyUnicode_As%sString" % codec_name + return self._substitute_method_call( + node, function, encode_function, + self.PyUnicode_AsXyzString_func_type, + 'encode', is_unbound_method, [string_node]) + + return self._substitute_method_call( + node, function, "PyUnicode_AsEncodedString", + self.PyUnicode_AsEncodedString_func_type, + 'encode', is_unbound_method, + [string_node, encoding_node, error_handling_node]) + + PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType( + Builtin.unicode_type, [ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), ])) - - _decode_c_string_func_type = PyrexTypes.CFuncType( - Builtin.unicode_type, [ + + _decode_c_string_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), + PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), ]) - - _decode_bytes_func_type = PyrexTypes.CFuncType( - Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + + _decode_bytes_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), + PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), ]) - - _decode_cpp_string_func_type = None # lazy init - - def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method): - """Replace char*.decode() by a direct C-API call to the - corresponding codec, possibly resolving a slice on the char*. - """ - if not (1 <= len(args) <= 3): - self._error_wrong_arg_count('bytes.decode', node, args, '1-3') - return node - - # normalise input nodes - string_node = args[0] - start = stop = None - if isinstance(string_node, ExprNodes.SliceIndexNode): - index_node = string_node - string_node = index_node.base - start, stop = index_node.start, index_node.stop - if not start or start.constant_result == 0: - start = None - if isinstance(string_node, ExprNodes.CoerceToPyTypeNode): - string_node = string_node.arg - - string_type = string_node.type - if string_type in (Builtin.bytes_type, Builtin.bytearray_type): - if is_unbound_method: - string_node = string_node.as_none_safe_node( - "descriptor '%s' requires a '%s' object but received a 'NoneType'", - format_args=['decode', string_type.name]) - else: - string_node = string_node.as_none_safe_node( + + _decode_cpp_string_func_type = None # lazy init + + def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method): + """Replace char*.decode() by a direct C-API call to the + corresponding codec, possibly resolving a slice on the char*. + """ + if not (1 <= len(args) <= 3): + self._error_wrong_arg_count('bytes.decode', node, args, '1-3') + return node + + # normalise input nodes + string_node = args[0] + start = stop = None + if isinstance(string_node, ExprNodes.SliceIndexNode): + index_node = string_node + string_node = index_node.base + start, stop = index_node.start, index_node.stop + if not start or start.constant_result == 0: + start = None + if isinstance(string_node, ExprNodes.CoerceToPyTypeNode): + string_node = string_node.arg + + string_type = string_node.type + if string_type in (Builtin.bytes_type, Builtin.bytearray_type): + if is_unbound_method: + string_node = string_node.as_none_safe_node( + "descriptor '%s' requires a '%s' object but received a 'NoneType'", + format_args=['decode', string_type.name]) + else: + string_node = string_node.as_none_safe_node( "'NoneType' object has no attribute '%.30s'", - error="PyExc_AttributeError", - format_args=['decode']) - elif not string_type.is_string and not string_type.is_cpp_string: - # nothing to optimise here - return node - - parameters = self._unpack_encoding_and_error_mode(node.pos, args) - if parameters is None: - return node - encoding, encoding_node, error_handling, error_handling_node = parameters - - if not start: - start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) - elif not start.type.is_int: - start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - if stop and not stop.type.is_int: - stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - - # try to find a specific encoder function - codec_name = None - if encoding is not None: - codec_name = self._find_special_codec_name(encoding) - if codec_name is not None: + error="PyExc_AttributeError", + format_args=['decode']) + elif not string_type.is_string and not string_type.is_cpp_string: + # nothing to optimise here + return node + + parameters = self._unpack_encoding_and_error_mode(node.pos, args) + if parameters is None: + return node + encoding, encoding_node, error_handling, error_handling_node = parameters + + if not start: + start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) + elif not start.type.is_int: + start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + if stop and not stop.type.is_int: + stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + + # try to find a specific encoder function + codec_name = None + if encoding is not None: + codec_name = self._find_special_codec_name(encoding) + if codec_name is not None: if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'): codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '') else: codec_cname = "PyUnicode_Decode%s" % codec_name - decode_function = ExprNodes.RawCNameExprNode( + decode_function = ExprNodes.RawCNameExprNode( node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname) - encoding_node = ExprNodes.NullNode(node.pos) - else: - decode_function = ExprNodes.NullNode(node.pos) - - # build the helper function call - temps = [] - if string_type.is_string: - # C string - if not stop: - # use strlen() to find the string length, just as CPython would - if not string_node.is_name: - string_node = UtilNodes.LetRefNode(string_node) # used twice - temps.append(string_node) - stop = ExprNodes.PythonCapiCallNode( - string_node.pos, "strlen", self.Pyx_strlen_func_type, - args=[string_node], - is_temp=False, - utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"), - ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - helper_func_type = self._decode_c_string_func_type - utility_code_name = 'decode_c_string' - elif string_type.is_cpp_string: - # C++ std::string - if not stop: - stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', - constant_result=ExprNodes.not_a_constant) - if self._decode_cpp_string_func_type is None: - # lazy init to reuse the C++ string type - self._decode_cpp_string_func_type = PyrexTypes.CFuncType( - Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("string", string_type, None), - PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), + encoding_node = ExprNodes.NullNode(node.pos) + else: + decode_function = ExprNodes.NullNode(node.pos) + + # build the helper function call + temps = [] + if string_type.is_string: + # C string + if not stop: + # use strlen() to find the string length, just as CPython would + if not string_node.is_name: + string_node = UtilNodes.LetRefNode(string_node) # used twice + temps.append(string_node) + stop = ExprNodes.PythonCapiCallNode( + string_node.pos, "strlen", self.Pyx_strlen_func_type, + args=[string_node], + is_temp=False, + utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"), + ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + helper_func_type = self._decode_c_string_func_type + utility_code_name = 'decode_c_string' + elif string_type.is_cpp_string: + # C++ std::string + if not stop: + stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', + constant_result=ExprNodes.not_a_constant) + if self._decode_cpp_string_func_type is None: + # lazy init to reuse the C++ string type + self._decode_cpp_string_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("string", string_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None), - ]) - helper_func_type = self._decode_cpp_string_func_type - utility_code_name = 'decode_cpp_string' - else: - # Python bytes/bytearray object - if not stop: - stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', - constant_result=ExprNodes.not_a_constant) - helper_func_type = self._decode_bytes_func_type - if string_type is Builtin.bytes_type: - utility_code_name = 'decode_bytes' - else: - utility_code_name = 'decode_bytearray' - - node = ExprNodes.PythonCapiCallNode( - node.pos, '__Pyx_%s' % utility_code_name, helper_func_type, - args=[string_node, start, stop, encoding_node, error_handling_node, decode_function], - is_temp=node.is_temp, - utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'), - ) - - for temp in temps[::-1]: - node = UtilNodes.EvalWithTempExprNode(temp, node) - return node - - _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode - - def _find_special_codec_name(self, encoding): - try: - requested_codec = codecs.getencoder(encoding) - except LookupError: - return None - for name, codec in self._special_codecs: - if codec == requested_codec: - if '_' in name: - name = ''.join([s.capitalize() - for s in name.split('_')]) - return name - return None - - def _unpack_encoding_and_error_mode(self, pos, args): - null_node = ExprNodes.NullNode(pos) - - if len(args) >= 2: - encoding, encoding_node = self._unpack_string_and_cstring_node(args[1]) - if encoding_node is None: - return None - else: - encoding = None - encoding_node = null_node - - if len(args) == 3: - error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2]) - if error_handling_node is None: - return None - if error_handling == 'strict': - error_handling_node = null_node - else: - error_handling = 'strict' - error_handling_node = null_node - - return (encoding, encoding_node, error_handling, error_handling_node) - - def _unpack_string_and_cstring_node(self, node): - if isinstance(node, ExprNodes.CoerceToPyTypeNode): - node = node.arg - if isinstance(node, ExprNodes.UnicodeNode): - encoding = node.value - node = ExprNodes.BytesNode( + PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None), + ]) + helper_func_type = self._decode_cpp_string_func_type + utility_code_name = 'decode_cpp_string' + else: + # Python bytes/bytearray object + if not stop: + stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX', + constant_result=ExprNodes.not_a_constant) + helper_func_type = self._decode_bytes_func_type + if string_type is Builtin.bytes_type: + utility_code_name = 'decode_bytes' + else: + utility_code_name = 'decode_bytearray' + + node = ExprNodes.PythonCapiCallNode( + node.pos, '__Pyx_%s' % utility_code_name, helper_func_type, + args=[string_node, start, stop, encoding_node, error_handling_node, decode_function], + is_temp=node.is_temp, + utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'), + ) + + for temp in temps[::-1]: + node = UtilNodes.EvalWithTempExprNode(temp, node) + return node + + _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode + + def _find_special_codec_name(self, encoding): + try: + requested_codec = codecs.getencoder(encoding) + except LookupError: + return None + for name, codec in self._special_codecs: + if codec == requested_codec: + if '_' in name: + name = ''.join([s.capitalize() + for s in name.split('_')]) + return name + return None + + def _unpack_encoding_and_error_mode(self, pos, args): + null_node = ExprNodes.NullNode(pos) + + if len(args) >= 2: + encoding, encoding_node = self._unpack_string_and_cstring_node(args[1]) + if encoding_node is None: + return None + else: + encoding = None + encoding_node = null_node + + if len(args) == 3: + error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2]) + if error_handling_node is None: + return None + if error_handling == 'strict': + error_handling_node = null_node + else: + error_handling = 'strict' + error_handling_node = null_node + + return (encoding, encoding_node, error_handling, error_handling_node) + + def _unpack_string_and_cstring_node(self, node): + if isinstance(node, ExprNodes.CoerceToPyTypeNode): + node = node.arg + if isinstance(node, ExprNodes.UnicodeNode): + encoding = node.value + node = ExprNodes.BytesNode( node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type) - elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): - encoding = node.value.decode('ISO-8859-1') - node = ExprNodes.BytesNode( + elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): + encoding = node.value.decode('ISO-8859-1') + node = ExprNodes.BytesNode( node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type) - elif node.type is Builtin.bytes_type: - encoding = None + elif node.type is Builtin.bytes_type: + encoding = None node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env()) - elif node.type.is_string: - encoding = None - else: - encoding = node = None - return encoding, node - - def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'str', 'endswith', - str_tailmatch_utility_code, +1) - - def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'str', 'startswith', - str_tailmatch_utility_code, -1) - - def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'bytes', 'endswith', - bytes_tailmatch_utility_code, +1) - - def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'bytes', 'startswith', - bytes_tailmatch_utility_code, -1) - - ''' # disabled for now, enable when we consider it worth it (see StringTools.c) - def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'bytearray', 'endswith', - bytes_tailmatch_utility_code, +1) - - def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method): - return self._inject_tailmatch( - node, function, args, is_unbound_method, 'bytearray', 'startswith', - bytes_tailmatch_utility_code, -1) - ''' - - ### helpers - - def _substitute_method_call(self, node, function, name, func_type, - attr_name, is_unbound_method, args=(), - utility_code=None, is_temp=None, + elif node.type.is_string: + encoding = None + else: + encoding = node = None + return encoding, node + + def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'str', 'endswith', + str_tailmatch_utility_code, +1) + + def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'str', 'startswith', + str_tailmatch_utility_code, -1) + + def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'bytes', 'endswith', + bytes_tailmatch_utility_code, +1) + + def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'bytes', 'startswith', + bytes_tailmatch_utility_code, -1) + + ''' # disabled for now, enable when we consider it worth it (see StringTools.c) + def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'bytearray', 'endswith', + bytes_tailmatch_utility_code, +1) + + def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method): + return self._inject_tailmatch( + node, function, args, is_unbound_method, 'bytearray', 'startswith', + bytes_tailmatch_utility_code, -1) + ''' + + ### helpers + + def _substitute_method_call(self, node, function, name, func_type, + attr_name, is_unbound_method, args=(), + utility_code=None, is_temp=None, may_return_none=ExprNodes.PythonCapiCallNode.may_return_none, with_none_check=True): - args = list(args) + args = list(args) if with_none_check and args: args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name) - if is_temp is None: - is_temp = node.is_temp - return ExprNodes.PythonCapiCallNode( - node.pos, name, func_type, - args = args, - is_temp = is_temp, - utility_code = utility_code, - may_return_none = may_return_none, - result_is_used = node.result_is_used, - ) - + if is_temp is None: + is_temp = node.is_temp + return ExprNodes.PythonCapiCallNode( + node.pos, name, func_type, + args = args, + is_temp = is_temp, + utility_code = utility_code, + may_return_none = may_return_none, + result_is_used = node.result_is_used, + ) + def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name): if self_arg.is_literal: return self_arg @@ -3963,280 +3963,280 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, format_args=[attr_name]) return self_arg - def _inject_int_default_argument(self, node, args, arg_index, type, default_value): - assert len(args) >= arg_index - if len(args) == arg_index: - args.append(ExprNodes.IntNode(node.pos, value=str(default_value), - type=type, constant_result=default_value)) - else: - args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) - - def _inject_bint_default_argument(self, node, args, arg_index, default_value): - assert len(args) >= arg_index - if len(args) == arg_index: - default_value = bool(default_value) - args.append(ExprNodes.BoolNode(node.pos, value=default_value, - constant_result=default_value)) - else: - args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env()) - - -unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c') -bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c') -str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c') - - -class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): - """Calculate the result of constant expressions to store it in - ``expr_node.constant_result``, and replace trivial cases by their - constant result. - - General rules: - - - We calculate float constants to make them available to the - compiler, but we do not aggregate them into a single literal - node to prevent any loss of precision. - - - We recursively calculate constants from non-literal nodes to - make them available to the compiler, but we only aggregate - literal nodes at each step. Non-literal nodes are never merged - into a single node. - """ - - def __init__(self, reevaluate=False): - """ - The reevaluate argument specifies whether constant values that were - previously computed should be recomputed. - """ - super(ConstantFolding, self).__init__() - self.reevaluate = reevaluate - - def _calculate_const(self, node): - if (not self.reevaluate and - node.constant_result is not ExprNodes.constant_value_not_set): - return - - # make sure we always set the value - not_a_constant = ExprNodes.not_a_constant - node.constant_result = not_a_constant - - # check if all children are constant - children = self.visitchildren(node) - for child_result in children.values(): - if type(child_result) is list: - for child in child_result: - if getattr(child, 'constant_result', not_a_constant) is not_a_constant: - return - elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant: - return - - # now try to calculate the real constant value - try: - node.calculate_constant_result() -# if node.constant_result is not ExprNodes.not_a_constant: -# print node.__class__.__name__, node.constant_result - except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): - # ignore all 'normal' errors here => no constant result - pass - except Exception: - # this looks like a real error - import traceback, sys - traceback.print_exc(file=sys.stdout) - - NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode, - ExprNodes.IntNode, ExprNodes.FloatNode] - - def _widest_node_class(self, *nodes): - try: - return self.NODE_TYPE_ORDER[ - max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] - except ValueError: - return None - - def _bool_node(self, node, value): - value = bool(value) - return ExprNodes.BoolNode(node.pos, value=value, constant_result=value) - - def visit_ExprNode(self, node): - self._calculate_const(node) - return node - - def visit_UnopNode(self, node): - self._calculate_const(node) - if not node.has_constant_result(): - if node.operator == '!': - return self._handle_NotNode(node) - return node - if not node.operand.is_literal: - return node - if node.operator == '!': - return self._bool_node(node, node.constant_result) - elif isinstance(node.operand, ExprNodes.BoolNode): - return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)), - type=PyrexTypes.c_int_type, - constant_result=int(node.constant_result)) - elif node.operator == '+': - return self._handle_UnaryPlusNode(node) - elif node.operator == '-': - return self._handle_UnaryMinusNode(node) - return node - - _negate_operator = { - 'in': 'not_in', - 'not_in': 'in', - 'is': 'is_not', - 'is_not': 'is' - }.get - - def _handle_NotNode(self, node): - operand = node.operand - if isinstance(operand, ExprNodes.PrimaryCmpNode): - operator = self._negate_operator(operand.operator) - if operator: - node = copy.copy(operand) - node.operator = operator - node = self.visit_PrimaryCmpNode(node) - return node - - def _handle_UnaryMinusNode(self, node): - def _negate(value): - if value.startswith('-'): - value = value[1:] - else: - value = '-' + value - return value - - node_type = node.operand.type - if isinstance(node.operand, ExprNodes.FloatNode): - # this is a safe operation - return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value), - type=node_type, - constant_result=node.constant_result) - if node_type.is_int and node_type.signed or \ - isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject: - return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value), - type=node_type, - longness=node.operand.longness, - constant_result=node.constant_result) - return node - - def _handle_UnaryPlusNode(self, node): - if (node.operand.has_constant_result() and - node.constant_result == node.operand.constant_result): - return node.operand - return node - - def visit_BoolBinopNode(self, node): - self._calculate_const(node) - if not node.operand1.has_constant_result(): - return node - if node.operand1.constant_result: - if node.operator == 'and': - return node.operand2 - else: - return node.operand1 - else: - if node.operator == 'and': - return node.operand1 - else: - return node.operand2 - - def visit_BinopNode(self, node): - self._calculate_const(node) - if node.constant_result is ExprNodes.not_a_constant: - return node - if isinstance(node.constant_result, float): - return node - operand1, operand2 = node.operand1, node.operand2 - if not operand1.is_literal or not operand2.is_literal: - return node - - # now inject a new constant node with the calculated value - try: - type1, type2 = operand1.type, operand2.type - if type1 is None or type2 is None: - return node - except AttributeError: - return node - - if type1.is_numeric and type2.is_numeric: - widest_type = PyrexTypes.widest_numeric_type(type1, type2) - else: - widest_type = PyrexTypes.py_object_type - - target_class = self._widest_node_class(operand1, operand2) - if target_class is None: - return node - elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>': - # C arithmetic results in at least an int type - target_class = ExprNodes.IntNode - elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^': - # C arithmetic results in at least an int type - target_class = ExprNodes.IntNode - - if target_class is ExprNodes.IntNode: - unsigned = getattr(operand1, 'unsigned', '') and \ - getattr(operand2, 'unsigned', '') - longness = "LL"[:max(len(getattr(operand1, 'longness', '')), - len(getattr(operand2, 'longness', '')))] - new_node = ExprNodes.IntNode(pos=node.pos, - unsigned=unsigned, longness=longness, - value=str(int(node.constant_result)), - constant_result=int(node.constant_result)) - # IntNode is smart about the type it chooses, so we just - # make sure we were not smarter this time - if widest_type.is_pyobject or new_node.type.is_pyobject: - new_node.type = PyrexTypes.py_object_type - else: - new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type) - else: - if target_class is ExprNodes.BoolNode: - node_value = node.constant_result - else: - node_value = str(node.constant_result) - new_node = target_class(pos=node.pos, type = widest_type, - value = node_value, - constant_result = node.constant_result) - return new_node - - def visit_AddNode(self, node): - self._calculate_const(node) - if node.constant_result is ExprNodes.not_a_constant: - return node - if node.operand1.is_string_literal and node.operand2.is_string_literal: - # some people combine string literals with a '+' - str1, str2 = node.operand1, node.operand2 - if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode): - bytes_value = None - if str1.bytes_value is not None and str2.bytes_value is not None: - if str1.bytes_value.encoding == str2.bytes_value.encoding: + def _inject_int_default_argument(self, node, args, arg_index, type, default_value): + assert len(args) >= arg_index + if len(args) == arg_index: + args.append(ExprNodes.IntNode(node.pos, value=str(default_value), + type=type, constant_result=default_value)) + else: + args[arg_index] = args[arg_index].coerce_to(type, self.current_env()) + + def _inject_bint_default_argument(self, node, args, arg_index, default_value): + assert len(args) >= arg_index + if len(args) == arg_index: + default_value = bool(default_value) + args.append(ExprNodes.BoolNode(node.pos, value=default_value, + constant_result=default_value)) + else: + args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env()) + + +unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c') +bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c') +str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c') + + +class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): + """Calculate the result of constant expressions to store it in + ``expr_node.constant_result``, and replace trivial cases by their + constant result. + + General rules: + + - We calculate float constants to make them available to the + compiler, but we do not aggregate them into a single literal + node to prevent any loss of precision. + + - We recursively calculate constants from non-literal nodes to + make them available to the compiler, but we only aggregate + literal nodes at each step. Non-literal nodes are never merged + into a single node. + """ + + def __init__(self, reevaluate=False): + """ + The reevaluate argument specifies whether constant values that were + previously computed should be recomputed. + """ + super(ConstantFolding, self).__init__() + self.reevaluate = reevaluate + + def _calculate_const(self, node): + if (not self.reevaluate and + node.constant_result is not ExprNodes.constant_value_not_set): + return + + # make sure we always set the value + not_a_constant = ExprNodes.not_a_constant + node.constant_result = not_a_constant + + # check if all children are constant + children = self.visitchildren(node) + for child_result in children.values(): + if type(child_result) is list: + for child in child_result: + if getattr(child, 'constant_result', not_a_constant) is not_a_constant: + return + elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant: + return + + # now try to calculate the real constant value + try: + node.calculate_constant_result() +# if node.constant_result is not ExprNodes.not_a_constant: +# print node.__class__.__name__, node.constant_result + except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): + # ignore all 'normal' errors here => no constant result + pass + except Exception: + # this looks like a real error + import traceback, sys + traceback.print_exc(file=sys.stdout) + + NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode, + ExprNodes.IntNode, ExprNodes.FloatNode] + + def _widest_node_class(self, *nodes): + try: + return self.NODE_TYPE_ORDER[ + max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] + except ValueError: + return None + + def _bool_node(self, node, value): + value = bool(value) + return ExprNodes.BoolNode(node.pos, value=value, constant_result=value) + + def visit_ExprNode(self, node): + self._calculate_const(node) + return node + + def visit_UnopNode(self, node): + self._calculate_const(node) + if not node.has_constant_result(): + if node.operator == '!': + return self._handle_NotNode(node) + return node + if not node.operand.is_literal: + return node + if node.operator == '!': + return self._bool_node(node, node.constant_result) + elif isinstance(node.operand, ExprNodes.BoolNode): + return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)), + type=PyrexTypes.c_int_type, + constant_result=int(node.constant_result)) + elif node.operator == '+': + return self._handle_UnaryPlusNode(node) + elif node.operator == '-': + return self._handle_UnaryMinusNode(node) + return node + + _negate_operator = { + 'in': 'not_in', + 'not_in': 'in', + 'is': 'is_not', + 'is_not': 'is' + }.get + + def _handle_NotNode(self, node): + operand = node.operand + if isinstance(operand, ExprNodes.PrimaryCmpNode): + operator = self._negate_operator(operand.operator) + if operator: + node = copy.copy(operand) + node.operator = operator + node = self.visit_PrimaryCmpNode(node) + return node + + def _handle_UnaryMinusNode(self, node): + def _negate(value): + if value.startswith('-'): + value = value[1:] + else: + value = '-' + value + return value + + node_type = node.operand.type + if isinstance(node.operand, ExprNodes.FloatNode): + # this is a safe operation + return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value), + type=node_type, + constant_result=node.constant_result) + if node_type.is_int and node_type.signed or \ + isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject: + return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value), + type=node_type, + longness=node.operand.longness, + constant_result=node.constant_result) + return node + + def _handle_UnaryPlusNode(self, node): + if (node.operand.has_constant_result() and + node.constant_result == node.operand.constant_result): + return node.operand + return node + + def visit_BoolBinopNode(self, node): + self._calculate_const(node) + if not node.operand1.has_constant_result(): + return node + if node.operand1.constant_result: + if node.operator == 'and': + return node.operand2 + else: + return node.operand1 + else: + if node.operator == 'and': + return node.operand1 + else: + return node.operand2 + + def visit_BinopNode(self, node): + self._calculate_const(node) + if node.constant_result is ExprNodes.not_a_constant: + return node + if isinstance(node.constant_result, float): + return node + operand1, operand2 = node.operand1, node.operand2 + if not operand1.is_literal or not operand2.is_literal: + return node + + # now inject a new constant node with the calculated value + try: + type1, type2 = operand1.type, operand2.type + if type1 is None or type2 is None: + return node + except AttributeError: + return node + + if type1.is_numeric and type2.is_numeric: + widest_type = PyrexTypes.widest_numeric_type(type1, type2) + else: + widest_type = PyrexTypes.py_object_type + + target_class = self._widest_node_class(operand1, operand2) + if target_class is None: + return node + elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>': + # C arithmetic results in at least an int type + target_class = ExprNodes.IntNode + elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^': + # C arithmetic results in at least an int type + target_class = ExprNodes.IntNode + + if target_class is ExprNodes.IntNode: + unsigned = getattr(operand1, 'unsigned', '') and \ + getattr(operand2, 'unsigned', '') + longness = "LL"[:max(len(getattr(operand1, 'longness', '')), + len(getattr(operand2, 'longness', '')))] + new_node = ExprNodes.IntNode(pos=node.pos, + unsigned=unsigned, longness=longness, + value=str(int(node.constant_result)), + constant_result=int(node.constant_result)) + # IntNode is smart about the type it chooses, so we just + # make sure we were not smarter this time + if widest_type.is_pyobject or new_node.type.is_pyobject: + new_node.type = PyrexTypes.py_object_type + else: + new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type) + else: + if target_class is ExprNodes.BoolNode: + node_value = node.constant_result + else: + node_value = str(node.constant_result) + new_node = target_class(pos=node.pos, type = widest_type, + value = node_value, + constant_result = node.constant_result) + return new_node + + def visit_AddNode(self, node): + self._calculate_const(node) + if node.constant_result is ExprNodes.not_a_constant: + return node + if node.operand1.is_string_literal and node.operand2.is_string_literal: + # some people combine string literals with a '+' + str1, str2 = node.operand1, node.operand2 + if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode): + bytes_value = None + if str1.bytes_value is not None and str2.bytes_value is not None: + if str1.bytes_value.encoding == str2.bytes_value.encoding: bytes_value = bytes_literal( str1.bytes_value + str2.bytes_value, str1.bytes_value.encoding) - string_value = EncodedString(node.constant_result) - return ExprNodes.UnicodeNode( - str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value) - elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode): - if str1.value.encoding == str2.value.encoding: + string_value = EncodedString(node.constant_result) + return ExprNodes.UnicodeNode( + str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value) + elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode): + if str1.value.encoding == str2.value.encoding: bytes_value = bytes_literal(node.constant_result, str1.value.encoding) - return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result) - # all other combinations are rather complicated - # to get right in Py2/3: encodings, unicode escapes, ... - return self.visit_BinopNode(node) - - def visit_MulNode(self, node): - self._calculate_const(node) - if node.operand1.is_sequence_constructor: - return self._calculate_constant_seq(node, node.operand1, node.operand2) - if isinstance(node.operand1, ExprNodes.IntNode) and \ - node.operand2.is_sequence_constructor: - return self._calculate_constant_seq(node, node.operand2, node.operand1) + return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result) + # all other combinations are rather complicated + # to get right in Py2/3: encodings, unicode escapes, ... + return self.visit_BinopNode(node) + + def visit_MulNode(self, node): + self._calculate_const(node) + if node.operand1.is_sequence_constructor: + return self._calculate_constant_seq(node, node.operand1, node.operand2) + if isinstance(node.operand1, ExprNodes.IntNode) and \ + node.operand2.is_sequence_constructor: + return self._calculate_constant_seq(node, node.operand2, node.operand1) if node.operand1.is_string_literal: return self._multiply_string(node, node.operand1, node.operand2) elif node.operand2.is_string_literal: return self._multiply_string(node, node.operand2, node.operand1) - return self.visit_BinopNode(node) - + return self.visit_BinopNode(node) + def _multiply_string(self, node, string_node, multiplier_node): multiplier = multiplier_node.constant_result if not isinstance(multiplier, _py_int_types): @@ -4273,25 +4273,25 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): string_node.constant_result = string_node.value return string_node - def _calculate_constant_seq(self, node, sequence_node, factor): - if factor.constant_result != 1 and sequence_node.args: + def _calculate_constant_seq(self, node, sequence_node, factor): + if factor.constant_result != 1 and sequence_node.args: if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0: - del sequence_node.args[:] - sequence_node.mult_factor = None - elif sequence_node.mult_factor is not None: + del sequence_node.args[:] + sequence_node.mult_factor = None + elif sequence_node.mult_factor is not None: if (isinstance(factor.constant_result, _py_int_types) and isinstance(sequence_node.mult_factor.constant_result, _py_int_types)): - value = sequence_node.mult_factor.constant_result * factor.constant_result - sequence_node.mult_factor = ExprNodes.IntNode( - sequence_node.mult_factor.pos, - value=str(value), constant_result=value) - else: - # don't know if we can combine the factors, so don't - return self.visit_BinopNode(node) - else: - sequence_node.mult_factor = factor - return sequence_node - + value = sequence_node.mult_factor.constant_result * factor.constant_result + sequence_node.mult_factor = ExprNodes.IntNode( + sequence_node.mult_factor.pos, + value=str(value), constant_result=value) + else: + # don't know if we can combine the factors, so don't + return self.visit_BinopNode(node) + else: + sequence_node.mult_factor = factor + return sequence_node + def visit_ModNode(self, node): self.visitchildren(node) if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode): @@ -4520,298 +4520,298 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): self._calculate_const(node) return node - def visit_PrimaryCmpNode(self, node): - # calculate constant partial results in the comparison cascade - self.visitchildren(node, ['operand1']) - left_node = node.operand1 - cmp_node = node - while cmp_node is not None: - self.visitchildren(cmp_node, ['operand2']) - right_node = cmp_node.operand2 - cmp_node.constant_result = not_a_constant - if left_node.has_constant_result() and right_node.has_constant_result(): - try: - cmp_node.calculate_cascaded_constant_result(left_node.constant_result) - except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): - pass # ignore all 'normal' errors here => no constant result - left_node = right_node - cmp_node = cmp_node.cascade - - if not node.cascade: - if node.has_constant_result(): - return self._bool_node(node, node.constant_result) - return node - - # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...] - cascades = [[node.operand1]] - final_false_result = [] - - def split_cascades(cmp_node): - if cmp_node.has_constant_result(): - if not cmp_node.constant_result: - # False => short-circuit - final_false_result.append(self._bool_node(cmp_node, False)) - return - else: - # True => discard and start new cascade - cascades.append([cmp_node.operand2]) - else: - # not constant => append to current cascade - cascades[-1].append(cmp_node) - if cmp_node.cascade: - split_cascades(cmp_node.cascade) - - split_cascades(node) - - cmp_nodes = [] - for cascade in cascades: - if len(cascade) < 2: - continue - cmp_node = cascade[1] - pcmp_node = ExprNodes.PrimaryCmpNode( - cmp_node.pos, - operand1=cascade[0], - operator=cmp_node.operator, - operand2=cmp_node.operand2, - constant_result=not_a_constant) - cmp_nodes.append(pcmp_node) - - last_cmp_node = pcmp_node - for cmp_node in cascade[2:]: - last_cmp_node.cascade = cmp_node - last_cmp_node = cmp_node - last_cmp_node.cascade = None - - if final_false_result: - # last cascade was constant False - cmp_nodes.append(final_false_result[0]) - elif not cmp_nodes: - # only constants, but no False result - return self._bool_node(node, True) - node = cmp_nodes[0] - if len(cmp_nodes) == 1: - if node.has_constant_result(): - return self._bool_node(node, node.constant_result) - else: - for cmp_node in cmp_nodes[1:]: - node = ExprNodes.BoolBinopNode( - node.pos, - operand1=node, - operator='and', - operand2=cmp_node, - constant_result=not_a_constant) - return node - - def visit_CondExprNode(self, node): - self._calculate_const(node) - if not node.test.has_constant_result(): - return node - if node.test.constant_result: - return node.true_val - else: - return node.false_val - - def visit_IfStatNode(self, node): - self.visitchildren(node) - # eliminate dead code based on constant condition results - if_clauses = [] - for if_clause in node.if_clauses: - condition = if_clause.condition - if condition.has_constant_result(): - if condition.constant_result: - # always true => subsequent clauses can safely be dropped - node.else_clause = if_clause.body - break - # else: false => drop clause - else: - # unknown result => normal runtime evaluation - if_clauses.append(if_clause) - if if_clauses: - node.if_clauses = if_clauses - return node - elif node.else_clause: - return node.else_clause - else: - return Nodes.StatListNode(node.pos, stats=[]) - - def visit_SliceIndexNode(self, node): - self._calculate_const(node) - # normalise start/stop values - if node.start is None or node.start.constant_result is None: - start = node.start = None - else: - start = node.start.constant_result - if node.stop is None or node.stop.constant_result is None: - stop = node.stop = None - else: - stop = node.stop.constant_result - # cut down sliced constant sequences - if node.constant_result is not not_a_constant: - base = node.base - if base.is_sequence_constructor and base.mult_factor is None: - base.args = base.args[start:stop] - return base - elif base.is_string_literal: - base = base.as_sliced_node(start, stop) - if base is not None: - return base - return node - - def visit_ComprehensionNode(self, node): - self.visitchildren(node) - if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: - # loop was pruned already => transform into literal - if node.type is Builtin.list_type: - return ExprNodes.ListNode( - node.pos, args=[], constant_result=[]) - elif node.type is Builtin.set_type: - return ExprNodes.SetNode( - node.pos, args=[], constant_result=set()) - elif node.type is Builtin.dict_type: - return ExprNodes.DictNode( - node.pos, key_value_pairs=[], constant_result={}) - return node - - def visit_ForInStatNode(self, node): - self.visitchildren(node) - sequence = node.iterator.sequence - if isinstance(sequence, ExprNodes.SequenceNode): - if not sequence.args: - if node.else_clause: - return node.else_clause - else: - # don't break list comprehensions - return Nodes.StatListNode(node.pos, stats=[]) - # iterating over a list literal? => tuples are more efficient - if isinstance(sequence, ExprNodes.ListNode): - node.iterator.sequence = sequence.as_tuple() - return node - - def visit_WhileStatNode(self, node): - self.visitchildren(node) - if node.condition and node.condition.has_constant_result(): - if node.condition.constant_result: - node.condition = None - node.else_clause = None - else: - return node.else_clause - return node - - def visit_ExprStatNode(self, node): - self.visitchildren(node) - if not isinstance(node.expr, ExprNodes.ExprNode): - # ParallelRangeTransform does this ... - return node - # drop unused constant expressions - if node.expr.has_constant_result(): - return None - return node - - # in the future, other nodes can have their own handler method here - # that can replace them with a constant result node - - visit_Node = Visitor.VisitorTransform.recurse_to_children - - + def visit_PrimaryCmpNode(self, node): + # calculate constant partial results in the comparison cascade + self.visitchildren(node, ['operand1']) + left_node = node.operand1 + cmp_node = node + while cmp_node is not None: + self.visitchildren(cmp_node, ['operand2']) + right_node = cmp_node.operand2 + cmp_node.constant_result = not_a_constant + if left_node.has_constant_result() and right_node.has_constant_result(): + try: + cmp_node.calculate_cascaded_constant_result(left_node.constant_result) + except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError): + pass # ignore all 'normal' errors here => no constant result + left_node = right_node + cmp_node = cmp_node.cascade + + if not node.cascade: + if node.has_constant_result(): + return self._bool_node(node, node.constant_result) + return node + + # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...] + cascades = [[node.operand1]] + final_false_result = [] + + def split_cascades(cmp_node): + if cmp_node.has_constant_result(): + if not cmp_node.constant_result: + # False => short-circuit + final_false_result.append(self._bool_node(cmp_node, False)) + return + else: + # True => discard and start new cascade + cascades.append([cmp_node.operand2]) + else: + # not constant => append to current cascade + cascades[-1].append(cmp_node) + if cmp_node.cascade: + split_cascades(cmp_node.cascade) + + split_cascades(node) + + cmp_nodes = [] + for cascade in cascades: + if len(cascade) < 2: + continue + cmp_node = cascade[1] + pcmp_node = ExprNodes.PrimaryCmpNode( + cmp_node.pos, + operand1=cascade[0], + operator=cmp_node.operator, + operand2=cmp_node.operand2, + constant_result=not_a_constant) + cmp_nodes.append(pcmp_node) + + last_cmp_node = pcmp_node + for cmp_node in cascade[2:]: + last_cmp_node.cascade = cmp_node + last_cmp_node = cmp_node + last_cmp_node.cascade = None + + if final_false_result: + # last cascade was constant False + cmp_nodes.append(final_false_result[0]) + elif not cmp_nodes: + # only constants, but no False result + return self._bool_node(node, True) + node = cmp_nodes[0] + if len(cmp_nodes) == 1: + if node.has_constant_result(): + return self._bool_node(node, node.constant_result) + else: + for cmp_node in cmp_nodes[1:]: + node = ExprNodes.BoolBinopNode( + node.pos, + operand1=node, + operator='and', + operand2=cmp_node, + constant_result=not_a_constant) + return node + + def visit_CondExprNode(self, node): + self._calculate_const(node) + if not node.test.has_constant_result(): + return node + if node.test.constant_result: + return node.true_val + else: + return node.false_val + + def visit_IfStatNode(self, node): + self.visitchildren(node) + # eliminate dead code based on constant condition results + if_clauses = [] + for if_clause in node.if_clauses: + condition = if_clause.condition + if condition.has_constant_result(): + if condition.constant_result: + # always true => subsequent clauses can safely be dropped + node.else_clause = if_clause.body + break + # else: false => drop clause + else: + # unknown result => normal runtime evaluation + if_clauses.append(if_clause) + if if_clauses: + node.if_clauses = if_clauses + return node + elif node.else_clause: + return node.else_clause + else: + return Nodes.StatListNode(node.pos, stats=[]) + + def visit_SliceIndexNode(self, node): + self._calculate_const(node) + # normalise start/stop values + if node.start is None or node.start.constant_result is None: + start = node.start = None + else: + start = node.start.constant_result + if node.stop is None or node.stop.constant_result is None: + stop = node.stop = None + else: + stop = node.stop.constant_result + # cut down sliced constant sequences + if node.constant_result is not not_a_constant: + base = node.base + if base.is_sequence_constructor and base.mult_factor is None: + base.args = base.args[start:stop] + return base + elif base.is_string_literal: + base = base.as_sliced_node(start, stop) + if base is not None: + return base + return node + + def visit_ComprehensionNode(self, node): + self.visitchildren(node) + if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats: + # loop was pruned already => transform into literal + if node.type is Builtin.list_type: + return ExprNodes.ListNode( + node.pos, args=[], constant_result=[]) + elif node.type is Builtin.set_type: + return ExprNodes.SetNode( + node.pos, args=[], constant_result=set()) + elif node.type is Builtin.dict_type: + return ExprNodes.DictNode( + node.pos, key_value_pairs=[], constant_result={}) + return node + + def visit_ForInStatNode(self, node): + self.visitchildren(node) + sequence = node.iterator.sequence + if isinstance(sequence, ExprNodes.SequenceNode): + if not sequence.args: + if node.else_clause: + return node.else_clause + else: + # don't break list comprehensions + return Nodes.StatListNode(node.pos, stats=[]) + # iterating over a list literal? => tuples are more efficient + if isinstance(sequence, ExprNodes.ListNode): + node.iterator.sequence = sequence.as_tuple() + return node + + def visit_WhileStatNode(self, node): + self.visitchildren(node) + if node.condition and node.condition.has_constant_result(): + if node.condition.constant_result: + node.condition = None + node.else_clause = None + else: + return node.else_clause + return node + + def visit_ExprStatNode(self, node): + self.visitchildren(node) + if not isinstance(node.expr, ExprNodes.ExprNode): + # ParallelRangeTransform does this ... + return node + # drop unused constant expressions + if node.expr.has_constant_result(): + return None + return node + + # in the future, other nodes can have their own handler method here + # that can replace them with a constant result node + + visit_Node = Visitor.VisitorTransform.recurse_to_children + + class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): - """ - This visitor handles several commuting optimizations, and is run - just before the C code generation phase. - - The optimizations currently implemented in this class are: - - eliminate None assignment and refcounting for first assignment. - - isinstance -> typecheck for cdef types - - eliminate checks for None and/or types that became redundant after tree changes + """ + This visitor handles several commuting optimizations, and is run + just before the C code generation phase. + + The optimizations currently implemented in this class are: + - eliminate None assignment and refcounting for first assignment. + - isinstance -> typecheck for cdef types + - eliminate checks for None and/or types that became redundant after tree changes - eliminate useless string formatting steps - - replace Python function calls that look like method calls by a faster PyMethodCallNode - """ + - replace Python function calls that look like method calls by a faster PyMethodCallNode + """ in_loop = False - def visit_SingleAssignmentNode(self, node): - """Avoid redundant initialisation of local variables before their - first assignment. - """ - self.visitchildren(node) - if node.first: - lhs = node.lhs - lhs.lhs_of_first_assignment = True - return node - - def visit_SimpleCallNode(self, node): - """ - Replace generic calls to isinstance(x, type) by a more efficient type check. - Replace likely Python method calls by a specialised PyMethodCallNode. - """ - self.visitchildren(node) - function = node.function - if function.type.is_cfunction and function.is_name: - if function.name == 'isinstance' and len(node.args) == 2: - type_arg = node.args[1] - if type_arg.type.is_builtin_type and type_arg.type.name == 'type': - cython_scope = self.context.cython_scope - function.entry = cython_scope.lookup('PyObject_TypeCheck') - function.type = function.entry.type - PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type) - node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) + def visit_SingleAssignmentNode(self, node): + """Avoid redundant initialisation of local variables before their + first assignment. + """ + self.visitchildren(node) + if node.first: + lhs = node.lhs + lhs.lhs_of_first_assignment = True + return node + + def visit_SimpleCallNode(self, node): + """ + Replace generic calls to isinstance(x, type) by a more efficient type check. + Replace likely Python method calls by a specialised PyMethodCallNode. + """ + self.visitchildren(node) + function = node.function + if function.type.is_cfunction and function.is_name: + if function.name == 'isinstance' and len(node.args) == 2: + type_arg = node.args[1] + if type_arg.type.is_builtin_type and type_arg.type.name == 'type': + cython_scope = self.context.cython_scope + function.entry = cython_scope.lookup('PyObject_TypeCheck') + function.type = function.entry.type + PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type) + node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) elif (node.is_temp and function.type.is_pyobject and self.current_directives.get( "optimize.unpack_method_calls_in_pyinit" if not self.in_loop and self.current_env().is_module_scope else "optimize.unpack_method_calls")): - # optimise simple Python methods calls - if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not ( + # optimise simple Python methods calls + if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not ( node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)): - # simple call, now exclude calls to objects that are definitely not methods - may_be_a_method = True - if function.type is Builtin.type_type: - may_be_a_method = False + # simple call, now exclude calls to objects that are definitely not methods + may_be_a_method = True + if function.type is Builtin.type_type: + may_be_a_method = False elif function.is_attribute: if function.entry and function.entry.type.is_cfunction: # optimised builtin method may_be_a_method = False - elif function.is_name: + elif function.is_name: entry = function.entry if entry.is_builtin or entry.type.is_cfunction: - may_be_a_method = False + may_be_a_method = False elif entry.cf_assignments: - # local functions/classes are definitely not methods - non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode) - may_be_a_method = any( - assignment.rhs and not isinstance(assignment.rhs, non_method_nodes) + # local functions/classes are definitely not methods + non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode) + may_be_a_method = any( + assignment.rhs and not isinstance(assignment.rhs, non_method_nodes) for assignment in entry.cf_assignments) - if may_be_a_method: + if may_be_a_method: if (node.self and function.is_attribute and isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self): # function self object was moved into a CloneNode => undo function.obj = function.obj.arg - node = self.replace(node, ExprNodes.PyMethodCallNode.from_node( - node, function=function, arg_tuple=node.arg_tuple, type=node.type)) - return node - + node = self.replace(node, ExprNodes.PyMethodCallNode.from_node( + node, function=function, arg_tuple=node.arg_tuple, type=node.type)) + return node + def visit_NumPyMethodCallNode(self, node): # Exclude from replacement above. self.visitchildren(node) return node - def visit_PyTypeTestNode(self, node): - """Remove tests for alternatively allowed None values from - type tests when we know that the argument cannot be None - anyway. - """ - self.visitchildren(node) - if not node.notnone: - if not node.arg.may_be_none(): - node.notnone = True - return node - - def visit_NoneCheckNode(self, node): - """Remove None checks from expressions that definitely do not - carry a None value. - """ - self.visitchildren(node) - if not node.arg.may_be_none(): - return node.arg - return node - + def visit_PyTypeTestNode(self, node): + """Remove tests for alternatively allowed None values from + type tests when we know that the argument cannot be None + anyway. + """ + self.visitchildren(node) + if not node.notnone: + if not node.arg.may_be_none(): + node.notnone = True + return node + + def visit_NoneCheckNode(self, node): + """Remove None checks from expressions that definitely do not + carry a None value. + """ + self.visitchildren(node) + if not node.arg.may_be_none(): + return node.arg + return node + def visit_LoopNode(self, node): """Remember when we enter a loop as some expensive optimisations might still be worth it there. """ @@ -4822,36 +4822,36 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): return node -class ConsolidateOverflowCheck(Visitor.CythonTransform): - """ - This class facilitates the sharing of overflow checking among all nodes - of a nested arithmetic expression. For example, given the expression - a*b + c, where a, b, and x are all possibly overflowing ints, the entire - sequence will be evaluated and the overflow bit checked only at the end. - """ - overflow_bit_node = None - - def visit_Node(self, node): - if self.overflow_bit_node is not None: - saved = self.overflow_bit_node - self.overflow_bit_node = None - self.visitchildren(node) - self.overflow_bit_node = saved - else: - self.visitchildren(node) - return node - - def visit_NumBinopNode(self, node): - if node.overflow_check and node.overflow_fold: - top_level_overflow = self.overflow_bit_node is None - if top_level_overflow: - self.overflow_bit_node = node - else: - node.overflow_bit_node = self.overflow_bit_node - node.overflow_check = False - self.visitchildren(node) - if top_level_overflow: - self.overflow_bit_node = None - else: - self.visitchildren(node) - return node +class ConsolidateOverflowCheck(Visitor.CythonTransform): + """ + This class facilitates the sharing of overflow checking among all nodes + of a nested arithmetic expression. For example, given the expression + a*b + c, where a, b, and x are all possibly overflowing ints, the entire + sequence will be evaluated and the overflow bit checked only at the end. + """ + overflow_bit_node = None + + def visit_Node(self, node): + if self.overflow_bit_node is not None: + saved = self.overflow_bit_node + self.overflow_bit_node = None + self.visitchildren(node) + self.overflow_bit_node = saved + else: + self.visitchildren(node) + return node + + def visit_NumBinopNode(self, node): + if node.overflow_check and node.overflow_fold: + top_level_overflow = self.overflow_bit_node is None + if top_level_overflow: + self.overflow_bit_node = node + else: + node.overflow_bit_node = self.overflow_bit_node + node.overflow_check = False + self.visitchildren(node) + if top_level_overflow: + self.overflow_bit_node = None + else: + self.visitchildren(node) + return node |