aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/threading/future/future_mt_ut.cpp
blob: 4f390866c111614bfbe21c2dab10033c26dc1ab8 (plain) (tree)





















































































































































































































                                                                                                                                                  
#include "future.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/generic/noncopyable.h>
#include <util/generic/xrange.h>
#include <util/thread/pool.h>

#include <atomic>
#include <exception>

using NThreading::NewPromise;
using NThreading::TFuture;
using NThreading::TPromise;
using NThreading::TWaitPolicy;

namespace {
    // Wait* implementation without optimizations, to test TWaitGroup better
    template <class WaitPolicy, class TContainer>
    TFuture<void> WaitNoOpt(const TContainer& futures) {
        NThreading::TWaitGroup<WaitPolicy> wg;
        for (const auto& fut : futures) {
            wg.Add(fut);
        }

        return std::move(wg).Finish();
    }

    class TRelaxedBarrier {
    public:
        explicit TRelaxedBarrier(i64 size)
            : Waiting_{size} {
        }

        void Arrive() {
            // barrier is not for synchronization, just to ensure good timings, so
            // std::memory_order_relaxed is enough
            Waiting_.fetch_add(-1, std::memory_order_relaxed);

            while (Waiting_.load(std::memory_order_relaxed)) {
            }

            Y_ASSERT(Waiting_.load(std::memory_order_relaxed) >= 0);
        }

    private:
        std::atomic<i64> Waiting_;
    };

    THolder<TThreadPool> MakePool() {
        auto pool = MakeHolder<TThreadPool>(TThreadPool::TParams{}.SetBlocking(false).SetCatching(false));
        pool->Start(8);
        return pool;
    }

    template <class T>
    TVector<TFuture<T>> ToFutures(const TVector<TPromise<T>>& promises) {
        TVector<TFuture<void>> futures;

        for (auto&& p : promises) {
            futures.emplace_back(p);
        }

        return futures;
    }

    struct TStateSnapshot {
        i64 Started = -1;
        i64 StartedException = -1;
        const TVector<TFuture<void>>* Futures = nullptr;
    };

    // note: std::memory_order_relaxed should be enough everywhere, because TFuture::SetValue must provide the
    // needed synchronization
    template <class TFactory>
    void RunWaitTest(TFactory global) {
        auto pool = MakePool();

        const auto exception = std::make_exception_ptr(42);

        for (auto numPromises : xrange(1, 5)) {
            for (auto loopIter : xrange(1024 * 64)) {
                const auto numParticipants = numPromises + 1;

                TRelaxedBarrier barrier{numParticipants};

                std::atomic<i64> started = 0;
                std::atomic<i64> startedException = 0;
                std::atomic<i64> completed = 0;

                TVector<TPromise<void>> promises;
                for (auto i : xrange(numPromises)) {
                    Y_UNUSED(i);
                    promises.push_back(NewPromise());
                }

                const auto futures = ToFutures(promises);

                auto snapshotter = [&] {
                    return TStateSnapshot{
                        .Started = started.load(std::memory_order_relaxed),
                        .StartedException = startedException.load(std::memory_order_relaxed),
                        .Futures = &futures,
                    };
                };

                for (auto i : xrange(numPromises)) {
                    pool->SafeAddFunc([&, i] {
                        barrier.Arrive();

                        // subscribers must observe effects of this operation
                        // after .Set*
                        started.fetch_add(1, std::memory_order_relaxed);

                        if ((loopIter % 4 == 0) && i == 0) {
                            startedException.fetch_add(1, std::memory_order_relaxed);
                            promises[i].SetException(exception);
                        } else {
                            promises[i].SetValue();
                        }

                        completed.fetch_add(1, std::memory_order_release);
                    });
                }

                pool->SafeAddFunc([&] {
                    auto local = global(snapshotter);

                    barrier.Arrive();

                    local();

                    completed.fetch_add(1, std::memory_order_release);
                });

                while (completed.load() != numParticipants) {
                }
            }
        }
    }
}

Y_UNIT_TEST_SUITE(TFutureMultiThreadedTest) {
    Y_UNIT_TEST(WaitAll) {
        RunWaitTest(
            [](auto snapshotter) {
                return [=]() {
                    auto* futures = snapshotter().Futures;

                    auto all = WaitNoOpt<TWaitPolicy::TAll>(*futures);

                    // tests safety part
                    all.Subscribe([=] (auto&& all) {
                        TStateSnapshot snap = snapshotter();

                        // value safety: all is set => every future is set
                        UNIT_ASSERT(all.HasValue() <= ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException));

                        // safety for hasException: all is set => every future is set and some has exception
                        UNIT_ASSERT(all.HasException() <= ((snap.Started == (i64)snap.Futures->size()) && snap.StartedException > 0));
                    });

                    // test liveness
                    all.Wait();
                };
            });
    }

    Y_UNIT_TEST(WaitAny) {
        RunWaitTest(
            [](auto snapshotter) {
                return [=]() {
                    auto* futures = snapshotter().Futures;

                    auto any = WaitNoOpt<TWaitPolicy::TAny>(*futures);

                    // safety: any is ready => some f is ready
                    any.Subscribe([=](auto&&) {
                        UNIT_ASSERT(snapshotter().Started > 0);
                    });

                    // do we need better multithreaded liveness tests?
                    any.Wait();
                };
            });
    }

    Y_UNIT_TEST(WaitExceptionOrAll) {
        RunWaitTest(
            [](auto snapshotter) {
                return [=]() {
                    NThreading::WaitExceptionOrAll(*snapshotter().Futures)
                        .Subscribe([=](auto&&) {
                            auto* futures = snapshotter().Futures;

                            auto exceptionOrAll = WaitNoOpt<TWaitPolicy::TExceptionOrAll>(*futures);

                            exceptionOrAll.Subscribe([snapshotter](auto&& exceptionOrAll) {
                                TStateSnapshot snap = snapshotter();

                                // safety for hasException: exceptionOrAll has exception => some has exception
                                UNIT_ASSERT(exceptionOrAll.HasException() ? snap.StartedException > 0 : true);

                                // value safety: exceptionOrAll has value => all have value
                                UNIT_ASSERT(exceptionOrAll.HasValue() == ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException));
                            });

                            // do we need better multithreaded liveness tests?
                            exceptionOrAll.Wait();
                        });
                };
            });
    }
}