diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/threading/future/subscription | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/threading/future/subscription')
22 files changed, 1938 insertions, 0 deletions
diff --git a/library/cpp/threading/future/subscription/README.md b/library/cpp/threading/future/subscription/README.md new file mode 100644 index 0000000000..62c7e1303e --- /dev/null +++ b/library/cpp/threading/future/subscription/README.md @@ -0,0 +1,104 @@ +Subscriptions manager and wait primitives library +================================================= + +Wait primitives +--------------- + +All wait primitives are futures those being signaled when some or all of theirs dependencies are signaled. +Wait privimitives could be constructed either from an initializer_list or from a standard container of futures. + +1. WaitAll is signaled when all its dependencies are signaled: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_all.h> + + auto w = NWait::WaitAll({ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all futures + ``` + +2. WaitAny is signaled when any of its dependencies is signaled: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_any.h> + + auto w = NWait::WaitAny(TVector<TFuture<T>>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for any future + ``` + +3. WaitAllOrException is signaled when all its dependencies are signaled with values or any dependency is signaled with an exception: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_all_or_exception.h> + + auto w = NWait::WaitAllOrException(TVector<TFuture<T>>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all values or for an exception + ``` + +Subscriptions manager +--------------------- + +The subscription manager can manage multiple links beetween futures and callbacks. Multiple managed subscriptions to a single future shares just a single underlying subscription to the future. That allows dynamic creation and deletion of subscriptions and efficient implementation of different wait primitives. +The subscription manager could be used in the following way: + +1. Subscribe to a single future: + + ```C++ + #include <library/cpp/threading/subscriptions/subscription.h> + + TFuture<int> LongOperation(); + + ... + auto future = LongRunnigOperation(); + auto m = MakeSubsriptionManager<int>(); + auto id = m->Subscribe(future, [](TFuture<int> const& f) { + try { + auto value = f.GetValue(); + ... + } catch (...) { + ... // handle exception + } + }); + if (id.has_value()) { + ... // Callback will run asynchronously + } else { + ... // Future has been signaled already. The callback has been invoked synchronously + } + ``` + + Note that a callback could be invoked synchronously during a Subscribe call. In this case the returned optional will have no value. + +2. Unsubscribe from a single future: + + ```C++ + // id holds the subscription id from a previous Subscribe call + m->Unsubscribe(id.value()); + ``` + + There is no need to call Unsubscribe if the callback has been called. In this case Unsubscribe will do nothing. And it is safe to call Unsubscribe with the same id multiple times. + +3. Subscribe a single callback to multiple futures: + + ```C++ + auto ids = m->Subscribe({ future1, future2, ..., futureN }, [](auto&& f) { ... }); + ... + ``` + + Futures could be passed to Subscribe method either via an initializer_list or via a standard container like vector or list. Subscribe method accept an optional boolean parameter revertOnSignaled. If the parameter is false (default) then all suscriptions will be performed regardless of the futures states and the returned vector will have a subscription id for each future (even if callback has been executed synchronously for some futures). Otherwise the method will stop on the first signaled future (the callback will be synchronously called for it), no suscriptions will be created and an empty vector will be returned. + +4. Unsubscribe multiple subscriptions: + + ```C++ + // ids is the vector or subscription ids + m->Unsubscribe(ids); + ``` + + The vector of IDs could be a result of a previous Subscribe call or an arbitrary set of IDs of previously created subscriptions. + +5. If you do not want to instantiate a new instance of the subscription manager it is possible to use the default instance: + + ```C++ + auto m = TSubscriptionManager<T>::Default(); + ``` diff --git a/library/cpp/threading/future/subscription/subscription-inl.h b/library/cpp/threading/future/subscription/subscription-inl.h new file mode 100644 index 0000000000..a45d8999d3 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription-inl.h @@ -0,0 +1,118 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H) +#error "you should never include subscription-inl.h directly" +#endif + +namespace NThreading { + +namespace NPrivate { + +template <typename T> +TFutureStateId CheckedStateId(TFuture<T> const& future) { + auto const id = future.StateId(); + if (id.Defined()) { + return *id; + } + ythrow TFutureException() << "Future state should be initialized"; +} + +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TSubscriptionManager::TSubscription::TSubscription(TFuture<T> future, F&& callback, TCallbackExecutor&& executor) + : Callback( + [future = std::move(future), callback = std::forward<F>(callback), executor = std::forward<TCallbackExecutor>(executor)]() mutable { + executor(std::as_const(future), callback); + }) +{ +} + +template <typename T, typename F, typename TCallbackExecutor> +inline std::optional<TSubscriptionId> TSubscriptionManager::Subscribe(TFuture<T> const& future, F&& callback, TCallbackExecutor&& executor) { + auto stateId = NPrivate::CheckedStateId(future); + with_lock(Lock) { + auto const status = TrySubscribe(future, std::forward<F>(callback), stateId, std::forward<TCallbackExecutor>(executor)); + switch (status) { + case ECallbackStatus::Subscribed: + return TSubscriptionId(stateId, Revision); + case ECallbackStatus::ExecutedSynchronously: + return {}; + default: + Y_FAIL("Unexpected callback status"); + } + } +} + +template <typename TFutures, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled + , TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward<F>(callback), revertOnSignaled, std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::Subscribe(std::initializer_list<TFuture<T> const> futures, F&& callback + , bool revertOnSignaled, TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward<F>(callback), revertOnSignaled, std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TSubscriptionManager::ECallbackStatus TSubscriptionManager::TrySubscribe(TFuture<T> const& future, F&& callback, TFutureStateId stateId + , TCallbackExecutor&& executor) +{ + TSubscription subscription(future, std::forward<F>(callback), std::forward<TCallbackExecutor>(executor)); + auto const it = Subscriptions.find(stateId); + auto const revision = ++Revision; + if (it == std::end(Subscriptions)) { + auto const success = Subscriptions.emplace(stateId, THashMap<ui64, TSubscription>{ { revision, std::move(subscription) } }).second; + Y_VERIFY(success); + auto self = TSubscriptionManagerPtr(this); + future.Subscribe([self, stateId](TFuture<T> const&) { self->OnCallback(stateId); }); + if (Subscriptions.find(stateId) == std::end(Subscriptions)) { + return ECallbackStatus::ExecutedSynchronously; + } + } else { + Y_VERIFY(it->second.emplace(revision, std::move(subscription)).second); + } + return ECallbackStatus::Subscribed; +} + +template <typename TFutures, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor) +{ + TVector<TSubscriptionId> results; + results.reserve(std::size(futures)); + // resolve all state ids to minimize processing under the lock + for (auto const& f : futures) { + results.push_back(TSubscriptionId(NPrivate::CheckedStateId(f), 0)); + } + with_lock(Lock) { + size_t i = 0; + for (auto const& f : futures) { + auto& r = results[i]; + auto const status = TrySubscribe(f, callback, r.StateId(), executor); + switch (status) { + case ECallbackStatus::Subscribed: + break; + case ECallbackStatus::ExecutedSynchronously: + if (revertOnSignaled) { + // revert + results.crop(i); + UnsubscribeImpl(results); + return {}; + } + break; + default: + Y_FAIL("Unexpected callback status"); + } + r.SetSubId(Revision); + ++i; + } + } + return results; +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.cpp b/library/cpp/threading/future/subscription/subscription.cpp new file mode 100644 index 0000000000..a98b4a4f03 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.cpp @@ -0,0 +1,65 @@ +#include "subscription.h" + +namespace NThreading { + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return l.StateId() == r.StateId() && l.SubId() == r.SubId(); +} + +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return !(l == r); +} + +void TSubscriptionManager::TSubscription::operator()() { + Callback(); +} + +TSubscriptionManagerPtr TSubscriptionManager::NewInstance() { + return new TSubscriptionManager(); +} + +TSubscriptionManagerPtr TSubscriptionManager::Default() { + static auto instance = NewInstance(); + return instance; +} + +void TSubscriptionManager::Unsubscribe(TSubscriptionId id) { + with_lock(Lock) { + UnsubscribeImpl(id); + } +} + +void TSubscriptionManager::Unsubscribe(TVector<TSubscriptionId> const& ids) { + with_lock(Lock) { + UnsubscribeImpl(ids); + } +} + +void TSubscriptionManager::OnCallback(TFutureStateId stateId) noexcept { + THashMap<ui64, TSubscription> subscriptions; + with_lock(Lock) { + auto const it = Subscriptions.find(stateId); + Y_VERIFY(it != Subscriptions.end(), "The callback has been triggered more than once"); + subscriptions.swap(it->second); + Subscriptions.erase(it); + } + for (auto& [_, subscription] : subscriptions) { + subscription(); + } +} + +void TSubscriptionManager::UnsubscribeImpl(TSubscriptionId id) { + auto const it = Subscriptions.find(id.StateId()); + if (it == std::end(Subscriptions)) { + return; + } + it->second.erase(id.SubId()); +} + +void TSubscriptionManager::UnsubscribeImpl(TVector<TSubscriptionId> const& ids) { + for (auto const& id : ids) { + UnsubscribeImpl(id); + } +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.h b/library/cpp/threading/future/subscription/subscription.h new file mode 100644 index 0000000000..afe5eda711 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.h @@ -0,0 +1,186 @@ +#pragma once + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/hash.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/system/mutex.h> + +#include <functional> +#include <optional> +#include <utility> + +namespace NThreading { + +namespace NPrivate { + +struct TNoexceptExecutor { + template <typename T, typename F> + void operator()(TFuture<T> const& future, F&& callee) const noexcept { + return callee(future); + } +}; + +} + +class TSubscriptionManager; + +using TSubscriptionManagerPtr = TIntrusivePtr<TSubscriptionManager>; + +//! A subscription id +class TSubscriptionId { +private: + TFutureStateId StateId_; + ui64 SubId_; // Secondary id to make the whole subscription id unique + + friend class TSubscriptionManager; + +public: + TFutureStateId StateId() const noexcept { + return StateId_; + } + + ui64 SubId() const noexcept { + return SubId_; + } + +private: + TSubscriptionId(TFutureStateId stateId, ui64 subId) + : StateId_(stateId) + , SubId_(subId) + { + } + + void SetSubId(ui64 subId) noexcept { + SubId_ = subId; + } +}; + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; + +//! The subscription manager manages subscriptions to futures +/** It provides an ability to create (and drop) multiple subscriptions to any future + with just a single underlying subscription per future. + + When a future is signaled all its subscriptions are removed. + So, there no need to call Unsubscribe for subscriptions to already signaled futures. + + Warning!!! For correct operation this class imposes the following requirement to futures/promises: + Any used future must be signaled (value or exception set) before the future state destruction. + Otherwise subscriptions and futures may happen. + Current future design does not provide the required guarantee. But that should be fixed soon. +**/ +class TSubscriptionManager final : public TAtomicRefCount<TSubscriptionManager> { +private: + //! A single subscription + class TSubscription { + private: + std::function<void()> Callback; + + public: + template <typename T, typename F, typename TCallbackExecutor> + TSubscription(TFuture<T> future, F&& callback, TCallbackExecutor&& executor); + + void operator()(); + }; + + struct TFutureStateIdHash { + size_t operator()(TFutureStateId const id) const noexcept { + auto const value = id.Value(); + return ::hash<decltype(value)>()(value); + } + }; + +private: + THashMap<TFutureStateId, THashMap<ui64, TSubscription>, TFutureStateIdHash> Subscriptions; + ui64 Revision = 0; + TMutex Lock; + +public: + //! Creates a new subscription manager instance + static TSubscriptionManagerPtr NewInstance(); + + //! The default subscription manager instance + static TSubscriptionManagerPtr Default(); + + //! Attempts to subscribe the callback to the future + /** Subscription should succeed if the future is not signaled yet. + Otherwise the callback will be called synchronously and nullopt will be returned + + @param future - The future to subscribe to + @param callback - The callback to attach + @return The subscription id on success, nullopt if the future has been signaled already + **/ + template <typename T, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + std::optional<TSubscriptionId> Subscribe(TFuture<T> const& future, F&& callback + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscription with the given id + /** @param id - The subscription id + **/ + void Unsubscribe(TSubscriptionId id); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template <typename TFutures, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + TVector<TSubscriptionId> Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template <typename T, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + TVector<TSubscriptionId> Subscribe(std::initializer_list<TFuture<T> const> futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscriptions with the given ids + /** @param ids - The subscription ids + **/ + void Unsubscribe(TVector<TSubscriptionId> const& ids); + +private: + enum class ECallbackStatus { + Subscribed, //! A subscription has been created. The callback will be called asynchronously. + ExecutedSynchronously //! A callback has been called synchronously. No subscription has been created + }; + +private: + //! .ctor + TSubscriptionManager() = default; + //! Processes a callback from a future + void OnCallback(TFutureStateId stateId) noexcept; + //! Attempts to create a subscription + /** This method should be called under the lock + **/ + template <typename T, typename F, typename TCallbackExecutor> + ECallbackStatus TrySubscribe(TFuture<T> const& future, F&& callback, TFutureStateId stateId, TCallbackExecutor&& executor); + //! Batch subscribe implementation + template <typename TFutures, typename F, typename TCallbackExecutor> + TVector<TSubscriptionId> SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor); + //! Unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TSubscriptionId id); + //! Batch unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TVector<TSubscriptionId> const& ids); +}; + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H +#include "subscription-inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H diff --git a/library/cpp/threading/future/subscription/subscription_ut.cpp b/library/cpp/threading/future/subscription/subscription_ut.cpp new file mode 100644 index 0000000000..d018ea15cc --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription_ut.cpp @@ -0,0 +1,432 @@ +#include "subscription.h" + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TSubscriptionManagerTest) { + + Y_UNIT_TEST(TestSubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto id = m->Subscribe(f, [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(!id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledAndSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount1, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledUnsubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestUnsubscribeTwice) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeOneSignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + m->Unsubscribe(id3.value()); + m->Unsubscribe(id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 1); + } + + Y_UNIT_TEST(TestSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p1.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManySignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f1, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f2, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f2, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManyMixed) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p3.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p2.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + size_t callCount5 = 0; + auto id5 = m->Subscribe(p1.GetFuture(), [&callCount5](auto&&) { ++callCount5; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT(id5.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 0); + + m->Unsubscribe(id1.value()); + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + m->Unsubscribe(id4.value()); + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), p1.GetFuture() }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), f }, [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), f, p2.GetFuture() }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(); + p2.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + + size_t callCount = 0; + auto ids = m->Subscribe( + TVector<TFuture<int>>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture(), p2.GetFuture(), p1.GetFuture() } + , [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 5); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(TVector<TSubscriptionId>{ ids[0], ids[3] }); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount, 1); + + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount, 2); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount, 3); + } +} diff --git a/library/cpp/threading/future/subscription/ut/ya.make b/library/cpp/threading/future/subscription/ut/ya.make new file mode 100644 index 0000000000..45210f7bd7 --- /dev/null +++ b/library/cpp/threading/future/subscription/ut/ya.make @@ -0,0 +1,17 @@ +UNITTEST_FOR(library/cpp/threading/future/subscription) + +OWNER( + g:kwyt + g:rtmr + ishfb +) + +SRCS( + subscription_ut.cpp + wait_all_ut.cpp + wait_all_or_exception_ut.cpp + wait_any_ut.cpp + wait_ut_common.cpp +) + +END() diff --git a/library/cpp/threading/future/subscription/wait.h b/library/cpp/threading/future/subscription/wait.h new file mode 100644 index 0000000000..533bab9d8d --- /dev/null +++ b/library/cpp/threading/future/subscription/wait.h @@ -0,0 +1,119 @@ +#pragma once + +#include "subscription.h" + +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/system/spinlock.h> + + +#include <initializer_list> + +namespace NThreading::NPrivate { + +template <typename TDerived> +class TWait : public TThrRefBase { +private: + TSubscriptionManagerPtr Manager; + TVector<TSubscriptionId> Subscriptions; + bool Unsubscribed = false; + +protected: + TAdaptiveLock Lock; + TPromise<void> Promise; + +public: + template <typename TFutures, typename TCallbackExecutor> + static TFuture<void> Make(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + TIntrusivePtr<TDerived> w(new TDerived(std::move(manager))); + w->Subscribe(futures, std::forward<TCallbackExecutor>(executor)); + return w->Promise.GetFuture(); + } + +protected: + TWait(TSubscriptionManagerPtr manager) + : Manager(std::move(manager)) + , Subscriptions() + , Unsubscribed(false) + , Lock() + , Promise(NewPromise()) + { + Y_ENSURE(Manager != nullptr); + } + +protected: + //! Unsubscribes all existing subscriptions + /** Lock should be acquired! + **/ + void Unsubscribe() noexcept { + if (Unsubscribed) { + return; + } + Unsubscribe(Subscriptions); + Subscriptions.clear(); + } + +private: + //! Performs a subscription to the given futures + /** Lock should not be acquired! + @param future - The futures to subscribe to + @param callback - The callback to call for each future + **/ + template <typename TFutures, typename TCallbackExecutor> + void Subscribe(TFutures const& futures, TCallbackExecutor&& executor) { + auto self = TIntrusivePtr<TDerived>(static_cast<TDerived*>(this)); + self->BeforeSubscribe(futures); + auto callback = [self = std::move(self)](const auto& future) mutable { + self->Set(future); + }; + auto subscriptions = Manager->Subscribe(futures, callback, TDerived::RevertOnSignaled, std::forward<TCallbackExecutor>(executor)); + if (subscriptions.empty()) { + return; + } + with_lock (Lock) { + if (Unsubscribed) { + Unsubscribe(subscriptions); + } else { + Subscriptions = std::move(subscriptions); + } + } + } + + void Unsubscribe(TVector<TSubscriptionId>& subscriptions) noexcept { + Manager->Unsubscribe(subscriptions); + Unsubscribed = true; + } +}; + +template <typename TWaiter, typename TFutures, typename TCallbackExecutor> +TFuture<void> Wait(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); + } +} + +template <typename TWaiter, typename T, typename TCallbackExecutor> +TFuture<void> Wait(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); + } +} + + +template <typename TWaiter, typename T, typename TCallbackExecutor> +TFuture<void> Wait(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return TWaiter::Make(std::initializer_list<TFuture<T> const>({ future1, future2 }), std::move(manager) + , std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all.cpp b/library/cpp/threading/future/subscription/wait_all.cpp new file mode 100644 index 0000000000..10e7ee7598 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.cpp @@ -0,0 +1 @@ +#include "wait_all.h" diff --git a/library/cpp/threading/future/subscription/wait_all.h b/library/cpp/threading/future/subscription/wait_all.h new file mode 100644 index 0000000000..5c0d2bb862 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.h @@ -0,0 +1,23 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAll(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAll(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H +#include "wait_all_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_inl.h b/library/cpp/threading/future/subscription/wait_all_inl.h new file mode 100644 index 0000000000..a3b665f642 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_inl.h @@ -0,0 +1,80 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H) +#error "you should never include wait_all_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAll final : public NThreading::NPrivate::TWait<TWaitAll> { +private: + size_t Count = 0; + std::exception_ptr Exception; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait<TWaitAll>; + friend TBase; + +private: + TWaitAll(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + , Exception() + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + if (!Exception) { + try { + future.TryRethrow(); + } catch (...) { + Exception = std::current_exception(); + } + } + + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + Y_ASSERT(!TBase::Promise.HasValue() && !TBase::Promise.HasException()); + if (Exception) { + TBase::Promise.SetException(std::move(Exception)); + } else { + TBase::Promise.SetValue(); + } + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(future1, future2, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp new file mode 100644 index 0000000000..0c73ddeb84 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp @@ -0,0 +1 @@ +#include "wait_all_or_exception.h" diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.h b/library/cpp/threading/future/subscription/wait_all_or_exception.h new file mode 100644 index 0000000000..e3e0caf2f8 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.h @@ -0,0 +1,25 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(std::initializer_list<TFuture<T> const> futures + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(TFuture<T> const& future1, TFuture<T> const& future2 + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H +#include "wait_all_or_exception_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h new file mode 100644 index 0000000000..fcd9782d54 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h @@ -0,0 +1,79 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H) +#error "you should never include wait_all_or_exception_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAllOrException final : public NThreading::NPrivate::TWait<TWaitAllOrException> +{ +private: + size_t Count = 0; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait<TWaitAllOrException>; + friend TBase; + +private: + TWaitAllOrException(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + try { + future.TryRethrow(); + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + TBase::Promise.SetValue(); + } + } catch (...) { + Y_ASSERT(!TBase::Promise.HasValue()); + TBase::Unsubscribe(); + if (!TBase::Promise.HasException()) { + TBase::Promise.SetException(std::current_exception()); + } + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(future1, future2, std::move(manager) + , std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp new file mode 100644 index 0000000000..34ae9edb4e --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp @@ -0,0 +1,167 @@ +#include "wait_all_or_exception.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <atomic> +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllOrExceptionTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAllOrException(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAllOrException(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAllOrException({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAllOrException(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException(TVector<TFuture<int>>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(34); + } + + Y_UNIT_TEST(TestManyWithVectorAndIntialError) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + constexpr TStringBuf message = "Test exception 5"; + auto f = MakeErrorFuture<void>(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAllOrException(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic<size_t> index = 0; + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_all_ut.cpp b/library/cpp/threading/future/subscription/wait_all_ut.cpp new file mode 100644 index 0000000000..3bc9762671 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_ut.cpp @@ -0,0 +1,161 @@ +#include "wait_all.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <atomic> +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(-11); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAll(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAll(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAll({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAll(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll(TVector<TFuture<int>>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(34); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(42); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAll(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic<size_t> index = 0; + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAll(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(index); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_any.cpp b/library/cpp/threading/future/subscription/wait_any.cpp new file mode 100644 index 0000000000..57cc1b2c25 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.cpp @@ -0,0 +1 @@ +#include "wait_any.h" diff --git a/library/cpp/threading/future/subscription/wait_any.h b/library/cpp/threading/future/subscription/wait_any.h new file mode 100644 index 0000000000..e770d7b59e --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.h @@ -0,0 +1,23 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAny(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAny(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H +#include "wait_any_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H diff --git a/library/cpp/threading/future/subscription/wait_any_inl.h b/library/cpp/threading/future/subscription/wait_any_inl.h new file mode 100644 index 0000000000..e80822bfc9 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_inl.h @@ -0,0 +1,64 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H) +#error "you should never include wait_any_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAny final : public NThreading::NPrivate::TWait<TWaitAny> { +private: + static constexpr bool RevertOnSignaled = true; + + using TBase = NThreading::NPrivate::TWait<TWaitAny>; + friend TBase; + +private: + TWaitAny(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Y_ENSURE(std::size(futures) > 0, "Futures set cannot be empty"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + TBase::Unsubscribe(); + try { + future.TryRethrow(); + TBase::Promise.TrySetValue(); + } catch (...) { + TBase::Promise.TrySetException(std::current_exception()); + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(future1, future2, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_any_ut.cpp b/library/cpp/threading/future/subscription/wait_any_ut.cpp new file mode 100644 index 0000000000..262080e8d1 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_ut.cpp @@ -0,0 +1,166 @@ +#include "wait_any.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAnyTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(w.HasValue()); + p2.SetValue(1); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(p.GetFuture(), f); + UNIT_ASSERT(w.HasValue()); + + p.SetValue(); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + constexpr TStringBuf message = "Test exception 2"; + auto f = MakeErrorFuture<void>(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAny(f, p.GetFuture()); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p.SetValue(); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAny(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAny(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAny({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyUnsignaledWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + auto w = NWait::WaitAny({ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(42); + UNIT_ASSERT(w.HasValue()); + + p2.SetValue(-3); + p3.SetValue(12); + } + + Y_UNIT_TEST(TestManyMixedWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAny({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(10); + p2.SetValue(-3); + } + + + Y_UNIT_TEST(TestManyUnsignaledWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p3.SetValue(); + } + + + Y_UNIT_TEST(TestManyMixedWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAny(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.cpp b/library/cpp/threading/future/subscription/wait_ut_common.cpp new file mode 100644 index 0000000000..9f961e7303 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.cpp @@ -0,0 +1,26 @@ +#include "wait_ut_common.h" + +#include <util/random/shuffle.h> +#include <util/system/event.h> +#include <util/thread/pool.h> + +namespace NThreading::NTest::NPrivate { + +void ExecuteAndWait(TVector<std::function<void()>> jobs, TFuture<void> waiter, size_t threads) { + Y_ENSURE(threads > 0); + Shuffle(jobs.begin(), jobs.end()); + auto pool = CreateThreadPool(threads); + TManualEvent start; + for (auto& j : jobs) { + pool->SafeAddFunc( + [&start, job = std::move(j)]() { + start.WaitI(); + job(); + }); + } + start.Signal(); + waiter.Wait(); + pool->Stop(); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.h b/library/cpp/threading/future/subscription/wait_ut_common.h new file mode 100644 index 0000000000..99530dd1f6 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.h @@ -0,0 +1,56 @@ +#pragma once + +#include <library/cpp/threading/future/future.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/vector.h> + +#include <functional> +#include <type_traits> + +namespace NThreading::NTest { + +namespace NPrivate { + +void ExecuteAndWait(TVector<std::function<void()>> jobs, TFuture<void> waiter, size_t threads); + +template <typename TPromises, typename FSetter> +void SetConcurrentAndWait(TPromises&& promises, FSetter&& setter, TFuture<void> waiter, size_t threads = 8) { + TVector<std::function<void()>> jobs; + jobs.reserve(std::size(promises)); + for (auto& p : promises) { + jobs.push_back([p, setter]() mutable {setter(p); }); + } + ExecuteAndWait(std::move(jobs), std::move(waiter), threads); +} + +template <typename T> +auto MakePromise() { + if constexpr (std::is_same_v<T, void>) { + return NewPromise(); + } + return NewPromise<T>(); +} + +} + +template <typename T, typename FWaiterFactory, typename FSetterFactory, typename FChecker> +void TestManyStress(FWaiterFactory&& waiterFactory, FSetterFactory&& setterFactory, FChecker&& checker) { + for (size_t i : { 1, 2, 4, 8, 16, 32, 64, 128, 256 }) { + TVector<TPromise<T>> promises; + TVector<TFuture<T>> futures; + promises.reserve(i); + futures.reserve(i); + for (size_t j = 0; j < i; ++j) { + auto promise = NPrivate::MakePromise<T>(); + futures.push_back(promise.GetFuture()); + promises.push_back(std::move(promise)); + } + auto waiter = waiterFactory(futures); + NPrivate::SetConcurrentAndWait(std::move(promises), [valueSetter = setterFactory(i)](auto&& p) { valueSetter(p); } + , waiter); + checker(waiter); + } +} + +} diff --git a/library/cpp/threading/future/subscription/ya.make b/library/cpp/threading/future/subscription/ya.make new file mode 100644 index 0000000000..cb75731dbf --- /dev/null +++ b/library/cpp/threading/future/subscription/ya.make @@ -0,0 +1,24 @@ +OWNER( + g:kwyt + g:rtmr + ishfb +) + +LIBRARY() + +SRCS( + subscription.cpp + wait_all.cpp + wait_all_or_exception.cpp + wait_any.cpp +) + +PEERDIR( + library/cpp/threading/future +) + +END() + +RECURSE_FOR_TESTS( + ut +) |