#pragma once
#include <util/generic/hash.h>
#include <util/generic/maybe.h>
#include <util/generic/ptr.h>
#include <util/generic/vector.h>
#include <util/memory/pool.h>
#include <util/system/mutex.h>
#include <util/system/thread.h>
#include <library/cpp/threading/hot_swap/hot_swap.h>
#include <library/cpp/threading/skip_list/skiplist.h>
#include <array>
#include <atomic>
#include <thread>
namespace NThreading {
// TThreadLocalValue
//
// Safe RAII-friendly thread local storage without dirty hacks from util/system/tls
//
// Example 1:
//
// THolder<IThreadPool> pool = CreateThreadPool(threads);
// TThreadLocalValue<ui32> tls;
// for (ui32 i : xrange(threads)) {
// pool->SafeAddFunc([&]) {
// *tls->Get() = 1337;
// }
// }
//
// Example 2:
//
// class TNoisy {
// public:
// TNoisy(const char* name = "TNoisy")
// : Name_{name} {
// printf("%s::%s\n", Name_, Name_);
// }
//
// ~TNoisy() {
// printf("%s::~%s\n", Name_, Name_);
// }
// private:
// const char* Name_;
// };
//
// class TWrapper {
// public:
// TWrapper() {
// Println(__PRETTY_FUNCTION__);
// }
//
// ~TWrapper() {
// Println(__PRETTY_FUNCTION__);
// }
//
// void DoWork() {
// ThreadLocal_->Get();
// }
//
// private:
// TNoisy Noisy_{"TWrapper"};
// TThreadLocalValue<TNoisy> ThreadLocal_;
// };
//
// THolder<IThreadPool> pool = CreateThreadPool(3);
// {
// TWrapper wrapper;
// for (ui32 i : xrange(3)) {
// pool->SafeAddFunc([&] {
// wrapper.DoWork();
// });
// }
// }
//
// Will always print:
// TWrapper::TWrapper()
// TNoisy::TNoisy()
// TNoisy::TNoisy()
// TNoisy::TNoisy()
// TNoisy::~TNoisy()
// TNoisy::~TNoisy()
// TNoisy::~TNoisy()
// TWrapper::~TWrapper()
//
enum class EThreadLocalImpl {
HotSwap,
SkipList,
ForwardList,
};
namespace NDetail {
template <typename T, EThreadLocalImpl Impl, size_t NumShards>
class TThreadLocalValueImpl;
} // namespace NDetail
inline constexpr size_t DefaultNumShards = 3;
template <typename T, EThreadLocalImpl Impl = EThreadLocalImpl::SkipList, size_t NumShards = DefaultNumShards>
class TThreadLocalValue : private TNonCopyable {
public:
template <typename ...ConstructArgs>
T& GetRef(ConstructArgs&& ...args) const {
return *Get(std::forward<ConstructArgs>(args)...);
}
template <typename ...ConstructArgs>
T* Get(ConstructArgs&& ...args) const {
TThread::TId tid = TThread::CurrentThreadId();
return Shards_[tid % NumShards].Get(tid, std::forward<ConstructArgs>(args)...);
}
private:
using TStorage = NDetail::TThreadLocalValueImpl<T, Impl, NumShards>;
mutable std::array<TStorage, NumShards> Shards_;
};
namespace NDetail {
template <typename T, size_t NumShards>
class TThreadLocalValueImpl<T, EThreadLocalImpl::HotSwap, NumShards> {
private:
class TStorage: public THashMap<TThread::TId, TAtomicSharedPtr<T>>, public TAtomicRefCount<TStorage> {
};
public:
TThreadLocalValueImpl() {
Registered_.AtomicStore(new TStorage());
}
template <typename ...ConstructArgs>
T* Get(TThread::TId tid, ConstructArgs&& ...args) {
if (TIntrusivePtr<TStorage> state = Registered_.AtomicLoad(); TAtomicSharedPtr<T>* result = state->FindPtr(tid)) {
return result->Get();
} else {
TAtomicSharedPtr<T> value = MakeAtomicShared<T>(std::forward<ConstructArgs>(args)...);
with_lock(RegisterLock_) {
TIntrusivePtr<TStorage> oldState = Registered_.AtomicLoad();
THolder<TStorage> newState = MakeHolder<TStorage>(*oldState);
(*newState)[tid] = value;
Registered_.AtomicStore(newState.Release());
}
return value.Get();
}
}
private:
THotSwap<TStorage> Registered_;
TMutex RegisterLock_;
};
template <typename T, size_t NumShards>
class TThreadLocalValueImpl<T, EThreadLocalImpl::SkipList, NumShards> {
private:
struct TNode {
TThread::TId Key;
THolder<T> Value;
};
struct TCompare {
int operator()(const TNode& lhs, const TNode& rhs) const {
return ::NThreading::TCompare<TThread::TId>{}(lhs.Key, rhs.Key);
}
};
public:
TThreadLocalValueImpl()
: ListPool_{InitialPoolSize()}
, SkipList_{ListPool_}
{}
template <typename ...ConstructArgs>
T* Get(TThread::TId tid, ConstructArgs&& ...args) {
TNode key{tid, {}};
auto iterator = SkipList_.SeekTo(key);
if (iterator.IsValid() && iterator.GetValue().Key == key.Key) {
return iterator.GetValue().Value.Get();
}
with_lock (RegisterLock_) {
SkipList_.Insert({tid, MakeHolder<T>(std::forward<ConstructArgs>(args)...)});
}
iterator = SkipList_.SeekTo(key);
return iterator.GetValue().Value.Get();
}
private:
static size_t InitialPoolSize() {
return std::thread::hardware_concurrency() * (sizeof(T) + sizeof(TThread::TId) + sizeof(void*)) / NumShards;
}
private:
static inline constexpr size_t MaxHeight = 6;
using TCustomSkipList = TSkipList<TNode, TCompare, TMemoryPool, TSizeCounter, MaxHeight>;
TMemoryPool ListPool_;
TCustomSkipList SkipList_;
TAdaptiveLock RegisterLock_;
};
template <typename T, size_t NumShards>
class TThreadLocalValueImpl<T, EThreadLocalImpl::ForwardList, NumShards> {
private:
struct TNode {
TThread::TId Key = 0;
T Value;
TNode* Next = nullptr;
};
public:
TThreadLocalValueImpl()
: Head_{nullptr}
, Pool_{0}
{}
template <typename ...ConsturctArgs>
T* Get(TThread::TId tid, ConsturctArgs&& ...args) {
TNode* node = Head_.load(std::memory_order_relaxed);
for (; node; node = node->Next) {
if (node->Key == tid) {
return &node->Value;
}
}
TNode* newNode = AllocateNode(tid, node, std::forward<ConsturctArgs>(args)...);
while (!Head_.compare_exchange_weak(node, newNode, std::memory_order_release, std::memory_order_relaxed)) {
newNode->Next = node;
}
return &newNode->Value;
}
template <typename ...ConstructArgs>
TNode* AllocateNode(TThread::TId tid, TNode* next, ConstructArgs&& ...args) {
TNode* storage = nullptr;
with_lock(PoolMutex_) {
storage = Pool_.Allocate<TNode>();
}
new (storage) TNode{tid, T{std::forward<ConstructArgs>(args)...}, next};
return storage;
}
~TThreadLocalValueImpl() {
if constexpr (!std::is_trivially_destructible_v<T>) {
TNode* next = nullptr;
for (TNode* node = Head_.load(); node; node = next) {
next = node->Next;
node->~TNode();
}
}
}
private:
std::atomic<TNode*> Head_;
TMemoryPool Pool_;
TMutex PoolMutex_;
};
} // namespace NDetail
} // namespace NThreading