#pragma once
#include "task_result.h"
#include <util/generic/ptr.h>
#include <util/system/compiler.h>
#include <util/system/yassert.h>
#include <coroutine>
#include <atomic>
#include <memory>

namespace NActors {

    namespace NDetail {

        template<class T>
        struct TTaskGroupResult final : public TTaskResult<T> {
            TTaskGroupResult* Next;
        };

        template<class T>
        struct TTaskGroupSink final
            : public TAtomicRefCount<TTaskGroupSink<T>>
        {
            std::atomic<void*> LastReady{ nullptr };
            TTaskGroupResult<T>* ReadyQueue = nullptr;
            std::coroutine_handle<> Continuation;

            static constexpr uintptr_t MarkerAwaiting = 1;
            static constexpr uintptr_t MarkerDetached = 2;

            ~TTaskGroupSink() noexcept {
                if (!IsDetached()) {
                    Detach();
                }
            }

            std::coroutine_handle<> Push(std::unique_ptr<TTaskGroupResult<T>>&& result) noexcept {
                void* currentValue = LastReady.load(std::memory_order_acquire);
                for (;;) {
                    if (currentValue == (void*)MarkerAwaiting) {
                        if (Y_UNLIKELY(!LastReady.compare_exchange_weak(currentValue, nullptr, std::memory_order_acquire))) {
                            continue;
                        }
                        // We consume the awaiter
                        Y_VERIFY_DEBUG(ReadyQueue == nullptr, "TaskGroup is awaiting with non-empty ready queue");
                        result->Next = ReadyQueue;
                        ReadyQueue = result.release();
                        return std::exchange(Continuation, {});
                    }
                    if (currentValue == (void*)MarkerDetached) {
                        // Task group is detached, discard the result
                        return std::noop_coroutine();
                    }
                    TTaskGroupResult<T>* current = reinterpret_cast<TTaskGroupResult<T>*>(currentValue);
                    result->Next = current;
                    void* nextValue = result.get();
                    if (Y_LIKELY(LastReady.compare_exchange_weak(currentValue, nextValue, std::memory_order_acq_rel))) {
                        // Result successfully added
                        result.release();
                        return std::noop_coroutine();
                    }
                }
            }

            bool Ready() const noexcept {
                return ReadyQueue != nullptr || LastReady.load(std::memory_order_acquire) != nullptr;
            }

            Y_NO_INLINE std::coroutine_handle<> Suspend(std::coroutine_handle<> h) noexcept {
                Y_VERIFY_DEBUG(ReadyQueue == nullptr, "Caller suspending with non-empty ready queue");
                Continuation = h;
                void* currentValue = LastReady.load(std::memory_order_acquire);
                for (;;) {
                    if (currentValue == nullptr) {
                        if (Y_UNLIKELY(!LastReady.compare_exchange_weak(currentValue, (void*)MarkerAwaiting, std::memory_order_release))) {
                            continue;
                        }
                        // Continuation may wake up on another thread
                        return std::noop_coroutine();
                    }
                    Y_VERIFY(currentValue != (void*)MarkerAwaiting, "TaskGroup is suspending with an awaiting marker");
                    Y_VERIFY(currentValue != (void*)MarkerDetached, "TaskGroup is suspending with a detached marker");
                    // Race: ready queue is not actually empty
                    Continuation = {};
                    return h;
                }
            }

            std::unique_ptr<TTaskGroupResult<T>> Resume() noexcept {
                std::unique_ptr<TTaskGroupResult<T>> result;
                if (ReadyQueue == nullptr) {
                    void* headValue = LastReady.exchange(nullptr, std::memory_order_acq_rel);
                    Y_VERIFY(headValue != (void*)MarkerAwaiting, "TaskGroup is resuming with an awaiting marker");
                    Y_VERIFY(headValue != (void*)MarkerDetached, "TaskGroup is resuming with a detached marker");
                    Y_VERIFY(headValue, "TaskGroup is resuming with an empty queue");
                    TTaskGroupResult<T>* head = reinterpret_cast<TTaskGroupResult<T>*>(headValue);
                    while (head) {
                        auto* next = std::exchange(head->Next, nullptr);
                        head->Next = ReadyQueue;
                        ReadyQueue = head;
                        head = next;
                    }
                }
                Y_VERIFY(ReadyQueue != nullptr);
                result.reset(ReadyQueue);
                ReadyQueue = std::exchange(result->Next, nullptr);
                return result;
            }

            static void Dispose(TTaskGroupResult<T>* head) noexcept {
                while (head) {
                    auto* next = std::exchange(head->Next, nullptr);
                    std::unique_ptr<TTaskGroupResult<T>> ptr(head);
                    head = next;
                }
            }

            bool IsDetached() const noexcept {
                void* headValue = LastReady.load(std::memory_order_acquire);
                return headValue == (void*)MarkerDetached;
            }

            void Detach() noexcept {
                // After this exchange all new results will be discarded
                void* headValue = LastReady.exchange((void*)MarkerDetached, std::memory_order_acq_rel);
                Y_VERIFY(headValue != (void*)MarkerAwaiting, "TaskGroup is detaching with an awaiting marker");
                Y_VERIFY(headValue != (void*)MarkerDetached, "TaskGroup is detaching with a detached marker");
                if (headValue) {
                    Dispose(reinterpret_cast<TTaskGroupResult<T>*>(headValue));
                }
                if (ReadyQueue) {
                    Dispose(std::exchange(ReadyQueue, nullptr));
                }
            }
        };

        template<class T>
        class TTaskGroupResultHandler {
        public:
            void unhandled_exception() noexcept {
                Result->SetException(std::current_exception());
            }

            template<class TResult>
            void return_value(TResult&& result) {
                Result->SetValue(std::forward<TResult>(result));
            }

        protected:
            std::unique_ptr<TTaskGroupResult<T>> Result = std::make_unique<TTaskGroupResult<T>>();
        };

        template<>
        class TTaskGroupResultHandler<void> {
        public:
            void unhandled_exception() noexcept {
                Result->SetException(std::current_exception());
            }

            void return_void() noexcept {
                Result->SetValue();
            }

        protected:
            std::unique_ptr<TTaskGroupResult<void>> Result = std::make_unique<TTaskGroupResult<void>>();
        };

        template<class T>
        class TTaskGroupPromise final : public TTaskGroupResultHandler<T> {
        public:
            using THandle = std::coroutine_handle<TTaskGroupPromise<T>>;

            THandle get_return_object() noexcept {
                return THandle::from_promise(*this);
            }

            static auto initial_suspend() noexcept { return std::suspend_always{}; }

            struct TFinalSuspend {
                static bool await_ready() noexcept { return false; }
                static void await_resume() noexcept { Y_FAIL("unexpected coroutine resume"); }

                Y_NO_INLINE
                static std::coroutine_handle<> await_suspend(std::coroutine_handle<TTaskGroupPromise<T>> h) noexcept {
                    auto& promise = h.promise();
                    auto sink = std::move(promise.Sink);
                    auto next = sink->Push(std::move(promise.Result));
                    h.destroy();
                    return next;
                }
            };

            static auto final_suspend() noexcept { return TFinalSuspend{}; }

            void SetSink(const TIntrusivePtr<TTaskGroupSink<T>>& sink) {
                Sink = sink;
            }

        private:
            TIntrusivePtr<TTaskGroupSink<T>> Sink;
        };

        template<class T>
        class TTaskGroupTask final {
        public:
            using THandle = std::coroutine_handle<TTaskGroupPromise<T>>;
            using promise_type = TTaskGroupPromise<T>;
            using value_type = T;

        public:
            TTaskGroupTask(THandle handle)
                : Handle(handle)
            {}

            void Start(const TIntrusivePtr<TTaskGroupSink<T>>& sink) {
                Handle.promise().SetSink(sink);
                Handle.resume();
            }

        private:
            THandle Handle;
        };

        template<class T, class TAwaitable>
        TTaskGroupTask<T> CreateTaskGroupTask(TAwaitable awaitable) {
            co_return co_await std::move(awaitable);
        }

    } // namespace NDetail

    /**
     * A task group allows starting multiple subtasks of the same result type
     * and awaiting them in a structured way. When task group is destroyed
     * all subtasks are detached in a thread-safe way.
     */
    template<class T>
    class TTaskGroup {
    public:
        TTaskGroup() = default;

        TTaskGroup(const TTaskGroup&) = delete;
        TTaskGroup(TTaskGroup&&) = delete;
        TTaskGroup& operator=(const TTaskGroup&) = delete;
        TTaskGroup& operator=(TTaskGroup&&) = delete;

        ~TTaskGroup() {
            Sink_->Detach();
        }

        /**
         * Add task to the group that will await the result of awaitable
         */
        template<class TAwaitable>
        void AddTask(TAwaitable&& awaitable) {
            auto task = NDetail::CreateTaskGroupTask<T>(std::forward<TAwaitable>(awaitable));
            task.Start(Sink_);
            ++TaskCount_;
        }

        /**
         * Returns the number of tasks left unawaited
         */
        size_t TaskCount() const {
            return TaskCount_;
        }

        class TAwaiter {
        public:
            explicit TAwaiter(TTaskGroup& taskGroup) noexcept
                : TaskGroup_(taskGroup)
            {}

            bool await_ready() const noexcept {
                Y_VERIFY(TaskGroup_.TaskCount_ > 0, "Not enough tasks to await");
                --TaskGroup_.TaskCount_;
                return TaskGroup_.Sink_->Ready();
            }

            std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept {
                return TaskGroup_.Sink_->Suspend(h);
            }

            T await_resume() {
                return std::move(*TaskGroup_.Sink_->Resume()).Value();
            }

        private:
            TTaskGroup& TaskGroup_;
        };

        /**
         * Await result of the next task in the task group
         */
        TAwaiter operator co_await() noexcept {
            return TAwaiter(*this);
        }

    private:
        TIntrusivePtr<NDetail::TTaskGroupSink<T>> Sink_ = MakeIntrusive<NDetail::TTaskGroupSink<T>>();
        size_t TaskCount_ = 0;
    };

} // namespace NActors