aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/threading/future/future_ut.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/threading/future/future_ut.cpp
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/threading/future/future_ut.cpp')
-rw-r--r--library/cpp/threading/future/future_ut.cpp640
1 files changed, 640 insertions, 0 deletions
diff --git a/library/cpp/threading/future/future_ut.cpp b/library/cpp/threading/future/future_ut.cpp
new file mode 100644
index 0000000000..05950a568d
--- /dev/null
+++ b/library/cpp/threading/future/future_ut.cpp
@@ -0,0 +1,640 @@
+#include "future.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <list>
+#include <type_traits>
+
+namespace NThreading {
+
+namespace {
+
+ class TCopyCounter {
+ public:
+ TCopyCounter(size_t* numCopies)
+ : NumCopies(numCopies)
+ {}
+
+ TCopyCounter(const TCopyCounter& that)
+ : NumCopies(that.NumCopies)
+ {
+ ++*NumCopies;
+ }
+
+ TCopyCounter& operator=(const TCopyCounter& that) {
+ NumCopies = that.NumCopies;
+ ++*NumCopies;
+ return *this;
+ }
+
+ TCopyCounter(TCopyCounter&& that) = default;
+
+ TCopyCounter& operator=(TCopyCounter&& that) = default;
+
+ private:
+ size_t* NumCopies = nullptr;
+ };
+
+ template <typename T>
+ auto MakePromise() {
+ if constexpr (std::is_same_v<T, void>) {
+ return NewPromise();
+ }
+ return NewPromise<T>();
+ }
+
+
+ template <typename T>
+ void TestFutureStateId() {
+ TFuture<T> empty;
+ UNIT_ASSERT(!empty.StateId().Defined());
+ auto promise1 = MakePromise<T>();
+ auto future11 = promise1.GetFuture();
+ UNIT_ASSERT(future11.StateId().Defined());
+ auto future12 = promise1.GetFuture();
+ UNIT_ASSERT_EQUAL(future11.StateId(), future11.StateId()); // same result for subsequent invocations
+ UNIT_ASSERT_EQUAL(future11.StateId(), future12.StateId()); // same result for different futures with the same state
+ auto promise2 = MakePromise<T>();
+ auto future2 = promise2.GetFuture();
+ UNIT_ASSERT(future2.StateId().Defined());
+ UNIT_ASSERT_UNEQUAL(future11.StateId(), future2.StateId()); // different results for futures with different states
+ }
+
+}
+
+ ////////////////////////////////////////////////////////////////////////////////
+
+ Y_UNIT_TEST_SUITE(TFutureTest) {
+ Y_UNIT_TEST(ShouldInitiallyHasNoValue) {
+ TPromise<int> promise;
+ UNIT_ASSERT(!promise.HasValue());
+
+ promise = NewPromise<int>();
+ UNIT_ASSERT(!promise.HasValue());
+
+ TFuture<int> future;
+ UNIT_ASSERT(!future.HasValue());
+
+ future = promise.GetFuture();
+ UNIT_ASSERT(!future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldInitiallyHasNoValueVoid) {
+ TPromise<void> promise;
+ UNIT_ASSERT(!promise.HasValue());
+
+ promise = NewPromise();
+ UNIT_ASSERT(!promise.HasValue());
+
+ TFuture<void> future;
+ UNIT_ASSERT(!future.HasValue());
+
+ future = promise.GetFuture();
+ UNIT_ASSERT(!future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldStoreValue) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetValue(123);
+ UNIT_ASSERT(promise.HasValue());
+ UNIT_ASSERT_EQUAL(promise.GetValue(), 123);
+
+ TFuture<int> future = promise.GetFuture();
+ UNIT_ASSERT(future.HasValue());
+ UNIT_ASSERT_EQUAL(future.GetValue(), 123);
+
+ future = MakeFuture(345);
+ UNIT_ASSERT(future.HasValue());
+ UNIT_ASSERT_EQUAL(future.GetValue(), 345);
+ }
+
+ Y_UNIT_TEST(ShouldStoreValueVoid) {
+ TPromise<void> promise = NewPromise();
+ promise.SetValue();
+ UNIT_ASSERT(promise.HasValue());
+
+ TFuture<void> future = promise.GetFuture();
+ UNIT_ASSERT(future.HasValue());
+
+ future = MakeFuture();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ struct TTestCallback {
+ int Value;
+
+ TTestCallback(int value)
+ : Value(value)
+ {
+ }
+
+ void Callback(const TFuture<int>& future) {
+ Value += future.GetValue();
+ }
+
+ int Func(const TFuture<int>& future) {
+ return (Value += future.GetValue());
+ }
+
+ void VoidFunc(const TFuture<int>& future) {
+ future.GetValue();
+ }
+
+ TFuture<int> FutureFunc(const TFuture<int>& future) {
+ return MakeFuture(Value += future.GetValue());
+ }
+
+ TPromise<void> Signal = NewPromise();
+ TFuture<void> FutureVoidFunc(const TFuture<int>& future) {
+ future.GetValue();
+ return Signal;
+ }
+ };
+
+ Y_UNIT_TEST(ShouldInvokeCallback) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<int> future = promise.GetFuture()
+ .Subscribe([&](const TFuture<int>& theFuture) { return callback.Callback(theFuture); });
+
+ promise.SetValue(456);
+ UNIT_ASSERT_EQUAL(future.GetValue(), 456);
+ UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
+ }
+
+ Y_UNIT_TEST(ShouldApplyFunc) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<int> future = promise.GetFuture()
+ .Apply([&](const auto& theFuture) { return callback.Func(theFuture); });
+
+ promise.SetValue(456);
+ UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456);
+ UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
+ }
+
+ Y_UNIT_TEST(ShouldApplyVoidFunc) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<void> future = promise.GetFuture()
+ .Apply([&](const auto& theFuture) { return callback.VoidFunc(theFuture); });
+
+ promise.SetValue(456);
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldApplyFutureFunc) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<int> future = promise.GetFuture()
+ .Apply([&](const auto& theFuture) { return callback.FutureFunc(theFuture); });
+
+ promise.SetValue(456);
+ UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456);
+ UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
+ }
+
+ Y_UNIT_TEST(ShouldApplyFutureVoidFunc) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<void> future = promise.GetFuture()
+ .Apply([&](const auto& theFuture) { return callback.FutureVoidFunc(theFuture); });
+
+ promise.SetValue(456);
+ UNIT_ASSERT(!future.HasValue());
+
+ callback.Signal.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldIgnoreResultIfAsked) {
+ TPromise<int> promise = NewPromise<int>();
+
+ TTestCallback callback(123);
+ TFuture<int> future = promise.GetFuture().IgnoreResult().Return(42);
+
+ promise.SetValue(456);
+ UNIT_ASSERT_EQUAL(future.GetValue(), 42);
+ }
+
+ class TCustomException: public yexception {
+ };
+
+ Y_UNIT_TEST(ShouldRethrowException) {
+ TPromise<int> promise = NewPromise<int>();
+ try {
+ ythrow TCustomException();
+ } catch (...) {
+ promise.SetException(std::current_exception());
+ }
+
+ UNIT_ASSERT(!promise.HasValue());
+ UNIT_ASSERT(promise.HasException());
+ UNIT_ASSERT_EXCEPTION(promise.GetValue(), TCustomException);
+ UNIT_ASSERT_EXCEPTION(promise.TryRethrow(), TCustomException);
+ }
+
+ Y_UNIT_TEST(ShouldRethrowCallbackException) {
+ TPromise<int> promise = NewPromise<int>();
+ TFuture<int> future = promise.GetFuture();
+ future.Subscribe([](const TFuture<int>&) {
+ throw TCustomException();
+ });
+
+ UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException);
+ }
+
+ Y_UNIT_TEST(ShouldRethrowCallbackExceptionIgnoreResult) {
+ TPromise<int> promise = NewPromise<int>();
+ TFuture<void> future = promise.GetFuture().IgnoreResult();
+ future.Subscribe([](const TFuture<void>&) {
+ throw TCustomException();
+ });
+
+ UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException);
+ }
+
+
+ Y_UNIT_TEST(ShouldWaitExceptionOrAll) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ TFuture<void> future = WaitExceptionOrAll(promise1, promise2);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(!future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitExceptionOrAllVector) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ TVector<TFuture<void>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitExceptionOrAll(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(!future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorWithValueType) {
+ TPromise<int> promise1 = NewPromise<int>();
+ TPromise<int> promise2 = NewPromise<int>();
+
+ TVector<TFuture<int>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitExceptionOrAll(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue(0);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise2.SetValue(0);
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitExceptionOrAllList) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ std::list<TFuture<void>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitExceptionOrAll(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(!future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorEmpty) {
+ TVector<TFuture<void>> promises;
+
+ TFuture<void> future = WaitExceptionOrAll(promises);
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitAnyVector) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ TVector<TFuture<void>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitAny(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+
+ Y_UNIT_TEST(ShouldWaitAnyVectorWithValueType) {
+ TPromise<int> promise1 = NewPromise<int>();
+ TPromise<int> promise2 = NewPromise<int>();
+
+ TVector<TFuture<int>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitAny(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue(0);
+ UNIT_ASSERT(future.HasValue());
+
+ promise2.SetValue(0);
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitAnyList) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ std::list<TFuture<void>> promises;
+ promises.push_back(promise1);
+ promises.push_back(promise2);
+
+ TFuture<void> future = WaitAny(promises);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitAnyVectorEmpty) {
+ TVector<TFuture<void>> promises;
+
+ TFuture<void> future = WaitAny(promises);
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldWaitAny) {
+ TPromise<void> promise1 = NewPromise();
+ TPromise<void> promise2 = NewPromise();
+
+ TFuture<void> future = WaitAny(promise1, promise2);
+ UNIT_ASSERT(!future.HasValue());
+
+ promise1.SetValue();
+ UNIT_ASSERT(future.HasValue());
+
+ promise2.SetValue();
+ UNIT_ASSERT(future.HasValue());
+ }
+
+ Y_UNIT_TEST(ShouldStoreTypesWithoutDefaultConstructor) {
+ // compileability test
+ struct TRec {
+ explicit TRec(int) {
+ }
+ };
+
+ auto promise = NewPromise<TRec>();
+ promise.SetValue(TRec(1));
+
+ auto future = MakeFuture(TRec(1));
+ const auto& rec = future.GetValue();
+ Y_UNUSED(rec);
+ }
+
+ Y_UNIT_TEST(ShouldStoreMovableTypes) {
+ // compileability test
+ struct TRec : TMoveOnly {
+ explicit TRec(int) {
+ }
+ };
+
+ auto promise = NewPromise<TRec>();
+ promise.SetValue(TRec(1));
+
+ auto future = MakeFuture(TRec(1));
+ const auto& rec = future.GetValue();
+ Y_UNUSED(rec);
+ }
+
+ Y_UNIT_TEST(ShouldMoveMovableTypes) {
+ // compileability test
+ struct TRec : TMoveOnly {
+ explicit TRec(int) {
+ }
+ };
+
+ auto promise = NewPromise<TRec>();
+ promise.SetValue(TRec(1));
+
+ auto future = MakeFuture(TRec(1));
+ auto rec = future.ExtractValue();
+ Y_UNUSED(rec);
+ }
+
+ Y_UNIT_TEST(ShouldNotExtractAfterGet) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetValue(123);
+ UNIT_ASSERT(promise.HasValue());
+ UNIT_ASSERT_EQUAL(promise.GetValue(), 123);
+ UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException);
+ }
+
+ Y_UNIT_TEST(ShouldNotGetAfterExtract) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetValue(123);
+ UNIT_ASSERT(promise.HasValue());
+ UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123);
+ UNIT_CHECK_GENERATED_EXCEPTION(promise.GetValue(), TFutureException);
+ }
+
+ Y_UNIT_TEST(ShouldNotExtractAfterExtract) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetValue(123);
+ UNIT_ASSERT(promise.HasValue());
+ UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123);
+ UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException);
+ }
+
+ Y_UNIT_TEST(ShouldNotExtractFromSharedDefault) {
+ UNIT_CHECK_GENERATED_EXCEPTION(MakeFuture<int>().ExtractValue(), TFutureException);
+
+ struct TStorage {
+ TString String = TString(100, 'a');
+ };
+ try {
+ TString s = MakeFuture<TStorage>().ExtractValue().String;
+ Y_UNUSED(s);
+ } catch (TFutureException) {
+ // pass
+ }
+ UNIT_ASSERT_VALUES_EQUAL(MakeFuture<TStorage>().GetValue().String, TString(100, 'a'));
+ }
+
+ Y_UNIT_TEST(HandlingRepetitiveSet) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetValue(42);
+ UNIT_CHECK_GENERATED_EXCEPTION(promise.SetValue(42), TFutureException);
+ }
+
+ Y_UNIT_TEST(HandlingRepetitiveTrySet) {
+ TPromise<int> promise = NewPromise<int>();
+ UNIT_ASSERT(promise.TrySetValue(42));
+ UNIT_ASSERT(!promise.TrySetValue(42));
+ }
+
+ Y_UNIT_TEST(HandlingRepetitiveSetException) {
+ TPromise<int> promise = NewPromise<int>();
+ promise.SetException("test");
+ UNIT_CHECK_GENERATED_EXCEPTION(promise.SetException("test"), TFutureException);
+ }
+
+ Y_UNIT_TEST(HandlingRepetitiveTrySetException) {
+ TPromise<int> promise = NewPromise<int>();
+ UNIT_ASSERT(promise.TrySetException(std::make_exception_ptr("test")));
+ UNIT_ASSERT(!promise.TrySetException(std::make_exception_ptr("test")));
+ }
+
+ Y_UNIT_TEST(ShouldAllowToMakeFutureWithException)
+ {
+ auto future1 = MakeErrorFuture<void>(std::make_exception_ptr(TFutureException()));
+ UNIT_ASSERT(future1.HasException());
+ UNIT_CHECK_GENERATED_EXCEPTION(future1.GetValue(), TFutureException);
+
+ auto future2 = MakeErrorFuture<int>(std::make_exception_ptr(TFutureException()));
+ UNIT_ASSERT(future2.HasException());
+ UNIT_CHECK_GENERATED_EXCEPTION(future2.GetValue(), TFutureException);
+
+ auto future3 = MakeFuture<std::exception_ptr>(std::make_exception_ptr(TFutureException()));
+ UNIT_ASSERT(future3.HasValue());
+ UNIT_CHECK_GENERATED_NO_EXCEPTION(future3.GetValue(), TFutureException);
+
+ auto future4 = MakeFuture<std::unique_ptr<int>>(nullptr);
+ UNIT_ASSERT(future4.HasValue());
+ UNIT_CHECK_GENERATED_NO_EXCEPTION(future4.GetValue(), TFutureException);
+ }
+
+ Y_UNIT_TEST(WaitAllowsExtract) {
+ auto future = MakeFuture<int>(42);
+ TVector vec{future, future, future};
+ WaitExceptionOrAll(vec).GetValue();
+ WaitAny(vec).GetValue();
+
+ UNIT_ASSERT_EQUAL(future.ExtractValue(), 42);
+ }
+
+ Y_UNIT_TEST(IgnoreAllowsExtract) {
+ auto future = MakeFuture<int>(42);
+ future.IgnoreResult().GetValue();
+
+ UNIT_ASSERT_EQUAL(future.ExtractValue(), 42);
+ }
+
+ Y_UNIT_TEST(WaitExceptionOrAllException) {
+ auto promise1 = NewPromise();
+ auto promise2 = NewPromise();
+ auto future1 = promise1.GetFuture();
+ auto future2 = promise2.GetFuture();
+ auto wait = WaitExceptionOrAll(future1, future2);
+ promise2.SetException("foo-exception");
+ wait.Wait();
+ UNIT_ASSERT(future2.HasException());
+ UNIT_ASSERT(!future1.HasValue() && !future1.HasException());
+ }
+
+ Y_UNIT_TEST(WaitAllException) {
+ auto promise1 = NewPromise();
+ auto promise2 = NewPromise();
+ auto future1 = promise1.GetFuture();
+ auto future2 = promise2.GetFuture();
+ auto wait = WaitAll(future1, future2);
+ promise2.SetException("foo-exception");
+ UNIT_ASSERT(!wait.HasValue() && !wait.HasException());
+ promise1.SetValue();
+ UNIT_ASSERT_EXCEPTION_CONTAINS(wait.GetValueSync(), yexception, "foo-exception");
+ }
+
+ Y_UNIT_TEST(FutureStateId) {
+ TestFutureStateId<void>();
+ TestFutureStateId<int>();
+ }
+
+ template <typename T>
+ void TestApplyNoRvalueCopyImpl() {
+ size_t numCopies = 0;
+ TCopyCounter copyCounter(&numCopies);
+
+ auto promise = MakePromise<T>();
+
+ const auto future = promise.GetFuture().Apply(
+ [copyCounter = std::move(copyCounter)] (const auto&) {}
+ );
+
+ if constexpr (std::is_same_v<T, void>) {
+ promise.SetValue();
+ } else {
+ promise.SetValue(T());
+ }
+
+ future.GetValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(numCopies, 0);
+ }
+
+ Y_UNIT_TEST(ApplyNoRvalueCopy) {
+ TestApplyNoRvalueCopyImpl<void>();
+ TestApplyNoRvalueCopyImpl<int>();
+ }
+
+ template <typename T>
+ void TestApplyLvalueCopyImpl() {
+ size_t numCopies = 0;
+ TCopyCounter copyCounter(&numCopies);
+
+ auto promise = MakePromise<T>();
+
+ auto func = [copyCounter = std::move(copyCounter)] (const auto&) {};
+ const auto future = promise.GetFuture().Apply(func);
+
+ if constexpr (std::is_same_v<T, void>) {
+ promise.SetValue();
+ } else {
+ promise.SetValue(T());
+ }
+
+ future.GetValueSync();
+
+ UNIT_ASSERT_VALUES_EQUAL(numCopies, 1);
+ }
+
+ Y_UNIT_TEST(ApplyLvalueCopy) {
+ TestApplyLvalueCopyImpl<void>();
+ TestApplyLvalueCopyImpl<int>();
+ }
+ }
+
+}