aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp
diff options
context:
space:
mode:
authorsnaury <snaury@ydb.tech>2023-08-08 17:18:33 +0300
committersnaury <snaury@ydb.tech>2023-08-08 18:14:19 +0300
commite8f301ee51eb89ff10308d312557586ac0261c3d (patch)
tree5f8b923d708622d228d1f6e6387cd715154c6f03 /library/cpp
parentf95fc3633ff92fc77fbc6220aa9b3653d0df412e (diff)
downloadydb-e8f301ee51eb89ff10308d312557586ac0261c3d.tar.gz
Better C++ coroutine lifetime in actors KIKIMR-18962
Diffstat (limited to 'library/cpp')
-rw-r--r--library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt2
-rw-r--r--library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt2
-rw-r--r--library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt2
-rw-r--r--library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt2
-rw-r--r--library/cpp/actors/cppcoro/await_callback.h16
-rw-r--r--library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt31
-rw-r--r--library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt34
-rw-r--r--library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt36
-rw-r--r--library/cpp/actors/cppcoro/benchmark/CMakeLists.txt17
-rw-r--r--library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt24
-rw-r--r--library/cpp/actors/cppcoro/benchmark/main.cpp76
-rw-r--r--library/cpp/actors/cppcoro/benchmark/ya.make11
-rw-r--r--library/cpp/actors/cppcoro/task.h315
-rw-r--r--library/cpp/actors/cppcoro/task_actor.cpp76
-rw-r--r--library/cpp/actors/cppcoro/task_actor.h40
-rw-r--r--library/cpp/actors/cppcoro/task_group.h161
-rw-r--r--library/cpp/actors/cppcoro/task_result.cpp1
-rw-r--r--library/cpp/actors/cppcoro/task_result.h113
-rw-r--r--library/cpp/actors/cppcoro/task_ut.cpp200
-rw-r--r--library/cpp/actors/cppcoro/ya.make3
20 files changed, 724 insertions, 438 deletions
diff --git a/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt
index ecac0aa784..f27bdd46d9 100644
--- a/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt
+++ b/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt
@@ -6,6 +6,7 @@
# original buildsystem will not be accepted.
+add_subdirectory(benchmark)
add_subdirectory(ut)
add_library(cpp-actors-cppcoro)
@@ -18,5 +19,6 @@ target_sources(cpp-actors-cppcoro PRIVATE
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp
)
diff --git a/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt b/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt
index ff385af6fa..a7c6669d6f 100644
--- a/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt
+++ b/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt
@@ -6,6 +6,7 @@
# original buildsystem will not be accepted.
+add_subdirectory(benchmark)
add_subdirectory(ut)
add_library(cpp-actors-cppcoro)
@@ -19,5 +20,6 @@ target_sources(cpp-actors-cppcoro PRIVATE
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp
)
diff --git a/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt
index ff385af6fa..a7c6669d6f 100644
--- a/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt
+++ b/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt
@@ -6,6 +6,7 @@
# original buildsystem will not be accepted.
+add_subdirectory(benchmark)
add_subdirectory(ut)
add_library(cpp-actors-cppcoro)
@@ -19,5 +20,6 @@ target_sources(cpp-actors-cppcoro PRIVATE
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp
)
diff --git a/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt
index ecac0aa784..f27bdd46d9 100644
--- a/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt
+++ b/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt
@@ -6,6 +6,7 @@
# original buildsystem will not be accepted.
+add_subdirectory(benchmark)
add_subdirectory(ut)
add_library(cpp-actors-cppcoro)
@@ -18,5 +19,6 @@ target_sources(cpp-actors-cppcoro PRIVATE
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp
${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp
)
diff --git a/library/cpp/actors/cppcoro/await_callback.h b/library/cpp/actors/cppcoro/await_callback.h
index 9f23d5e0db..fcb2eb78f9 100644
--- a/library/cpp/actors/cppcoro/await_callback.h
+++ b/library/cpp/actors/cppcoro/await_callback.h
@@ -5,6 +5,7 @@
namespace NActors {
namespace NDetail {
+
template<class TAwaitable>
decltype(auto) GetAwaiter(TAwaitable&& awaitable) {
if constexpr (requires { ((TAwaitable&&) awaitable).operator co_await(); }) {
@@ -23,31 +24,31 @@ namespace NActors {
class TCallbackResult {
public:
TCallbackResult(TCallback& callback)
- : Callback_(callback)
+ : Callback(callback)
{}
template<class TRealResult>
void return_value(TRealResult&& result) noexcept {
- Callback_(std::forward<TRealResult>(result));
+ Callback(std::forward<TRealResult>(result));
}
private:
- TCallback& Callback_;
+ TCallback& Callback;
};
template<class TCallback>
class TCallbackResult<TCallback, void> {
public:
TCallbackResult(TCallback& callback)
- : Callback_(callback)
+ : Callback(callback)
{}
void return_void() noexcept {
- Callback_();
+ Callback();
}
private:
- TCallback& Callback_;
+ TCallback& Callback;
};
template<class TAwaitable, class TCallback>
@@ -82,7 +83,8 @@ namespace NActors {
TAwaitThenCallback(THandle) noexcept {}
};
- }
+
+ } // namespace NDetail
/**
* Awaits the awaitable and calls callback with the result.
diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt
new file mode 100644
index 0000000000..41a756eff5
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt
@@ -0,0 +1,31 @@
+
+# This file was generated by the build system used internally in the Yandex monorepo.
+# Only simple modifications are allowed (adding source-files to targets, adding simple properties
+# like target_include_directories). These modifications will be ported to original
+# ya.make files by maintainers. Any complex modifications which can't be ported back to the
+# original buildsystem will not be accepted.
+
+
+
+add_executable(benchmark)
+target_link_libraries(benchmark PUBLIC
+ contrib-libs-cxxsupp
+ yutil
+ library-cpp-cpuid_check
+ testing-benchmark-main
+ cpp-actors-cppcoro
+)
+target_link_options(benchmark PRIVATE
+ -Wl,-platform_version,macos,11.0,11.0
+ -fPIC
+ -fPIC
+ -framework
+ CoreFoundation
+)
+target_sources(benchmark PRIVATE
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp
+)
+target_allocator(benchmark
+ system_allocator
+)
+vcs_info(benchmark)
diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt
new file mode 100644
index 0000000000..1a5559813a
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt
@@ -0,0 +1,34 @@
+
+# This file was generated by the build system used internally in the Yandex monorepo.
+# Only simple modifications are allowed (adding source-files to targets, adding simple properties
+# like target_include_directories). These modifications will be ported to original
+# ya.make files by maintainers. Any complex modifications which can't be ported back to the
+# original buildsystem will not be accepted.
+
+
+
+add_executable(benchmark)
+target_link_libraries(benchmark PUBLIC
+ contrib-libs-linux-headers
+ contrib-libs-cxxsupp
+ yutil
+ testing-benchmark-main
+ cpp-actors-cppcoro
+)
+target_link_options(benchmark PRIVATE
+ -ldl
+ -lrt
+ -Wl,--no-as-needed
+ -fPIC
+ -fPIC
+ -lpthread
+ -lrt
+ -ldl
+)
+target_sources(benchmark PRIVATE
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp
+)
+target_allocator(benchmark
+ cpp-malloc-jemalloc
+)
+vcs_info(benchmark)
diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt
new file mode 100644
index 0000000000..68e2fd4d6e
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt
@@ -0,0 +1,36 @@
+
+# This file was generated by the build system used internally in the Yandex monorepo.
+# Only simple modifications are allowed (adding source-files to targets, adding simple properties
+# like target_include_directories). These modifications will be ported to original
+# ya.make files by maintainers. Any complex modifications which can't be ported back to the
+# original buildsystem will not be accepted.
+
+
+
+add_executable(benchmark)
+target_link_libraries(benchmark PUBLIC
+ contrib-libs-linux-headers
+ contrib-libs-cxxsupp
+ yutil
+ library-cpp-cpuid_check
+ testing-benchmark-main
+ cpp-actors-cppcoro
+)
+target_link_options(benchmark PRIVATE
+ -ldl
+ -lrt
+ -Wl,--no-as-needed
+ -fPIC
+ -fPIC
+ -lpthread
+ -lrt
+ -ldl
+)
+target_sources(benchmark PRIVATE
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp
+)
+target_allocator(benchmark
+ cpp-malloc-tcmalloc
+ libs-tcmalloc-no_percpu_cache
+)
+vcs_info(benchmark)
diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt
new file mode 100644
index 0000000000..f8b31df0c1
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt
@@ -0,0 +1,17 @@
+
+# This file was generated by the build system used internally in the Yandex monorepo.
+# Only simple modifications are allowed (adding source-files to targets, adding simple properties
+# like target_include_directories). These modifications will be ported to original
+# ya.make files by maintainers. Any complex modifications which can't be ported back to the
+# original buildsystem will not be accepted.
+
+
+if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA)
+ include(CMakeLists.linux-aarch64.txt)
+elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64")
+ include(CMakeLists.darwin-x86_64.txt)
+elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA)
+ include(CMakeLists.windows-x86_64.txt)
+elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA)
+ include(CMakeLists.linux-x86_64.txt)
+endif()
diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt
new file mode 100644
index 0000000000..731a59b67b
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt
@@ -0,0 +1,24 @@
+
+# This file was generated by the build system used internally in the Yandex monorepo.
+# Only simple modifications are allowed (adding source-files to targets, adding simple properties
+# like target_include_directories). These modifications will be ported to original
+# ya.make files by maintainers. Any complex modifications which can't be ported back to the
+# original buildsystem will not be accepted.
+
+
+
+add_executable(benchmark)
+target_link_libraries(benchmark PUBLIC
+ contrib-libs-cxxsupp
+ yutil
+ library-cpp-cpuid_check
+ testing-benchmark-main
+ cpp-actors-cppcoro
+)
+target_sources(benchmark PRIVATE
+ ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp
+)
+target_allocator(benchmark
+ system_allocator
+)
+vcs_info(benchmark)
diff --git a/library/cpp/actors/cppcoro/benchmark/main.cpp b/library/cpp/actors/cppcoro/benchmark/main.cpp
new file mode 100644
index 0000000000..20b4d63243
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/main.cpp
@@ -0,0 +1,76 @@
+#include <library/cpp/actors/cppcoro/task.h>
+#include <library/cpp/actors/cppcoro/await_callback.h>
+#include <library/cpp/testing/benchmark/bench.h>
+
+using namespace NActors;
+
+namespace {
+
+ int LastValue = 0;
+
+ Y_NO_INLINE int NextFuncValue() {
+ return ++LastValue;
+ }
+
+ Y_NO_INLINE void IterateFuncValues(size_t iterations) {
+ for (size_t i = 0; i < iterations; ++i) {
+ int value = NextFuncValue();
+ Y_DO_NOT_OPTIMIZE_AWAY(value);
+ }
+ }
+
+ Y_NO_INLINE TTask<int> NextTaskValue() {
+ co_return ++LastValue;
+ }
+
+ Y_NO_INLINE TTask<void> IterateTaskValues(size_t iterations) {
+ for (size_t i = 0; i < iterations; ++i) {
+ int value = co_await NextTaskValue();
+ Y_DO_NOT_OPTIMIZE_AWAY(value);
+ }
+ }
+
+ std::coroutine_handle<> Paused;
+
+ struct {
+ static bool await_ready() noexcept {
+ return false;
+ }
+ static void await_suspend(std::coroutine_handle<> h) noexcept {
+ Paused = h;
+ }
+ static int await_resume() noexcept {
+ return ++LastValue;
+ }
+ } Pause;
+
+ Y_NO_INLINE TTask<void> IteratePauseValues(size_t iterations) {
+ for (size_t i = 0; i < iterations; ++i) {
+ int value = co_await Pause;
+ Y_DO_NOT_OPTIMIZE_AWAY(value);
+ }
+ }
+
+} // namespace
+
+Y_CPU_BENCHMARK(FuncCalls, iface) {
+ IterateFuncValues(iface.Iterations());
+}
+
+Y_CPU_BENCHMARK(TaskCalls, iface) {
+ bool finished = false;
+ AwaitThenCallback(IterateTaskValues(iface.Iterations()), [&]{
+ finished = true;
+ });
+ Y_VERIFY(finished);
+}
+
+Y_CPU_BENCHMARK(CoroAwaits, iface) {
+ bool finished = false;
+ AwaitThenCallback(IteratePauseValues(iface.Iterations()), [&]{
+ finished = true;
+ });
+ while (!finished) {
+ std::exchange(Paused, {}).resume();
+ }
+}
diff --git a/library/cpp/actors/cppcoro/benchmark/ya.make b/library/cpp/actors/cppcoro/benchmark/ya.make
new file mode 100644
index 0000000000..ef5ad4135c
--- /dev/null
+++ b/library/cpp/actors/cppcoro/benchmark/ya.make
@@ -0,0 +1,11 @@
+Y_BENCHMARK()
+
+PEERDIR(
+ library/cpp/actors/cppcoro
+)
+
+SRCS(
+ main.cpp
+)
+
+END()
diff --git a/library/cpp/actors/cppcoro/task.h b/library/cpp/actors/cppcoro/task.h
index dade638ddb..f02ec22008 100644
--- a/library/cpp/actors/cppcoro/task.h
+++ b/library/cpp/actors/cppcoro/task.h
@@ -1,153 +1,158 @@
#pragma once
-#include <util/system/compiler.h>
+#include "task_result.h"
#include <util/system/yassert.h>
#include <coroutine>
-#include <exception>
-#include <variant>
namespace NActors {
template<class T>
class TTask;
+ /**
+ * This exception is commonly thrown when task is cancelled
+ */
+ class TTaskCancelled : public std::exception {
+ public:
+ const char* what() const noexcept {
+ return "Task cancelled";
+ }
+ };
+
namespace NDetail {
- class TTaskPromiseBase {
+ template<class T>
+ class TTaskPromise;
+
+ template<class T>
+ using TTaskHandle = std::coroutine_handle<TTaskPromise<T>>;
+
+ template<class T>
+ class TTaskAwaiter {
public:
- static auto initial_suspend() noexcept {
- return std::suspend_always{};
+ explicit TTaskAwaiter(TTaskHandle<T> handle)
+ : Handle(handle)
+ {
+ Y_VERIFY_DEBUG(Handle);
}
- struct TFinalSuspend {
- static bool await_ready() noexcept { return false; }
- static void await_resume() noexcept { std::terminate(); }
-
- template<class TPromise>
- static std::coroutine_handle<> await_suspend(std::coroutine_handle<TPromise> h) noexcept {
- TTaskPromiseBase& promise = h.promise();
- return std::exchange(promise.Continuation_, {});
- }
- };
+ TTaskAwaiter(TTaskAwaiter&& rhs)
+ : Handle(std::exchange(rhs.Handle, {}))
+ {}
- static auto final_suspend() noexcept {
- return TFinalSuspend{};
- }
+ TTaskAwaiter& operator=(const TTaskAwaiter&) = delete;
+ TTaskAwaiter& operator=(TTaskAwaiter&&) = delete;
- bool HasStarted() const noexcept {
- return Flags_ & 1;
+ ~TTaskAwaiter() noexcept {
+ if (Handle) {
+ Handle.destroy();
+ }
}
- void SetStarted() noexcept {
- Flags_ |= 1;
- }
+ // We can only await a task that has not started yet
+ static bool await_ready() noexcept { return false; }
- bool HasContinuation() const noexcept {
- return Flags_ & 2;
+ // Some arbitrary continuation c suspended and awaits the task
+ TTaskHandle<T> await_suspend(std::coroutine_handle<> c) noexcept {
+ Y_VERIFY_DEBUG(Handle);
+ Handle.promise().SetContinuation(c);
+ return Handle;
}
- void SetContinuation(std::coroutine_handle<> continuation) noexcept {
- Y_VERIFY_DEBUG(continuation, "Attempt to set an invalid continuation");
- Y_VERIFY_DEBUG(!HasContinuation(), "Attempt to set multiple continuations");
- Continuation_ = continuation;
- Flags_ |= 2;
+ TTaskResult<T>&& await_resume() noexcept {
+ Y_VERIFY_DEBUG(Handle);
+ return std::move(Handle.promise().Result);
}
private:
- // Default is used when task is resumed without a continuation
- std::coroutine_handle<> Continuation_ = std::noop_coroutine();
- unsigned char Flags_ = 0;
+ TTaskHandle<T> Handle;
};
template<class T>
- class TTaskPromise;
+ class TTaskResultAwaiter final : public TTaskAwaiter<T> {
+ public:
+ using TTaskAwaiter<T>::TTaskAwaiter;
- template<class T>
- using TTaskHandle = std::coroutine_handle<TTaskPromise<T>>;
+ T&& await_resume() {
+ return TTaskAwaiter<T>::await_resume().Value();
+ }
+ };
- template<class T>
- class TTaskPromise final : public TTaskPromiseBase {
+ template<>
+ class TTaskResultAwaiter<void> final : public TTaskAwaiter<void> {
public:
- TTask<T> get_return_object() noexcept;
+ using TTaskAwaiter<void>::TTaskAwaiter;
- std::coroutine_handle<> Start() noexcept {
- if (Y_LIKELY(!HasStarted())) {
- SetStarted();
- return TTaskHandle<T>::from_promise(*this);
- } else {
- // After coroutine starts is cannot be safely resumed, because
- // it is waiting for something and must be resumed via its
- // continuation.
- return std::noop_coroutine();
- }
+ void await_resume() {
+ TTaskAwaiter<void>::await_resume().Value();
}
+ };
+ template<class T>
+ class TTaskResultHandlerBase {
+ public:
void unhandled_exception() noexcept {
- Result_.template emplace<std::exception_ptr>(std::current_exception());
+ Result.SetException(std::current_exception());
}
- template<class TResult>
- void return_value(TResult&& result) {
- Result_.template emplace<T>(std::forward<TResult>(result));
- }
+ protected:
+ TTaskResult<T> Result;
+ };
- T ExtractResult() {
- switch (Result_.index()) {
- case 0: {
- std::rethrow_exception(std::get<0>(std::move(Result_)));
- }
- case 1: {
- return std::get<1>(std::move(Result_));
- }
- }
- std::terminate();
+ template<class T>
+ class TTaskResultHandler : public TTaskResultHandlerBase<T> {
+ public:
+ template<class TResult>
+ void return_value(TResult&& value) {
+ this->Result.SetValue(std::forward<TResult>(value));
}
-
- private:
- std::variant<std::exception_ptr, T> Result_;
};
template<>
- class TTaskPromise<void> final : public TTaskPromiseBase {
+ class TTaskResultHandler<void> : public TTaskResultHandlerBase<void> {
public:
- TTask<void> get_return_object() noexcept;
-
- std::coroutine_handle<> Start() noexcept {
- if (Y_LIKELY(!HasStarted())) {
- SetStarted();
- return TTaskHandle<void>::from_promise(*this);
- } else {
- // After coroutine starts is cannot be safely resumed, because
- // it is waiting for something and must be resumed via its
- // continuation.
- return std::noop_coroutine();
- }
+ void return_void() noexcept {
+ this->Result.SetValue();
}
+ };
- void unhandled_exception() noexcept {
- Exception_ = std::current_exception();
- }
+ template<class T>
+ class TTaskPromise final
+ : public TTaskResultHandler<T>
+ {
+ friend class TTaskAwaiter<T>;
- void return_void() noexcept {
- Exception_ = nullptr;
- }
+ public:
+ TTask<T> get_return_object() noexcept;
- void ExtractResult() {
- if (Exception_) {
- std::rethrow_exception(std::move(Exception_));
+ static auto initial_suspend() noexcept { return std::suspend_always{}; }
+
+ struct TFinalSuspend {
+ static bool await_ready() noexcept { return false; }
+ static void await_resume() noexcept { Y_FAIL("unexpected coroutine resume"); }
+
+ static std::coroutine_handle<> await_suspend(std::coroutine_handle<TTaskPromise<T>> h) noexcept {
+ auto next = std::exchange(h.promise().Continuation, std::noop_coroutine());
+ Y_VERIFY_DEBUG(next, "Task finished without a continuation");
+ return next;
}
+ };
+
+ static auto final_suspend() noexcept { return TFinalSuspend{}; }
+
+ private:
+ void SetContinuation(std::coroutine_handle<> continuation) noexcept {
+ Y_VERIFY_DEBUG(!Continuation, "Task can only be awaited once");
+ Continuation = continuation;
}
private:
- std::exception_ptr Exception_;
+ std::coroutine_handle<> Continuation;
};
} // namespace NDetail
/**
- * A bare-bones lazy task implementation
- *
- * This task is not thread safe and assumes external synchronization, e.g.
- * races between destructor and await resume are not allowed and not safe.
+ * Represents a task that has not been started yet
*/
template<class T>
class TTask final {
@@ -159,23 +164,23 @@ namespace NActors {
TTask() noexcept = default;
explicit TTask(NDetail::TTaskHandle<T> handle) noexcept
- : Handle_(handle)
+ : Handle(handle)
{}
TTask(TTask&& rhs) noexcept
- : Handle_(std::exchange(rhs.Handle_, {}))
+ : Handle(std::exchange(rhs.Handle, {}))
{}
~TTask() {
- if (Handle_) {
- Handle_.destroy();
+ if (Handle) {
+ Handle.destroy();
}
}
TTask& operator=(TTask&& rhs) noexcept {
if (Y_LIKELY(this != &rhs)) {
- auto handle = std::exchange(Handle_, {});
- Handle_ = std::exchange(rhs.Handle_, {});
+ auto handle = std::exchange(Handle, {});
+ Handle = std::exchange(rhs.Handle, {});
if (handle) {
handle.destroy();
}
@@ -187,113 +192,27 @@ namespace NActors {
* Returns true for a valid task object
*/
explicit operator bool() const noexcept {
- return bool(Handle_);
- }
-
- /**
- * Returns true if the task finished executing (produced a result)
- */
- bool Done() const {
- Y_VERIFY_DEBUG(Handle_);
- return Handle_.done();
- }
-
- /**
- * Manually start the task, only possible once
- */
- void Start() const {
- Y_VERIFY_DEBUG(!Done());
- Y_VERIFY_DEBUG(!Promise().HasStarted());
- Y_VERIFY_DEBUG(!Promise().HasContinuation());
- Promise().Start().resume();
- }
-
- /**
- * Implementation of awaiter for WhenDone
- */
- class TWhenDoneAwaiter {
- public:
- TWhenDoneAwaiter(NDetail::TTaskHandle<T> handle) noexcept
- : Handle_(handle)
- {}
-
- bool await_ready() const noexcept {
- return Handle_.done();
- }
-
- std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) const noexcept {
- Handle_.promise().SetContinuation(continuation);
- return Handle_.promise().Start();
- }
-
- void await_resume() const noexcept {
- // nothing
- }
-
- private:
- NDetail::TTaskHandle<T> Handle_;
- };
-
- /**
- * Returns an awaitable that completes when task finishes executing
- *
- * Note the result of the task is not consumed.
- */
- auto WhenDone() const noexcept {
- return TWhenDoneAwaiter(Handle_);
+ return bool(Handle);
}
/**
- * Extracts result of the task
+ * Starts task and returns TTaskResult<T> when it completes
*/
- T ExtractResult() {
- Y_VERIFY_DEBUG(Done());
- return Promise().ExtractResult();
+ auto WhenDone() && noexcept {
+ Y_VERIFY_DEBUG(Handle, "Cannot await an empty task");
+ return NDetail::TTaskAwaiter<T>(std::exchange(Handle, {}));
}
/**
- * Implementation of awaiter for co_await
- */
- class TAwaiter {
- public:
- TAwaiter(TTask&& task) noexcept
- : Task_(std::move(task))
- {}
-
- bool await_ready() const noexcept {
- return Task_.Done();
- }
-
- std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) const noexcept {
- Task_.Promise().SetContinuation(continuation);
- return Task_.Promise().Start();
- }
-
- T await_resume() {
- // We destroy task state before we return
- TTask task(std::move(Task_));
- return task.ExtractResult();
- }
-
- private:
- TTask Task_;
- };
-
- /**
- * Returns the task result when it finishes
+ * Starts task and returns its result when it completes
*/
auto operator co_await() && noexcept {
- return TAwaiter(std::move(*this));
- }
-
- private:
- NDetail::TTaskPromise<T>& Promise() const noexcept {
- Y_VERIFY_DEBUG(Handle_);
- return Handle_.promise();
+ Y_VERIFY_DEBUG(Handle, "Cannot await an empty task");
+ return NDetail::TTaskResultAwaiter<T>(std::exchange(Handle, {}));
}
private:
- NDetail::TTaskHandle<T> Handle_;
+ NDetail::TTaskHandle<T> Handle;
};
namespace NDetail {
@@ -303,10 +222,6 @@ namespace NActors {
return TTask<T>(TTaskHandle<T>::from_promise(*this));
}
- inline TTask<void> TTaskPromise<void>::get_return_object() noexcept {
- return TTask<void>(TTaskHandle<void>::from_promise(*this));
- }
-
} // namespace NDetail
} // namespace NActors
diff --git a/library/cpp/actors/cppcoro/task_actor.cpp b/library/cpp/actors/cppcoro/task_actor.cpp
index 756d7e2f32..d55db4eb04 100644
--- a/library/cpp/actors/cppcoro/task_actor.cpp
+++ b/library/cpp/actors/cppcoro/task_actor.cpp
@@ -1,4 +1,5 @@
#include "task_actor.h"
+#include "await_callback.h"
#include <library/cpp/actors/core/actor.h>
#include <library/cpp/actors/core/hfunc.h>
@@ -25,20 +26,30 @@ namespace NActors {
struct TEvResumeTask : public TEventLocal<TEvResumeTask, EvResumeTask> {
std::coroutine_handle<> Handle;
+ TTaskResult<void>* Result;
- explicit TEvResumeTask(std::coroutine_handle<> handle) noexcept
+ explicit TEvResumeTask(std::coroutine_handle<> handle, TTaskResult<void>* result) noexcept
: Handle(handle)
+ , Result(result)
{}
~TEvResumeTask() noexcept {
- // TODO: actor may be dead already
+ if (Handle) {
+ Result->SetException(std::make_exception_ptr(TTaskCancelled()));
+ Handle.resume();
+ }
}
};
+ class TTaskActorResult final : public TAtomicRefCount<TTaskActorResult> {
+ public:
+ bool Finished = false;
+ };
+
class TTaskActorImpl : public TActor<TTaskActorImpl> {
friend class TTaskActor;
- friend struct TAfterAwaiter;
- friend struct TBindAwaiter;
+ friend class TAfterAwaiter;
+ friend class TBindAwaiter;
public:
TTaskActorImpl(TTask<void>&& task)
@@ -48,6 +59,15 @@ namespace NActors {
Y_VERIFY(Task);
}
+ ~TTaskActorImpl() {
+ Stopped = true;
+ while (EventAwaiter) {
+ // Unblock event awaiter until task stops trying
+ TCurrentTaskActorGuard guard(this);
+ std::exchange(EventAwaiter, {}).resume();
+ }
+ }
+
void Registered(TActorSystem* sys, const TActorId& parent) override {
ParentId = parent;
sys->Send(new IEventHandle(TEvents::TSystem::Bootstrap, 0, SelfId(), SelfId(), {}, 0));
@@ -57,7 +77,15 @@ namespace NActors {
Y_VERIFY(ev->GetTypeRewrite() == TEvents::TSystem::Bootstrap, "Expected bootstrap event");
TCurrentTaskActorGuard guard(this);
Become(&TThis::StateWork);
- Task.Start();
+ AwaitThenCallback(std::move(Task).WhenDone(),
+ [result = Result](TTaskResult<void>&& outcome) noexcept {
+ result->Finished = true;
+ try {
+ outcome.Value();
+ } catch (TTaskCancelled&) {
+ // ignore
+ }
+ });
Check();
}
@@ -66,46 +94,50 @@ namespace NActors {
switch (ev->GetTypeRewrite()) {
hFunc(TEvResumeTask, Handle);
default:
- Y_VERIFY(EventWaiter);
+ Y_VERIFY(EventAwaiter);
Event.reset(ev.Release());
- std::exchange(EventWaiter, {}).resume();
+ std::exchange(EventAwaiter, {}).resume();
}
Check();
}
void Handle(TEvResumeTask::TPtr& ev) {
auto* msg = ev->Get();
+ msg->Result->SetValue();
std::exchange(msg->Handle, {}).resume();
}
bool Check() {
- if (Task.Done()) {
- Y_VERIFY(!EventWaiter, "Task terminated while waiting for the next event");
- Task.ExtractResult();
+ if (Result->Finished) {
+ Y_VERIFY(!EventAwaiter, "Task terminated while waiting for the next event");
PassAway();
return false;
}
- Y_VERIFY(EventWaiter, "Task suspended without waiting for the next event");
- Event.reset();
+ Y_VERIFY(EventAwaiter, "Task suspended without waiting for the next event");
return true;
}
void WaitForEvent(std::coroutine_handle<> h) noexcept {
- Y_VERIFY(!EventWaiter, "Task cannot have multiple waiters for the next event");
- EventWaiter = h;
+ Y_VERIFY(!EventAwaiter, "Task cannot have multiple awaiters for the next event");
+ EventAwaiter = h;
}
- std::unique_ptr<IEventHandle> FinishWaitForEvent() noexcept {
+ std::unique_ptr<IEventHandle> FinishWaitForEvent() {
+ if (Stopped) {
+ throw TTaskCancelled();
+ }
Y_VERIFY(Event, "Task does not have current event");
return std::move(Event);
}
private:
+ TIntrusivePtr<TTaskActorResult> Result = MakeIntrusive<TTaskActorResult>();
TTask<void> Task;
TActorId ParentId;
- std::coroutine_handle<> EventWaiter;
+ std::coroutine_handle<> EventAwaiter;
std::unique_ptr<IEventHandle> Event;
+ bool Stopped = false;
};
void TTaskActorNextEvent::await_suspend(std::coroutine_handle<> h) noexcept {
@@ -113,7 +145,7 @@ namespace NActors {
TlsCurrentTaskActor->WaitForEvent(h);
}
- std::unique_ptr<IEventHandle> TTaskActorNextEvent::await_resume() noexcept {
+ std::unique_ptr<IEventHandle> TTaskActorNextEvent::await_resume() {
Y_VERIFY(TlsCurrentTaskActor, "Not in a task actor context");
return TlsCurrentTaskActor->FinishWaitForEvent();
}
@@ -134,10 +166,7 @@ namespace NActors {
void TAfterAwaiter::await_suspend(std::coroutine_handle<> h) noexcept {
Y_VERIFY(TlsCurrentTaskActor, "Not in a task actor context");
- TlsCurrentTaskActor->Schedule(Duration, new TEvResumeTask(h));
- }
-
- void TAfterAwaiter::await_resume() {
+ TlsCurrentTaskActor->Schedule(Duration, new TEvResumeTask(h, &Result));
}
bool TBindAwaiter::await_ready() noexcept {
@@ -148,10 +177,7 @@ namespace NActors {
}
void TBindAwaiter::await_suspend(std::coroutine_handle<> h) noexcept {
- Sys->Send(new IEventHandle(ActorId, ActorId, new TEvResumeTask(h)));
- }
-
- void TBindAwaiter::await_resume() {
+ Sys->Send(new IEventHandle(ActorId, ActorId, new TEvResumeTask(h, &Result)));
}
} // namespace NActors
diff --git a/library/cpp/actors/cppcoro/task_actor.h b/library/cpp/actors/cppcoro/task_actor.h
index e4a1c9df3e..75d498a04e 100644
--- a/library/cpp/actors/cppcoro/task_actor.h
+++ b/library/cpp/actors/cppcoro/task_actor.h
@@ -8,28 +8,47 @@ namespace NActors {
static void await_suspend(std::coroutine_handle<> h) noexcept;
- static std::unique_ptr<IEventHandle> await_resume() noexcept;
+ static std::unique_ptr<IEventHandle> await_resume();
};
- struct TAfterAwaiter {
- TDuration Duration;
+ class TAfterAwaiter {
+ public:
+ TAfterAwaiter(TDuration duration)
+ : Duration(duration)
+ {}
static constexpr bool await_ready() noexcept { return false; }
void await_suspend(std::coroutine_handle<> h) noexcept;
- void await_resume();
+ void await_resume() {
+ Result.Value();
+ }
+
+ private:
+ TDuration Duration;
+ TTaskResult<void> Result;
};
- struct TBindAwaiter {
- TActorSystem* Sys;
- TActorId ActorId;
+ class TBindAwaiter {
+ public:
+ TBindAwaiter(TActorSystem* sys, const TActorId& actorId)
+ : Sys(sys)
+ , ActorId(actorId)
+ {}
bool await_ready() noexcept;
void await_suspend(std::coroutine_handle<> h) noexcept;
- void await_resume();
+ void await_resume() {
+ Result.Value();
+ }
+
+ private:
+ TActorSystem* Sys;
+ TActorId ActorId;
+ TTaskResult<void> Result;
};
class TTaskActor {
@@ -77,11 +96,10 @@ namespace NActors {
*/
template<class T>
static TTask<T> Bind(TTask<T>&& task) {
- // TODO: may run on non-actor thread, protect from unwind
return [](TTask<T> task, TBindAwaiter bindTask) -> TTask<T> {
- co_await task.WhenDone();
+ auto result = co_await std::move(task).WhenDone();
co_await bindTask;
- co_return task.ExtractResult();
+ co_return std::move(result).Value();
}(std::move(task), Bind());
}
};
diff --git a/library/cpp/actors/cppcoro/task_group.h b/library/cpp/actors/cppcoro/task_group.h
index b57cf59529..b2496f57eb 100644
--- a/library/cpp/actors/cppcoro/task_group.h
+++ b/library/cpp/actors/cppcoro/task_group.h
@@ -1,10 +1,9 @@
#pragma once
+#include "task_result.h"
#include <util/generic/ptr.h>
#include <util/system/compiler.h>
#include <util/system/yassert.h>
#include <coroutine>
-#include <exception>
-#include <variant>
#include <atomic>
#include <memory>
@@ -13,50 +12,8 @@ namespace NActors {
namespace NDetail {
template<class T>
- struct TTaskGroupResult final {
+ struct TTaskGroupResult final : public TTaskResult<T> {
TTaskGroupResult* Next;
- std::variant<std::exception_ptr, T> Result_;
-
- void SetException() {
- Result_.template emplace<0>(std::current_exception());
- }
-
- template<class TResult>
- void SetValue(TResult&& result) {
- Result_.template emplace<1>(std::forward<TResult>(result));
- }
-
- T Extract() {
- switch (Result_.index()) {
- case 0: {
- std::rethrow_exception(std::get<0>(std::move(Result_)));
- }
- case 1: {
- return std::get<1>(std::move(Result_));
- }
- }
- std::terminate();
- }
- };
-
- template<>
- struct TTaskGroupResult<void> final {
- TTaskGroupResult* Next;
- std::exception_ptr Exception_;
-
- void SetException() {
- Exception_ = std::current_exception();
- }
-
- void SetValue() {
- // nothing
- }
-
- void Extract() {
- if (Exception_) {
- std::rethrow_exception(std::move(Exception_));
- }
- }
};
template<class T>
@@ -70,6 +27,12 @@ namespace NActors {
static constexpr uintptr_t MarkerAwaiting = 1;
static constexpr uintptr_t MarkerDetached = 2;
+ ~TTaskGroupSink() noexcept {
+ if (!IsDetached()) {
+ Detach();
+ }
+ }
+
std::coroutine_handle<> Push(std::unique_ptr<TTaskGroupResult<T>>&& result) noexcept {
void* currentValue = LastReady.load(std::memory_order_acquire);
for (;;) {
@@ -78,8 +41,8 @@ namespace NActors {
continue;
}
// We consume the awaiter
- Y_VERIFY(ReadyQueue == nullptr, "TaskGroup is awaiting with non-empty ready queue");
- result->Next = nullptr;
+ Y_VERIFY_DEBUG(ReadyQueue == nullptr, "TaskGroup is awaiting with non-empty ready queue");
+ result->Next = ReadyQueue;
ReadyQueue = result.release();
return std::exchange(Continuation, {});
}
@@ -103,7 +66,7 @@ namespace NActors {
}
Y_NO_INLINE std::coroutine_handle<> Suspend(std::coroutine_handle<> h) noexcept {
- Y_VERIFY(ReadyQueue == nullptr, "Caller suspending with non-empty ready queue");
+ Y_VERIFY_DEBUG(ReadyQueue == nullptr, "Caller suspending with non-empty ready queue");
Continuation = h;
void* currentValue = LastReady.load(std::memory_order_acquire);
for (;;) {
@@ -151,6 +114,11 @@ namespace NActors {
}
}
+ bool IsDetached() const noexcept {
+ void* headValue = LastReady.load(std::memory_order_acquire);
+ return headValue == (void*)MarkerDetached;
+ }
+
void Detach() noexcept {
// After this exchange all new results will be discarded
void* headValue = LastReady.exchange((void*)MarkerDetached, std::memory_order_acq_rel);
@@ -166,48 +134,38 @@ namespace NActors {
};
template<class T>
- class TTaskGroupPromiseBase {
+ class TTaskGroupResultHandler {
public:
- static auto initial_suspend() noexcept { return std::suspend_always{}; }
-
- class TFinalSuspend {
- public:
- TFinalSuspend(TTaskGroupPromiseBase& promise)
- : Promise_(promise)
- {}
-
- static bool await_ready() noexcept { return false; }
-
- Y_NO_INLINE std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept {
- auto next = Promise_.Sink_->Push(std::move(Promise_.Result_));
- h.destroy();
- return next;
- }
-
- static void await_resume() noexcept { std::terminate(); }
+ void unhandled_exception() noexcept {
+ Result->SetException(std::current_exception());
+ }
- private:
- TTaskGroupPromiseBase& Promise_;
- };
+ template<class TResult>
+ void return_value(TResult&& result) {
+ Result->SetValue(std::forward<TResult>(result));
+ }
- auto final_suspend() noexcept { return TFinalSuspend(*this); }
+ protected:
+ std::unique_ptr<TTaskGroupResult<T>> Result = std::make_unique<TTaskGroupResult<T>>();
+ };
+ template<>
+ class TTaskGroupResultHandler<void> {
+ public:
void unhandled_exception() noexcept {
- Result_->SetException();
- Sink_->Push(std::move(Result_));
+ Result->SetException(std::current_exception());
}
- void SetSink(const TIntrusivePtr<TTaskGroupSink<T>>& sink) {
- Sink_ = sink;
+ void return_void() noexcept {
+ Result->SetValue();
}
protected:
- std::unique_ptr<TTaskGroupResult<T>> Result_ = std::make_unique<TTaskGroupResult<T>>();
- TIntrusivePtr<TTaskGroupSink<T>> Sink_;
+ std::unique_ptr<TTaskGroupResult<void>> Result = std::make_unique<TTaskGroupResult<void>>();
};
template<class T>
- class TTaskGroupPromise final : public TTaskGroupPromiseBase<T> {
+ class TTaskGroupPromise final : public TTaskGroupResultHandler<T> {
public:
using THandle = std::coroutine_handle<TTaskGroupPromise<T>>;
@@ -215,24 +173,30 @@ namespace NActors {
return THandle::from_promise(*this);
}
- template<class TResult>
- void return_value(TResult&& result) {
- this->Result_->SetValue(std::forward<TResult>(result));
- }
- };
+ static auto initial_suspend() noexcept { return std::suspend_always{}; }
- template<>
- class TTaskGroupPromise<void> final : public TTaskGroupPromiseBase<void> {
- public:
- using THandle = std::coroutine_handle<TTaskGroupPromise<void>>;
+ struct TFinalSuspend {
+ static bool await_ready() noexcept { return false; }
+ static void await_resume() noexcept { Y_FAIL("unexpected coroutine resume"); }
- THandle get_return_object() noexcept {
- return THandle::from_promise(*this);
- }
+ Y_NO_INLINE
+ static std::coroutine_handle<> await_suspend(std::coroutine_handle<TTaskGroupPromise<T>> h) noexcept {
+ auto& promise = h.promise();
+ auto sink = std::move(promise.Sink);
+ auto next = sink->Push(std::move(promise.Result));
+ h.destroy();
+ return next;
+ }
+ };
+
+ static auto final_suspend() noexcept { return TFinalSuspend{}; }
- void return_void() {
- this->Result_->SetValue();
+ void SetSink(const TIntrusivePtr<TTaskGroupSink<T>>& sink) {
+ Sink = sink;
}
+
+ private:
+ TIntrusivePtr<TTaskGroupSink<T>> Sink;
};
template<class T>
@@ -244,16 +208,16 @@ namespace NActors {
public:
TTaskGroupTask(THandle handle)
- : Handle_(handle)
+ : Handle(handle)
{}
void Start(const TIntrusivePtr<TTaskGroupSink<T>>& sink) {
- Handle_.promise().SetSink(sink);
- Handle_.resume();
+ Handle.promise().SetSink(sink);
+ Handle.resume();
}
private:
- THandle Handle_;
+ THandle Handle;
};
template<class T, class TAwaitable>
@@ -273,6 +237,11 @@ namespace NActors {
public:
TTaskGroup() = default;
+ TTaskGroup(const TTaskGroup&) = delete;
+ TTaskGroup(TTaskGroup&&) = delete;
+ TTaskGroup& operator=(const TTaskGroup&) = delete;
+ TTaskGroup& operator=(TTaskGroup&&) = delete;
+
~TTaskGroup() {
Sink_->Detach();
}
@@ -311,7 +280,7 @@ namespace NActors {
}
T await_resume() {
- return TaskGroup_.Sink_->Resume()->Extract();
+ return std::move(*TaskGroup_.Sink_->Resume()).Value();
}
private:
diff --git a/library/cpp/actors/cppcoro/task_result.cpp b/library/cpp/actors/cppcoro/task_result.cpp
new file mode 100644
index 0000000000..bb1a1dc5ca
--- /dev/null
+++ b/library/cpp/actors/cppcoro/task_result.cpp
@@ -0,0 +1 @@
+#include "task_result.h"
diff --git a/library/cpp/actors/cppcoro/task_result.h b/library/cpp/actors/cppcoro/task_result.h
new file mode 100644
index 0000000000..70176e64d5
--- /dev/null
+++ b/library/cpp/actors/cppcoro/task_result.h
@@ -0,0 +1,113 @@
+#pragma once
+#include <util/system/yassert.h>
+#include <exception>
+#include <variant>
+
+namespace NActors {
+
+ namespace NDetail {
+
+ struct TVoid {};
+
+ template<class T>
+ struct TReplaceVoid {
+ using TType = T;
+ };
+
+ template<>
+ struct TReplaceVoid<void> {
+ using TType = TVoid;
+ };
+
+ template<class T>
+ struct TLValue {
+ using TType = T&;
+ };
+
+ template<>
+ struct TLValue<void> {
+ using TType = void;
+ };
+
+ template<class T>
+ struct TRValue {
+ using TType = T&&;
+ };
+
+ template<>
+ struct TRValue<void> {
+ using TType = void;
+ };
+
+ } // namespace NDetail
+
+ /**
+ * Wrapper for the task result
+ */
+ template<class T>
+ class TTaskResult {
+ public:
+ void SetValue()
+ requires (std::same_as<T, void>)
+ {
+ Result.template emplace<1>();
+ }
+
+ template<class TResult>
+ void SetValue(TResult&& result)
+ requires (!std::same_as<T, void>)
+ {
+ Result.template emplace<1>(std::forward<TResult>(result));
+ }
+
+ void SetException(std::exception_ptr&& e) noexcept {
+ Result.template emplace<2>(std::move(e));
+ }
+
+ typename NDetail::TLValue<T>::TType Value() & {
+ switch (Result.index()) {
+ case 0: {
+ Y_FAIL("Task result has no value");
+ }
+ case 1: {
+ if constexpr (std::same_as<T, void>) {
+ return;
+ } else {
+ return std::get<1>(Result);
+ }
+ }
+ case 2: {
+ std::exception_ptr& e = std::get<2>(Result);
+ Y_VERIFY_DEBUG(e, "Task exception missing");
+ std::rethrow_exception(e);
+ }
+ }
+ Y_FAIL("Task result has an invalid state");
+ }
+
+ typename NDetail::TRValue<T>::TType Value() && {
+ switch (Result.index()) {
+ case 0: {
+ Y_FAIL("Task result has no value");
+ }
+ case 1: {
+ if constexpr (std::same_as<T, void>) {
+ return;
+ } else {
+ return std::get<1>(std::move(Result));
+ }
+ }
+ case 2: {
+ std::exception_ptr& e = std::get<2>(Result);
+ Y_VERIFY_DEBUG(e, "Task exception missing");
+ std::rethrow_exception(std::move(e));
+ }
+ }
+ Y_FAIL("Task result has an invalid state");
+ }
+
+ private:
+ std::variant<std::monostate, typename NDetail::TReplaceVoid<T>::TType, std::exception_ptr> Result;
+ };
+
+} // namespace NActors
diff --git a/library/cpp/actors/cppcoro/task_ut.cpp b/library/cpp/actors/cppcoro/task_ut.cpp
index 52c1b0e591..24ea9d0700 100644
--- a/library/cpp/actors/cppcoro/task_ut.cpp
+++ b/library/cpp/actors/cppcoro/task_ut.cpp
@@ -32,44 +32,26 @@ Y_UNIT_TEST_SUITE(Task) {
UNIT_ASSERT_VALUES_EQUAL(*result, 42);
}
- Y_UNIT_TEST(DoneAndWhenDone) {
- auto task = SimpleReturn42();
- UNIT_ASSERT(task);
- UNIT_ASSERT(!task.Done());
-
- bool whenDoneFinished = false;
- AwaitThenCallback(task.WhenDone(), [&]() {
- whenDoneFinished = true;
- });
- UNIT_ASSERT(whenDoneFinished);
- UNIT_ASSERT(task.Done());
-
- // WhenDone can be used even when task is already done
- whenDoneFinished = false;
- AwaitThenCallback(task.WhenDone(), [&]() {
- whenDoneFinished = true;
- });
- UNIT_ASSERT(whenDoneFinished);
-
- std::optional<int> result;
- AwaitThenCallback(std::move(task), [&](int value) {
- result = value;
+ Y_UNIT_TEST(SimpleVoidWhenDone) {
+ std::optional<TTaskResult<void>> result;
+ AwaitThenCallback(SimpleReturnVoid().WhenDone(), [&](auto value) {
+ result = std::move(value);
});
UNIT_ASSERT(result);
- UNIT_ASSERT_VALUES_EQUAL(*result, 42);
- UNIT_ASSERT(!task);
+ result->Value();
}
- Y_UNIT_TEST(ManualStart) {
- auto task = SimpleReturn42();
- UNIT_ASSERT(task && !task.Done());
- task.Start();
- UNIT_ASSERT(task.Done());
- UNIT_ASSERT_VALUES_EQUAL(task.ExtractResult(), 42);
+ Y_UNIT_TEST(SimpleIntWhenDone) {
+ std::optional<TTaskResult<int>> result;
+ AwaitThenCallback(SimpleReturn42().WhenDone(), [&](auto value) {
+ result = std::move(value);
+ });
+ UNIT_ASSERT(result);
+ UNIT_ASSERT_VALUES_EQUAL(result->Value(), 42);
}
template<class TCallback>
- TTask<int> CallTwice(TCallback&& callback) {
+ TTask<int> CallTwice(TCallback callback) {
int a = co_await callback();
int b = co_await callback();
co_return a + b;
@@ -79,6 +61,7 @@ Y_UNIT_TEST_SUITE(Task) {
auto task = CallTwice([]{
return SimpleReturn42();
});
+ UNIT_ASSERT(task);
std::optional<int> result;
AwaitThenCallback(std::move(task), [&](int value) {
result = value;
@@ -87,22 +70,37 @@ Y_UNIT_TEST_SUITE(Task) {
UNIT_ASSERT_VALUES_EQUAL(*result, 84);
}
+ template<class T>
struct TPauseState {
std::coroutine_handle<> Next;
- int NextResult;
+ std::optional<T> NextResult;
- auto Wait() {
- struct TAwaiter {
- TPauseState* State;
+ ~TPauseState() {
+ while (Next) {
+ NextResult.reset();
+ std::exchange(Next, {}).resume();
+ }
+ }
- bool await_ready() const noexcept { return false; }
- int await_resume() const noexcept {
- return State->NextResult;
- }
- void await_suspend(std::coroutine_handle<> c) {
- State->Next = c;
+ struct TAwaiter {
+ TPauseState* State;
+
+ bool await_ready() const noexcept { return false; }
+ void await_suspend(std::coroutine_handle<> c) const noexcept {
+ State->Next = c;
+ }
+ T await_resume() const {
+ if (!State->NextResult) {
+ throw TTaskCancelled();
+ } else {
+ T result = std::move(*State->NextResult);
+ State->NextResult.reset();
+ return result;
}
- };
+ }
+ };
+
+ auto Wait() {
return TAwaiter{ this };
}
@@ -110,19 +108,24 @@ Y_UNIT_TEST_SUITE(Task) {
return bool(Next);
}
- void Resume(int result) {
+ void Resume(T result) {
Y_VERIFY(Next && !Next.done());
NextResult = result;
std::exchange(Next, {}).resume();
}
+
+ void Cancel() {
+ Y_VERIFY(Next && !Next.done());
+ NextResult.reset();
+ std::exchange(Next, {}).resume();
+ }
};
- Y_UNIT_TEST(PausedAwait) {
- TPauseState state;
- auto callback = [&]{
+ Y_UNIT_TEST(PauseResume) {
+ TPauseState<int> state;
+ auto task = CallTwice([&]{
return state.Wait();
- };
- auto task = CallTwice(callback);
+ });
std::optional<int> result;
AwaitThenCallback(std::move(task), [&](int value) {
result = value;
@@ -137,63 +140,31 @@ Y_UNIT_TEST_SUITE(Task) {
UNIT_ASSERT_VALUES_EQUAL(*result, 33);
}
- Y_UNIT_TEST(ManuallyStartThenWhenDone) {
- TPauseState state;
- auto next = [&]{
+ Y_UNIT_TEST(PauseCancel) {
+ TPauseState<int> state;
+ auto task = CallTwice([&]{
return state.Wait();
- };
-
- auto task = [](auto next) -> TTask<int> {
- int value = co_await next();
- co_return value * 2;
- }(next);
-
- UNIT_ASSERT(task && !task.Done());
- task.Start();
- UNIT_ASSERT(!task.Done() && state);
- bool finished = false;
- AwaitThenCallback(task.WhenDone(), [&]{
- finished = true;
});
- UNIT_ASSERT(!finished && !task.Done());
- state.Resume(11);
- UNIT_ASSERT(finished && task.Done());
- UNIT_ASSERT_VALUES_EQUAL(task.ExtractResult(), 22);
- }
-
- Y_UNIT_TEST(ManuallyStartThenAwait) {
- TPauseState state;
- auto next = [&]{
- return state.Wait();
- };
-
- auto task = [](auto next) -> TTask<int> {
- int value = co_await next();
- co_return value * 2;
- }(next);
-
- UNIT_ASSERT(task && !task.Done());
- task.Start();
- UNIT_ASSERT(!task.Done() && state);
-
- auto awaitTask = [](auto task) -> TTask<int> {
- int value = co_await std::move(task);
- co_return value * 3;
- }(std::move(task));
- UNIT_ASSERT(awaitTask && !awaitTask.Done());
std::optional<int> result;
- AwaitThenCallback(std::move(awaitTask), [&](int value) {
- result = value;
+ AwaitThenCallback(std::move(task).WhenDone(), [&](TTaskResult<int>&& value) {
+ try {
+ result = value.Value();
+ } catch (TTaskCancelled&) {
+ // nothing
+ }
});
UNIT_ASSERT(!result);
+ UNIT_ASSERT(state);
state.Resume(11);
- UNIT_ASSERT(result);
- UNIT_ASSERT_VALUES_EQUAL(*result, 66);
+ UNIT_ASSERT(!result);
+ UNIT_ASSERT(state);
+ state.Cancel();
+ UNIT_ASSERT(!result);
}
Y_UNIT_TEST(GroupWithTwoSubTasks) {
- TPauseState state1;
- TPauseState state2;
+ TPauseState<int> state1;
+ TPauseState<int> state2;
std::vector<int> results;
auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
@@ -227,8 +198,8 @@ Y_UNIT_TEST_SUITE(Task) {
}
Y_UNIT_TEST(GroupWithTwoSubTasksDetached) {
- TPauseState state1;
- TPauseState state2;
+ TPauseState<int> state1;
+ TPauseState<int> state2;
std::vector<int> results;
auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
@@ -253,9 +224,40 @@ Y_UNIT_TEST_SUITE(Task) {
UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22);
UNIT_ASSERT(result);
UNIT_ASSERT_VALUES_EQUAL(*result, 22);
+ }
- // We must resume the first state (otherwise memory leaks), but result is ignored
+ Y_UNIT_TEST(GroupWithTwoSubTasksOneCancelled) {
+ TPauseState<int> state1;
+ TPauseState<int> state2;
+ std::vector<int> results;
+ auto task = [](auto& state1, auto& state2, auto& results) -> TTask<void> {
+ TTaskGroup<int> group;
+ group.AddTask(state1.Wait());
+ group.AddTask(state2.Wait());
+ for (int i = 0; i < 2; ++i) {
+ try {
+ results.push_back(co_await group);
+ } catch (TTaskCancelled&) {
+ results.push_back(-1);
+ }
+ }
+ }(state1, state2, results);
+
+ bool finished = false;
+ AwaitThenCallback(std::move(task), [&]() {
+ finished = true;
+ });
+
+ UNIT_ASSERT(state1);
+ UNIT_ASSERT(state2);
+ state2.Cancel();
+ UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
+ UNIT_ASSERT_VALUES_EQUAL(results.at(0), -1);
+ UNIT_ASSERT(!finished);
state1.Resume(11);
+ UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u);
+ UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11);
+ UNIT_ASSERT(finished);
}
} // Y_UNIT_TEST_SUITE(Task)
diff --git a/library/cpp/actors/cppcoro/ya.make b/library/cpp/actors/cppcoro/ya.make
index 9890eccbee..c1d8c225aa 100644
--- a/library/cpp/actors/cppcoro/ya.make
+++ b/library/cpp/actors/cppcoro/ya.make
@@ -11,6 +11,8 @@ SRCS(
task_actor.h
task_group.cpp
task_group.h
+ task_result.cpp
+ task_result.h
task.cpp
task.h
)
@@ -18,5 +20,6 @@ SRCS(
END()
RECURSE_FOR_TESTS(
+ benchmark
ut
)