aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/actors/cppcoro/await_callback.h
blob: fcb2eb78f9821f62166066794fcc26fce9332e67 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <coroutine>
#include <exception>
#include <concepts>

namespace NActors {

    namespace NDetail {

        template<class TAwaitable>
        decltype(auto) GetAwaiter(TAwaitable&& awaitable) {
            if constexpr (requires { ((TAwaitable&&) awaitable).operator co_await(); }) {
                return ((TAwaitable&&) awaitable).operator co_await();
            } else if constexpr (requires { operator co_await((TAwaitable&&) awaitable); }) {
                return operator co_await((TAwaitable&&) awaitable);
            } else {
                return ((TAwaitable&&) awaitable);
            }
        }

        template<class TAwaitable>
        using TAwaitResult = decltype(GetAwaiter(std::declval<TAwaitable>()).await_resume());

        template<class TCallback, class TResult>
        class TCallbackResult {
        public:
            TCallbackResult(TCallback& callback)
                : Callback(callback)
            {}

            template<class TRealResult>
            void return_value(TRealResult&& result) noexcept {
                Callback(std::forward<TRealResult>(result));
            }

        private:
            TCallback& Callback;
        };

        template<class TCallback>
        class TCallbackResult<TCallback, void> {
        public:
            TCallbackResult(TCallback& callback)
                : Callback(callback)
            {}

            void return_void() noexcept {
                Callback();
            }

        private:
            TCallback& Callback;
        };

        template<class TAwaitable, class TCallback>
        class TAwaitThenCallbackPromise
            : public TCallbackResult<TCallback, TAwaitResult<TAwaitable>>
        {
        public:
            using THandle = std::coroutine_handle<TAwaitThenCallbackPromise<TAwaitable, TCallback>>;

            TAwaitThenCallbackPromise(TAwaitable&, TCallback& callback)
                : TCallbackResult<TCallback, TAwaitResult<TAwaitable>>(callback)
            {}

            THandle get_return_object() noexcept {
                return THandle::from_promise(*this);
            }

            static auto initial_suspend() noexcept { return std::suspend_never{}; }
            static auto final_suspend() noexcept { return std::suspend_never{}; }

            void unhandled_exception() noexcept {
                std::terminate();
            }
        };

        template<class TAwaitable, class TCallback>
        class TAwaitThenCallback {
        public:
            using promise_type = TAwaitThenCallbackPromise<TAwaitable, TCallback>;

            using THandle = typename promise_type::THandle;

            TAwaitThenCallback(THandle) noexcept {}
        };

    } // namespace NDetail

    /**
     * Awaits the awaitable and calls callback with the result.
     *
     * Note: program terminates if awaitable or callback throw an exception.
     */
    template<class TAwaitable, class TCallback>
    NDetail::TAwaitThenCallback<TAwaitable, TCallback> AwaitThenCallback(TAwaitable awaitable, TCallback) {
        // Note: underlying promise takes callback argument address and calls it when we return
        co_return co_await std::move(awaitable);
    }

} // namespace NActors