aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/threading/skip_list/skiplist.h
blob: f60c7cc97c36e44e09ef9cdd0bf2e39faf885420 (plain) (tree)
1
2
3
4
5
6
7
8
9







                                     
                                                 
                      
                                                                                    
 



                                 
 

                                 
 

                      
 
                                                                                    
 
                        
            
                    
           
                      

         
                                
         


                                 
         
                                 
         
                      

         




                                                                                    
 












                                                                             
                                         






















                                                      
           


                                  
 




                               
 



                                                               
 


                                             
             
 



                                                          
 
                                                         
             





                                                                                
 

                               
 

                                       
 



                                        
 

                              
 

                         
 
                               
 
                                  
                                                                                             
 





                                                                                 
 

                        
 


                                            
         
                              


                                                                
             
                                              
                                                 
         













                                                                                 
 



                                                        
 

                                                
 


                                                                  
 

                                                                  
         





                                                                                           
         
 





                                                 
             
         







                                                                      

             




                                                                     
 
                                                    

                                                                  
                                                        
         
 

                                    
 

                                               
 




                                                    
 



                                

             


                                                                                           
 


                                                                       
 







                                                                           
                 


                                                      
 




                                

             


                                                                  
 










                                                                           
                 



                                

                                
         
 



                                                                                                               
                    
                                                 
             
                        
         
                                    







                                                              
             
                                                                 
                                     
 



                                                                     
         
      
 
 
#pragma once

#include "compare.h"

#include <util/generic/algorithm.h>
#include <util/generic/noncopyable.h>
#include <util/generic/typetraits.h>
#include <util/memory/pool.h>
#include <util/random/random.h>
#include <library/cpp/deprecated/atomic/atomic.h>

namespace NThreading {
    ////////////////////////////////////////////////////////////////////////////////

    class TNopCounter {
    protected:
        template <typename T>
        void OnInsert(const T&) {
        }

        template <typename T>
        void OnUpdate(const T&) {
        }

        void Reset() {
        }
    };

    ////////////////////////////////////////////////////////////////////////////////

    class TSizeCounter {
    private:
        size_t Size;

    public:
        TSizeCounter()
            : Size(0)
        {
        }

        size_t GetSize() const {
            return Size;
        }

    protected:
        template <typename T>
        void OnInsert(const T&) {
            ++Size;
        }

        template <typename T>
        void OnUpdate(const T&) {
        }

        void Reset() {
            Size = 0;
        }
    };

    ////////////////////////////////////////////////////////////////////////////////
    // Append-only concurrent skip-list
    //
    // Readers do not require any synchronization.
    // Writers should be externally synchronized.
    // Nodes will be allocated using TMemoryPool instance.

    template <
        typename T,
        typename TComparer = TCompare<T>,
        typename TAllocator = TMemoryPool,
        typename TCounter = TSizeCounter,
        int MaxHeight = 12,
        int Branching = 4>
    class TSkipList: public TCounter, private TNonCopyable {
        class TNode {
        private:
            T Value;       // should be immutable after insert
            TNode* Next[]; // variable-size array maximum of MaxHeight values

        public:
            TNode(T&& value)
                : Value(std::move(value))
            {
                Y_UNUSED(Next);
            }

            const T& GetValue() const {
                return Value;
            }

            T& GetValue() {
                return Value;
            }

            TNode* GetNext(int height) const {
                return AtomicGet(Next[height]);
            }

            void Link(int height, TNode** prev) {
                for (int i = 0; i < height; ++i) {
                    Next[i] = prev[i]->Next[i];
                    AtomicSet(prev[i]->Next[i], this);
                }
            }
        };

    public:
        class TIterator {
        private:
            const TSkipList* List;
            const TNode* Node;

        public:
            TIterator()
                : List(nullptr)
                , Node(nullptr)
            {
            }

            TIterator(const TSkipList* list, const TNode* node)
                : List(list)
                , Node(node)
            {
            }

            TIterator(const TIterator& other)
                : List(other.List)
                , Node(other.Node)
            {
            }

            TIterator& operator=(const TIterator& other) {
                List = other.List;
                Node = other.Node;
                return *this;
            }

            void Next() {
                Node = Node ? Node->GetNext(0) : nullptr;
            }

            // much less efficient than Next as our list is single-linked
            void Prev() {
                if (Node) {
                    TNode* node = List->FindLessThan(Node->GetValue(), nullptr);
                    Node = (node != List->Head ? node : nullptr);
                }
            }

            void Reset() {
                Node = nullptr;
            }

            bool IsValid() const {
                return Node != nullptr;
            }

            const T& GetValue() const {
                Y_ASSERT(IsValid());
                return Node->GetValue();
            }
        };

    private:
        TAllocator& Allocator;
        TComparer Comparer;

        TNode* Head;
        TAtomic Height;
        TCounter Counter;

        TNode* Prev[MaxHeight];

        template <typename TValue>
        using TComparerReturnType = std::invoke_result_t<TComparer, const T&, const TValue&>;

    public:
        TSkipList(TAllocator& allocator, const TComparer& comparer = TComparer())
            : Allocator(allocator)
            , Comparer(comparer)
        {
            Init();
        }

        ~TSkipList() {
            CallDtors();
        }

        void Clear() {
            CallDtors();
            Allocator.ClearKeepFirstChunk();
            Init();
        }

        bool Insert(T value) {
            TNode* node = PrepareInsert(value);
            if (Y_UNLIKELY(node && Compare(node, value) == 0)) {
                // we do not allow duplicates
                return false;
            }
            node = DoInsert(std::move(value));
            TCounter::OnInsert(node->GetValue());
            return true;
        }

        template <typename TInsertAction, typename TUpdateAction>
        bool Insert(const T& value, TInsertAction insert, TUpdateAction update) {
            TNode* node = PrepareInsert(value);
            if (Y_UNLIKELY(node && Compare(node, value) == 0)) {
                if (update(node->GetValue())) {
                    TCounter::OnUpdate(node->GetValue());
                    return true;
                }
                // we do not allow duplicates
                return false;
            }
            node = DoInsert(insert(value));
            TCounter::OnInsert(node->GetValue());
            return true;
        }

        template <typename TValue>
        bool Contains(const TValue& value) const {
            TNode* node = FindGreaterThanOrEqual(value);
            return node && Compare(node, value) == 0;
        }

        TIterator SeekToFirst() const {
            return TIterator(this, FindFirst());
        }

        TIterator SeekToLast() const {
            TNode* last = FindLast();
            return TIterator(this, last != Head ? last : nullptr);
        }

        template <typename TValue>
        TIterator SeekTo(const TValue& value) const {
            return TIterator(this, FindGreaterThanOrEqual(value));
        }

    private:
        static int RandomHeight() {
            int height = 1;
            while (height < MaxHeight && (RandomNumber<unsigned int>() % Branching) == 0) {
                ++height;
            }
            return height;
        }

        void Init() {
            Head = AllocateRootNode();
            Height = 1;
            TCounter::Reset();

            for (int i = 0; i < MaxHeight; ++i) {
                Prev[i] = Head;
            }
        }

        void CallDtors() {
            if (!TTypeTraits<T>::IsPod) {
                // we should explicitly call destructors for our nodes
                TNode* node = Head->GetNext(0);
                while (node) {
                    TNode* next = node->GetNext(0);
                    node->~TNode();
                    node = next;
                }
            }
        }

        TNode* AllocateRootNode() {
            size_t size = sizeof(TNode) + sizeof(TNode*) * MaxHeight;
            void* buffer = Allocator.Allocate(size);
            memset(buffer, 0, size);
            return static_cast<TNode*>(buffer);
        }

        TNode* AllocateNode(T&& value, int height) {
            size_t size = sizeof(TNode) + sizeof(TNode*) * height;
            void* buffer = Allocator.Allocate(size);
            memset(buffer, 0, size);
            return new (buffer) TNode(std::move(value));
        }

        TNode* FindFirst() const {
            return Head->GetNext(0);
        }

        TNode* FindLast() const {
            TNode* node = Head;
            int height = AtomicGet(Height) - 1;

            while (true) {
                TNode* next = node->GetNext(height);
                if (next) {
                    node = next;
                    continue;
                }

                if (height) {
                    --height;
                } else {
                    return node;
                }
            }
        }

        template <typename TValue>
        TComparerReturnType<TValue> Compare(const TNode* node, const TValue& value) const {
            return Comparer(node->GetValue(), value);
        }

        template <typename TValue>
        TNode* FindLessThan(const TValue& value, TNode** links) const {
            TNode* node = Head;
            int height = AtomicGet(Height) - 1;

            TNode* prev = nullptr;
            while (true) {
                TNode* next = node->GetNext(height);
                if (next && next != prev) {
                    TComparerReturnType<TValue> cmp = Compare(next, value);
                    if (cmp < 0) {
                        node = next;
                        continue;
                    }
                }

                if (links) {
                    // collect links from upper levels
                    links[height] = node;
                }

                if (height) {
                    prev = next;
                    --height;
                } else {
                    return node;
                }
            }
        }

        template <typename TValue>
        TNode* FindGreaterThanOrEqual(const TValue& value) const {
            TNode* node = Head;
            int height = AtomicGet(Height) - 1;

            TNode* prev = nullptr;
            while (true) {
                TNode* next = node->GetNext(height);
                if (next && next != prev) {
                    TComparerReturnType<TValue> cmp = Compare(next, value);
                    if (cmp < 0) {
                        node = next;
                        continue;
                    }
                    if (cmp == 0) {
                        return next;
                    }
                }

                if (height) {
                    prev = next;
                    --height;
                } else {
                    return next;
                }
            }
        }

        TNode* PrepareInsert(const T& value) {
            TNode* prev = Prev[0];
            TNode* next = prev->GetNext(0);
            if ((prev == Head || Compare(prev, value) < 0) && (next == nullptr || Compare(next, value) >= 0)) {
                // avoid seek in case of sequential insert
            } else {
                prev = FindLessThan(value, Prev);
                next = prev->GetNext(0);
            }
            return next;
        }

        TNode* DoInsert(T&& value) {
            // choose level to place new node
            int currentHeight = AtomicGet(Height);
            int height = RandomHeight();
            if (height > currentHeight) {
                for (int i = currentHeight; i < height; ++i) {
                    // head should link to all levels
                    Prev[i] = Head;
                }
                AtomicSet(Height, height);
            }

            TNode* node = AllocateNode(std::move(value), height);
            node->Link(height, Prev);

            // keep last inserted node to optimize sequential inserts
            for (int i = 0; i < height; i++) {
                Prev[i] = node;
            }
            return node;
        }
    };

}