aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/threading/future/wait/wait_group-inl.h
blob: a7da536f20e025fc115a37c04d7aed789282bb91 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14












                                                           
                 




























































































































































































                                                                                                                                                                                           
#pragma once

#if !defined(INCLUDE_FUTURE_INL_H)
#error "you should never include wait_group-inl.h directly"
#endif // INCLUDE_FUTURE_INL_H

#include "wait_policy.h"

#include <util/generic/maybe.h>
#include <util/generic/ptr.h>

#include <library/cpp/threading/future/core/future.h>

#include <util/system/spinlock.h>

#include <atomic>
#include <exception>

namespace NThreading {
    namespace NWaitGroup::NImpl {
        template <class WaitPolicy>
        struct TState final : TAtomicRefCount<TState<WaitPolicy>> {
            template <class T>
            void Add(const TFuture<T>& future);
            TFuture<void> Finish();

            void TryPublish();
            void Publish();

            bool ShouldPublishByCount() const noexcept;
            bool ShouldPublishByException() const noexcept;

            TStateRef<WaitPolicy> SharedFromThis() noexcept {
                return TStateRef<WaitPolicy>{this};
            }

            enum class EPhase {
                Initial,
                Publishing,
            };

            // initially we have one imaginary discovered future which we
            // use for synchronization with ::Finish
            std::atomic<ui64> Discovered{1};

            std::atomic<ui64> Finished{0};

            std::atomic<EPhase> Phase{EPhase::Initial};

            TPromise<void> Subscribers = NewPromise();

            mutable TAdaptiveLock Mut;
            std::exception_ptr ExceptionInFlight;

            void TrySetException(std::exception_ptr eptr) noexcept {
                TGuard lock{Mut};
                if (!ExceptionInFlight) {
                    ExceptionInFlight = std::move(eptr);
                }
            }

            std::exception_ptr GetExceptionInFlight() const noexcept {
                TGuard lock{Mut};
                return ExceptionInFlight;
            }
        };

        template <class WaitPolicy>
        inline TFuture<void> TState<WaitPolicy>::Finish() {
            Finished.fetch_add(1); // complete the imaginary future

            // handle empty case explicitly:
            if (Discovered.load() == 1) {
                Y_ASSERT(Phase.load() == EPhase::Initial);
                Publish();
            } else {
                TryPublish();
            }

            return Subscribers;
        }

        template <class WaitPolicy>
        template <class T>
        inline void TState<WaitPolicy>::Add(const TFuture<T>& future) {
            future.EnsureInitialized();

            Discovered.fetch_add(1);

            // NoexceptSubscribe is needed to make ::Add exception-safe
            future.NoexceptSubscribe([self = SharedFromThis()](auto&& future) {
                try {
                    future.TryRethrow();
                } catch (...) {
                    self->TrySetException(std::current_exception());
                }

                self->Finished.fetch_add(1);
                self->TryPublish();
            });
        }

        //
        // ============================ PublishByCount ==================================
        //

        template <class WaitPolicy>
        inline bool TState<WaitPolicy>::ShouldPublishByCount() const noexcept {
            // - safety: a) If the future incremented ::Finished, and we observe the effect, then we will observe ::Discovered as incremented by its discovery later
            //           b) Every discovery of a future observes discovery of the imaginary future
            //          a, b => if finishedByNow == discoveredByNow, then every future discovered in [imaginary discovered, imaginary finished] is finished
            //
            // - liveness: a) TryPublish is called after each increment of ::Finished
            //             b) There is some last increment of ::Finished which follows all other operations with ::Finished and ::Discovered (provided that every future is eventually set)
            //             c) For each increment of ::Discovered there is an increment of ::Finished (provided that every future is eventually set)
            //          a, b c => some call to ShouldPublishByCount will always return true
            //
            // order of the following two operations is significant for the proof.
            auto finishedByNow = Finished.load();
            auto discoveredByNow = Discovered.load();

            return finishedByNow == discoveredByNow;
        }

        template <>
        inline bool TState<TWaitPolicy::TAny>::ShouldPublishByCount() const noexcept {
            auto finishedByNow = Finished.load();

            // note that the empty case is not handled here
            return finishedByNow >= 2; // at least one non-imaginary
        }

        //
        // ============================ PublishByException ==================================
        //

        template <>
        inline bool TState<TWaitPolicy::TAny>::ShouldPublishByException() const noexcept {
            // for TAny exceptions are handled by ShouldPublishByCount
            return false;
        }

        template <>
        inline bool TState<TWaitPolicy::TAll>::ShouldPublishByException() const noexcept {
            return false;
        }

        template <>
        inline bool TState<TWaitPolicy::TExceptionOrAll>::ShouldPublishByException() const noexcept {
            return GetExceptionInFlight() != nullptr;
        }

        //
        //
        //

        template <class WaitPolicy>
        inline void TState<WaitPolicy>::TryPublish() {
            // the order is insignificant (without proof)
            bool shouldPublish = ShouldPublishByCount() || ShouldPublishByException();

            if (shouldPublish) {
                if (auto currentPhase = EPhase::Initial;
                    Phase.compare_exchange_strong(currentPhase, EPhase::Publishing)) {
                    Publish();
                }
            }
        }

        template <class WaitPolicy>
        inline void TState<WaitPolicy>::Publish() {
            auto eptr = GetExceptionInFlight();

            // can potentially throw
            if (eptr) {
                Subscribers.SetException(std::move(eptr));
            } else {
                Subscribers.SetValue();
            }
        }
    }

    template <class WaitPolicy>
    inline TWaitGroup<WaitPolicy>::TWaitGroup()
        : State_{MakeIntrusive<NWaitGroup::NImpl::TState<WaitPolicy>>()}
    {
    }

    template <class WaitPolicy>
    template <class T>
    inline TWaitGroup<WaitPolicy>& TWaitGroup<WaitPolicy>::Add(const TFuture<T>& future) {
        State_->Add(future);
        return *this;
    }

    template <class WaitPolicy>
    inline TFuture<void> TWaitGroup<WaitPolicy>::Finish() && {
        auto res = State_->Finish();

        // just to prevent nasty bugs from use-after-move
        State_.Reset();

        return res;
    }
}