aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/containers/concurrent_hash/concurrent_hash.h
blob: f15a1c3d6ec27054deb6e502f1e7468c9fbb0fa1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once

#include <util/generic/hash.h>
#include <util/system/spinlock.h>

#include <array>

template <typename K, typename V, size_t BucketCount = 64, typename L = TAdaptiveLock>
class TConcurrentHashMap {
public:
    using TActualMap = THashMap<K, V>;
    using TLock = L;

    struct TBucket {
        friend class TConcurrentHashMap;

    private:
        TActualMap Map;
        mutable TLock Mutex;

    public:
        TLock& GetMutex() const {
            return Mutex;
        }

        TActualMap& GetMap() {
            return Map;
        }
        const TActualMap& GetMap() const {
            return Map;
        }

        const V& GetUnsafe(const K& key) const {
            typename TActualMap::const_iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "not found by key");
            return it->second;
        }

        V& GetUnsafe(const K& key) {
            typename TActualMap::iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "not found by key");
            return it->second;
        }

        V RemoveUnsafe(const K& key) {
            typename TActualMap::iterator it = Map.find(key);
            Y_VERIFY(it != Map.end(), "removing non-existent key");
            V r = std::move(it->second);
            Map.erase(it);
            return r;
        }

        bool HasUnsafe(const K& key) const {
            typename TActualMap::const_iterator it = Map.find(key);
            return (it != Map.end());
        }
    };

    std::array<TBucket, BucketCount> Buckets;

public:
    TBucket& GetBucketForKey(const K& key) {
        return Buckets[THash<K>()(key) % BucketCount];
    }

    const TBucket& GetBucketForKey(const K& key) const {
        return Buckets[THash<K>()(key) % BucketCount];
    }

    void Insert(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        bucket.Map[key] = value;
    }

    void InsertUnique(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (!bucket.Map.insert(std::make_pair(key, value)).second) {
            Y_FAIL("non-unique key");
        }
    }

    V& InsertIfAbsent(const K& key, const V& value) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.Map.insert(std::make_pair(key, value)).first->second;
    }

    template <typename Callable>
    V& InsertIfAbsentWithInit(const K& key, Callable initFunc) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (bucket.HasUnsafe(key)) {
            return bucket.GetUnsafe(key);
        }

        return bucket.Map.insert(std::make_pair(key, initFunc())).first->second;
    }

    V Get(const K& key) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.GetUnsafe(key);
    }

    bool Get(const K& key, V& result) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        if (bucket.HasUnsafe(key)) {
            result = bucket.GetUnsafe(key);
            return true;
        }
        return false;
    }

    V Remove(const K& key) {
        TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.RemoveUnsafe(key);
    }

    bool Has(const K& key) const {
        const TBucket& bucket = GetBucketForKey(key);
        TGuard<TLock> guard(bucket.Mutex);
        return bucket.HasUnsafe(key);
    }
};