aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/threading/future/wait/wait_group-inl.h
blob: 407e0b563091fc4a63077b27e1fba70789007c29 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#pragma once 
 
#if !defined(INCLUDE_FUTURE_INL_H) 
#error "you should never include wait_group-inl.h directly" 
#endif // INCLUDE_FUTURE_INL_H 
 
#include "wait_policy.h" 
 
#include <util/generic/maybe.h> 
#include <util/generic/ptr.h> 
 
#include <library/cpp/threading/future/core/future.h> 
 
#include <util/system/spinlock.h> 

#include <atomic>
#include <exception> 
 
namespace NThreading { 
    namespace NWaitGroup::NImpl { 
        template <class WaitPolicy> 
        struct TState final : TAtomicRefCount<TState<WaitPolicy>> { 
            template <class T> 
            void Add(const TFuture<T>& future); 
            TFuture<void> Finish(); 
 
            void TryPublish(); 
            void Publish(); 
 
            bool ShouldPublishByCount() const noexcept; 
            bool ShouldPublishByException() const noexcept; 
 
            TStateRef<WaitPolicy> SharedFromThis() noexcept { 
                return TStateRef<WaitPolicy>{this}; 
            } 
 
            enum class EPhase { 
                Initial, 
                Publishing, 
            }; 
 
            // initially we have one imaginary discovered future which we 
            // use for synchronization with ::Finish 
            std::atomic<ui64> Discovered{1}; 
 
            std::atomic<ui64> Finished{0}; 
 
            std::atomic<EPhase> Phase{EPhase::Initial}; 
 
            TPromise<void> Subscribers = NewPromise(); 
 
            mutable TAdaptiveLock Mut; 
            std::exception_ptr ExceptionInFlight; 
 
            void TrySetException(std::exception_ptr eptr) noexcept { 
                TGuard lock{Mut}; 
                if (!ExceptionInFlight) { 
                    ExceptionInFlight = std::move(eptr); 
                } 
            } 
 
            std::exception_ptr GetExceptionInFlight() const noexcept { 
                TGuard lock{Mut}; 
                return ExceptionInFlight; 
            } 
        }; 
 
        template <class WaitPolicy> 
        inline TFuture<void> TState<WaitPolicy>::Finish() { 
            Finished.fetch_add(1); // complete the imaginary future 
 
            // handle empty case explicitly: 
            if (Discovered.load() == 1) { 
                Y_ASSERT(Phase.load() == EPhase::Initial); 
                Publish(); 
            } else { 
                TryPublish(); 
            } 
 
            return Subscribers; 
        } 
 
        template <class WaitPolicy> 
        template <class T> 
        inline void TState<WaitPolicy>::Add(const TFuture<T>& future) { 
            future.EnsureInitialized(); 
 
            Discovered.fetch_add(1); 
 
            // NoexceptSubscribe is needed to make ::Add exception-safe 
            future.NoexceptSubscribe([self = SharedFromThis()](auto&& future) { 
                try { 
                    future.TryRethrow(); 
                } catch (...) { 
                    self->TrySetException(std::current_exception()); 
                } 
 
                self->Finished.fetch_add(1); 
                self->TryPublish(); 
            }); 
        } 
 
        // 
        // ============================ PublishByCount ================================== 
        // 
 
        template <class WaitPolicy> 
        inline bool TState<WaitPolicy>::ShouldPublishByCount() const noexcept { 
            // - safety: a) If the future incremented ::Finished, and we observe the effect, then we will observe ::Discovered as incremented by its discovery later 
            //           b) Every discovery of a future observes discovery of the imaginary future 
            //          a, b => if finishedByNow == discoveredByNow, then every future discovered in [imaginary discovered, imaginary finished] is finished 
            // 
            // - liveness: a) TryPublish is called after each increment of ::Finished 
            //             b) There is some last increment of ::Finished which follows all other operations with ::Finished and ::Discovered (provided that every future is eventually set) 
            //             c) For each increment of ::Discovered there is an increment of ::Finished (provided that every future is eventually set) 
            //          a, b c => some call to ShouldPublishByCount will always return true 
            // 
            // order of the following two operations is significant for the proof. 
            auto finishedByNow = Finished.load(); 
            auto discoveredByNow = Discovered.load(); 
 
            return finishedByNow == discoveredByNow; 
        } 
 
        template <> 
        inline bool TState<TWaitPolicy::TAny>::ShouldPublishByCount() const noexcept { 
            auto finishedByNow = Finished.load(); 
 
            // note that the empty case is not handled here 
            return finishedByNow >= 2; // at least one non-imaginary 
        } 
 
        // 
        // ============================ PublishByException ================================== 
        // 
 
        template <> 
        inline bool TState<TWaitPolicy::TAny>::ShouldPublishByException() const noexcept { 
            // for TAny exceptions are handled by ShouldPublishByCount 
            return false; 
        } 
 
        template <> 
        inline bool TState<TWaitPolicy::TAll>::ShouldPublishByException() const noexcept { 
            return false; 
        } 
 
        template <> 
        inline bool TState<TWaitPolicy::TExceptionOrAll>::ShouldPublishByException() const noexcept { 
            return GetExceptionInFlight() != nullptr; 
        } 
 
        // 
        // 
        // 
 
        template <class WaitPolicy> 
        inline void TState<WaitPolicy>::TryPublish() { 
            // the order is insignificant (without proof) 
            bool shouldPublish = ShouldPublishByCount() || ShouldPublishByException(); 
 
            if (shouldPublish) { 
                if (auto currentPhase = EPhase::Initial; 
                    Phase.compare_exchange_strong(currentPhase, EPhase::Publishing)) { 
                    Publish(); 
                } 
            } 
        } 
 
        template <class WaitPolicy> 
        inline void TState<WaitPolicy>::Publish() { 
            auto eptr = GetExceptionInFlight(); 
 
            // can potentially throw 
            if (eptr) { 
                Subscribers.SetException(std::move(eptr)); 
            } else { 
                Subscribers.SetValue(); 
            } 
        } 
    } 
 
    template <class WaitPolicy> 
    inline TWaitGroup<WaitPolicy>::TWaitGroup() 
        : State_{MakeIntrusive<NWaitGroup::NImpl::TState<WaitPolicy>>()} 
    { 
    } 
 
    template <class WaitPolicy> 
    template <class T> 
    inline TWaitGroup<WaitPolicy>& TWaitGroup<WaitPolicy>::Add(const TFuture<T>& future) { 
        State_->Add(future); 
        return *this; 
    } 
 
    template <class WaitPolicy> 
    inline TFuture<void> TWaitGroup<WaitPolicy>::Finish() && { 
        auto res = State_->Finish(); 
 
        // just to prevent nasty bugs from use-after-move 
        State_.Reset(); 
 
        return res; 
    } 
}