aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/threading/future/core/coroutine_traits.h
blob: cdd3eeeff3e7d1e63231cf69cb336aeba9e04d60 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#pragma once

#include <library/cpp/threading/future/future.h>

#include <coroutine>

template <typename... Args>
struct std::coroutine_traits<NThreading::TFuture<void>, Args...> {
    struct promise_type {

        NThreading::TFuture<void> get_return_object() {
            return Promise_.GetFuture();
        }

        std::suspend_never initial_suspend() { return {}; }
        std::suspend_never final_suspend() noexcept { return {}; }

        void unhandled_exception() {
            Promise_.SetException(std::current_exception());
        }

        void return_void() {
            Promise_.SetValue();
        }

    private:
        NThreading::TPromise<void> Promise_ = NThreading::NewPromise();
    };
};

template <typename T, typename... Args>
struct std::coroutine_traits<NThreading::TFuture<T>, Args...> {
    struct promise_type {
        NThreading::TFuture<T> get_return_object() {
            return Promise_.GetFuture();
        }

        std::suspend_never initial_suspend() { return {}; }
        std::suspend_never final_suspend() noexcept { return {}; }

        void unhandled_exception() {
            Promise_.SetException(std::current_exception());
        }

        void return_value(auto&& val) {
            Promise_.SetValue(std::forward<decltype(val)>(val));
        }

    private:
        NThreading::TPromise<T> Promise_ = NThreading::NewPromise<T>();
    };
};

namespace NThreading {

    template <typename T, bool Extracting = false>
    struct TFutureAwaitable {
        NThreading::TFuture<T> Future;

        TFutureAwaitable(const NThreading::TFuture<T>& future) noexcept requires (!Extracting)
            : Future{future}
        {
        }

        TFutureAwaitable(NThreading::TFuture<T>&& future) noexcept
            : Future{std::move(future)}
        {
        }

        bool await_ready() const noexcept {
            return Future.IsReady();
        }

        void await_suspend(auto h) noexcept {
            /*
            * This library assumes that resume never throws an exception.
            * This assumption is made due to the fact that the users of these library in most cases do not need to write their own coroutine handlers,
            * and all coroutine handlers provided by the library do not throw exception from resume.
            *
            * WARNING: do not change subscribe to apply or something other here, creating an extra future state degrades performance.
            */
            Future.NoexceptSubscribe(
                [h](auto) mutable noexcept {
                    h();
                }
            );
        }

        decltype(auto) await_resume() {
            if constexpr (Extracting && !std::is_same_v<T, void>) {  // Future<void> has only GetValue()
                return Future.ExtractValue();
            } else {
                return Future.GetValue();
            }
        }
    };

    template <typename T>
    using TExtractingFutureAwaitable = TFutureAwaitable<T, true>;

} // namespace NThreading

template <typename T>
auto operator co_await(const NThreading::TFuture<T>& future) noexcept {
    return NThreading::TFutureAwaitable{future};
}

template <typename T>
auto operator co_await(NThreading::TFuture<T>&& future) noexcept {
    // Not TExtractongFutureAwaitable, because TFuture works like std::shared_future.
    // auto value = co_await GetCachedFuture();
    // If GetCachedFuture stores a future in some cache and returns its copies,
    // then subsequent uses of co_await will return a moved-from value.
    return NThreading::TFutureAwaitable{std::move(future)};
}

namespace NThreading {

    template <typename T>
    auto AsAwaitable(const NThreading::TFuture<T>& fut) noexcept {
        return TFutureAwaitable(fut);
    }

    template <typename T>
    auto AsExtractingAwaitable(NThreading::TFuture<T>&& fut) noexcept {
        return TExtractingFutureAwaitable<T>(std::move(fut));
    }

} // namespace NThreading