aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/python/pythran/pythran/analyses/argument_read_once.py
blob: a30fa5a098536a6699a103cef683ac755e44c2cc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
""" ArgumentReadOnce counts the usages of each argument of each function. """

from pythran.analyses.aliases import Aliases
from pythran.analyses.global_declarations import GlobalDeclarations
from pythran.passmanager import ModuleAnalysis
from pythran.tables import MODULES
import pythran.intrinsic as intrinsic

import gast as ast
from functools import reduce


class ArgumentReadOnce(ModuleAnalysis):

    """
    Counts the usages of each argument of each function.

    Attributes
    ----------
    result : {FunctionEffects}
        Number of use for each argument of each function.
    node_to_functioneffect : {???: ???}
        FunctionDef ast node to function effect binding.
    """

    class FunctionEffects(object):
        def __init__(self, node):
            self.func = node
            self.dependencies = lambda ctx: 0
            if isinstance(node, ast.FunctionDef):
                self.read_effects = [-1] * len(node.args.args)
            elif isinstance(node, intrinsic.Intrinsic):
                self.read_effects = [
                    1 if isinstance(x, intrinsic.ReadOnceEffect)
                    else 2 for x in node.argument_effects]
            elif isinstance(node, ast.alias):
                self.read_effects = []
            else:
                raise NotImplementedError

    class ConstructorEffects(object):
        def __init__(self, node):
            self.func = node
            self.dependencies = lambda ctx: 0
            self.read_effects = [0]

    class Context(object):
        def __init__(self, function, index, path, global_dependencies):
            self.function = function
            self.index = index
            self.path = path
            self.global_dependencies = global_dependencies

    def __init__(self):
        """ Basic initialiser for class attributes. """
        self.result = set()
        self.node_to_functioneffect = dict()
        super(ArgumentReadOnce, self).__init__(Aliases, GlobalDeclarations)

    def prepare(self, node):
        """
        Initialise arguments effects as this analysis in inter-procedural.

        Initialisation done for Pythonic functions and default values set for
        user defined functions.
        """
        super(ArgumentReadOnce, self).prepare(node)
        # global functions init
        for n in self.global_declarations.values():
            fe = ArgumentReadOnce.FunctionEffects(n)
            self.node_to_functioneffect[n] = fe
            self.result.add(fe)

        # Pythonic functions init
        def save_effect(module):
            """ Recursively save read once effect for Pythonic functions. """
            for intr in module.values():
                if isinstance(intr, dict):  # Submodule case
                    save_effect(intr)
                else:
                    fe = ArgumentReadOnce.FunctionEffects(intr)
                    self.node_to_functioneffect[intr] = fe
                    self.result.add(fe)
                    if isinstance(intr, intrinsic.Class):  # Class case
                        save_effect(intr.fields)

        for module in MODULES.values():
            save_effect(module)

    def run(self, node):
        result = super(ArgumentReadOnce, self).run(node)
        for fun in result:
            for i in range(len(fun.read_effects)):
                self.recursive_weight(fun, i, set())
        self.result = {f.func: f.read_effects for f in result}
        return self.result

    def recursive_weight(self, function, index, predecessors):
        # TODO : Find out why it happens in some cases
        if len(function.read_effects) <= index:
            return 0
        if function.read_effects[index] == -1:
            # In case of recursive/cyclic calls
            cycle = function in predecessors
            predecessors.add(function)
            if cycle:
                function.read_effects[index] = 2 * function.dependencies(
                    ArgumentReadOnce.Context(function, index,
                                             predecessors, False))
            else:
                function.read_effects[index] = function.dependencies(
                    ArgumentReadOnce.Context(function, index,
                                             predecessors, True))
        return function.read_effects[index]

    def argument_index(self, node):
        while isinstance(node, ast.Subscript):
            node = node.value
        if node in self.aliases:
            for n_alias in self.aliases[node]:
                try:
                    return self.current_function.func.args.args.index(n_alias)
                except ValueError:
                    pass
        return -1

    def local_effect(self, node, effect):
        index = self.argument_index(node)
        return lambda ctx: effect if index == ctx.index else 0

    def generic_visit(self, node):
        lambdas = [self.visit(child) for child in ast.iter_child_nodes(node)]
        return lambda ctx: sum(l(ctx) for l in lambdas)

    def visit_FunctionDef(self, node):
        self.current_function = self.node_to_functioneffect[node]
        assert self.current_function in self.result
        self.current_function.dependencies = self.generic_visit(node)

    def visit_Return(self, node):
        dep = self.generic_visit(node)
        if isinstance(node.value, ast.Name):
            local = self.local_effect(node.value, 2)
            return lambda ctx: dep(ctx) + local(ctx)
        else:
            return dep

    def visit_Assign(self, node):
        dep = self.generic_visit(node)
        local = [self.local_effect(t, 2) for t in node.targets
                 if isinstance(t, ast.Subscript)]
        return lambda ctx: dep(ctx) + sum(l(ctx) for l in local)

    def visit_AugAssign(self, node):
        dep = self.generic_visit(node)
        local = self.local_effect(node.target, 2)
        return lambda ctx: dep(ctx) + local(ctx)

    def visit_For(self, node):
        iter_local = self.local_effect(node.iter, 1)
        iter_deps = self.visit(node.iter)
        body_deps = [self.visit(stmt) for stmt in node.body]
        else_deps = [self.visit(stmt) for stmt in node.orelse]
        return lambda ctx: iter_local(ctx) + iter_deps(ctx) + 2 * sum(
            l(ctx) for l in body_deps) + sum(l(ctx) for l in else_deps)

    def visit_While(self, node):
        test_deps = self.visit(node.test)
        body_deps = [self.visit(stmt) for stmt in node.body]
        else_deps = [self.visit(stmt) for stmt in node.orelse]
        return lambda ctx: test_deps(ctx) + 2 * sum(
            l(ctx) for l in body_deps) + sum(l(ctx) for l in else_deps)

    def visit_If(self, node):
        test_deps = self.visit(node.test)
        body_deps = [self.visit(stmt) for stmt in node.body]
        else_deps = [self.visit(stmt) for stmt in node.orelse]
        return lambda ctx: test_deps(ctx) + max(sum(
            l(ctx) for l in body_deps), sum(l(ctx) for l in else_deps))

    def visit_Call(self, node):
        l0 = self.generic_visit(node)
        index_corres = dict()
        func = None
        for i, arg in enumerate(node.args):
            n = self.argument_index(arg)
            if n >= 0:
                func_aliases = self.aliases[node.func]

                # expand argument if any
                func_aliases = reduce(
                    lambda x, y: x + (
                        # all functions
                        list(self.node_to_functioneffect.keys())
                        if (isinstance(y, ast.Name) and
                            self.argument_index(y) >= 0)
                        else [y]),
                    func_aliases,
                    list())

                for func_alias in func_aliases:
                    # special hook for binded functions
                    if isinstance(func_alias, ast.Call):
                        bound_name = func_alias.args[0].id
                        func_alias = self.global_declarations[bound_name]

                    if func_alias is intrinsic.UnboundValue:
                        continue

                    if func_alias not in self.node_to_functioneffect:
                        continue

                    func = self.node_to_functioneffect[func_alias]
                    index_corres[n] = i

        def merger(ctx):
            base = l0(ctx)
            if (ctx.index in index_corres) and ctx.global_dependencies:
                rec = self.recursive_weight(func, index_corres[ctx.index],
                                            ctx.path)
            else:
                rec = 0
            return base + rec

        return merger

    def visit_Subscript(self, node):
        dep = self.generic_visit(node)
        local = self.local_effect(node.value, 2)
        return lambda ctx: dep(ctx) + local(ctx)

    def visit_comprehension(self, node):
        dep = self.generic_visit(node)
        local = self.local_effect(node.iter, 1)
        return lambda ctx: dep(ctx) + local(ctx)