#pragma once

#include <util/generic/noncopyable.h>

#include <atomic>
#include <cstddef>
#include <utility>

//////////////////////////////
// lock free lifo stack
template <class T>
class TLockFreeStack: TNonCopyable {
    struct TNode {
        T Value;
        std::atomic<TNode*> Next;

        TNode() = default;

        template <class U>
        explicit TNode(U&& val)
            : Value(std::forward<U>(val))
            , Next(nullptr)
        {
        }
    };

    std::atomic<TNode*> Head = nullptr;
    std::atomic<TNode*> FreePtr = nullptr;
    std::atomic<size_t> DequeueCount = 0;

    void TryToFreeMemory() {
        TNode* current = FreePtr.load(std::memory_order_acquire);
        if (!current)
            return;
        if (DequeueCount.load() == 1) {
            // node current is in free list, we are the last thread so try to cleanup
            if (FreePtr.compare_exchange_strong(current, nullptr))
                EraseList(current);
        }
    }
    void EraseList(TNode* p) {
        while (p) {
            TNode* next = p->Next;
            delete p;
            p = next;
        }
    }
    void EnqueueImpl(TNode* head, TNode* tail) {
        auto headValue = Head.load(std::memory_order_acquire);
        for (;;) {
            tail->Next.store(headValue, std::memory_order_release);
            // NB. See https://en.cppreference.com/w/cpp/atomic/atomic/compare_exchange
            // The weak forms (1-2) of the functions are allowed to fail spuriously, that is,
            // act as if *this != expected even if they are equal.
            // When a compare-and-exchange is in a loop, the weak version will yield better
            // performance on some platforms.
            if (Head.compare_exchange_weak(headValue, head))
                break;
        }
    }
    template <class U>
    void EnqueueImpl(U&& u) {
        TNode* node = new TNode(std::forward<U>(u));
        EnqueueImpl(node, node);
    }

public:
    TLockFreeStack() = default;

    ~TLockFreeStack() {
        EraseList(Head.load());
        EraseList(FreePtr.load());
    }

    void Enqueue(const T& t) {
        EnqueueImpl(t);
    }

    void Enqueue(T&& t) {
        EnqueueImpl(std::move(t));
    }

    template <typename TCollection>
    void EnqueueAll(const TCollection& data) {
        EnqueueAll(data.begin(), data.end());
    }
    template <typename TIter>
    void EnqueueAll(TIter dataBegin, TIter dataEnd) {
        if (dataBegin == dataEnd) {
            return;
        }
        TIter i = dataBegin;
        TNode* node = new TNode(*i);
        TNode* tail = node;

        for (++i; i != dataEnd; ++i) {
            TNode* nextNode = node;
            node = new TNode(*i);
            node->Next.store(nextNode, std::memory_order_release);
        }
        EnqueueImpl(node, tail);
    }
    bool Dequeue(T* res) {
        ++DequeueCount;
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, current->Next.load(std::memory_order_acquire))) {
                *res = std::move(current->Value);
                // delete current; // ABA problem
                // even more complex node deletion
                TryToFreeMemory();
                if (--DequeueCount == 0) {
                    // no other Dequeue()s, can safely reclaim memory
                    delete current;
                } else {
                    // Dequeue()s in progress, put node to free list
                    for (TNode* freePtr = FreePtr.load(std::memory_order_acquire);;) {
                        current->Next.store(freePtr, std::memory_order_release);
                        if (FreePtr.compare_exchange_weak(freePtr, current))
                            break;
                    }
                }
                return true;
            }
        }
        TryToFreeMemory();
        --DequeueCount;
        return false;
    }
    // add all elements to *res
    // elements are returned in order of dequeue (top to bottom; see example in unittest)
    template <typename TCollection>
    void DequeueAll(TCollection* res) {
        ++DequeueCount;
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, nullptr)) {
                for (TNode* x = current; x;) {
                    res->push_back(std::move(x->Value));
                    x = x->Next;
                }
                // EraseList(current); // ABA problem
                // even more complex node deletion
                TryToFreeMemory();
                if (--DequeueCount == 0) {
                    // no other Dequeue()s, can safely reclaim memory
                    EraseList(current);
                } else {
                    // Dequeue()s in progress, add nodes list to free list
                    TNode* currentLast = current;
                    while (currentLast->Next) {
                        currentLast = currentLast->Next;
                    }
                    for (TNode* freePtr = FreePtr.load(std::memory_order_acquire);;) {
                        currentLast->Next.store(freePtr, std::memory_order_release);
                        if (FreePtr.compare_exchange_weak(freePtr, current))
                            break;
                    }
                }
                return;
            }
        }
        TryToFreeMemory();
        --DequeueCount;
    }
    bool DequeueSingleConsumer(T* res) {
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, current->Next)) {
                *res = std::move(current->Value);
                delete current; // with single consumer thread ABA does not happen
                return true;
            }
        }
        return false;
    }
    // add all elements to *res
    // elements are returned in order of dequeue (top to bottom; see example in unittest)
    template <typename TCollection>
    void DequeueAllSingleConsumer(TCollection* res) {
        for (TNode* head = Head.load(std::memory_order_acquire); head;) {
            if (Head.compare_exchange_weak(head, nullptr)) {
                for (TNode* x = head; x;) {
                    res->push_back(std::move(x->Value));
                    x = x->Next;
                }
                EraseList(head); // with single consumer thread ABA does not happen
                return;
            }
        }
    }
    bool IsEmpty() {
        return Head.load() == nullptr; // without lock, so result is approximate
    }
};