summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorsnaury <[email protected]>2025-06-19 13:24:45 +0300
committersnaury <[email protected]>2025-06-19 14:12:36 +0300
commit11434faf28fda9ddc192bf1453dfec7980f88357 (patch)
tree7294a14ef298897f6a66e114cfa5ce46db6ff75d
parente7daa2b33914fddad263236d2b0b8106e13791cd (diff)
Fix object destruction order when using TFuture<T> coroutines
commit_hash:683c797584872e45e8df2ad7c663f1f1ebb253e3
-rw-r--r--library/cpp/threading/future/core/coroutine_traits.h56
-rw-r--r--library/cpp/threading/future/core/future-inl.h58
-rw-r--r--library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp85
3 files changed, 157 insertions, 42 deletions
diff --git a/library/cpp/threading/future/core/coroutine_traits.h b/library/cpp/threading/future/core/coroutine_traits.h
index cdd3eeeff3e..de08c377b16 100644
--- a/library/cpp/threading/future/core/coroutine_traits.h
+++ b/library/cpp/threading/future/core/coroutine_traits.h
@@ -8,46 +8,74 @@ template <typename... Args>
struct std::coroutine_traits<NThreading::TFuture<void>, Args...> {
struct promise_type {
- NThreading::TFuture<void> get_return_object() {
- return Promise_.GetFuture();
+ NThreading::TFuture<void> get_return_object() noexcept {
+ return NThreading::TFuture<void>(State_);
}
- std::suspend_never initial_suspend() { return {}; }
- std::suspend_never final_suspend() noexcept { return {}; }
+ struct TFinalSuspend {
+ bool await_ready() noexcept { return false; }
+ void await_resume() noexcept { /* never called */ }
+ void await_suspend(std::coroutine_handle<promise_type> self) noexcept {
+ auto state = std::move(self.promise().State_);
+ // We must destroy the coroutine before running callbacks
+ // This will make sure argument copies are destroyed before the caller is resumed
+ self.destroy();
+ state->RunCallbacks();
+ }
+ };
+
+ std::suspend_never initial_suspend() noexcept { return {}; }
+ TFinalSuspend final_suspend() noexcept { return {}; }
void unhandled_exception() {
- Promise_.SetException(std::current_exception());
+ bool success = State_->TrySetException(std::current_exception(), /* deferCallbacks */ true);
+ Y_ASSERT(success && "value already set");
}
void return_void() {
- Promise_.SetValue();
+ bool success = State_->TrySetValue(/* deferCallbacks */ true);
+ Y_ASSERT(success && "value already set");
}
private:
- NThreading::TPromise<void> Promise_ = NThreading::NewPromise();
+ TIntrusivePtr<NThreading::NImpl::TFutureState<void>> State_{new NThreading::NImpl::TFutureState<void>()};
};
};
template <typename T, typename... Args>
struct std::coroutine_traits<NThreading::TFuture<T>, Args...> {
struct promise_type {
- NThreading::TFuture<T> get_return_object() {
- return Promise_.GetFuture();
+ NThreading::TFuture<T> get_return_object() noexcept {
+ return NThreading::TFuture<T>(State_);
}
- std::suspend_never initial_suspend() { return {}; }
- std::suspend_never final_suspend() noexcept { return {}; }
+ struct TFinalSuspend {
+ bool await_ready() noexcept { return false; }
+ void await_resume() noexcept { /* never called */ }
+ void await_suspend(std::coroutine_handle<promise_type> self) noexcept {
+ auto state = std::move(self.promise().State_);
+ // We must destroy the coroutine before running callbacks
+ // This will make sure argument copies are destroyed before the caller is resumed
+ self.destroy();
+ state->RunCallbacks();
+ }
+ };
+
+ std::suspend_never initial_suspend() noexcept { return {}; }
+ TFinalSuspend final_suspend() noexcept { return {}; }
void unhandled_exception() {
- Promise_.SetException(std::current_exception());
+ bool success = State_->TrySetException(std::current_exception(), /* deferCallbacks */ true);
+ Y_ASSERT(success && "value already set");
}
void return_value(auto&& val) {
- Promise_.SetValue(std::forward<decltype(val)>(val));
+ bool success = State_->TrySetValue(std::forward<decltype(val)>(val), /* deferCallbacks */ true);
+ Y_ASSERT(success && "value already set");
}
private:
- NThreading::TPromise<T> Promise_ = NThreading::NewPromise<T>();
+ TIntrusivePtr<NThreading::NImpl::TFutureState<T>> State_{new NThreading::NImpl::TFutureState<T>()};
};
};
diff --git a/library/cpp/threading/future/core/future-inl.h b/library/cpp/threading/future/core/future-inl.h
index 1ce1cbd4e21..74d227f71c8 100644
--- a/library/cpp/threading/future/core/future-inl.h
+++ b/library/cpp/threading/future/core/future-inl.h
@@ -140,9 +140,8 @@ namespace NThreading {
}
template <typename TT>
- bool TrySetValue(TT&& value) {
+ bool TrySetValue(TT&& value, bool deferCallbacks = false) {
TSystemEvent* readyEvent = nullptr;
- TCallbackList<T> callbacks;
with_lock (StateLock) {
TAtomicBase state = AtomicGet(State);
@@ -153,7 +152,6 @@ namespace NThreading {
new (&Value) T(std::forward<TT>(value));
readyEvent = ReadyEvent.Get();
- callbacks = std::move(Callbacks);
AtomicSet(State, ValueSet);
}
@@ -162,11 +160,8 @@ namespace NThreading {
readyEvent->Signal();
}
- if (callbacks) {
- TFuture<T> temp(this);
- for (auto& callback : callbacks) {
- callback(temp);
- }
+ if (!deferCallbacks) {
+ RunCallbacks();
}
return true;
@@ -179,9 +174,8 @@ namespace NThreading {
}
}
- bool TrySetException(std::exception_ptr e) {
+ bool TrySetException(std::exception_ptr e, bool deferCallbacks = false) {
TSystemEvent* readyEvent;
- TCallbackList<T> callbacks;
with_lock (StateLock) {
TAtomicBase state = AtomicGet(State);
@@ -192,7 +186,6 @@ namespace NThreading {
Exception = std::move(e);
readyEvent = ReadyEvent.Get();
- callbacks = std::move(Callbacks);
AtomicSet(State, ExceptionSet);
}
@@ -201,14 +194,22 @@ namespace NThreading {
readyEvent->Signal();
}
- if (callbacks) {
+ if (!deferCallbacks) {
+ RunCallbacks();
+ }
+
+ return true;
+ }
+
+ void RunCallbacks() {
+ Y_ASSERT(AtomicGet(State) != NotReady);
+ if (!Callbacks.empty()) {
+ TCallbackList<T> callbacks = std::move(Callbacks);
TFuture<T> temp(this);
for (auto& callback : callbacks) {
callback(temp);
}
}
-
- return true;
}
template <typename F>
@@ -331,9 +332,8 @@ namespace NThreading {
}
}
- bool TrySetValue() {
+ bool TrySetValue(bool deferCallbacks = false) {
TSystemEvent* readyEvent = nullptr;
- TCallbackList<void> callbacks;
with_lock (StateLock) {
TAtomicBase state = AtomicGet(State);
@@ -342,7 +342,6 @@ namespace NThreading {
}
readyEvent = ReadyEvent.Get();
- callbacks = std::move(Callbacks);
AtomicSet(State, ValueSet);
}
@@ -351,11 +350,8 @@ namespace NThreading {
readyEvent->Signal();
}
- if (callbacks) {
- TFuture<void> temp(this);
- for (auto& callback : callbacks) {
- callback(temp);
- }
+ if (!deferCallbacks) {
+ RunCallbacks();
}
return true;
@@ -368,9 +364,8 @@ namespace NThreading {
}
}
- bool TrySetException(std::exception_ptr e) {
+ bool TrySetException(std::exception_ptr e, bool deferCallbacks = false) {
TSystemEvent* readyEvent = nullptr;
- TCallbackList<void> callbacks;
with_lock (StateLock) {
TAtomicBase state = AtomicGet(State);
@@ -381,7 +376,6 @@ namespace NThreading {
Exception = std::move(e);
readyEvent = ReadyEvent.Get();
- callbacks = std::move(Callbacks);
AtomicSet(State, ExceptionSet);
}
@@ -390,14 +384,22 @@ namespace NThreading {
readyEvent->Signal();
}
- if (callbacks) {
+ if (!deferCallbacks) {
+ RunCallbacks();
+ }
+
+ return true;
+ }
+
+ void RunCallbacks() {
+ Y_ASSERT(AtomicGet(State) != NotReady);
+ if (!Callbacks.empty()) {
+ TCallbackList<void> callbacks = std::move(Callbacks);
TFuture<void> temp(this);
for (auto& callback : callbacks) {
callback(temp);
}
}
-
- return true;
}
template <typename F>
diff --git a/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp b/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp
index 2daa1f5e47a..4b3c6135a53 100644
--- a/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp
+++ b/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp
@@ -188,6 +188,91 @@ TEST(TestFutureTraits, CrashOnExceptionInCoroutineHandlerResume) {
);
}
+TEST(TestFutureTraits, DestructorOrder) {
+ class TTrackedValue {
+ public:
+ TTrackedValue(TVector<TString>& result, TString name)
+ : Result(result)
+ , Name(std::move(name))
+ {
+ Result.push_back(Name + " constructed");
+ }
+
+ TTrackedValue(TTrackedValue&& rhs)
+ : Result(rhs.Result)
+ , Name(std::move(rhs.Name))
+ {
+ Result.push_back(Name + " moved");
+ rhs.Name.clear();
+ }
+
+ ~TTrackedValue() {
+ if (!Name.empty()) {
+ Result.push_back(Name + " destroyed");
+ }
+ }
+
+ private:
+ TVector<TString>& Result;
+ TString Name;
+ };
+
+ TVector<TString> result;
+ NThreading::TPromise<void> promise = NThreading::NewPromise<void>();
+ NThreading::TFuture<void> future = promise.GetFuture();
+
+ auto coroutine1 = [&](TTrackedValue arg) -> NThreading::TFuture<TString> {
+ TTrackedValue a(result, "local a");
+ result.push_back("before co_await future");
+ co_await future;
+ result.push_back("after co_await future");
+ Y_UNUSED(arg);
+ co_return "42";
+ };
+
+ auto coroutine2 = [&]() -> NThreading::TFuture<void> {
+ TTrackedValue b(result, "local b");
+ result.push_back("before co_await coroutine1(...)");
+ TString value = co_await coroutine1(TTrackedValue(result, "arg"));
+ result.push_back("after co_await coroutine1(...)");
+ result.push_back("value = " + value);
+ };
+
+ result.push_back("before coroutine2()");
+ auto future2 = coroutine2();
+ result.push_back("after coroutine2()");
+ EXPECT_FALSE(future2.HasValue() || future2.HasException());
+ future2.Subscribe([&](const auto&) {
+ result.push_back("in coroutine2() callback");
+ });
+
+ promise.SetValue();
+ EXPECT_TRUE(future2.HasValue());
+
+ EXPECT_THAT(
+ result,
+ ::testing::ContainerEq(
+ TVector<TString>({
+ "before coroutine2()",
+ "local b constructed",
+ "before co_await coroutine1(...)",
+ "arg constructed",
+ "arg moved",
+ "local a constructed",
+ "before co_await future",
+ "after coroutine2()",
+ "after co_await future",
+ "local a destroyed",
+ "arg destroyed",
+ "after co_await coroutine1(...)",
+ "value = 42",
+ "local b destroyed",
+ "in coroutine2() callback",
+ })
+ )
+ );
+}
+
TEST(ExtractingFutureAwaitable, Simple) {
NThreading::TPromise<THolder<size_t>> suspendPromise = NThreading::NewPromise<THolder<size_t>>();
auto coro = [](NThreading::TFuture<THolder<size_t>> future) -> NThreading::TFuture<THolder<size_t>> {