diff options
author | snaury <[email protected]> | 2025-06-19 13:24:45 +0300 |
---|---|---|
committer | snaury <[email protected]> | 2025-06-19 14:12:36 +0300 |
commit | 11434faf28fda9ddc192bf1453dfec7980f88357 (patch) | |
tree | 7294a14ef298897f6a66e114cfa5ce46db6ff75d | |
parent | e7daa2b33914fddad263236d2b0b8106e13791cd (diff) |
Fix object destruction order when using TFuture<T> coroutines
commit_hash:683c797584872e45e8df2ad7c663f1f1ebb253e3
-rw-r--r-- | library/cpp/threading/future/core/coroutine_traits.h | 56 | ||||
-rw-r--r-- | library/cpp/threading/future/core/future-inl.h | 58 | ||||
-rw-r--r-- | library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp | 85 |
3 files changed, 157 insertions, 42 deletions
diff --git a/library/cpp/threading/future/core/coroutine_traits.h b/library/cpp/threading/future/core/coroutine_traits.h index cdd3eeeff3e..de08c377b16 100644 --- a/library/cpp/threading/future/core/coroutine_traits.h +++ b/library/cpp/threading/future/core/coroutine_traits.h @@ -8,46 +8,74 @@ template <typename... Args> struct std::coroutine_traits<NThreading::TFuture<void>, Args...> { struct promise_type { - NThreading::TFuture<void> get_return_object() { - return Promise_.GetFuture(); + NThreading::TFuture<void> get_return_object() noexcept { + return NThreading::TFuture<void>(State_); } - std::suspend_never initial_suspend() { return {}; } - std::suspend_never final_suspend() noexcept { return {}; } + struct TFinalSuspend { + bool await_ready() noexcept { return false; } + void await_resume() noexcept { /* never called */ } + void await_suspend(std::coroutine_handle<promise_type> self) noexcept { + auto state = std::move(self.promise().State_); + // We must destroy the coroutine before running callbacks + // This will make sure argument copies are destroyed before the caller is resumed + self.destroy(); + state->RunCallbacks(); + } + }; + + std::suspend_never initial_suspend() noexcept { return {}; } + TFinalSuspend final_suspend() noexcept { return {}; } void unhandled_exception() { - Promise_.SetException(std::current_exception()); + bool success = State_->TrySetException(std::current_exception(), /* deferCallbacks */ true); + Y_ASSERT(success && "value already set"); } void return_void() { - Promise_.SetValue(); + bool success = State_->TrySetValue(/* deferCallbacks */ true); + Y_ASSERT(success && "value already set"); } private: - NThreading::TPromise<void> Promise_ = NThreading::NewPromise(); + TIntrusivePtr<NThreading::NImpl::TFutureState<void>> State_{new NThreading::NImpl::TFutureState<void>()}; }; }; template <typename T, typename... Args> struct std::coroutine_traits<NThreading::TFuture<T>, Args...> { struct promise_type { - NThreading::TFuture<T> get_return_object() { - return Promise_.GetFuture(); + NThreading::TFuture<T> get_return_object() noexcept { + return NThreading::TFuture<T>(State_); } - std::suspend_never initial_suspend() { return {}; } - std::suspend_never final_suspend() noexcept { return {}; } + struct TFinalSuspend { + bool await_ready() noexcept { return false; } + void await_resume() noexcept { /* never called */ } + void await_suspend(std::coroutine_handle<promise_type> self) noexcept { + auto state = std::move(self.promise().State_); + // We must destroy the coroutine before running callbacks + // This will make sure argument copies are destroyed before the caller is resumed + self.destroy(); + state->RunCallbacks(); + } + }; + + std::suspend_never initial_suspend() noexcept { return {}; } + TFinalSuspend final_suspend() noexcept { return {}; } void unhandled_exception() { - Promise_.SetException(std::current_exception()); + bool success = State_->TrySetException(std::current_exception(), /* deferCallbacks */ true); + Y_ASSERT(success && "value already set"); } void return_value(auto&& val) { - Promise_.SetValue(std::forward<decltype(val)>(val)); + bool success = State_->TrySetValue(std::forward<decltype(val)>(val), /* deferCallbacks */ true); + Y_ASSERT(success && "value already set"); } private: - NThreading::TPromise<T> Promise_ = NThreading::NewPromise<T>(); + TIntrusivePtr<NThreading::NImpl::TFutureState<T>> State_{new NThreading::NImpl::TFutureState<T>()}; }; }; diff --git a/library/cpp/threading/future/core/future-inl.h b/library/cpp/threading/future/core/future-inl.h index 1ce1cbd4e21..74d227f71c8 100644 --- a/library/cpp/threading/future/core/future-inl.h +++ b/library/cpp/threading/future/core/future-inl.h @@ -140,9 +140,8 @@ namespace NThreading { } template <typename TT> - bool TrySetValue(TT&& value) { + bool TrySetValue(TT&& value, bool deferCallbacks = false) { TSystemEvent* readyEvent = nullptr; - TCallbackList<T> callbacks; with_lock (StateLock) { TAtomicBase state = AtomicGet(State); @@ -153,7 +152,6 @@ namespace NThreading { new (&Value) T(std::forward<TT>(value)); readyEvent = ReadyEvent.Get(); - callbacks = std::move(Callbacks); AtomicSet(State, ValueSet); } @@ -162,11 +160,8 @@ namespace NThreading { readyEvent->Signal(); } - if (callbacks) { - TFuture<T> temp(this); - for (auto& callback : callbacks) { - callback(temp); - } + if (!deferCallbacks) { + RunCallbacks(); } return true; @@ -179,9 +174,8 @@ namespace NThreading { } } - bool TrySetException(std::exception_ptr e) { + bool TrySetException(std::exception_ptr e, bool deferCallbacks = false) { TSystemEvent* readyEvent; - TCallbackList<T> callbacks; with_lock (StateLock) { TAtomicBase state = AtomicGet(State); @@ -192,7 +186,6 @@ namespace NThreading { Exception = std::move(e); readyEvent = ReadyEvent.Get(); - callbacks = std::move(Callbacks); AtomicSet(State, ExceptionSet); } @@ -201,14 +194,22 @@ namespace NThreading { readyEvent->Signal(); } - if (callbacks) { + if (!deferCallbacks) { + RunCallbacks(); + } + + return true; + } + + void RunCallbacks() { + Y_ASSERT(AtomicGet(State) != NotReady); + if (!Callbacks.empty()) { + TCallbackList<T> callbacks = std::move(Callbacks); TFuture<T> temp(this); for (auto& callback : callbacks) { callback(temp); } } - - return true; } template <typename F> @@ -331,9 +332,8 @@ namespace NThreading { } } - bool TrySetValue() { + bool TrySetValue(bool deferCallbacks = false) { TSystemEvent* readyEvent = nullptr; - TCallbackList<void> callbacks; with_lock (StateLock) { TAtomicBase state = AtomicGet(State); @@ -342,7 +342,6 @@ namespace NThreading { } readyEvent = ReadyEvent.Get(); - callbacks = std::move(Callbacks); AtomicSet(State, ValueSet); } @@ -351,11 +350,8 @@ namespace NThreading { readyEvent->Signal(); } - if (callbacks) { - TFuture<void> temp(this); - for (auto& callback : callbacks) { - callback(temp); - } + if (!deferCallbacks) { + RunCallbacks(); } return true; @@ -368,9 +364,8 @@ namespace NThreading { } } - bool TrySetException(std::exception_ptr e) { + bool TrySetException(std::exception_ptr e, bool deferCallbacks = false) { TSystemEvent* readyEvent = nullptr; - TCallbackList<void> callbacks; with_lock (StateLock) { TAtomicBase state = AtomicGet(State); @@ -381,7 +376,6 @@ namespace NThreading { Exception = std::move(e); readyEvent = ReadyEvent.Get(); - callbacks = std::move(Callbacks); AtomicSet(State, ExceptionSet); } @@ -390,14 +384,22 @@ namespace NThreading { readyEvent->Signal(); } - if (callbacks) { + if (!deferCallbacks) { + RunCallbacks(); + } + + return true; + } + + void RunCallbacks() { + Y_ASSERT(AtomicGet(State) != NotReady); + if (!Callbacks.empty()) { + TCallbackList<void> callbacks = std::move(Callbacks); TFuture<void> temp(this); for (auto& callback : callbacks) { callback(temp); } } - - return true; } template <typename F> diff --git a/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp b/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp index 2daa1f5e47a..4b3c6135a53 100644 --- a/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp +++ b/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp @@ -188,6 +188,91 @@ TEST(TestFutureTraits, CrashOnExceptionInCoroutineHandlerResume) { ); } +TEST(TestFutureTraits, DestructorOrder) { + class TTrackedValue { + public: + TTrackedValue(TVector<TString>& result, TString name) + : Result(result) + , Name(std::move(name)) + { + Result.push_back(Name + " constructed"); + } + + TTrackedValue(TTrackedValue&& rhs) + : Result(rhs.Result) + , Name(std::move(rhs.Name)) + { + Result.push_back(Name + " moved"); + rhs.Name.clear(); + } + + ~TTrackedValue() { + if (!Name.empty()) { + Result.push_back(Name + " destroyed"); + } + } + + private: + TVector<TString>& Result; + TString Name; + }; + + TVector<TString> result; + NThreading::TPromise<void> promise = NThreading::NewPromise<void>(); + NThreading::TFuture<void> future = promise.GetFuture(); + + auto coroutine1 = [&](TTrackedValue arg) -> NThreading::TFuture<TString> { + TTrackedValue a(result, "local a"); + result.push_back("before co_await future"); + co_await future; + result.push_back("after co_await future"); + Y_UNUSED(arg); + co_return "42"; + }; + + auto coroutine2 = [&]() -> NThreading::TFuture<void> { + TTrackedValue b(result, "local b"); + result.push_back("before co_await coroutine1(...)"); + TString value = co_await coroutine1(TTrackedValue(result, "arg")); + result.push_back("after co_await coroutine1(...)"); + result.push_back("value = " + value); + }; + + result.push_back("before coroutine2()"); + auto future2 = coroutine2(); + result.push_back("after coroutine2()"); + EXPECT_FALSE(future2.HasValue() || future2.HasException()); + future2.Subscribe([&](const auto&) { + result.push_back("in coroutine2() callback"); + }); + + promise.SetValue(); + EXPECT_TRUE(future2.HasValue()); + + EXPECT_THAT( + result, + ::testing::ContainerEq( + TVector<TString>({ + "before coroutine2()", + "local b constructed", + "before co_await coroutine1(...)", + "arg constructed", + "arg moved", + "local a constructed", + "before co_await future", + "after coroutine2()", + "after co_await future", + "local a destroyed", + "arg destroyed", + "after co_await coroutine1(...)", + "value = 42", + "local b destroyed", + "in coroutine2() callback", + }) + ) + ); +} + TEST(ExtractingFutureAwaitable, Simple) { NThreading::TPromise<THolder<size_t>> suspendPromise = NThreading::NewPromise<THolder<size_t>>(); auto coro = [](NThreading::TFuture<THolder<size_t>> future) -> NThreading::TFuture<THolder<size_t>> { |