aboutsummaryrefslogtreecommitdiffstats
path: root/library/python/par_apply/__init__.py
blob: 19b89ae8431fa88790695ed5f6c226b29400e59a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import threading
import six

from six.moves import queue


def par_apply(seq, func, thr_num, join_polling=None):
    if thr_num < 2:
        for x in seq:
            yield func(x)

        return

    in_q = queue.Queue()
    out_q = queue.Queue()

    def enumerate_blocks():
        n = 0

        for b in seq:
            yield n, [b]
            n += 1

        yield n, None

    def iter_out():
        n = 0
        d = {}

        while True:
            if n in d:
                r = d[n]
                del d[n]
                n += 1

                yield r
            else:
                res = out_q.get()

                d[res[0]] = res

    out_iter = iter_out()

    def wait_block():
        for x in out_iter:
            return x

    def iter_compressed():
        p = 0

        for n, b in enumerate_blocks():
            in_q.put((n, b))

            while n > p + (thr_num * 2):
                p, b, c = wait_block()

                if not b:
                    return

                yield p, c

        while True:
            p, b, c = wait_block()

            if not b:
                return

            yield p, c

    def proc():
        while True:
            data = in_q.get()

            if data is None:
                return

            n, b = data

            if b:
                try:
                    res = (func(b[0]), None)
                except Exception:
                    res = (None, sys.exc_info())
            else:
                res = (None, None)

            out_q.put((n, b, res))

    thrs = [threading.Thread(target=proc) for i in range(0, thr_num)]

    for t in thrs:
        t.start()

    try:
        for p, c in iter_compressed():
            res, err = c

            if err:
                six.reraise(*err)

            yield res
    finally:
        for t in thrs:
            in_q.put(None)

        for t in thrs:
            if join_polling is not None:
                while True:
                    t.join(join_polling)
                    if not t.is_alive():
                        break
            else:
                t.join()