#include "future.h"
#include <library/cpp/testing/unittest/registar.h>
#include <list>
#include <type_traits>
namespace NThreading {
namespace {
class TCopyCounter {
public:
TCopyCounter(size_t* numCopies)
: NumCopies(numCopies)
{}
TCopyCounter(const TCopyCounter& that)
: NumCopies(that.NumCopies)
{
++*NumCopies;
}
TCopyCounter& operator=(const TCopyCounter& that) {
NumCopies = that.NumCopies;
++*NumCopies;
return *this;
}
TCopyCounter(TCopyCounter&& that) = default;
TCopyCounter& operator=(TCopyCounter&& that) = default;
private:
size_t* NumCopies = nullptr;
};
template <typename T>
auto MakePromise() {
if constexpr (std::is_same_v<T, void>) {
return NewPromise();
}
return NewPromise<T>();
}
template <typename T>
void TestFutureStateId() {
TFuture<T> empty;
UNIT_ASSERT(!empty.StateId().Defined());
auto promise1 = MakePromise<T>();
auto future11 = promise1.GetFuture();
UNIT_ASSERT(future11.StateId().Defined());
auto future12 = promise1.GetFuture();
UNIT_ASSERT_EQUAL(future11.StateId(), future11.StateId()); // same result for subsequent invocations
UNIT_ASSERT_EQUAL(future11.StateId(), future12.StateId()); // same result for different futures with the same state
auto promise2 = MakePromise<T>();
auto future2 = promise2.GetFuture();
UNIT_ASSERT(future2.StateId().Defined());
UNIT_ASSERT_UNEQUAL(future11.StateId(), future2.StateId()); // different results for futures with different states
}
}
////////////////////////////////////////////////////////////////////////////////
Y_UNIT_TEST_SUITE(TFutureTest) {
Y_UNIT_TEST(ShouldInitiallyHasNoValue) {
TPromise<int> promise;
UNIT_ASSERT(!promise.HasValue());
promise = NewPromise<int>();
UNIT_ASSERT(!promise.HasValue());
TFuture<int> future;
UNIT_ASSERT(!future.HasValue());
future = promise.GetFuture();
UNIT_ASSERT(!future.HasValue());
}
Y_UNIT_TEST(ShouldInitiallyHasNoValueVoid) {
TPromise<void> promise;
UNIT_ASSERT(!promise.HasValue());
promise = NewPromise();
UNIT_ASSERT(!promise.HasValue());
TFuture<void> future;
UNIT_ASSERT(!future.HasValue());
future = promise.GetFuture();
UNIT_ASSERT(!future.HasValue());
}
Y_UNIT_TEST(ShouldStoreValue) {
TPromise<int> promise = NewPromise<int>();
promise.SetValue(123);
UNIT_ASSERT(promise.HasValue());
UNIT_ASSERT_EQUAL(promise.GetValue(), 123);
TFuture<int> future = promise.GetFuture();
UNIT_ASSERT(future.HasValue());
UNIT_ASSERT_EQUAL(future.GetValue(), 123);
future = MakeFuture(345);
UNIT_ASSERT(future.HasValue());
UNIT_ASSERT(future.IsReady());
UNIT_ASSERT_EQUAL(future.GetValue(), 345);
}
Y_UNIT_TEST(ShouldStoreValueVoid) {
TPromise<void> promise = NewPromise();
promise.SetValue();
UNIT_ASSERT(promise.HasValue());
TFuture<void> future = promise.GetFuture();
UNIT_ASSERT(future.HasValue());
UNIT_ASSERT(future.IsReady());
future = MakeFuture();
UNIT_ASSERT(future.HasValue());
}
struct TTestCallback {
int Value;
TTestCallback(int value)
: Value(value)
{
}
void Callback(const TFuture<int>& future) {
Value += future.GetValue();
}
int Func(const TFuture<int>& future) {
return (Value += future.GetValue());
}
void VoidFunc(const TFuture<int>& future) {
future.GetValue();
}
TFuture<int> FutureFunc(const TFuture<int>& future) {
return MakeFuture(Value += future.GetValue());
}
TPromise<void> Signal = NewPromise();
TFuture<void> FutureVoidFunc(const TFuture<int>& future) {
future.GetValue();
return Signal;
}
};
Y_UNIT_TEST(ShouldInvokeCallback) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<int> future = promise.GetFuture()
.Subscribe([&](const TFuture<int>& theFuture) { return callback.Callback(theFuture); });
promise.SetValue(456);
UNIT_ASSERT_EQUAL(future.GetValue(), 456);
UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
}
Y_UNIT_TEST(ShouldApplyFunc) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<int> future = promise.GetFuture()
.Apply([&](const auto& theFuture) { return callback.Func(theFuture); });
promise.SetValue(456);
UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456);
UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
}
Y_UNIT_TEST(ShouldApplyVoidFunc) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<void> future = promise.GetFuture()
.Apply([&](const auto& theFuture) { return callback.VoidFunc(theFuture); });
promise.SetValue(456);
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldApplyFutureFunc) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<int> future = promise.GetFuture()
.Apply([&](const auto& theFuture) { return callback.FutureFunc(theFuture); });
promise.SetValue(456);
UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456);
UNIT_ASSERT_EQUAL(callback.Value, 123 + 456);
}
Y_UNIT_TEST(ShouldApplyFutureVoidFunc) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<void> future = promise.GetFuture()
.Apply([&](const auto& theFuture) { return callback.FutureVoidFunc(theFuture); });
promise.SetValue(456);
UNIT_ASSERT(!future.HasValue());
callback.Signal.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldIgnoreResultIfAsked) {
TPromise<int> promise = NewPromise<int>();
TTestCallback callback(123);
TFuture<int> future = promise.GetFuture().IgnoreResult().Return(42);
promise.SetValue(456);
UNIT_ASSERT_EQUAL(future.GetValue(), 42);
}
class TCustomException: public yexception {
};
Y_UNIT_TEST(ShouldRethrowException) {
TPromise<int> promise = NewPromise<int>();
try {
ythrow TCustomException();
} catch (...) {
promise.SetException(std::current_exception());
}
UNIT_ASSERT(!promise.HasValue());
UNIT_ASSERT(promise.HasException());
UNIT_ASSERT_EXCEPTION(promise.GetValue(), TCustomException);
UNIT_ASSERT_EXCEPTION(promise.TryRethrow(), TCustomException);
}
Y_UNIT_TEST(ShouldRethrowCallbackException) {
TPromise<int> promise = NewPromise<int>();
TFuture<int> future = promise.GetFuture();
future.Subscribe([](const TFuture<int>&) {
throw TCustomException();
});
UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException);
}
Y_UNIT_TEST(ShouldRethrowCallbackExceptionIgnoreResult) {
TPromise<int> promise = NewPromise<int>();
TFuture<void> future = promise.GetFuture().IgnoreResult();
future.Subscribe([](const TFuture<void>&) {
throw TCustomException();
});
UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException);
}
Y_UNIT_TEST(ShouldWaitExceptionOrAll) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
TFuture<void> future = WaitExceptionOrAll(promise1, promise2);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(!future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitExceptionOrAllVector) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
TVector<TFuture<void>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitExceptionOrAll(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(!future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorWithValueType) {
TPromise<int> promise1 = NewPromise<int>();
TPromise<int> promise2 = NewPromise<int>();
TVector<TFuture<int>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitExceptionOrAll(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue(0);
UNIT_ASSERT(!future.HasValue());
promise2.SetValue(0);
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitExceptionOrAllList) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
std::list<TFuture<void>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitExceptionOrAll(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(!future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorEmpty) {
TVector<TFuture<void>> promises;
TFuture<void> future = WaitExceptionOrAll(promises);
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitAnyVector) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
TVector<TFuture<void>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitAny(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitAnyVectorWithValueType) {
TPromise<int> promise1 = NewPromise<int>();
TPromise<int> promise2 = NewPromise<int>();
TVector<TFuture<int>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitAny(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue(0);
UNIT_ASSERT(future.HasValue());
promise2.SetValue(0);
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitAnyList) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
std::list<TFuture<void>> promises;
promises.push_back(promise1);
promises.push_back(promise2);
TFuture<void> future = WaitAny(promises);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitAnyVectorEmpty) {
TVector<TFuture<void>> promises;
TFuture<void> future = WaitAny(promises);
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldWaitAny) {
TPromise<void> promise1 = NewPromise();
TPromise<void> promise2 = NewPromise();
TFuture<void> future = WaitAny(promise1, promise2);
UNIT_ASSERT(!future.HasValue());
promise1.SetValue();
UNIT_ASSERT(future.HasValue());
promise2.SetValue();
UNIT_ASSERT(future.HasValue());
}
Y_UNIT_TEST(ShouldStoreTypesWithoutDefaultConstructor) {
// compileability test
struct TRec {
explicit TRec(int) {
}
};
auto promise = NewPromise<TRec>();
promise.SetValue(TRec(1));
auto future = MakeFuture(TRec(1));
const auto& rec = future.GetValue();
Y_UNUSED(rec);
}
Y_UNIT_TEST(ShouldStoreMovableTypes) {
// compileability test
struct TRec : TMoveOnly {
explicit TRec(int) {
}
};
auto promise = NewPromise<TRec>();
promise.SetValue(TRec(1));
auto future = MakeFuture(TRec(1));
const auto& rec = future.GetValue();
Y_UNUSED(rec);
}
Y_UNIT_TEST(ShouldMoveMovableTypes) {
// compileability test
struct TRec : TMoveOnly {
explicit TRec(int) {
}
};
auto promise = NewPromise<TRec>();
promise.SetValue(TRec(1));
auto future = MakeFuture(TRec(1));
auto rec = future.ExtractValue();
Y_UNUSED(rec);
}
Y_UNIT_TEST(ShouldNotExtractAfterGet) {
TPromise<int> promise = NewPromise<int>();
promise.SetValue(123);
UNIT_ASSERT(promise.HasValue());
UNIT_ASSERT_EQUAL(promise.GetValue(), 123);
UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException);
}
Y_UNIT_TEST(ShouldNotGetAfterExtract) {
TPromise<int> promise = NewPromise<int>();
promise.SetValue(123);
UNIT_ASSERT(promise.HasValue());
UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123);
UNIT_CHECK_GENERATED_EXCEPTION(promise.GetValue(), TFutureException);
}
Y_UNIT_TEST(ShouldNotExtractAfterExtract) {
TPromise<int> promise = NewPromise<int>();
promise.SetValue(123);
UNIT_ASSERT(promise.HasValue());
UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123);
UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException);
}
Y_UNIT_TEST(ShouldNotExtractFromSharedDefault) {
UNIT_CHECK_GENERATED_EXCEPTION(MakeFuture<int>().ExtractValue(), TFutureException);
struct TStorage {
TString String = TString(100, 'a');
};
try {
TString s = MakeFuture<TStorage>().ExtractValue().String;
Y_UNUSED(s);
} catch (TFutureException) {
// pass
}
UNIT_ASSERT_VALUES_EQUAL(MakeFuture<TStorage>().GetValue().String, TString(100, 'a'));
}
Y_UNIT_TEST(HandlingRepetitiveSet) {
TPromise<int> promise = NewPromise<int>();
promise.SetValue(42);
UNIT_CHECK_GENERATED_EXCEPTION(promise.SetValue(42), TFutureException);
}
Y_UNIT_TEST(HandlingRepetitiveTrySet) {
TPromise<int> promise = NewPromise<int>();
UNIT_ASSERT(promise.TrySetValue(42));
UNIT_ASSERT(!promise.TrySetValue(42));
}
Y_UNIT_TEST(HandlingRepetitiveSetException) {
TPromise<int> promise = NewPromise<int>();
promise.SetException("test");
UNIT_CHECK_GENERATED_EXCEPTION(promise.SetException("test"), TFutureException);
}
Y_UNIT_TEST(HandlingRepetitiveTrySetException) {
TPromise<int> promise = NewPromise<int>();
UNIT_ASSERT(promise.TrySetException(std::make_exception_ptr("test")));
UNIT_ASSERT(!promise.TrySetException(std::make_exception_ptr("test")));
}
Y_UNIT_TEST(ShouldAllowToMakeFutureWithException)
{
auto future1 = MakeErrorFuture<void>(std::make_exception_ptr(TFutureException()));
UNIT_ASSERT(future1.HasException());
UNIT_ASSERT(future1.IsReady());
UNIT_CHECK_GENERATED_EXCEPTION(future1.GetValue(), TFutureException);
auto future2 = MakeErrorFuture<int>(std::make_exception_ptr(TFutureException()));
UNIT_ASSERT(future2.HasException());
UNIT_CHECK_GENERATED_EXCEPTION(future2.GetValue(), TFutureException);
auto future3 = MakeFuture<std::exception_ptr>(std::make_exception_ptr(TFutureException()));
UNIT_ASSERT(future3.HasValue());
UNIT_CHECK_GENERATED_NO_EXCEPTION(future3.GetValue(), TFutureException);
auto future4 = MakeFuture<std::unique_ptr<int>>(nullptr);
UNIT_ASSERT(future4.HasValue());
UNIT_CHECK_GENERATED_NO_EXCEPTION(future4.GetValue(), TFutureException);
}
Y_UNIT_TEST(WaitAllowsExtract) {
auto future = MakeFuture<int>(42);
TVector vec{future, future, future};
WaitExceptionOrAll(vec).GetValue();
WaitAny(vec).GetValue();
UNIT_ASSERT_EQUAL(future.ExtractValue(), 42);
}
Y_UNIT_TEST(IgnoreAllowsExtract) {
auto future = MakeFuture<int>(42);
future.IgnoreResult().GetValue();
UNIT_ASSERT_EQUAL(future.ExtractValue(), 42);
}
Y_UNIT_TEST(WaitExceptionOrAllException) {
auto promise1 = NewPromise();
auto promise2 = NewPromise();
auto future1 = promise1.GetFuture();
auto future2 = promise2.GetFuture();
auto wait = WaitExceptionOrAll(future1, future2);
promise2.SetException("foo-exception");
wait.Wait();
UNIT_ASSERT(future2.HasException());
UNIT_ASSERT(!future1.IsReady());
UNIT_ASSERT(!future1.HasValue() && !future1.HasException());
}
Y_UNIT_TEST(WaitAllException) {
auto promise1 = NewPromise();
auto promise2 = NewPromise();
auto future1 = promise1.GetFuture();
auto future2 = promise2.GetFuture();
auto wait = WaitAll(future1, future2);
promise2.SetException("foo-exception");
UNIT_ASSERT(!wait.HasValue() && !wait.HasException());
promise1.SetValue();
UNIT_ASSERT_EXCEPTION_CONTAINS(wait.GetValueSync(), yexception, "foo-exception");
}
Y_UNIT_TEST(FutureStateId) {
TestFutureStateId<void>();
TestFutureStateId<int>();
}
template <typename T>
void TestApplyNoRvalueCopyImpl() {
size_t numCopies = 0;
TCopyCounter copyCounter(&numCopies);
auto promise = MakePromise<T>();
const auto future = promise.GetFuture().Apply(
[copyCounter = std::move(copyCounter)] (const auto&) {}
);
if constexpr (std::is_same_v<T, void>) {
promise.SetValue();
} else {
promise.SetValue(T());
}
future.GetValueSync();
UNIT_ASSERT_VALUES_EQUAL(numCopies, 0);
}
Y_UNIT_TEST(ApplyNoRvalueCopy) {
TestApplyNoRvalueCopyImpl<void>();
TestApplyNoRvalueCopyImpl<int>();
}
template <typename T>
void TestApplyLvalueCopyImpl() {
size_t numCopies = 0;
TCopyCounter copyCounter(&numCopies);
auto promise = MakePromise<T>();
auto func = [copyCounter = std::move(copyCounter)] (const auto&) {};
const auto future = promise.GetFuture().Apply(func);
if constexpr (std::is_same_v<T, void>) {
promise.SetValue();
} else {
promise.SetValue(T());
}
future.GetValueSync();
UNIT_ASSERT_VALUES_EQUAL(numCopies, 1);
}
Y_UNIT_TEST(ApplyLvalueCopy) {
TestApplyLvalueCopyImpl<void>();
TestApplyLvalueCopyImpl<int>();
}
Y_UNIT_TEST(ReturnForwardingTypeDeduction) {
const TString e = TString(80, 'a');
TString l = TString(80, 'a');
TFuture<TString> futureL = MakeFuture().Return(l);
UNIT_ASSERT_VALUES_EQUAL(futureL.GetValue(), e);
UNIT_ASSERT_VALUES_EQUAL(l, e);
TFuture<TString> futureR = MakeFuture().Return(std::move(l));
UNIT_ASSERT_VALUES_EQUAL(futureR.GetValue(), e);
}
Y_UNIT_TEST(ReturnForwardingCopiesCount) {
size_t numCopies = 0;
TCopyCounter copyCounter(&numCopies);
auto returnedCounter = MakeFuture().Return(std::move(copyCounter)).ExtractValueSync();
Y_DO_NOT_OPTIMIZE_AWAY(returnedCounter);
UNIT_ASSERT_VALUES_EQUAL(numCopies, 0);
}
}
}