#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace NThreading { // TThreadLocalValue // // Safe RAII-friendly thread local storage without dirty hacks from util/system/tls // // Example 1: // // THolder pool = CreateThreadPool(threads); // TThreadLocalValue tls; // for (ui32 i : xrange(threads)) { // pool->SafeAddFunc([&]) { // *tls->Get() = 1337; // } // } // // Example 2: // // class TNoisy { // public: // TNoisy(const char* name = "TNoisy") // : Name_{name} { // printf("%s::%s\n", Name_, Name_); // } // // ~TNoisy() { // printf("%s::~%s\n", Name_, Name_); // } // private: // const char* Name_; // }; // // class TWrapper { // public: // TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // ~TWrapper() { // Println(__PRETTY_FUNCTION__); // } // // void DoWork() { // ThreadLocal_->Get(); // } // // private: // TNoisy Noisy_{"TWrapper"}; // TThreadLocalValue ThreadLocal_; // }; // // THolder pool = CreateThreadPool(3); // { // TWrapper wrapper; // for (ui32 i : xrange(3)) { // pool->SafeAddFunc([&] { // wrapper.DoWork(); // }); // } // } // // Will always print: // TWrapper::TWrapper() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TNoisy::~TNoisy() // TWrapper::~TWrapper() // enum class EThreadLocalImpl { HotSwap, SkipList, ForwardList, StdThreadLocal // completely different implementation over thread_local keyword, correctly destroys objects on thread finish }; namespace NDetail { template class TThreadLocalValueImpl; } // namespace NDetail inline constexpr size_t DefaultNumShards = 3; template class TThreadLocalValue : private TNonCopyable { public: template T& GetRef(ConstructArgs&& ...args) const { return *Get(std::forward(args)...); } template T* Get(ConstructArgs&& ...args) const { TThread::TId tid = TThread::CurrentThreadId(); return Shards_[tid % NumShards].Get(tid, std::forward(args)...); } private: using TStorage = NDetail::TThreadLocalValueImpl; mutable std::array Shards_; }; template class TThreadLocalValue : private TNonCopyable { public: ~TThreadLocalValue() { StaticState_->AllValues.Destroy(ObjectId_); StaticState_->FlatIds.Put(ObjectId_); } template T& GetRef(ConstructArgs&& ...args) const { return *Get(std::forward(args)...); } template T* Get(ConstructArgs&& ...args) const { auto& values = Values_.Storage; with_lock (Values_.Lock) { if (ObjectId_ >= values.size()) { values.resize(ObjectId_ + 1); } } TMaybe& v = values[ObjectId_]; if (!v.Defined()) { v.ConstructInPlace(std::forward(args)...); } return &*v; } private: struct TStaticState; struct TPerThreadValues { TAdaptiveLock Lock; TDeque> Storage; public: explicit TPerThreadValues(TAtomicSharedPtr staticState) : StaticState_(std::move(staticState)) { StaticState_->AllValues.Add(this); } ~TPerThreadValues() { StaticState_->AllValues.Remove(this); } private: TAtomicSharedPtr StaticState_; }; class TFlatIdGenerator { public: ui32 Get() { with_lock (Lock_) { if (Free_.empty()) { return MaxUnused_++; } else { ui32 free = Free_.back(); Free_.pop_back(); return free; } } } void Put(ui32 id) { with_lock (Lock_) { Free_.push_back(id); } } private: TAdaptiveLock Lock_; TVector Free_; ui32 MaxUnused_ = 0; }; struct TAllValues { public: void Add(TPerThreadValues* v) { with_lock (Lock_) { Ptrs_.insert(v); } } void Remove(TPerThreadValues* v) { with_lock (Lock_) { Ptrs_.erase(v); } } void Destroy(ui32 objectId) { with_lock (Lock_) { for (auto* v : Ptrs_) { TMaybe* toDestroy = nullptr; with_lock (v->Lock) { if (objectId < v->Storage.size()) { toDestroy = &v->Storage[objectId]; } } if (toDestroy) { toDestroy->Clear(); } } } } private: TAdaptiveLock Lock_; THashSet Ptrs_; }; private: struct TStaticState { TFlatIdGenerator FlatIds; TAllValues AllValues; }; static TAtomicSharedPtr GetStaticState() { static TAtomicSharedPtr state = MakeAtomicShared(); return state; } private: TAtomicSharedPtr StaticState_ = GetStaticState(); const ui32 ObjectId_ = StaticState_->FlatIds.Get(); static inline thread_local TPerThreadValues Values_{GetStaticState()}; }; namespace NDetail { template class TThreadLocalValueImpl { private: class TStorage: public THashMap>, public TAtomicRefCount { }; public: TThreadLocalValueImpl() { Registered_.AtomicStore(new TStorage()); } template T* Get(TThread::TId tid, ConstructArgs&& ...args) { if (TIntrusivePtr state = Registered_.AtomicLoad(); TAtomicSharedPtr* result = state->FindPtr(tid)) { return result->Get(); } else { TAtomicSharedPtr value = MakeAtomicShared(std::forward(args)...); with_lock(RegisterLock_) { TIntrusivePtr oldState = Registered_.AtomicLoad(); THolder newState = MakeHolder(*oldState); (*newState)[tid] = value; Registered_.AtomicStore(newState.Release()); } return value.Get(); } } private: THotSwap Registered_; TMutex RegisterLock_; }; template class TThreadLocalValueImpl { private: struct TNode { TThread::TId Key; THolder Value; }; struct TCompare { int operator()(const TNode& lhs, const TNode& rhs) const { return ::NThreading::TCompare{}(lhs.Key, rhs.Key); } }; public: TThreadLocalValueImpl() : ListPool_{InitialPoolSize()} , SkipList_{ListPool_} {} template T* Get(TThread::TId tid, ConstructArgs&& ...args) { TNode key{tid, {}}; auto iterator = SkipList_.SeekTo(key); if (iterator.IsValid() && iterator.GetValue().Key == key.Key) { return iterator.GetValue().Value.Get(); } with_lock (RegisterLock_) { SkipList_.Insert({tid, MakeHolder(std::forward(args)...)}); } iterator = SkipList_.SeekTo(key); return iterator.GetValue().Value.Get(); } private: static size_t InitialPoolSize() { return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards; } private: static inline constexpr size_t MaxHeight = 6; using TCustomSkipList = TSkipList; TMemoryPool ListPool_; TCustomSkipList SkipList_; TAdaptiveLock RegisterLock_; }; template class TThreadLocalValueImpl { private: struct TNode { TThread::TId Key = 0; T Value; TNode* Next = nullptr; }; public: TThreadLocalValueImpl() : Head_{nullptr} , Pool_{0} {} template T* Get(TThread::TId tid, ConstructArgs&& ...args) { TNode* head = Head_.load(std::memory_order_acquire); for (TNode* node = head; node; node = node->Next) { if (node->Key == tid) { return &node->Value; } } TNode* newNode = AllocateNode(tid, head, std::forward(args)...); while (!Head_.compare_exchange_weak(head, newNode, std::memory_order_release, std::memory_order_relaxed)) { newNode->Next = head; } return &newNode->Value; } template TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) { TNode* storage = nullptr; with_lock(PoolMutex_) { storage = Pool_.Allocate(); } new (storage) TNode{tid, T{std::forward(args)...}, next}; return storage; } ~TThreadLocalValueImpl() { if constexpr (!std::is_trivially_destructible_v) { TNode* next = nullptr; for (TNode* node = Head_.load(); node; node = next) { next = node->Next; node->~TNode(); } } } private: std::atomic Head_; TMemoryPool Pool_; TMutex PoolMutex_; }; } // namespace NDetail } // namespace NThreading