aboutsummaryrefslogblamecommitdiffstats
path: root/util/thread/lfstack.h
blob: effde7c706dc25409a0f23c3c45d1fac800590d2 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
            
 
                                     


                  

                              

                                    
                
                                 
 
                          
 
                          
                                         
                           
         
      

                                          
 
                            
                                                                 
                     
                                       
                                                                                     
                                                                  

                                   
                              
                   
                                  


                     
                                                              
                  





                                                                                             

                      
                             
                                                    
                                
 
       
                               
                       
                                  
     
 
                              
                       
     



                                  
                                   
                                              

                                             
                                                     
                                   
         
                            
                                    


                                      
                                                                  
                                
     
                          

                                                                                                     
                                                 

                                                  
                                          


                                                                     

                                                                                      





                                  
                       
                     
                               
                                                                                         
                                       

                                                                               
                                              
                                                        



                                                     
                                          


                                                                          
                                                 
                                               
                                                        
                     

                                                                                      





                                  
                       
     
                                        
                                                                               
                                                 






                                                                                         
                                   
                                                     

                                                                         
                                                        
                                
                                                                                   


                       
                    
                                                                                
     
#pragma once

#include <util/generic/noncopyable.h>

#include <atomic>
#include <cstddef>
#include <utility>

//////////////////////////////
// lock free lifo stack
template <class T>
class TLockFreeStack: TNonCopyable {
    struct TNode {
        T Value;
        std::atomic<TNode*> Next;

        TNode() = default;

        template <class U>
        explicit TNode(U&& val)
            : Value(std::forward<U>(val))
            , Next(nullptr)
        {
        }
    };

    std::atomic<TNode*> Head = nullptr;
    std::atomic<TNode*> FreePtr = nullptr;
    std::atomic<size_t> DequeueCount = 0;

    void TryToFreeMemory() {
        TNode* current = FreePtr.load(std::memory_order_acquire);
        if (!current)
            return;
        if (DequeueCount.load() == 1) {
            // node current is in free list, we are the last thread so try to cleanup
            if (FreePtr.compare_exchange_strong(current, nullptr))
                EraseList(current);
        }
    }
    void EraseList(TNode* p) {
        while (p) {
            TNode* next = p->Next;
            delete p;
            p = next;
        }
    }
    void EnqueueImpl(TNode* head, TNode* tail) {
        auto headValue = Head.load(std::memory_order_acquire);
        for (;;) {
            tail->Next.store(headValue, std::memory_order_release);
            // NB. See https://en.cppreference.com/w/cpp/atomic/atomic/compare_exchange
            // The weak forms (1-2) of the functions are allowed to fail spuriously, that is,
            // act as if *this != expected even if they are equal.
            // When a compare-and-exchange is in a loop, the weak version will yield better
            // performance on some platforms.
            if (Head.compare_exchange_weak(headValue, head))
                break;
        }
    }
    template <class U>
    void EnqueueImpl(U&& u) {
        TNode* node = new TNode(std::forward<U>(u));
        EnqueueImpl(node, node);
    }

public:
    TLockFreeStack() = default;

    ~TLockFreeStack() {
        EraseList(Head.load());
        EraseList(FreePtr.load());
    }

    void Enqueue(const T& t) {
        EnqueueImpl(t);
    }

    void Enqueue(T&& t) {
        EnqueueImpl(std::move(t));
    }

    template <typename TCollection>
    void EnqueueAll(const TCollection& data) {
        EnqueueAll(data.begin(), data.end());
    }
    template <typename TIter>
    void EnqueueAll(TIter dataBegin, TIter dataEnd) {
        if (dataBegin == dataEnd) {
            return;
        }
        TIter i = dataBegin;
        TNode* node = new TNode(*i);
        TNode* tail = node;

        for (++i; i != dataEnd; ++i) {
            TNode* nextNode = node;
            node = new TNode(*i);
            node->Next.store(nextNode, std::memory_order_release);
        }
        EnqueueImpl(node, tail);
    }
    bool Dequeue(T* res) {
        ++DequeueCount;
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, current->Next.load(std::memory_order_acquire))) {
                *res = std::move(current->Value);
                // delete current; // ABA problem
                // even more complex node deletion
                TryToFreeMemory();
                if (--DequeueCount == 0) {
                    // no other Dequeue()s, can safely reclaim memory
                    delete current;
                } else {
                    // Dequeue()s in progress, put node to free list
                    for (TNode* freePtr = FreePtr.load(std::memory_order_acquire);;) {
                        current->Next.store(freePtr, std::memory_order_release);
                        if (FreePtr.compare_exchange_weak(freePtr, current))
                            break;
                    }
                }
                return true;
            }
        }
        TryToFreeMemory();
        --DequeueCount;
        return false;
    }
    // add all elements to *res
    // elements are returned in order of dequeue (top to bottom; see example in unittest)
    template <typename TCollection>
    void DequeueAll(TCollection* res) {
        ++DequeueCount;
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, nullptr)) {
                for (TNode* x = current; x;) {
                    res->push_back(std::move(x->Value));
                    x = x->Next;
                }
                // EraseList(current); // ABA problem
                // even more complex node deletion
                TryToFreeMemory();
                if (--DequeueCount == 0) {
                    // no other Dequeue()s, can safely reclaim memory
                    EraseList(current);
                } else {
                    // Dequeue()s in progress, add nodes list to free list
                    TNode* currentLast = current;
                    while (currentLast->Next) {
                        currentLast = currentLast->Next;
                    }
                    for (TNode* freePtr = FreePtr.load(std::memory_order_acquire);;) {
                        currentLast->Next.store(freePtr, std::memory_order_release);
                        if (FreePtr.compare_exchange_weak(freePtr, current))
                            break;
                    }
                }
                return;
            }
        }
        TryToFreeMemory();
        --DequeueCount;
    }
    bool DequeueSingleConsumer(T* res) {
        for (TNode* current = Head.load(std::memory_order_acquire); current;) {
            if (Head.compare_exchange_weak(current, current->Next)) {
                *res = std::move(current->Value);
                delete current; // with single consumer thread ABA does not happen
                return true;
            }
        }
        return false;
    }
    // add all elements to *res
    // elements are returned in order of dequeue (top to bottom; see example in unittest)
    template <typename TCollection>
    void DequeueAllSingleConsumer(TCollection* res) {
        for (TNode* head = Head.load(std::memory_order_acquire); head;) {
            if (Head.compare_exchange_weak(head, nullptr)) {
                for (TNode* x = head; x;) {
                    res->push_back(std::move(x->Value));
                    x = x->Next;
                }
                EraseList(head); // with single consumer thread ABA does not happen
                return;
            }
        }
    }
    bool IsEmpty() {
        return Head.load() == nullptr; // without lock, so result is approximate
    }
};