diff options
author | snaury <snaury@ydb.tech> | 2023-08-08 17:18:33 +0300 |
---|---|---|
committer | snaury <snaury@ydb.tech> | 2023-08-08 18:14:19 +0300 |
commit | e8f301ee51eb89ff10308d312557586ac0261c3d (patch) | |
tree | 5f8b923d708622d228d1f6e6387cd715154c6f03 /library/cpp | |
parent | f95fc3633ff92fc77fbc6220aa9b3653d0df412e (diff) | |
download | ydb-e8f301ee51eb89ff10308d312557586ac0261c3d.tar.gz |
Better C++ coroutine lifetime in actors KIKIMR-18962
Diffstat (limited to 'library/cpp')
20 files changed, 724 insertions, 438 deletions
diff --git a/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt index ecac0aa784..f27bdd46d9 100644 --- a/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt +++ b/library/cpp/actors/cppcoro/CMakeLists.darwin-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(benchmark) add_subdirectory(ut) add_library(cpp-actors-cppcoro) @@ -18,5 +19,6 @@ target_sources(cpp-actors-cppcoro PRIVATE ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp ) diff --git a/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt b/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt index ff385af6fa..a7c6669d6f 100644 --- a/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt +++ b/library/cpp/actors/cppcoro/CMakeLists.linux-aarch64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(benchmark) add_subdirectory(ut) add_library(cpp-actors-cppcoro) @@ -19,5 +20,6 @@ target_sources(cpp-actors-cppcoro PRIVATE ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp ) diff --git a/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt index ff385af6fa..a7c6669d6f 100644 --- a/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt +++ b/library/cpp/actors/cppcoro/CMakeLists.linux-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(benchmark) add_subdirectory(ut) add_library(cpp-actors-cppcoro) @@ -19,5 +20,6 @@ target_sources(cpp-actors-cppcoro PRIVATE ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp ) diff --git a/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt b/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt index ecac0aa784..f27bdd46d9 100644 --- a/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt +++ b/library/cpp/actors/cppcoro/CMakeLists.windows-x86_64.txt @@ -6,6 +6,7 @@ # original buildsystem will not be accepted. +add_subdirectory(benchmark) add_subdirectory(ut) add_library(cpp-actors-cppcoro) @@ -18,5 +19,6 @@ target_sources(cpp-actors-cppcoro PRIVATE ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/await_callback.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_actor.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_group.cpp + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task_result.cpp ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/task.cpp ) diff --git a/library/cpp/actors/cppcoro/await_callback.h b/library/cpp/actors/cppcoro/await_callback.h index 9f23d5e0db..fcb2eb78f9 100644 --- a/library/cpp/actors/cppcoro/await_callback.h +++ b/library/cpp/actors/cppcoro/await_callback.h @@ -5,6 +5,7 @@ namespace NActors { namespace NDetail { + template<class TAwaitable> decltype(auto) GetAwaiter(TAwaitable&& awaitable) { if constexpr (requires { ((TAwaitable&&) awaitable).operator co_await(); }) { @@ -23,31 +24,31 @@ namespace NActors { class TCallbackResult { public: TCallbackResult(TCallback& callback) - : Callback_(callback) + : Callback(callback) {} template<class TRealResult> void return_value(TRealResult&& result) noexcept { - Callback_(std::forward<TRealResult>(result)); + Callback(std::forward<TRealResult>(result)); } private: - TCallback& Callback_; + TCallback& Callback; }; template<class TCallback> class TCallbackResult<TCallback, void> { public: TCallbackResult(TCallback& callback) - : Callback_(callback) + : Callback(callback) {} void return_void() noexcept { - Callback_(); + Callback(); } private: - TCallback& Callback_; + TCallback& Callback; }; template<class TAwaitable, class TCallback> @@ -82,7 +83,8 @@ namespace NActors { TAwaitThenCallback(THandle) noexcept {} }; - } + + } // namespace NDetail /** * Awaits the awaitable and calls callback with the result. diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt new file mode 100644 index 0000000000..41a756eff5 --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.darwin-x86_64.txt @@ -0,0 +1,31 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_executable(benchmark) +target_link_libraries(benchmark PUBLIC + contrib-libs-cxxsupp + yutil + library-cpp-cpuid_check + testing-benchmark-main + cpp-actors-cppcoro +) +target_link_options(benchmark PRIVATE + -Wl,-platform_version,macos,11.0,11.0 + -fPIC + -fPIC + -framework + CoreFoundation +) +target_sources(benchmark PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp +) +target_allocator(benchmark + system_allocator +) +vcs_info(benchmark) diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt new file mode 100644 index 0000000000..1a5559813a --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-aarch64.txt @@ -0,0 +1,34 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_executable(benchmark) +target_link_libraries(benchmark PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + testing-benchmark-main + cpp-actors-cppcoro +) +target_link_options(benchmark PRIVATE + -ldl + -lrt + -Wl,--no-as-needed + -fPIC + -fPIC + -lpthread + -lrt + -ldl +) +target_sources(benchmark PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp +) +target_allocator(benchmark + cpp-malloc-jemalloc +) +vcs_info(benchmark) diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt new file mode 100644 index 0000000000..68e2fd4d6e --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.linux-x86_64.txt @@ -0,0 +1,36 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_executable(benchmark) +target_link_libraries(benchmark PUBLIC + contrib-libs-linux-headers + contrib-libs-cxxsupp + yutil + library-cpp-cpuid_check + testing-benchmark-main + cpp-actors-cppcoro +) +target_link_options(benchmark PRIVATE + -ldl + -lrt + -Wl,--no-as-needed + -fPIC + -fPIC + -lpthread + -lrt + -ldl +) +target_sources(benchmark PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp +) +target_allocator(benchmark + cpp-malloc-tcmalloc + libs-tcmalloc-no_percpu_cache +) +vcs_info(benchmark) diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt new file mode 100644 index 0000000000..f8b31df0c1 --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.txt @@ -0,0 +1,17 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + +if (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-aarch64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64") + include(CMakeLists.darwin-x86_64.txt) +elseif (WIN32 AND CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64" AND NOT HAVE_CUDA) + include(CMakeLists.windows-x86_64.txt) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" AND NOT HAVE_CUDA) + include(CMakeLists.linux-x86_64.txt) +endif() diff --git a/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt b/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt new file mode 100644 index 0000000000..731a59b67b --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/CMakeLists.windows-x86_64.txt @@ -0,0 +1,24 @@ + +# This file was generated by the build system used internally in the Yandex monorepo. +# Only simple modifications are allowed (adding source-files to targets, adding simple properties +# like target_include_directories). These modifications will be ported to original +# ya.make files by maintainers. Any complex modifications which can't be ported back to the +# original buildsystem will not be accepted. + + + +add_executable(benchmark) +target_link_libraries(benchmark PUBLIC + contrib-libs-cxxsupp + yutil + library-cpp-cpuid_check + testing-benchmark-main + cpp-actors-cppcoro +) +target_sources(benchmark PRIVATE + ${CMAKE_SOURCE_DIR}/library/cpp/actors/cppcoro/benchmark/main.cpp +) +target_allocator(benchmark + system_allocator +) +vcs_info(benchmark) diff --git a/library/cpp/actors/cppcoro/benchmark/main.cpp b/library/cpp/actors/cppcoro/benchmark/main.cpp new file mode 100644 index 0000000000..20b4d63243 --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/main.cpp @@ -0,0 +1,76 @@ +#include <library/cpp/actors/cppcoro/task.h> +#include <library/cpp/actors/cppcoro/await_callback.h> +#include <library/cpp/testing/benchmark/bench.h> + +using namespace NActors; + +namespace { + + int LastValue = 0; + + Y_NO_INLINE int NextFuncValue() { + return ++LastValue; + } + + Y_NO_INLINE void IterateFuncValues(size_t iterations) { + for (size_t i = 0; i < iterations; ++i) { + int value = NextFuncValue(); + Y_DO_NOT_OPTIMIZE_AWAY(value); + } + } + + Y_NO_INLINE TTask<int> NextTaskValue() { + co_return ++LastValue; + } + + Y_NO_INLINE TTask<void> IterateTaskValues(size_t iterations) { + for (size_t i = 0; i < iterations; ++i) { + int value = co_await NextTaskValue(); + Y_DO_NOT_OPTIMIZE_AWAY(value); + } + } + + std::coroutine_handle<> Paused; + + struct { + static bool await_ready() noexcept { + return false; + } + static void await_suspend(std::coroutine_handle<> h) noexcept { + Paused = h; + } + static int await_resume() noexcept { + return ++LastValue; + } + } Pause; + + Y_NO_INLINE TTask<void> IteratePauseValues(size_t iterations) { + for (size_t i = 0; i < iterations; ++i) { + int value = co_await Pause; + Y_DO_NOT_OPTIMIZE_AWAY(value); + } + } + +} // namespace + +Y_CPU_BENCHMARK(FuncCalls, iface) { + IterateFuncValues(iface.Iterations()); +} + +Y_CPU_BENCHMARK(TaskCalls, iface) { + bool finished = false; + AwaitThenCallback(IterateTaskValues(iface.Iterations()), [&]{ + finished = true; + }); + Y_VERIFY(finished); +} + +Y_CPU_BENCHMARK(CoroAwaits, iface) { + bool finished = false; + AwaitThenCallback(IteratePauseValues(iface.Iterations()), [&]{ + finished = true; + }); + while (!finished) { + std::exchange(Paused, {}).resume(); + } +} diff --git a/library/cpp/actors/cppcoro/benchmark/ya.make b/library/cpp/actors/cppcoro/benchmark/ya.make new file mode 100644 index 0000000000..ef5ad4135c --- /dev/null +++ b/library/cpp/actors/cppcoro/benchmark/ya.make @@ -0,0 +1,11 @@ +Y_BENCHMARK() + +PEERDIR( + library/cpp/actors/cppcoro +) + +SRCS( + main.cpp +) + +END() diff --git a/library/cpp/actors/cppcoro/task.h b/library/cpp/actors/cppcoro/task.h index dade638ddb..f02ec22008 100644 --- a/library/cpp/actors/cppcoro/task.h +++ b/library/cpp/actors/cppcoro/task.h @@ -1,153 +1,158 @@ #pragma once -#include <util/system/compiler.h> +#include "task_result.h" #include <util/system/yassert.h> #include <coroutine> -#include <exception> -#include <variant> namespace NActors { template<class T> class TTask; + /** + * This exception is commonly thrown when task is cancelled + */ + class TTaskCancelled : public std::exception { + public: + const char* what() const noexcept { + return "Task cancelled"; + } + }; + namespace NDetail { - class TTaskPromiseBase { + template<class T> + class TTaskPromise; + + template<class T> + using TTaskHandle = std::coroutine_handle<TTaskPromise<T>>; + + template<class T> + class TTaskAwaiter { public: - static auto initial_suspend() noexcept { - return std::suspend_always{}; + explicit TTaskAwaiter(TTaskHandle<T> handle) + : Handle(handle) + { + Y_VERIFY_DEBUG(Handle); } - struct TFinalSuspend { - static bool await_ready() noexcept { return false; } - static void await_resume() noexcept { std::terminate(); } - - template<class TPromise> - static std::coroutine_handle<> await_suspend(std::coroutine_handle<TPromise> h) noexcept { - TTaskPromiseBase& promise = h.promise(); - return std::exchange(promise.Continuation_, {}); - } - }; + TTaskAwaiter(TTaskAwaiter&& rhs) + : Handle(std::exchange(rhs.Handle, {})) + {} - static auto final_suspend() noexcept { - return TFinalSuspend{}; - } + TTaskAwaiter& operator=(const TTaskAwaiter&) = delete; + TTaskAwaiter& operator=(TTaskAwaiter&&) = delete; - bool HasStarted() const noexcept { - return Flags_ & 1; + ~TTaskAwaiter() noexcept { + if (Handle) { + Handle.destroy(); + } } - void SetStarted() noexcept { - Flags_ |= 1; - } + // We can only await a task that has not started yet + static bool await_ready() noexcept { return false; } - bool HasContinuation() const noexcept { - return Flags_ & 2; + // Some arbitrary continuation c suspended and awaits the task + TTaskHandle<T> await_suspend(std::coroutine_handle<> c) noexcept { + Y_VERIFY_DEBUG(Handle); + Handle.promise().SetContinuation(c); + return Handle; } - void SetContinuation(std::coroutine_handle<> continuation) noexcept { - Y_VERIFY_DEBUG(continuation, "Attempt to set an invalid continuation"); - Y_VERIFY_DEBUG(!HasContinuation(), "Attempt to set multiple continuations"); - Continuation_ = continuation; - Flags_ |= 2; + TTaskResult<T>&& await_resume() noexcept { + Y_VERIFY_DEBUG(Handle); + return std::move(Handle.promise().Result); } private: - // Default is used when task is resumed without a continuation - std::coroutine_handle<> Continuation_ = std::noop_coroutine(); - unsigned char Flags_ = 0; + TTaskHandle<T> Handle; }; template<class T> - class TTaskPromise; + class TTaskResultAwaiter final : public TTaskAwaiter<T> { + public: + using TTaskAwaiter<T>::TTaskAwaiter; - template<class T> - using TTaskHandle = std::coroutine_handle<TTaskPromise<T>>; + T&& await_resume() { + return TTaskAwaiter<T>::await_resume().Value(); + } + }; - template<class T> - class TTaskPromise final : public TTaskPromiseBase { + template<> + class TTaskResultAwaiter<void> final : public TTaskAwaiter<void> { public: - TTask<T> get_return_object() noexcept; + using TTaskAwaiter<void>::TTaskAwaiter; - std::coroutine_handle<> Start() noexcept { - if (Y_LIKELY(!HasStarted())) { - SetStarted(); - return TTaskHandle<T>::from_promise(*this); - } else { - // After coroutine starts is cannot be safely resumed, because - // it is waiting for something and must be resumed via its - // continuation. - return std::noop_coroutine(); - } + void await_resume() { + TTaskAwaiter<void>::await_resume().Value(); } + }; + template<class T> + class TTaskResultHandlerBase { + public: void unhandled_exception() noexcept { - Result_.template emplace<std::exception_ptr>(std::current_exception()); + Result.SetException(std::current_exception()); } - template<class TResult> - void return_value(TResult&& result) { - Result_.template emplace<T>(std::forward<TResult>(result)); - } + protected: + TTaskResult<T> Result; + }; - T ExtractResult() { - switch (Result_.index()) { - case 0: { - std::rethrow_exception(std::get<0>(std::move(Result_))); - } - case 1: { - return std::get<1>(std::move(Result_)); - } - } - std::terminate(); + template<class T> + class TTaskResultHandler : public TTaskResultHandlerBase<T> { + public: + template<class TResult> + void return_value(TResult&& value) { + this->Result.SetValue(std::forward<TResult>(value)); } - - private: - std::variant<std::exception_ptr, T> Result_; }; template<> - class TTaskPromise<void> final : public TTaskPromiseBase { + class TTaskResultHandler<void> : public TTaskResultHandlerBase<void> { public: - TTask<void> get_return_object() noexcept; - - std::coroutine_handle<> Start() noexcept { - if (Y_LIKELY(!HasStarted())) { - SetStarted(); - return TTaskHandle<void>::from_promise(*this); - } else { - // After coroutine starts is cannot be safely resumed, because - // it is waiting for something and must be resumed via its - // continuation. - return std::noop_coroutine(); - } + void return_void() noexcept { + this->Result.SetValue(); } + }; - void unhandled_exception() noexcept { - Exception_ = std::current_exception(); - } + template<class T> + class TTaskPromise final + : public TTaskResultHandler<T> + { + friend class TTaskAwaiter<T>; - void return_void() noexcept { - Exception_ = nullptr; - } + public: + TTask<T> get_return_object() noexcept; - void ExtractResult() { - if (Exception_) { - std::rethrow_exception(std::move(Exception_)); + static auto initial_suspend() noexcept { return std::suspend_always{}; } + + struct TFinalSuspend { + static bool await_ready() noexcept { return false; } + static void await_resume() noexcept { Y_FAIL("unexpected coroutine resume"); } + + static std::coroutine_handle<> await_suspend(std::coroutine_handle<TTaskPromise<T>> h) noexcept { + auto next = std::exchange(h.promise().Continuation, std::noop_coroutine()); + Y_VERIFY_DEBUG(next, "Task finished without a continuation"); + return next; } + }; + + static auto final_suspend() noexcept { return TFinalSuspend{}; } + + private: + void SetContinuation(std::coroutine_handle<> continuation) noexcept { + Y_VERIFY_DEBUG(!Continuation, "Task can only be awaited once"); + Continuation = continuation; } private: - std::exception_ptr Exception_; + std::coroutine_handle<> Continuation; }; } // namespace NDetail /** - * A bare-bones lazy task implementation - * - * This task is not thread safe and assumes external synchronization, e.g. - * races between destructor and await resume are not allowed and not safe. + * Represents a task that has not been started yet */ template<class T> class TTask final { @@ -159,23 +164,23 @@ namespace NActors { TTask() noexcept = default; explicit TTask(NDetail::TTaskHandle<T> handle) noexcept - : Handle_(handle) + : Handle(handle) {} TTask(TTask&& rhs) noexcept - : Handle_(std::exchange(rhs.Handle_, {})) + : Handle(std::exchange(rhs.Handle, {})) {} ~TTask() { - if (Handle_) { - Handle_.destroy(); + if (Handle) { + Handle.destroy(); } } TTask& operator=(TTask&& rhs) noexcept { if (Y_LIKELY(this != &rhs)) { - auto handle = std::exchange(Handle_, {}); - Handle_ = std::exchange(rhs.Handle_, {}); + auto handle = std::exchange(Handle, {}); + Handle = std::exchange(rhs.Handle, {}); if (handle) { handle.destroy(); } @@ -187,113 +192,27 @@ namespace NActors { * Returns true for a valid task object */ explicit operator bool() const noexcept { - return bool(Handle_); - } - - /** - * Returns true if the task finished executing (produced a result) - */ - bool Done() const { - Y_VERIFY_DEBUG(Handle_); - return Handle_.done(); - } - - /** - * Manually start the task, only possible once - */ - void Start() const { - Y_VERIFY_DEBUG(!Done()); - Y_VERIFY_DEBUG(!Promise().HasStarted()); - Y_VERIFY_DEBUG(!Promise().HasContinuation()); - Promise().Start().resume(); - } - - /** - * Implementation of awaiter for WhenDone - */ - class TWhenDoneAwaiter { - public: - TWhenDoneAwaiter(NDetail::TTaskHandle<T> handle) noexcept - : Handle_(handle) - {} - - bool await_ready() const noexcept { - return Handle_.done(); - } - - std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) const noexcept { - Handle_.promise().SetContinuation(continuation); - return Handle_.promise().Start(); - } - - void await_resume() const noexcept { - // nothing - } - - private: - NDetail::TTaskHandle<T> Handle_; - }; - - /** - * Returns an awaitable that completes when task finishes executing - * - * Note the result of the task is not consumed. - */ - auto WhenDone() const noexcept { - return TWhenDoneAwaiter(Handle_); + return bool(Handle); } /** - * Extracts result of the task + * Starts task and returns TTaskResult<T> when it completes */ - T ExtractResult() { - Y_VERIFY_DEBUG(Done()); - return Promise().ExtractResult(); + auto WhenDone() && noexcept { + Y_VERIFY_DEBUG(Handle, "Cannot await an empty task"); + return NDetail::TTaskAwaiter<T>(std::exchange(Handle, {})); } /** - * Implementation of awaiter for co_await - */ - class TAwaiter { - public: - TAwaiter(TTask&& task) noexcept - : Task_(std::move(task)) - {} - - bool await_ready() const noexcept { - return Task_.Done(); - } - - std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) const noexcept { - Task_.Promise().SetContinuation(continuation); - return Task_.Promise().Start(); - } - - T await_resume() { - // We destroy task state before we return - TTask task(std::move(Task_)); - return task.ExtractResult(); - } - - private: - TTask Task_; - }; - - /** - * Returns the task result when it finishes + * Starts task and returns its result when it completes */ auto operator co_await() && noexcept { - return TAwaiter(std::move(*this)); - } - - private: - NDetail::TTaskPromise<T>& Promise() const noexcept { - Y_VERIFY_DEBUG(Handle_); - return Handle_.promise(); + Y_VERIFY_DEBUG(Handle, "Cannot await an empty task"); + return NDetail::TTaskResultAwaiter<T>(std::exchange(Handle, {})); } private: - NDetail::TTaskHandle<T> Handle_; + NDetail::TTaskHandle<T> Handle; }; namespace NDetail { @@ -303,10 +222,6 @@ namespace NActors { return TTask<T>(TTaskHandle<T>::from_promise(*this)); } - inline TTask<void> TTaskPromise<void>::get_return_object() noexcept { - return TTask<void>(TTaskHandle<void>::from_promise(*this)); - } - } // namespace NDetail } // namespace NActors diff --git a/library/cpp/actors/cppcoro/task_actor.cpp b/library/cpp/actors/cppcoro/task_actor.cpp index 756d7e2f32..d55db4eb04 100644 --- a/library/cpp/actors/cppcoro/task_actor.cpp +++ b/library/cpp/actors/cppcoro/task_actor.cpp @@ -1,4 +1,5 @@ #include "task_actor.h" +#include "await_callback.h" #include <library/cpp/actors/core/actor.h> #include <library/cpp/actors/core/hfunc.h> @@ -25,20 +26,30 @@ namespace NActors { struct TEvResumeTask : public TEventLocal<TEvResumeTask, EvResumeTask> { std::coroutine_handle<> Handle; + TTaskResult<void>* Result; - explicit TEvResumeTask(std::coroutine_handle<> handle) noexcept + explicit TEvResumeTask(std::coroutine_handle<> handle, TTaskResult<void>* result) noexcept : Handle(handle) + , Result(result) {} ~TEvResumeTask() noexcept { - // TODO: actor may be dead already + if (Handle) { + Result->SetException(std::make_exception_ptr(TTaskCancelled())); + Handle.resume(); + } } }; + class TTaskActorResult final : public TAtomicRefCount<TTaskActorResult> { + public: + bool Finished = false; + }; + class TTaskActorImpl : public TActor<TTaskActorImpl> { friend class TTaskActor; - friend struct TAfterAwaiter; - friend struct TBindAwaiter; + friend class TAfterAwaiter; + friend class TBindAwaiter; public: TTaskActorImpl(TTask<void>&& task) @@ -48,6 +59,15 @@ namespace NActors { Y_VERIFY(Task); } + ~TTaskActorImpl() { + Stopped = true; + while (EventAwaiter) { + // Unblock event awaiter until task stops trying + TCurrentTaskActorGuard guard(this); + std::exchange(EventAwaiter, {}).resume(); + } + } + void Registered(TActorSystem* sys, const TActorId& parent) override { ParentId = parent; sys->Send(new IEventHandle(TEvents::TSystem::Bootstrap, 0, SelfId(), SelfId(), {}, 0)); @@ -57,7 +77,15 @@ namespace NActors { Y_VERIFY(ev->GetTypeRewrite() == TEvents::TSystem::Bootstrap, "Expected bootstrap event"); TCurrentTaskActorGuard guard(this); Become(&TThis::StateWork); - Task.Start(); + AwaitThenCallback(std::move(Task).WhenDone(), + [result = Result](TTaskResult<void>&& outcome) noexcept { + result->Finished = true; + try { + outcome.Value(); + } catch (TTaskCancelled&) { + // ignore + } + }); Check(); } @@ -66,46 +94,50 @@ namespace NActors { switch (ev->GetTypeRewrite()) { hFunc(TEvResumeTask, Handle); default: - Y_VERIFY(EventWaiter); + Y_VERIFY(EventAwaiter); Event.reset(ev.Release()); - std::exchange(EventWaiter, {}).resume(); + std::exchange(EventAwaiter, {}).resume(); } Check(); } void Handle(TEvResumeTask::TPtr& ev) { auto* msg = ev->Get(); + msg->Result->SetValue(); std::exchange(msg->Handle, {}).resume(); } bool Check() { - if (Task.Done()) { - Y_VERIFY(!EventWaiter, "Task terminated while waiting for the next event"); - Task.ExtractResult(); + if (Result->Finished) { + Y_VERIFY(!EventAwaiter, "Task terminated while waiting for the next event"); PassAway(); return false; } - Y_VERIFY(EventWaiter, "Task suspended without waiting for the next event"); - Event.reset(); + Y_VERIFY(EventAwaiter, "Task suspended without waiting for the next event"); return true; } void WaitForEvent(std::coroutine_handle<> h) noexcept { - Y_VERIFY(!EventWaiter, "Task cannot have multiple waiters for the next event"); - EventWaiter = h; + Y_VERIFY(!EventAwaiter, "Task cannot have multiple awaiters for the next event"); + EventAwaiter = h; } - std::unique_ptr<IEventHandle> FinishWaitForEvent() noexcept { + std::unique_ptr<IEventHandle> FinishWaitForEvent() { + if (Stopped) { + throw TTaskCancelled(); + } Y_VERIFY(Event, "Task does not have current event"); return std::move(Event); } private: + TIntrusivePtr<TTaskActorResult> Result = MakeIntrusive<TTaskActorResult>(); TTask<void> Task; TActorId ParentId; - std::coroutine_handle<> EventWaiter; + std::coroutine_handle<> EventAwaiter; std::unique_ptr<IEventHandle> Event; + bool Stopped = false; }; void TTaskActorNextEvent::await_suspend(std::coroutine_handle<> h) noexcept { @@ -113,7 +145,7 @@ namespace NActors { TlsCurrentTaskActor->WaitForEvent(h); } - std::unique_ptr<IEventHandle> TTaskActorNextEvent::await_resume() noexcept { + std::unique_ptr<IEventHandle> TTaskActorNextEvent::await_resume() { Y_VERIFY(TlsCurrentTaskActor, "Not in a task actor context"); return TlsCurrentTaskActor->FinishWaitForEvent(); } @@ -134,10 +166,7 @@ namespace NActors { void TAfterAwaiter::await_suspend(std::coroutine_handle<> h) noexcept { Y_VERIFY(TlsCurrentTaskActor, "Not in a task actor context"); - TlsCurrentTaskActor->Schedule(Duration, new TEvResumeTask(h)); - } - - void TAfterAwaiter::await_resume() { + TlsCurrentTaskActor->Schedule(Duration, new TEvResumeTask(h, &Result)); } bool TBindAwaiter::await_ready() noexcept { @@ -148,10 +177,7 @@ namespace NActors { } void TBindAwaiter::await_suspend(std::coroutine_handle<> h) noexcept { - Sys->Send(new IEventHandle(ActorId, ActorId, new TEvResumeTask(h))); - } - - void TBindAwaiter::await_resume() { + Sys->Send(new IEventHandle(ActorId, ActorId, new TEvResumeTask(h, &Result))); } } // namespace NActors diff --git a/library/cpp/actors/cppcoro/task_actor.h b/library/cpp/actors/cppcoro/task_actor.h index e4a1c9df3e..75d498a04e 100644 --- a/library/cpp/actors/cppcoro/task_actor.h +++ b/library/cpp/actors/cppcoro/task_actor.h @@ -8,28 +8,47 @@ namespace NActors { static void await_suspend(std::coroutine_handle<> h) noexcept; - static std::unique_ptr<IEventHandle> await_resume() noexcept; + static std::unique_ptr<IEventHandle> await_resume(); }; - struct TAfterAwaiter { - TDuration Duration; + class TAfterAwaiter { + public: + TAfterAwaiter(TDuration duration) + : Duration(duration) + {} static constexpr bool await_ready() noexcept { return false; } void await_suspend(std::coroutine_handle<> h) noexcept; - void await_resume(); + void await_resume() { + Result.Value(); + } + + private: + TDuration Duration; + TTaskResult<void> Result; }; - struct TBindAwaiter { - TActorSystem* Sys; - TActorId ActorId; + class TBindAwaiter { + public: + TBindAwaiter(TActorSystem* sys, const TActorId& actorId) + : Sys(sys) + , ActorId(actorId) + {} bool await_ready() noexcept; void await_suspend(std::coroutine_handle<> h) noexcept; - void await_resume(); + void await_resume() { + Result.Value(); + } + + private: + TActorSystem* Sys; + TActorId ActorId; + TTaskResult<void> Result; }; class TTaskActor { @@ -77,11 +96,10 @@ namespace NActors { */ template<class T> static TTask<T> Bind(TTask<T>&& task) { - // TODO: may run on non-actor thread, protect from unwind return [](TTask<T> task, TBindAwaiter bindTask) -> TTask<T> { - co_await task.WhenDone(); + auto result = co_await std::move(task).WhenDone(); co_await bindTask; - co_return task.ExtractResult(); + co_return std::move(result).Value(); }(std::move(task), Bind()); } }; diff --git a/library/cpp/actors/cppcoro/task_group.h b/library/cpp/actors/cppcoro/task_group.h index b57cf59529..b2496f57eb 100644 --- a/library/cpp/actors/cppcoro/task_group.h +++ b/library/cpp/actors/cppcoro/task_group.h @@ -1,10 +1,9 @@ #pragma once +#include "task_result.h" #include <util/generic/ptr.h> #include <util/system/compiler.h> #include <util/system/yassert.h> #include <coroutine> -#include <exception> -#include <variant> #include <atomic> #include <memory> @@ -13,50 +12,8 @@ namespace NActors { namespace NDetail { template<class T> - struct TTaskGroupResult final { + struct TTaskGroupResult final : public TTaskResult<T> { TTaskGroupResult* Next; - std::variant<std::exception_ptr, T> Result_; - - void SetException() { - Result_.template emplace<0>(std::current_exception()); - } - - template<class TResult> - void SetValue(TResult&& result) { - Result_.template emplace<1>(std::forward<TResult>(result)); - } - - T Extract() { - switch (Result_.index()) { - case 0: { - std::rethrow_exception(std::get<0>(std::move(Result_))); - } - case 1: { - return std::get<1>(std::move(Result_)); - } - } - std::terminate(); - } - }; - - template<> - struct TTaskGroupResult<void> final { - TTaskGroupResult* Next; - std::exception_ptr Exception_; - - void SetException() { - Exception_ = std::current_exception(); - } - - void SetValue() { - // nothing - } - - void Extract() { - if (Exception_) { - std::rethrow_exception(std::move(Exception_)); - } - } }; template<class T> @@ -70,6 +27,12 @@ namespace NActors { static constexpr uintptr_t MarkerAwaiting = 1; static constexpr uintptr_t MarkerDetached = 2; + ~TTaskGroupSink() noexcept { + if (!IsDetached()) { + Detach(); + } + } + std::coroutine_handle<> Push(std::unique_ptr<TTaskGroupResult<T>>&& result) noexcept { void* currentValue = LastReady.load(std::memory_order_acquire); for (;;) { @@ -78,8 +41,8 @@ namespace NActors { continue; } // We consume the awaiter - Y_VERIFY(ReadyQueue == nullptr, "TaskGroup is awaiting with non-empty ready queue"); - result->Next = nullptr; + Y_VERIFY_DEBUG(ReadyQueue == nullptr, "TaskGroup is awaiting with non-empty ready queue"); + result->Next = ReadyQueue; ReadyQueue = result.release(); return std::exchange(Continuation, {}); } @@ -103,7 +66,7 @@ namespace NActors { } Y_NO_INLINE std::coroutine_handle<> Suspend(std::coroutine_handle<> h) noexcept { - Y_VERIFY(ReadyQueue == nullptr, "Caller suspending with non-empty ready queue"); + Y_VERIFY_DEBUG(ReadyQueue == nullptr, "Caller suspending with non-empty ready queue"); Continuation = h; void* currentValue = LastReady.load(std::memory_order_acquire); for (;;) { @@ -151,6 +114,11 @@ namespace NActors { } } + bool IsDetached() const noexcept { + void* headValue = LastReady.load(std::memory_order_acquire); + return headValue == (void*)MarkerDetached; + } + void Detach() noexcept { // After this exchange all new results will be discarded void* headValue = LastReady.exchange((void*)MarkerDetached, std::memory_order_acq_rel); @@ -166,48 +134,38 @@ namespace NActors { }; template<class T> - class TTaskGroupPromiseBase { + class TTaskGroupResultHandler { public: - static auto initial_suspend() noexcept { return std::suspend_always{}; } - - class TFinalSuspend { - public: - TFinalSuspend(TTaskGroupPromiseBase& promise) - : Promise_(promise) - {} - - static bool await_ready() noexcept { return false; } - - Y_NO_INLINE std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept { - auto next = Promise_.Sink_->Push(std::move(Promise_.Result_)); - h.destroy(); - return next; - } - - static void await_resume() noexcept { std::terminate(); } + void unhandled_exception() noexcept { + Result->SetException(std::current_exception()); + } - private: - TTaskGroupPromiseBase& Promise_; - }; + template<class TResult> + void return_value(TResult&& result) { + Result->SetValue(std::forward<TResult>(result)); + } - auto final_suspend() noexcept { return TFinalSuspend(*this); } + protected: + std::unique_ptr<TTaskGroupResult<T>> Result = std::make_unique<TTaskGroupResult<T>>(); + }; + template<> + class TTaskGroupResultHandler<void> { + public: void unhandled_exception() noexcept { - Result_->SetException(); - Sink_->Push(std::move(Result_)); + Result->SetException(std::current_exception()); } - void SetSink(const TIntrusivePtr<TTaskGroupSink<T>>& sink) { - Sink_ = sink; + void return_void() noexcept { + Result->SetValue(); } protected: - std::unique_ptr<TTaskGroupResult<T>> Result_ = std::make_unique<TTaskGroupResult<T>>(); - TIntrusivePtr<TTaskGroupSink<T>> Sink_; + std::unique_ptr<TTaskGroupResult<void>> Result = std::make_unique<TTaskGroupResult<void>>(); }; template<class T> - class TTaskGroupPromise final : public TTaskGroupPromiseBase<T> { + class TTaskGroupPromise final : public TTaskGroupResultHandler<T> { public: using THandle = std::coroutine_handle<TTaskGroupPromise<T>>; @@ -215,24 +173,30 @@ namespace NActors { return THandle::from_promise(*this); } - template<class TResult> - void return_value(TResult&& result) { - this->Result_->SetValue(std::forward<TResult>(result)); - } - }; + static auto initial_suspend() noexcept { return std::suspend_always{}; } - template<> - class TTaskGroupPromise<void> final : public TTaskGroupPromiseBase<void> { - public: - using THandle = std::coroutine_handle<TTaskGroupPromise<void>>; + struct TFinalSuspend { + static bool await_ready() noexcept { return false; } + static void await_resume() noexcept { Y_FAIL("unexpected coroutine resume"); } - THandle get_return_object() noexcept { - return THandle::from_promise(*this); - } + Y_NO_INLINE + static std::coroutine_handle<> await_suspend(std::coroutine_handle<TTaskGroupPromise<T>> h) noexcept { + auto& promise = h.promise(); + auto sink = std::move(promise.Sink); + auto next = sink->Push(std::move(promise.Result)); + h.destroy(); + return next; + } + }; + + static auto final_suspend() noexcept { return TFinalSuspend{}; } - void return_void() { - this->Result_->SetValue(); + void SetSink(const TIntrusivePtr<TTaskGroupSink<T>>& sink) { + Sink = sink; } + + private: + TIntrusivePtr<TTaskGroupSink<T>> Sink; }; template<class T> @@ -244,16 +208,16 @@ namespace NActors { public: TTaskGroupTask(THandle handle) - : Handle_(handle) + : Handle(handle) {} void Start(const TIntrusivePtr<TTaskGroupSink<T>>& sink) { - Handle_.promise().SetSink(sink); - Handle_.resume(); + Handle.promise().SetSink(sink); + Handle.resume(); } private: - THandle Handle_; + THandle Handle; }; template<class T, class TAwaitable> @@ -273,6 +237,11 @@ namespace NActors { public: TTaskGroup() = default; + TTaskGroup(const TTaskGroup&) = delete; + TTaskGroup(TTaskGroup&&) = delete; + TTaskGroup& operator=(const TTaskGroup&) = delete; + TTaskGroup& operator=(TTaskGroup&&) = delete; + ~TTaskGroup() { Sink_->Detach(); } @@ -311,7 +280,7 @@ namespace NActors { } T await_resume() { - return TaskGroup_.Sink_->Resume()->Extract(); + return std::move(*TaskGroup_.Sink_->Resume()).Value(); } private: diff --git a/library/cpp/actors/cppcoro/task_result.cpp b/library/cpp/actors/cppcoro/task_result.cpp new file mode 100644 index 0000000000..bb1a1dc5ca --- /dev/null +++ b/library/cpp/actors/cppcoro/task_result.cpp @@ -0,0 +1 @@ +#include "task_result.h" diff --git a/library/cpp/actors/cppcoro/task_result.h b/library/cpp/actors/cppcoro/task_result.h new file mode 100644 index 0000000000..70176e64d5 --- /dev/null +++ b/library/cpp/actors/cppcoro/task_result.h @@ -0,0 +1,113 @@ +#pragma once +#include <util/system/yassert.h> +#include <exception> +#include <variant> + +namespace NActors { + + namespace NDetail { + + struct TVoid {}; + + template<class T> + struct TReplaceVoid { + using TType = T; + }; + + template<> + struct TReplaceVoid<void> { + using TType = TVoid; + }; + + template<class T> + struct TLValue { + using TType = T&; + }; + + template<> + struct TLValue<void> { + using TType = void; + }; + + template<class T> + struct TRValue { + using TType = T&&; + }; + + template<> + struct TRValue<void> { + using TType = void; + }; + + } // namespace NDetail + + /** + * Wrapper for the task result + */ + template<class T> + class TTaskResult { + public: + void SetValue() + requires (std::same_as<T, void>) + { + Result.template emplace<1>(); + } + + template<class TResult> + void SetValue(TResult&& result) + requires (!std::same_as<T, void>) + { + Result.template emplace<1>(std::forward<TResult>(result)); + } + + void SetException(std::exception_ptr&& e) noexcept { + Result.template emplace<2>(std::move(e)); + } + + typename NDetail::TLValue<T>::TType Value() & { + switch (Result.index()) { + case 0: { + Y_FAIL("Task result has no value"); + } + case 1: { + if constexpr (std::same_as<T, void>) { + return; + } else { + return std::get<1>(Result); + } + } + case 2: { + std::exception_ptr& e = std::get<2>(Result); + Y_VERIFY_DEBUG(e, "Task exception missing"); + std::rethrow_exception(e); + } + } + Y_FAIL("Task result has an invalid state"); + } + + typename NDetail::TRValue<T>::TType Value() && { + switch (Result.index()) { + case 0: { + Y_FAIL("Task result has no value"); + } + case 1: { + if constexpr (std::same_as<T, void>) { + return; + } else { + return std::get<1>(std::move(Result)); + } + } + case 2: { + std::exception_ptr& e = std::get<2>(Result); + Y_VERIFY_DEBUG(e, "Task exception missing"); + std::rethrow_exception(std::move(e)); + } + } + Y_FAIL("Task result has an invalid state"); + } + + private: + std::variant<std::monostate, typename NDetail::TReplaceVoid<T>::TType, std::exception_ptr> Result; + }; + +} // namespace NActors diff --git a/library/cpp/actors/cppcoro/task_ut.cpp b/library/cpp/actors/cppcoro/task_ut.cpp index 52c1b0e591..24ea9d0700 100644 --- a/library/cpp/actors/cppcoro/task_ut.cpp +++ b/library/cpp/actors/cppcoro/task_ut.cpp @@ -32,44 +32,26 @@ Y_UNIT_TEST_SUITE(Task) { UNIT_ASSERT_VALUES_EQUAL(*result, 42); } - Y_UNIT_TEST(DoneAndWhenDone) { - auto task = SimpleReturn42(); - UNIT_ASSERT(task); - UNIT_ASSERT(!task.Done()); - - bool whenDoneFinished = false; - AwaitThenCallback(task.WhenDone(), [&]() { - whenDoneFinished = true; - }); - UNIT_ASSERT(whenDoneFinished); - UNIT_ASSERT(task.Done()); - - // WhenDone can be used even when task is already done - whenDoneFinished = false; - AwaitThenCallback(task.WhenDone(), [&]() { - whenDoneFinished = true; - }); - UNIT_ASSERT(whenDoneFinished); - - std::optional<int> result; - AwaitThenCallback(std::move(task), [&](int value) { - result = value; + Y_UNIT_TEST(SimpleVoidWhenDone) { + std::optional<TTaskResult<void>> result; + AwaitThenCallback(SimpleReturnVoid().WhenDone(), [&](auto value) { + result = std::move(value); }); UNIT_ASSERT(result); - UNIT_ASSERT_VALUES_EQUAL(*result, 42); - UNIT_ASSERT(!task); + result->Value(); } - Y_UNIT_TEST(ManualStart) { - auto task = SimpleReturn42(); - UNIT_ASSERT(task && !task.Done()); - task.Start(); - UNIT_ASSERT(task.Done()); - UNIT_ASSERT_VALUES_EQUAL(task.ExtractResult(), 42); + Y_UNIT_TEST(SimpleIntWhenDone) { + std::optional<TTaskResult<int>> result; + AwaitThenCallback(SimpleReturn42().WhenDone(), [&](auto value) { + result = std::move(value); + }); + UNIT_ASSERT(result); + UNIT_ASSERT_VALUES_EQUAL(result->Value(), 42); } template<class TCallback> - TTask<int> CallTwice(TCallback&& callback) { + TTask<int> CallTwice(TCallback callback) { int a = co_await callback(); int b = co_await callback(); co_return a + b; @@ -79,6 +61,7 @@ Y_UNIT_TEST_SUITE(Task) { auto task = CallTwice([]{ return SimpleReturn42(); }); + UNIT_ASSERT(task); std::optional<int> result; AwaitThenCallback(std::move(task), [&](int value) { result = value; @@ -87,22 +70,37 @@ Y_UNIT_TEST_SUITE(Task) { UNIT_ASSERT_VALUES_EQUAL(*result, 84); } + template<class T> struct TPauseState { std::coroutine_handle<> Next; - int NextResult; + std::optional<T> NextResult; - auto Wait() { - struct TAwaiter { - TPauseState* State; + ~TPauseState() { + while (Next) { + NextResult.reset(); + std::exchange(Next, {}).resume(); + } + } - bool await_ready() const noexcept { return false; } - int await_resume() const noexcept { - return State->NextResult; - } - void await_suspend(std::coroutine_handle<> c) { - State->Next = c; + struct TAwaiter { + TPauseState* State; + + bool await_ready() const noexcept { return false; } + void await_suspend(std::coroutine_handle<> c) const noexcept { + State->Next = c; + } + T await_resume() const { + if (!State->NextResult) { + throw TTaskCancelled(); + } else { + T result = std::move(*State->NextResult); + State->NextResult.reset(); + return result; } - }; + } + }; + + auto Wait() { return TAwaiter{ this }; } @@ -110,19 +108,24 @@ Y_UNIT_TEST_SUITE(Task) { return bool(Next); } - void Resume(int result) { + void Resume(T result) { Y_VERIFY(Next && !Next.done()); NextResult = result; std::exchange(Next, {}).resume(); } + + void Cancel() { + Y_VERIFY(Next && !Next.done()); + NextResult.reset(); + std::exchange(Next, {}).resume(); + } }; - Y_UNIT_TEST(PausedAwait) { - TPauseState state; - auto callback = [&]{ + Y_UNIT_TEST(PauseResume) { + TPauseState<int> state; + auto task = CallTwice([&]{ return state.Wait(); - }; - auto task = CallTwice(callback); + }); std::optional<int> result; AwaitThenCallback(std::move(task), [&](int value) { result = value; @@ -137,63 +140,31 @@ Y_UNIT_TEST_SUITE(Task) { UNIT_ASSERT_VALUES_EQUAL(*result, 33); } - Y_UNIT_TEST(ManuallyStartThenWhenDone) { - TPauseState state; - auto next = [&]{ + Y_UNIT_TEST(PauseCancel) { + TPauseState<int> state; + auto task = CallTwice([&]{ return state.Wait(); - }; - - auto task = [](auto next) -> TTask<int> { - int value = co_await next(); - co_return value * 2; - }(next); - - UNIT_ASSERT(task && !task.Done()); - task.Start(); - UNIT_ASSERT(!task.Done() && state); - bool finished = false; - AwaitThenCallback(task.WhenDone(), [&]{ - finished = true; }); - UNIT_ASSERT(!finished && !task.Done()); - state.Resume(11); - UNIT_ASSERT(finished && task.Done()); - UNIT_ASSERT_VALUES_EQUAL(task.ExtractResult(), 22); - } - - Y_UNIT_TEST(ManuallyStartThenAwait) { - TPauseState state; - auto next = [&]{ - return state.Wait(); - }; - - auto task = [](auto next) -> TTask<int> { - int value = co_await next(); - co_return value * 2; - }(next); - - UNIT_ASSERT(task && !task.Done()); - task.Start(); - UNIT_ASSERT(!task.Done() && state); - - auto awaitTask = [](auto task) -> TTask<int> { - int value = co_await std::move(task); - co_return value * 3; - }(std::move(task)); - UNIT_ASSERT(awaitTask && !awaitTask.Done()); std::optional<int> result; - AwaitThenCallback(std::move(awaitTask), [&](int value) { - result = value; + AwaitThenCallback(std::move(task).WhenDone(), [&](TTaskResult<int>&& value) { + try { + result = value.Value(); + } catch (TTaskCancelled&) { + // nothing + } }); UNIT_ASSERT(!result); + UNIT_ASSERT(state); state.Resume(11); - UNIT_ASSERT(result); - UNIT_ASSERT_VALUES_EQUAL(*result, 66); + UNIT_ASSERT(!result); + UNIT_ASSERT(state); + state.Cancel(); + UNIT_ASSERT(!result); } Y_UNIT_TEST(GroupWithTwoSubTasks) { - TPauseState state1; - TPauseState state2; + TPauseState<int> state1; + TPauseState<int> state2; std::vector<int> results; auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> { @@ -227,8 +198,8 @@ Y_UNIT_TEST_SUITE(Task) { } Y_UNIT_TEST(GroupWithTwoSubTasksDetached) { - TPauseState state1; - TPauseState state2; + TPauseState<int> state1; + TPauseState<int> state2; std::vector<int> results; auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> { @@ -253,9 +224,40 @@ Y_UNIT_TEST_SUITE(Task) { UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22); UNIT_ASSERT(result); UNIT_ASSERT_VALUES_EQUAL(*result, 22); + } - // We must resume the first state (otherwise memory leaks), but result is ignored + Y_UNIT_TEST(GroupWithTwoSubTasksOneCancelled) { + TPauseState<int> state1; + TPauseState<int> state2; + std::vector<int> results; + auto task = [](auto& state1, auto& state2, auto& results) -> TTask<void> { + TTaskGroup<int> group; + group.AddTask(state1.Wait()); + group.AddTask(state2.Wait()); + for (int i = 0; i < 2; ++i) { + try { + results.push_back(co_await group); + } catch (TTaskCancelled&) { + results.push_back(-1); + } + } + }(state1, state2, results); + + bool finished = false; + AwaitThenCallback(std::move(task), [&]() { + finished = true; + }); + + UNIT_ASSERT(state1); + UNIT_ASSERT(state2); + state2.Cancel(); + UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u); + UNIT_ASSERT_VALUES_EQUAL(results.at(0), -1); + UNIT_ASSERT(!finished); state1.Resume(11); + UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u); + UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11); + UNIT_ASSERT(finished); } } // Y_UNIT_TEST_SUITE(Task) diff --git a/library/cpp/actors/cppcoro/ya.make b/library/cpp/actors/cppcoro/ya.make index 9890eccbee..c1d8c225aa 100644 --- a/library/cpp/actors/cppcoro/ya.make +++ b/library/cpp/actors/cppcoro/ya.make @@ -11,6 +11,8 @@ SRCS( task_actor.h task_group.cpp task_group.h + task_result.cpp + task_result.h task.cpp task.h ) @@ -18,5 +20,6 @@ SRCS( END() RECURSE_FOR_TESTS( + benchmark ut ) |