aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/threading/thread_local
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/threading/thread_local
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/threading/thread_local')
-rw-r--r--library/cpp/threading/thread_local/thread_local.cpp1
-rw-r--r--library/cpp/threading/thread_local/thread_local.h268
2 files changed, 269 insertions, 0 deletions
diff --git a/library/cpp/threading/thread_local/thread_local.cpp b/library/cpp/threading/thread_local/thread_local.cpp
new file mode 100644
index 0000000000..3acca47b76
--- /dev/null
+++ b/library/cpp/threading/thread_local/thread_local.cpp
@@ -0,0 +1 @@
+#include "thread_local.h"
diff --git a/library/cpp/threading/thread_local/thread_local.h b/library/cpp/threading/thread_local/thread_local.h
new file mode 100644
index 0000000000..1cc4642373
--- /dev/null
+++ b/library/cpp/threading/thread_local/thread_local.h
@@ -0,0 +1,268 @@
+#pragma once
+
+#include <util/generic/hash.h>
+#include <util/generic/maybe.h>
+#include <util/generic/ptr.h>
+#include <util/generic/vector.h>
+#include <util/memory/pool.h>
+#include <util/system/mutex.h>
+#include <util/system/thread.h>
+
+#include <library/cpp/threading/hot_swap/hot_swap.h>
+#include <library/cpp/threading/skip_list/skiplist.h>
+
+#include <array>
+#include <atomic>
+#include <thread>
+
+namespace NThreading {
+
+// TThreadLocalValue
+//
+// Safe RAII-friendly thread local storage without dirty hacks from util/system/tls
+//
+// Example 1:
+//
+// THolder<IThreadPool> pool = CreateThreadPool(threads);
+// TThreadLocalValue<ui32> tls;
+// for (ui32 i : xrange(threads)) {
+// pool->SafeAddFunc([&]) {
+// *tls->Get() = 1337;
+// }
+// }
+//
+// Example 2:
+//
+// class TNoisy {
+// public:
+// TNoisy(const char* name = "TNoisy")
+// : Name_{name} {
+// printf("%s::%s\n", Name_, Name_);
+// }
+//
+// ~TNoisy() {
+// printf("%s::~%s\n", Name_, Name_);
+// }
+// private:
+// const char* Name_;
+// };
+//
+// class TWrapper {
+// public:
+// TWrapper() {
+// Println(__PRETTY_FUNCTION__);
+// }
+//
+// ~TWrapper() {
+// Println(__PRETTY_FUNCTION__);
+// }
+//
+// void DoWork() {
+// ThreadLocal_->Get();
+// }
+//
+// private:
+// TNoisy Noisy_{"TWrapper"};
+// TThreadLocalValue<TNoisy> ThreadLocal_;
+// };
+//
+// THolder<IThreadPool> pool = CreateThreadPool(3);
+// {
+// TWrapper wrapper;
+// for (ui32 i : xrange(3)) {
+// pool->SafeAddFunc([&] {
+// wrapper.DoWork();
+// });
+// }
+// }
+//
+// Will always print:
+// TWrapper::TWrapper()
+// TNoisy::TNoisy()
+// TNoisy::TNoisy()
+// TNoisy::TNoisy()
+// TNoisy::~TNoisy()
+// TNoisy::~TNoisy()
+// TNoisy::~TNoisy()
+// TWrapper::~TWrapper()
+//
+
+enum class EThreadLocalImpl {
+ HotSwap,
+ SkipList,
+ ForwardList,
+};
+
+namespace NDetail {
+
+template <typename T, EThreadLocalImpl Impl, size_t NumShards>
+class TThreadLocalValueImpl;
+
+} // namespace NDetail
+
+inline constexpr size_t DefaultNumShards = 3;
+
+template <typename T, EThreadLocalImpl Impl = EThreadLocalImpl::SkipList, size_t NumShards = DefaultNumShards>
+class TThreadLocalValue : private TNonCopyable {
+public:
+ template <typename ...ConstructArgs>
+ T& GetRef(ConstructArgs&& ...args) const {
+ return *Get(std::forward<ConstructArgs>(args)...);
+ }
+
+ template <typename ...ConstructArgs>
+ T* Get(ConstructArgs&& ...args) const {
+ TThread::TId tid = TThread::CurrentThreadId();
+ return Shards_[tid % NumShards].Get(tid, std::forward<ConstructArgs>(args)...);
+ }
+
+private:
+ using TStorage = NDetail::TThreadLocalValueImpl<T, Impl, NumShards>;
+
+ mutable std::array<TStorage, NumShards> Shards_;
+};
+
+namespace NDetail {
+
+template <typename T, size_t NumShards>
+class TThreadLocalValueImpl<T, EThreadLocalImpl::HotSwap, NumShards> {
+private:
+ class TStorage: public THashMap<TThread::TId, TAtomicSharedPtr<T>>, public TAtomicRefCount<TStorage> {
+ };
+
+public:
+ TThreadLocalValueImpl() {
+ Registered_.AtomicStore(new TStorage());
+ }
+
+ template <typename ...ConstructArgs>
+ T* Get(TThread::TId tid, ConstructArgs&& ...args) {
+ if (TIntrusivePtr<TStorage> state = Registered_.AtomicLoad(); TAtomicSharedPtr<T>* result = state->FindPtr(tid)) {
+ return result->Get();
+ } else {
+ TAtomicSharedPtr<T> value = MakeAtomicShared<T>(std::forward<ConstructArgs>(args)...);
+ with_lock(RegisterLock_) {
+ TIntrusivePtr<TStorage> oldState = Registered_.AtomicLoad();
+ THolder<TStorage> newState = MakeHolder<TStorage>(*oldState);
+ (*newState)[tid] = value;
+ Registered_.AtomicStore(newState.Release());
+ }
+ return value.Get();
+ }
+ }
+
+private:
+ THotSwap<TStorage> Registered_;
+ TMutex RegisterLock_;
+};
+
+template <typename T, size_t NumShards>
+class TThreadLocalValueImpl<T, EThreadLocalImpl::SkipList, NumShards> {
+private:
+ struct TNode {
+ TThread::TId Key;
+ THolder<T> Value;
+ };
+
+ struct TCompare {
+ int operator()(const TNode& lhs, const TNode& rhs) const {
+ return ::NThreading::TCompare<TThread::TId>{}(lhs.Key, rhs.Key);
+ }
+ };
+
+public:
+ TThreadLocalValueImpl()
+ : ListPool_{InitialPoolSize()}
+ , SkipList_{ListPool_}
+ {}
+
+ template <typename ...ConstructArgs>
+ T* Get(TThread::TId tid, ConstructArgs&& ...args) {
+ TNode key{tid, {}};
+ auto iterator = SkipList_.SeekTo(key);
+ if (iterator.IsValid() && iterator.GetValue().Key == key.Key) {
+ return iterator.GetValue().Value.Get();
+ }
+
+ with_lock (RegisterLock_) {
+ SkipList_.Insert({tid, MakeHolder<T>(std::forward<ConstructArgs>(args)...)});
+ }
+ iterator = SkipList_.SeekTo(key);
+ return iterator.GetValue().Value.Get();
+ }
+
+private:
+ static size_t InitialPoolSize() {
+ return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards;
+ }
+
+private:
+ static inline constexpr size_t MaxHeight = 6;
+ using TCustomSkipList = TSkipList<TNode, TCompare, TMemoryPool, TSizeCounter, MaxHeight>;
+
+ TMemoryPool ListPool_;
+ TCustomSkipList SkipList_;
+ TAdaptiveLock RegisterLock_;
+};
+
+template <typename T, size_t NumShards>
+class TThreadLocalValueImpl<T, EThreadLocalImpl::ForwardList, NumShards> {
+private:
+ struct TNode {
+ TThread::TId Key = 0;
+ T Value;
+ TNode* Next = nullptr;
+ };
+
+public:
+ TThreadLocalValueImpl()
+ : Head_{nullptr}
+ , Pool_{0}
+ {}
+
+ template <typename ...ConsturctArgs>
+ T* Get(TThread::TId tid, ConsturctArgs&& ...args) {
+ TNode* node = Head_.load(std::memory_order_relaxed);
+ for (; node; node = node->Next) {
+ if (node->Key == tid) {
+ return &node->Value;
+ }
+ }
+
+ TNode* newNode = AllocateNode(tid, node, std::forward<ConsturctArgs>(args)...);
+ while (!Head_.compare_exchange_weak(node, newNode, std::memory_order_release, std::memory_order_relaxed)) {
+ newNode->Next = node;
+ }
+
+ return &newNode->Value;
+ }
+
+ template <typename ...ConstructArgs>
+ TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) {
+ TNode* storage = nullptr;
+ with_lock(PoolMutex_) {
+ storage = Pool_.Allocate<TNode>();
+ }
+ new (storage) TNode{tid, T{std::forward<ConstructArgs>(args)...}, next};
+ return storage;
+ }
+
+ ~TThreadLocalValueImpl() {
+ if constexpr (!std::is_trivially_destructible_v<T>) {
+ TNode* next = nullptr;
+ for (TNode* node = Head_.load(); node; node = next) {
+ next = node->Next;
+ node->~TNode();
+ }
+ }
+ }
+
+private:
+ std::atomic<TNode*> Head_;
+ TMemoryPool Pool_;
+ TMutex PoolMutex_;
+};
+
+} // namespace NDetail
+
+} // namespace NThreading