aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/clickhouse/src/Common/CacheBase.h
blob: 3f9454a0d03831fd1cd51c7b4c14df7f93a38a18 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#pragma once

#include <Common/Exception.h>
#include <Common/ICachePolicy.h>
#include <Common/LRUCachePolicy.h>
#include <Common/SLRUCachePolicy.h>

#include <atomic>
#include <cassert>
#include <chrono>
#include <memory>
#include <mutex>
#include <unordered_map>

#include <base/defines.h>


namespace DB
{
namespace ErrorCodes
{
    extern const int BAD_ARGUMENTS;
}

/// Thread-safe cache that evicts entries using special cache policy
/// (default policy evicts entries which are not used for a long time).
/// WeightFunction is a functor that takes Mapped as a parameter and returns "weight" (approximate size)
/// of that value.
/// Cache starts to evict entries when their total weight exceeds max_size_in_bytes.
/// Value weight should not change after insertion.
template <typename TKey, typename TMapped, typename HashFunction = std::hash<TKey>, typename WeightFunction = EqualWeightFunction<TMapped>>
class CacheBase
{
private:
    using CachePolicy = ICachePolicy<TKey, TMapped, HashFunction, WeightFunction>;

public:
    using Key = typename CachePolicy::Key;
    using Mapped = typename CachePolicy::Mapped;
    using MappedPtr = typename CachePolicy::MappedPtr;
    using KeyMapped = typename CachePolicy::KeyMapped;

    static constexpr auto NO_MAX_COUNT = size_t(0);
    static constexpr auto DEFAULT_SIZE_RATIO = 0.5l;

    /// Use this ctor if you only care about the cache size but not internals like the cache policy.
    explicit CacheBase(size_t max_size_in_bytes, size_t max_count = NO_MAX_COUNT, double size_ratio = DEFAULT_SIZE_RATIO)
        : CacheBase("SLRU", max_size_in_bytes, max_count, size_ratio)
    {
    }

    /// Use this ctor if the user should be able to configure the cache policy and cache sizes via settings. Supports only general-purpose policies LRU and SLRU.
    explicit CacheBase(std::string_view cache_policy_name, size_t max_size_in_bytes, size_t max_count, double size_ratio)
    {
        auto on_weight_loss_function = [&](size_t weight_loss) { onRemoveOverflowWeightLoss(weight_loss); };

        if (cache_policy_name.empty())
        {
            static constexpr auto default_cache_policy = "SLRU";
            cache_policy_name = default_cache_policy;
        }

        if (cache_policy_name == "LRU")
        {
            using LRUPolicy = LRUCachePolicy<TKey, TMapped, HashFunction, WeightFunction>;
            cache_policy = std::make_unique<LRUPolicy>(max_size_in_bytes, max_count, on_weight_loss_function);
        }
        else if (cache_policy_name == "SLRU")
        {
            using SLRUPolicy = SLRUCachePolicy<TKey, TMapped, HashFunction, WeightFunction>;
            cache_policy = std::make_unique<SLRUPolicy>(max_size_in_bytes, max_count, size_ratio, on_weight_loss_function);
        }
        else
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown cache policy name: {}", cache_policy_name);
    }

    /// Use this ctor to provide an arbitrary cache policy.
    explicit CacheBase(std::unique_ptr<ICachePolicy<TKey, TMapped, HashFunction, WeightFunction>> cache_policy_)
        : cache_policy(std::move(cache_policy_))
    {}

    MappedPtr get(const Key & key)
    {
        std::lock_guard lock(mutex);
        auto res = cache_policy->get(key);
        if (res)
            ++hits;
        else
            ++misses;
        return res;
    }

    std::optional<KeyMapped> getWithKey(const Key & key)
    {
        std::lock_guard lock(mutex);
        auto res = cache_policy->getWithKey(key);
        if (res.has_value())
            ++hits;
        else
            ++misses;
        return res;
    }

    void set(const Key & key, const MappedPtr & mapped)
    {
        std::lock_guard lock(mutex);
        cache_policy->set(key, mapped);
    }

    /// If the value for the key is in the cache, returns it. If it is not, calls load_func() to
    /// produce it, saves the result in the cache and returns it.
    /// Only one of several concurrent threads calling getOrSet() will call load_func(),
    /// others will wait for that call to complete and will use its result (this helps prevent cache stampede).
    /// Exceptions occurring in load_func will be propagated to the caller. Another thread from the
    /// set of concurrent threads will then try to call its load_func etc.
    ///
    /// Returns std::pair of the cached value and a bool indicating whether the value was produced during this call.
    template <typename LoadFunc>
    std::pair<MappedPtr, bool> getOrSet(const Key & key, LoadFunc && load_func)
    {
        InsertTokenHolder token_holder;
        {
            std::lock_guard cache_lock(mutex);
            auto val = cache_policy->get(key);
            if (val)
            {
                ++hits;
                return std::make_pair(val, false);
            }

            auto & token = insert_tokens[key];
            if (!token)
                token = std::make_shared<InsertToken>(*this);

            token_holder.acquire(&key, token, cache_lock);
        }

        InsertToken * token = token_holder.token.get();

        std::lock_guard token_lock(token->mutex);

        token_holder.cleaned_up = token->cleaned_up;

        if (token->value)
        {
            /// Another thread already produced the value while we waited for token->mutex.
            ++hits;
            return std::make_pair(token->value, false);
        }

        ++misses;
        token->value = load_func();

        std::lock_guard cache_lock(mutex);

        /// Insert the new value only if the token is still in present in insert_tokens.
        /// (The token may be absent because of a concurrent clear() call).
        bool result = false;
        auto token_it = insert_tokens.find(key);
        if (token_it != insert_tokens.end() && token_it->second.get() == token)
        {
            cache_policy->set(key, token->value);
            result = true;
        }

        if (!token->cleaned_up)
            token_holder.cleanup(token_lock, cache_lock);

        return std::make_pair(token->value, result);
    }

    void getStats(size_t & out_hits, size_t & out_misses) const
    {
        std::lock_guard lock(mutex);
        out_hits = hits;
        out_misses = misses;
    }

    std::vector<KeyMapped> dump() const
    {
        std::lock_guard lock(mutex);
        return cache_policy->dump();
    }

    void clear()
    {
        std::lock_guard lock(mutex);
        insert_tokens.clear();
        hits = 0;
        misses = 0;
        cache_policy->clear();
    }

    void remove(const Key & key)
    {
        std::lock_guard lock(mutex);
        cache_policy->remove(key);
    }

    size_t sizeInBytes() const
    {
        std::lock_guard lock(mutex);
        return cache_policy->sizeInBytes();
    }

    size_t count() const
    {
        std::lock_guard lock(mutex);
        return cache_policy->count();
    }

    size_t maxSizeInBytes() const
    {
        std::lock_guard lock(mutex);
        return cache_policy->maxSizeInBytes();
    }

    void setMaxCount(size_t max_count)
    {
        std::lock_guard lock(mutex);
        cache_policy->setMaxCount(max_count);
    }

    void setMaxSizeInBytes(size_t max_size_in_bytes)
    {
        std::lock_guard lock(mutex);
        cache_policy->setMaxSizeInBytes(max_size_in_bytes);
    }

    void setQuotaForUser(const String & user_name, size_t max_size_in_bytes, size_t max_entries)
    {
        std::lock_guard lock(mutex);
        cache_policy->setQuotaForUser(user_name, max_size_in_bytes, max_entries);
    }

    virtual ~CacheBase() = default;

protected:
    mutable std::mutex mutex;

private:
    std::unique_ptr<CachePolicy> cache_policy TSA_GUARDED_BY(mutex);

    std::atomic<size_t> hits{0};
    std::atomic<size_t> misses{0};

    /// Represents pending insertion attempt.
    struct InsertToken
    {
        explicit InsertToken(CacheBase & cache_) : cache(cache_) {}

        std::mutex mutex;
        bool cleaned_up TSA_GUARDED_BY(mutex) = false;
        MappedPtr value TSA_GUARDED_BY(mutex);

        CacheBase & cache;
        size_t refcount = 0; /// Protected by the cache mutex
    };

    using InsertTokenById = std::unordered_map<Key, std::shared_ptr<InsertToken>, HashFunction>;

    /// This class is responsible for removing used insert tokens from the insert_tokens map.
    /// Among several concurrent threads the first successful one is responsible for removal. But if they all
    /// fail, then the last one is responsible.
    struct InsertTokenHolder
    {
        const Key * key = nullptr;
        std::shared_ptr<InsertToken> token;
        bool cleaned_up = false;

        InsertTokenHolder() = default;

        void acquire(const Key * key_, const std::shared_ptr<InsertToken> & token_, std::lock_guard<std::mutex> & /* cache_lock */)
            TSA_NO_THREAD_SAFETY_ANALYSIS // disabled only because we can't reference the parent-level cache mutex from here
        {
            key = key_;
            token = token_;
            ++token->refcount;
        }

        void cleanup(std::lock_guard<std::mutex> & /* token_lock */, std::lock_guard<std::mutex> & /* cache_lock */)
            TSA_NO_THREAD_SAFETY_ANALYSIS // disabled only because we can't reference the parent-level cache mutex from here
        {
            token->cache.insert_tokens.erase(*key);
            token->cleaned_up = true;
            cleaned_up = true;
        }

        ~InsertTokenHolder()
        {
            if (!token)
                return;

            if (cleaned_up)
                return;

            std::lock_guard token_lock(token->mutex);

            if (token->cleaned_up)
                return;

            std::lock_guard cache_lock(token->cache.mutex);

            --token->refcount;
            if (token->refcount == 0)
                cleanup(token_lock, cache_lock);
        }
    };

    friend struct InsertTokenHolder;

    InsertTokenById insert_tokens TSA_GUARDED_BY(mutex);

    /// Override this method if you want to track how much weight was lost in removeOverflow method.
    virtual void onRemoveOverflowWeightLoss(size_t /*weight_loss*/) {}
};


}