aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/actors/cppcoro/task_ut.cpp
blob: 24ea9d0700911de8ecd9a18225b27a9415b9d114 (plain) (tree)
































                                                            


                                                                          
                            
                        
     





                                                                        

                             
                                              







                                    
                          






                                                           
                     
                                     
                                    
 




                                                 
 












                                                                          
                 


                     





                                        
                               


                                             




                                             
      

                                  
                                
           












                                                           

                                  
                                
           
                                  




                                                                                     
                             
                           
                         


                             

                                       
                                































                                                                                 
                                






















                                                                                 
     
 


























                                                                                  
                          

                                                     

                            
#include "task.h"
#include "task_group.h"
#include "await_callback.h"
#include <library/cpp/testing/unittest/registar.h>

using namespace NActors;

Y_UNIT_TEST_SUITE(Task) {

    TTask<void> SimpleReturnVoid() {
        co_return;
    }

    TTask<int> SimpleReturn42() {
        co_return 42;
    }

    Y_UNIT_TEST(SimpleVoidCoroutine) {
        bool finished = false;
        AwaitThenCallback(SimpleReturnVoid(), [&]() {
            finished = true;
        });
        UNIT_ASSERT(finished);
    }

    Y_UNIT_TEST(SimpleIntCoroutine) {
        std::optional<int> result;
        AwaitThenCallback(SimpleReturn42(), [&](int value) {
            result = value;
        });
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(*result, 42);
    }

    Y_UNIT_TEST(SimpleVoidWhenDone) {
        std::optional<TTaskResult<void>> result;
        AwaitThenCallback(SimpleReturnVoid().WhenDone(), [&](auto value) {
            result = std::move(value);
        });
        UNIT_ASSERT(result);
        result->Value();
    }

    Y_UNIT_TEST(SimpleIntWhenDone) {
        std::optional<TTaskResult<int>> result;
        AwaitThenCallback(SimpleReturn42().WhenDone(), [&](auto value) {
            result = std::move(value);
        });
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(result->Value(), 42);
    }

    template<class TCallback>
    TTask<int> CallTwice(TCallback callback) {
        int a = co_await callback();
        int b = co_await callback();
        co_return a + b;
    }

    Y_UNIT_TEST(NestedAwait) {
        auto task = CallTwice([]{
            return SimpleReturn42();
        });
        UNIT_ASSERT(task);
        std::optional<int> result;
        AwaitThenCallback(std::move(task), [&](int value) {
            result = value;
        });
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(*result, 84);
    }

    template<class T>
    struct TPauseState {
        std::coroutine_handle<> Next;
        std::optional<T> NextResult;

        ~TPauseState() {
            while (Next) {
                NextResult.reset();
                std::exchange(Next, {}).resume();
            }
        }

        struct TAwaiter {
            TPauseState* State;

            bool await_ready() const noexcept { return false; }
            void await_suspend(std::coroutine_handle<> c) const noexcept {
                State->Next = c;
            }
            T await_resume() const {
                if (!State->NextResult) {
                    throw TTaskCancelled();
                } else {
                    T result = std::move(*State->NextResult);
                    State->NextResult.reset();
                    return result;
                }
            }
        };

        auto Wait() {
            return TAwaiter{ this };
        }

        explicit operator bool() const {
            return bool(Next);
        }

        void Resume(T result) {
            Y_VERIFY(Next && !Next.done());
            NextResult = result;
            std::exchange(Next, {}).resume();
        }

        void Cancel() {
            Y_VERIFY(Next && !Next.done());
            NextResult.reset();
            std::exchange(Next, {}).resume();
        }
    };

    Y_UNIT_TEST(PauseResume) {
        TPauseState<int> state;
        auto task = CallTwice([&]{
            return state.Wait();
        });
        std::optional<int> result;
        AwaitThenCallback(std::move(task), [&](int value) {
            result = value;
        });
        UNIT_ASSERT(!result);
        UNIT_ASSERT(state);
        state.Resume(11);
        UNIT_ASSERT(!result);
        UNIT_ASSERT(state);
        state.Resume(22);
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(*result, 33);
    }

    Y_UNIT_TEST(PauseCancel) {
        TPauseState<int> state;
        auto task = CallTwice([&]{
            return state.Wait();
        });
        std::optional<int> result;
        AwaitThenCallback(std::move(task).WhenDone(), [&](TTaskResult<int>&& value) {
            try {
                result = value.Value();
            } catch (TTaskCancelled&) {
                // nothing
            }
        });
        UNIT_ASSERT(!result);
        UNIT_ASSERT(state);
        state.Resume(11);
        UNIT_ASSERT(!result);
        UNIT_ASSERT(state);
        state.Cancel();
        UNIT_ASSERT(!result);
    }

    Y_UNIT_TEST(GroupWithTwoSubTasks) {
        TPauseState<int> state1;
        TPauseState<int> state2;

        std::vector<int> results;
        auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
            TTaskGroup<int> group;
            group.AddTask(state1.Wait());
            group.AddTask(state2.Wait());
            int a = co_await group;
            results.push_back(a);
            int b = co_await group;
            results.push_back(b);
            co_return a + b;
        }(state1, state2, results);

        std::optional<int> result;
        AwaitThenCallback(std::move(task), [&](int value) {
            result = value;
        });

        // We must be waiting for both states
        UNIT_ASSERT(state1);
        UNIT_ASSERT(state2);
        state2.Resume(22);
        UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
        UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22);
        UNIT_ASSERT(!result);
        state1.Resume(11);
        UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u);
        UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11);
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(*result, 33);
    }

    Y_UNIT_TEST(GroupWithTwoSubTasksDetached) {
        TPauseState<int> state1;
        TPauseState<int> state2;

        std::vector<int> results;
        auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
            TTaskGroup<int> group;
            group.AddTask(state1.Wait());
            group.AddTask(state2.Wait());
            int a = co_await group;
            results.push_back(a);
            co_return a;
        }(state1, state2, results);

        std::optional<int> result;
        AwaitThenCallback(std::move(task), [&](int value) {
            result = value;
        });

        // We must be waiting for both states
        UNIT_ASSERT(state1);
        UNIT_ASSERT(state2);
        state2.Resume(22);
        UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
        UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22);
        UNIT_ASSERT(result);
        UNIT_ASSERT_VALUES_EQUAL(*result, 22);
    }

    Y_UNIT_TEST(GroupWithTwoSubTasksOneCancelled) {
        TPauseState<int> state1;
        TPauseState<int> state2;
        std::vector<int> results;
        auto task = [](auto& state1, auto& state2, auto& results) -> TTask<void> {
            TTaskGroup<int> group;
            group.AddTask(state1.Wait());
            group.AddTask(state2.Wait());
            for (int i = 0; i < 2; ++i) {
                try {
                    results.push_back(co_await group);
                } catch (TTaskCancelled&) {
                    results.push_back(-1);
                }
            }
        }(state1, state2, results);

        bool finished = false;
        AwaitThenCallback(std::move(task), [&]() {
            finished = true;
        });

        UNIT_ASSERT(state1);
        UNIT_ASSERT(state2);
        state2.Cancel();
        UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
        UNIT_ASSERT_VALUES_EQUAL(results.at(0), -1);
        UNIT_ASSERT(!finished);
        state1.Resume(11);
        UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u);
        UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11);
        UNIT_ASSERT(finished);
    }

} // Y_UNIT_TEST_SUITE(Task)