diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/thread | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/thread')
-rw-r--r-- | util/thread/factory.cpp | 93 | ||||
-rw-r--r-- | util/thread/factory.h | 65 | ||||
-rw-r--r-- | util/thread/factory_ut.cpp | 57 | ||||
-rw-r--r-- | util/thread/fwd.cpp | 1 | ||||
-rw-r--r-- | util/thread/fwd.h | 30 | ||||
-rw-r--r-- | util/thread/lfqueue.cpp | 1 | ||||
-rw-r--r-- | util/thread/lfqueue.h | 406 | ||||
-rw-r--r-- | util/thread/lfqueue_ut.cpp | 333 | ||||
-rw-r--r-- | util/thread/lfstack.cpp | 1 | ||||
-rw-r--r-- | util/thread/lfstack.h | 188 | ||||
-rw-r--r-- | util/thread/lfstack_ut.cpp | 346 | ||||
-rw-r--r-- | util/thread/pool.cpp | 772 | ||||
-rw-r--r-- | util/thread/pool.h | 390 | ||||
-rw-r--r-- | util/thread/pool_ut.cpp | 257 | ||||
-rw-r--r-- | util/thread/singleton.cpp | 1 | ||||
-rw-r--r-- | util/thread/singleton.h | 41 | ||||
-rw-r--r-- | util/thread/singleton_ut.cpp | 21 | ||||
-rw-r--r-- | util/thread/ut/ya.make | 18 | ||||
-rw-r--r-- | util/thread/ya.make | 6 |
19 files changed, 3027 insertions, 0 deletions
diff --git a/util/thread/factory.cpp b/util/thread/factory.cpp new file mode 100644 index 0000000000..48e898f32d --- /dev/null +++ b/util/thread/factory.cpp @@ -0,0 +1,93 @@ +#include "factory.h" + +#include <util/system/thread.h> +#include <util/generic/singleton.h> + +using IThread = IThreadFactory::IThread; + +namespace { + class TSystemThreadFactory: public IThreadFactory { + public: + class TPoolThread: public IThread { + public: + ~TPoolThread() override { + if (Thr_) { + Thr_->Detach(); + } + } + + void DoRun(IThreadAble* func) override { + Thr_.Reset(new TThread(ThreadProc, func)); + + Thr_->Start(); + } + + void DoJoin() noexcept override { + if (!Thr_) { + return; + } + + Thr_->Join(); + Thr_.Destroy(); + } + + private: + static void* ThreadProc(void* func) { + ((IThreadAble*)(func))->Execute(); + + return nullptr; + } + + private: + THolder<TThread> Thr_; + }; + + inline TSystemThreadFactory() noexcept { + } + + IThread* DoCreate() override { + return new TPoolThread; + } + }; + + class TThreadFactoryFuncObj: public IThreadFactory::IThreadAble { + public: + TThreadFactoryFuncObj(const std::function<void()>& func) + : Func(func) + { + } + void DoExecute() override { + THolder<TThreadFactoryFuncObj> self(this); + Func(); + } + + private: + std::function<void()> Func; + }; +} + +THolder<IThread> IThreadFactory::Run(std::function<void()> func) { + THolder<IThread> ret(DoCreate()); + + ret->Run(new ::TThreadFactoryFuncObj(func)); + + return ret; +} + +static IThreadFactory* SystemThreadPoolImpl() { + return Singleton<TSystemThreadFactory>(); +} + +static IThreadFactory* systemPool = nullptr; + +IThreadFactory* SystemThreadFactory() { + if (systemPool) { + return systemPool; + } + + return SystemThreadPoolImpl(); +} + +void SetSystemThreadFactory(IThreadFactory* pool) { + systemPool = pool; +} diff --git a/util/thread/factory.h b/util/thread/factory.h new file mode 100644 index 0000000000..561fcbac88 --- /dev/null +++ b/util/thread/factory.h @@ -0,0 +1,65 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <functional> + +class IThreadFactory { +public: + class IThreadAble { + public: + inline IThreadAble() noexcept = default; + + virtual ~IThreadAble() = default; + + inline void Execute() { + DoExecute(); + } + + private: + virtual void DoExecute() = 0; + }; + + class IThread { + friend class IThreadFactory; + + public: + inline IThread() noexcept = default; + + virtual ~IThread() = default; + + inline void Join() noexcept { + DoJoin(); + } + + private: + inline void Run(IThreadAble* func) { + DoRun(func); + } + + private: + // it's actually DoStart + virtual void DoRun(IThreadAble* func) = 0; + virtual void DoJoin() noexcept = 0; + }; + + inline IThreadFactory() noexcept = default; + + virtual ~IThreadFactory() = default; + + // XXX: rename to Start + inline THolder<IThread> Run(IThreadAble* func) { + THolder<IThread> ret(DoCreate()); + + ret->Run(func); + + return ret; + } + + THolder<IThread> Run(std::function<void()> func); + +private: + virtual IThread* DoCreate() = 0; +}; + +IThreadFactory* SystemThreadFactory(); +void SetSystemThreadFactory(IThreadFactory* pool); diff --git a/util/thread/factory_ut.cpp b/util/thread/factory_ut.cpp new file mode 100644 index 0000000000..647d96c901 --- /dev/null +++ b/util/thread/factory_ut.cpp @@ -0,0 +1,57 @@ +#include "factory.h" +#include "pool.h" + +#include <library/cpp/testing/unittest/registar.h> + +class TThrPoolTest: public TTestBase { + UNIT_TEST_SUITE(TThrPoolTest); + UNIT_TEST(TestSystemPool) + UNIT_TEST(TestAdaptivePool) + UNIT_TEST_SUITE_END(); + + struct TRunAble: public IThreadFactory::IThreadAble { + inline TRunAble() + : done(false) + { + } + + ~TRunAble() override = default; + + void DoExecute() override { + done = true; + } + + bool done; + }; + +private: + inline void TestSystemPool() { + TRunAble r; + + { + THolder<IThreadFactory::IThread> thr = SystemThreadFactory()->Run(&r); + + thr->Join(); + } + + UNIT_ASSERT_EQUAL(r.done, true); + } + + inline void TestAdaptivePool() { + TRunAble r; + + { + TAdaptiveThreadPool pool; + + pool.Start(0); + + THolder<IThreadFactory::IThread> thr = pool.Run(&r); + + thr->Join(); + } + + UNIT_ASSERT_EQUAL(r.done, true); + } +}; + +UNIT_TEST_SUITE_REGISTRATION(TThrPoolTest); diff --git a/util/thread/fwd.cpp b/util/thread/fwd.cpp new file mode 100644 index 0000000000..4214b6df83 --- /dev/null +++ b/util/thread/fwd.cpp @@ -0,0 +1 @@ +#include "fwd.h" diff --git a/util/thread/fwd.h b/util/thread/fwd.h new file mode 100644 index 0000000000..6f1caed21c --- /dev/null +++ b/util/thread/fwd.h @@ -0,0 +1,30 @@ +#pragma once + +#include <stlfwd> + +struct TDefaultLFCounter; + +template <class T, class TCounter = TDefaultLFCounter> +class TLockFreeQueue; + +template <class T, class TCounter = TDefaultLFCounter> +class TAutoLockFreeQueue; + +template <class T> +class TLockFreeStack; + +class IThreadFactory; + +struct IObjectInQueue; +class TThreadFactoryHolder; + +using TThreadFunction = std::function<void()>; + +class IThreadPool; +class TFakeThreadPool; +class TThreadPool; +class TAdaptiveThreadPool; +class TSimpleThreadPool; + +template <class TQueueType, class TSlave> +class TThreadPoolBinder; diff --git a/util/thread/lfqueue.cpp b/util/thread/lfqueue.cpp new file mode 100644 index 0000000000..5861999b78 --- /dev/null +++ b/util/thread/lfqueue.cpp @@ -0,0 +1 @@ +#include "lfqueue.h" diff --git a/util/thread/lfqueue.h b/util/thread/lfqueue.h new file mode 100644 index 0000000000..ab523631e4 --- /dev/null +++ b/util/thread/lfqueue.h @@ -0,0 +1,406 @@ +#pragma once + +#include "fwd.h" + +#include <util/generic/ptr.h> +#include <util/system/atomic.h> +#include <util/system/yassert.h> +#include "lfstack.h" + +struct TDefaultLFCounter { + template <class T> + void IncCount(const T& data) { + (void)data; + } + template <class T> + void DecCount(const T& data) { + (void)data; + } +}; + +// @brief lockfree queue +// @tparam T - the queue element, should be movable +// @tparam TCounter, a observer class to count number of items in queue +// be carifull, IncCount and DecCount can be called on a moved object and +// it is TCounter class responsibility to check validity of passed object +template <class T, class TCounter> +class TLockFreeQueue: public TNonCopyable { + struct TListNode { + template <typename U> + TListNode(U&& u, TListNode* next) + : Next(next) + , Data(std::forward<U>(u)) + { + } + + template <typename U> + explicit TListNode(U&& u) + : Data(std::forward<U>(u)) + { + } + + TListNode* volatile Next; + T Data; + }; + + // using inheritance to be able to use 0 bytes for TCounter when we don't need one + struct TRootNode: public TCounter { + TListNode* volatile PushQueue; + TListNode* volatile PopQueue; + TListNode* volatile ToDelete; + TRootNode* volatile NextFree; + + TRootNode() + : PushQueue(nullptr) + , PopQueue(nullptr) + , ToDelete(nullptr) + , NextFree(nullptr) + { + } + void CopyCounter(TRootNode* x) { + *(TCounter*)this = *(TCounter*)x; + } + }; + + static void EraseList(TListNode* n) { + while (n) { + TListNode* keepNext = AtomicGet(n->Next); + delete n; + n = keepNext; + } + } + + alignas(64) TRootNode* volatile JobQueue; + alignas(64) volatile TAtomic FreememCounter; + alignas(64) volatile TAtomic FreeingTaskCounter; + alignas(64) TRootNode* volatile FreePtr; + + void TryToFreeAsyncMemory() { + TAtomic keepCounter = AtomicAdd(FreeingTaskCounter, 0); + TRootNode* current = AtomicGet(FreePtr); + if (current == nullptr) + return; + if (AtomicAdd(FreememCounter, 0) == 1) { + // we are the last thread, try to cleanup + // check if another thread have cleaned up + if (keepCounter != AtomicAdd(FreeingTaskCounter, 0)) { + return; + } + if (AtomicCas(&FreePtr, (TRootNode*)nullptr, current)) { + // free list + while (current) { + TRootNode* p = AtomicGet(current->NextFree); + EraseList(AtomicGet(current->ToDelete)); + delete current; + current = p; + } + AtomicAdd(FreeingTaskCounter, 1); + } + } + } + void AsyncRef() { + AtomicAdd(FreememCounter, 1); + } + void AsyncUnref() { + TryToFreeAsyncMemory(); + AtomicAdd(FreememCounter, -1); + } + void AsyncDel(TRootNode* toDelete, TListNode* lst) { + AtomicSet(toDelete->ToDelete, lst); + for (;;) { + AtomicSet(toDelete->NextFree, AtomicGet(FreePtr)); + if (AtomicCas(&FreePtr, toDelete, AtomicGet(toDelete->NextFree))) + break; + } + } + void AsyncUnref(TRootNode* toDelete, TListNode* lst) { + TryToFreeAsyncMemory(); + if (AtomicAdd(FreememCounter, -1) == 0) { + // no other operations in progress, can safely reclaim memory + EraseList(lst); + delete toDelete; + } else { + // Dequeue()s in progress, put node to free list + AsyncDel(toDelete, lst); + } + } + + struct TListInvertor { + TListNode* Copy; + TListNode* Tail; + TListNode* PrevFirst; + + TListInvertor() + : Copy(nullptr) + , Tail(nullptr) + , PrevFirst(nullptr) + { + } + ~TListInvertor() { + EraseList(Copy); + } + void CopyWasUsed() { + Copy = nullptr; + Tail = nullptr; + PrevFirst = nullptr; + } + void DoCopy(TListNode* ptr) { + TListNode* newFirst = ptr; + TListNode* newCopy = nullptr; + TListNode* newTail = nullptr; + while (ptr) { + if (ptr == PrevFirst) { + // short cut, we have copied this part already + AtomicSet(Tail->Next, newCopy); + newCopy = Copy; + Copy = nullptr; // do not destroy prev try + if (!newTail) + newTail = Tail; // tried to invert same list + break; + } + TListNode* newElem = new TListNode(ptr->Data, newCopy); + newCopy = newElem; + ptr = AtomicGet(ptr->Next); + if (!newTail) + newTail = newElem; + } + EraseList(Copy); // copy was useless + Copy = newCopy; + PrevFirst = newFirst; + Tail = newTail; + } + }; + + void EnqueueImpl(TListNode* head, TListNode* tail) { + TRootNode* newRoot = new TRootNode; + AsyncRef(); + AtomicSet(newRoot->PushQueue, head); + for (;;) { + TRootNode* curRoot = AtomicGet(JobQueue); + AtomicSet(tail->Next, AtomicGet(curRoot->PushQueue)); + AtomicSet(newRoot->PopQueue, AtomicGet(curRoot->PopQueue)); + newRoot->CopyCounter(curRoot); + + for (TListNode* node = head;; node = AtomicGet(node->Next)) { + newRoot->IncCount(node->Data); + if (node == tail) + break; + } + + if (AtomicCas(&JobQueue, newRoot, curRoot)) { + AsyncUnref(curRoot, nullptr); + break; + } + } + } + + template <typename TCollection> + static void FillCollection(TListNode* lst, TCollection* res) { + while (lst) { + res->emplace_back(std::move(lst->Data)); + lst = AtomicGet(lst->Next); + } + } + + /** Traverses a given list simultaneously creating its inversed version. + * After that, fills a collection with a reversed version and returns the last visited lst's node. + */ + template <typename TCollection> + static TListNode* FillCollectionReverse(TListNode* lst, TCollection* res) { + if (!lst) { + return nullptr; + } + + TListNode* newCopy = nullptr; + do { + TListNode* newElem = new TListNode(std::move(lst->Data), newCopy); + newCopy = newElem; + lst = AtomicGet(lst->Next); + } while (lst); + + FillCollection(newCopy, res); + EraseList(newCopy); + + return lst; + } + +public: + TLockFreeQueue() + : JobQueue(new TRootNode) + , FreememCounter(0) + , FreeingTaskCounter(0) + , FreePtr(nullptr) + { + } + ~TLockFreeQueue() { + AsyncRef(); + AsyncUnref(); // should free FreeList + EraseList(JobQueue->PushQueue); + EraseList(JobQueue->PopQueue); + delete JobQueue; + } + template <typename U> + void Enqueue(U&& data) { + TListNode* newNode = new TListNode(std::forward<U>(data)); + EnqueueImpl(newNode, newNode); + } + void Enqueue(T&& data) { + TListNode* newNode = new TListNode(std::move(data)); + EnqueueImpl(newNode, newNode); + } + void Enqueue(const T& data) { + TListNode* newNode = new TListNode(data); + EnqueueImpl(newNode, newNode); + } + template <typename TCollection> + void EnqueueAll(const TCollection& data) { + EnqueueAll(data.begin(), data.end()); + } + template <typename TIter> + void EnqueueAll(TIter dataBegin, TIter dataEnd) { + if (dataBegin == dataEnd) + return; + + TIter i = dataBegin; + TListNode* volatile node = new TListNode(*i); + TListNode* volatile tail = node; + + for (++i; i != dataEnd; ++i) { + TListNode* nextNode = node; + node = new TListNode(*i, nextNode); + } + EnqueueImpl(node, tail); + } + bool Dequeue(T* data) { + TRootNode* newRoot = nullptr; + TListInvertor listInvertor; + AsyncRef(); + for (;;) { + TRootNode* curRoot = AtomicGet(JobQueue); + TListNode* tail = AtomicGet(curRoot->PopQueue); + if (tail) { + // has elems to pop + if (!newRoot) + newRoot = new TRootNode; + + AtomicSet(newRoot->PushQueue, AtomicGet(curRoot->PushQueue)); + AtomicSet(newRoot->PopQueue, AtomicGet(tail->Next)); + newRoot->CopyCounter(curRoot); + newRoot->DecCount(tail->Data); + Y_ASSERT(AtomicGet(curRoot->PopQueue) == tail); + if (AtomicCas(&JobQueue, newRoot, curRoot)) { + *data = std::move(tail->Data); + AtomicSet(tail->Next, nullptr); + AsyncUnref(curRoot, tail); + return true; + } + continue; + } + if (AtomicGet(curRoot->PushQueue) == nullptr) { + delete newRoot; + AsyncUnref(); + return false; // no elems to pop + } + + if (!newRoot) + newRoot = new TRootNode; + AtomicSet(newRoot->PushQueue, nullptr); + listInvertor.DoCopy(AtomicGet(curRoot->PushQueue)); + AtomicSet(newRoot->PopQueue, listInvertor.Copy); + newRoot->CopyCounter(curRoot); + Y_ASSERT(AtomicGet(curRoot->PopQueue) == nullptr); + if (AtomicCas(&JobQueue, newRoot, curRoot)) { + newRoot = nullptr; + listInvertor.CopyWasUsed(); + AsyncDel(curRoot, AtomicGet(curRoot->PushQueue)); + } else { + AtomicSet(newRoot->PopQueue, nullptr); + } + } + } + template <typename TCollection> + void DequeueAll(TCollection* res) { + AsyncRef(); + + TRootNode* newRoot = new TRootNode; + TRootNode* curRoot; + do { + curRoot = AtomicGet(JobQueue); + } while (!AtomicCas(&JobQueue, newRoot, curRoot)); + + FillCollection(curRoot->PopQueue, res); + + TListNode* toDeleteHead = curRoot->PushQueue; + TListNode* toDeleteTail = FillCollectionReverse(curRoot->PushQueue, res); + AtomicSet(curRoot->PushQueue, nullptr); + + if (toDeleteTail) { + toDeleteTail->Next = curRoot->PopQueue; + } else { + toDeleteTail = curRoot->PopQueue; + } + AtomicSet(curRoot->PopQueue, nullptr); + + AsyncUnref(curRoot, toDeleteHead); + } + bool IsEmpty() { + AsyncRef(); + TRootNode* curRoot = AtomicGet(JobQueue); + bool res = AtomicGet(curRoot->PushQueue) == nullptr && AtomicGet(curRoot->PopQueue) == nullptr; + AsyncUnref(); + return res; + } + TCounter GetCounter() { + AsyncRef(); + TRootNode* curRoot = AtomicGet(JobQueue); + TCounter res = *(TCounter*)curRoot; + AsyncUnref(); + return res; + } +}; + +template <class T, class TCounter> +class TAutoLockFreeQueue { +public: + using TRef = THolder<T>; + + inline ~TAutoLockFreeQueue() { + TRef tmp; + + while (Dequeue(&tmp)) { + } + } + + inline bool Dequeue(TRef* t) { + T* res = nullptr; + + if (Queue.Dequeue(&res)) { + t->Reset(res); + + return true; + } + + return false; + } + + inline void Enqueue(TRef& t) { + Queue.Enqueue(t.Get()); + Y_UNUSED(t.Release()); + } + + inline void Enqueue(TRef&& t) { + Queue.Enqueue(t.Get()); + Y_UNUSED(t.Release()); + } + + inline bool IsEmpty() { + return Queue.IsEmpty(); + } + + inline TCounter GetCounter() { + return Queue.GetCounter(); + } + +private: + TLockFreeQueue<T*, TCounter> Queue; +}; diff --git a/util/thread/lfqueue_ut.cpp b/util/thread/lfqueue_ut.cpp new file mode 100644 index 0000000000..83bca100cf --- /dev/null +++ b/util/thread/lfqueue_ut.cpp @@ -0,0 +1,333 @@ +#include <library/cpp/threading/future/future.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/algorithm.h> +#include <util/generic/vector.h> +#include <util/generic/ptr.h> +#include <util/system/atomic.h> +#include <util/thread/pool.h> + +#include "lfqueue.h" + +class TMoveTest { +public: + TMoveTest(int marker = 0, int value = 0) + : Marker_(marker) + , Value_(value) + { + } + + TMoveTest(const TMoveTest& other) { + *this = other; + } + + TMoveTest(TMoveTest&& other) { + *this = std::move(other); + } + + TMoveTest& operator=(const TMoveTest& other) { + Value_ = other.Value_; + Marker_ = other.Marker_ + 1024; + return *this; + } + + TMoveTest& operator=(TMoveTest&& other) { + Value_ = other.Value_; + Marker_ = other.Marker_; + other.Marker_ = 0; + return *this; + } + + int Marker() const { + return Marker_; + } + + int Value() const { + return Value_; + } + +private: + int Marker_ = 0; + int Value_ = 0; +}; + +class TOperationsChecker { +public: + TOperationsChecker() { + ++DefaultCtor_; + } + + TOperationsChecker(TOperationsChecker&&) { + ++MoveCtor_; + } + + TOperationsChecker(const TOperationsChecker&) { + ++CopyCtor_; + } + + TOperationsChecker& operator=(TOperationsChecker&&) { + ++MoveAssign_; + return *this; + } + + TOperationsChecker& operator=(const TOperationsChecker&) { + ++CopyAssign_; + return *this; + } + + static void Check(int defaultCtor, int moveCtor, int copyCtor, int moveAssign, int copyAssign) { + UNIT_ASSERT_VALUES_EQUAL(defaultCtor, DefaultCtor_); + UNIT_ASSERT_VALUES_EQUAL(moveCtor, MoveCtor_); + UNIT_ASSERT_VALUES_EQUAL(copyCtor, CopyCtor_); + UNIT_ASSERT_VALUES_EQUAL(moveAssign, MoveAssign_); + UNIT_ASSERT_VALUES_EQUAL(copyAssign, CopyAssign_); + Clear(); + } + +private: + static void Clear() { + DefaultCtor_ = MoveCtor_ = CopyCtor_ = MoveAssign_ = CopyAssign_ = 0; + } + + static int DefaultCtor_; + static int MoveCtor_; + static int CopyCtor_; + static int MoveAssign_; + static int CopyAssign_; +}; + +int TOperationsChecker::DefaultCtor_ = 0; +int TOperationsChecker::MoveCtor_ = 0; +int TOperationsChecker::CopyCtor_ = 0; +int TOperationsChecker::MoveAssign_ = 0; +int TOperationsChecker::CopyAssign_ = 0; + +Y_UNIT_TEST_SUITE(TLockFreeQueueTests) { + Y_UNIT_TEST(TestMoveEnqueue) { + TMoveTest value(0xFF, 0xAA); + TMoveTest tmp; + + TLockFreeQueue<TMoveTest> queue; + + queue.Enqueue(value); + UNIT_ASSERT_VALUES_EQUAL(value.Marker(), 0xFF); + UNIT_ASSERT(queue.Dequeue(&tmp)); + UNIT_ASSERT_VALUES_UNEQUAL(tmp.Marker(), 0xFF); + UNIT_ASSERT_VALUES_EQUAL(tmp.Value(), 0xAA); + + queue.Enqueue(std::move(value)); + UNIT_ASSERT_VALUES_EQUAL(value.Marker(), 0); + UNIT_ASSERT(queue.Dequeue(&tmp)); + UNIT_ASSERT_VALUES_EQUAL(tmp.Value(), 0xAA); + } + + Y_UNIT_TEST(TestSimpleEnqueueDequeue) { + TLockFreeQueue<int> queue; + + int i = -1; + + UNIT_ASSERT(!queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(i, -1); + + queue.Enqueue(10); + queue.Enqueue(11); + queue.Enqueue(12); + + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(10, i); + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(11, i); + + queue.Enqueue(13); + + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(12, i); + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(13, i); + + UNIT_ASSERT(!queue.Dequeue(&i)); + + const int tmp = 100; + queue.Enqueue(tmp); + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(i, tmp); + } + + Y_UNIT_TEST(TestSimpleEnqueueAllDequeue) { + TLockFreeQueue<int> queue; + + int i = -1; + + UNIT_ASSERT(!queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(i, -1); + + TVector<int> v; + v.push_back(20); + v.push_back(21); + + queue.EnqueueAll(v); + + v.clear(); + v.push_back(22); + v.push_back(23); + v.push_back(24); + + queue.EnqueueAll(v); + + v.clear(); + queue.EnqueueAll(v); + + v.clear(); + v.push_back(25); + + queue.EnqueueAll(v); + + for (int j = 20; j <= 25; ++j) { + UNIT_ASSERT(queue.Dequeue(&i)); + UNIT_ASSERT_VALUES_EQUAL(j, i); + } + + UNIT_ASSERT(!queue.Dequeue(&i)); + } + + void DequeueAllRunner(TLockFreeQueue<int>& queue, bool singleConsumer) { + size_t threadsNum = 4; + size_t enqueuesPerThread = 10'000; + TThreadPool p; + p.Start(threadsNum, 0); + + TVector<NThreading::TFuture<void>> futures; + + for (size_t i = 0; i < threadsNum; ++i) { + NThreading::TPromise<void> promise = NThreading::NewPromise(); + futures.emplace_back(promise.GetFuture()); + + p.SafeAddFunc([enqueuesPerThread, &queue, promise]() mutable { + for (size_t i = 0; i != enqueuesPerThread; ++i) { + queue.Enqueue(i); + } + + promise.SetValue(); + }); + } + + TAtomic elementsLeft; + AtomicSet(elementsLeft, threadsNum * enqueuesPerThread); + + ui64 numOfConsumers = singleConsumer ? 1 : threadsNum; + + TVector<TVector<int>> dataBuckets(numOfConsumers); + + for (size_t i = 0; i < numOfConsumers; ++i) { + NThreading::TPromise<void> promise = NThreading::NewPromise(); + futures.emplace_back(promise.GetFuture()); + + p.SafeAddFunc([&queue, &elementsLeft, promise, consumerData{&dataBuckets[i]}]() mutable { + TVector<int> vec; + while (static_cast<i64>(AtomicGet(elementsLeft)) > 0) { + for (size_t i = 0; i != 100; ++i) { + vec.clear(); + queue.DequeueAll(&vec); + + AtomicSub(elementsLeft, vec.size()); + consumerData->insert(consumerData->end(), vec.begin(), vec.end()); + } + } + + promise.SetValue(); + }); + } + + NThreading::WaitExceptionOrAll(futures).GetValueSync(); + p.Stop(); + + TVector<int> left; + queue.DequeueAll(&left); + + UNIT_ASSERT(left.empty()); + + TVector<int> data; + for (auto& dataBucket : dataBuckets) { + data.insert(data.end(), dataBucket.begin(), dataBucket.end()); + } + + UNIT_ASSERT_EQUAL(data.size(), threadsNum * enqueuesPerThread); + + size_t threadIdx = 0; + size_t cntValue = 0; + + Sort(data.begin(), data.end()); + for (size_t i = 0; i != data.size(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(cntValue, data[i]); + ++threadIdx; + + if (threadIdx == threadsNum) { + ++cntValue; + threadIdx = 0; + } + } + } + + Y_UNIT_TEST(TestDequeueAllSingleConsumer) { + TLockFreeQueue<int> queue; + DequeueAllRunner(queue, true); + } + + Y_UNIT_TEST(TestDequeueAllMultipleConsumers) { + TLockFreeQueue<int> queue; + DequeueAllRunner(queue, false); + } + + Y_UNIT_TEST(TestDequeueAllEmptyQueue) { + TLockFreeQueue<int> queue; + TVector<int> vec; + + queue.DequeueAll(&vec); + + UNIT_ASSERT(vec.empty()); + } + + Y_UNIT_TEST(TestDequeueAllQueueOrder) { + TLockFreeQueue<int> queue; + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + + TVector<int> v; + queue.DequeueAll(&v); + + UNIT_ASSERT_VALUES_EQUAL(v.size(), 3); + UNIT_ASSERT_VALUES_EQUAL(v[0], 1); + UNIT_ASSERT_VALUES_EQUAL(v[1], 2); + UNIT_ASSERT_VALUES_EQUAL(v[2], 3); + } + + Y_UNIT_TEST(CleanInDestructor) { + TSimpleSharedPtr<bool> p(new bool); + UNIT_ASSERT_VALUES_EQUAL(1u, p.RefCount()); + + { + TLockFreeQueue<TSimpleSharedPtr<bool>> stack; + + stack.Enqueue(p); + stack.Enqueue(p); + + UNIT_ASSERT_VALUES_EQUAL(3u, p.RefCount()); + } + + UNIT_ASSERT_VALUES_EQUAL(1, p.RefCount()); + } + + Y_UNIT_TEST(CheckOperationsCount) { + TOperationsChecker o; + o.Check(1, 0, 0, 0, 0); + TLockFreeQueue<TOperationsChecker> queue; + o.Check(0, 0, 0, 0, 0); + queue.Enqueue(std::move(o)); + o.Check(0, 1, 0, 0, 0); + queue.Enqueue(o); + o.Check(0, 0, 1, 0, 0); + queue.Dequeue(&o); + o.Check(0, 0, 2, 1, 0); + } +} diff --git a/util/thread/lfstack.cpp b/util/thread/lfstack.cpp new file mode 100644 index 0000000000..be8b3bdf37 --- /dev/null +++ b/util/thread/lfstack.cpp @@ -0,0 +1 @@ +#include "lfstack.h" diff --git a/util/thread/lfstack.h b/util/thread/lfstack.h new file mode 100644 index 0000000000..ca3d95f3c3 --- /dev/null +++ b/util/thread/lfstack.h @@ -0,0 +1,188 @@ +#pragma once + +#include <util/generic/noncopyable.h> +#include <util/system/atomic.h> + +////////////////////////////// +// lock free lifo stack +template <class T> +class TLockFreeStack: TNonCopyable { + struct TNode { + T Value; + TNode* Next; + + TNode() = default; + + template <class U> + explicit TNode(U&& val) + : Value(std::forward<U>(val)) + , Next(nullptr) + { + } + }; + + TNode* Head; + TNode* FreePtr; + TAtomic DequeueCount; + + void TryToFreeMemory() { + TNode* current = AtomicGet(FreePtr); + if (!current) + return; + if (AtomicAdd(DequeueCount, 0) == 1) { + // node current is in free list, we are the last thread so try to cleanup + if (AtomicCas(&FreePtr, (TNode*)nullptr, current)) + EraseList(current); + } + } + void EraseList(TNode* volatile p) { + while (p) { + TNode* next = p->Next; + delete p; + p = next; + } + } + void EnqueueImpl(TNode* volatile head, TNode* volatile tail) { + for (;;) { + tail->Next = AtomicGet(Head); + if (AtomicCas(&Head, head, tail->Next)) + break; + } + } + template <class U> + void EnqueueImpl(U&& u) { + TNode* volatile node = new TNode(std::forward<U>(u)); + EnqueueImpl(node, node); + } + +public: + TLockFreeStack() + : Head(nullptr) + , FreePtr(nullptr) + , DequeueCount(0) + { + } + ~TLockFreeStack() { + EraseList(Head); + EraseList(FreePtr); + } + + void Enqueue(const T& t) { + EnqueueImpl(t); + } + + void Enqueue(T&& t) { + EnqueueImpl(std::move(t)); + } + + template <typename TCollection> + void EnqueueAll(const TCollection& data) { + EnqueueAll(data.begin(), data.end()); + } + template <typename TIter> + void EnqueueAll(TIter dataBegin, TIter dataEnd) { + if (dataBegin == dataEnd) { + return; + } + TIter i = dataBegin; + TNode* volatile node = new TNode(*i); + TNode* volatile tail = node; + + for (++i; i != dataEnd; ++i) { + TNode* nextNode = node; + node = new TNode(*i); + node->Next = nextNode; + } + EnqueueImpl(node, tail); + } + bool Dequeue(T* res) { + AtomicAdd(DequeueCount, 1); + for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) { + if (AtomicCas(&Head, AtomicGet(current->Next), current)) { + *res = std::move(current->Value); + // delete current; // ABA problem + // even more complex node deletion + TryToFreeMemory(); + if (AtomicAdd(DequeueCount, -1) == 0) { + // no other Dequeue()s, can safely reclaim memory + delete current; + } else { + // Dequeue()s in progress, put node to free list + for (;;) { + AtomicSet(current->Next, AtomicGet(FreePtr)); + if (AtomicCas(&FreePtr, current, current->Next)) + break; + } + } + return true; + } + } + TryToFreeMemory(); + AtomicAdd(DequeueCount, -1); + return false; + } + // add all elements to *res + // elements are returned in order of dequeue (top to bottom; see example in unittest) + template <typename TCollection> + void DequeueAll(TCollection* res) { + AtomicAdd(DequeueCount, 1); + for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) { + if (AtomicCas(&Head, (TNode*)nullptr, current)) { + for (TNode* x = current; x;) { + res->push_back(std::move(x->Value)); + x = x->Next; + } + // EraseList(current); // ABA problem + // even more complex node deletion + TryToFreeMemory(); + if (AtomicAdd(DequeueCount, -1) == 0) { + // no other Dequeue()s, can safely reclaim memory + EraseList(current); + } else { + // Dequeue()s in progress, add nodes list to free list + TNode* currentLast = current; + while (currentLast->Next) { + currentLast = currentLast->Next; + } + for (;;) { + AtomicSet(currentLast->Next, AtomicGet(FreePtr)); + if (AtomicCas(&FreePtr, current, currentLast->Next)) + break; + } + } + return; + } + } + TryToFreeMemory(); + AtomicAdd(DequeueCount, -1); + } + bool DequeueSingleConsumer(T* res) { + for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) { + if (AtomicCas(&Head, current->Next, current)) { + *res = std::move(current->Value); + delete current; // with single consumer thread ABA does not happen + return true; + } + } + return false; + } + // add all elements to *res + // elements are returned in order of dequeue (top to bottom; see example in unittest) + template <typename TCollection> + void DequeueAllSingleConsumer(TCollection* res) { + for (TNode* current = AtomicGet(Head); current; current = AtomicGet(Head)) { + if (AtomicCas(&Head, (TNode*)nullptr, current)) { + for (TNode* x = current; x;) { + res->push_back(std::move(x->Value)); + x = x->Next; + } + EraseList(current); // with single consumer thread ABA does not happen + return; + } + } + } + bool IsEmpty() { + AtomicAdd(DequeueCount, 0); // mem barrier + return AtomicGet(Head) == nullptr; // without lock, so result is approximate + } +}; diff --git a/util/thread/lfstack_ut.cpp b/util/thread/lfstack_ut.cpp new file mode 100644 index 0000000000..e20a838f95 --- /dev/null +++ b/util/thread/lfstack_ut.cpp @@ -0,0 +1,346 @@ + +#include <util/system/atomic.h> +#include <util/system/event.h> +#include <util/generic/deque.h> +#include <library/cpp/threading/future/legacy_future.h> + +#include <library/cpp/testing/unittest/registar.h> + +#include "lfstack.h" + +Y_UNIT_TEST_SUITE(TLockFreeStackTests) { + class TCountDownLatch { + private: + TAtomic Current_; + TSystemEvent EventObject_; + + public: + TCountDownLatch(unsigned initial) + : Current_(initial) + { + } + + void CountDown() { + if (AtomicDecrement(Current_) == 0) { + EventObject_.Signal(); + } + } + + void Await() { + EventObject_.Wait(); + } + + bool Await(TDuration timeout) { + return EventObject_.WaitT(timeout); + } + }; + + template <bool SingleConsumer> + struct TDequeueAllTester { + size_t EnqueueThreads; + size_t DequeueThreads; + + size_t EnqueuesPerThread; + TAtomic LeftToDequeue; + + TCountDownLatch StartLatch; + TLockFreeStack<int> Stack; + + TDequeueAllTester() + : EnqueueThreads(4) + , DequeueThreads(SingleConsumer ? 1 : 3) + , EnqueuesPerThread(100000) + , LeftToDequeue(EnqueueThreads * EnqueuesPerThread) + , StartLatch(EnqueueThreads + DequeueThreads) + { + } + + void Enqueuer() { + StartLatch.CountDown(); + StartLatch.Await(); + + for (size_t i = 0; i < EnqueuesPerThread; ++i) { + Stack.Enqueue(i); + } + } + + void DequeuerAll() { + StartLatch.CountDown(); + StartLatch.Await(); + + TVector<int> temp; + while (AtomicGet(LeftToDequeue) > 0) { + size_t dequeued = 0; + for (size_t i = 0; i < 100; ++i) { + temp.clear(); + if (SingleConsumer) { + Stack.DequeueAllSingleConsumer(&temp); + } else { + Stack.DequeueAll(&temp); + } + dequeued += temp.size(); + } + AtomicAdd(LeftToDequeue, -dequeued); + } + } + + void Run() { + TVector<TSimpleSharedPtr<NThreading::TLegacyFuture<>>> futures; + + for (size_t i = 0; i < EnqueueThreads; ++i) { + futures.push_back(new NThreading::TLegacyFuture<>(std::bind(&TDequeueAllTester<SingleConsumer>::Enqueuer, this))); + } + + for (size_t i = 0; i < DequeueThreads; ++i) { + futures.push_back(new NThreading::TLegacyFuture<>(std::bind(&TDequeueAllTester<SingleConsumer>::DequeuerAll, this))); + } + + // effectively join + futures.clear(); + + UNIT_ASSERT_VALUES_EQUAL(0, int(AtomicGet(LeftToDequeue))); + + TVector<int> left; + Stack.DequeueAll(&left); + UNIT_ASSERT(left.empty()); + } + }; + + Y_UNIT_TEST(TestDequeueAll) { + TDequeueAllTester<false>().Run(); + } + + Y_UNIT_TEST(TestDequeueAllSingleConsumer) { + TDequeueAllTester<true>().Run(); + } + + Y_UNIT_TEST(TestDequeueAllEmptyStack) { + TLockFreeStack<int> stack; + + TVector<int> r; + stack.DequeueAll(&r); + + UNIT_ASSERT(r.empty()); + } + + Y_UNIT_TEST(TestDequeueAllReturnsInReverseOrder) { + TLockFreeStack<int> stack; + + stack.Enqueue(17); + stack.Enqueue(19); + stack.Enqueue(23); + + TVector<int> r; + + stack.DequeueAll(&r); + + UNIT_ASSERT_VALUES_EQUAL(size_t(3), r.size()); + UNIT_ASSERT_VALUES_EQUAL(23, r.at(0)); + UNIT_ASSERT_VALUES_EQUAL(19, r.at(1)); + UNIT_ASSERT_VALUES_EQUAL(17, r.at(2)); + } + + Y_UNIT_TEST(TestEnqueueAll) { + TLockFreeStack<int> stack; + + TVector<int> v; + TVector<int> expected; + + stack.EnqueueAll(v); // add empty + + v.push_back(2); + v.push_back(3); + v.push_back(5); + expected.insert(expected.end(), v.begin(), v.end()); + stack.EnqueueAll(v); + + v.clear(); + + stack.EnqueueAll(v); // add empty + + v.push_back(7); + v.push_back(11); + v.push_back(13); + v.push_back(17); + expected.insert(expected.end(), v.begin(), v.end()); + stack.EnqueueAll(v); + + TVector<int> actual; + stack.DequeueAll(&actual); + + UNIT_ASSERT_VALUES_EQUAL(expected.size(), actual.size()); + for (size_t i = 0; i < actual.size(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(expected.at(expected.size() - i - 1), actual.at(i)); + } + } + + Y_UNIT_TEST(CleanInDestructor) { + TSimpleSharedPtr<bool> p(new bool); + UNIT_ASSERT_VALUES_EQUAL(1u, p.RefCount()); + + { + TLockFreeStack<TSimpleSharedPtr<bool>> stack; + + stack.Enqueue(p); + stack.Enqueue(p); + + UNIT_ASSERT_VALUES_EQUAL(3u, p.RefCount()); + } + + UNIT_ASSERT_VALUES_EQUAL(1, p.RefCount()); + } + + Y_UNIT_TEST(NoCopyTest) { + static unsigned copied = 0; + struct TCopyCount { + TCopyCount(int) { + } + TCopyCount(const TCopyCount&) { + ++copied; + } + + TCopyCount(TCopyCount&&) { + } + + TCopyCount& operator=(const TCopyCount&) { + ++copied; + return *this; + } + + TCopyCount& operator=(TCopyCount&&) { + return *this; + } + }; + + TLockFreeStack<TCopyCount> stack; + stack.Enqueue(TCopyCount(1)); + TCopyCount val(0); + stack.Dequeue(&val); + UNIT_ASSERT_VALUES_EQUAL(0, copied); + } + + Y_UNIT_TEST(MoveOnlyTest) { + TLockFreeStack<THolder<bool>> stack; + stack.Enqueue(MakeHolder<bool>(true)); + THolder<bool> val; + stack.Dequeue(&val); + UNIT_ASSERT(val); + UNIT_ASSERT_VALUES_EQUAL(true, *val); + } + + template <class TTest> + struct TMultiThreadTester { + using ThisType = TMultiThreadTester<TTest>; + + size_t Threads; + size_t OperationsPerThread; + + TCountDownLatch StartLatch; + TLockFreeStack<typename TTest::ValueType> Stack; + + TMultiThreadTester() + : Threads(10) + , OperationsPerThread(100000) + , StartLatch(Threads) + { + } + + void Worker() { + StartLatch.CountDown(); + StartLatch.Await(); + + TVector<typename TTest::ValueType> unused; + for (size_t i = 0; i < OperationsPerThread; ++i) { + switch (GetCycleCount() % 4) { + case 0: { + TTest::Enqueue(Stack, i); + break; + } + case 1: { + TTest::Dequeue(Stack); + break; + } + case 2: { + TTest::EnqueueAll(Stack); + break; + } + case 3: { + TTest::DequeueAll(Stack); + break; + } + } + } + } + + void Run() { + TDeque<NThreading::TLegacyFuture<>> futures; + + for (size_t i = 0; i < Threads; ++i) { + futures.emplace_back(std::bind(&ThisType::Worker, this)); + } + futures.clear(); + TTest::DequeueAll(Stack); + } + }; + + struct TFreeListTest { + using ValueType = int; + + static void Enqueue(TLockFreeStack<int>& stack, size_t i) { + stack.Enqueue(static_cast<int>(i)); + } + + static void Dequeue(TLockFreeStack<int>& stack) { + int value; + stack.Dequeue(&value); + } + + static void EnqueueAll(TLockFreeStack<int>& stack) { + TVector<int> values(5); + stack.EnqueueAll(values); + } + + static void DequeueAll(TLockFreeStack<int>& stack) { + TVector<int> value; + stack.DequeueAll(&value); + } + }; + + // Test for catching thread sanitizer problems + Y_UNIT_TEST(TestFreeList) { + TMultiThreadTester<TFreeListTest>().Run(); + } + + struct TMoveTest { + using ValueType = THolder<int>; + + static void Enqueue(TLockFreeStack<ValueType>& stack, size_t i) { + stack.Enqueue(MakeHolder<int>(static_cast<int>(i))); + } + + static void Dequeue(TLockFreeStack<ValueType>& stack) { + ValueType value; + if (stack.Dequeue(&value)) { + UNIT_ASSERT(value); + } + } + + static void EnqueueAll(TLockFreeStack<ValueType>& stack) { + // there is no enqueAll with moving signature in LockFreeStack + Enqueue(stack, 0); + } + + static void DequeueAll(TLockFreeStack<ValueType>& stack) { + TVector<ValueType> values; + stack.DequeueAll(&values); + for (auto& v : values) { + UNIT_ASSERT(v); + } + } + }; + + // Test for catching thread sanitizer problems + Y_UNIT_TEST(TesMultiThreadMove) { + TMultiThreadTester<TMoveTest>().Run(); + } +} diff --git a/util/thread/pool.cpp b/util/thread/pool.cpp new file mode 100644 index 0000000000..05fad02e9b --- /dev/null +++ b/util/thread/pool.cpp @@ -0,0 +1,772 @@ +#include <atomic> + +#include <util/system/defaults.h> + +#if defined(_unix_) + #include <pthread.h> +#endif + +#include <util/generic/vector.h> +#include <util/generic/intrlist.h> +#include <util/generic/yexception.h> +#include <util/generic/ylimits.h> +#include <util/generic/singleton.h> +#include <util/generic/fastqueue.h> + +#include <util/stream/output.h> +#include <util/string/builder.h> + +#include <util/system/event.h> +#include <util/system/mutex.h> +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/thread.h> + +#include <util/datetime/base.h> + +#include "factory.h" +#include "pool.h" + +namespace { + class TThreadNamer { + public: + TThreadNamer(const IThreadPool::TParams& params) + : ThreadName(params.ThreadName_) + , EnumerateThreads(params.EnumerateThreads_) + { + } + + explicit operator bool() const { + return !ThreadName.empty(); + } + + void SetCurrentThreadName() { + if (EnumerateThreads) { + Set(TStringBuilder() << ThreadName << (Index++)); + } else { + Set(ThreadName); + } + } + + private: + void Set(const TString& name) { + TThread::SetCurrentThreadName(name.c_str()); + } + + private: + TString ThreadName; + bool EnumerateThreads = false; + std::atomic<ui64> Index{0}; + }; +} + +TThreadFactoryHolder::TThreadFactoryHolder() noexcept + : Pool_(SystemThreadFactory()) +{ +} + +class TThreadPool::TImpl: public TIntrusiveListItem<TImpl>, public IThreadFactory::IThreadAble { + using TTsr = IThreadPool::TTsr; + using TJobQueue = TFastQueue<IObjectInQueue*>; + using TThreadRef = THolder<IThreadFactory::IThread>; + +public: + inline TImpl(TThreadPool* parent, size_t thrnum, size_t maxqueue, const TParams& params) + : Parent_(parent) + , Blocking(params.Blocking_) + , Catching(params.Catching_) + , Namer(params) + , ShouldTerminate(1) + , MaxQueueSize(0) + , ThreadCountExpected(0) + , ThreadCountReal(0) + , Forked(false) + { + TAtforkQueueRestarter::Get().RegisterObject(this); + Start(thrnum, maxqueue); + } + + inline ~TImpl() override { + try { + Stop(); + } catch (...) { + // ¯\_(ツ)_/¯ + } + + TAtforkQueueRestarter::Get().UnregisterObject(this); + Y_ASSERT(Tharr.empty()); + } + + inline bool Add(IObjectInQueue* obj) { + if (AtomicGet(ShouldTerminate)) { + return false; + } + + if (Tharr.empty()) { + TTsr tsr(Parent_); + obj->Process(tsr); + + return true; + } + + with_lock (QueueMutex) { + while (MaxQueueSize > 0 && Queue.Size() >= MaxQueueSize && !AtomicGet(ShouldTerminate)) { + if (!Blocking) { + return false; + } + QueuePopCond.Wait(QueueMutex); + } + + if (AtomicGet(ShouldTerminate)) { + return false; + } + + Queue.Push(obj); + } + + QueuePushCond.Signal(); + + return true; + } + + inline size_t Size() const noexcept { + auto guard = Guard(QueueMutex); + + return Queue.Size(); + } + + inline size_t GetMaxQueueSize() const noexcept { + return MaxQueueSize; + } + + inline size_t GetThreadCountExpected() const noexcept { + return ThreadCountExpected; + } + + inline size_t GetThreadCountReal() const noexcept { + return ThreadCountReal; + } + + inline void AtforkAction() noexcept Y_NO_SANITIZE("thread") { + Forked = true; + } + + inline bool NeedRestart() const noexcept { + return Forked; + } + +private: + inline void Start(size_t num, size_t maxque) { + AtomicSet(ShouldTerminate, 0); + MaxQueueSize = maxque; + ThreadCountExpected = num; + + try { + for (size_t i = 0; i < num; ++i) { + Tharr.push_back(Parent_->Pool()->Run(this)); + ++ThreadCountReal; + } + } catch (...) { + Stop(); + + throw; + } + } + + inline void Stop() { + AtomicSet(ShouldTerminate, 1); + + with_lock (QueueMutex) { + QueuePopCond.BroadCast(); + } + + if (!NeedRestart()) { + WaitForComplete(); + } + + Tharr.clear(); + ThreadCountExpected = 0; + MaxQueueSize = 0; + } + + inline void WaitForComplete() noexcept { + with_lock (StopMutex) { + while (ThreadCountReal) { + with_lock (QueueMutex) { + QueuePushCond.Signal(); + } + + StopCond.Wait(StopMutex); + } + } + } + + void DoExecute() override { + THolder<TTsr> tsr(new TTsr(Parent_)); + + if (Namer) { + Namer.SetCurrentThreadName(); + } + + while (true) { + IObjectInQueue* job = nullptr; + + with_lock (QueueMutex) { + while (Queue.Empty() && !AtomicGet(ShouldTerminate)) { + QueuePushCond.Wait(QueueMutex); + } + + if (AtomicGet(ShouldTerminate) && Queue.Empty()) { + tsr.Destroy(); + + break; + } + + job = Queue.Pop(); + } + + QueuePopCond.Signal(); + + if (Catching) { + try { + try { + job->Process(*tsr); + } catch (...) { + Cdbg << "[mtp queue] " << CurrentExceptionMessage() << Endl; + } + } catch (...) { + // ¯\_(ツ)_/¯ + } + } else { + job->Process(*tsr); + } + } + + FinishOneThread(); + } + + inline void FinishOneThread() noexcept { + auto guard = Guard(StopMutex); + + --ThreadCountReal; + StopCond.Signal(); + } + +private: + TThreadPool* Parent_; + const bool Blocking; + const bool Catching; + TThreadNamer Namer; + mutable TMutex QueueMutex; + mutable TMutex StopMutex; + TCondVar QueuePushCond; + TCondVar QueuePopCond; + TCondVar StopCond; + TJobQueue Queue; + TVector<TThreadRef> Tharr; + TAtomic ShouldTerminate; + size_t MaxQueueSize; + size_t ThreadCountExpected; + size_t ThreadCountReal; + bool Forked; + + class TAtforkQueueRestarter { + public: + static TAtforkQueueRestarter& Get() { + return *SingletonWithPriority<TAtforkQueueRestarter, 256>(); + } + + inline void RegisterObject(TImpl* obj) { + auto guard = Guard(ActionMutex); + + RegisteredObjects.PushBack(obj); + } + + inline void UnregisterObject(TImpl* obj) { + auto guard = Guard(ActionMutex); + + obj->Unlink(); + } + + private: + void ChildAction() { + with_lock (ActionMutex) { + for (auto it = RegisteredObjects.Begin(); it != RegisteredObjects.End(); ++it) { + it->AtforkAction(); + } + } + } + + static void ProcessChildAction() { + Get().ChildAction(); + } + + TIntrusiveList<TImpl> RegisteredObjects; + TMutex ActionMutex; + + public: + inline TAtforkQueueRestarter() { +#if defined(_bionic_) +//no pthread_atfork on android libc +#elif defined(_unix_) + pthread_atfork(nullptr, nullptr, ProcessChildAction); +#endif + } + }; +}; + +TThreadPool::~TThreadPool() = default; + +size_t TThreadPool::Size() const noexcept { + if (!Impl_.Get()) { + return 0; + } + + return Impl_->Size(); +} + +size_t TThreadPool::GetThreadCountExpected() const noexcept { + if (!Impl_.Get()) { + return 0; + } + + return Impl_->GetThreadCountExpected(); +} + +size_t TThreadPool::GetThreadCountReal() const noexcept { + if (!Impl_.Get()) { + return 0; + } + + return Impl_->GetThreadCountReal(); +} + +size_t TThreadPool::GetMaxQueueSize() const noexcept { + if (!Impl_.Get()) { + return 0; + } + + return Impl_->GetMaxQueueSize(); +} + +bool TThreadPool::Add(IObjectInQueue* obj) { + Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started")); + + if (Impl_->NeedRestart()) { + Start(Impl_->GetThreadCountExpected(), Impl_->GetMaxQueueSize()); + } + + return Impl_->Add(obj); +} + +void TThreadPool::Start(size_t thrnum, size_t maxque) { + Impl_.Reset(new TImpl(this, thrnum, maxque, Params)); +} + +void TThreadPool::Stop() noexcept { + Impl_.Destroy(); +} + +static TAtomic mtp_queue_counter = 0; + +class TAdaptiveThreadPool::TImpl { +public: + class TThread: public IThreadFactory::IThreadAble { + public: + inline TThread(TImpl* parent) + : Impl_(parent) + , Thread_(Impl_->Parent_->Pool()->Run(this)) + { + } + + inline ~TThread() override { + Impl_->DecThreadCount(); + } + + private: + void DoExecute() noexcept override { + THolder<TThread> This(this); + + if (Impl_->Namer) { + Impl_->Namer.SetCurrentThreadName(); + } + + { + TTsr tsr(Impl_->Parent_); + IObjectInQueue* obj; + + while ((obj = Impl_->WaitForJob()) != nullptr) { + if (Impl_->Catching) { + try { + try { + obj->Process(tsr); + } catch (...) { + Cdbg << Impl_->Name() << " " << CurrentExceptionMessage() << Endl; + } + } catch (...) { + // ¯\_(ツ)_/¯ + } + } else { + obj->Process(tsr); + } + } + } + } + + private: + TImpl* Impl_; + THolder<IThreadFactory::IThread> Thread_; + }; + + inline TImpl(TAdaptiveThreadPool* parent, const TParams& params) + : Parent_(parent) + , Catching(params.Catching_) + , Namer(params) + , ThrCount_(0) + , AllDone_(false) + , Obj_(nullptr) + , Free_(0) + , IdleTime_(TDuration::Max()) + { + sprintf(Name_, "[mtp queue %ld]", (long)AtomicAdd(mtp_queue_counter, 1)); + } + + inline ~TImpl() { + Stop(); + } + + inline void SetMaxIdleTime(TDuration idleTime) { + IdleTime_ = idleTime; + } + + inline const char* Name() const noexcept { + return Name_; + } + + inline void Add(IObjectInQueue* obj) { + with_lock (Mutex_) { + while (Obj_ != nullptr) { + CondFree_.Wait(Mutex_); + } + + if (Free_ == 0) { + AddThreadNoLock(); + } + + Obj_ = obj; + + Y_ENSURE_EX(!AllDone_, TThreadPoolException() << TStringBuf("adding to a stopped queue")); + } + + CondReady_.Signal(); + } + + inline void AddThreads(size_t n) { + with_lock (Mutex_) { + while (n) { + AddThreadNoLock(); + + --n; + } + } + } + + inline size_t Size() const noexcept { + return (size_t)ThrCount_; + } + +private: + inline void IncThreadCount() noexcept { + AtomicAdd(ThrCount_, 1); + } + + inline void DecThreadCount() noexcept { + AtomicAdd(ThrCount_, -1); + } + + inline void AddThreadNoLock() { + IncThreadCount(); + + try { + new TThread(this); + } catch (...) { + DecThreadCount(); + + throw; + } + } + + inline void Stop() noexcept { + Mutex_.Acquire(); + + AllDone_ = true; + + while (AtomicGet(ThrCount_)) { + Mutex_.Release(); + CondReady_.Signal(); + Mutex_.Acquire(); + } + + Mutex_.Release(); + } + + inline IObjectInQueue* WaitForJob() noexcept { + Mutex_.Acquire(); + + ++Free_; + + while (!Obj_ && !AllDone_) { + if (!CondReady_.WaitT(Mutex_, IdleTime_)) { + break; + } + } + + IObjectInQueue* ret = Obj_; + Obj_ = nullptr; + + --Free_; + + Mutex_.Release(); + CondFree_.Signal(); + + return ret; + } + +private: + TAdaptiveThreadPool* Parent_; + const bool Catching; + TThreadNamer Namer; + TAtomic ThrCount_; + TMutex Mutex_; + TCondVar CondReady_; + TCondVar CondFree_; + bool AllDone_; + IObjectInQueue* Obj_; + size_t Free_; + char Name_[64]; + TDuration IdleTime_; +}; + +TThreadPoolBase::TThreadPoolBase(const TParams& params) + : TThreadFactoryHolder(params.Factory_) + , Params(params) +{ +} + +#define DEFINE_THREAD_POOL_CTORS(type) \ + type::type(const TParams& params) \ + : TThreadPoolBase(params) \ + { \ + } + +DEFINE_THREAD_POOL_CTORS(TThreadPool) +DEFINE_THREAD_POOL_CTORS(TAdaptiveThreadPool) +DEFINE_THREAD_POOL_CTORS(TSimpleThreadPool) + +TAdaptiveThreadPool::~TAdaptiveThreadPool() = default; + +bool TAdaptiveThreadPool::Add(IObjectInQueue* obj) { + Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started")); + + Impl_->Add(obj); + + return true; +} + +void TAdaptiveThreadPool::Start(size_t, size_t) { + Impl_.Reset(new TImpl(this, Params)); +} + +void TAdaptiveThreadPool::Stop() noexcept { + Impl_.Destroy(); +} + +size_t TAdaptiveThreadPool::Size() const noexcept { + if (Impl_.Get()) { + return Impl_->Size(); + } + + return 0; +} + +void TAdaptiveThreadPool::SetMaxIdleTime(TDuration interval) { + Y_ENSURE_EX(Impl_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started")); + + Impl_->SetMaxIdleTime(interval); +} + +TSimpleThreadPool::~TSimpleThreadPool() { + try { + Stop(); + } catch (...) { + // ¯\_(ツ)_/¯ + } +} + +bool TSimpleThreadPool::Add(IObjectInQueue* obj) { + Y_ENSURE_EX(Slave_.Get(), TThreadPoolException() << TStringBuf("mtp queue not started")); + + return Slave_->Add(obj); +} + +void TSimpleThreadPool::Start(size_t thrnum, size_t maxque) { + THolder<IThreadPool> tmp; + TAdaptiveThreadPool* adaptive(nullptr); + + if (thrnum) { + tmp.Reset(new TThreadPoolBinder<TThreadPool, TSimpleThreadPool>(this, Params)); + } else { + adaptive = new TThreadPoolBinder<TAdaptiveThreadPool, TSimpleThreadPool>(this, Params); + tmp.Reset(adaptive); + } + + tmp->Start(thrnum, maxque); + + if (adaptive) { + adaptive->SetMaxIdleTime(TDuration::Seconds(100)); + } + + Slave_.Swap(tmp); +} + +void TSimpleThreadPool::Stop() noexcept { + Slave_.Destroy(); +} + +size_t TSimpleThreadPool::Size() const noexcept { + if (Slave_.Get()) { + return Slave_->Size(); + } + + return 0; +} + +namespace { + class TOwnedObjectInQueue: public IObjectInQueue { + private: + THolder<IObjectInQueue> Owned; + + public: + TOwnedObjectInQueue(THolder<IObjectInQueue> owned) + : Owned(std::move(owned)) + { + } + + void Process(void* data) override { + THolder<TOwnedObjectInQueue> self(this); + Owned->Process(data); + } + }; +} + +void IThreadPool::SafeAdd(IObjectInQueue* obj) { + Y_ENSURE_EX(Add(obj), TThreadPoolException() << TStringBuf("can not add object to queue")); +} + +void IThreadPool::SafeAddAndOwn(THolder<IObjectInQueue> obj) { + Y_ENSURE_EX(AddAndOwn(std::move(obj)), TThreadPoolException() << TStringBuf("can not add to queue and own")); +} + +bool IThreadPool::AddAndOwn(THolder<IObjectInQueue> obj) { + auto owner = MakeHolder<TOwnedObjectInQueue>(std::move(obj)); + bool added = Add(owner.Get()); + if (added) { + Y_UNUSED(owner.Release()); + } + return added; +} + +using IThread = IThreadFactory::IThread; +using IThreadAble = IThreadFactory::IThreadAble; + +namespace { + class TPoolThread: public IThread { + class TThreadImpl: public IObjectInQueue, public TAtomicRefCount<TThreadImpl> { + public: + inline TThreadImpl(IThreadAble* func) + : Func_(func) + { + } + + ~TThreadImpl() override = default; + + inline void WaitForStart() noexcept { + StartEvent_.Wait(); + } + + inline void WaitForComplete() noexcept { + CompleteEvent_.Wait(); + } + + private: + void Process(void* /*tsr*/) override { + TThreadImplRef This(this); + + { + StartEvent_.Signal(); + + try { + Func_->Execute(); + } catch (...) { + // ¯\_(ツ)_/¯ + } + + CompleteEvent_.Signal(); + } + } + + private: + IThreadAble* Func_; + TSystemEvent CompleteEvent_; + TSystemEvent StartEvent_; + }; + + using TThreadImplRef = TIntrusivePtr<TThreadImpl>; + + public: + inline TPoolThread(IThreadPool* parent) + : Parent_(parent) + { + } + + ~TPoolThread() override { + if (Impl_) { + Impl_->WaitForStart(); + } + } + + private: + void DoRun(IThreadAble* func) override { + TThreadImplRef impl(new TThreadImpl(func)); + + Parent_->SafeAdd(impl.Get()); + Impl_.Swap(impl); + } + + void DoJoin() noexcept override { + if (Impl_) { + Impl_->WaitForComplete(); + Impl_ = nullptr; + } + } + + private: + IThreadPool* Parent_; + TThreadImplRef Impl_; + }; +} + +IThread* IThreadPool::DoCreate() { + return new TPoolThread(this); +} + +THolder<IThreadPool> CreateThreadPool(size_t threadsCount, size_t queueSizeLimit, const TThreadPoolParams& params) { + THolder<IThreadPool> queue; + if (threadsCount > 1) { + queue.Reset(new TThreadPool(params)); + } else { + queue.Reset(new TFakeThreadPool()); + } + queue->Start(threadsCount, queueSizeLimit); + return queue; +} diff --git a/util/thread/pool.h b/util/thread/pool.h new file mode 100644 index 0000000000..d1ea3a67cb --- /dev/null +++ b/util/thread/pool.h @@ -0,0 +1,390 @@ +#pragma once + +#include "fwd.h" +#include "factory.h" + +#include <util/system/yassert.h> +#include <util/system/defaults.h> +#include <util/generic/yexception.h> +#include <util/generic/ptr.h> +#include <util/generic/noncopyable.h> +#include <functional> + +class TDuration; + +struct IObjectInQueue { + virtual ~IObjectInQueue() = default; + + /** + * Supposed to be implemented by user, to define jobs processed + * in multiple threads. + * + * @param threadSpecificResource is nullptr by default. But if you override + * IThreadPool::CreateThreadSpecificResource, then result of + * IThreadPool::CreateThreadSpecificResource is passed as threadSpecificResource + * parameter. + */ + virtual void Process(void* threadSpecificResource) = 0; +}; + +/** + * Mighty class to add 'Pool' method to derived classes. + * Useful only for creators of new queue classes. + */ +class TThreadFactoryHolder { +public: + TThreadFactoryHolder() noexcept; + + inline TThreadFactoryHolder(IThreadFactory* pool) noexcept + : Pool_(pool) + { + } + + inline ~TThreadFactoryHolder() = default; + + inline IThreadFactory* Pool() const noexcept { + return Pool_; + } + +private: + IThreadFactory* Pool_; +}; + +class TThreadPoolException: public yexception { +}; + +template <class T> +class TThrFuncObj: public IObjectInQueue { +public: + TThrFuncObj(const T& func) + : Func(func) + { + } + + TThrFuncObj(T&& func) + : Func(std::move(func)) + { + } + + void Process(void*) override { + THolder<TThrFuncObj> self(this); + Func(); + } + +private: + T Func; +}; + +template <class T> +IObjectInQueue* MakeThrFuncObj(T&& func) { + return new TThrFuncObj<std::remove_cv_t<std::remove_reference_t<T>>>(std::forward<T>(func)); +} + +struct TThreadPoolParams { + bool Catching_ = true; + bool Blocking_ = false; + IThreadFactory* Factory_ = SystemThreadFactory(); + TString ThreadName_; + bool EnumerateThreads_ = false; + + using TSelf = TThreadPoolParams; + + TThreadPoolParams() { + } + + TThreadPoolParams(IThreadFactory* factory) + : Factory_(factory) + { + } + + TThreadPoolParams(const TString& name) { + SetThreadName(name); + } + + TThreadPoolParams(const char* name) { + SetThreadName(name); + } + + TSelf& SetCatching(bool val) { + Catching_ = val; + return *this; + } + + TSelf& SetBlocking(bool val) { + Blocking_ = val; + return *this; + } + + TSelf& SetFactory(IThreadFactory* factory) { + Factory_ = factory; + return *this; + } + + TSelf& SetThreadName(const TString& name) { + ThreadName_ = name; + EnumerateThreads_ = false; + return *this; + } + + TSelf& SetThreadNamePrefix(const TString& prefix) { + ThreadName_ = prefix; + EnumerateThreads_ = true; + return *this; + } +}; + +/** + * A queue processed simultaneously by several threads + */ +class IThreadPool: public IThreadFactory, public TNonCopyable { +public: + using TParams = TThreadPoolParams; + + ~IThreadPool() override = default; + + /** + * Safe versions of Add*() functions. Behave exactly like as non-safe + * version of Add*(), but use exceptions instead returning false + */ + void SafeAdd(IObjectInQueue* obj); + + template <class T> + void SafeAddFunc(T&& func) { + Y_ENSURE_EX(AddFunc(std::forward<T>(func)), TThreadPoolException() << TStringBuf("can not add function to queue")); + } + + void SafeAddAndOwn(THolder<IObjectInQueue> obj); + + /** + * Add object to queue, run ojb->Proccess in other threads. + * Obj is not deleted after execution + * @return true of obj is successfully added to queue + * @return false if queue is full or shutting down + */ + virtual bool Add(IObjectInQueue* obj) Y_WARN_UNUSED_RESULT = 0; + + template <class T> + Y_WARN_UNUSED_RESULT bool AddFunc(T&& func) { + THolder<IObjectInQueue> wrapper(MakeThrFuncObj(std::forward<T>(func))); + bool added = Add(wrapper.Get()); + if (added) { + Y_UNUSED(wrapper.Release()); + } + return added; + } + + bool AddAndOwn(THolder<IObjectInQueue> obj) Y_WARN_UNUSED_RESULT; + virtual void Start(size_t threadCount, size_t queueSizeLimit = 0) = 0; + /** Wait for completion of all scheduled objects, and then exit */ + virtual void Stop() noexcept = 0; + /** Number of tasks currently in queue */ + virtual size_t Size() const noexcept = 0; + +public: + /** + * RAII wrapper for Create/DestroyThreadSpecificResource. + * Useful only for implementers of new IThreadPool queues. + */ + class TTsr { + public: + inline TTsr(IThreadPool* q) + : Q_(q) + , Data_(Q_->CreateThreadSpecificResource()) + { + } + + inline ~TTsr() { + try { + Q_->DestroyThreadSpecificResource(Data_); + } catch (...) { + // ¯\_(ツ)_/¯ + } + } + + inline operator void*() noexcept { + return Data_; + } + + private: + IThreadPool* Q_; + void* Data_; + }; + + /** + * CreateThreadSpecificResource and DestroyThreadSpecificResource + * called from internals of (TAdaptiveThreadPool, TThreadPool, ...) implementation, + * not by user of IThreadPool interface. + * Created resource is passed to IObjectInQueue::Proccess function. + */ + virtual void* CreateThreadSpecificResource() { + return nullptr; + } + + virtual void DestroyThreadSpecificResource(void* resource) { + if (resource != nullptr) { + Y_ASSERT(resource == nullptr); + } + } + +private: + IThread* DoCreate() override; +}; + +/** + * Single-threaded implementation of IThreadPool, process tasks in same thread when + * added. + * Can be used to remove multithreading. + */ +class TFakeThreadPool: public IThreadPool { +public: + bool Add(IObjectInQueue* pObj) override Y_WARN_UNUSED_RESULT { + TTsr tsr(this); + pObj->Process(tsr); + + return true; + } + + void Start(size_t, size_t = 0) override { + } + + void Stop() noexcept override { + } + + size_t Size() const noexcept override { + return 0; + } +}; + +class TThreadPoolBase: public IThreadPool, public TThreadFactoryHolder { +public: + TThreadPoolBase(const TParams& params); + +protected: + TParams Params; +}; + +/** queue processed by fixed size thread pool */ +class TThreadPool: public TThreadPoolBase { +public: + TThreadPool(const TParams& params = {}); + ~TThreadPool() override; + + bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT; + /** + * @param queueSizeLimit means "unlimited" when = 0 + * @param threadCount means "single thread" when = 0 + */ + void Start(size_t threadCount, size_t queueSizeLimit = 0) override; + void Stop() noexcept override; + size_t Size() const noexcept override; + size_t GetThreadCountExpected() const noexcept; + size_t GetThreadCountReal() const noexcept; + size_t GetMaxQueueSize() const noexcept; + +private: + class TImpl; + THolder<TImpl> Impl_; +}; + +/** + * Always create new thread for new task, when all existing threads are busy. + * Maybe dangerous, number of threads is not limited. + */ +class TAdaptiveThreadPool: public TThreadPoolBase { +public: + TAdaptiveThreadPool(const TParams& params = {}); + ~TAdaptiveThreadPool() override; + + /** + * If working thread waits task too long (more then interval parameter), + * then the thread would be killed. Default value - infinity, all created threads + * waits for new task forever, before Stop. + */ + void SetMaxIdleTime(TDuration interval); + + bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT; + /** @param thrnum, @param maxque are ignored */ + void Start(size_t thrnum = 0, size_t maxque = 0) override; + void Stop() noexcept override; + size_t Size() const noexcept override; + +private: + class TImpl; + THolder<TImpl> Impl_; +}; + +/** Behave like TThreadPool or TAdaptiveThreadPool, choosen by thrnum parameter of Start() */ +class TSimpleThreadPool: public TThreadPoolBase { +public: + TSimpleThreadPool(const TParams& params = {}); + ~TSimpleThreadPool() override; + + bool Add(IObjectInQueue* obj) override Y_WARN_UNUSED_RESULT; + /** + * @parameter thrnum. If thrnum is 0, use TAdaptiveThreadPool with small + * SetMaxIdleTime interval parameter. if thrnum is not 0, use non-blocking TThreadPool + */ + void Start(size_t thrnum, size_t maxque = 0) override; + void Stop() noexcept override; + size_t Size() const noexcept override; + +private: + THolder<IThreadPool> Slave_; +}; + +/** + * Helper to override virtual functions Create/DestroyThreadSpecificResource + * from IThreadPool and implement them using functions with same name from + * pointer to TSlave. + */ +template <class TQueueType, class TSlave> +class TThreadPoolBinder: public TQueueType { +public: + inline TThreadPoolBinder(TSlave* slave) + : Slave_(slave) + { + } + + template <class... Args> + inline TThreadPoolBinder(TSlave* slave, Args&&... args) + : TQueueType(std::forward<Args>(args)...) + , Slave_(slave) + { + } + + inline TThreadPoolBinder(TSlave& slave) + : Slave_(&slave) + { + } + + ~TThreadPoolBinder() override { + try { + this->Stop(); + } catch (...) { + // ¯\_(ツ)_/¯ + } + } + + void* CreateThreadSpecificResource() override { + return Slave_->CreateThreadSpecificResource(); + } + + void DestroyThreadSpecificResource(void* resource) override { + Slave_->DestroyThreadSpecificResource(resource); + } + +private: + TSlave* Slave_; +}; + +inline void Delete(THolder<IThreadPool> q) { + if (q.Get()) { + q->Stop(); + } +} + +/** + * Creates and starts TThreadPool if threadsCount > 1, or TFakeThreadPool otherwise + * You could specify blocking and catching modes for TThreadPool only + */ +THolder<IThreadPool> CreateThreadPool(size_t threadCount, size_t queueSizeLimit = 0, const IThreadPool::TParams& params = {}); diff --git a/util/thread/pool_ut.cpp b/util/thread/pool_ut.cpp new file mode 100644 index 0000000000..893770d0c4 --- /dev/null +++ b/util/thread/pool_ut.cpp @@ -0,0 +1,257 @@ +#include "pool.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/stream/output.h> +#include <util/random/fast.h> +#include <util/system/spinlock.h> +#include <util/system/thread.h> +#include <util/system/mutex.h> +#include <util/system/condvar.h> + +struct TThreadPoolTest { + TSpinLock Lock; + long R = -1; + + struct TTask: public IObjectInQueue { + TThreadPoolTest* Test = nullptr; + long Value = 0; + + TTask(TThreadPoolTest* test, int value) + : Test(test) + , Value(value) + { + } + + void Process(void*) override { + THolder<TTask> This(this); + + TGuard<TSpinLock> guard(Test->Lock); + Test->R ^= Value; + } + }; + + struct TOwnedTask: public IObjectInQueue { + bool& Processed; + bool& Destructed; + + TOwnedTask(bool& processed, bool& destructed) + : Processed(processed) + , Destructed(destructed) + { + } + + ~TOwnedTask() override { + Destructed = true; + } + + void Process(void*) override { + Processed = true; + } + }; + + inline void TestAnyQueue(IThreadPool* queue, size_t queueSize = 1000) { + TReallyFastRng32 rand(17); + const size_t cnt = 1000; + + R = 0; + + for (size_t i = 0; i < cnt; ++i) { + R ^= (long)rand.GenRand(); + } + + queue->Start(10, queueSize); + rand = TReallyFastRng32(17); + + for (size_t i = 0; i < cnt; ++i) { + UNIT_ASSERT(queue->Add(new TTask(this, (long)rand.GenRand()))); + } + + queue->Stop(); + + UNIT_ASSERT_EQUAL(0, R); + } +}; + +class TFailAddQueue: public IThreadPool { +public: + bool Add(IObjectInQueue* /*obj*/) override Y_WARN_UNUSED_RESULT { + return false; + } + + void Start(size_t, size_t) override { + } + + void Stop() noexcept override { + } + + size_t Size() const noexcept override { + return 0; + } +}; + +Y_UNIT_TEST_SUITE(TThreadPoolTest) { + Y_UNIT_TEST(TestTThreadPool) { + TThreadPoolTest t; + TThreadPool q; + t.TestAnyQueue(&q); + } + + Y_UNIT_TEST(TestTThreadPoolBlocking) { + TThreadPoolTest t; + TThreadPool q(TThreadPool::TParams().SetBlocking(true)); + t.TestAnyQueue(&q, 100); + } + + // disabled by pg@ long time ago due to test flaps + // Tried to enable: REVIEW:78772 + Y_UNIT_TEST(TestTAdaptiveThreadPool) { + if (false) { + TThreadPoolTest t; + TAdaptiveThreadPool q; + t.TestAnyQueue(&q); + } + } + + Y_UNIT_TEST(TestAddAndOwn) { + TThreadPool q; + q.Start(2); + bool processed = false; + bool destructed = false; + q.SafeAddAndOwn(MakeHolder<TThreadPoolTest::TOwnedTask>(processed, destructed)); + q.Stop(); + + UNIT_ASSERT_C(processed, "Not processed"); + UNIT_ASSERT_C(destructed, "Not destructed"); + } + + Y_UNIT_TEST(TestAddFunc) { + TFailAddQueue queue; + bool added = queue.AddFunc( + []() {} // Lambda, I call him 'Lambda'! + ); + UNIT_ASSERT_VALUES_EQUAL(added, false); + } + + Y_UNIT_TEST(TestSafeAddFuncThrows) { + TFailAddQueue queue; + UNIT_CHECK_GENERATED_EXCEPTION(queue.SafeAddFunc([] {}), TThreadPoolException); + } + + Y_UNIT_TEST(TestFunctionNotCopied) { + struct TFailOnCopy { + TFailOnCopy() { + } + + TFailOnCopy(TFailOnCopy&&) { + } + + TFailOnCopy(const TFailOnCopy&) { + UNIT_FAIL("Don't copy std::function inside TThreadPool"); + } + }; + + TThreadPool queue(TThreadPool::TParams().SetBlocking(false).SetCatching(true)); + queue.Start(2); + + queue.SafeAddFunc([data = TFailOnCopy()]() {}); + + queue.Stop(); + } + + Y_UNIT_TEST(TestInfoGetters) { + TThreadPool queue; + + queue.Start(2, 7); + + UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 2); + UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 2); + UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 7); + + queue.Stop(); + + queue.Start(4, 1); + + UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 4); + UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 4); + UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 1); + + queue.Stop(); + } + + void TestFixedThreadName(IThreadPool& pool, const TString& expectedName) { + pool.Start(1); + TString name; + pool.SafeAddFunc([&name]() { + name = TThread::CurrentThreadName(); + }); + pool.Stop(); + if (TThread::CanGetCurrentThreadName()) { + UNIT_ASSERT_EQUAL(name, expectedName); + UNIT_ASSERT_UNEQUAL(TThread::CurrentThreadName(), expectedName); + } + } + + Y_UNIT_TEST(TestFixedThreadName) { + const TString expectedName = "HelloWorld"; + { + TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadName(expectedName)); + TestFixedThreadName(pool, expectedName); + } + { + TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadName(expectedName)); + TestFixedThreadName(pool, expectedName); + } + } + + void TestEnumeratedThreadName(IThreadPool& pool, const THashSet<TString>& expectedNames) { + pool.Start(expectedNames.size()); + TMutex lock; + TCondVar allReady; + size_t readyCount = 0; + THashSet<TString> names; + for (size_t i = 0; i < expectedNames.size(); ++i) { + pool.SafeAddFunc([&]() { + with_lock (lock) { + if (++readyCount == expectedNames.size()) { + allReady.BroadCast(); + } else { + while (readyCount != expectedNames.size()) { + allReady.WaitI(lock); + } + } + names.insert(TThread::CurrentThreadName()); + } + }); + } + pool.Stop(); + if (TThread::CanGetCurrentThreadName()) { + UNIT_ASSERT_EQUAL(names, expectedNames); + } + } + + Y_UNIT_TEST(TestEnumeratedThreadName) { + const TString namePrefix = "HelloWorld"; + const THashSet<TString> expectedNames = { + "HelloWorld0", + "HelloWorld1", + "HelloWorld2", + "HelloWorld3", + "HelloWorld4", + "HelloWorld5", + "HelloWorld6", + "HelloWorld7", + "HelloWorld8", + "HelloWorld9", + "HelloWorld10", + }; + { + TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadNamePrefix(namePrefix)); + TestEnumeratedThreadName(pool, expectedNames); + } + { + TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadNamePrefix(namePrefix)); + TestEnumeratedThreadName(pool, expectedNames); + } + } +} diff --git a/util/thread/singleton.cpp b/util/thread/singleton.cpp new file mode 100644 index 0000000000..a898bdc9d4 --- /dev/null +++ b/util/thread/singleton.cpp @@ -0,0 +1 @@ +#include "singleton.h" diff --git a/util/thread/singleton.h b/util/thread/singleton.h new file mode 100644 index 0000000000..4a1e05aea0 --- /dev/null +++ b/util/thread/singleton.h @@ -0,0 +1,41 @@ +#pragma once + +#include <util/system/tls.h> +#include <util/generic/singleton.h> +#include <util/generic/ptr.h> + +namespace NPrivate { + template <class T, size_t Priority> + struct TFastThreadSingletonHelper { + static inline T* GetSlow() { + return SingletonWithPriority<NTls::TValue<T>, Priority>()->GetPtr(); + } + + static inline T* Get() { +#if defined(Y_HAVE_FAST_POD_TLS) + Y_POD_STATIC_THREAD(T*) fast(nullptr); + + if (Y_UNLIKELY(!fast)) { + fast = GetSlow(); + } + + return fast; +#else + return GetSlow(); +#endif + } + }; +} + +template <class T, size_t Priority> +static inline T* FastTlsSingletonWithPriority() { + return ::NPrivate::TFastThreadSingletonHelper<T, Priority>::Get(); +} + +// NB: the singleton is the same for all modules that use +// FastTlsSingleton with the same type parameter. If unique singleton +// required, use unique types. +template <class T> +static inline T* FastTlsSingleton() { + return FastTlsSingletonWithPriority<T, TSingletonTraits<T>::Priority>(); +} diff --git a/util/thread/singleton_ut.cpp b/util/thread/singleton_ut.cpp new file mode 100644 index 0000000000..164b1cc184 --- /dev/null +++ b/util/thread/singleton_ut.cpp @@ -0,0 +1,21 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "singleton.h" + +namespace { + struct TFoo { + int i; + TFoo() + : i(0) + { + } + }; +} + +Y_UNIT_TEST_SUITE(Tls) { + Y_UNIT_TEST(FastThread) { + UNIT_ASSERT_VALUES_EQUAL(0, FastTlsSingleton<TFoo>()->i); + FastTlsSingleton<TFoo>()->i += 3; + UNIT_ASSERT_VALUES_EQUAL(3, FastTlsSingleton<TFoo>()->i); + } +} diff --git a/util/thread/ut/ya.make b/util/thread/ut/ya.make new file mode 100644 index 0000000000..93198bfaf1 --- /dev/null +++ b/util/thread/ut/ya.make @@ -0,0 +1,18 @@ +UNITTEST_FOR(util) + +OWNER(g:util) +SUBSCRIBER(g:util-subscribers) + +SRCS( + thread/factory_ut.cpp + thread/lfqueue_ut.cpp + thread/lfstack_ut.cpp + thread/pool_ut.cpp + thread/singleton_ut.cpp +) + +PEERDIR( + library/cpp/threading/future +) + +END() diff --git a/util/thread/ya.make b/util/thread/ya.make new file mode 100644 index 0000000000..79c9498ddd --- /dev/null +++ b/util/thread/ya.make @@ -0,0 +1,6 @@ +OWNER(g:util) +SUBSCRIBER(g:util-subscribers) + +RECURSE_FOR_TESTS( + ut +) |