aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/threading/future/ut_gtest/coroutine_traits_ut.cpp
blob: 2daa1f5e47accbfba77c63616c2a0cbe2ad60530 (plain) (tree)



























































































































































































                                                                                                                                  
















                                                                                                         
#include <library/cpp/testing/gtest/gtest.h>
#include <library/cpp/testing/gtest_extensions/gtest_extensions.h>
#include <library/cpp/threading/future/future.h>
#include <library/cpp/threading/future/core/coroutine_traits.h>

#include <util/generic/vector.h>
#include <util/generic/scope.h>
#include <util/system/platform.h>


TEST(TestFutureTraits, Simple) {
    TVector<TString> result;

    auto coroutine1 = [&result]() -> NThreading::TFuture<size_t> {
        result.push_back("coroutine1");
        co_return 1;
    };

    NThreading::TPromise<size_t> coroutine2SuspendPromise = NThreading::NewPromise<size_t>();
    auto coroutine2 = [&result, coroutine2SuspendFuture = coroutine2SuspendPromise.GetFuture()]() -> NThreading::TFuture<size_t> {
        result.push_back("coroutine2");

        result.push_back("pre_coroutine2_suspend_future");
        size_t futureResult = co_await coroutine2SuspendFuture;
        result.push_back("post_coroutine2_suspend_future");

        co_return 2 + futureResult;
    };

    auto coroutineAll = [&]() -> NThreading::TFuture<size_t> {
        Y_DEFER {
            result.push_back("coroutine_all_destroy");
        };

        result.push_back("pre_coroutine1");
        size_t coroutine1Res = co_await coroutine1();
        result.push_back("post_coroutine1");

        result.push_back("pre_coroutine2");
        size_t coroutine2Res = co_await coroutine2();
        result.push_back("post_coroutine2");

        co_return coroutine1Res + coroutine2Res;
    };

    NThreading::TFuture<size_t> coroutineAllFuture = coroutineAll();
    EXPECT_FALSE(coroutineAllFuture.HasValue());
    EXPECT_FALSE(coroutineAllFuture.HasException());
    EXPECT_THAT(
        result,
        ::testing::ContainerEq(
            TVector<TString>({
                "pre_coroutine1",
                "coroutine1",
                "post_coroutine1",

                "pre_coroutine2",
                "coroutine2",
                "pre_coroutine2_suspend_future"
            })
        )
    );

    coroutine2SuspendPromise.SetValue(3u);
    EXPECT_TRUE(coroutineAllFuture.HasValue());
    EXPECT_EQ(coroutineAllFuture.GetValue(), 6u);
    EXPECT_THAT(
        result,
        ::testing::ContainerEq(
            TVector<TString>({
                "pre_coroutine1",
                "coroutine1",
                "post_coroutine1",

                "pre_coroutine2",
                "coroutine2",
                "pre_coroutine2_suspend_future",
                "post_coroutine2_suspend_future",
                "post_coroutine2",

                "coroutine_all_destroy"
            })
        )
    );
}

TEST(TestFutureTraits, Exception) {
    TVector<TString> result;

    auto coroutine1 = [&result]() -> NThreading::TFuture<size_t> {
        result.push_back("coroutine1");
        co_return 1;
    };

    auto coroutine2 = [&result]() -> NThreading::TFuture<size_t> {
        result.push_back("coroutine2");
        ythrow yexception() << "coroutine2 exception";
    };

    auto coroutineAll = [&]() -> NThreading::TFuture<size_t> {
        Y_DEFER {
            result.push_back("coroutine_all_destroy");
        };

        result.push_back("pre_coroutine1");
        size_t coroutine1Res = co_await coroutine1();
        result.push_back("post_coroutine1");

        result.push_back("pre_coroutine2");
        size_t coroutine2Res = co_await coroutine2();
        result.push_back("post_coroutine2");

        co_return coroutine1Res + coroutine2Res;
    };

    EXPECT_THROW_MESSAGE_HAS_SUBSTR(
        coroutineAll().GetValueSync(),
        yexception,
        "coroutine2 exception"
    );
    EXPECT_THAT(
        result,
        ::testing::ContainerEq(
            TVector<TString>({
                "pre_coroutine1",
                "coroutine1",
                "post_coroutine1",

                "pre_coroutine2",
                "coroutine2",

                "coroutine_all_destroy"
            })
        )
    );
}

TEST(TestFutureTraits, CrashOnExceptionInCoroutineHandlerResume) {
    EXPECT_DEATH(
        {
            struct TBadPromise;
            struct TBadCoroutine : std::coroutine_handle<TBadPromise> {
                using promise_type = TBadPromise;
            };

            struct TBadPromise {
                TBadCoroutine get_return_object() {
                    return {TBadCoroutine::from_promise(*this)};
                }

                std::suspend_never initial_suspend() noexcept {
                    return {};
                }
                std::suspend_never final_suspend() noexcept {
                    return {};
                }
                void return_void() {
                }
                void unhandled_exception() {
                    throw;
                }
            };

            auto badCoroutine = []() -> TBadCoroutine {
                ythrow yexception() << "bad coroutine exception";
            };
            // Sanity check
            EXPECT_THROW_MESSAGE_HAS_SUBSTR(
                badCoroutine(),
                yexception,
                "bad coroutine exception"
            );

            NThreading::TPromise<void> promise = NThreading::NewPromise<void>();
            auto badCoroutineWithFutureAwait = [future = promise.GetFuture()]() -> TBadCoroutine {
                co_await future;
                ythrow yexception() << "bad coroutine with future await exception";
            };

            badCoroutineWithFutureAwait();
            promise.SetValue();
        },
#if defined(_win_)
        ".*"
#else
        "bad coroutine with future await exception"
#endif
    );
}

TEST(ExtractingFutureAwaitable, Simple) {
    NThreading::TPromise<THolder<size_t>> suspendPromise = NThreading::NewPromise<THolder<size_t>>();
    auto coro = [](NThreading::TFuture<THolder<size_t>> future) -> NThreading::TFuture<THolder<size_t>> {
        auto value = co_await NThreading::AsExtractingAwaitable(std::move(future));
        co_return value;
    };

    NThreading::TFuture<THolder<size_t>> getHolder = coro(suspendPromise.GetFuture());
    EXPECT_FALSE(getHolder.HasValue());
    EXPECT_FALSE(getHolder.HasException());
    suspendPromise.SetValue(MakeHolder<size_t>(42));

    EXPECT_TRUE(getHolder.HasValue());
    auto holder = getHolder.ExtractValue();
    ASSERT_NE(holder, nullptr);
    EXPECT_EQ(*holder, 42u);
}