diff options
author | orivej <[email protected]> | 2022-02-10 16:45:01 +0300 |
---|---|---|
committer | Daniil Cherednik <[email protected]> | 2022-02-10 16:45:01 +0300 |
commit | 2d37894b1b037cf24231090eda8589bbb44fb6fc (patch) | |
tree | be835aa92c6248212e705f25388ebafcf84bc7a1 /contrib/tools/cython/Cython/Compiler/Optimize.py | |
parent | 718c552901d703c502ccbefdfc3c9028d608b947 (diff) |
Restoring authorship annotation for <[email protected]>. Commit 2 of 2.
Diffstat (limited to 'contrib/tools/cython/Cython/Compiler/Optimize.py')
-rw-r--r-- | contrib/tools/cython/Cython/Compiler/Optimize.py | 1740 |
1 files changed, 870 insertions, 870 deletions
diff --git a/contrib/tools/cython/Cython/Compiler/Optimize.py b/contrib/tools/cython/Cython/Compiler/Optimize.py index da4556a9f63..3cb77efe2c5 100644 --- a/contrib/tools/cython/Cython/Compiler/Optimize.py +++ b/contrib/tools/cython/Cython/Compiler/Optimize.py @@ -1,25 +1,25 @@ from __future__ import absolute_import import re -import sys -import copy -import codecs -import itertools - +import sys +import copy +import codecs +import itertools + from . import TypeSlots from .ExprNodes import not_a_constant import cython cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object, Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, - UtilNodes=object, _py_int_types=object) + UtilNodes=object, _py_int_types=object) -if sys.version_info[0] >= 3: - _py_int_types = int +if sys.version_info[0] >= 3: + _py_int_types = int _py_string_types = (bytes, str) -else: - _py_int_types = (int, long) +else: + _py_int_types = (int, long) _py_string_types = (bytes, unicode) - + from . import Nodes from . import ExprNodes from . import PyrexTypes @@ -28,7 +28,7 @@ from . import Builtin from . import UtilNodes from . import Options -from .Code import UtilityCode, TempitaUtilityCode +from .Code import UtilityCode, TempitaUtilityCode from .StringEncoding import EncodedString, bytes_literal, encoded_string from .Errors import error, warning from .ParseTreeTransforms import SkipDeclarations @@ -43,23 +43,23 @@ try: except ImportError: basestring = str # Python 3 - + def load_c_utility(name): return UtilityCode.load_cached(name, "Optimize.c") - + def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): if isinstance(node, coercion_nodes): return node.arg return node - + def unwrap_node(node): while isinstance(node, UtilNodes.ResultRefNode): node = node.expression return node - + def is_common_value(a, b): a = unwrap_node(a) b = unwrap_node(b) @@ -69,66 +69,66 @@ def is_common_value(a, b): return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return False - + def filter_none_node(node): if node is not None and node.constant_result is None: return None return node - -class _YieldNodeCollector(Visitor.TreeVisitor): - """ - YieldExprNode finder for generator expressions. - """ - def __init__(self): - Visitor.TreeVisitor.__init__(self) - self.yield_stat_nodes = {} - self.yield_nodes = [] - - visit_Node = Visitor.TreeVisitor.visitchildren - - def visit_YieldExprNode(self, node): - self.yield_nodes.append(node) - self.visitchildren(node) - - def visit_ExprStatNode(self, node): - self.visitchildren(node) - if node.expr in self.yield_nodes: - self.yield_stat_nodes[node.expr] = node - - # everything below these nodes is out of scope: - - def visit_GeneratorExpressionNode(self, node): - pass - - def visit_LambdaNode(self, node): - pass - - def visit_FuncDefNode(self, node): - pass - - -def _find_single_yield_expression(node): - yield_statements = _find_yield_statements(node) - if len(yield_statements) != 1: - return None, None - return yield_statements[0] - - -def _find_yield_statements(node): - collector = _YieldNodeCollector() - collector.visitchildren(node) - try: - yield_statements = [ - (yield_node.arg, collector.yield_stat_nodes[yield_node]) - for yield_node in collector.yield_nodes - ] - except KeyError: - # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield') - yield_statements = [] - return yield_statements - - + +class _YieldNodeCollector(Visitor.TreeVisitor): + """ + YieldExprNode finder for generator expressions. + """ + def __init__(self): + Visitor.TreeVisitor.__init__(self) + self.yield_stat_nodes = {} + self.yield_nodes = [] + + visit_Node = Visitor.TreeVisitor.visitchildren + + def visit_YieldExprNode(self, node): + self.yield_nodes.append(node) + self.visitchildren(node) + + def visit_ExprStatNode(self, node): + self.visitchildren(node) + if node.expr in self.yield_nodes: + self.yield_stat_nodes[node.expr] = node + + # everything below these nodes is out of scope: + + def visit_GeneratorExpressionNode(self, node): + pass + + def visit_LambdaNode(self, node): + pass + + def visit_FuncDefNode(self, node): + pass + + +def _find_single_yield_expression(node): + yield_statements = _find_yield_statements(node) + if len(yield_statements) != 1: + return None, None + return yield_statements[0] + + +def _find_yield_statements(node): + collector = _YieldNodeCollector() + collector.visitchildren(node) + try: + yield_statements = [ + (yield_node.arg, collector.yield_stat_nodes[yield_node]) + for yield_node in collector.yield_nodes + ] + except KeyError: + # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield') + yield_statements = [] + return yield_statements + + class IterationTransform(Visitor.EnvTransform): """Transform some common for-in loop patterns into efficient C loops: @@ -148,7 +148,7 @@ class IterationTransform(Visitor.EnvTransform): pos = node.pos result_ref = UtilNodes.ResultRefNode(node) - if node.operand2.is_subscript: + if node.operand2.is_subscript: base_type = node.operand2.base.type.base_type else: base_type = node.operand2.type.base_type @@ -250,7 +250,7 @@ class IterationTransform(Visitor.EnvTransform): if not is_safe_iter and method in ('keys', 'values', 'items'): # try to reduce this to the corresponding .iter*() methods - if isinstance(base_obj, ExprNodes.CallNode): + if isinstance(base_obj, ExprNodes.CallNode): inner_function = base_obj.function if (inner_function.is_name and inner_function.name == 'dict' and inner_function.entry @@ -391,7 +391,7 @@ class IterationTransform(Visitor.EnvTransform): if slice_node.is_literal: # try to reduce to byte iteration for plain Latin-1 strings try: - bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1') + bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1') except UnicodeEncodeError: pass else: @@ -400,8 +400,8 @@ class IterationTransform(Visitor.EnvTransform): base=ExprNodes.BytesNode( slice_node.pos, value=bytes_value, constant_result=bytes_value, - type=PyrexTypes.c_const_char_ptr_type).coerce_to( - PyrexTypes.c_const_uchar_ptr_type, self.current_env()), + type=PyrexTypes.c_const_char_ptr_type).coerce_to( + PyrexTypes.c_const_uchar_ptr_type, self.current_env()), start=None, stop=ExprNodes.IntNode( slice_node.pos, value=str(len(bytes_value)), @@ -491,7 +491,7 @@ class IterationTransform(Visitor.EnvTransform): error(slice_node.pos, "C array iteration requires known end index") return node - elif slice_node.is_subscript: + elif slice_node.is_subscript: assert isinstance(slice_node.index, ExprNodes.SliceNode) slice_base = slice_node.base index = slice_node.index @@ -499,7 +499,7 @@ class IterationTransform(Visitor.EnvTransform): stop = filter_none_node(index.stop) step = filter_none_node(index.step) if step: - if not isinstance(step.constant_result, _py_int_types) \ + if not isinstance(step.constant_result, _py_int_types) \ or step.constant_result == 0 \ or step.constant_result > 0 and not stop \ or step.constant_result < 0 and not start: @@ -733,19 +733,19 @@ class IterationTransform(Visitor.EnvTransform): if len(args) < 3: step_pos = range_function.pos step_value = 1 - step = ExprNodes.IntNode(step_pos, value='1', constant_result=1) + step = ExprNodes.IntNode(step_pos, value='1', constant_result=1) else: step = args[2] step_pos = step.pos - if not isinstance(step.constant_result, _py_int_types): + if not isinstance(step.constant_result, _py_int_types): # cannot determine step direction return node step_value = step.constant_result if step_value == 0: # will lead to an error elsewhere return node - step = ExprNodes.IntNode(step_pos, value=str(step_value), - constant_result=step_value) + step = ExprNodes.IntNode(step_pos, value=str(step_value), + constant_result=step_value) if len(args) == 1: bound1 = ExprNodes.IntNode(range_function.pos, value='0', @@ -757,34 +757,34 @@ class IterationTransform(Visitor.EnvTransform): relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed) - bound2_ref_node = None + bound2_ref_node = None if reversed: bound1, bound2 = bound2, bound1 - abs_step = abs(step_value) - if abs_step != 1: - if (isinstance(bound1.constant_result, _py_int_types) and - isinstance(bound2.constant_result, _py_int_types)): - # calculate final bounds now - if step_value < 0: - begin_value = bound2.constant_result - end_value = bound1.constant_result - bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1 - else: - begin_value = bound1.constant_result - end_value = bound2.constant_result - bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1 - - bound1 = ExprNodes.IntNode( - bound1.pos, value=str(bound1_value), constant_result=bound1_value, - type=PyrexTypes.spanning_type(bound1.type, bound2.type)) - else: - # evaluate the same expression as above at runtime - bound2_ref_node = UtilNodes.LetRefNode(bound2) - bound1 = self._build_range_step_calculation( - bound1, bound2_ref_node, step, step_value) - - if step_value < 0: - step_value = -step_value + abs_step = abs(step_value) + if abs_step != 1: + if (isinstance(bound1.constant_result, _py_int_types) and + isinstance(bound2.constant_result, _py_int_types)): + # calculate final bounds now + if step_value < 0: + begin_value = bound2.constant_result + end_value = bound1.constant_result + bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1 + else: + begin_value = bound1.constant_result + end_value = bound2.constant_result + bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1 + + bound1 = ExprNodes.IntNode( + bound1.pos, value=str(bound1_value), constant_result=bound1_value, + type=PyrexTypes.spanning_type(bound1.type, bound2.type)) + else: + # evaluate the same expression as above at runtime + bound2_ref_node = UtilNodes.LetRefNode(bound2) + bound1 = self._build_range_step_calculation( + bound1, bound2_ref_node, step, step_value) + + if step_value < 0: + step_value = -step_value step.value = str(step_value) step.constant_result = step_value step = step.coerce_to_integer(self.current_env()) @@ -792,7 +792,7 @@ class IterationTransform(Visitor.EnvTransform): if not bound2.is_literal: # stop bound must be immutable => keep it in a temp var bound2_is_temp = True - bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2) + bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2) else: bound2_is_temp = False @@ -811,70 +811,70 @@ class IterationTransform(Visitor.EnvTransform): return for_node - def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value): - abs_step = abs(step_value) - spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type) - if step.type.is_int and abs_step < 0x7FFF: - # Avoid loss of integer precision warnings. - spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type) - else: - spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type) - if step_value < 0: - begin_value = bound2_ref_node - end_value = bound1 - final_op = '-' - else: - begin_value = bound1 - end_value = bound2_ref_node - final_op = '+' - - step_calculation_node = ExprNodes.binop_node( - bound1.pos, - operand1=ExprNodes.binop_node( - bound1.pos, - operand1=bound2_ref_node, - operator=final_op, # +/- - operand2=ExprNodes.MulNode( - bound1.pos, - operand1=ExprNodes.IntNode( - bound1.pos, - value=str(abs_step), - constant_result=abs_step, - type=spanning_step_type), - operator='*', - operand2=ExprNodes.DivNode( - bound1.pos, - operand1=ExprNodes.SubNode( - bound1.pos, - operand1=ExprNodes.SubNode( - bound1.pos, - operand1=begin_value, - operator='-', - operand2=end_value, - type=spanning_type), - operator='-', - operand2=ExprNodes.IntNode( - bound1.pos, - value='1', - constant_result=1), - type=spanning_step_type), - operator='//', - operand2=ExprNodes.IntNode( - bound1.pos, - value=str(abs_step), - constant_result=abs_step, - type=spanning_step_type), - type=spanning_step_type), - type=spanning_step_type), - type=spanning_step_type), - operator=final_op, # +/- - operand2=ExprNodes.IntNode( - bound1.pos, - value='1', - constant_result=1), - type=spanning_type) - return step_calculation_node - + def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value): + abs_step = abs(step_value) + spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type) + if step.type.is_int and abs_step < 0x7FFF: + # Avoid loss of integer precision warnings. + spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type) + else: + spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type) + if step_value < 0: + begin_value = bound2_ref_node + end_value = bound1 + final_op = '-' + else: + begin_value = bound1 + end_value = bound2_ref_node + final_op = '+' + + step_calculation_node = ExprNodes.binop_node( + bound1.pos, + operand1=ExprNodes.binop_node( + bound1.pos, + operand1=bound2_ref_node, + operator=final_op, # +/- + operand2=ExprNodes.MulNode( + bound1.pos, + operand1=ExprNodes.IntNode( + bound1.pos, + value=str(abs_step), + constant_result=abs_step, + type=spanning_step_type), + operator='*', + operand2=ExprNodes.DivNode( + bound1.pos, + operand1=ExprNodes.SubNode( + bound1.pos, + operand1=ExprNodes.SubNode( + bound1.pos, + operand1=begin_value, + operator='-', + operand2=end_value, + type=spanning_type), + operator='-', + operand2=ExprNodes.IntNode( + bound1.pos, + value='1', + constant_result=1), + type=spanning_step_type), + operator='//', + operand2=ExprNodes.IntNode( + bound1.pos, + value=str(abs_step), + constant_result=abs_step, + type=spanning_step_type), + type=spanning_step_type), + type=spanning_step_type), + type=spanning_step_type), + operator=final_op, # +/- + operand2=ExprNodes.IntNode( + bound1.pos, + value='1', + constant_result=1), + type=spanning_type) + return step_calculation_node + def _transform_dict_iteration(self, node, dict_obj, method, keys, values): temps = [] temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) @@ -1192,9 +1192,9 @@ class SwitchTransform(Visitor.EnvTransform): if common_var is None: self.visitchildren(node) return node - cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos, - conditions=conditions, - body=if_clause.body)) + cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos, + conditions=conditions, + body=if_clause.body)) condition_values = [ cond for case in cases for cond in case.conditions] @@ -1205,16 +1205,16 @@ class SwitchTransform(Visitor.EnvTransform): self.visitchildren(node) return node - # Recurse into body subtrees that we left untouched so far. - self.visitchildren(node, 'else_clause') - for case in cases: - self.visitchildren(case, 'body') - + # Recurse into body subtrees that we left untouched so far. + self.visitchildren(node, 'else_clause') + for case in cases: + self.visitchildren(case, 'body') + common_var = unwrap_node(common_var) - switch_node = Nodes.SwitchStatNode(pos=node.pos, - test=common_var, - cases=cases, - else_clause=node.else_clause) + switch_node = Nodes.SwitchStatNode(pos=node.pos, + test=common_var, + cases=cases, + else_clause=node.else_clause) return switch_node def visit_CondExprNode(self, node): @@ -1225,11 +1225,11 @@ class SwitchTransform(Visitor.EnvTransform): not_in, common_var, conditions = self.extract_common_conditions( None, node.test, True) if common_var is None \ - or len(conditions) < 2 \ - or self.has_duplicate_values(conditions): + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): self.visitchildren(node) return node - + return self.build_simple_switch_statement( node, common_var, conditions, not_in, node.true_val, node.false_val) @@ -1242,8 +1242,8 @@ class SwitchTransform(Visitor.EnvTransform): not_in, common_var, conditions = self.extract_common_conditions( None, node, True) if common_var is None \ - or len(conditions) < 2 \ - or self.has_duplicate_values(conditions): + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): self.visitchildren(node) node.wrap_operands(self.current_env()) # in case we changed the operands return node @@ -1261,8 +1261,8 @@ class SwitchTransform(Visitor.EnvTransform): not_in, common_var, conditions = self.extract_common_conditions( None, node, True) if common_var is None \ - or len(conditions) < 2 \ - or self.has_duplicate_values(conditions): + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): self.visitchildren(node) return node @@ -1477,20 +1477,20 @@ class DropRefcountingTransform(Visitor.VisitorTransform): node = node.arg name_path = [] obj_node = node - while obj_node.is_attribute: + while obj_node.is_attribute: if obj_node.is_py_attr: return False name_path.append(obj_node.member) obj_node = obj_node.obj - if obj_node.is_name: + if obj_node.is_name: name_path.append(obj_node.name) names.append( ('.'.join(name_path[::-1]), node) ) - elif node.is_subscript: + elif node.is_subscript: if node.base.type != Builtin.list_type: return False if not node.index.type.is_int: return False - if not node.base.is_name: + if not node.base.is_name: return False indices.append(node) else: @@ -1618,60 +1618,60 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): stop=stop, step=step or ExprNodes.NoneNode(node.pos)) - def _handle_simple_function_ord(self, node, pos_args): - """Unpack ord('X'). - """ - if len(pos_args) != 1: - return node - arg = pos_args[0] - if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): - if len(arg.value) == 1: - return ExprNodes.IntNode( - arg.pos, type=PyrexTypes.c_long_type, - value=str(ord(arg.value)), - constant_result=ord(arg.value) - ) - elif isinstance(arg, ExprNodes.StringNode): - if arg.unicode_value and len(arg.unicode_value) == 1 \ - and ord(arg.unicode_value) <= 255: # Py2/3 portability - return ExprNodes.IntNode( - arg.pos, type=PyrexTypes.c_int_type, - value=str(ord(arg.unicode_value)), - constant_result=ord(arg.unicode_value) - ) - return node - - # sequence processing + def _handle_simple_function_ord(self, node, pos_args): + """Unpack ord('X'). + """ + if len(pos_args) != 1: + return node + arg = pos_args[0] + if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): + if len(arg.value) == 1: + return ExprNodes.IntNode( + arg.pos, type=PyrexTypes.c_long_type, + value=str(ord(arg.value)), + constant_result=ord(arg.value) + ) + elif isinstance(arg, ExprNodes.StringNode): + if arg.unicode_value and len(arg.unicode_value) == 1 \ + and ord(arg.unicode_value) <= 255: # Py2/3 portability + return ExprNodes.IntNode( + arg.pos, type=PyrexTypes.c_int_type, + value=str(ord(arg.unicode_value)), + constant_result=ord(arg.unicode_value) + ) + return node + + # sequence processing def _handle_simple_function_all(self, node, pos_args): """Transform - _result = all(p(x) for L in LL for x in L) + _result = all(p(x) for L in LL for x in L) into for L in LL: for x in L: - if not p(x): - return False + if not p(x): + return False else: - return True + return True """ return self._transform_any_all(node, pos_args, False) def _handle_simple_function_any(self, node, pos_args): """Transform - _result = any(p(x) for L in LL for x in L) + _result = any(p(x) for L in LL for x in L) into for L in LL: for x in L: - if p(x): - return True + if p(x): + return True else: - return False + return False """ return self._transform_any_all(node, pos_args, True) @@ -1681,40 +1681,40 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): return node gen_expr_node = pos_args[0] - generator_body = gen_expr_node.def_node.gbody - loop_node = generator_body.body - yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) + generator_body = gen_expr_node.def_node.gbody + loop_node = generator_body.body + yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) if yield_expression is None: return node if is_any: condition = yield_expression else: - condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression) + condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression) test_node = Nodes.IfStatNode( - yield_expression.pos, else_clause=None, if_clauses=[ - Nodes.IfClauseNode( - yield_expression.pos, - condition=condition, - body=Nodes.ReturnStatNode( - node.pos, - value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any)) - )] - ) - loop_node.else_clause = Nodes.ReturnStatNode( + yield_expression.pos, else_clause=None, if_clauses=[ + Nodes.IfClauseNode( + yield_expression.pos, + condition=condition, + body=Nodes.ReturnStatNode( + node.pos, + value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any)) + )] + ) + loop_node.else_clause = Nodes.ReturnStatNode( node.pos, - value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any)) + value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any)) - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node) return ExprNodes.InlinedGeneratorExpressionNode( - gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all') + gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all') + + PySequence_List_func_type = PyrexTypes.CFuncType( + Builtin.list_type, + [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) - PySequence_List_func_type = PyrexTypes.CFuncType( - Builtin.list_type, - [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) - def _handle_simple_function_sorted(self, node, pos_args): """Transform sorted(genexpr) and sorted([listcomp]) into [listcomp].sort(). CPython just reads the iterable into a @@ -1724,62 +1724,62 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): """ if len(pos_args) != 1: return node - - arg = pos_args[0] - if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type: - list_node = pos_args[0] - loop_node = list_node.loop - - elif isinstance(arg, ExprNodes.GeneratorExpressionNode): - gen_expr_node = arg + + arg = pos_args[0] + if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type: + list_node = pos_args[0] + loop_node = list_node.loop + + elif isinstance(arg, ExprNodes.GeneratorExpressionNode): + gen_expr_node = arg loop_node = gen_expr_node.loop - yield_statements = _find_yield_statements(loop_node) - if not yield_statements: + yield_statements = _find_yield_statements(loop_node) + if not yield_statements: return node - list_node = ExprNodes.InlinedGeneratorExpressionNode( - node.pos, gen_expr_node, orig_func='sorted', - comprehension_type=Builtin.list_type) - - for yield_expression, yield_stat_node in yield_statements: - append_node = ExprNodes.ComprehensionAppendNode( - yield_expression.pos, - expr=yield_expression, - target=list_node.target) - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - - elif arg.is_sequence_constructor: - # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list, - # so starting off with a fresh one is more efficient. - list_node = loop_node = arg.as_list() - + list_node = ExprNodes.InlinedGeneratorExpressionNode( + node.pos, gen_expr_node, orig_func='sorted', + comprehension_type=Builtin.list_type) + + for yield_expression, yield_stat_node in yield_statements: + append_node = ExprNodes.ComprehensionAppendNode( + yield_expression.pos, + expr=yield_expression, + target=list_node.target) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) + + elif arg.is_sequence_constructor: + # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list, + # so starting off with a fresh one is more efficient. + list_node = loop_node = arg.as_list() + else: - # Interestingly, PySequence_List works on a lot of non-sequence - # things as well. - list_node = loop_node = ExprNodes.PythonCapiCallNode( - node.pos, "PySequence_List", self.PySequence_List_func_type, - args=pos_args, is_temp=True) + # Interestingly, PySequence_List works on a lot of non-sequence + # things as well. + list_node = loop_node = ExprNodes.PythonCapiCallNode( + node.pos, "PySequence_List", self.PySequence_List_func_type, + args=pos_args, is_temp=True) result_node = UtilNodes.ResultRefNode( - pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False) - list_assign_node = Nodes.SingleAssignmentNode( - node.pos, lhs=result_node, rhs=list_node, first=True) + pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False) + list_assign_node = Nodes.SingleAssignmentNode( + node.pos, lhs=result_node, rhs=list_node, first=True) sort_method = ExprNodes.AttributeNode( - node.pos, obj=result_node, attribute=EncodedString('sort'), + node.pos, obj=result_node, attribute=EncodedString('sort'), # entry ? type ? - needs_none_check=False) + needs_none_check=False) sort_node = Nodes.ExprStatNode( - node.pos, expr=ExprNodes.SimpleCallNode( - node.pos, function=sort_method, args=[])) + node.pos, expr=ExprNodes.SimpleCallNode( + node.pos, function=sort_method, args=[])) sort_node.analyse_declarations(self.current_env()) return UtilNodes.TempResultFromStatNode( result_node, - Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node])) + Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node])) - def __handle_simple_function_sum(self, node, pos_args): + def __handle_simple_function_sum(self, node, pos_args): """Transform sum(genexpr) into an equivalent inlined aggregation loop. """ if len(pos_args) not in (1,2): @@ -1791,12 +1791,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): loop_node = gen_expr_node.loop if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode): - yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) - # FIXME: currently nonfunctional - yield_expression = None + yield_expression, yield_stat_node = _find_single_yield_expression(loop_node) + # FIXME: currently nonfunctional + yield_expression = None if yield_expression is None: return node - else: # ComprehensionNode + else: # ComprehensionNode yield_stat_node = gen_expr_node.append yield_expression = yield_stat_node.expr try: @@ -1819,7 +1819,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) ) - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node) exec_code = Nodes.StatListNode( node.pos, @@ -1849,7 +1849,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if len(args) <= 1: if len(args) == 1 and args[0].is_sequence_constructor: args = args[0].args - if len(args) <= 1: + if len(args) <= 1: # leave this to Python return node @@ -1876,8 +1876,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return last_result - # builtin type creation - + # builtin type creation + def _DISABLED_handle_simple_function_tuple(self, node, pos_args): if not pos_args: return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) @@ -1915,7 +1915,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type) def _transform_list_set_genexpr(self, node, pos_args, target_type): - """Replace set(genexpr) and list(genexpr) by an inlined comprehension. + """Replace set(genexpr) and list(genexpr) by an inlined comprehension. """ if len(pos_args) > 1: return node @@ -1924,26 +1924,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - yield_statements = _find_yield_statements(loop_node) - if not yield_statements: + yield_statements = _find_yield_statements(loop_node) + if not yield_statements: return node - result_node = ExprNodes.InlinedGeneratorExpressionNode( - node.pos, gen_expr_node, - orig_func='set' if target_type is Builtin.set_type else 'list', - comprehension_type=target_type) - - for yield_expression, yield_stat_node in yield_statements: - append_node = ExprNodes.ComprehensionAppendNode( - yield_expression.pos, - expr=yield_expression, - target=result_node.target) - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) + result_node = ExprNodes.InlinedGeneratorExpressionNode( + node.pos, gen_expr_node, + orig_func='set' if target_type is Builtin.set_type else 'list', + comprehension_type=target_type) - return result_node + for yield_expression, yield_stat_node in yield_statements: + append_node = ExprNodes.ComprehensionAppendNode( + yield_expression.pos, + expr=yield_expression, + target=result_node.target) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) + + return result_node def _handle_simple_function_dict(self, node, pos_args): - """Replace dict( (a,b) for ... ) by an inlined { a:b for ... } + """Replace dict( (a,b) for ... ) by an inlined { a:b for ... } """ if len(pos_args) == 0: return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) @@ -1954,29 +1954,29 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - yield_statements = _find_yield_statements(loop_node) - if not yield_statements: + yield_statements = _find_yield_statements(loop_node) + if not yield_statements: return node - for yield_expression, _ in yield_statements: - if not isinstance(yield_expression, ExprNodes.TupleNode): - return node - if len(yield_expression.args) != 2: - return node + for yield_expression, _ in yield_statements: + if not isinstance(yield_expression, ExprNodes.TupleNode): + return node + if len(yield_expression.args) != 2: + return node + + result_node = ExprNodes.InlinedGeneratorExpressionNode( + node.pos, gen_expr_node, orig_func='dict', + comprehension_type=Builtin.dict_type) - result_node = ExprNodes.InlinedGeneratorExpressionNode( - node.pos, gen_expr_node, orig_func='dict', - comprehension_type=Builtin.dict_type) - - for yield_expression, yield_stat_node in yield_statements: - append_node = ExprNodes.DictComprehensionAppendNode( - yield_expression.pos, - key_expr=yield_expression.args[0], - value_expr=yield_expression.args[1], - target=result_node.target) - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) + for yield_expression, yield_stat_node in yield_statements: + append_node = ExprNodes.DictComprehensionAppendNode( + yield_expression.pos, + key_expr=yield_expression.args[0], + value_expr=yield_expression.args[1], + target=result_node.target) + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - return result_node + return result_node # specific handlers for general call nodes @@ -2024,8 +2024,8 @@ class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform): return node -class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, - Visitor.MethodDispatcherTransform): +class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, + Visitor.MethodDispatcherTransform): """Optimize some common methods calls and instantiation patterns for builtin types *after* the type analysis phase. @@ -2080,33 +2080,33 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return arg.arg.coerce_to_boolean(self.current_env()) return node - PyNumber_Float_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) - ]) - - def visit_CoerceToPyTypeNode(self, node): - """Drop redundant conversion nodes after tree changes.""" - self.visitchildren(node) - arg = node.arg - if isinstance(arg, ExprNodes.CoerceFromPyTypeNode): - arg = arg.arg - if isinstance(arg, ExprNodes.PythonCapiCallNode): - if arg.function.name == 'float' and len(arg.args) == 1: - # undo redundant Py->C->Py coercion - func_arg = arg.args[0] - if func_arg.type is Builtin.float_type: - return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'") - elif func_arg.type.is_pyobject: - return ExprNodes.PythonCapiCallNode( - node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type, - args=[func_arg], - py_name='float', - is_temp=node.is_temp, - result_is_used=node.result_is_used, - ).coerce_to(node.type, self.current_env()) - return node - + PyNumber_Float_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) + ]) + + def visit_CoerceToPyTypeNode(self, node): + """Drop redundant conversion nodes after tree changes.""" + self.visitchildren(node) + arg = node.arg + if isinstance(arg, ExprNodes.CoerceFromPyTypeNode): + arg = arg.arg + if isinstance(arg, ExprNodes.PythonCapiCallNode): + if arg.function.name == 'float' and len(arg.args) == 1: + # undo redundant Py->C->Py coercion + func_arg = arg.args[0] + if func_arg.type is Builtin.float_type: + return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'") + elif func_arg.type.is_pyobject: + return ExprNodes.PythonCapiCallNode( + node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type, + args=[func_arg], + py_name='float', + is_temp=node.is_temp, + result_is_used=node.result_is_used, + ).coerce_to(node.type, self.current_env()) + return node + def visit_CoerceFromPyTypeNode(self, node): """Drop redundant conversion nodes after tree changes. @@ -2118,9 +2118,9 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, arg = node.arg if not arg.type.is_pyobject: # no Python conversion left at all, just do a C coercion instead - if node.type != arg.type: - arg = arg.coerce_to(node.type, self.current_env()) - return arg + if node.type != arg.type: + arg = arg.coerce_to(node.type, self.current_env()) + return arg if isinstance(arg, ExprNodes.PyTypeTestNode): arg = arg.arg if arg.is_literal: @@ -2133,13 +2133,13 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if node.type.assignable_from(arg.arg.type): # completely redundant C->Py->C coercion return arg.arg.coerce_to(node.type, self.current_env()) - elif arg.type is Builtin.unicode_type: - if arg.arg.type.is_unicode_char and node.type.is_unicode_char: - return arg.arg.coerce_to(node.type, self.current_env()) + elif arg.type is Builtin.unicode_type: + if arg.arg.type.is_unicode_char and node.type.is_unicode_char: + return arg.arg.coerce_to(node.type, self.current_env()) elif isinstance(arg, ExprNodes.SimpleCallNode): if node.type.is_int or node.type.is_float: return self._optimise_numeric_cast_call(node, arg) - elif arg.is_subscript: + elif arg.is_subscript: index_node = arg.index if isinstance(index_node, ExprNodes.CoerceToPyTypeNode): index_node = index_node.arg @@ -2181,51 +2181,51 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return node return coerce_node - float_float_func_types = dict( - (float_type, PyrexTypes.CFuncType( - float_type, [ - PyrexTypes.CFuncTypeArg("arg", float_type, None) - ])) - for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type)) - + float_float_func_types = dict( + (float_type, PyrexTypes.CFuncType( + float_type, [ + PyrexTypes.CFuncTypeArg("arg", float_type, None) + ])) + for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type)) + def _optimise_numeric_cast_call(self, node, arg): function = arg.function - args = None - if isinstance(arg, ExprNodes.PythonCapiCallNode): - args = arg.args - elif isinstance(function, ExprNodes.NameNode): - if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode): - args = arg.arg_tuple.args - - if args is None or len(args) != 1: + args = None + if isinstance(arg, ExprNodes.PythonCapiCallNode): + args = arg.args + elif isinstance(function, ExprNodes.NameNode): + if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode): + args = arg.arg_tuple.args + + if args is None or len(args) != 1: return node func_arg = args[0] if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): func_arg = func_arg.arg elif func_arg.type.is_pyobject: - # play it safe: Python conversion might work on all sorts of things + # play it safe: Python conversion might work on all sorts of things return node - + if function.name == 'int': if func_arg.type.is_int or node.type.is_int: if func_arg.type == node.type: return func_arg elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: - return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type) - elif func_arg.type.is_float and node.type.is_numeric: - if func_arg.type.math_h_modifier == 'l': - # Work around missing Cygwin definition. - truncl = '__Pyx_truncl' - else: - truncl = 'trunc' + func_arg.type.math_h_modifier - return ExprNodes.PythonCapiCallNode( - node.pos, truncl, - func_type=self.float_float_func_types[func_arg.type], - args=[func_arg], - py_name='int', - is_temp=node.is_temp, - result_is_used=node.result_is_used, - ).coerce_to(node.type, self.current_env()) + return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type) + elif func_arg.type.is_float and node.type.is_numeric: + if func_arg.type.math_h_modifier == 'l': + # Work around missing Cygwin definition. + truncl = '__Pyx_truncl' + else: + truncl = 'trunc' + func_arg.type.math_h_modifier + return ExprNodes.PythonCapiCallNode( + node.pos, truncl, + func_type=self.float_float_func_types[func_arg.type], + args=[func_arg], + py_name='int', + is_temp=node.is_temp, + result_is_used=node.result_is_used, + ).coerce_to(node.type, self.current_env()) elif function.name == 'float': if func_arg.type.is_float or node.type.is_float: if func_arg.type == node.type: @@ -2281,7 +2281,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, entry=type_entry, type=type_entry.type), attribute=attr_name, - is_called=True).analyse_as_type_attribute(self.current_env()) + is_called=True).analyse_as_type_attribute(self.current_env()) if method is None: return self._optimise_generic_builtin_method_call( node, attr_name, function, arg_list, is_unbound_method) @@ -2376,41 +2376,41 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ) return node - PySequence_List_func_type = PyrexTypes.CFuncType( - Builtin.list_type, - [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) - - def _handle_simple_function_list(self, node, function, pos_args): - """Turn list(ob) into PySequence_List(ob). - """ - if len(pos_args) != 1: - return node - arg = pos_args[0] - return ExprNodes.PythonCapiCallNode( - node.pos, "PySequence_List", self.PySequence_List_func_type, - args=pos_args, is_temp=node.is_temp) - + PySequence_List_func_type = PyrexTypes.CFuncType( + Builtin.list_type, + [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)]) + + def _handle_simple_function_list(self, node, function, pos_args): + """Turn list(ob) into PySequence_List(ob). + """ + if len(pos_args) != 1: + return node + arg = pos_args[0] + return ExprNodes.PythonCapiCallNode( + node.pos, "PySequence_List", self.PySequence_List_func_type, + args=pos_args, is_temp=node.is_temp) + PyList_AsTuple_func_type = PyrexTypes.CFuncType( Builtin.tuple_type, [ PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) ]) def _handle_simple_function_tuple(self, node, function, pos_args): - """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple. + """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple. """ if len(pos_args) != 1 or not node.is_temp: return node arg = pos_args[0] if arg.type is Builtin.tuple_type and not arg.may_be_none(): return arg - if arg.type is Builtin.list_type: - pos_args[0] = arg.as_none_safe_node( - "'NoneType' object is not iterable") - - return ExprNodes.PythonCapiCallNode( - node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, - args=pos_args, is_temp=node.is_temp) - else: + if arg.type is Builtin.list_type: + pos_args[0] = arg.as_none_safe_node( + "'NoneType' object is not iterable") + + return ExprNodes.PythonCapiCallNode( + node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, + args=pos_args, is_temp=node.is_temp) + else: return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type) PySet_New_func_type = PyrexTypes.CFuncType( @@ -2435,18 +2435,18 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, temps.append(arg) args.append(arg) result = ExprNodes.SetNode(node.pos, is_temp=1, args=args) - self.replace(node, result) + self.replace(node, result) for temp in temps[::-1]: result = UtilNodes.EvalWithTempExprNode(temp, result) return result else: # PySet_New(it) is better than a generic Python call to set(it) - return self.replace(node, ExprNodes.PythonCapiCallNode( + return self.replace(node, ExprNodes.PythonCapiCallNode( node.pos, "PySet_New", self.PySet_New_func_type, args=pos_args, is_temp=node.is_temp, - py_name="set")) + py_name="set")) PyFrozenSet_New_func_type = PyrexTypes.CFuncType( Builtin.frozenset_type, [ @@ -2510,11 +2510,11 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None) ]) - PyInt_FromDouble_func_type = PyrexTypes.CFuncType( - PyrexTypes.py_object_type, [ - PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None) - ]) - + PyInt_FromDouble_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None) + ]) + def _handle_simple_function_int(self, node, function, pos_args): """Transform int() into a faster C function call. """ @@ -2525,17 +2525,17 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return node # int(x, base) func_arg = pos_args[0] if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode): - if func_arg.arg.type.is_float: - return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type, - args=[func_arg.arg], is_temp=True, py_name='int', - utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c")) - else: - return node # handled in visit_CoerceFromPyTypeNode() + if func_arg.arg.type.is_float: + return ExprNodes.PythonCapiCallNode( + node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type, + args=[func_arg.arg], is_temp=True, py_name='int', + utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c")) + else: + return node # handled in visit_CoerceFromPyTypeNode() if func_arg.type.is_pyobject and node.type.is_pyobject: return ExprNodes.PythonCapiCallNode( - node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type, - args=pos_args, is_temp=True, py_name='int') + node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type, + args=pos_args, is_temp=True, py_name='int') return node def _handle_simple_function_bool(self, node, function, pos_args): @@ -2560,30 +2560,30 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, Pyx_strlen_func_type = PyrexTypes.CFuncType( PyrexTypes.c_size_t_type, [ - PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) - ]) + PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None) + ]) Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType( PyrexTypes.c_size_t_type, [ - PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None) - ]) + PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None) + ]) PyObject_Size_func_type = PyrexTypes.CFuncType( PyrexTypes.c_py_ssize_t_type, [ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None) - ], + ], exception_value="-1") _map_to_capi_len_function = { - Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH", - Builtin.bytes_type: "PyBytes_GET_SIZE", + Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH", + Builtin.bytes_type: "PyBytes_GET_SIZE", Builtin.bytearray_type: 'PyByteArray_GET_SIZE', - Builtin.list_type: "PyList_GET_SIZE", - Builtin.tuple_type: "PyTuple_GET_SIZE", - Builtin.set_type: "PySet_GET_SIZE", - Builtin.frozenset_type: "PySet_GET_SIZE", - Builtin.dict_type: "PyDict_Size", - }.get + Builtin.list_type: "PyList_GET_SIZE", + Builtin.tuple_type: "PyTuple_GET_SIZE", + Builtin.set_type: "PySet_GET_SIZE", + Builtin.frozenset_type: "PySet_GET_SIZE", + Builtin.dict_type: "PyDict_Size", + }.get _ext_types_with_pysize = set(["cpython.array.array"]) @@ -2668,14 +2668,14 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if len(pos_args) != 2: return node arg, types = pos_args - temps = [] + temps = [] if isinstance(types, ExprNodes.TupleNode): types = types.args - if len(types) == 1 and not types[0].type is Builtin.type_type: - return node # nothing to improve here + if len(types) == 1 and not types[0].type is Builtin.type_type: + return node # nothing to improve here if arg.is_attribute or not arg.is_simple(): - arg = UtilNodes.ResultRefNode(arg) - temps.append(arg) + arg = UtilNodes.ResultRefNode(arg) + temps.append(arg) elif types.type is Builtin.type_type: types = [types] else: @@ -2706,17 +2706,17 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, type_check_function = '__Pyx_TypeCheck' type_check_args = [arg, test_type_node] else: - if not test_type_node.is_literal: - test_type_node = UtilNodes.ResultRefNode(test_type_node) - temps.append(test_type_node) - type_check_function = 'PyObject_IsInstance' - type_check_args = [arg, test_type_node] + if not test_type_node.is_literal: + test_type_node = UtilNodes.ResultRefNode(test_type_node) + temps.append(test_type_node) + type_check_function = 'PyObject_IsInstance' + type_check_args = [arg, test_type_node] test_nodes.append( ExprNodes.PythonCapiCallNode( test_type_node.pos, type_check_function, self.Py_type_check_func_type, - args=type_check_args, - is_temp=True, - )) + args=type_check_args, + is_temp=True, + )) def join_with_or(a, b, make_binop_node=ExprNodes.binop_node): or_node = make_binop_node(node.pos, 'or', a, b) @@ -2725,7 +2725,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return or_node test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env) - for temp in temps[::-1]: + for temp in temps[::-1]: test_node = UtilNodes.EvalWithTempExprNode(temp, test_node) return test_node @@ -2738,7 +2738,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if isinstance(arg, ExprNodes.CoerceToPyTypeNode): if arg.arg.type.is_unicode_char: return ExprNodes.TypecastNode( - arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type + arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type ).coerce_to(node.type, self.current_env()) elif isinstance(arg, ExprNodes.UnicodeNode): if len(arg.value) == 1: @@ -2990,8 +2990,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyObject_PopIndex_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None), ], has_varargs=True) # to fake the additional macro args that lack a proper C type @@ -3026,23 +3026,23 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ) elif len(args) == 2: index = unwrap_coerced_node(args[1]) - py_index = ExprNodes.NoneNode(index.pos) + py_index = ExprNodes.NoneNode(index.pos) orig_index_type = index.type if not index.type.is_int: - if isinstance(index, ExprNodes.IntNode): - py_index = index.coerce_to_pyobject(self.current_env()) + if isinstance(index, ExprNodes.IntNode): + py_index = index.coerce_to_pyobject(self.current_env()) + index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) + elif is_list: + if index.type.is_pyobject: + py_index = index.coerce_to_simple(self.current_env()) + index = ExprNodes.CloneNode(py_index) index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) - elif is_list: - if index.type.is_pyobject: - py_index = index.coerce_to_simple(self.current_env()) - index = ExprNodes.CloneNode(py_index) - index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env()) else: return node elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type): return node - elif isinstance(index, ExprNodes.IntNode): - py_index = index.coerce_to_pyobject(self.current_env()) + elif isinstance(index, ExprNodes.IntNode): + py_index = index.coerce_to_pyobject(self.current_env()) # real type might still be larger at runtime if not orig_index_type.is_int: orig_index_type = index.type @@ -3054,12 +3054,12 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, return ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_Py%s_PopIndex" % type_name, self.PyObject_PopIndex_func_type, - args=[obj, py_index, index, + args=[obj, py_index, index, ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0), constant_result=orig_index_type.signed and 1 or 0, type=PyrexTypes.c_int_type), ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type, - orig_index_type.empty_declaration_code()), + orig_index_type.empty_declaration_code()), ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)], may_return_none=True, is_temp=node.is_temp, @@ -3163,184 +3163,184 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, may_return_none=True, utility_code=load_c_utility('py_dict_pop')) - Pyx_BinopInt_func_types = dict( - ((ctype, ret_type), PyrexTypes.CFuncType( - ret_type, [ - PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None), - PyrexTypes.CFuncTypeArg("cval", ctype, None), - PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None), - PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None), - ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value)) - for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type) - for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type) - ) - - def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Add', node, function, args, is_unbound_method) - - def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) - - def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) - - def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) - - def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('And', node, function, args, is_unbound_method) - - def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Or', node, function, args, is_unbound_method) - - def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Xor', node, function, args, is_unbound_method) - - def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method): - if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): - return node - if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): - return node - return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method) - - def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method): - if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): - return node - if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): - return node - return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method) - - def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method): - return self._optimise_num_div('Remainder', node, function, args, is_unbound_method) - - def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method): - return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method) - - def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method): - return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method) - - def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method): - return self._optimise_num_div('Divide', node, function, args, is_unbound_method) - - def _optimise_num_div(self, operator, node, function, args, is_unbound_method): - if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0: - return node - if isinstance(args[1], ExprNodes.IntNode): - if not (-2**30 <= args[1].constant_result <= 2**30): - return node - elif isinstance(args[1], ExprNodes.FloatNode): - if not (-2**53 <= args[1].constant_result <= 2**53): - return node - else: - return node - return self._optimise_num_binop(operator, node, function, args, is_unbound_method) - - def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Add', node, function, args, is_unbound_method) - - def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) - - def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method) - - def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Divide', node, function, args, is_unbound_method) - - def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method) - - def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) - - def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method): - return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) - - def _optimise_num_binop(self, operator, node, function, args, is_unbound_method): - """ - Optimise math operators for (likely) float or small integer operations. - """ - if len(args) != 2: - return node - - if node.type.is_pyobject: - ret_type = PyrexTypes.py_object_type - elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'): - ret_type = PyrexTypes.c_bint_type - else: - return node - - # When adding IntNode/FloatNode to something else, assume other operand is also numeric. - # Prefer constants on RHS as they allows better size control for some operators. - num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode) - if isinstance(args[1], num_nodes): - if args[0].type is not PyrexTypes.py_object_type: - return node - numval = args[1] - arg_order = 'ObjC' - elif isinstance(args[0], num_nodes): - if args[1].type is not PyrexTypes.py_object_type: - return node - numval = args[0] - arg_order = 'CObj' - else: - return node - - if not numval.has_constant_result(): - return node - - is_float = isinstance(numval, ExprNodes.FloatNode) - num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type - if is_float: - if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'): - return node - elif operator == 'Divide': - # mixed old-/new-style division is not currently optimised for integers - return node - elif abs(numval.constant_result) > 2**30: - # Cut off at an integer border that is still safe for all operations. - return node - - if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'): - if args[1].constant_result == 0: - # Don't optimise division by 0. :) - return node - - args = list(args) - args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)( - numval.pos, value=numval.value, constant_result=numval.constant_result, - type=num_type)) - inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False - args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) - if is_float or operator not in ('Eq', 'Ne'): - # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument. - zerodivision_check = arg_order == 'CObj' and ( - not node.cdivision if isinstance(node, ExprNodes.DivNode) else False) - args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check)) - - utility_code = TempitaUtilityCode.load_cached( - "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop", - "Optimize.c", - context=dict(op=operator, order=arg_order, ret_type=ret_type)) - - call_node = self._substitute_method_call( - node, function, - "__Pyx_Py%s_%s%s%s" % ( - 'Float' if is_float else 'Int', - '' if ret_type.is_pyobject else 'Bool', - operator, - arg_order), - self.Pyx_BinopInt_func_types[(num_type, ret_type)], - '__%s__' % operator[:3].lower(), is_unbound_method, args, - may_return_none=True, - with_none_check=False, - utility_code=utility_code) - - if node.type.is_pyobject and not ret_type.is_pyobject: - call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type) - return call_node - + Pyx_BinopInt_func_types = dict( + ((ctype, ret_type), PyrexTypes.CFuncType( + ret_type, [ + PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("cval", ctype, None), + PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None), + PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None), + ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value)) + for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type) + for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type) + ) + + def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Add', node, function, args, is_unbound_method) + + def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) + + def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) + + def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) + + def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('And', node, function, args, is_unbound_method) + + def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Or', node, function, args, is_unbound_method) + + def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Xor', node, function, args, is_unbound_method) + + def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method): + if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): + return node + if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): + return node + return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method) + + def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method): + if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode): + return node + if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63): + return node + return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method) + + def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method): + return self._optimise_num_div('Remainder', node, function, args, is_unbound_method) + + def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method): + return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method) + + def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method): + return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method) + + def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method): + return self._optimise_num_div('Divide', node, function, args, is_unbound_method) + + def _optimise_num_div(self, operator, node, function, args, is_unbound_method): + if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0: + return node + if isinstance(args[1], ExprNodes.IntNode): + if not (-2**30 <= args[1].constant_result <= 2**30): + return node + elif isinstance(args[1], ExprNodes.FloatNode): + if not (-2**53 <= args[1].constant_result <= 2**53): + return node + else: + return node + return self._optimise_num_binop(operator, node, function, args, is_unbound_method) + + def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Add', node, function, args, is_unbound_method) + + def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method) + + def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method) + + def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Divide', node, function, args, is_unbound_method) + + def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method) + + def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Eq', node, function, args, is_unbound_method) + + def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method): + return self._optimise_num_binop('Ne', node, function, args, is_unbound_method) + + def _optimise_num_binop(self, operator, node, function, args, is_unbound_method): + """ + Optimise math operators for (likely) float or small integer operations. + """ + if len(args) != 2: + return node + + if node.type.is_pyobject: + ret_type = PyrexTypes.py_object_type + elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'): + ret_type = PyrexTypes.c_bint_type + else: + return node + + # When adding IntNode/FloatNode to something else, assume other operand is also numeric. + # Prefer constants on RHS as they allows better size control for some operators. + num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode) + if isinstance(args[1], num_nodes): + if args[0].type is not PyrexTypes.py_object_type: + return node + numval = args[1] + arg_order = 'ObjC' + elif isinstance(args[0], num_nodes): + if args[1].type is not PyrexTypes.py_object_type: + return node + numval = args[0] + arg_order = 'CObj' + else: + return node + + if not numval.has_constant_result(): + return node + + is_float = isinstance(numval, ExprNodes.FloatNode) + num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type + if is_float: + if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'): + return node + elif operator == 'Divide': + # mixed old-/new-style division is not currently optimised for integers + return node + elif abs(numval.constant_result) > 2**30: + # Cut off at an integer border that is still safe for all operations. + return node + + if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'): + if args[1].constant_result == 0: + # Don't optimise division by 0. :) + return node + + args = list(args) + args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)( + numval.pos, value=numval.value, constant_result=numval.constant_result, + type=num_type)) + inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False + args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) + if is_float or operator not in ('Eq', 'Ne'): + # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument. + zerodivision_check = arg_order == 'CObj' and ( + not node.cdivision if isinstance(node, ExprNodes.DivNode) else False) + args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check)) + + utility_code = TempitaUtilityCode.load_cached( + "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop", + "Optimize.c", + context=dict(op=operator, order=arg_order, ret_type=ret_type)) + + call_node = self._substitute_method_call( + node, function, + "__Pyx_Py%s_%s%s%s" % ( + 'Float' if is_float else 'Int', + '' if ret_type.is_pyobject else 'Bool', + operator, + arg_order), + self.Pyx_BinopInt_func_types[(num_type, ret_type)], + '__%s__' % operator[:3].lower(), is_unbound_method, args, + may_return_none=True, + with_none_check=False, + utility_code=utility_code) + + if node.type.is_pyobject and not ret_type.is_pyobject: + call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type) + return call_node + ### unicode type methods PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType( @@ -3456,44 +3456,44 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, "PyUnicode_Split", self.PyUnicode_Split_func_type, 'split', is_unbound_method, args) - PyUnicode_Join_func_type = PyrexTypes.CFuncType( - Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None), - ]) - - def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method): - """ - unicode.join() builds a list first => see if we can do this more efficiently - """ - if len(args) != 2: - self._error_wrong_arg_count('unicode.join', node, args, "2") - return node - if isinstance(args[1], ExprNodes.GeneratorExpressionNode): - gen_expr_node = args[1] - loop_node = gen_expr_node.loop - - yield_statements = _find_yield_statements(loop_node) - if yield_statements: - inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode( - node.pos, gen_expr_node, orig_func='list', - comprehension_type=Builtin.list_type) - - for yield_expression, yield_stat_node in yield_statements: - append_node = ExprNodes.ComprehensionAppendNode( - yield_expression.pos, - expr=yield_expression, - target=inlined_genexpr.target) - - Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) - - args[1] = inlined_genexpr - - return self._substitute_method_call( - node, function, - "PyUnicode_Join", self.PyUnicode_Join_func_type, - 'join', is_unbound_method, args) - + PyUnicode_Join_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None), + ]) + + def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method): + """ + unicode.join() builds a list first => see if we can do this more efficiently + """ + if len(args) != 2: + self._error_wrong_arg_count('unicode.join', node, args, "2") + return node + if isinstance(args[1], ExprNodes.GeneratorExpressionNode): + gen_expr_node = args[1] + loop_node = gen_expr_node.loop + + yield_statements = _find_yield_statements(loop_node) + if yield_statements: + inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode( + node.pos, gen_expr_node, orig_func='list', + comprehension_type=Builtin.list_type) + + for yield_expression, yield_stat_node in yield_statements: + append_node = ExprNodes.ComprehensionAppendNode( + yield_expression.pos, + expr=yield_expression, + target=inlined_genexpr.target) + + Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node) + + args[1] = inlined_genexpr + + return self._substitute_method_call( + node, function, + "PyUnicode_Join", self.PyUnicode_Join_func_type, + 'join', is_unbound_method, args) + PyString_Tailmatch_func_type = PyrexTypes.CFuncType( PyrexTypes.c_bint_type, [ PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode @@ -3626,8 +3626,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( Builtin.bytes_type, [ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), - PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), ]) PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( @@ -3671,8 +3671,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, # well, looks like we can't pass else: - value = bytes_literal(value, encoding) - return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type) + value = bytes_literal(value, encoding) + return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type) if encoding and error_handling == 'strict': # try to find a specific encoder function @@ -3692,30 +3692,30 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType( Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), - ])) + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), + ])) _decode_c_string_func_type = PyrexTypes.CFuncType( Builtin.unicode_type, [ - PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), - ]) + ]) _decode_bytes_func_type = PyrexTypes.CFuncType( Builtin.unicode_type, [ PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None), - ]) + ]) _decode_cpp_string_func_type = None # lazy init @@ -3810,8 +3810,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, PyrexTypes.CFuncTypeArg("string", string_type, None), PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None), - PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), - PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None), PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None), ]) helper_func_type = self._decode_cpp_string_func_type @@ -3882,14 +3882,14 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, if isinstance(node, ExprNodes.UnicodeNode): encoding = node.value node = ExprNodes.BytesNode( - node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type) + node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type) elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): encoding = node.value.decode('ISO-8859-1') node = ExprNodes.BytesNode( - node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type) + node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type) elif node.type is Builtin.bytes_type: encoding = None - node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env()) + node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env()) elif node.type.is_string: encoding = None else: @@ -3933,8 +3933,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, def _substitute_method_call(self, node, function, name, func_type, attr_name, is_unbound_method, args=(), utility_code=None, is_temp=None, - may_return_none=ExprNodes.PythonCapiCallNode.may_return_none, - with_none_check=True): + may_return_none=ExprNodes.PythonCapiCallNode.may_return_none, + with_none_check=True): args = list(args) if with_none_check and args: args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name) @@ -4210,15 +4210,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): bytes_value = None if str1.bytes_value is not None and str2.bytes_value is not None: if str1.bytes_value.encoding == str2.bytes_value.encoding: - bytes_value = bytes_literal( - str1.bytes_value + str2.bytes_value, - str1.bytes_value.encoding) + bytes_value = bytes_literal( + str1.bytes_value + str2.bytes_value, + str1.bytes_value.encoding) string_value = EncodedString(node.constant_result) return ExprNodes.UnicodeNode( str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value) elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode): if str1.value.encoding == str2.value.encoding: - bytes_value = bytes_literal(node.constant_result, str1.value.encoding) + bytes_value = bytes_literal(node.constant_result, str1.value.encoding) return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result) # all other combinations are rather complicated # to get right in Py2/3: encodings, unicode escapes, ... @@ -4275,12 +4275,12 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): def _calculate_constant_seq(self, node, sequence_node, factor): if factor.constant_result != 1 and sequence_node.args: - if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0: + if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0: del sequence_node.args[:] sequence_node.mult_factor = None elif sequence_node.mult_factor is not None: - if (isinstance(factor.constant_result, _py_int_types) and - isinstance(sequence_node.mult_factor.constant_result, _py_int_types)): + if (isinstance(factor.constant_result, _py_int_types) and + isinstance(sequence_node.mult_factor.constant_result, _py_int_types)): value = sequence_node.mult_factor.constant_result * factor.constant_result sequence_node.mult_factor = ExprNodes.IntNode( sequence_node.mult_factor.pos, @@ -4332,16 +4332,16 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): warning(pos, "Too few arguments for format placeholders", level=1) can_be_optimised = False break - if arg.is_starred: - can_be_optimised = False - break - if format_type in u'asrfdoxX': + if arg.is_starred: + can_be_optimised = False + break + if format_type in u'asrfdoxX': format_spec = s[1:] conversion_char = None if format_type in u'doxX' and u'.' in format_spec: # Precision is not allowed for integers in format(), but ok in %-formatting. can_be_optimised = False - elif format_type in u'ars': + elif format_type in u'ars': format_spec = format_spec[:-1] conversion_char = format_type if format_spec.startswith('0'): @@ -4363,7 +4363,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): else: # keep it simple for now ... can_be_optimised = False - break + break if not can_be_optimised: # Print all warnings we can find before finally giving up here. @@ -4379,11 +4379,11 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): node = ExprNodes.JoinedStrNode(pos, values=substrings) return self.visit_JoinedStrNode(node) - def visit_FormattedValueNode(self, node): - self.visitchildren(node) + def visit_FormattedValueNode(self, node): + self.visitchildren(node) conversion_char = node.conversion_char or 's' - if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value: - node.format_spec = None + if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value: + node.format_spec = None if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode): value = EncodedString(node.value.value) if value.isdigit(): @@ -4396,130 +4396,130 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): value = node.value.unicode_value if value is not None: return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value) - return node - - def visit_JoinedStrNode(self, node): - """ - Clean up after the parser by discarding empty Unicode strings and merging - substring sequences. Empty or single-value join lists are not uncommon - because f-string format specs are always parsed into JoinedStrNodes. - """ - self.visitchildren(node) - unicode_node = ExprNodes.UnicodeNode - - values = [] - for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)): - if is_unode_group: - substrings = list(substrings) - unode = substrings[0] - if len(substrings) > 1: + return node + + def visit_JoinedStrNode(self, node): + """ + Clean up after the parser by discarding empty Unicode strings and merging + substring sequences. Empty or single-value join lists are not uncommon + because f-string format specs are always parsed into JoinedStrNodes. + """ + self.visitchildren(node) + unicode_node = ExprNodes.UnicodeNode + + values = [] + for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)): + if is_unode_group: + substrings = list(substrings) + unode = substrings[0] + if len(substrings) > 1: value = EncodedString(u''.join(value.value for value in substrings)) unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value) - # ignore empty Unicode strings - if unode.value: - values.append(unode) - else: - values.extend(substrings) - - if not values: + # ignore empty Unicode strings + if unode.value: + values.append(unode) + else: + values.extend(substrings) + + if not values: value = EncodedString('') node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value) - elif len(values) == 1: - node = values[0] - elif len(values) == 2: - # reduce to string concatenation - node = ExprNodes.binop_node(node.pos, '+', *values) - else: - node.values = values - return node - - def visit_MergedDictNode(self, node): - """Unpack **args in place if we can.""" - self.visitchildren(node) - args = [] - items = [] - - def add(arg): - if arg.is_dict_literal: - if items: - items[0].key_value_pairs.extend(arg.key_value_pairs) - else: - items.append(arg) - elif isinstance(arg, ExprNodes.MergedDictNode): - for child_arg in arg.keyword_args: - add(child_arg) - else: - if items: - args.append(items[0]) - del items[:] - args.append(arg) - - for arg in node.keyword_args: - add(arg) - if items: - args.append(items[0]) - - if len(args) == 1: - arg = args[0] - if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode): - return arg - node.keyword_args[:] = args - self._calculate_const(node) - return node - - def visit_MergedSequenceNode(self, node): - """Unpack *args in place if we can.""" - self.visitchildren(node) - - is_set = node.type is Builtin.set_type - args = [] - values = [] - - def add(arg): - if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor): - if values: - values[0].args.extend(arg.args) - else: - values.append(arg) - elif isinstance(arg, ExprNodes.MergedSequenceNode): - for child_arg in arg.args: - add(child_arg) - else: - if values: - args.append(values[0]) - del values[:] - args.append(arg) - - for arg in node.args: - add(arg) - if values: - args.append(values[0]) - - if len(args) == 1: - arg = args[0] - if ((is_set and arg.is_set_literal) or - (arg.is_sequence_constructor and arg.type is node.type) or - isinstance(arg, ExprNodes.MergedSequenceNode)): - return arg - node.args[:] = args - self._calculate_const(node) - return node - - def visit_SequenceNode(self, node): - """Unpack *args in place if we can.""" - self.visitchildren(node) - args = [] - for arg in node.args: - if not arg.is_starred: - args.append(arg) - elif arg.target.is_sequence_constructor and not arg.target.mult_factor: - args.extend(arg.target.args) - else: - args.append(arg) - node.args[:] = args - self._calculate_const(node) - return node - + elif len(values) == 1: + node = values[0] + elif len(values) == 2: + # reduce to string concatenation + node = ExprNodes.binop_node(node.pos, '+', *values) + else: + node.values = values + return node + + def visit_MergedDictNode(self, node): + """Unpack **args in place if we can.""" + self.visitchildren(node) + args = [] + items = [] + + def add(arg): + if arg.is_dict_literal: + if items: + items[0].key_value_pairs.extend(arg.key_value_pairs) + else: + items.append(arg) + elif isinstance(arg, ExprNodes.MergedDictNode): + for child_arg in arg.keyword_args: + add(child_arg) + else: + if items: + args.append(items[0]) + del items[:] + args.append(arg) + + for arg in node.keyword_args: + add(arg) + if items: + args.append(items[0]) + + if len(args) == 1: + arg = args[0] + if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode): + return arg + node.keyword_args[:] = args + self._calculate_const(node) + return node + + def visit_MergedSequenceNode(self, node): + """Unpack *args in place if we can.""" + self.visitchildren(node) + + is_set = node.type is Builtin.set_type + args = [] + values = [] + + def add(arg): + if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor): + if values: + values[0].args.extend(arg.args) + else: + values.append(arg) + elif isinstance(arg, ExprNodes.MergedSequenceNode): + for child_arg in arg.args: + add(child_arg) + else: + if values: + args.append(values[0]) + del values[:] + args.append(arg) + + for arg in node.args: + add(arg) + if values: + args.append(values[0]) + + if len(args) == 1: + arg = args[0] + if ((is_set and arg.is_set_literal) or + (arg.is_sequence_constructor and arg.type is node.type) or + isinstance(arg, ExprNodes.MergedSequenceNode)): + return arg + node.args[:] = args + self._calculate_const(node) + return node + + def visit_SequenceNode(self, node): + """Unpack *args in place if we can.""" + self.visitchildren(node) + args = [] + for arg in node.args: + if not arg.is_starred: + args.append(arg) + elif arg.target.is_sequence_constructor and not arg.target.mult_factor: + args.extend(arg.target.args) + else: + args.append(arg) + node.args[:] = args + self._calculate_const(node) + return node + def visit_PrimaryCmpNode(self, node): # calculate constant partial results in the comparison cascade self.visitchildren(node, ['operand1']) @@ -4759,30 +4759,30 @@ class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin): else "optimize.unpack_method_calls")): # optimise simple Python methods calls if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not ( - node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)): + node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)): # simple call, now exclude calls to objects that are definitely not methods may_be_a_method = True if function.type is Builtin.type_type: may_be_a_method = False - elif function.is_attribute: - if function.entry and function.entry.type.is_cfunction: - # optimised builtin method - may_be_a_method = False + elif function.is_attribute: + if function.entry and function.entry.type.is_cfunction: + # optimised builtin method + may_be_a_method = False elif function.is_name: - entry = function.entry - if entry.is_builtin or entry.type.is_cfunction: + entry = function.entry + if entry.is_builtin or entry.type.is_cfunction: may_be_a_method = False - elif entry.cf_assignments: + elif entry.cf_assignments: # local functions/classes are definitely not methods non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode) may_be_a_method = any( assignment.rhs and not isinstance(assignment.rhs, non_method_nodes) - for assignment in entry.cf_assignments) + for assignment in entry.cf_assignments) if may_be_a_method: - if (node.self and function.is_attribute and - isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self): - # function self object was moved into a CloneNode => undo - function.obj = function.obj.arg + if (node.self and function.is_attribute and + isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self): + # function self object was moved into a CloneNode => undo + function.obj = function.obj.arg node = self.replace(node, ExprNodes.PyMethodCallNode.from_node( node, function=function, arg_tuple=node.arg_tuple, type=node.type)) return node |