aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/containers/concurrent_hash/concurrent_hash.h
blob: b11deb62544db91417f4109ab5e8fc0e0492a1bb (plain) (tree)






















































                                                                                      








                                                                   































                                                                           












                                                              


                                                                
                                                  












                                                                                
                                                        















                                                     
#pragma once

#include <util/generic/hash.h>
#include <util/system/spinlock.h>

#include <array>

template <typename K, typename V, size_t BucketCount = 64, typename L = TAdaptiveLock>
class TConcurrentHashMap {
public:
    using TActualMap = THashMap<K, V>;
    using TLock = L;

    struct TBucket {
        friend class TConcurrentHashMap;

    private:
        TActualMap Map;
        mutable TLock Mutex;

    public:
        TLock& GetMutex() const {
            return Mutex;
        }

        TActualMap& GetMap() {
            return Map;
        }
        const TActualMap& GetMap() const {
            return Map;
        }

        const V& GetUnsafe(const K& key) const {
            typename TActualMap::const_iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "not found by key");
            return it->second;
        }

        V& GetUnsafe(const K& key) {
            typename TActualMap::iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "not found by key");
            return it->second;
        }

        V RemoveUnsafe(const K& key) {
            typename TActualMap::iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "removing non-existent key");
            V r = std::move(it->second);
            Map.erase(it);
            return r;
        }

        bool HasUnsafe(const K& key) const {
            typename TActualMap::const_iterator it = Map.find(key);
            return (it != Map.end());
        }

        const V* TryGetUnsafe(const K& key) const {
            typename TActualMap::const_iterator it = Map.find(key);
            return it == Map.end() ? nullptr : &it->second;
        }

        V* TryGetUnsafe(const K& key) {
            typename TActualMap::iterator it = Map.find(key);
            return it == Map.end() ? nullptr : &it->second;
        }
    };

    std::array<TBucket, BucketCount> Buckets;

public:
    TBucket& GetBucketForKey(const K& key) {
        return Buckets[THash<K>()(key) % BucketCount];
    }

    const TBucket& GetBucketForKey(const K& key) const {
        return Buckets[THash<K>()(key) % BucketCount];
    }

    void Insert(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        bucket.Map[key] = value;
    }

    void InsertUnique(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (!bucket.Map.insert(std::make_pair(key, value)).second) {
            Y_FAIL("non-unique key");
        }
    }

    V& InsertIfAbsent(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.Map.insert(std::make_pair(key, value)).first->second;
    }

    template <typename TKey, typename... Args>
    V& EmplaceIfAbsent(TKey&& key, Args&&... args) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (V* value = bucket.TryGetUnsafe(key)) {
            return *value;
        }
        return bucket.Map.emplace(
            std::piecewise_construct,
            std::forward_as_tuple(std::forward<TKey>(key)),
            std::forward_as_tuple(std::forward<Args>(args)...)
        ).first->second;
    }

    template <typename Callable>
    V& InsertIfAbsentWithInit(const K& key, Callable initFunc) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (V* value = bucket.TryGetUnsafe(key)) {
            return *value;
        }

        return bucket.Map.insert(std::make_pair(key, initFunc())).first->second;
    }

    V Get(const K& key) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.GetUnsafe(key);
    }

    bool Get(const K& key, V& result) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (const V* value = bucket.TryGetUnsafe(key)) {
            result = *value;
            return true;
        }
        return false;
    }

    V Remove(const K& key) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.RemoveUnsafe(key);
    }

    bool Has(const K& key) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.HasUnsafe(key);
    }
};