#pragma once #include <util/generic/hash.h> #include <util/generic/maybe.h> #include <util/generic/ptr.h> #include <util/generic/vector.h> #include <util/memory/pool.h> #include <util/system/mutex.h> #include <util/system/thread.h> #include <library/cpp/threading/hot_swap/hot_swap.h> #include <library/cpp/threading/skip_list/skiplist.h> #include <array> #include <atomic> #include <thread> namespace NThreading { // TThreadLocalValue // // Safe RAII-friendly thread local storage without dirty hacks from util/system/tls // // Example 1: // // THolder<IThreadPool> pool = CreateThreadPool(threads); // TThreadLocalValue<ui32> tls; // for (ui32 i : xrange(threads)) { // pool->SafeAddFunc([&]) { // *tls->Get() = 1337; // } // } // // Example 2: // // class TNoisy { // public: // TNoisy(const char* name = "TNoisy") // : Name_{name} { // printf("%s::%s\n", Name_, Name_); // } // // ~TNoisy() { // printf("%s::~%s\n", Name_, Name_); // } // private: // const char* Name_; // }; // // class TWrapper { // public: // TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // ~TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // void DoWork() { // ThreadLocal_->Get(); // } // // private: // TNoisy Noisy_{"TWrapper"}; // TThreadLocalValue<TNoisy> ThreadLocal_; // }; // // THolder<IThreadPool> pool = CreateThreadPool(3); // { // TWrapper wrapper; // for (ui32 i : xrange(3)) { // pool->SafeAddFunc([&] { // wrapper.DoWork(); // }); // } // } // // Will always print: // TWrapper::TWrapper() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TWrapper::~TWrapper() // enum class EThreadLocalImpl { HotSwap, SkipList, ForwardList, }; namespace NDetail { template <typename T, EThreadLocalImpl Impl, size_t NumShards> class TThreadLocalValueImpl; } // namespace NDetail inline constexpr size_t DefaultNumShards = 3; template <typename T, EThreadLocalImpl Impl = EThreadLocalImpl::SkipList, size_t NumShards = DefaultNumShards> class TThreadLocalValue : private TNonCopyable { public: template <typename ...ConstructArgs> T& GetRef(ConstructArgs&& ...args) const { return *Get(std::forward<ConstructArgs>(args)...); } template <typename ...ConstructArgs> T* Get(ConstructArgs&& ...args) const { TThread::TId tid = TThread::CurrentThreadId(); return Shards_[tid % NumShards].Get(tid, std::forward<ConstructArgs>(args)...); } private: using TStorage = NDetail::TThreadLocalValueImpl<T, Impl, NumShards>; mutable std::array<TStorage, NumShards> Shards_; }; namespace NDetail { template <typename T, size_t NumShards> class TThreadLocalValueImpl<T, EThreadLocalImpl::HotSwap, NumShards> { private: class TStorage: public THashMap<TThread::TId, TAtomicSharedPtr<T>>, public TAtomicRefCount<TStorage> { }; public: TThreadLocalValueImpl() { Registered_.AtomicStore(new TStorage()); } template <typename ...ConstructArgs> T* Get(TThread::TId tid, ConstructArgs&& ...args) { if (TIntrusivePtr<TStorage> state = Registered_.AtomicLoad(); TAtomicSharedPtr<T>* result = state->FindPtr(tid)) { return result->Get(); } else { TAtomicSharedPtr<T> value = MakeAtomicShared<T>(std::forward<ConstructArgs>(args)...); with_lock(RegisterLock_) { TIntrusivePtr<TStorage> oldState = Registered_.AtomicLoad(); THolder<TStorage> newState = MakeHolder<TStorage>(*oldState); (*newState)[tid] = value; Registered_.AtomicStore(newState.Release()); } return value.Get(); } } private: THotSwap<TStorage> Registered_; TMutex RegisterLock_; }; template <typename T, size_t NumShards> class TThreadLocalValueImpl<T, EThreadLocalImpl::SkipList, NumShards> { private: struct TNode { TThread::TId Key; THolder<T> Value; }; struct TCompare { int operator()(const TNode& lhs, const TNode& rhs) const { return ::NThreading::TCompare<TThread::TId>{}(lhs.Key, rhs.Key); } }; public: TThreadLocalValueImpl() : ListPool_{InitialPoolSize()} , SkipList_{ListPool_} {} template <typename ...ConstructArgs> T* Get(TThread::TId tid, ConstructArgs&& ...args) { TNode key{tid, {}}; auto iterator = SkipList_.SeekTo(key); if (iterator.IsValid() && iterator.GetValue().Key == key.Key) { return iterator.GetValue().Value.Get(); } with_lock (RegisterLock_) { SkipList_.Insert({tid, MakeHolder<T>(std::forward<ConstructArgs>(args)...)}); } iterator = SkipList_.SeekTo(key); return iterator.GetValue().Value.Get(); } private: static size_t InitialPoolSize() { return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards; } private: static inline constexpr size_t MaxHeight = 6; using TCustomSkipList = TSkipList<TNode, TCompare, TMemoryPool, TSizeCounter, MaxHeight>; TMemoryPool ListPool_; TCustomSkipList SkipList_; TAdaptiveLock RegisterLock_; }; template <typename T, size_t NumShards> class TThreadLocalValueImpl<T, EThreadLocalImpl::ForwardList, NumShards> { private: struct TNode { TThread::TId Key = 0; T Value; TNode* Next = nullptr; }; public: TThreadLocalValueImpl() : Head_{nullptr} , Pool_{0} {} template <typename ...ConsturctArgs> T* Get(TThread::TId tid, ConsturctArgs&& ...args) { TNode* node = Head_.load(std::memory_order_relaxed); for (; node; node = node->Next) { if (node->Key == tid) { return &node->Value; } } TNode* newNode = AllocateNode(tid, node, std::forward<ConsturctArgs>(args)...); while (!Head_.compare_exchange_weak(node, newNode, std::memory_order_release, std::memory_order_relaxed)) { newNode->Next = node; } return &newNode->Value; } template <typename ...ConstructArgs> TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) { TNode* storage = nullptr; with_lock(PoolMutex_) { storage = Pool_.Allocate<TNode>(); } new (storage) TNode{tid, T{std::forward<ConstructArgs>(args)...}, next}; return storage; } ~TThreadLocalValueImpl() { if constexpr (!std::is_trivially_destructible_v<T>) { TNode* next = nullptr; for (TNode* node = Head_.load(); node; node = next) { next = node->Next; node->~TNode(); } } } private: std::atomic<TNode*> Head_; TMemoryPool Pool_; TMutex PoolMutex_; }; } // namespace NDetail } // namespace NThreading