aboutsummaryrefslogtreecommitdiffstats
path: root/util/thread
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/thread
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/thread')
-rw-r--r--util/thread/factory.cpp93
-rw-r--r--util/thread/factory.h65
-rw-r--r--util/thread/factory_ut.cpp57
-rw-r--r--util/thread/fwd.cpp1
-rw-r--r--util/thread/fwd.h30
-rw-r--r--util/thread/lfqueue.cpp1
-rw-r--r--util/thread/lfqueue.h406
-rw-r--r--util/thread/lfqueue_ut.cpp333
-rw-r--r--util/thread/lfstack.cpp1
-rw-r--r--util/thread/lfstack.h188
-rw-r--r--util/thread/lfstack_ut.cpp346
-rw-r--r--util/thread/pool.cpp772
-rw-r--r--util/thread/pool.h390
-rw-r--r--util/thread/pool_ut.cpp257
-rw-r--r--util/thread/singleton.cpp1
-rw-r--r--util/thread/singleton.h41
-rw-r--r--util/thread/singleton_ut.cpp21
-rw-r--r--util/thread/ut/ya.make18
-rw-r--r--util/thread/ya.make6
19 files changed, 3027 insertions, 0 deletions
diff --git a/util/thread/factory.cpp b/util/thread/factory.cpp
new file mode 100644
index 0000000000..48e898f32d
--- /dev/null
+++ b/util/thread/factory.cpp
@@ -0,0 +1,93 @@
+#include "factory.h"
+
+#include <util/system/thread.h>
+#include <util/generic/singleton.h>
+
+using IThread = IThreadFactory::IThread;
+
+namespace {
+ class TSystemThreadFactory: public IThreadFactory {
+ public:
+ class TPoolThread: public IThread {
+ public:
+ ~TPoolThread() override {
+ if (Thr_) {
+ Thr_->Detach();
+ }
+ }
+
+ void DoRun(IThreadAble* func) override {
+ Thr_.Reset(new TThread(ThreadProc, func));
+
+ Thr_->Start();
+ }
+
+ void DoJoin() noexcept override {
+ if (!Thr_) {
+ return;
+ }
+
+ Thr_->Join();
+ Thr_.Destroy();
+ }
+
+ private:
+ static void* ThreadProc(void* func) {
+ ((IThreadAble*)(func))->Execute();
+
+ return nullptr;
+ }
+
+ private:
+ THolder<TThread> Thr_;
+ };
+
+ inline TSystemThreadFactory() noexcept {
+ }
+
+ IThread* DoCreate() override {
+ return new TPoolThread;
+ }
+ };
+
+ class TThreadFactoryFuncObj: public IThreadFactory::IThreadAble {
+ public:
+ TThreadFactoryFuncObj(const std::function<void()>& func)
+ : Func(func)
+ {
+ }
+ void DoExecute() override {
+ THolder<TThreadFactoryFuncObj> self(this);
+ Func();
+ }
+
+ private:
+ std::function<void()> Func;
+ };
+}
+
+THolder<IThread> IThreadFactory::Run(std::function<void()> func) {
+ THolder<IThread> ret(DoCreate());
+
+ ret->Run(new ::TThreadFactoryFuncObj(func));
+
+ return ret;
+}
+
+static IThreadFactory* SystemThreadPoolImpl() {
+ return Singleton<TSystemThreadFactory>();
+}
+
+static IThreadFactory* systemPool = nullptr;
+
+IThreadFactory* SystemThreadFactory() {
+ if (systemPool) {
+ return systemPool;
+ }
+
+ return SystemThreadPoolImpl();
+}
+
+void SetSystemThreadFactory(IThreadFactory* pool) {
+ systemPool = pool;
+}
diff --git a/util/thread/factory.h b/util/thread/factory.h
new file mode 100644
index 0000000000..561fcbac88
--- /dev/null
+++ b/util/thread/factory.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <functional>
+
+class IThreadFactory {
+public:
+ class IThreadAble {
+ public:
+ inline IThreadAble() noexcept = default;
+
+ virtual ~IThreadAble() = default;
+
+ inline void Execute() {
+ DoExecute();
+ }
+
+ private:
+ virtual void DoExecute() = 0;
+ };
+
+ class IThread {
+ friend class IThreadFactory;
+
+ public:
+ inline IThread() noexcept = default;
+
+ virtual ~IThread() = default;
+
+ inline void Join() noexcept {
+ DoJoin();
+ }
+
+ private:
+ inline void Run(IThreadAble* func) {
+ DoRun(func);
+ }
+
+ private:
+ // it's actually DoStart
+ virtual void DoRun(IThreadAble* func) = 0;
+ virtual void DoJoin() noexcept = 0;
+ };
+
+ inline IThreadFactory() noexcept = default;
+
+ virtual ~IThreadFactory() = default;
+
+ // XXX: rename to Start
+ inline THolder<IThread> Run(IThreadAble* func) {
+ THolder<IThread> ret(DoCreate());
+
+ ret->Run(func);
+
+ return ret;
+ }
+
+ THolder<IThread> Run(std::function<void()> func);
+
+private:
+ virtual IThread* DoCreate() = 0;
+};
+
+IThreadFactory* SystemThreadFactory();
+void SetSystemThreadFactory(IThreadFactory* pool);
diff --git a/util/thread/factory_ut.cpp b/util/thread/factory_ut.cpp
new file mode 100644
index 0000000000..647d96c901
--- /dev/null
+++ b/util/thread/factory_ut.cpp
@@ -0,0 +1,57 @@
+#include "factory.h"
+#include "pool.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+class TThrPoolTest: public TTestBase {
+ UNIT_TEST_SUITE(TThrPoolTest);
+ UNIT_TEST(TestSystemPool)
+ UNIT_TEST(TestAdaptivePool)
+ UNIT_TEST_SUITE_END();
+
+ struct TRunAble: public IThreadFactory::IThreadAble {
+ inline TRunAble()
+ : done(false)
+ {
+ }
+
+ ~TRunAble() override = default;
+
+ void DoExecute() override {
+ done = true;
+ }
+
+ bool done;
+ };
+
+private:
+ inline void TestSystemPool() {
+ TRunAble r;
+
+ {
+ THolder<IThreadFactory::IThread> thr = SystemThreadFactory()->Run(&r);
+
+ thr->Join();
+ }
+
+ UNIT_ASSERT_EQUAL(r.done, true);
+ }
+
+ inline void TestAdaptivePool() {
+ TRunAble r;
+
+ {
+ TAdaptiveThreadPool pool;
+
+ pool.Start(0);
+
+ THolder<IThreadFactory::IThread> thr = pool.Run(&r);
+
+ thr->Join();
+ }
+
+ UNIT_ASSERT_EQUAL(r.done, true);
+ }
+};
+
+UNIT_TEST_SUITE_REGISTRATION(TThrPoolTest);
diff --git a/util/thread/fwd.cpp b/util/thread/fwd.cpp
new file mode 100644
index 0000000000..4214b6df83
--- /dev/null
+++ b/util/thread/fwd.cpp
@@ -0,0 +1 @@
+#include "fwd.h"
diff --git a/util/thread/fwd.h b/util/thread/fwd.h
new file mode 100644
index 0000000000..6f1caed21c
--- /dev/null
+++ b/util/thread/fwd.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <stlfwd>
+
+struct TDefaultLFCounter;
+
+template <class T, class TCounter = TDefaultLFCounter>
+class TLockFreeQueue;
+
+template <class T, class TCounter = TDefaultLFCounter>
+class TAutoLockFreeQueue;
+
+template <class T>
+class TLockFreeStack;
+
+class IThreadFactory;
+
+struct IObjectInQueue;
+class TThreadFactoryHolder;
+
+using TThreadFunction = std::function<void()>;
+
+class IThreadPool;
+class TFakeThreadPool;
+class TThreadPool;
+class TAdaptiveThreadPool;
+class TSimpleThreadPool;
+
+template <class TQueueType, class TSlave>
+class TThreadPoolBinder;
diff --git a/util/thread/lfqueue.cpp b/util/thread/lfqueue.cpp
new file mode 100644
index 0000000000..5861999b78
--- /dev/null
+++ b/util/thread/lfqueue.cpp
@@ -0,0 +1 @@
+#include "lfqueue.h"
diff --git a/util/thread/lfqueue.h b/util/thread/lfqueue.h
new file mode 100644
index 0000000000..ab523631e4
--- /dev/null
+++ b/util/thread/lfqueue.h
@@ -0,0 +1,406 @@
+#pragma once
+
+#include "fwd.h"
+
+#include <util/generic/ptr.h>
+#include <util/system/atomic.h>
+#include <util/system/yassert.h>
+#include "lfstack.h"
+
+struct TDefaultLFCounter {
+ template <class T>
+ void IncCount(const T& data) {
+ (void)data;
+ }
+ template <class T>
+ void DecCount(const T& data) {
+ (void)data;
+ }
+};
+
+// @brief lockfree queue
+// @tparam T - the queue element, should be movable
+// @tparam TCounter, a observer class to count number of items in queue
+// be carifull, IncCount and DecCount can be called on a moved object and
+// it is TCounter class responsibility to check validity of passed object
+template <class T, class TCounter>
+class TLockFreeQueue: public TNonCopyable {
+ struct TListNode {
+ template <typename U>
+ TListNode(U&& u, TListNode* next)
+ : Next(next)
+ , Data(std::forward<U>(u))
+ {
+ }
+
+ template <typename U>
+ explicit TListNode(U&& u)
+ : Data(std::forward<U>(u))
+ {
+ }
+
+ TListNode* volatile Next;
+ T Data;
+ };
+
+ // using inheritance to be able to use 0 bytes for TCounter when we don't need one
+ struct TRootNode: public TCounter {
+ TListNode* volatile PushQueue;
+ TListNode* volatile PopQueue;
+ TListNode* volatile ToDelete;
+ TRootNode* volatile NextFree;
+
+ TRootNode()
+ : PushQueue(nullptr)
+ , PopQueue(nullptr)
+ , ToDelete(nullptr)
+ , NextFree(nullptr)
+ {
+ }
+ void CopyCounter(TRootNode* x) {
+ *(TCounter*)this = *(TCounter*)x;
+ }
+ };
+
+ static void EraseList(TListNode* n) {
+ while (n) {
+ TListNode* keepNext = AtomicGet(n->Next);
+ delete n;
+ n = keepNext;
+ }
+ }
+
+ alignas(64) TRootNode* volatile JobQueue;
+ alignas(64) volatile TAtomic FreememCounter;
+ alignas(64) volatile TAtomic FreeingTaskCounter;
+ alignas(64) TRootNode* volatile FreePtr;
+
+ void TryToFreeAsyncMemory() {
+ TAtomic keepCounter = AtomicAdd(FreeingTaskCounter, 0);
+ TRootNode* current = AtomicGet(FreePtr);
+ if (current == nullptr)
+ return;
+ if (AtomicAdd(FreememCounter, 0) == 1) {
+ // we are the last thread, try to cleanup
+ // check if another thread have cleaned up
+ if (keepCounter != AtomicAdd(FreeingTaskCounter, 0)) {
+ return;
+ }
+ if (AtomicCas(&FreePtr, (TRootNode*)nullptr, current)) {
+ // free list
+ while (current) {
+ TRootNode* p = AtomicGet(current->NextFree);
+ EraseList(AtomicGet(current->ToDelete));
+ delete current;
+ current = p;
+ }
+ AtomicAdd(FreeingTaskCounter, 1);
+ }
+ }
+ }
+ void AsyncRef() {
+ AtomicAdd(FreememCounter, 1);
+ }
+ void AsyncUnref() {
+ TryToFreeAsyncMemory();
+ AtomicAdd(FreememCounter, -1);
+ }
+ void AsyncDel(TRootNode* toDelete, TListNode* lst) {
+ AtomicSet(toDelete->ToDelete, lst);
+ for (;;) {
+ AtomicSet(toDelete->NextFree, AtomicGet(FreePtr));
+ if (AtomicCas(&FreePtr, toDelete, AtomicGet(toDelete->NextFree)))
+ break;
+ }
+ }
+ void AsyncUnref(TRootNode* toDelete, TListNode* lst) {
+ TryToFreeAsyncMemory();
+ if (AtomicAdd(FreememCounter, -1) == 0) {
+ // no other operations in progress, can safely reclaim memory
+ EraseList(lst);
+ delete toDelete;
+ } else {
+ // Dequeue()s in progress, put node to free list
+ AsyncDel(toDelete, lst);
+ }
+ }
+
+ struct TListInvertor {
+ TListNode* Copy;
+ TListNode* Tail;
+ TListNode* PrevFirst;
+
+ TListInvertor()
+ : Copy(nullptr)
+ , Tail(nullptr)
+ , PrevFirst(nullptr)
+ {
+ }
+ ~TListInvertor() {
+ EraseList(Copy);
+ }
+ void CopyWasUsed() {
+ Copy = nullptr;
+ Tail = nullptr;
+ PrevFirst = nullptr;
+ }
+ void DoCopy(TListNode* ptr) {
+ TListNode* newFirst = ptr;
+ TListNode* newCopy = nullptr;
+ TListNode* newTail = nullptr;
+ while (ptr) {
+ if (ptr == PrevFirst) {
+ // short cut, we have copied this part already
+ AtomicSet(Tail->Next, newCopy);
+ newCopy = Copy;
+ Copy = nullptr; // do not destroy prev try
+ if (!newTail)
+ newTail = Tail; // tried to invert same list
+ break;
+ }
+ TListNode* newElem = new TListNode(ptr->Data, newCopy);
+ newCopy = newElem;
+ ptr = AtomicGet(ptr->Next);
+ if (!newTail)
+ newTail = newElem;
+ }
+ EraseList(Copy); // copy was useless
+ Copy = newCopy;
+ PrevFirst = newFirst;
+ Tail = newTail;
+ }
+ };
+
+ void EnqueueImpl(TListNode* head, TListNode* tail) {
+ TRootNode* newRoot = new TRootNode;
+ AsyncRef();
+ AtomicSet(newRoot->PushQueue, head);
+ for (;;) {
+ TRootNode* curRoot = AtomicGet(JobQueue);
+ AtomicSet(tail->Next, AtomicGet(curRoot->PushQueue));
+ AtomicSet(newRoot->PopQueue, AtomicGet(curRoot->PopQueue));
+ newRoot->CopyCounter(curRoot);
+
+ for (TListNode* node = head;; node = AtomicGet(node->Next)) {
+ newRoot->IncCount(node->Data);
+ if (node == tail)
+ break;
+ }
+
+ if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+ AsyncUnref(curRoot, nullptr);
+ break;
+ }
+ }
+ }
+
+ template <typename TCollection>
+ static void FillCollection(TListNode* lst, TCollection* res) {
+ while (lst) {
+ res->emplace_back(std::move(lst->Data));
+ lst = AtomicGet(lst->Next);
+ }
+ }
+
+ /** Traverses a given list simultaneously creating its inversed version.
+ * After that, fills a collection with a reversed version and returns the last visited lst's node.
+ */
+ template <typename TCollection>
+ static TListNode* FillCollectionReverse(TListNode* lst, TCollection* res) {
+ if (!lst) {
+ return nullptr;
+ }
+
+ TListNode* newCopy = nullptr;
+ do {
+ TListNode* newElem = new TListNode(std::move(lst->Data), newCopy);
+ newCopy = newElem;
+ lst = AtomicGet(lst->Next);
+ } while (lst);
+
+ FillCollection(newCopy, res);
+ EraseList(newCopy);
+
+ return lst;
+ }
+
+public:
+ TLockFreeQueue()
+ : JobQueue(new TRootNode)
+ , FreememCounter(0)
+ , FreeingTaskCounter(0)
+ , FreePtr(nullptr)
+ {
+ }
+ ~TLockFreeQueue() {
+ AsyncRef();
+ AsyncUnref(); // should free FreeList
+ EraseList(JobQueue->PushQueue);
+ EraseList(JobQueue->PopQueue);
+ delete JobQueue;
+ }
+ template <typename U>
+ void Enqueue(U&& data) {
+ TListNode* newNode = new TListNode(std::forward<U>(data));
+ EnqueueImpl(newNode, newNode);
+ }
+ void Enqueue(T&& data) {
+ TListNode* newNode = new TListNode(std::move(data));
+ EnqueueImpl(newNode, newNode);
+ }
+ void Enqueue(const T& data) {
+ TListNode* newNode = new TListNode(data);
+ EnqueueImpl(newNode, newNode);
+ }
+ template <typename TCollection>
+ void EnqueueAll(const TCollection& data) {
+ EnqueueAll(data.begin(), data.end());
+ }
+ template <typename TIter>
+ void EnqueueAll(TIter dataBegin, TIter dataEnd) {
+ if (dataBegin == dataEnd)
+ return;
+
+ TIter i = dataBegin;
+ TListNode* volatile node = new TListNode(*i);
+ TListNode* volatile tail = node;
+
+ for (++i; i != dataEnd; ++i) {
+ TListNode* nextNode = node;
+ node = new TListNode(*i, nextNode);
+ }
+ EnqueueImpl(node, tail);
+ }
+ bool Dequeue(T* data) {
+ TRootNode* newRoot = nullptr;
+ TListInvertor listInvertor;
+ AsyncRef();
+ for (;;) {
+ TRootNode* curRoot = AtomicGet(JobQueue);
+ TListNode* tail = AtomicGet(curRoot->PopQueue);
+ if (tail) {
+ // has elems to pop
+ if (!newRoot)
+ newRoot = new TRootNode;
+
+ AtomicSet(newRoot->PushQueue, AtomicGet(curRoot->PushQueue));
+ AtomicSet(newRoot->PopQueue, AtomicGet(tail->Next));
+ newRoot->CopyCounter(curRoot);
+ newRoot->DecCount(tail->Data);
+ Y_ASSERT(AtomicGet(curRoot->PopQueue) == tail);
+ if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+ *data = std::move(tail->Data);
+ AtomicSet(tail->Next, nullptr);
+ AsyncUnref(curRoot, tail);
+ return true;
+ }
+ continue;
+ }
+ if (AtomicGet(curRoot->PushQueue) == nullptr) {
+ delete newRoot;
+ AsyncUnref();
+ return false; // no elems to pop
+ }
+
+ if (!newRoot)
+ newRoot = new TRootNode;
+ AtomicSet(newRoot->PushQueue, nullptr);
+ listInvertor.DoCopy(AtomicGet(curRoot->PushQueue));
+ AtomicSet(newRoot->PopQueue, listInvertor.Copy);
+ newRoot->CopyCounter(curRoot);
+ Y_ASSERT(AtomicGet(curRoot->PopQueue) == nullptr);
+ if (AtomicCas(&JobQueue, newRoot, curRoot)) {
+ newRoot = nullptr;
+ listInvertor.CopyWasUsed();
+ AsyncDel(curRoot, AtomicGet(curRoot->PushQueue));
+ } else {
+ AtomicSet(newRoot->PopQueue, nullptr);
+ }
+ }
+ }
+ template <typename TCollection>
+ void DequeueAll(TCollection* res) {
+ AsyncRef();
+
+ TRootNode* newRoot = new TRootNode;
+ TRootNode* curRoot;
+ do {
+ curRoot = AtomicGet(JobQueue);
+ } while (!AtomicCas(&JobQueue, newRoot, curRoot));
+
+ FillCollection(curRoot->PopQueue, res);
+
+ TListNode* toDeleteHead = curRoot->PushQueue;
+ TListNode* toDeleteTail = FillCollectionReverse(curRoot->PushQueue, res);
+ AtomicSet(curRoot->PushQueue, nullptr);
+
+ if (toDeleteTail) {
+ toDeleteTail->Next = curRoot->PopQueue;
+ } else {
+ toDeleteTail = curRoot->PopQueue;
+ }
+ AtomicSet(curRoot->PopQueue, nullptr);
+
+ AsyncUnref(curRoot, toDeleteHead);
+ }
+ bool IsEmpty() {
+ AsyncRef();
+ TRootNode* curRoot = AtomicGet(JobQueue);
+ bool res = AtomicGet(curRoot->PushQueue) == nullptr && AtomicGet(curRoot->PopQueue) == nullptr;
+ AsyncUnref();
+ return res;
+ }
+ TCounter GetCounter() {
+ AsyncRef();
+ TRootNode* curRoot = AtomicGet(JobQueue);
+ TCounter res = *(TCounter*)curRoot;
+ AsyncUnref();
+ return res;
+ }
+};
+
+template <class T, class TCounter>
+class TAutoLockFreeQueue {
+public:
+ using TRef = THolder<T>;
+
+ inline ~TAutoLockFreeQueue() {
+ TRef tmp;
+
+ while (Dequeue(&tmp)) {
+ }
+ }
+
+ inline bool Dequeue(TRef* t) {
+ T* res = nullptr;
+
+ if (Queue.Dequeue(&res)) {
+ t->Reset(res);
+
+ return true;
+ }
+
+ return false;
+ }
+
+ inline void Enqueue(TRef& t) {
+ Queue.Enqueue(t.Get());
+ Y_UNUSED(t.Release());
+ }
+
+ inline void Enqueue(TRef&& t) {
+ Queue.Enqueue(t.Get());
+ Y_UNUSED(t.Release());
+ }
+
+ inline bool IsEmpty() {
+ return Queue.IsEmpty();
+ }
+
+ inline TCounter GetCounter() {
+ return Queue.GetCounter();
+ }
+
+private:
+ TLockFreeQueue<T*, TCounter> Queue;
+};
diff --git a/util/thread/lfqueue_ut.cpp b/util/thread/lfqueue_ut.cpp
new file mode 100644
index 0000000000..83bca100cf
--- /dev/null
+++ b/util/thread/lfqueue_ut.cpp
@@ -0,0 +1,333 @@
+#include <library/cpp/threading/future/future.h>
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/generic/algorithm.h>
+#include <util/generic/vector.h>
+#include <util/generic/ptr.h>
+#include <util/system/atomic.h>
+#include <util/thread/pool.h>
+
+#include "lfqueue.h"
+
+class TMoveTest {
+public:
+ TMoveTest(int marker = 0, int value = 0)
+ : Marker_(marker)
+ , Value_(value)
+ {
+ }
+
+ TMoveTest(const TMoveTest& other) {
+ *this = other;
+ }
+
+ TMoveTest(TMoveTest&& other) {
+ *this = std::move(other);
+ }
+
+ TMoveTest& operator=(const TMoveTest& other) {
+ Value_ = other.Value_;
+ Marker_ = other.Marker_ + 1024;
+ return *this;
+ }
+
+ TMoveTest& operator=(TMoveTest&& other) {
+ Value_ = other.Value_;
+ Marker_ = other.Marker_;
+ other.Marker_ = 0;
+ return *this;
+ }
+
+ int Marker() const {
+ return Marker_;
+ }
+
+ int Value() const {
+ return Value_;
+ }
+
+private:
+ int Marker_ = 0;
+ int Value_ = 0;
+};
+
+class TOperationsChecker {
+public:
+ TOperationsChecker() {
+ ++DefaultCtor_;
+ }
+
+ TOperationsChecker(TOperationsChecker&&) {
+ ++MoveCtor_;
+ }
+
+ TOperationsChecker(const TOperationsChecker&) {
+ ++CopyCtor_;
+ }
+
+ TOperationsChecker& operator=(TOperationsChecker&&) {
+ ++MoveAssign_;
+ return *this;
+ }
+
+ TOperationsChecker& operator=(const TOperationsChecker&) {
+ ++CopyAssign_;
+ return *this;
+ }
+
+ static void Check(int defaultCtor, int moveCtor, int copyCtor, int moveAssign, int copyAssign) {
+ UNIT_ASSERT_VALUES_EQUAL(defaultCtor, DefaultCtor_);
+ UNIT_ASSERT_VALUES_EQUAL(moveCtor, MoveCtor_);
+ UNIT_ASSERT_VALUES_EQUAL(copyCtor, CopyCtor_);
+ UNIT_ASSERT_VALUES_EQUAL(moveAssign, MoveAssign_);
+ UNIT_ASSERT_VALUES_EQUAL(copyAssign, CopyAssign_);
+ Clear();
+ }
+
+private:
+ static void Clear() {
+ DefaultCtor_ = MoveCtor_ = CopyCtor_ = MoveAssign_ = CopyAssign_ = 0;
+ }
+
+ static int DefaultCtor_;
+ static int MoveCtor_;
+ static int CopyCtor_;
+ static int MoveAssign_;
+ static int CopyAssign_;
+};
+
+int TOperationsChecker::DefaultCtor_ = 0;
+int TOperationsChecker::MoveCtor_ = 0;
+int TOperationsChecker::CopyCtor_ = 0;
+int TOperationsChecker::MoveAssign_ = 0;
+int TOperationsChecker::CopyAssign_ = 0;
+
+Y_UNIT_TEST_SUITE(TLockFreeQueueTests) {
+ Y_UNIT_TEST(TestMoveEnqueue) {
+ TMoveTest value(0xFF, 0xAA);
+ TMoveTest tmp;
+
+ TLockFreeQueue<TMoveTest> queue;
+
+ queue.Enqueue(value);
+ UNIT_ASSERT_VALUES_EQUAL(value.Marker(), 0xFF);
+ UNIT_ASSERT(queue.Dequeue(&tmp));
+ UNIT_ASSERT_VALUES_UNEQUAL(tmp.Marker(), 0xFF);
+ UNIT_ASSERT_VALUES_EQUAL(tmp.Value(), 0xAA);
+
+ queue.Enqueue(std::move(value));
+ UNIT_ASSERT_VALUES_EQUAL(value.Marker(), 0);
+ UNIT_ASSERT(queue.Dequeue(&tmp));
+ UNIT_ASSERT_VALUES_EQUAL(tmp.Value(), 0xAA);
+ }
+
+ Y_UNIT_TEST(TestSimpleEnqueueDequeue) {
+ TLockFreeQueue<int> queue;
+
+ int i = -1;
+
+ UNIT_ASSERT(!queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(i, -1);
+
+ queue.Enqueue(10);
+ queue.Enqueue(11);
+ queue.Enqueue(12);
+
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(10, i);
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(11, i);
+
+ queue.Enqueue(13);
+
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(12, i);
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(13, i);
+
+ UNIT_ASSERT(!queue.Dequeue(&i));
+
+ const int tmp = 100;
+ queue.Enqueue(tmp);
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(i, tmp);
+ }
+
+ Y_UNIT_TEST(TestSimpleEnqueueAllDequeue) {
+ TLockFreeQueue<int> queue;
+
+ int i = -1;
+
+ UNIT_ASSERT(!queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(i, -1);
+
+ TVector<int> v;
+ v.push_back(20);
+ v.push_back(21);
+
+ queue.EnqueueAll(v);
+
+ v.clear();
+ v.push_back(22);
+ v.push_back(23);
+ v.push_back(24);
+
+ queue.EnqueueAll(v);
+
+ v.clear();
+ queue.EnqueueAll(v);
+
+ v.clear();
+ v.push_back(25);
+
+ queue.EnqueueAll(v);
+
+ for (int j = 20; j <= 25; ++j) {
+ UNIT_ASSERT(queue.Dequeue(&i));
+ UNIT_ASSERT_VALUES_EQUAL(j, i);
+ }
+
+ UNIT_ASSERT(!queue.Dequeue(&i));
+ }
+
+ void DequeueAllRunner(TLockFreeQueue<int>& queue, bool singleConsumer) {
+ size_t threadsNum = 4;
+ size_t enqueuesPerThread = 10'000;
+ TThreadPool p;
+ p.Start(threadsNum, 0);
+
+ TVector<NThreading::TFuture<void>> futures;
+
+ for (size_t i = 0; i < threadsNum; ++i) {
+ NThreading::TPromise<void> promise = NThreading::NewPromise();
+ futures.emplace_back(promise.GetFuture());
+
+ p.SafeAddFunc([enqueuesPerThread, &queue, promise]() mutable {
+ for (size_t i = 0; i != enqueuesPerThread; ++i) {
+ queue.Enqueue(i);
+ }
+
+ promise.SetValue();
+ });
+ }
+
+ TAtomic elementsLeft;
+ AtomicSet(elementsLeft, threadsNum * enqueuesPerThread);
+
+ ui64 numOfConsumers = singleConsumer ? 1 : threadsNum;
+
+ TVector<TVector<int>> dataBuckets(numOfConsumers);
+
+ for (size_t i = 0; i < numOfConsumers; ++i) {
+ NThreading::TPromise<void> promise = NThreading::NewPromise();
+ futures.emplace_back(promise.GetFuture());
+
+ p.SafeAddFunc([&queue, &elementsLeft, promise, consumerData{&dataBuckets[i]}]() mutable {
+ TVector<int> vec;
+ while (static_cast<i64>(AtomicGet(elementsLeft)) > 0) {
+ for (size_t i = 0; i != 100; ++i) {
+ vec.clear();
+ queue.DequeueAll(&vec);
+
+ AtomicSub(elementsLeft, vec.size());
+ consumerData->insert(consumerData->end(), vec.begin(), vec.end());
+ }
+ }
+
+ promise.SetValue();
+ });
+ }
+
+ NThreading::WaitExceptionOrAll(futures).GetValueSync();
+ p.Stop();
+
+ TVector<int> left;
+ queue.DequeueAll(&left);
+
+ UNIT_ASSERT(left.empty());
+
+ TVector<int> data;
+ for (auto& dataBucket : dataBuckets) {
+ data.insert(data.end(), dataBucket.begin(), dataBucket.end());
+ }
+
+ UNIT_ASSERT_EQUAL(data.size(), threadsNum * enqueuesPerThread);
+
+ size_t threadIdx = 0;
+ size_t cntValue = 0;
+
+ Sort(data.begin(), data.end());
+ for (size_t i = 0; i != data.size(); ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(cntValue, data[i]);
+ ++threadIdx;
+
+ if (threadIdx == threadsNum) {
+ ++cntValue;
+ threadIdx = 0;
+ }
+ }
+ }
+
+ Y_UNIT_TEST(TestDequeueAllSingleConsumer) {
+ TLockFreeQueue<int> queue;
+ DequeueAllRunner(queue, true);
+ }
+
+ Y_UNIT_TEST(TestDequeueAllMultipleConsumers) {
+ TLockFreeQueue<int> queue;
+ DequeueAllRunner(queue, false);
+ }
+
+ Y_UNIT_TEST(TestDequeueAllEmptyQueue) {
+ TLockFreeQueue<int> queue;
+ TVector<int> vec;
+
+ queue.DequeueAll(&vec);
+
+ UNIT_ASSERT(vec.empty());
+ }
+
+ Y_UNIT_TEST(TestDequeueAllQueueOrder) {
+ TLockFreeQueue<int> queue;
+ queue.Enqueue(1);
+ queue.Enqueue(2);
+ queue.Enqueue(3);
+
+ TVector<int> v;
+ queue.DequeueAll(&v);
+
+ UNIT_ASSERT_VALUES_EQUAL(v.size(), 3);
+ UNIT_ASSERT_VALUES_EQUAL(v[0], 1);
+ UNIT_ASSERT_VALUES_EQUAL(v[1], 2);
+ UNIT_ASSERT_VALUES_EQUAL(v[2], 3);
+ }
+
+ Y_UNIT_TEST(CleanInDestructor) {
+ TSimpleSharedPtr<bool> p(new bool);
+ UNIT_ASSERT_VALUES_EQUAL(1u, p.RefCount());
+
+ {
+ TLockFreeQueue<TSimpleSharedPtr<bool>> stack;
+
+ stack.Enqueue(p);
+ stack.Enqueue(p);
+
+ UNIT_ASSERT_VALUES_EQUAL(3u, p.RefCount());
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(1, p.RefCount());
+ }
+
+ Y_UNIT_TEST(CheckOperationsCount) {
+ TOperationsChecker o;
+ o.Check(1, 0, 0, 0, 0);
+ TLockFreeQueue<TOperationsChecker> queue;
+ o.Check(0, 0, 0, 0, 0);
+ queue.Enqueue(std::move(o));
+ o.Check(0, 1, 0, 0, 0);
+ queue.Enqueue(o);
+ o.Check(0, 0, 1, 0, 0);
+ queue.Dequeue(&o);
+ o.Check(0, 0, 2, 1, 0);
+ }
+}
diff --git a/util/thread/lfstack.cpp b/util/thread/lfstack.cpp
new file mode 100644
index 0000000000..be8b3bdf37
--- /dev/null
+++ b/util/thread/lfstack.cpp
@@ -0,0 +1 @@
+#include "lfstack.h"
diff --git a/util/thread/lfstack.h b/util/thread/lfstack.h
new file mode 100644
index 0000000000..ca3d95f3c3
--- /dev/null
+++ b/util/thread/lfstack.h
@@ -0,0 +1,188 @@
+#pragma once
+
+#include <util/generic/noncopyable.h>
+#include <util/system/atomic.h>
+
+//////////////////////////////
+// lock free lifo stack
+template <class T>
+class TLockFreeStack: TNonCopyable {
+ struct TNode {
+ T Value;
+ TNode* Next;
+
+ TNode() = default;
+
+ template <class U>
+ explicit TNode(U&& val)
+ : Value(std::forward<U>(val))
+ , Next(nullptr)
+ {
+ }
+ };
+
+ TNode* Head;
+ TNode* FreePtr;
+ TAtomic DequeueCount;
+
+ void TryToFreeMemory() {
+ TNode* current = AtomicGet(FreePtr);
+ if (!current)
+ return;
+ if (AtomicAdd(DequeueCount, 0) == 1) {
+ // node current is in free list, we are the last thread so try to cleanup
+ if (AtomicCas(&FreePtr, (TNode*)nullptr, current))
+ EraseList(current);
+ }
+ }
+ void EraseList(TNode* volatile p) {
+ while (p) {
+ TNode* next = p->Next;
+ delete p;
+ p = next;
+ }
+ }
+ void EnqueueImpl(TNode* volatile head, TNode* volatile tail) {
+ for (;;) {
+ tail->Next = AtomicGet(Head);
+ if (AtomicCas(&Head, head, tail->Next))
+ break;
+ }
+ }
+ template <class U>
+ void EnqueueImpl(U&& u) {
+ TNode* volatile node = new TNode(std::forward<U>(u));
+ EnqueueImpl(node, node);
+ }
+
+public:
+ TLockFreeStack()
+ : Head(nullptr)
+ , FreePtr(nullptr)
+ , DequeueCount(0)
+ {
+ }
+ ~TLockFreeStack() {
+ EraseList(Head);
+ EraseList(FreePtr);
+ }
+
+ void Enqueue(const T& t) {
+ EnqueueImpl(t);
+ }
+
+ void Enqueue(T&& t) {
+ EnqueueImpl(std::move(t));
+ }
+
+ template <typename TCollection>
+ void EnqueueAll(const TCollection& data) {
+ EnqueueAll(data.begin(), data.end());
+ }
+ template <typename TIter>
+ void EnqueueAll(TIter dataBegin, TIter dataEnd) {
+ if (dataBegin == dataEnd) {
+ return;
+ }
+ TIter i = dataBegin;
+ TNode* volatile node = new TNode(*i);
+ TNode* volatile tail = node;
+
+ for (++i; i != dataEnd; ++i) {
+ TNode* nextNode = node;
+ node = new TNode(*i);
+ node->Next = nextNode;
+ }
+ EnqueueImpl(node, tail);
+ }
+ bool Dequeue(T* res) {
+ AtomicAdd(DequeueCount, 1);
+ for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) {
+ if (AtomicCas(&Head, AtomicGet(current->Next), current)) {
+ *res = std::move(current->Value);
+ // delete current; // ABA problem
+ // even more complex node deletion
+ TryToFreeMemory();
+ if (AtomicAdd(DequeueCount, -1) == 0) {
+ // no other Dequeue()s, can safely reclaim memory
+ delete current;
+ } else {
+ // Dequeue()s in progress, put node to free list
+ for (;;) {
+ AtomicSet(current->Next, AtomicGet(FreePtr));
+ if (AtomicCas(&FreePtr, current, current->Next))
+ break;
+ }
+ }
+ return true;
+ }
+ }
+ TryToFreeMemory();
+ AtomicAdd(DequeueCount, -1);
+ return false;
+ }
+ // add all elements to *res
+ // elements are returned in order of dequeue (top to bottom; see example in unittest)
+ template <typename TCollection>
+ void DequeueAll(TCollection* res) {
+ AtomicAdd(DequeueCount, 1);
+ for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) {
+ if (AtomicCas(&Head, (TNode*)nullptr, current)) {
+ for (TNode* x = current; x;) {
+ res->push_back(std::move(x->Value));
+ x = x->Next;
+ }
+ // EraseList(current); // ABA problem
+ // even more complex node deletion
+ TryToFreeMemory();
+ if (AtomicAdd(DequeueCount, -1) == 0) {
+ // no other Dequeue()s, can safely reclaim memory
+ EraseList(current);
+ } else {
+ // Dequeue()s in progress, add nodes list to free list
+ TNode* currentLast = current;
+ while (currentLast->Next) {
+ currentLast = currentLast->Next;
+ }
+ for (;;) {
+ AtomicSet(currentLast->Next, AtomicGet(FreePtr));
+ if (AtomicCas(&FreePtr, current, currentLast->Next))
+ break;
+ }
+ }
+ return;
+ }
+ }
+ TryToFreeMemory();
+ AtomicAdd(DequeueCount, -1);
+ }
+ bool DequeueSingleConsumer(T* res) {
+ for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) {
+ if (AtomicCas(&Head, current->Next, current)) {
+ *res = std::move(current->Value);
+ delete current; // with single consumer thread ABA does not happen
+ return true;
+ }
+ }
+ return false;
+ }
+ // add all elements to *res
+ // elements are returned in order of dequeue (top to bottom; see example in unittest)
+ template <typename TCollection>
+ void DequeueAllSingleConsumer(TCollection* res) {
+ for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) {
+ if (AtomicCas(&Head, (TNode*)nullptr, current)) {
+ for (TNode* x = current; x;) {
+ res->push_back(std::move(x->Value));
+ x = x->Next;
+ }
+ EraseList(current); // with single consumer thread ABA does not happen
+ return;
+ }
+ }
+ }
+ bool IsEmpty() {
+ AtomicAdd(DequeueCount, 0); // mem barrier
+ return AtomicGet(Head) == nullptr; // without lock, so result is approximate
+ }
+};
diff --git a/util/thread/lfstack_ut.cpp b/util/thread/lfstack_ut.cpp
new file mode 100644
index 0000000000..e20a838f95
--- /dev/null
+++ b/util/thread/lfstack_ut.cpp
@@ -0,0 +1,346 @@
+
+#include <util/system/atomic.h>
+#include <util/system/event.h>
+#include <util/generic/deque.h>
+#include <library/cpp/threading/future/legacy_future.h>
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "lfstack.h"
+
+Y_UNIT_TEST_SUITE(TLockFreeStackTests) {
+ class TCountDownLatch {
+ private:
+ TAtomic Current_;
+ TSystemEvent EventObject_;
+
+ public:
+ TCountDownLatch(unsigned initial)
+ : Current_(initial)
+ {
+ }
+
+ void CountDown() {
+ if (AtomicDecrement(Current_) == 0) {
+ EventObject_.Signal();
+ }
+ }
+
+ void Await() {
+ EventObject_.Wait();
+ }
+
+ bool Await(TDuration timeout) {
+ return EventObject_.WaitT(timeout);
+ }
+ };
+
+ template <bool SingleConsumer>
+ struct TDequeueAllTester {
+ size_t EnqueueThreads;
+ size_t DequeueThreads;
+
+ size_t EnqueuesPerThread;
+ TAtomic LeftToDequeue;
+
+ TCountDownLatch StartLatch;
+ TLockFreeStack<int> Stack;
+
+ TDequeueAllTester()
+ : EnqueueThreads(4)
+ , DequeueThreads(SingleConsumer ? 1 : 3)
+ , EnqueuesPerThread(100000)
+ , LeftToDequeue(EnqueueThreads * EnqueuesPerThread)
+ , StartLatch(EnqueueThreads + DequeueThreads)
+ {
+ }
+
+ void Enqueuer() {
+ StartLatch.CountDown();
+ StartLatch.Await();
+
+ for (size_t i = 0; i < EnqueuesPerThread; ++i) {
+ Stack.Enqueue(i);
+ }
+ }
+
+ void DequeuerAll() {
+ StartLatch.CountDown();
+ StartLatch.Await();
+
+ TVector<int> temp;
+ while (AtomicGet(LeftToDequeue) > 0) {
+ size_t dequeued = 0;
+ for (size_t i = 0; i < 100; ++i) {
+ temp.clear();
+ if (SingleConsumer) {
+ Stack.DequeueAllSingleConsumer(&temp);
+ } else {
+ Stack.DequeueAll(&temp);
+ }
+ dequeued += temp.size();
+ }
+ AtomicAdd(LeftToDequeue, -dequeued);
+ }
+ }
+
+ void Run() {
+ TVector<TSimpleSharedPtr<NThreading::TLegacyFuture<>>> futures;
+
+ for (size_t i = 0; i < EnqueueThreads; ++i) {
+ futures.push_back(new NThreading::TLegacyFuture<>(std::bind(&TDequeueAllTester<SingleConsumer>::Enqueuer, this)));
+ }
+
+ for (size_t i = 0; i < DequeueThreads; ++i) {
+ futures.push_back(new NThreading::TLegacyFuture<>(std::bind(&TDequeueAllTester<SingleConsumer>::DequeuerAll, this)));
+ }
+
+ // effectively join
+ futures.clear();
+
+ UNIT_ASSERT_VALUES_EQUAL(0, int(AtomicGet(LeftToDequeue)));
+
+ TVector<int> left;
+ Stack.DequeueAll(&left);
+ UNIT_ASSERT(left.empty());
+ }
+ };
+
+ Y_UNIT_TEST(TestDequeueAll) {
+ TDequeueAllTester<false>().Run();
+ }
+
+ Y_UNIT_TEST(TestDequeueAllSingleConsumer) {
+ TDequeueAllTester<true>().Run();
+ }
+
+ Y_UNIT_TEST(TestDequeueAllEmptyStack) {
+ TLockFreeStack<int> stack;
+
+ TVector<int> r;
+ stack.DequeueAll(&r);
+
+ UNIT_ASSERT(r.empty());
+ }
+
+ Y_UNIT_TEST(TestDequeueAllReturnsInReverseOrder) {
+ TLockFreeStack<int> stack;
+
+ stack.Enqueue(17);
+ stack.Enqueue(19);
+ stack.Enqueue(23);
+
+ TVector<int> r;
+
+ stack.DequeueAll(&r);
+
+ UNIT_ASSERT_VALUES_EQUAL(size_t(3), r.size());
+ UNIT_ASSERT_VALUES_EQUAL(23, r.at(0));
+ UNIT_ASSERT_VALUES_EQUAL(19, r.at(1));
+ UNIT_ASSERT_VALUES_EQUAL(17, r.at(2));
+ }
+
+ Y_UNIT_TEST(TestEnqueueAll) {
+ TLockFreeStack<int> stack;
+
+ TVector<int> v;
+ TVector<int> expected;
+
+ stack.EnqueueAll(v); // add empty
+
+ v.push_back(2);
+ v.push_back(3);
+ v.push_back(5);
+ expected.insert(expected.end(), v.begin(), v.end());
+ stack.EnqueueAll(v);
+
+ v.clear();
+
+ stack.EnqueueAll(v); // add empty
+
+ v.push_back(7);
+ v.push_back(11);
+ v.push_back(13);
+ v.push_back(17);
+ expected.insert(expected.end(), v.begin(), v.end());
+ stack.EnqueueAll(v);
+
+ TVector<int> actual;
+ stack.DequeueAll(&actual);
+
+ UNIT_ASSERT_VALUES_EQUAL(expected.size(), actual.size());
+ for (size_t i = 0; i < actual.size(); ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(expected.at(expected.size() - i - 1), actual.at(i));
+ }
+ }
+
+ Y_UNIT_TEST(CleanInDestructor) {
+ TSimpleSharedPtr<bool> p(new bool);
+ UNIT_ASSERT_VALUES_EQUAL(1u, p.RefCount());
+
+ {
+ TLockFreeStack<TSimpleSharedPtr<bool>> stack;
+
+ stack.Enqueue(p);
+ stack.Enqueue(p);
+
+ UNIT_ASSERT_VALUES_EQUAL(3u, p.RefCount());
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(1, p.RefCount());
+ }
+
+ Y_UNIT_TEST(NoCopyTest) {
+ static unsigned copied = 0;
+ struct TCopyCount {
+ TCopyCount(int) {
+ }
+ TCopyCount(const TCopyCount&) {
+ ++copied;
+ }
+
+ TCopyCount(TCopyCount&&) {
+ }
+
+ TCopyCount& operator=(const TCopyCount&) {
+ ++copied;
+ return *this;
+ }
+
+ TCopyCount& operator=(TCopyCount&&) {
+ return *this;
+ }
+ };
+
+ TLockFreeStack<TCopyCount> stack;
+ stack.Enqueue(TCopyCount(1));
+ TCopyCount val(0);
+ stack.Dequeue(&val);
+ UNIT_ASSERT_VALUES_EQUAL(0, copied);
+ }
+
+ Y_UNIT_TEST(MoveOnlyTest) {
+ TLockFreeStack<THolder<bool>> stack;
+ stack.Enqueue(MakeHolder<bool>(true));
+ THolder<bool> val;
+ stack.Dequeue(&val);
+ UNIT_ASSERT(val);
+ UNIT_ASSERT_VALUES_EQUAL(true, *val);
+ }
+
+ template <class TTest>
+ struct TMultiThreadTester {
+ using ThisType = TMultiThreadTester<TTest>;
+
+ size_t Threads;
+ size_t OperationsPerThread;
+
+ TCountDownLatch StartLatch;
+ TLockFreeStack<typename TTest::ValueType> Stack;
+
+ TMultiThreadTester()
+ : Threads(10)
+ , OperationsPerThread(100000)
+ , StartLatch(Threads)
+ {
+ }
+
+ void Worker() {
+ StartLatch.CountDown();
+ StartLatch.Await();
+
+ TVector<typename TTest::ValueType> unused;
+ for (size_t i = 0; i < OperationsPerThread; ++i) {
+ switch (GetCycleCount() % 4) {
+ case 0: {
+ TTest::Enqueue(Stack, i);
+ break;
+ }
+ case 1: {
+ TTest::Dequeue(Stack);
+ break;
+ }
+ case 2: {
+ TTest::EnqueueAll(Stack);
+ break;
+ }
+ case 3: {
+ TTest::DequeueAll(Stack);
+ break;
+ }
+ }
+ }
+ }
+
+ void Run() {
+ TDeque<NThreading::TLegacyFuture<>> futures;
+
+ for (size_t i = 0; i < Threads; ++i) {
+ futures.emplace_back(std::bind(&ThisType::Worker, this));
+ }
+ futures.clear();
+ TTest::DequeueAll(Stack);
+ }
+ };
+
+ struct TFreeListTest {
+ using ValueType = int;
+
+ static void Enqueue(TLockFreeStack<int>& stack, size_t i) {
+ stack.Enqueue(static_cast<int>(i));
+ }
+
+ static void Dequeue(TLockFreeStack<int>& stack) {
+ int value;
+ stack.Dequeue(&value);
+ }
+
+ static void EnqueueAll(TLockFreeStack<int>& stack) {
+ TVector<int> values(5);
+ stack.EnqueueAll(values);
+ }
+
+ static void DequeueAll(TLockFreeStack<int>& stack) {
+ TVector<int> value;
+ stack.DequeueAll(&value);
+ }
+ };
+
+ // Test for catching thread sanitizer problems
+ Y_UNIT_TEST(TestFreeList) {
+ TMultiThreadTester<TFreeListTest>().Run();
+ }
+
+ struct TMoveTest {
+ using ValueType = THolder<int>;
+
+ static void Enqueue(TLockFreeStack<ValueType>& stack, size_t i) {
+ stack.Enqueue(MakeHolder<int>(static_cast<int>(i)));
+ }
+
+ static void Dequeue(TLockFreeStack<ValueType>& stack) {
+ ValueType value;
+ if (stack.Dequeue(&value)) {
+ UNIT_ASSERT(value);
+ }
+ }
+
+ static void EnqueueAll(TLockFreeStack<ValueType>& stack) {
+ // there is no enqueAll with moving signature in LockFreeStack
+ Enqueue(stack, 0);
+ }
+
+ static void DequeueAll(TLockFreeStack<ValueType>& stack) {
+ TVector<ValueType> values;
+ stack.DequeueAll(&values);
+ for (auto& v : values) {
+ UNIT_ASSERT(v);
+ }
+ }
+ };
+
+ // Test for catching thread sanitizer problems
+ Y_UNIT_TEST(TesMultiThreadMove) {
+ TMultiThreadTester<TMoveTest>().Run();
+ }
+}
diff --git a/util/thread/pool.cpp b/util/thread/pool.cpp
new file mode 100644
index 0000000000..05fad02e9b
--- /dev/null
+++ b/util/thread/pool.cpp
@@ -0,0 +1,772 @@
+#include <atomic>
+
+#include <util/system/defaults.h>
+
+#if defined(_unix_)
+ #include <pthread.h>
+#endif
+
+#include <util/generic/vector.h>
+#include <util/generic/intrlist.h>
+#include <util/generic/yexception.h>
+#include <util/generic/ylimits.h>
+#include <util/generic/singleton.h>
+#include <util/generic/fastqueue.h>
+
+#include <util/stream/output.h>
+#include <util/string/builder.h>
+
+#include <util/system/event.h>
+#include <util/system/mutex.h>
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/thread.h>
+
+#include <util/datetime/base.h>
+
+#include "factory.h"
+#include "pool.h"
+
+namespace {
+ class TThreadNamer {
+ public:
+ TThreadNamer(const IThreadPool::TParams& params)
+ : ThreadName(params.ThreadName_)
+ , EnumerateThreads(params.EnumerateThreads_)
+ {
+ }
+
+ explicit operator bool() const {
+ return !ThreadName.empty();
+ }
+
+ void SetCurrentThreadName() {
+ if (EnumerateThreads) {
+ Set(TStringBuilder() << ThreadName << (Index++));
+ } else {
+ Set(ThreadName);
+ }
+ }
+
+ private:
+ void Set(const TString& name) {
+ TThread::SetCurrentThreadName(name.c_str());
+ }
+
+ private:
+ TString ThreadName;
+ bool EnumerateThreads = false;
+ std::atomic<ui64> Index{0};
+ };
+}
+
+TThreadFactoryHolder::TThreadFactoryHolder() noexcept
+ : Pool_(SystemThreadFactory())
+{
+}
+
+class TThreadPool::TImpl: public TIntrusiveListItem<TImpl>, public IThreadFactory::IThreadAble {
+ using TTsr = IThreadPool::TTsr;
+ using TJobQueue = TFastQueue<IObjectInQueue*>;
+ using TThreadRef = THolder<IThreadFactory::IThread>;
+
+public:
+ inline TImpl(TThreadPool* parent, size_t thrnum, size_t maxqueue, const TParams& params)
+ : Parent_(parent)
+ , Blocking(params.Blocking_)
+ , Catching(params.Catching_)
+ , Namer(params)
+ , ShouldTerminate(1)
+ , MaxQueueSize(0)
+ , ThreadCountExpected(0)
+ , ThreadCountReal(0)
+ , Forked(false)
+ {
+ TAtforkQueueRestarter::Get().RegisterObject(this);
+ Start(thrnum, maxqueue);
+ }
+
+ inline ~TImpl() override {
+ try {
+ Stop();
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+
+ TAtforkQueueRestarter::Get().UnregisterObject(this);
+ Y_ASSERT(Tharr.empty());
+ }
+
+ inline bool Add(IObjectInQueue* obj) {
+ if (AtomicGet(ShouldTerminate)) {
+ return false;
+ }
+
+ if (Tharr.empty()) {
+ TTsr tsr(Parent_);
+ obj->Process(tsr);
+
+ return true;
+ }
+
+ with_lock (QueueMutex) {
+ while (MaxQueueSize > 0 && Queue.Size() >= MaxQueueSize && !AtomicGet(ShouldTerminate)) {
+ if (!Blocking) {
+ return false;
+ }
+ QueuePopCond.Wait(QueueMutex);
+ }
+
+ if (AtomicGet(ShouldTerminate)) {
+ return false;
+ }
+
+ Queue.Push(obj);
+ }
+
+ QueuePushCond.Signal();
+
+ return true;
+ }
+
+ inline size_t Size() const noexcept {
+ auto guard = Guard(QueueMutex);
+
+ return Queue.Size();
+ }
+
+ inline size_t GetMaxQueueSize() const noexcept {
+ return MaxQueueSize;
+ }
+
+ inline size_t GetThreadCountExpected() const noexcept {
+ return ThreadCountExpected;
+ }
+
+ inline size_t GetThreadCountReal() const noexcept {
+ return ThreadCountReal;
+ }
+
+ inline void AtforkAction() noexcept Y_NO_SANITIZE("thread") {
+ Forked = true;
+ }
+
+ inline bool NeedRestart() const noexcept {
+ return Forked;
+ }
+
+private:
+ inline void Start(size_t num, size_t maxque) {
+ AtomicSet(ShouldTerminate, 0);
+ MaxQueueSize = maxque;
+ ThreadCountExpected = num;
+
+ try {
+ for (size_t i = 0; i < num; ++i) {
+ Tharr.push_back(Parent_->Pool()->Run(this));
+ ++ThreadCountReal;
+ }
+ } catch (...) {
+ Stop();
+
+ throw;
+ }
+ }
+
+ inline void Stop() {
+ AtomicSet(ShouldTerminate, 1);
+
+ with_lock (QueueMutex) {
+ QueuePopCond.BroadCast();
+ }
+
+ if (!NeedRestart()) {
+ WaitForComplete();
+ }
+
+ Tharr.clear();
+ ThreadCountExpected = 0;
+ MaxQueueSize = 0;
+ }
+
+ inline void WaitForComplete() noexcept {
+ with_lock (StopMutex) {
+ while (ThreadCountReal) {
+ with_lock (QueueMutex) {
+ QueuePushCond.Signal();
+ }
+
+ StopCond.Wait(StopMutex);
+ }
+ }
+ }
+
+ void DoExecute() override {
+ THolder<TTsr> tsr(new TTsr(Parent_));
+
+ if (Namer) {
+ Namer.SetCurrentThreadName();
+ }
+
+ while (true) {
+ IObjectInQueue* job = nullptr;
+
+ with_lock (QueueMutex) {
+ while (Queue.Empty() && !AtomicGet(ShouldTerminate)) {
+ QueuePushCond.Wait(QueueMutex);
+ }
+
+ if (AtomicGet(ShouldTerminate) && Queue.Empty()) {
+ tsr.Destroy();
+
+ break;
+ }
+
+ job = Queue.Pop();
+ }
+
+ QueuePopCond.Signal();
+
+ if (Catching) {
+ try {
+ try {
+ job->Process(*tsr);
+ } catch (...) {
+ Cdbg << "[mtp queue] " << CurrentExceptionMessage() << Endl;
+ }
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+ } else {
+ job->Process(*tsr);
+ }
+ }
+
+ FinishOneThread();
+ }
+
+ inline void FinishOneThread() noexcept {
+ auto guard = Guard(StopMutex);
+
+ --ThreadCountReal;
+ StopCond.Signal();
+ }
+
+private:
+ TThreadPool* Parent_;
+ const bool Blocking;
+ const bool Catching;
+ TThreadNamer Namer;
+ mutable TMutex QueueMutex;
+ mutable TMutex StopMutex;
+ TCondVar QueuePushCond;
+ TCondVar QueuePopCond;
+ TCondVar StopCond;
+ TJobQueue Queue;
+ TVector<TThreadRef> Tharr;
+ TAtomic ShouldTerminate;
+ size_t MaxQueueSize;
+ size_t ThreadCountExpected;
+ size_t ThreadCountReal;
+ bool Forked;
+
+ class TAtforkQueueRestarter {
+ public:
+ static TAtforkQueueRestarter& Get() {
+ return *SingletonWithPriority<TAtforkQueueRestarter, 256>();
+ }
+
+ inline void RegisterObject(TImpl* obj) {
+ auto guard = Guard(ActionMutex);
+
+ RegisteredObjects.PushBack(obj);
+ }
+
+ inline void UnregisterObject(TImpl* obj) {
+ auto guard = Guard(ActionMutex);
+
+ obj->Unlink();
+ }
+
+ private:
+ void ChildAction() {
+ with_lock (ActionMutex) {
+ for (auto it = RegisteredObjects.Begin(); it != RegisteredObjects.End(); ++it) {
+ it->AtforkAction();
+ }
+ }
+ }
+
+ static void ProcessChildAction() {
+ Get().ChildAction();
+ }
+
+ TIntrusiveList<TImpl> RegisteredObjects;
+ TMutex ActionMutex;
+
+ public:
+ inline TAtforkQueueRestarter() {
+#if defined(_bionic_)
+//no pthread_atfork on android libc
+#elif defined(_unix_)
+ pthread_atfork(nullptr, nullptr, ProcessChildAction);
+#endif
+ }
+ };
+};
+
+TThreadPool::~TThreadPool() = default;
+
+size_t TThreadPool::Size() const noexcept {
+ if (!Impl_.Get()) {
+ return 0;
+ }
+
+ return Impl_->Size();
+}
+
+size_t TThreadPool::GetThreadCountExpected() const noexcept {
+ if (!Impl_.Get()) {
+ return 0;
+ }
+
+ return Impl_->GetThreadCountExpected();
+}
+
+size_t TThreadPool::GetThreadCountReal() const noexcept {
+ if (!Impl_.Get()) {
+ return 0;
+ }
+
+ return Impl_->GetThreadCountReal();
+}
+
+size_t TThreadPool::GetMaxQueueSize() const noexcept {
+ if (!Impl_.Get()) {
+ return 0;
+ }
+
+ return Impl_->GetMaxQueueSize();
+}
+
+bool TThreadPool::Add(IObjectInQueue* obj) {
+ Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started"));
+
+ if (Impl_->NeedRestart()) {
+ Start(Impl_->GetThreadCountExpected(), Impl_->GetMaxQueueSize());
+ }
+
+ return Impl_->Add(obj);
+}
+
+void TThreadPool::Start(size_t thrnum, size_t maxque) {
+ Impl_.Reset(new TImpl(this, thrnum, maxque, Params));
+}
+
+void TThreadPool::Stop() noexcept {
+ Impl_.Destroy();
+}
+
+static TAtomic mtp_queue_counter = 0;
+
+class TAdaptiveThreadPool::TImpl {
+public:
+ class TThread: public IThreadFactory::IThreadAble {
+ public:
+ inline TThread(TImpl* parent)
+ : Impl_(parent)
+ , Thread_(Impl_->Parent_->Pool()->Run(this))
+ {
+ }
+
+ inline ~TThread() override {
+ Impl_->DecThreadCount();
+ }
+
+ private:
+ void DoExecute() noexcept override {
+ THolder<TThread> This(this);
+
+ if (Impl_->Namer) {
+ Impl_->Namer.SetCurrentThreadName();
+ }
+
+ {
+ TTsr tsr(Impl_->Parent_);
+ IObjectInQueue* obj;
+
+ while ((obj = Impl_->WaitForJob()) != nullptr) {
+ if (Impl_->Catching) {
+ try {
+ try {
+ obj->Process(tsr);
+ } catch (...) {
+ Cdbg << Impl_->Name() << " " << CurrentExceptionMessage() << Endl;
+ }
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+ } else {
+ obj->Process(tsr);
+ }
+ }
+ }
+ }
+
+ private:
+ TImpl* Impl_;
+ THolder<IThreadFactory::IThread> Thread_;
+ };
+
+ inline TImpl(TAdaptiveThreadPool* parent, const TParams& params)
+ : Parent_(parent)
+ , Catching(params.Catching_)
+ , Namer(params)
+ , ThrCount_(0)
+ , AllDone_(false)
+ , Obj_(nullptr)
+ , Free_(0)
+ , IdleTime_(TDuration::Max())
+ {
+ sprintf(Name_, "[mtp queue %ld]", (long)AtomicAdd(mtp_queue_counter, 1));
+ }
+
+ inline ~TImpl() {
+ Stop();
+ }
+
+ inline void SetMaxIdleTime(TDuration idleTime) {
+ IdleTime_ = idleTime;
+ }
+
+ inline const char* Name() const noexcept {
+ return Name_;
+ }
+
+ inline void Add(IObjectInQueue* obj) {
+ with_lock (Mutex_) {
+ while (Obj_ != nullptr) {
+ CondFree_.Wait(Mutex_);
+ }
+
+ if (Free_ == 0) {
+ AddThreadNoLock();
+ }
+
+ Obj_ = obj;
+
+ Y_ENSURE_EX(!AllDone_, TThreadPoolException() << TStringBuf("adding to a stopped queue"));
+ }
+
+ CondReady_.Signal();
+ }
+
+ inline void AddThreads(size_t n) {
+ with_lock (Mutex_) {
+ while (n) {
+ AddThreadNoLock();
+
+ --n;
+ }
+ }
+ }
+
+ inline size_t Size() const noexcept {
+ return (size_t)ThrCount_;
+ }
+
+private:
+ inline void IncThreadCount() noexcept {
+ AtomicAdd(ThrCount_, 1);
+ }
+
+ inline void DecThreadCount() noexcept {
+ AtomicAdd(ThrCount_, -1);
+ }
+
+ inline void AddThreadNoLock() {
+ IncThreadCount();
+
+ try {
+ new TThread(this);
+ } catch (...) {
+ DecThreadCount();
+
+ throw;
+ }
+ }
+
+ inline void Stop() noexcept {
+ Mutex_.Acquire();
+
+ AllDone_ = true;
+
+ while (AtomicGet(ThrCount_)) {
+ Mutex_.Release();
+ CondReady_.Signal();
+ Mutex_.Acquire();
+ }
+
+ Mutex_.Release();
+ }
+
+ inline IObjectInQueue* WaitForJob() noexcept {
+ Mutex_.Acquire();
+
+ ++Free_;
+
+ while (!Obj_ && !AllDone_) {
+ if (!CondReady_.WaitT(Mutex_, IdleTime_)) {
+ break;
+ }
+ }
+
+ IObjectInQueue* ret = Obj_;
+ Obj_ = nullptr;
+
+ --Free_;
+
+ Mutex_.Release();
+ CondFree_.Signal();
+
+ return ret;
+ }
+
+private:
+ TAdaptiveThreadPool* Parent_;
+ const bool Catching;
+ TThreadNamer Namer;
+ TAtomic ThrCount_;
+ TMutex Mutex_;
+ TCondVar CondReady_;
+ TCondVar CondFree_;
+ bool AllDone_;
+ IObjectInQueue* Obj_;
+ size_t Free_;
+ char Name_[64];
+ TDuration IdleTime_;
+};
+
+TThreadPoolBase::TThreadPoolBase(const TParams& params)
+ : TThreadFactoryHolder(params.Factory_)
+ , Params(params)
+{
+}
+
+#define DEFINE_THREAD_POOL_CTORS(type) \
+ type::type(const TParams& params) \
+ : TThreadPoolBase(params) \
+ { \
+ }
+
+DEFINE_THREAD_POOL_CTORS(TThreadPool)
+DEFINE_THREAD_POOL_CTORS(TAdaptiveThreadPool)
+DEFINE_THREAD_POOL_CTORS(TSimpleThreadPool)
+
+TAdaptiveThreadPool::~TAdaptiveThreadPool() = default;
+
+bool TAdaptiveThreadPool::Add(IObjectInQueue* obj) {
+ Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started"));
+
+ Impl_->Add(obj);
+
+ return true;
+}
+
+void TAdaptiveThreadPool::Start(size_t, size_t) {
+ Impl_.Reset(new TImpl(this, Params));
+}
+
+void TAdaptiveThreadPool::Stop() noexcept {
+ Impl_.Destroy();
+}
+
+size_t TAdaptiveThreadPool::Size() const noexcept {
+ if (Impl_.Get()) {
+ return Impl_->Size();
+ }
+
+ return 0;
+}
+
+void TAdaptiveThreadPool::SetMaxIdleTime(TDuration interval) {
+ Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started"));
+
+ Impl_->SetMaxIdleTime(interval);
+}
+
+TSimpleThreadPool::~TSimpleThreadPool() {
+ try {
+ Stop();
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+}
+
+bool TSimpleThreadPool::Add(IObjectInQueue* obj) {
+ Y_ENSURE_EX(Slave_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started"));
+
+ return Slave_->Add(obj);
+}
+
+void TSimpleThreadPool::Start(size_t thrnum, size_t maxque) {
+ THolder<IThreadPool> tmp;
+ TAdaptiveThreadPool* adaptive(nullptr);
+
+ if (thrnum) {
+ tmp.Reset(new TThreadPoolBinder<TThreadPool, TSimpleThreadPool>(this, Params));
+ } else {
+ adaptive = new TThreadPoolBinder<TAdaptiveThreadPool, TSimpleThreadPool>(this, Params);
+ tmp.Reset(adaptive);
+ }
+
+ tmp->Start(thrnum, maxque);
+
+ if (adaptive) {
+ adaptive->SetMaxIdleTime(TDuration::Seconds(100));
+ }
+
+ Slave_.Swap(tmp);
+}
+
+void TSimpleThreadPool::Stop() noexcept {
+ Slave_.Destroy();
+}
+
+size_t TSimpleThreadPool::Size() const noexcept {
+ if (Slave_.Get()) {
+ return Slave_->Size();
+ }
+
+ return 0;
+}
+
+namespace {
+ class TOwnedObjectInQueue: public IObjectInQueue {
+ private:
+ THolder<IObjectInQueue> Owned;
+
+ public:
+ TOwnedObjectInQueue(THolder<IObjectInQueue> owned)
+ : Owned(std::move(owned))
+ {
+ }
+
+ void Process(void* data) override {
+ THolder<TOwnedObjectInQueue> self(this);
+ Owned->Process(data);
+ }
+ };
+}
+
+void IThreadPool::SafeAdd(IObjectInQueue* obj) {
+ Y_ENSURE_EX(Add(obj), TThreadPoolException() << TStringBuf("can not add object to queue"));
+}
+
+void IThreadPool::SafeAddAndOwn(THolder<IObjectInQueue> obj) {
+ Y_ENSURE_EX(AddAndOwn(std::move(obj)), TThreadPoolException() << TStringBuf("can not add to queue and own"));
+}
+
+bool IThreadPool::AddAndOwn(THolder<IObjectInQueue> obj) {
+ auto owner = MakeHolder<TOwnedObjectInQueue>(std::move(obj));
+ bool added = Add(owner.Get());
+ if (added) {
+ Y_UNUSED(owner.Release());
+ }
+ return added;
+}
+
+using IThread = IThreadFactory::IThread;
+using IThreadAble = IThreadFactory::IThreadAble;
+
+namespace {
+ class TPoolThread: public IThread {
+ class TThreadImpl: public IObjectInQueue, public TAtomicRefCount<TThreadImpl> {
+ public:
+ inline TThreadImpl(IThreadAble* func)
+ : Func_(func)
+ {
+ }
+
+ ~TThreadImpl() override = default;
+
+ inline void WaitForStart() noexcept {
+ StartEvent_.Wait();
+ }
+
+ inline void WaitForComplete() noexcept {
+ CompleteEvent_.Wait();
+ }
+
+ private:
+ void Process(void* /*tsr*/) override {
+ TThreadImplRef This(this);
+
+ {
+ StartEvent_.Signal();
+
+ try {
+ Func_->Execute();
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+
+ CompleteEvent_.Signal();
+ }
+ }
+
+ private:
+ IThreadAble* Func_;
+ TSystemEvent CompleteEvent_;
+ TSystemEvent StartEvent_;
+ };
+
+ using TThreadImplRef = TIntrusivePtr<TThreadImpl>;
+
+ public:
+ inline TPoolThread(IThreadPool* parent)
+ : Parent_(parent)
+ {
+ }
+
+ ~TPoolThread() override {
+ if (Impl_) {
+ Impl_->WaitForStart();
+ }
+ }
+
+ private:
+ void DoRun(IThreadAble* func) override {
+ TThreadImplRef impl(new TThreadImpl(func));
+
+ Parent_->SafeAdd(impl.Get());
+ Impl_.Swap(impl);
+ }
+
+ void DoJoin() noexcept override {
+ if (Impl_) {
+ Impl_->WaitForComplete();
+ Impl_ = nullptr;
+ }
+ }
+
+ private:
+ IThreadPool* Parent_;
+ TThreadImplRef Impl_;
+ };
+}
+
+IThread* IThreadPool::DoCreate() {
+ return new TPoolThread(this);
+}
+
+THolder<IThreadPool> CreateThreadPool(size_t threadsCount, size_t queueSizeLimit, const TThreadPoolParams& params) {
+ THolder<IThreadPool> queue;
+ if (threadsCount > 1) {
+ queue.Reset(new TThreadPool(params));
+ } else {
+ queue.Reset(new TFakeThreadPool());
+ }
+ queue->Start(threadsCount, queueSizeLimit);
+ return queue;
+}
diff --git a/util/thread/pool.h b/util/thread/pool.h
new file mode 100644
index 0000000000..d1ea3a67cb
--- /dev/null
+++ b/util/thread/pool.h
@@ -0,0 +1,390 @@
+#pragma once
+
+#include "fwd.h"
+#include "factory.h"
+
+#include <util/system/yassert.h>
+#include <util/system/defaults.h>
+#include <util/generic/yexception.h>
+#include <util/generic/ptr.h>
+#include <util/generic/noncopyable.h>
+#include <functional>
+
+class TDuration;
+
+struct IObjectInQueue {
+ virtual ~IObjectInQueue() = default;
+
+ /**
+ * Supposed to be implemented by user, to define jobs processed
+ * in multiple threads.
+ *
+ * @param threadSpecificResource is nullptr by default. But if you override
+ * IThreadPool::CreateThreadSpecificResource, then result of
+ * IThreadPool::CreateThreadSpecificResource is passed as threadSpecificResource
+ * parameter.
+ */
+ virtual void Process(void* threadSpecificResource) = 0;
+};
+
+/**
+ * Mighty class to add 'Pool' method to derived classes.
+ * Useful only for creators of new queue classes.
+ */
+class TThreadFactoryHolder {
+public:
+ TThreadFactoryHolder() noexcept;
+
+ inline TThreadFactoryHolder(IThreadFactory* pool) noexcept
+ : Pool_(pool)
+ {
+ }
+
+ inline ~TThreadFactoryHolder() = default;
+
+ inline IThreadFactory* Pool() const noexcept {
+ return Pool_;
+ }
+
+private:
+ IThreadFactory* Pool_;
+};
+
+class TThreadPoolException: public yexception {
+};
+
+template <class T>
+class TThrFuncObj: public IObjectInQueue {
+public:
+ TThrFuncObj(const T& func)
+ : Func(func)
+ {
+ }
+
+ TThrFuncObj(T&& func)
+ : Func(std::move(func))
+ {
+ }
+
+ void Process(void*) override {
+ THolder<TThrFuncObj> self(this);
+ Func();
+ }
+
+private:
+ T Func;
+};
+
+template <class T>
+IObjectInQueue* MakeThrFuncObj(T&& func) {
+ return new TThrFuncObj<std::remove_cv_t<std::remove_reference_t<T>>>(std::forward<T>(func));
+}
+
+struct TThreadPoolParams {
+ bool Catching_ = true;
+ bool Blocking_ = false;
+ IThreadFactory* Factory_ = SystemThreadFactory();
+ TString ThreadName_;
+ bool EnumerateThreads_ = false;
+
+ using TSelf = TThreadPoolParams;
+
+ TThreadPoolParams() {
+ }
+
+ TThreadPoolParams(IThreadFactory* factory)
+ : Factory_(factory)
+ {
+ }
+
+ TThreadPoolParams(const TString& name) {
+ SetThreadName(name);
+ }
+
+ TThreadPoolParams(const char* name) {
+ SetThreadName(name);
+ }
+
+ TSelf& SetCatching(bool val) {
+ Catching_ = val;
+ return *this;
+ }
+
+ TSelf& SetBlocking(bool val) {
+ Blocking_ = val;
+ return *this;
+ }
+
+ TSelf& SetFactory(IThreadFactory* factory) {
+ Factory_ = factory;
+ return *this;
+ }
+
+ TSelf& SetThreadName(const TString& name) {
+ ThreadName_ = name;
+ EnumerateThreads_ = false;
+ return *this;
+ }
+
+ TSelf& SetThreadNamePrefix(const TString& prefix) {
+ ThreadName_ = prefix;
+ EnumerateThreads_ = true;
+ return *this;
+ }
+};
+
+/**
+ * A queue processed simultaneously by several threads
+ */
+class IThreadPool: public IThreadFactory, public TNonCopyable {
+public:
+ using TParams = TThreadPoolParams;
+
+ ~IThreadPool() override = default;
+
+ /**
+ * Safe versions of Add*() functions. Behave exactly like as non-safe
+ * version of Add*(), but use exceptions instead returning false
+ */
+ void SafeAdd(IObjectInQueue* obj);
+
+ template <class T>
+ void SafeAddFunc(T&& func) {
+ Y_ENSURE_EX(AddFunc(std::forward<T>(func)), TThreadPoolException() << TStringBuf("can not add function to queue"));
+ }
+
+ void SafeAddAndOwn(THolder<IObjectInQueue> obj);
+
+ /**
+ * Add object to queue, run ojb->Proccess in other threads.
+ * Obj is not deleted after execution
+ * @return true of obj is successfully added to queue
+ * @return false if queue is full or shutting down
+ */
+ virtual bool Add(IObjectInQueue* obj) Y_WARN_UNUSED_RESULT = 0;
+
+ template <class T>
+ Y_WARN_UNUSED_RESULT bool AddFunc(T&& func) {
+ THolder<IObjectInQueue> wrapper(MakeThrFuncObj(std::forward<T>(func)));
+ bool added = Add(wrapper.Get());
+ if (added) {
+ Y_UNUSED(wrapper.Release());
+ }
+ return added;
+ }
+
+ bool AddAndOwn(THolder<IObjectInQueue> obj) Y_WARN_UNUSED_RESULT;
+ virtual void Start(size_t threadCount, size_t queueSizeLimit = 0) = 0;
+ /** Wait for completion of all scheduled objects, and then exit */
+ virtual void Stop() noexcept = 0;
+ /** Number of tasks currently in queue */
+ virtual size_t Size() const noexcept = 0;
+
+public:
+ /**
+ * RAII wrapper for Create/DestroyThreadSpecificResource.
+ * Useful only for implementers of new IThreadPool queues.
+ */
+ class TTsr {
+ public:
+ inline TTsr(IThreadPool* q)
+ : Q_(q)
+ , Data_(Q_->CreateThreadSpecificResource())
+ {
+ }
+
+ inline ~TTsr() {
+ try {
+ Q_->DestroyThreadSpecificResource(Data_);
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+ }
+
+ inline operator void*() noexcept {
+ return Data_;
+ }
+
+ private:
+ IThreadPool* Q_;
+ void* Data_;
+ };
+
+ /**
+ * CreateThreadSpecificResource and DestroyThreadSpecificResource
+ * called from internals of (TAdaptiveThreadPool, TThreadPool, ...) implementation,
+ * not by user of IThreadPool interface.
+ * Created resource is passed to IObjectInQueue::Proccess function.
+ */
+ virtual void* CreateThreadSpecificResource() {
+ return nullptr;
+ }
+
+ virtual void DestroyThreadSpecificResource(void* resource) {
+ if (resource != nullptr) {
+ Y_ASSERT(resource == nullptr);
+ }
+ }
+
+private:
+ IThread* DoCreate() override;
+};
+
+/**
+ * Single-threaded implementation of IThreadPool, process tasks in same thread when
+ * added.
+ * Can be used to remove multithreading.
+ */
+class TFakeThreadPool: public IThreadPool {
+public:
+ bool Add(IObjectInQueue* pObj) override Y_WARN_UNUSED_RESULT {
+ TTsr tsr(this);
+ pObj->Process(tsr);
+
+ return true;
+ }
+
+ void Start(size_t, size_t = 0) override {
+ }
+
+ void Stop() noexcept override {
+ }
+
+ size_t Size() const noexcept override {
+ return 0;
+ }
+};
+
+class TThreadPoolBase: public IThreadPool, public TThreadFactoryHolder {
+public:
+ TThreadPoolBase(const TParams& params);
+
+protected:
+ TParams Params;
+};
+
+/** queue processed by fixed size thread pool */
+class TThreadPool: public TThreadPoolBase {
+public:
+ TThreadPool(const TParams& params = {});
+ ~TThreadPool() override;
+
+ bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT;
+ /**
+ * @param queueSizeLimit means "unlimited" when = 0
+ * @param threadCount means "single thread" when = 0
+ */
+ void Start(size_t threadCount, size_t queueSizeLimit = 0) override;
+ void Stop() noexcept override;
+ size_t Size() const noexcept override;
+ size_t GetThreadCountExpected() const noexcept;
+ size_t GetThreadCountReal() const noexcept;
+ size_t GetMaxQueueSize() const noexcept;
+
+private:
+ class TImpl;
+ THolder<TImpl> Impl_;
+};
+
+/**
+ * Always create new thread for new task, when all existing threads are busy.
+ * Maybe dangerous, number of threads is not limited.
+ */
+class TAdaptiveThreadPool: public TThreadPoolBase {
+public:
+ TAdaptiveThreadPool(const TParams& params = {});
+ ~TAdaptiveThreadPool() override;
+
+ /**
+ * If working thread waits task too long (more then interval parameter),
+ * then the thread would be killed. Default value - infinity, all created threads
+ * waits for new task forever, before Stop.
+ */
+ void SetMaxIdleTime(TDuration interval);
+
+ bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT;
+ /** @param thrnum, @param maxque are ignored */
+ void Start(size_t thrnum = 0, size_t maxque = 0) override;
+ void Stop() noexcept override;
+ size_t Size() const noexcept override;
+
+private:
+ class TImpl;
+ THolder<TImpl> Impl_;
+};
+
+/** Behave like TThreadPool or TAdaptiveThreadPool, choosen by thrnum parameter of Start() */
+class TSimpleThreadPool: public TThreadPoolBase {
+public:
+ TSimpleThreadPool(const TParams& params = {});
+ ~TSimpleThreadPool() override;
+
+ bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT;
+ /**
+ * @parameter thrnum. If thrnum is 0, use TAdaptiveThreadPool with small
+ * SetMaxIdleTime interval parameter. if thrnum is not 0, use non-blocking TThreadPool
+ */
+ void Start(size_t thrnum, size_t maxque = 0) override;
+ void Stop() noexcept override;
+ size_t Size() const noexcept override;
+
+private:
+ THolder<IThreadPool> Slave_;
+};
+
+/**
+ * Helper to override virtual functions Create/DestroyThreadSpecificResource
+ * from IThreadPool and implement them using functions with same name from
+ * pointer to TSlave.
+ */
+template <class TQueueType, class TSlave>
+class TThreadPoolBinder: public TQueueType {
+public:
+ inline TThreadPoolBinder(TSlave* slave)
+ : Slave_(slave)
+ {
+ }
+
+ template <class... Args>
+ inline TThreadPoolBinder(TSlave* slave, Args&&... args)
+ : TQueueType(std::forward<Args>(args)...)
+ , Slave_(slave)
+ {
+ }
+
+ inline TThreadPoolBinder(TSlave& slave)
+ : Slave_(&slave)
+ {
+ }
+
+ ~TThreadPoolBinder() override {
+ try {
+ this->Stop();
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+ }
+
+ void* CreateThreadSpecificResource() override {
+ return Slave_->CreateThreadSpecificResource();
+ }
+
+ void DestroyThreadSpecificResource(void* resource) override {
+ Slave_->DestroyThreadSpecificResource(resource);
+ }
+
+private:
+ TSlave* Slave_;
+};
+
+inline void Delete(THolder<IThreadPool> q) {
+ if (q.Get()) {
+ q->Stop();
+ }
+}
+
+/**
+ * Creates and starts TThreadPool if threadsCount > 1, or TFakeThreadPool otherwise
+ * You could specify blocking and catching modes for TThreadPool only
+ */
+THolder<IThreadPool> CreateThreadPool(size_t threadCount, size_t queueSizeLimit = 0, const IThreadPool::TParams& params = {});
diff --git a/util/thread/pool_ut.cpp b/util/thread/pool_ut.cpp
new file mode 100644
index 0000000000..893770d0c4
--- /dev/null
+++ b/util/thread/pool_ut.cpp
@@ -0,0 +1,257 @@
+#include "pool.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/stream/output.h>
+#include <util/random/fast.h>
+#include <util/system/spinlock.h>
+#include <util/system/thread.h>
+#include <util/system/mutex.h>
+#include <util/system/condvar.h>
+
+struct TThreadPoolTest {
+ TSpinLock Lock;
+ long R = -1;
+
+ struct TTask: public IObjectInQueue {
+ TThreadPoolTest* Test = nullptr;
+ long Value = 0;
+
+ TTask(TThreadPoolTest* test, int value)
+ : Test(test)
+ , Value(value)
+ {
+ }
+
+ void Process(void*) override {
+ THolder<TTask> This(this);
+
+ TGuard<TSpinLock> guard(Test->Lock);
+ Test->R ^= Value;
+ }
+ };
+
+ struct TOwnedTask: public IObjectInQueue {
+ bool& Processed;
+ bool& Destructed;
+
+ TOwnedTask(bool& processed, bool& destructed)
+ : Processed(processed)
+ , Destructed(destructed)
+ {
+ }
+
+ ~TOwnedTask() override {
+ Destructed = true;
+ }
+
+ void Process(void*) override {
+ Processed = true;
+ }
+ };
+
+ inline void TestAnyQueue(IThreadPool* queue, size_t queueSize = 1000) {
+ TReallyFastRng32 rand(17);
+ const size_t cnt = 1000;
+
+ R = 0;
+
+ for (size_t i = 0; i < cnt; ++i) {
+ R ^= (long)rand.GenRand();
+ }
+
+ queue->Start(10, queueSize);
+ rand = TReallyFastRng32(17);
+
+ for (size_t i = 0; i < cnt; ++i) {
+ UNIT_ASSERT(queue->Add(new TTask(this, (long)rand.GenRand())));
+ }
+
+ queue->Stop();
+
+ UNIT_ASSERT_EQUAL(0, R);
+ }
+};
+
+class TFailAddQueue: public IThreadPool {
+public:
+ bool Add(IObjectInQueue* /*obj*/) override Y_WARN_UNUSED_RESULT {
+ return false;
+ }
+
+ void Start(size_t, size_t) override {
+ }
+
+ void Stop() noexcept override {
+ }
+
+ size_t Size() const noexcept override {
+ return 0;
+ }
+};
+
+Y_UNIT_TEST_SUITE(TThreadPoolTest) {
+ Y_UNIT_TEST(TestTThreadPool) {
+ TThreadPoolTest t;
+ TThreadPool q;
+ t.TestAnyQueue(&q);
+ }
+
+ Y_UNIT_TEST(TestTThreadPoolBlocking) {
+ TThreadPoolTest t;
+ TThreadPool q(TThreadPool::TParams().SetBlocking(true));
+ t.TestAnyQueue(&q, 100);
+ }
+
+ // disabled by pg@ long time ago due to test flaps
+ // Tried to enable: REVIEW:78772
+ Y_UNIT_TEST(TestTAdaptiveThreadPool) {
+ if (false) {
+ TThreadPoolTest t;
+ TAdaptiveThreadPool q;
+ t.TestAnyQueue(&q);
+ }
+ }
+
+ Y_UNIT_TEST(TestAddAndOwn) {
+ TThreadPool q;
+ q.Start(2);
+ bool processed = false;
+ bool destructed = false;
+ q.SafeAddAndOwn(MakeHolder<TThreadPoolTest::TOwnedTask>(processed, destructed));
+ q.Stop();
+
+ UNIT_ASSERT_C(processed, "Not processed");
+ UNIT_ASSERT_C(destructed, "Not destructed");
+ }
+
+ Y_UNIT_TEST(TestAddFunc) {
+ TFailAddQueue queue;
+ bool added = queue.AddFunc(
+ []() {} // Lambda, I call him 'Lambda'!
+ );
+ UNIT_ASSERT_VALUES_EQUAL(added, false);
+ }
+
+ Y_UNIT_TEST(TestSafeAddFuncThrows) {
+ TFailAddQueue queue;
+ UNIT_CHECK_GENERATED_EXCEPTION(queue.SafeAddFunc([] {}), TThreadPoolException);
+ }
+
+ Y_UNIT_TEST(TestFunctionNotCopied) {
+ struct TFailOnCopy {
+ TFailOnCopy() {
+ }
+
+ TFailOnCopy(TFailOnCopy&&) {
+ }
+
+ TFailOnCopy(const TFailOnCopy&) {
+ UNIT_FAIL("Don't copy std::function inside TThreadPool");
+ }
+ };
+
+ TThreadPool queue(TThreadPool::TParams().SetBlocking(false).SetCatching(true));
+ queue.Start(2);
+
+ queue.SafeAddFunc([data = TFailOnCopy()]() {});
+
+ queue.Stop();
+ }
+
+ Y_UNIT_TEST(TestInfoGetters) {
+ TThreadPool queue;
+
+ queue.Start(2, 7);
+
+ UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 2);
+ UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 2);
+ UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 7);
+
+ queue.Stop();
+
+ queue.Start(4, 1);
+
+ UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 4);
+ UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 4);
+ UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 1);
+
+ queue.Stop();
+ }
+
+ void TestFixedThreadName(IThreadPool& pool, const TString& expectedName) {
+ pool.Start(1);
+ TString name;
+ pool.SafeAddFunc([&name]() {
+ name = TThread::CurrentThreadName();
+ });
+ pool.Stop();
+ if (TThread::CanGetCurrentThreadName()) {
+ UNIT_ASSERT_EQUAL(name, expectedName);
+ UNIT_ASSERT_UNEQUAL(TThread::CurrentThreadName(), expectedName);
+ }
+ }
+
+ Y_UNIT_TEST(TestFixedThreadName) {
+ const TString expectedName = "HelloWorld";
+ {
+ TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadName(expectedName));
+ TestFixedThreadName(pool, expectedName);
+ }
+ {
+ TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadName(expectedName));
+ TestFixedThreadName(pool, expectedName);
+ }
+ }
+
+ void TestEnumeratedThreadName(IThreadPool& pool, const THashSet<TString>& expectedNames) {
+ pool.Start(expectedNames.size());
+ TMutex lock;
+ TCondVar allReady;
+ size_t readyCount = 0;
+ THashSet<TString> names;
+ for (size_t i = 0; i < expectedNames.size(); ++i) {
+ pool.SafeAddFunc([&]() {
+ with_lock (lock) {
+ if (++readyCount == expectedNames.size()) {
+ allReady.BroadCast();
+ } else {
+ while (readyCount != expectedNames.size()) {
+ allReady.WaitI(lock);
+ }
+ }
+ names.insert(TThread::CurrentThreadName());
+ }
+ });
+ }
+ pool.Stop();
+ if (TThread::CanGetCurrentThreadName()) {
+ UNIT_ASSERT_EQUAL(names, expectedNames);
+ }
+ }
+
+ Y_UNIT_TEST(TestEnumeratedThreadName) {
+ const TString namePrefix = "HelloWorld";
+ const THashSet<TString> expectedNames = {
+ "HelloWorld0",
+ "HelloWorld1",
+ "HelloWorld2",
+ "HelloWorld3",
+ "HelloWorld4",
+ "HelloWorld5",
+ "HelloWorld6",
+ "HelloWorld7",
+ "HelloWorld8",
+ "HelloWorld9",
+ "HelloWorld10",
+ };
+ {
+ TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadNamePrefix(namePrefix));
+ TestEnumeratedThreadName(pool, expectedNames);
+ }
+ {
+ TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadNamePrefix(namePrefix));
+ TestEnumeratedThreadName(pool, expectedNames);
+ }
+ }
+}
diff --git a/util/thread/singleton.cpp b/util/thread/singleton.cpp
new file mode 100644
index 0000000000..a898bdc9d4
--- /dev/null
+++ b/util/thread/singleton.cpp
@@ -0,0 +1 @@
+#include "singleton.h"
diff --git a/util/thread/singleton.h b/util/thread/singleton.h
new file mode 100644
index 0000000000..4a1e05aea0
--- /dev/null
+++ b/util/thread/singleton.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include <util/system/tls.h>
+#include <util/generic/singleton.h>
+#include <util/generic/ptr.h>
+
+namespace NPrivate {
+ template <class T, size_t Priority>
+ struct TFastThreadSingletonHelper {
+ static inline T* GetSlow() {
+ return SingletonWithPriority<NTls::TValue<T>, Priority>()->GetPtr();
+ }
+
+ static inline T* Get() {
+#if defined(Y_HAVE_FAST_POD_TLS)
+ Y_POD_STATIC_THREAD(T*) fast(nullptr);
+
+ if (Y_UNLIKELY(!fast)) {
+ fast = GetSlow();
+ }
+
+ return fast;
+#else
+ return GetSlow();
+#endif
+ }
+ };
+}
+
+template <class T, size_t Priority>
+static inline T* FastTlsSingletonWithPriority() {
+ return ::NPrivate::TFastThreadSingletonHelper<T, Priority>::Get();
+}
+
+// NB: the singleton is the same for all modules that use
+// FastTlsSingleton with the same type parameter. If unique singleton
+// required, use unique types.
+template <class T>
+static inline T* FastTlsSingleton() {
+ return FastTlsSingletonWithPriority<T, TSingletonTraits<T>::Priority>();
+}
diff --git a/util/thread/singleton_ut.cpp b/util/thread/singleton_ut.cpp
new file mode 100644
index 0000000000..164b1cc184
--- /dev/null
+++ b/util/thread/singleton_ut.cpp
@@ -0,0 +1,21 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "singleton.h"
+
+namespace {
+ struct TFoo {
+ int i;
+ TFoo()
+ : i(0)
+ {
+ }
+ };
+}
+
+Y_UNIT_TEST_SUITE(Tls) {
+ Y_UNIT_TEST(FastThread) {
+ UNIT_ASSERT_VALUES_EQUAL(0, FastTlsSingleton<TFoo>()->i);
+ FastTlsSingleton<TFoo>()->i += 3;
+ UNIT_ASSERT_VALUES_EQUAL(3, FastTlsSingleton<TFoo>()->i);
+ }
+}
diff --git a/util/thread/ut/ya.make b/util/thread/ut/ya.make
new file mode 100644
index 0000000000..93198bfaf1
--- /dev/null
+++ b/util/thread/ut/ya.make
@@ -0,0 +1,18 @@
+UNITTEST_FOR(util)
+
+OWNER(g:util)
+SUBSCRIBER(g:util-subscribers)
+
+SRCS(
+ thread/factory_ut.cpp
+ thread/lfqueue_ut.cpp
+ thread/lfstack_ut.cpp
+ thread/pool_ut.cpp
+ thread/singleton_ut.cpp
+)
+
+PEERDIR(
+ library/cpp/threading/future
+)
+
+END()
diff --git a/util/thread/ya.make b/util/thread/ya.make
new file mode 100644
index 0000000000..79c9498ddd
--- /dev/null
+++ b/util/thread/ya.make
@@ -0,0 +1,6 @@
+OWNER(g:util)
+SUBSCRIBER(g:util-subscribers)
+
+RECURSE_FOR_TESTS(
+ ut
+)