diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/threading | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/threading')
121 files changed, 12044 insertions, 0 deletions
diff --git a/library/cpp/threading/atomic/bool.cpp b/library/cpp/threading/atomic/bool.cpp new file mode 100644 index 0000000000..37917e01f1 --- /dev/null +++ b/library/cpp/threading/atomic/bool.cpp @@ -0,0 +1 @@ +#include "bool.h" diff --git a/library/cpp/threading/atomic/bool.h b/library/cpp/threading/atomic/bool.h new file mode 100644 index 0000000000..d52544e762 --- /dev/null +++ b/library/cpp/threading/atomic/bool.h @@ -0,0 +1,36 @@ +#pragma once + +#include <util/system/atomic.h> + +namespace NAtomic { + class TBool { + public: + TBool() noexcept = default; + + TBool(bool val) noexcept + : Val_(val) + { + } + + TBool(const TBool& src) noexcept { + AtomicSet(Val_, AtomicGet(src.Val_)); + } + + operator bool() const noexcept { + return AtomicGet(Val_); + } + + const TBool& operator=(bool val) noexcept { + AtomicSet(Val_, val); + return *this; + } + + const TBool& operator=(const TBool& src) noexcept { + AtomicSet(Val_, AtomicGet(src.Val_)); + return *this; + } + + private: + TAtomic Val_ = 0; + }; +} diff --git a/library/cpp/threading/atomic/bool_ut.cpp b/library/cpp/threading/atomic/bool_ut.cpp new file mode 100644 index 0000000000..9481f41d8d --- /dev/null +++ b/library/cpp/threading/atomic/bool_ut.cpp @@ -0,0 +1,31 @@ +#include "bool.h" + +#include <library/cpp/testing/unittest/registar.h> + +Y_UNIT_TEST_SUITE(AtomicBool) { + Y_UNIT_TEST(ReadWrite) { + NAtomic::TBool v; + + UNIT_ASSERT_VALUES_EQUAL((bool)v, false); + + v = true; + + UNIT_ASSERT_VALUES_EQUAL((bool)v, true); + + v = false; + + UNIT_ASSERT_VALUES_EQUAL((bool)v, false); + + NAtomic::TBool v2; + + UNIT_ASSERT(v == v2); + + v2 = true; + + UNIT_ASSERT(v != v2); + + v = v2; + + UNIT_ASSERT(v == v2); + } +} diff --git a/library/cpp/threading/atomic/ut/ya.make b/library/cpp/threading/atomic/ut/ya.make new file mode 100644 index 0000000000..3c555685df --- /dev/null +++ b/library/cpp/threading/atomic/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/threading/atomic) + +OWNER(vmordovin) + +SRCS( + bool_ut.cpp +) + +END() diff --git a/library/cpp/threading/atomic/ya.make b/library/cpp/threading/atomic/ya.make new file mode 100644 index 0000000000..c3a3ef8a76 --- /dev/null +++ b/library/cpp/threading/atomic/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(vmordovin) + +SRCS( + bool.cpp +) + +END() diff --git a/library/cpp/threading/chunk_queue/queue.cpp b/library/cpp/threading/chunk_queue/queue.cpp new file mode 100644 index 0000000000..4ebd3f3205 --- /dev/null +++ b/library/cpp/threading/chunk_queue/queue.cpp @@ -0,0 +1 @@ +#include "queue.h" diff --git a/library/cpp/threading/chunk_queue/queue.h b/library/cpp/threading/chunk_queue/queue.h new file mode 100644 index 0000000000..55859601a1 --- /dev/null +++ b/library/cpp/threading/chunk_queue/queue.h @@ -0,0 +1,568 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/generic/typetraits.h> +#include <util/generic/vector.h> +#include <util/generic/ylimits.h> +#include <util/system/atomic.h> +#include <util/system/guard.h> +#include <util/system/spinlock.h> +#include <util/system/yassert.h> + +#include <type_traits> +#include <utility> + +namespace NThreading { +//////////////////////////////////////////////////////////////////////////////// +// Platform helpers + +#if !defined(PLATFORM_CACHE_LINE) +#define PLATFORM_CACHE_LINE 64 +#endif + +#if !defined(PLATFORM_PAGE_SIZE) +#define PLATFORM_PAGE_SIZE 4 * 1024 +#endif + + template <typename T, size_t PadSize = PLATFORM_CACHE_LINE> + struct TPadded: public T { + char Pad[PadSize - sizeof(T) % PadSize]; + + TPadded() { + static_assert(sizeof(*this) % PadSize == 0, "padding does not work"); + Y_UNUSED(Pad); + } + + template<typename... Args> + TPadded(Args&&... args) + : T(std::forward<Args>(args)...) + { + static_assert(sizeof(*this) % PadSize == 0, "padding does not work"); + Y_UNUSED(Pad); + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Type helpers + + namespace NImpl { + template <typename T> + struct TPodTypeHelper { + template <typename TT> + static void Write(T* ptr, TT&& value) { + *ptr = value; + } + + static T Read(T* ptr) { + return *ptr; + } + + static void Destroy(T* ptr) { + Y_UNUSED(ptr); + } + }; + + template <typename T> + struct TNonPodTypeHelper { + template <typename TT> + static void Write(T* ptr, TT&& value) { + new (ptr) T(std::forward<TT>(value)); + } + + static T Read(T* ptr) { + return std::move(*ptr); + } + + static void Destroy(T* ptr) { + (void)ptr; /* Make MSVC happy. */ + ptr->~T(); + } + }; + + template <typename T> + using TTypeHelper = std::conditional_t< + TTypeTraits<T>::IsPod, + TPodTypeHelper<T>, + TNonPodTypeHelper<T>>; + + } + + //////////////////////////////////////////////////////////////////////////////// + // One producer/one consumer chunked queue. + + template <typename T, size_t ChunkSize = PLATFORM_PAGE_SIZE> + class TOneOneQueue: private TNonCopyable { + using TTypeHelper = NImpl::TTypeHelper<T>; + + struct TChunk; + + struct TChunkHeader { + size_t Count = 0; + TChunk* Next = nullptr; + }; + + struct TChunk: public TChunkHeader { + static constexpr size_t MaxCount = (ChunkSize - sizeof(TChunkHeader)) / sizeof(T); + + char Entries[MaxCount * sizeof(T)]; + + TChunk() { + Y_UNUSED(Entries); // uninitialized + } + + ~TChunk() { + for (size_t i = 0; i < this->Count; ++i) { + TTypeHelper::Destroy(GetPtr(i)); + } + } + + T* GetPtr(size_t i) { + return (T*)Entries + i; + } + }; + + struct TWriterState { + TChunk* Chunk = nullptr; + }; + + struct TReaderState { + TChunk* Chunk = nullptr; + size_t Count = 0; + }; + + private: + TPadded<TWriterState> Writer; + TPadded<TReaderState> Reader; + + public: + using TItem = T; + + TOneOneQueue() { + Writer.Chunk = Reader.Chunk = new TChunk(); + } + + ~TOneOneQueue() { + DeleteChunks(Reader.Chunk); + } + + template <typename TT> + void Enqueue(TT&& value) { + T* ptr = PrepareWrite(); + Y_ASSERT(ptr); + TTypeHelper::Write(ptr, std::forward<TT>(value)); + CompleteWrite(); + } + + bool Dequeue(T& value) { + if (T* ptr = PrepareRead()) { + value = TTypeHelper::Read(ptr); + CompleteRead(); + return true; + } + return false; + } + + bool IsEmpty() { + return !PrepareRead(); + } + + protected: + T* PrepareWrite() { + TChunk* chunk = Writer.Chunk; + Y_ASSERT(chunk && !chunk->Next); + + if (chunk->Count != TChunk::MaxCount) { + return chunk->GetPtr(chunk->Count); + } + + chunk = new TChunk(); + AtomicSet(Writer.Chunk->Next, chunk); + Writer.Chunk = chunk; + return chunk->GetPtr(0); + } + + void CompleteWrite() { + AtomicSet(Writer.Chunk->Count, Writer.Chunk->Count + 1); + } + + T* PrepareRead() { + TChunk* chunk = Reader.Chunk; + Y_ASSERT(chunk); + + for (;;) { + size_t writerCount = AtomicGet(chunk->Count); + if (Reader.Count != writerCount) { + return chunk->GetPtr(Reader.Count); + } + + if (writerCount != TChunk::MaxCount) { + return nullptr; + } + + chunk = AtomicGet(chunk->Next); + if (!chunk) { + return nullptr; + } + + delete Reader.Chunk; + Reader.Chunk = chunk; + Reader.Count = 0; + } + } + + void CompleteRead() { + ++Reader.Count; + } + + private: + static void DeleteChunks(TChunk* chunk) { + while (chunk) { + TChunk* next = chunk->Next; + delete chunk; + chunk = next; + } + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Multiple producers/single consumer partitioned queue. + // Provides FIFO guaranties for each producer. + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + class TManyOneQueue: private TNonCopyable { + using TTypeHelper = NImpl::TTypeHelper<T>; + + struct TEntry { + T Value; + ui64 Tag; + }; + + struct TQueueType: public TOneOneQueue<TEntry, ChunkSize> { + TAtomic WriteLock = 0; + + using TOneOneQueue<TEntry, ChunkSize>::PrepareWrite; + using TOneOneQueue<TEntry, ChunkSize>::CompleteWrite; + + using TOneOneQueue<TEntry, ChunkSize>::PrepareRead; + using TOneOneQueue<TEntry, ChunkSize>::CompleteRead; + }; + + private: + union { + TAtomic WriteTag = 0; + char Pad[PLATFORM_CACHE_LINE]; + }; + + TQueueType Queues[Concurrency]; + + public: + using TItem = T; + + template <typename TT> + void Enqueue(TT&& value) { + ui64 tag = NextTag(); + while (!TryEnqueue(std::forward<TT>(value), tag)) { + SpinLockPause(); + } + } + + bool Dequeue(T& value) { + size_t index = 0; + if (TEntry* entry = PrepareRead(index)) { + value = TTypeHelper::Read(&entry->Value); + Queues[index].CompleteRead(); + return true; + } + return false; + } + + bool IsEmpty() { + for (size_t i = 0; i < Concurrency; ++i) { + if (!Queues[i].IsEmpty()) { + return false; + } + } + return true; + } + + private: + ui64 NextTag() { + // TODO: can we avoid synchronization here? it costs 1.5x performance penalty + // return GetCycleCount(); + return AtomicIncrement(WriteTag); + } + + template <typename TT> + bool TryEnqueue(TT&& value, ui64 tag) { + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[i]; + if (AtomicTryAndTryLock(&queue.WriteLock)) { + TEntry* entry = queue.PrepareWrite(); + Y_ASSERT(entry); + TTypeHelper::Write(&entry->Value, std::forward<TT>(value)); + entry->Tag = tag; + queue.CompleteWrite(); + AtomicUnlock(&queue.WriteLock); + return true; + } + } + return false; + } + + TEntry* PrepareRead(size_t& index) { + TEntry* entry = nullptr; + ui64 tag = Max(); + + for (size_t i = 0; i < Concurrency; ++i) { + TEntry* e = Queues[i].PrepareRead(); + if (e && e->Tag < tag) { + index = i; + entry = e; + tag = e->Tag; + } + } + + if (entry) { + // need second pass to catch updates within already scanned range + size_t candidate = index; + for (size_t i = 0; i < candidate; ++i) { + TEntry* e = Queues[i].PrepareRead(); + if (e && e->Tag < tag) { + index = i; + entry = e; + tag = e->Tag; + } + } + } + + return entry; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Concurrent many-many queue with strong FIFO guaranties. + // Writers will not block readers (and vice versa), but will block each other. + + template <typename T, size_t ChunkSize = PLATFORM_PAGE_SIZE, typename TLock = TAdaptiveLock> + class TManyManyQueue: private TNonCopyable { + private: + TPadded<TLock> WriteLock; + TPadded<TLock> ReadLock; + + TOneOneQueue<T, ChunkSize> Queue; + + public: + using TItem = T; + + template <typename TT> + void Enqueue(TT&& value) { + with_lock (WriteLock) { + Queue.Enqueue(std::forward<TT>(value)); + } + } + + bool Dequeue(T& value) { + with_lock (ReadLock) { + return Queue.Dequeue(value); + } + } + + bool IsEmpty() { + with_lock (ReadLock) { + return Queue.IsEmpty(); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Multiple producers/single consumer partitioned queue. + // Because of random partitioning reordering possible - FIFO not guaranteed! + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + class TRelaxedManyOneQueue: private TNonCopyable { + struct TQueueType: public TOneOneQueue<T, ChunkSize> { + TAtomic WriteLock = 0; + }; + + private: + union { + size_t ReadPos = 0; + char Pad[PLATFORM_CACHE_LINE]; + }; + + TQueueType Queues[Concurrency]; + + public: + using TItem = T; + + template <typename TT> + void Enqueue(TT&& value) { + while (!TryEnqueue(std::forward<TT>(value))) { + SpinLockPause(); + } + } + + bool Dequeue(T& value) { + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[ReadPos++ % Concurrency]; + if (queue.Dequeue(value)) { + return true; + } + } + return false; + } + + bool IsEmpty() { + for (size_t i = 0; i < Concurrency; ++i) { + if (!Queues[i].IsEmpty()) { + return false; + } + } + return true; + } + + private: + template <typename TT> + bool TryEnqueue(TT&& value) { + size_t writePos = GetCycleCount(); + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[writePos++ % Concurrency]; + if (AtomicTryAndTryLock(&queue.WriteLock)) { + queue.Enqueue(std::forward<TT>(value)); + AtomicUnlock(&queue.WriteLock); + return true; + } + } + return false; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Concurrent many-many partitioned queue. + // Because of random partitioning reordering possible - FIFO not guaranteed! + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + class TRelaxedManyManyQueue: private TNonCopyable { + struct TQueueType: public TOneOneQueue<T, ChunkSize> { + union { + TAtomic WriteLock = 0; + char Pad1[PLATFORM_CACHE_LINE]; + }; + union { + TAtomic ReadLock = 0; + char Pad2[PLATFORM_CACHE_LINE]; + }; + }; + + private: + TQueueType Queues[Concurrency]; + + public: + using TItem = T; + + template <typename TT> + void Enqueue(TT&& value) { + while (!TryEnqueue(std::forward<TT>(value))) { + SpinLockPause(); + } + } + + bool Dequeue(T& value) { + size_t readPos = GetCycleCount(); + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[readPos++ % Concurrency]; + if (AtomicTryAndTryLock(&queue.ReadLock)) { + bool dequeued = queue.Dequeue(value); + AtomicUnlock(&queue.ReadLock); + if (dequeued) { + return true; + } + } + } + return false; + } + + bool IsEmpty() { + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[i]; + if (AtomicTryAndTryLock(&queue.ReadLock)) { + bool empty = queue.IsEmpty(); + AtomicUnlock(&queue.ReadLock); + if (!empty) { + return false; + } + } + } + return true; + } + + private: + template <typename TT> + bool TryEnqueue(TT&& value) { + size_t writePos = GetCycleCount(); + for (size_t i = 0; i < Concurrency; ++i) { + TQueueType& queue = Queues[writePos++ % Concurrency]; + if (AtomicTryAndTryLock(&queue.WriteLock)) { + queue.Enqueue(std::forward<TT>(value)); + AtomicUnlock(&queue.WriteLock); + return true; + } + } + return false; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Simple wrapper to deal with AutoPtrs + + template <typename T, typename TImpl> + class TAutoQueueBase: private TNonCopyable { + private: + TImpl Impl; + + public: + using TItem = TAutoPtr<T>; + + ~TAutoQueueBase() { + TAutoPtr<T> value; + while (Dequeue(value)) { + // do nothing + } + } + + void Enqueue(TAutoPtr<T> value) { + Impl.Enqueue(value.Get()); + Y_UNUSED(value.Release()); + } + + bool Dequeue(TAutoPtr<T>& value) { + T* ptr = nullptr; + if (Impl.Dequeue(ptr)) { + value.Reset(ptr); + return true; + } + return false; + } + + bool IsEmpty() { + return Impl.IsEmpty(); + } + }; + + template <typename T, size_t ChunkSize = PLATFORM_PAGE_SIZE> + using TAutoOneOneQueue = TAutoQueueBase<T, TOneOneQueue<T*, ChunkSize>>; + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + using TAutoManyOneQueue = TAutoQueueBase<T, TManyOneQueue<T*, Concurrency, ChunkSize>>; + + template <typename T, size_t ChunkSize = PLATFORM_PAGE_SIZE, typename TLock = TAdaptiveLock> + using TAutoManyManyQueue = TAutoQueueBase<T, TManyManyQueue<T*, ChunkSize, TLock>>; + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + using TAutoRelaxedManyOneQueue = TAutoQueueBase<T, TRelaxedManyOneQueue<T*, Concurrency, ChunkSize>>; + + template <typename T, size_t Concurrency = 4, size_t ChunkSize = PLATFORM_PAGE_SIZE> + using TAutoRelaxedManyManyQueue = TAutoQueueBase<T, TRelaxedManyManyQueue<T*, Concurrency, ChunkSize>>; +} diff --git a/library/cpp/threading/chunk_queue/queue_ut.cpp b/library/cpp/threading/chunk_queue/queue_ut.cpp new file mode 100644 index 0000000000..8cb36d8dd1 --- /dev/null +++ b/library/cpp/threading/chunk_queue/queue_ut.cpp @@ -0,0 +1,205 @@ +#include "queue.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/set.h> + +namespace NThreading { + //////////////////////////////////////////////////////////////////////////////// + + Y_UNIT_TEST_SUITE(TOneOneQueueTest){ + Y_UNIT_TEST(ShouldBeEmptyAtStart){ + TOneOneQueue<int> queue; + + int result = 0; + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldReturnEntries) { + TOneOneQueue<int> queue; + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 1); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 2); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 3); + + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldStoreMultipleChunks) { + TOneOneQueue<int, 100> queue; + for (int i = 0; i < 1000; ++i) { + queue.Enqueue(i); + } + + for (int i = 0; i < 1000; ++i) { + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, i); + } +} +} +; + +//////////////////////////////////////////////////////////////////////////////// + +Y_UNIT_TEST_SUITE(TManyOneQueueTest){ + Y_UNIT_TEST(ShouldBeEmptyAtStart){ + TManyOneQueue<int> queue; + +int result; +UNIT_ASSERT(queue.IsEmpty()); +UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldReturnEntries) { + TManyOneQueue<int> queue; + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 1); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 2); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 3); + + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} +} +; + +//////////////////////////////////////////////////////////////////////////////// + +Y_UNIT_TEST_SUITE(TManyManyQueueTest){ + Y_UNIT_TEST(ShouldBeEmptyAtStart){ + TManyManyQueue<int> queue; + +int result = 0; +UNIT_ASSERT(queue.IsEmpty()); +UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldReturnEntries) { + TManyManyQueue<int> queue; + queue.Enqueue(1); + queue.Enqueue(2); + queue.Enqueue(3); + + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 1); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 2); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT_EQUAL(result, 3); + + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} +} +; + +//////////////////////////////////////////////////////////////////////////////// + +Y_UNIT_TEST_SUITE(TRelaxedManyOneQueueTest){ + Y_UNIT_TEST(ShouldBeEmptyAtStart){ + TRelaxedManyOneQueue<int> queue; + +int result; +UNIT_ASSERT(queue.IsEmpty()); +UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldReturnEntries) { + TSet<int> items = {1, 2, 3}; + + TRelaxedManyOneQueue<int> queue; + for (int item : items) { + queue.Enqueue(item); + } + + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} +} +; + +//////////////////////////////////////////////////////////////////////////////// + +Y_UNIT_TEST_SUITE(TRelaxedManyManyQueueTest){ + Y_UNIT_TEST(ShouldBeEmptyAtStart){ + TRelaxedManyManyQueue<int> queue; + +int result = 0; +UNIT_ASSERT(queue.IsEmpty()); +UNIT_ASSERT(!queue.Dequeue(result)); +} + +Y_UNIT_TEST(ShouldReturnEntries) { + TSet<int> items = {1, 2, 3}; + + TRelaxedManyManyQueue<int> queue; + for (int item : items) { + queue.Enqueue(item); + } + + int result = 0; + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(!queue.IsEmpty()); + UNIT_ASSERT(queue.Dequeue(result)); + UNIT_ASSERT(items.erase(result)); + + UNIT_ASSERT(queue.IsEmpty()); + UNIT_ASSERT(!queue.Dequeue(result)); +} +} +; +} diff --git a/library/cpp/threading/chunk_queue/readme.txt b/library/cpp/threading/chunk_queue/readme.txt new file mode 100644 index 0000000000..7c9f046a86 --- /dev/null +++ b/library/cpp/threading/chunk_queue/readme.txt @@ -0,0 +1,60 @@ +vskipin@dev-kiwi09:~$ ./rtmr-queue-perf -w 4 -r 4 AdaptiveLock64 Mutex64 LFManyMany64 FastLFManyMany64 LFManyOne64 FastLFManyOne64 ManyMany64 ManyOne64 +2016-05-08T11:49:56.729254Z INFO: [-i] Iterations: 10000000 +2016-05-08T11:49:56.729319Z INFO: [-r] NumReaders: 4 +2016-05-08T11:49:56.729355Z INFO: [-w] NumWriters: 4 +2016-05-08T11:49:56.729502Z INFO: starting consumers... +2016-05-08T11:49:56.729621Z INFO: starting producers... +2016-05-08T11:49:56.729711Z INFO: wait for producers... +2016-05-08T11:50:14.650803Z INFO: wait for consumers... +2016-05-08T11:50:14.650859Z INFO: average producer time: 15.96846675 seconds +2016-05-08T11:50:14.650885Z INFO: average consumer time: 17.9209995 seconds +2016-05-08T11:50:14.650897Z INFO: test AdaptiveLock64 duration: 17.921395s (0.448034875us per iteration) +2016-05-08T11:50:14.650913Z INFO: starting consumers... +2016-05-08T11:50:14.651028Z INFO: starting producers... +2016-05-08T11:50:14.651122Z INFO: wait for producers... +2016-05-08T11:50:31.426378Z INFO: wait for consumers... +2016-05-08T11:50:31.426447Z INFO: average producer time: 15.58770475 seconds +2016-05-08T11:50:31.426491Z INFO: average consumer time: 16.775301 seconds +2016-05-08T11:50:31.426527Z INFO: test Mutex64 duration: 16.775614s (0.41939035us per iteration) +2016-05-08T11:50:31.426584Z INFO: starting consumers... +2016-05-08T11:50:31.426655Z INFO: starting producers... +2016-05-08T11:50:31.426749Z INFO: wait for producers... +2016-05-08T11:50:40.578425Z INFO: wait for consumers... +2016-05-08T11:50:40.578523Z INFO: average producer time: 8.69236075 seconds +2016-05-08T11:50:40.578577Z INFO: average consumer time: 9.15165125 seconds +2016-05-08T11:50:40.578617Z INFO: test LFManyMany64 duration: 9.152033s (0.228800825us per iteration) +2016-05-08T11:50:40.578670Z INFO: starting consumers... +2016-05-08T11:50:40.578742Z INFO: starting producers... +2016-05-08T11:50:40.578893Z INFO: wait for producers... +2016-05-08T11:50:47.447686Z INFO: wait for consumers... +2016-05-08T11:50:47.447758Z INFO: average producer time: 6.81136025 seconds +2016-05-08T11:50:47.447793Z INFO: average consumer time: 6.86875825 seconds +2016-05-08T11:50:47.447834Z INFO: test FastLFManyMany64 duration: 6.869165s (0.171729125us per iteration) +2016-05-08T11:50:47.447901Z INFO: starting consumers... +2016-05-08T11:50:47.447967Z INFO: starting producers... +2016-05-08T11:50:47.448058Z INFO: wait for producers... +2016-05-08T11:50:50.469710Z INFO: wait for consumers... +2016-05-08T11:50:50.469798Z INFO: average producer time: 2.9915505 seconds +2016-05-08T11:50:50.469848Z INFO: average consumer time: 3.02161675 seconds +2016-05-08T11:50:50.469883Z INFO: test LFManyOne64 duration: 3.021983s (0.075549575us per iteration) +2016-05-08T11:50:50.469947Z INFO: starting consumers... +2016-05-08T11:50:50.470012Z INFO: starting producers... +2016-05-08T11:50:50.470104Z INFO: wait for producers... +2016-05-08T11:50:53.139964Z INFO: wait for consumers... +2016-05-08T11:50:53.140050Z INFO: average producer time: 2.5656465 seconds +2016-05-08T11:50:53.140102Z INFO: average consumer time: 2.6697755 seconds +2016-05-08T11:50:53.140149Z INFO: test FastLFManyOne64 duration: 2.670202s (0.06675505us per iteration) +2016-05-08T11:50:53.140206Z INFO: starting consumers... +2016-05-08T11:50:53.140281Z INFO: starting producers... +2016-05-08T11:50:53.140371Z INFO: wait for producers... +2016-05-08T11:50:59.067812Z INFO: wait for consumers... +2016-05-08T11:50:59.067895Z INFO: average producer time: 5.8925505 seconds +2016-05-08T11:50:59.067946Z INFO: average consumer time: 5.9273365 seconds +2016-05-08T11:50:59.067978Z INFO: test ManyMany64 duration: 5.927773s (0.148194325us per iteration) +2016-05-08T11:50:59.068068Z INFO: starting consumers... +2016-05-08T11:50:59.068179Z INFO: starting producers... +2016-05-08T11:50:59.068288Z INFO: wait for producers... +2016-05-08T11:51:03.427416Z INFO: wait for consumers... +2016-05-08T11:51:03.427514Z INFO: average producer time: 4.1055505 seconds +2016-05-08T11:51:03.427560Z INFO: average consumer time: 4.35914975 seconds +2016-05-08T11:51:03.427596Z INFO: test ManyOne64 duration: 4.359529s (0.108988225us per iteration) diff --git a/library/cpp/threading/chunk_queue/ut/ya.make b/library/cpp/threading/chunk_queue/ut/ya.make new file mode 100644 index 0000000000..a35ed6bc4b --- /dev/null +++ b/library/cpp/threading/chunk_queue/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/threading/chunk_queue) + +OWNER(g:rtmr) + +SRCS( + queue_ut.cpp +) + +END() diff --git a/library/cpp/threading/chunk_queue/ya.make b/library/cpp/threading/chunk_queue/ya.make new file mode 100644 index 0000000000..2f883140ba --- /dev/null +++ b/library/cpp/threading/chunk_queue/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(g:rtmr) + +SRCS( + queue.cpp +) + +END() diff --git a/library/cpp/threading/equeue/equeue.cpp b/library/cpp/threading/equeue/equeue.cpp new file mode 100644 index 0000000000..54a848e912 --- /dev/null +++ b/library/cpp/threading/equeue/equeue.cpp @@ -0,0 +1,80 @@ +#include "equeue.h" + +TElasticQueue::TElasticQueue(THolder<IThreadPool> slaveQueue) + : SlaveQueue_(std::move(slaveQueue)) +{ +} + +size_t TElasticQueue::ObjectCount() const { + return (size_t)AtomicGet(ObjectCount_); +} + +bool TElasticQueue::TryIncCounter() { + if ((size_t)AtomicIncrement(GuardCount_) > MaxQueueSize_) { + AtomicDecrement(GuardCount_); + return false; + } + + return true; +} + + + +class TElasticQueue::TDecrementingWrapper: TNonCopyable, public IObjectInQueue { +public: + TDecrementingWrapper(IObjectInQueue* realObject, TElasticQueue* queue) + : RealObject_(realObject) + , Queue_(queue) + { + AtomicIncrement(Queue_->ObjectCount_); + } + + ~TDecrementingWrapper() override { + AtomicDecrement(Queue_->ObjectCount_); + AtomicDecrement(Queue_->GuardCount_); + } +private: + void Process(void *tsr) override { + THolder<TDecrementingWrapper> self(this); + RealObject_->Process(tsr); + } +private: + IObjectInQueue* const RealObject_; + TElasticQueue* const Queue_; +}; + + + +bool TElasticQueue::Add(IObjectInQueue* obj) { + if (!TryIncCounter()) { + return false; + } + + THolder<TDecrementingWrapper> wrapper; + try { + wrapper.Reset(new TDecrementingWrapper(obj, this)); + } catch (...) { + AtomicDecrement(GuardCount_); + throw; + } + + if (SlaveQueue_->Add(wrapper.Get())) { + Y_UNUSED(wrapper.Release()); + return true; + } else { + return false; + } +} + +void TElasticQueue::Start(size_t threadCount, size_t maxQueueSize) { + MaxQueueSize_ = maxQueueSize; + SlaveQueue_->Start(threadCount, maxQueueSize); +} + +void TElasticQueue::Stop() noexcept { + return SlaveQueue_->Stop(); +} + +size_t TElasticQueue::Size() const noexcept { + return SlaveQueue_->Size(); +} diff --git a/library/cpp/threading/equeue/equeue.h b/library/cpp/threading/equeue/equeue.h new file mode 100644 index 0000000000..40dd342585 --- /dev/null +++ b/library/cpp/threading/equeue/equeue.h @@ -0,0 +1,28 @@ +#pragma once + +#include <util/thread/pool.h> +#include <util/system/atomic.h> +#include <util/generic/ptr.h> + +//actual queue limit will be (maxQueueSize - numBusyThreads) or 0 +class TElasticQueue: public IThreadPool { +public: + explicit TElasticQueue(THolder<IThreadPool> slaveQueue); + + bool Add(IObjectInQueue* obj) override; + size_t Size() const noexcept override; + + void Start(size_t threadCount, size_t maxQueueSize) override; + void Stop() noexcept override; + + size_t ObjectCount() const; +private: + class TDecrementingWrapper; + + bool TryIncCounter(); +private: + THolder<IThreadPool> SlaveQueue_; + size_t MaxQueueSize_ = 0; + TAtomic ObjectCount_ = 0; + TAtomic GuardCount_ = 0; +}; diff --git a/library/cpp/threading/equeue/equeue_ut.cpp b/library/cpp/threading/equeue/equeue_ut.cpp new file mode 100644 index 0000000000..9cf2aced44 --- /dev/null +++ b/library/cpp/threading/equeue/equeue_ut.cpp @@ -0,0 +1,125 @@ +#include "equeue.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/system/event.h> +#include <util/datetime/base.h> +#include <util/generic/vector.h> + +Y_UNIT_TEST_SUITE(TElasticQueueTest) { + const size_t MaxQueueSize = 20; + const size_t ThreadCount = 10; + const size_t N = 100000; + + static THolder<TElasticQueue> Queue; + + struct TQueueSetup { + TQueueSetup() { + Queue.Reset(new TElasticQueue(MakeHolder<TSimpleThreadPool>())); + Queue->Start(ThreadCount, MaxQueueSize); + } + ~TQueueSetup() { + Queue->Stop(); + } + }; + + struct TCounters { + void Reset() { + Processed = Scheduled = Discarded = Total = 0; + } + + TAtomic Processed; + TAtomic Scheduled; + TAtomic Discarded; + TAtomic Total; + }; + static TCounters Counters; + +//fill test -- fill queue with "endless" jobs + TSystemEvent WaitEvent; + Y_UNIT_TEST(FillTest) { + Counters.Reset(); + + struct TWaitJob: public IObjectInQueue { + void Process(void*) override { + WaitEvent.Wait(); + AtomicIncrement(Counters.Processed); + } + } job; + + struct TLocalSetup: TQueueSetup { + ~TLocalSetup() { + WaitEvent.Signal(); + } + }; + + size_t enqueued = 0; + { + TLocalSetup setup; + while (Queue->Add(&job) && enqueued < MaxQueueSize + 100) { + ++enqueued; + } + + UNIT_ASSERT_VALUES_EQUAL(enqueued, MaxQueueSize); + UNIT_ASSERT_VALUES_EQUAL(enqueued, Queue->ObjectCount()); + } + + UNIT_ASSERT_VALUES_EQUAL(0u, Queue->ObjectCount()); + UNIT_ASSERT_VALUES_EQUAL(0u, Queue->Size()); + UNIT_ASSERT_VALUES_EQUAL((size_t)Counters.Processed, enqueued); + } + + +//concurrent test -- send many jobs from different threads + struct TJob: public IObjectInQueue { + void Process(void*) override { + AtomicIncrement(Counters.Processed); + }; + }; + static TJob Job; + + static bool TryAdd() { + AtomicIncrement(Counters.Total); + if (Queue->Add(&Job)) { + AtomicIncrement(Counters.Scheduled); + return true; + } else { + AtomicIncrement(Counters.Discarded); + return false; + } + } + + static size_t TryCounter; + + Y_UNIT_TEST(ConcurrentTest) { + Counters.Reset(); + TryCounter = 0; + + struct TSender: public IThreadFactory::IThreadAble { + void DoExecute() override { + while ((size_t)AtomicIncrement(TryCounter) <= N) { + if (!TryAdd()) { + Sleep(TDuration::MicroSeconds(50)); + } + } + } + } sender; + + { + TQueueSetup setup; + + TVector< TAutoPtr<IThreadFactory::IThread> > senders; + for (size_t i = 0; i < ThreadCount; ++i) { + senders.push_back(::SystemThreadFactory()->Run(&sender)); + } + + for (size_t i = 0; i < senders.size(); ++i) { + senders[i]->Join(); + } + } + + UNIT_ASSERT_VALUES_EQUAL((size_t)Counters.Total, N); + UNIT_ASSERT_VALUES_EQUAL(Counters.Processed, Counters.Scheduled); + UNIT_ASSERT_VALUES_EQUAL(Counters.Total, Counters.Scheduled + Counters.Discarded); + } +} diff --git a/library/cpp/threading/equeue/ut/ya.make b/library/cpp/threading/equeue/ut/ya.make new file mode 100644 index 0000000000..2f6293d47d --- /dev/null +++ b/library/cpp/threading/equeue/ut/ya.make @@ -0,0 +1,18 @@ +UNITTEST() + +OWNER( + g:base + g:middle +) + +PEERDIR( + ADDINCL library/cpp/threading/equeue +) + +SRCDIR(library/cpp/threading/equeue) + +SRCS( + equeue_ut.cpp +) + +END() diff --git a/library/cpp/threading/equeue/ya.make b/library/cpp/threading/equeue/ya.make new file mode 100644 index 0000000000..314f4d3c86 --- /dev/null +++ b/library/cpp/threading/equeue/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER( + g:base + g:middle + ironpeter + mvel +) + +SRCS( + equeue.h + equeue.cpp +) + +END() diff --git a/library/cpp/threading/future/async.cpp b/library/cpp/threading/future/async.cpp new file mode 100644 index 0000000000..ad9b21a2cf --- /dev/null +++ b/library/cpp/threading/future/async.cpp @@ -0,0 +1 @@ +#include "async.h" diff --git a/library/cpp/threading/future/async.h b/library/cpp/threading/future/async.h new file mode 100644 index 0000000000..8543fdd5c6 --- /dev/null +++ b/library/cpp/threading/future/async.h @@ -0,0 +1,31 @@ +#pragma once + +#include "future.h" + +#include <util/generic/function.h> +#include <util/thread/pool.h> + +namespace NThreading { + /** + * @brief Asynchronously executes @arg func in @arg queue returning a future for the result. + * + * @arg func should be a callable object with signature T(). + * @arg queue where @arg will be executed + * @returns For @arg func with signature T() the function returns TFuture<T> unless T is TFuture<U>. + * In this case the function returns TFuture<U>. + * + * If you want to use another queue for execution just write an overload, @see ExtensionExample + * unittest. + */ + template <typename Func> + TFuture<TFutureType<TFunctionResult<Func>>> Async(Func&& func, IThreadPool& queue) { + auto promise = NewPromise<TFutureType<TFunctionResult<Func>>>(); + auto lambda = [promise, func = std::forward<Func>(func)]() mutable { + NImpl::SetValue(promise, func); + }; + queue.SafeAddFunc(std::move(lambda)); + + return promise.GetFuture(); + } + +} diff --git a/library/cpp/threading/future/async_ut.cpp b/library/cpp/threading/future/async_ut.cpp new file mode 100644 index 0000000000..a3699744e4 --- /dev/null +++ b/library/cpp/threading/future/async_ut.cpp @@ -0,0 +1,57 @@ +#include "async.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/ptr.h> +#include <util/generic/vector.h> + +namespace { + struct TMySuperTaskQueue { + }; + +} + +namespace NThreading { + /* Here we provide an Async overload for TMySuperTaskQueue indide NThreading namespace + * so that we can call it in the way + * + * TMySuperTaskQueue queue; + * NThreading::Async([](){}, queue); + * + * See also ExtensionExample unittest. + */ + template <typename Func> + TFuture<TFunctionResult<Func>> Async(Func func, TMySuperTaskQueue&) { + return MakeFuture(func()); + } + +} + +Y_UNIT_TEST_SUITE(Async) { + Y_UNIT_TEST(ExtensionExample) { + TMySuperTaskQueue queue; + auto future = NThreading::Async([]() { return 5; }, queue); + future.Wait(); + UNIT_ASSERT_VALUES_EQUAL(future.GetValue(), 5); + } + + Y_UNIT_TEST(WorksWithIMtpQueue) { + auto queue = MakeHolder<TThreadPool>(); + queue->Start(1); + + auto future = NThreading::Async([]() { return 5; }, *queue); + future.Wait(); + UNIT_ASSERT_VALUES_EQUAL(future.GetValue(), 5); + } + + Y_UNIT_TEST(ProperlyDeducesFutureType) { + // Compileability test + auto queue = CreateThreadPool(1); + + NThreading::TFuture<void> f1 = NThreading::Async([]() {}, *queue); + NThreading::TFuture<int> f2 = NThreading::Async([]() { return 5; }, *queue); + NThreading::TFuture<double> f3 = NThreading::Async([]() { return 5.0; }, *queue); + NThreading::TFuture<TVector<int>> f4 = NThreading::Async([]() { return TVector<int>(); }, *queue); + NThreading::TFuture<int> f5 = NThreading::Async([]() { return NThreading::MakeFuture(5); }, *queue); + } +} diff --git a/library/cpp/threading/future/core/future-inl.h b/library/cpp/threading/future/core/future-inl.h new file mode 100644 index 0000000000..5fd4296a93 --- /dev/null +++ b/library/cpp/threading/future/core/future-inl.h @@ -0,0 +1,986 @@ +#pragma once + +#if !defined(INCLUDE_FUTURE_INL_H) +#error "you should never include future-inl.h directly" +#endif // INCLUDE_FUTURE_INL_H + +namespace NThreading { + namespace NImpl { + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + using TCallback = std::function<void(const TFuture<T>&)>; + + template <typename T> + using TCallbackList = TVector<TCallback<T>>; // TODO: small vector + + //////////////////////////////////////////////////////////////////////////////// + + enum class TError { + Error + }; + + template <typename T> + class TFutureState: public TAtomicRefCount<TFutureState<T>> { + enum { + NotReady, + ExceptionSet, + ValueMoved, // keep the ordering of this and following values + ValueSet, + ValueRead, + }; + + private: + mutable TAtomic State; + TAdaptiveLock StateLock; + + TCallbackList<T> Callbacks; + mutable THolder<TSystemEvent> ReadyEvent; + + std::exception_ptr Exception; + + union { + char NullValue; + T Value; + }; + + void AccessValue(TDuration timeout, int acquireState) const { + int state = AtomicGet(State); + if (Y_UNLIKELY(state == NotReady)) { + if (timeout == TDuration::Zero()) { + ythrow TFutureException() << "value not set"; + } + + if (!Wait(timeout)) { + ythrow TFutureException() << "wait timeout"; + } + + state = AtomicGet(State); + } + + TryRethrowWithState(state); + + switch (AtomicGetAndCas(&State, acquireState, ValueSet)) { + case ValueSet: + break; + case ValueRead: + if (acquireState != ValueRead) { + ythrow TFutureException() << "value being read"; + } + break; + case ValueMoved: + ythrow TFutureException() << "value was moved"; + default: + Y_ASSERT(state == ValueSet); + } + } + + public: + TFutureState() + : State(NotReady) + , NullValue(0) + { + } + + template <typename TT> + TFutureState(TT&& value) + : State(ValueSet) + , Value(std::forward<TT>(value)) + { + } + + TFutureState(std::exception_ptr exception, TError) + : State(ExceptionSet) + , Exception(std::move(exception)) + , NullValue(0) + { + } + + ~TFutureState() { + if (State >= ValueMoved) { // ValueMoved, ValueSet, ValueRead + Value.~T(); + } + } + + bool HasValue() const { + return AtomicGet(State) >= ValueMoved; // ValueMoved, ValueSet, ValueRead + } + + void TryRethrow() const { + int state = AtomicGet(State); + TryRethrowWithState(state); + } + + bool HasException() const { + return AtomicGet(State) == ExceptionSet; + } + + const T& GetValue(TDuration timeout = TDuration::Zero()) const { + AccessValue(timeout, ValueRead); + return Value; + } + + T ExtractValue(TDuration timeout = TDuration::Zero()) { + AccessValue(timeout, ValueMoved); + return std::move(Value); + } + + template <typename TT> + void SetValue(TT&& value) { + bool success = TrySetValue(std::forward<TT>(value)); + if (Y_UNLIKELY(!success)) { + ythrow TFutureException() << "value already set"; + } + } + + template <typename TT> + bool TrySetValue(TT&& value) { + TSystemEvent* readyEvent = nullptr; + TCallbackList<T> callbacks; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (Y_UNLIKELY(state != NotReady)) { + return false; + } + + new (&Value) T(std::forward<TT>(value)); + + readyEvent = ReadyEvent.Get(); + callbacks = std::move(Callbacks); + + AtomicSet(State, ValueSet); + } + + if (readyEvent) { + readyEvent->Signal(); + } + + if (callbacks) { + TFuture<T> temp(this); + for (auto& callback : callbacks) { + callback(temp); + } + } + + return true; + } + + void SetException(std::exception_ptr e) { + bool success = TrySetException(std::move(e)); + if (Y_UNLIKELY(!success)) { + ythrow TFutureException() << "value already set"; + } + } + + bool TrySetException(std::exception_ptr e) { + TSystemEvent* readyEvent; + TCallbackList<T> callbacks; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (Y_UNLIKELY(state != NotReady)) { + return false; + } + + Exception = std::move(e); + + readyEvent = ReadyEvent.Get(); + callbacks = std::move(Callbacks); + + AtomicSet(State, ExceptionSet); + } + + if (readyEvent) { + readyEvent->Signal(); + } + + if (callbacks) { + TFuture<T> temp(this); + for (auto& callback : callbacks) { + callback(temp); + } + } + + return true; + } + + template <typename F> + bool Subscribe(F&& func) { + with_lock (StateLock) { + int state = AtomicGet(State); + if (state == NotReady) { + Callbacks.emplace_back(std::forward<F>(func)); + return true; + } + } + return false; + } + + void Wait() const { + Wait(TInstant::Max()); + } + + bool Wait(TDuration timeout) const { + return Wait(timeout.ToDeadLine()); + } + + bool Wait(TInstant deadline) const { + TSystemEvent* readyEvent = nullptr; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (state != NotReady) { + return true; + } + + if (!ReadyEvent) { + ReadyEvent.Reset(new TSystemEvent()); + } + readyEvent = ReadyEvent.Get(); + } + + Y_ASSERT(readyEvent); + return readyEvent->WaitD(deadline); + } + + void TryRethrowWithState(int state) const { + if (Y_UNLIKELY(state == ExceptionSet)) { + Y_ASSERT(Exception); + std::rethrow_exception(Exception); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template <> + class TFutureState<void>: public TAtomicRefCount<TFutureState<void>> { + enum { + NotReady, + ValueSet, + ExceptionSet, + }; + + private: + TAtomic State; + TAdaptiveLock StateLock; + + TCallbackList<void> Callbacks; + mutable THolder<TSystemEvent> ReadyEvent; + + std::exception_ptr Exception; + + public: + TFutureState(bool valueSet = false) + : State(valueSet ? ValueSet : NotReady) + { + } + + TFutureState(std::exception_ptr exception, TError) + : State(ExceptionSet) + , Exception(std::move(exception)) + { + } + + bool HasValue() const { + return AtomicGet(State) == ValueSet; + } + + void TryRethrow() const { + int state = AtomicGet(State); + TryRethrowWithState(state); + } + + bool HasException() const { + return AtomicGet(State) == ExceptionSet; + } + + void GetValue(TDuration timeout = TDuration::Zero()) const { + int state = AtomicGet(State); + if (Y_UNLIKELY(state == NotReady)) { + if (timeout == TDuration::Zero()) { + ythrow TFutureException() << "value not set"; + } + + if (!Wait(timeout)) { + ythrow TFutureException() << "wait timeout"; + } + + state = AtomicGet(State); + } + + TryRethrowWithState(state); + + Y_ASSERT(state == ValueSet); + } + + void SetValue() { + bool success = TrySetValue(); + if (Y_UNLIKELY(!success)) { + ythrow TFutureException() << "value already set"; + } + } + + bool TrySetValue() { + TSystemEvent* readyEvent = nullptr; + TCallbackList<void> callbacks; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (Y_UNLIKELY(state != NotReady)) { + return false; + } + + readyEvent = ReadyEvent.Get(); + callbacks = std::move(Callbacks); + + AtomicSet(State, ValueSet); + } + + if (readyEvent) { + readyEvent->Signal(); + } + + if (callbacks) { + TFuture<void> temp(this); + for (auto& callback : callbacks) { + callback(temp); + } + } + + return true; + } + + void SetException(std::exception_ptr e) { + bool success = TrySetException(std::move(e)); + if (Y_UNLIKELY(!success)) { + ythrow TFutureException() << "value already set"; + } + } + + bool TrySetException(std::exception_ptr e) { + TSystemEvent* readyEvent = nullptr; + TCallbackList<void> callbacks; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (Y_UNLIKELY(state != NotReady)) { + return false; + } + + Exception = std::move(e); + + readyEvent = ReadyEvent.Get(); + callbacks = std::move(Callbacks); + + AtomicSet(State, ExceptionSet); + } + + if (readyEvent) { + readyEvent->Signal(); + } + + if (callbacks) { + TFuture<void> temp(this); + for (auto& callback : callbacks) { + callback(temp); + } + } + + return true; + } + + template <typename F> + bool Subscribe(F&& func) { + with_lock (StateLock) { + int state = AtomicGet(State); + if (state == NotReady) { + Callbacks.emplace_back(std::forward<F>(func)); + return true; + } + } + return false; + } + + void Wait() const { + Wait(TInstant::Max()); + } + + bool Wait(TDuration timeout) const { + return Wait(timeout.ToDeadLine()); + } + + bool Wait(TInstant deadline) const { + TSystemEvent* readyEvent = nullptr; + + with_lock (StateLock) { + int state = AtomicGet(State); + if (state != NotReady) { + return true; + } + + if (!ReadyEvent) { + ReadyEvent.Reset(new TSystemEvent()); + } + readyEvent = ReadyEvent.Get(); + } + + Y_ASSERT(readyEvent); + return readyEvent->WaitD(deadline); + } + + void TryRethrowWithState(int state) const { + if (Y_UNLIKELY(state == ExceptionSet)) { + Y_ASSERT(Exception); + std::rethrow_exception(Exception); + } + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + inline void SetValueImpl(TPromise<T>& promise, const T& value) { + promise.SetValue(value); + } + + template <typename T> + inline void SetValueImpl(TPromise<T>& promise, T&& value) { + promise.SetValue(std::move(value)); + } + + template <typename T> + inline void SetValueImpl(TPromise<T>& promise, const TFuture<T>& future, + std::enable_if_t<!std::is_void<T>::value, bool> = false) { + future.Subscribe([=](const TFuture<T>& f) mutable { + T const* value; + try { + value = &f.GetValue(); + } catch (...) { + promise.SetException(std::current_exception()); + return; + } + promise.SetValue(*value); + }); + } + + template <typename T> + inline void SetValueImpl(TPromise<void>& promise, const TFuture<T>& future) { + future.Subscribe([=](const TFuture<T>& f) mutable { + try { + f.TryRethrow(); + } catch (...) { + promise.SetException(std::current_exception()); + return; + } + promise.SetValue(); + }); + } + + template <typename T, typename F> + inline void SetValue(TPromise<T>& promise, F&& func) { + try { + SetValueImpl(promise, func()); + } catch (...) { + const bool success = promise.TrySetException(std::current_exception()); + if (Y_UNLIKELY(!success)) { + throw; + } + } + } + + template <typename F> + inline void SetValue(TPromise<void>& promise, F&& func, + std::enable_if_t<std::is_void<TFunctionResult<F>>::value, bool> = false) { + try { + func(); + } catch (...) { + promise.SetException(std::current_exception()); + return; + } + promise.SetValue(); + } + + } + + //////////////////////////////////////////////////////////////////////////////// + + class TFutureStateId { + private: + const void* Id; + + public: + template <typename T> + explicit TFutureStateId(const NImpl::TFutureState<T>& state) + : Id(&state) + { + } + + const void* Value() const noexcept { + return Id; + } + }; + + inline bool operator==(const TFutureStateId& l, const TFutureStateId& r) { + return l.Value() == r.Value(); + } + + inline bool operator!=(const TFutureStateId& l, const TFutureStateId& r) { + return !(l == r); + } + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + inline TFuture<T>::TFuture(const TIntrusivePtr<TFutureState>& state) noexcept + : State(state) + { + } + + template <typename T> + inline void TFuture<T>::Swap(TFuture<T>& other) { + State.Swap(other.State); + } + + template <typename T> + inline bool TFuture<T>::HasValue() const { + return State && State->HasValue(); + } + + template <typename T> + inline const T& TFuture<T>::GetValue(TDuration timeout) const { + EnsureInitialized(); + return State->GetValue(timeout); + } + + template <typename T> + inline T TFuture<T>::ExtractValue(TDuration timeout) { + EnsureInitialized(); + return State->ExtractValue(timeout); + } + + template <typename T> + inline const T& TFuture<T>::GetValueSync() const { + return GetValue(TDuration::Max()); + } + + template <typename T> + inline T TFuture<T>::ExtractValueSync() { + return ExtractValue(TDuration::Max()); + } + + template <typename T> + inline void TFuture<T>::TryRethrow() const { + if (State) { + State->TryRethrow(); + } + } + + template <typename T> + inline bool TFuture<T>::HasException() const { + return State && State->HasException(); + } + + template <typename T> + inline void TFuture<T>::Wait() const { + EnsureInitialized(); + return State->Wait(); + } + + template <typename T> + inline bool TFuture<T>::Wait(TDuration timeout) const { + EnsureInitialized(); + return State->Wait(timeout); + } + + template <typename T> + inline bool TFuture<T>::Wait(TInstant deadline) const { + EnsureInitialized(); + return State->Wait(deadline); + } + + template <typename T> + template <typename F> + inline const TFuture<T>& TFuture<T>::Subscribe(F&& func) const { + EnsureInitialized(); + if (!State->Subscribe(std::forward<F>(func))) { + func(*this); + } + return *this; + } + + template <typename T> + template <typename F> + inline const TFuture<T>& TFuture<T>::NoexceptSubscribe(F&& func) const noexcept { + return Subscribe(std::forward<F>(func)); + } + + + template <typename T> + template <typename F> + inline TFuture<TFutureType<TFutureCallResult<F, T>>> TFuture<T>::Apply(F&& func) const { + auto promise = NewPromise<TFutureType<TFutureCallResult<F, T>>>(); + Subscribe([promise, func = std::forward<F>(func)](const TFuture<T>& future) mutable { + NImpl::SetValue(promise, [&]() { return func(future); }); + }); + return promise; + } + + template <typename T> + inline TFuture<void> TFuture<T>::IgnoreResult() const { + auto promise = NewPromise(); + Subscribe([=](const TFuture<T>& future) mutable { + NImpl::SetValueImpl(promise, future); + }); + return promise; + } + + template <typename T> + inline bool TFuture<T>::Initialized() const { + return bool(State); + } + + template <typename T> + inline TMaybe<TFutureStateId> TFuture<T>::StateId() const noexcept { + return State != nullptr ? MakeMaybe<TFutureStateId>(*State) : Nothing(); + } + + template <typename T> + inline void TFuture<T>::EnsureInitialized() const { + if (!State) { + ythrow TFutureException() << "state not initialized"; + } + } + + //////////////////////////////////////////////////////////////////////////////// + + inline TFuture<void>::TFuture(const TIntrusivePtr<TFutureState>& state) noexcept + : State(state) + { + } + + inline void TFuture<void>::Swap(TFuture<void>& other) { + State.Swap(other.State); + } + + inline bool TFuture<void>::HasValue() const { + return State && State->HasValue(); + } + + inline void TFuture<void>::GetValue(TDuration timeout) const { + EnsureInitialized(); + State->GetValue(timeout); + } + + inline void TFuture<void>::GetValueSync() const { + GetValue(TDuration::Max()); + } + + inline void TFuture<void>::TryRethrow() const { + if (State) { + State->TryRethrow(); + } + } + + inline bool TFuture<void>::HasException() const { + return State && State->HasException(); + } + + inline void TFuture<void>::Wait() const { + EnsureInitialized(); + return State->Wait(); + } + + inline bool TFuture<void>::Wait(TDuration timeout) const { + EnsureInitialized(); + return State->Wait(timeout); + } + + inline bool TFuture<void>::Wait(TInstant deadline) const { + EnsureInitialized(); + return State->Wait(deadline); + } + + template <typename F> + inline const TFuture<void>& TFuture<void>::Subscribe(F&& func) const { + EnsureInitialized(); + if (!State->Subscribe(std::forward<F>(func))) { + func(*this); + } + return *this; + } + + template <typename F> + inline const TFuture<void>& TFuture<void>::NoexceptSubscribe(F&& func) const noexcept { + return Subscribe(std::forward<F>(func)); + } + + + template <typename F> + inline TFuture<TFutureType<TFutureCallResult<F, void>>> TFuture<void>::Apply(F&& func) const { + auto promise = NewPromise<TFutureType<TFutureCallResult<F, void>>>(); + Subscribe([promise, func = std::forward<F>(func)](const TFuture<void>& future) mutable { + NImpl::SetValue(promise, [&]() { return func(future); }); + }); + return promise; + } + + template <typename R> + inline TFuture<R> TFuture<void>::Return(const R& value) const { + auto promise = NewPromise<R>(); + Subscribe([=](const TFuture<void>& future) mutable { + try { + future.TryRethrow(); + } catch (...) { + promise.SetException(std::current_exception()); + return; + } + promise.SetValue(value); + }); + return promise; + } + + inline bool TFuture<void>::Initialized() const { + return bool(State); + } + + inline TMaybe<TFutureStateId> TFuture<void>::StateId() const noexcept { + return State != nullptr ? MakeMaybe<TFutureStateId>(*State) : Nothing(); + } + + inline void TFuture<void>::EnsureInitialized() const { + if (!State) { + ythrow TFutureException() << "state not initialized"; + } + } + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + inline TPromise<T>::TPromise(const TIntrusivePtr<TFutureState>& state) noexcept + : State(state) + { + } + + template <typename T> + inline void TPromise<T>::Swap(TPromise<T>& other) { + State.Swap(other.State); + } + + template <typename T> + inline const T& TPromise<T>::GetValue() const { + EnsureInitialized(); + return State->GetValue(); + } + + template <typename T> + inline T TPromise<T>::ExtractValue() { + EnsureInitialized(); + return State->ExtractValue(); + } + + template <typename T> + inline bool TPromise<T>::HasValue() const { + return State && State->HasValue(); + } + + template <typename T> + inline void TPromise<T>::SetValue(const T& value) { + EnsureInitialized(); + State->SetValue(value); + } + + template <typename T> + inline void TPromise<T>::SetValue(T&& value) { + EnsureInitialized(); + State->SetValue(std::move(value)); + } + + template <typename T> + inline bool TPromise<T>::TrySetValue(const T& value) { + EnsureInitialized(); + return State->TrySetValue(value); + } + + template <typename T> + inline bool TPromise<T>::TrySetValue(T&& value) { + EnsureInitialized(); + return State->TrySetValue(std::move(value)); + } + + template <typename T> + inline void TPromise<T>::TryRethrow() const { + if (State) { + State->TryRethrow(); + } + } + + template <typename T> + inline bool TPromise<T>::HasException() const { + return State && State->HasException(); + } + + template <typename T> + inline void TPromise<T>::SetException(const TString& e) { + EnsureInitialized(); + State->SetException(std::make_exception_ptr(yexception() << e)); + } + + template <typename T> + inline void TPromise<T>::SetException(std::exception_ptr e) { + EnsureInitialized(); + State->SetException(std::move(e)); + } + + template <typename T> + inline bool TPromise<T>::TrySetException(std::exception_ptr e) { + EnsureInitialized(); + return State->TrySetException(std::move(e)); + } + + template <typename T> + inline TFuture<T> TPromise<T>::GetFuture() const { + EnsureInitialized(); + return TFuture<T>(State); + } + + template <typename T> + inline TPromise<T>::operator TFuture<T>() const { + return GetFuture(); + } + + template <typename T> + inline bool TPromise<T>::Initialized() const { + return bool(State); + } + + template <typename T> + inline void TPromise<T>::EnsureInitialized() const { + if (!State) { + ythrow TFutureException() << "state not initialized"; + } + } + + //////////////////////////////////////////////////////////////////////////////// + + inline TPromise<void>::TPromise(const TIntrusivePtr<TFutureState>& state) noexcept + : State(state) + { + } + + inline void TPromise<void>::Swap(TPromise<void>& other) { + State.Swap(other.State); + } + + inline void TPromise<void>::GetValue() const { + EnsureInitialized(); + State->GetValue(); + } + + inline bool TPromise<void>::HasValue() const { + return State && State->HasValue(); + } + + inline void TPromise<void>::SetValue() { + EnsureInitialized(); + State->SetValue(); + } + + inline bool TPromise<void>::TrySetValue() { + EnsureInitialized(); + return State->TrySetValue(); + } + + inline void TPromise<void>::TryRethrow() const { + if(State) { + State->TryRethrow(); + } + } + + inline bool TPromise<void>::HasException() const { + return State && State->HasException(); + } + + inline void TPromise<void>::SetException(const TString& e) { + EnsureInitialized(); + State->SetException(std::make_exception_ptr(yexception() << e)); + } + + inline void TPromise<void>::SetException(std::exception_ptr e) { + EnsureInitialized(); + State->SetException(std::move(e)); + } + + inline bool TPromise<void>::TrySetException(std::exception_ptr e) { + EnsureInitialized(); + return State->TrySetException(std::move(e)); + } + + inline TFuture<void> TPromise<void>::GetFuture() const { + EnsureInitialized(); + return TFuture<void>(State); + } + + inline TPromise<void>::operator TFuture<void>() const { + return GetFuture(); + } + + inline bool TPromise<void>::Initialized() const { + return bool(State); + } + + inline void TPromise<void>::EnsureInitialized() const { + if (!State) { + ythrow TFutureException() << "state not initialized"; + } + } + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + inline TPromise<T> NewPromise() { + return {new NImpl::TFutureState<T>()}; + } + + inline TPromise<void> NewPromise() { + return {new NImpl::TFutureState<void>()}; + } + + template <typename T> + inline TFuture<T> MakeFuture(const T& value) { + return {new NImpl::TFutureState<T>(value)}; + } + + template <typename T> + inline TFuture<std::remove_reference_t<T>> MakeFuture(T&& value) { + return {new NImpl::TFutureState<std::remove_reference_t<T>>(std::forward<T>(value))}; + } + + template <typename T> + inline TFuture<T> MakeFuture() { + struct TCache { + TFuture<T> Instance{new NImpl::TFutureState<T>(Default<T>())}; + + TCache() { + // Immediately advance state from ValueSet to ValueRead. + // This should prevent corrupting shared value with an ExtractValue() call. + Y_UNUSED(Instance.GetValue()); + } + }; + return Singleton<TCache>()->Instance; + } + + template <typename T> + inline TFuture<T> MakeErrorFuture(std::exception_ptr exception) + { + return {new NImpl::TFutureState<T>(std::move(exception), NImpl::TError::Error)}; + } + + inline TFuture<void> MakeFuture() { + struct TCache { + TFuture<void> Instance{new NImpl::TFutureState<void>(true)}; + }; + return Singleton<TCache>()->Instance; + } +} diff --git a/library/cpp/threading/future/core/future.cpp b/library/cpp/threading/future/core/future.cpp new file mode 100644 index 0000000000..3243afcb40 --- /dev/null +++ b/library/cpp/threading/future/core/future.cpp @@ -0,0 +1 @@ +#include "future.h" diff --git a/library/cpp/threading/future/core/future.h b/library/cpp/threading/future/core/future.h new file mode 100644 index 0000000000..2e82bb953e --- /dev/null +++ b/library/cpp/threading/future/core/future.h @@ -0,0 +1,272 @@ +#pragma once + +#include "fwd.h" + +#include <util/datetime/base.h> +#include <util/generic/function.h> +#include <util/generic/maybe.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/system/event.h> +#include <util/system/spinlock.h> + +namespace NThreading { + //////////////////////////////////////////////////////////////////////////////// + + struct TFutureException: public yexception {}; + + // creates unset promise + template <typename T> + TPromise<T> NewPromise(); + TPromise<void> NewPromise(); + + // creates preset future + template <typename T> + TFuture<T> MakeFuture(const T& value); + template <typename T> + TFuture<std::remove_reference_t<T>> MakeFuture(T&& value); + template <typename T> + TFuture<T> MakeFuture(); + template <typename T> + TFuture<T> MakeErrorFuture(std::exception_ptr exception); + TFuture<void> MakeFuture(); + + //////////////////////////////////////////////////////////////////////////////// + + namespace NImpl { + template <typename T> + class TFutureState; + + template <typename T> + struct TFutureType { + using TType = T; + }; + + template <typename T> + struct TFutureType<TFuture<T>> { + using TType = typename TFutureType<T>::TType; + }; + + template <typename F, typename T> + struct TFutureCallResult { + // NOTE: separate class for msvc compatibility + using TType = decltype(std::declval<F&>()(std::declval<const TFuture<T>&>())); + }; + } + + template <typename F> + using TFutureType = typename NImpl::TFutureType<F>::TType; + + template <typename F, typename T> + using TFutureCallResult = typename NImpl::TFutureCallResult<F, T>::TType; + + //! Type of the future/promise state identifier + class TFutureStateId; + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + class TFuture { + using TFutureState = NImpl::TFutureState<T>; + + private: + TIntrusivePtr<TFutureState> State; + + public: + using value_type = T; + + TFuture() noexcept = default; + TFuture(const TFuture<T>& other) noexcept = default; + TFuture(TFuture<T>&& other) noexcept = default; + TFuture(const TIntrusivePtr<TFutureState>& state) noexcept; + + TFuture<T>& operator=(const TFuture<T>& other) noexcept = default; + TFuture<T>& operator=(TFuture<T>&& other) noexcept = default; + void Swap(TFuture<T>& other); + + bool Initialized() const; + + bool HasValue() const; + const T& GetValue(TDuration timeout = TDuration::Zero()) const; + const T& GetValueSync() const; + T ExtractValue(TDuration timeout = TDuration::Zero()); + T ExtractValueSync(); + + void TryRethrow() const; + bool HasException() const; + + void Wait() const; + bool Wait(TDuration timeout) const; + bool Wait(TInstant deadline) const; + + template <typename F> + const TFuture<T>& Subscribe(F&& callback) const; + + // precondition: EnsureInitialized() passes + // postcondition: std::terminate is highly unlikely + template <typename F> + const TFuture<T>& NoexceptSubscribe(F&& callback) const noexcept; + + template <typename F> + TFuture<TFutureType<TFutureCallResult<F, T>>> Apply(F&& func) const; + + TFuture<void> IgnoreResult() const; + + //! If the future is initialized returns the future state identifier. Otherwise returns an empty optional + /** The state identifier is guaranteed to be unique during the future state lifetime and could be reused after its death + **/ + TMaybe<TFutureStateId> StateId() const noexcept; + + void EnsureInitialized() const; + }; + + //////////////////////////////////////////////////////////////////////////////// + + template <> + class TFuture<void> { + using TFutureState = NImpl::TFutureState<void>; + + private: + TIntrusivePtr<TFutureState> State = nullptr; + + public: + using value_type = void; + + TFuture() noexcept = default; + TFuture(const TFuture<void>& other) noexcept = default; + TFuture(TFuture<void>&& other) noexcept = default; + TFuture(const TIntrusivePtr<TFutureState>& state) noexcept; + + TFuture<void>& operator=(const TFuture<void>& other) noexcept = default; + TFuture<void>& operator=(TFuture<void>&& other) noexcept = default; + void Swap(TFuture<void>& other); + + bool Initialized() const; + + bool HasValue() const; + void GetValue(TDuration timeout = TDuration::Zero()) const; + void GetValueSync() const; + + void TryRethrow() const; + bool HasException() const; + + void Wait() const; + bool Wait(TDuration timeout) const; + bool Wait(TInstant deadline) const; + + template <typename F> + const TFuture<void>& Subscribe(F&& callback) const; + + // precondition: EnsureInitialized() passes + // postcondition: std::terminate is highly unlikely + template <typename F> + const TFuture<void>& NoexceptSubscribe(F&& callback) const noexcept; + + template <typename F> + TFuture<TFutureType<TFutureCallResult<F, void>>> Apply(F&& func) const; + + template <typename R> + TFuture<R> Return(const R& value) const; + + TFuture<void> IgnoreResult() const { + return *this; + } + + //! If the future is initialized returns the future state identifier. Otherwise returns an empty optional + /** The state identifier is guaranteed to be unique during the future state lifetime and could be reused after its death + **/ + TMaybe<TFutureStateId> StateId() const noexcept; + + void EnsureInitialized() const; + }; + + //////////////////////////////////////////////////////////////////////////////// + + template <typename T> + class TPromise { + using TFutureState = NImpl::TFutureState<T>; + + private: + TIntrusivePtr<TFutureState> State = nullptr; + + public: + TPromise() noexcept = default; + TPromise(const TPromise<T>& other) noexcept = default; + TPromise(TPromise<T>&& other) noexcept = default; + TPromise(const TIntrusivePtr<TFutureState>& state) noexcept; + + TPromise<T>& operator=(const TPromise<T>& other) noexcept = default; + TPromise<T>& operator=(TPromise<T>&& other) noexcept = default; + void Swap(TPromise<T>& other); + + bool Initialized() const; + + bool HasValue() const; + const T& GetValue() const; + T ExtractValue(); + + void SetValue(const T& value); + void SetValue(T&& value); + + bool TrySetValue(const T& value); + bool TrySetValue(T&& value); + + void TryRethrow() const; + bool HasException() const; + void SetException(const TString& e); + void SetException(std::exception_ptr e); + bool TrySetException(std::exception_ptr e); + + TFuture<T> GetFuture() const; + operator TFuture<T>() const; + + private: + void EnsureInitialized() const; + }; + + //////////////////////////////////////////////////////////////////////////////// + + template <> + class TPromise<void> { + using TFutureState = NImpl::TFutureState<void>; + + private: + TIntrusivePtr<TFutureState> State; + + public: + TPromise() noexcept = default; + TPromise(const TPromise<void>& other) noexcept = default; + TPromise(TPromise<void>&& other) noexcept = default; + TPromise(const TIntrusivePtr<TFutureState>& state) noexcept; + + TPromise<void>& operator=(const TPromise<void>& other) noexcept = default; + TPromise<void>& operator=(TPromise<void>&& other) noexcept = default; + void Swap(TPromise<void>& other); + + bool Initialized() const; + + bool HasValue() const; + void GetValue() const; + + void SetValue(); + bool TrySetValue(); + + void TryRethrow() const; + bool HasException() const; + void SetException(const TString& e); + void SetException(std::exception_ptr e); + bool TrySetException(std::exception_ptr e); + + TFuture<void> GetFuture() const; + operator TFuture<void>() const; + + private: + void EnsureInitialized() const; + }; + +} + +#define INCLUDE_FUTURE_INL_H +#include "future-inl.h" +#undef INCLUDE_FUTURE_INL_H diff --git a/library/cpp/threading/future/core/fwd.cpp b/library/cpp/threading/future/core/fwd.cpp new file mode 100644 index 0000000000..4214b6df83 --- /dev/null +++ b/library/cpp/threading/future/core/fwd.cpp @@ -0,0 +1 @@ +#include "fwd.h" diff --git a/library/cpp/threading/future/core/fwd.h b/library/cpp/threading/future/core/fwd.h new file mode 100644 index 0000000000..96eba9e6a3 --- /dev/null +++ b/library/cpp/threading/future/core/fwd.h @@ -0,0 +1,11 @@ +#pragma once + +namespace NThreading { + struct TFutureException; + + template <typename T> + class TFuture; + + template <typename T> + class TPromise; +} diff --git a/library/cpp/threading/future/future.h b/library/cpp/threading/future/future.h new file mode 100644 index 0000000000..35db9abbe2 --- /dev/null +++ b/library/cpp/threading/future/future.h @@ -0,0 +1,4 @@ +#pragma once + +#include "core/future.h" +#include "wait/wait.h" diff --git a/library/cpp/threading/future/future_mt_ut.cpp b/library/cpp/threading/future/future_mt_ut.cpp new file mode 100644 index 0000000000..4f390866c1 --- /dev/null +++ b/library/cpp/threading/future/future_mt_ut.cpp @@ -0,0 +1,215 @@ +#include "future.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/noncopyable.h> +#include <util/generic/xrange.h> +#include <util/thread/pool.h> + +#include <atomic> +#include <exception> + +using NThreading::NewPromise; +using NThreading::TFuture; +using NThreading::TPromise; +using NThreading::TWaitPolicy; + +namespace { + // Wait* implementation without optimizations, to test TWaitGroup better + template <class WaitPolicy, class TContainer> + TFuture<void> WaitNoOpt(const TContainer& futures) { + NThreading::TWaitGroup<WaitPolicy> wg; + for (const auto& fut : futures) { + wg.Add(fut); + } + + return std::move(wg).Finish(); + } + + class TRelaxedBarrier { + public: + explicit TRelaxedBarrier(i64 size) + : Waiting_{size} { + } + + void Arrive() { + // barrier is not for synchronization, just to ensure good timings, so + // std::memory_order_relaxed is enough + Waiting_.fetch_add(-1, std::memory_order_relaxed); + + while (Waiting_.load(std::memory_order_relaxed)) { + } + + Y_ASSERT(Waiting_.load(std::memory_order_relaxed) >= 0); + } + + private: + std::atomic<i64> Waiting_; + }; + + THolder<TThreadPool> MakePool() { + auto pool = MakeHolder<TThreadPool>(TThreadPool::TParams{}.SetBlocking(false).SetCatching(false)); + pool->Start(8); + return pool; + } + + template <class T> + TVector<TFuture<T>> ToFutures(const TVector<TPromise<T>>& promises) { + TVector<TFuture<void>> futures; + + for (auto&& p : promises) { + futures.emplace_back(p); + } + + return futures; + } + + struct TStateSnapshot { + i64 Started = -1; + i64 StartedException = -1; + const TVector<TFuture<void>>* Futures = nullptr; + }; + + // note: std::memory_order_relaxed should be enough everywhere, because TFuture::SetValue must provide the + // needed synchronization + template <class TFactory> + void RunWaitTest(TFactory global) { + auto pool = MakePool(); + + const auto exception = std::make_exception_ptr(42); + + for (auto numPromises : xrange(1, 5)) { + for (auto loopIter : xrange(1024 * 64)) { + const auto numParticipants = numPromises + 1; + + TRelaxedBarrier barrier{numParticipants}; + + std::atomic<i64> started = 0; + std::atomic<i64> startedException = 0; + std::atomic<i64> completed = 0; + + TVector<TPromise<void>> promises; + for (auto i : xrange(numPromises)) { + Y_UNUSED(i); + promises.push_back(NewPromise()); + } + + const auto futures = ToFutures(promises); + + auto snapshotter = [&] { + return TStateSnapshot{ + .Started = started.load(std::memory_order_relaxed), + .StartedException = startedException.load(std::memory_order_relaxed), + .Futures = &futures, + }; + }; + + for (auto i : xrange(numPromises)) { + pool->SafeAddFunc([&, i] { + barrier.Arrive(); + + // subscribers must observe effects of this operation + // after .Set* + started.fetch_add(1, std::memory_order_relaxed); + + if ((loopIter % 4 == 0) && i == 0) { + startedException.fetch_add(1, std::memory_order_relaxed); + promises[i].SetException(exception); + } else { + promises[i].SetValue(); + } + + completed.fetch_add(1, std::memory_order_release); + }); + } + + pool->SafeAddFunc([&] { + auto local = global(snapshotter); + + barrier.Arrive(); + + local(); + + completed.fetch_add(1, std::memory_order_release); + }); + + while (completed.load() != numParticipants) { + } + } + } + } +} + +Y_UNIT_TEST_SUITE(TFutureMultiThreadedTest) { + Y_UNIT_TEST(WaitAll) { + RunWaitTest( + [](auto snapshotter) { + return [=]() { + auto* futures = snapshotter().Futures; + + auto all = WaitNoOpt<TWaitPolicy::TAll>(*futures); + + // tests safety part + all.Subscribe([=] (auto&& all) { + TStateSnapshot snap = snapshotter(); + + // value safety: all is set => every future is set + UNIT_ASSERT(all.HasValue() <= ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException)); + + // safety for hasException: all is set => every future is set and some has exception + UNIT_ASSERT(all.HasException() <= ((snap.Started == (i64)snap.Futures->size()) && snap.StartedException > 0)); + }); + + // test liveness + all.Wait(); + }; + }); + } + + Y_UNIT_TEST(WaitAny) { + RunWaitTest( + [](auto snapshotter) { + return [=]() { + auto* futures = snapshotter().Futures; + + auto any = WaitNoOpt<TWaitPolicy::TAny>(*futures); + + // safety: any is ready => some f is ready + any.Subscribe([=](auto&&) { + UNIT_ASSERT(snapshotter().Started > 0); + }); + + // do we need better multithreaded liveness tests? + any.Wait(); + }; + }); + } + + Y_UNIT_TEST(WaitExceptionOrAll) { + RunWaitTest( + [](auto snapshotter) { + return [=]() { + NThreading::WaitExceptionOrAll(*snapshotter().Futures) + .Subscribe([=](auto&&) { + auto* futures = snapshotter().Futures; + + auto exceptionOrAll = WaitNoOpt<TWaitPolicy::TExceptionOrAll>(*futures); + + exceptionOrAll.Subscribe([snapshotter](auto&& exceptionOrAll) { + TStateSnapshot snap = snapshotter(); + + // safety for hasException: exceptionOrAll has exception => some has exception + UNIT_ASSERT(exceptionOrAll.HasException() ? snap.StartedException > 0 : true); + + // value safety: exceptionOrAll has value => all have value + UNIT_ASSERT(exceptionOrAll.HasValue() == ((snap.Started == (i64)snap.Futures->size()) && !snap.StartedException)); + }); + + // do we need better multithreaded liveness tests? + exceptionOrAll.Wait(); + }); + }; + }); + } +} + diff --git a/library/cpp/threading/future/future_ut.cpp b/library/cpp/threading/future/future_ut.cpp new file mode 100644 index 0000000000..05950a568d --- /dev/null +++ b/library/cpp/threading/future/future_ut.cpp @@ -0,0 +1,640 @@ +#include "future.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <list> +#include <type_traits> + +namespace NThreading { + +namespace { + + class TCopyCounter { + public: + TCopyCounter(size_t* numCopies) + : NumCopies(numCopies) + {} + + TCopyCounter(const TCopyCounter& that) + : NumCopies(that.NumCopies) + { + ++*NumCopies; + } + + TCopyCounter& operator=(const TCopyCounter& that) { + NumCopies = that.NumCopies; + ++*NumCopies; + return *this; + } + + TCopyCounter(TCopyCounter&& that) = default; + + TCopyCounter& operator=(TCopyCounter&& that) = default; + + private: + size_t* NumCopies = nullptr; + }; + + template <typename T> + auto MakePromise() { + if constexpr (std::is_same_v<T, void>) { + return NewPromise(); + } + return NewPromise<T>(); + } + + + template <typename T> + void TestFutureStateId() { + TFuture<T> empty; + UNIT_ASSERT(!empty.StateId().Defined()); + auto promise1 = MakePromise<T>(); + auto future11 = promise1.GetFuture(); + UNIT_ASSERT(future11.StateId().Defined()); + auto future12 = promise1.GetFuture(); + UNIT_ASSERT_EQUAL(future11.StateId(), future11.StateId()); // same result for subsequent invocations + UNIT_ASSERT_EQUAL(future11.StateId(), future12.StateId()); // same result for different futures with the same state + auto promise2 = MakePromise<T>(); + auto future2 = promise2.GetFuture(); + UNIT_ASSERT(future2.StateId().Defined()); + UNIT_ASSERT_UNEQUAL(future11.StateId(), future2.StateId()); // different results for futures with different states + } + +} + + //////////////////////////////////////////////////////////////////////////////// + + Y_UNIT_TEST_SUITE(TFutureTest) { + Y_UNIT_TEST(ShouldInitiallyHasNoValue) { + TPromise<int> promise; + UNIT_ASSERT(!promise.HasValue()); + + promise = NewPromise<int>(); + UNIT_ASSERT(!promise.HasValue()); + + TFuture<int> future; + UNIT_ASSERT(!future.HasValue()); + + future = promise.GetFuture(); + UNIT_ASSERT(!future.HasValue()); + } + + Y_UNIT_TEST(ShouldInitiallyHasNoValueVoid) { + TPromise<void> promise; + UNIT_ASSERT(!promise.HasValue()); + + promise = NewPromise(); + UNIT_ASSERT(!promise.HasValue()); + + TFuture<void> future; + UNIT_ASSERT(!future.HasValue()); + + future = promise.GetFuture(); + UNIT_ASSERT(!future.HasValue()); + } + + Y_UNIT_TEST(ShouldStoreValue) { + TPromise<int> promise = NewPromise<int>(); + promise.SetValue(123); + UNIT_ASSERT(promise.HasValue()); + UNIT_ASSERT_EQUAL(promise.GetValue(), 123); + + TFuture<int> future = promise.GetFuture(); + UNIT_ASSERT(future.HasValue()); + UNIT_ASSERT_EQUAL(future.GetValue(), 123); + + future = MakeFuture(345); + UNIT_ASSERT(future.HasValue()); + UNIT_ASSERT_EQUAL(future.GetValue(), 345); + } + + Y_UNIT_TEST(ShouldStoreValueVoid) { + TPromise<void> promise = NewPromise(); + promise.SetValue(); + UNIT_ASSERT(promise.HasValue()); + + TFuture<void> future = promise.GetFuture(); + UNIT_ASSERT(future.HasValue()); + + future = MakeFuture(); + UNIT_ASSERT(future.HasValue()); + } + + struct TTestCallback { + int Value; + + TTestCallback(int value) + : Value(value) + { + } + + void Callback(const TFuture<int>& future) { + Value += future.GetValue(); + } + + int Func(const TFuture<int>& future) { + return (Value += future.GetValue()); + } + + void VoidFunc(const TFuture<int>& future) { + future.GetValue(); + } + + TFuture<int> FutureFunc(const TFuture<int>& future) { + return MakeFuture(Value += future.GetValue()); + } + + TPromise<void> Signal = NewPromise(); + TFuture<void> FutureVoidFunc(const TFuture<int>& future) { + future.GetValue(); + return Signal; + } + }; + + Y_UNIT_TEST(ShouldInvokeCallback) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<int> future = promise.GetFuture() + .Subscribe([&](const TFuture<int>& theFuture) { return callback.Callback(theFuture); }); + + promise.SetValue(456); + UNIT_ASSERT_EQUAL(future.GetValue(), 456); + UNIT_ASSERT_EQUAL(callback.Value, 123 + 456); + } + + Y_UNIT_TEST(ShouldApplyFunc) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<int> future = promise.GetFuture() + .Apply([&](const auto& theFuture) { return callback.Func(theFuture); }); + + promise.SetValue(456); + UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456); + UNIT_ASSERT_EQUAL(callback.Value, 123 + 456); + } + + Y_UNIT_TEST(ShouldApplyVoidFunc) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<void> future = promise.GetFuture() + .Apply([&](const auto& theFuture) { return callback.VoidFunc(theFuture); }); + + promise.SetValue(456); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldApplyFutureFunc) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<int> future = promise.GetFuture() + .Apply([&](const auto& theFuture) { return callback.FutureFunc(theFuture); }); + + promise.SetValue(456); + UNIT_ASSERT_EQUAL(future.GetValue(), 123 + 456); + UNIT_ASSERT_EQUAL(callback.Value, 123 + 456); + } + + Y_UNIT_TEST(ShouldApplyFutureVoidFunc) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<void> future = promise.GetFuture() + .Apply([&](const auto& theFuture) { return callback.FutureVoidFunc(theFuture); }); + + promise.SetValue(456); + UNIT_ASSERT(!future.HasValue()); + + callback.Signal.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldIgnoreResultIfAsked) { + TPromise<int> promise = NewPromise<int>(); + + TTestCallback callback(123); + TFuture<int> future = promise.GetFuture().IgnoreResult().Return(42); + + promise.SetValue(456); + UNIT_ASSERT_EQUAL(future.GetValue(), 42); + } + + class TCustomException: public yexception { + }; + + Y_UNIT_TEST(ShouldRethrowException) { + TPromise<int> promise = NewPromise<int>(); + try { + ythrow TCustomException(); + } catch (...) { + promise.SetException(std::current_exception()); + } + + UNIT_ASSERT(!promise.HasValue()); + UNIT_ASSERT(promise.HasException()); + UNIT_ASSERT_EXCEPTION(promise.GetValue(), TCustomException); + UNIT_ASSERT_EXCEPTION(promise.TryRethrow(), TCustomException); + } + + Y_UNIT_TEST(ShouldRethrowCallbackException) { + TPromise<int> promise = NewPromise<int>(); + TFuture<int> future = promise.GetFuture(); + future.Subscribe([](const TFuture<int>&) { + throw TCustomException(); + }); + + UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException); + } + + Y_UNIT_TEST(ShouldRethrowCallbackExceptionIgnoreResult) { + TPromise<int> promise = NewPromise<int>(); + TFuture<void> future = promise.GetFuture().IgnoreResult(); + future.Subscribe([](const TFuture<void>&) { + throw TCustomException(); + }); + + UNIT_ASSERT_EXCEPTION(promise.SetValue(123), TCustomException); + } + + + Y_UNIT_TEST(ShouldWaitExceptionOrAll) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + TFuture<void> future = WaitExceptionOrAll(promise1, promise2); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(!future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitExceptionOrAllVector) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + TVector<TFuture<void>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitExceptionOrAll(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(!future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorWithValueType) { + TPromise<int> promise1 = NewPromise<int>(); + TPromise<int> promise2 = NewPromise<int>(); + + TVector<TFuture<int>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitExceptionOrAll(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(0); + UNIT_ASSERT(!future.HasValue()); + + promise2.SetValue(0); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitExceptionOrAllList) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + std::list<TFuture<void>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitExceptionOrAll(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(!future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitExceptionOrAllVectorEmpty) { + TVector<TFuture<void>> promises; + + TFuture<void> future = WaitExceptionOrAll(promises); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitAnyVector) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + TVector<TFuture<void>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitAny(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + + Y_UNIT_TEST(ShouldWaitAnyVectorWithValueType) { + TPromise<int> promise1 = NewPromise<int>(); + TPromise<int> promise2 = NewPromise<int>(); + + TVector<TFuture<int>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitAny(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(0); + UNIT_ASSERT(future.HasValue()); + + promise2.SetValue(0); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitAnyList) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + std::list<TFuture<void>> promises; + promises.push_back(promise1); + promises.push_back(promise2); + + TFuture<void> future = WaitAny(promises); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitAnyVectorEmpty) { + TVector<TFuture<void>> promises; + + TFuture<void> future = WaitAny(promises); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldWaitAny) { + TPromise<void> promise1 = NewPromise(); + TPromise<void> promise2 = NewPromise(); + + TFuture<void> future = WaitAny(promise1, promise2); + UNIT_ASSERT(!future.HasValue()); + + promise1.SetValue(); + UNIT_ASSERT(future.HasValue()); + + promise2.SetValue(); + UNIT_ASSERT(future.HasValue()); + } + + Y_UNIT_TEST(ShouldStoreTypesWithoutDefaultConstructor) { + // compileability test + struct TRec { + explicit TRec(int) { + } + }; + + auto promise = NewPromise<TRec>(); + promise.SetValue(TRec(1)); + + auto future = MakeFuture(TRec(1)); + const auto& rec = future.GetValue(); + Y_UNUSED(rec); + } + + Y_UNIT_TEST(ShouldStoreMovableTypes) { + // compileability test + struct TRec : TMoveOnly { + explicit TRec(int) { + } + }; + + auto promise = NewPromise<TRec>(); + promise.SetValue(TRec(1)); + + auto future = MakeFuture(TRec(1)); + const auto& rec = future.GetValue(); + Y_UNUSED(rec); + } + + Y_UNIT_TEST(ShouldMoveMovableTypes) { + // compileability test + struct TRec : TMoveOnly { + explicit TRec(int) { + } + }; + + auto promise = NewPromise<TRec>(); + promise.SetValue(TRec(1)); + + auto future = MakeFuture(TRec(1)); + auto rec = future.ExtractValue(); + Y_UNUSED(rec); + } + + Y_UNIT_TEST(ShouldNotExtractAfterGet) { + TPromise<int> promise = NewPromise<int>(); + promise.SetValue(123); + UNIT_ASSERT(promise.HasValue()); + UNIT_ASSERT_EQUAL(promise.GetValue(), 123); + UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException); + } + + Y_UNIT_TEST(ShouldNotGetAfterExtract) { + TPromise<int> promise = NewPromise<int>(); + promise.SetValue(123); + UNIT_ASSERT(promise.HasValue()); + UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123); + UNIT_CHECK_GENERATED_EXCEPTION(promise.GetValue(), TFutureException); + } + + Y_UNIT_TEST(ShouldNotExtractAfterExtract) { + TPromise<int> promise = NewPromise<int>(); + promise.SetValue(123); + UNIT_ASSERT(promise.HasValue()); + UNIT_ASSERT_EQUAL(promise.ExtractValue(), 123); + UNIT_CHECK_GENERATED_EXCEPTION(promise.ExtractValue(), TFutureException); + } + + Y_UNIT_TEST(ShouldNotExtractFromSharedDefault) { + UNIT_CHECK_GENERATED_EXCEPTION(MakeFuture<int>().ExtractValue(), TFutureException); + + struct TStorage { + TString String = TString(100, 'a'); + }; + try { + TString s = MakeFuture<TStorage>().ExtractValue().String; + Y_UNUSED(s); + } catch (TFutureException) { + // pass + } + UNIT_ASSERT_VALUES_EQUAL(MakeFuture<TStorage>().GetValue().String, TString(100, 'a')); + } + + Y_UNIT_TEST(HandlingRepetitiveSet) { + TPromise<int> promise = NewPromise<int>(); + promise.SetValue(42); + UNIT_CHECK_GENERATED_EXCEPTION(promise.SetValue(42), TFutureException); + } + + Y_UNIT_TEST(HandlingRepetitiveTrySet) { + TPromise<int> promise = NewPromise<int>(); + UNIT_ASSERT(promise.TrySetValue(42)); + UNIT_ASSERT(!promise.TrySetValue(42)); + } + + Y_UNIT_TEST(HandlingRepetitiveSetException) { + TPromise<int> promise = NewPromise<int>(); + promise.SetException("test"); + UNIT_CHECK_GENERATED_EXCEPTION(promise.SetException("test"), TFutureException); + } + + Y_UNIT_TEST(HandlingRepetitiveTrySetException) { + TPromise<int> promise = NewPromise<int>(); + UNIT_ASSERT(promise.TrySetException(std::make_exception_ptr("test"))); + UNIT_ASSERT(!promise.TrySetException(std::make_exception_ptr("test"))); + } + + Y_UNIT_TEST(ShouldAllowToMakeFutureWithException) + { + auto future1 = MakeErrorFuture<void>(std::make_exception_ptr(TFutureException())); + UNIT_ASSERT(future1.HasException()); + UNIT_CHECK_GENERATED_EXCEPTION(future1.GetValue(), TFutureException); + + auto future2 = MakeErrorFuture<int>(std::make_exception_ptr(TFutureException())); + UNIT_ASSERT(future2.HasException()); + UNIT_CHECK_GENERATED_EXCEPTION(future2.GetValue(), TFutureException); + + auto future3 = MakeFuture<std::exception_ptr>(std::make_exception_ptr(TFutureException())); + UNIT_ASSERT(future3.HasValue()); + UNIT_CHECK_GENERATED_NO_EXCEPTION(future3.GetValue(), TFutureException); + + auto future4 = MakeFuture<std::unique_ptr<int>>(nullptr); + UNIT_ASSERT(future4.HasValue()); + UNIT_CHECK_GENERATED_NO_EXCEPTION(future4.GetValue(), TFutureException); + } + + Y_UNIT_TEST(WaitAllowsExtract) { + auto future = MakeFuture<int>(42); + TVector vec{future, future, future}; + WaitExceptionOrAll(vec).GetValue(); + WaitAny(vec).GetValue(); + + UNIT_ASSERT_EQUAL(future.ExtractValue(), 42); + } + + Y_UNIT_TEST(IgnoreAllowsExtract) { + auto future = MakeFuture<int>(42); + future.IgnoreResult().GetValue(); + + UNIT_ASSERT_EQUAL(future.ExtractValue(), 42); + } + + Y_UNIT_TEST(WaitExceptionOrAllException) { + auto promise1 = NewPromise(); + auto promise2 = NewPromise(); + auto future1 = promise1.GetFuture(); + auto future2 = promise2.GetFuture(); + auto wait = WaitExceptionOrAll(future1, future2); + promise2.SetException("foo-exception"); + wait.Wait(); + UNIT_ASSERT(future2.HasException()); + UNIT_ASSERT(!future1.HasValue() && !future1.HasException()); + } + + Y_UNIT_TEST(WaitAllException) { + auto promise1 = NewPromise(); + auto promise2 = NewPromise(); + auto future1 = promise1.GetFuture(); + auto future2 = promise2.GetFuture(); + auto wait = WaitAll(future1, future2); + promise2.SetException("foo-exception"); + UNIT_ASSERT(!wait.HasValue() && !wait.HasException()); + promise1.SetValue(); + UNIT_ASSERT_EXCEPTION_CONTAINS(wait.GetValueSync(), yexception, "foo-exception"); + } + + Y_UNIT_TEST(FutureStateId) { + TestFutureStateId<void>(); + TestFutureStateId<int>(); + } + + template <typename T> + void TestApplyNoRvalueCopyImpl() { + size_t numCopies = 0; + TCopyCounter copyCounter(&numCopies); + + auto promise = MakePromise<T>(); + + const auto future = promise.GetFuture().Apply( + [copyCounter = std::move(copyCounter)] (const auto&) {} + ); + + if constexpr (std::is_same_v<T, void>) { + promise.SetValue(); + } else { + promise.SetValue(T()); + } + + future.GetValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(numCopies, 0); + } + + Y_UNIT_TEST(ApplyNoRvalueCopy) { + TestApplyNoRvalueCopyImpl<void>(); + TestApplyNoRvalueCopyImpl<int>(); + } + + template <typename T> + void TestApplyLvalueCopyImpl() { + size_t numCopies = 0; + TCopyCounter copyCounter(&numCopies); + + auto promise = MakePromise<T>(); + + auto func = [copyCounter = std::move(copyCounter)] (const auto&) {}; + const auto future = promise.GetFuture().Apply(func); + + if constexpr (std::is_same_v<T, void>) { + promise.SetValue(); + } else { + promise.SetValue(T()); + } + + future.GetValueSync(); + + UNIT_ASSERT_VALUES_EQUAL(numCopies, 1); + } + + Y_UNIT_TEST(ApplyLvalueCopy) { + TestApplyLvalueCopyImpl<void>(); + TestApplyLvalueCopyImpl<int>(); + } + } + +} diff --git a/library/cpp/threading/future/fwd.cpp b/library/cpp/threading/future/fwd.cpp new file mode 100644 index 0000000000..4214b6df83 --- /dev/null +++ b/library/cpp/threading/future/fwd.cpp @@ -0,0 +1 @@ +#include "fwd.h" diff --git a/library/cpp/threading/future/fwd.h b/library/cpp/threading/future/fwd.h new file mode 100644 index 0000000000..0cd25dd288 --- /dev/null +++ b/library/cpp/threading/future/fwd.h @@ -0,0 +1,8 @@ +#pragma once + +#include "core/fwd.h" + +namespace NThreading { + template <typename TR = void, bool IgnoreException = false> + class TLegacyFuture; +} diff --git a/library/cpp/threading/future/legacy_future.h b/library/cpp/threading/future/legacy_future.h new file mode 100644 index 0000000000..6f1eabad73 --- /dev/null +++ b/library/cpp/threading/future/legacy_future.h @@ -0,0 +1,83 @@ +#pragma once + +#include "fwd.h" +#include "future.h" + +#include <util/thread/factory.h> + +#include <functional> + +namespace NThreading { + template <typename TR, bool IgnoreException> + class TLegacyFuture: public IThreadFactory::IThreadAble, TNonCopyable { + public: + typedef TR(TFunctionSignature)(); + using TFunctionObjectType = std::function<TFunctionSignature>; + using TResult = typename TFunctionObjectType::result_type; + + private: + TFunctionObjectType Func_; + TPromise<TResult> Result_; + THolder<IThreadFactory::IThread> Thread_; + + public: + inline TLegacyFuture(const TFunctionObjectType func, IThreadFactory* pool = SystemThreadFactory()) + : Func_(func) + , Result_(NewPromise<TResult>()) + , Thread_(pool->Run(this)) + { + } + + inline ~TLegacyFuture() override { + this->Join(); + } + + inline TResult Get() { + this->Join(); + return Result_.GetValue(); + } + + private: + inline void Join() { + if (Thread_) { + Thread_->Join(); + Thread_.Destroy(); + } + } + + template <typename Result, bool IgnoreException_> + struct TExecutor { + static void SetPromise(TPromise<Result>& promise, const TFunctionObjectType& func) { + if (IgnoreException_) { + try { + promise.SetValue(func()); + } catch (...) { + } + } else { + promise.SetValue(func()); + } + } + }; + + template <bool IgnoreException_> + struct TExecutor<void, IgnoreException_> { + static void SetPromise(TPromise<void>& promise, const TFunctionObjectType& func) { + if (IgnoreException_) { + try { + func(); + promise.SetValue(); + } catch (...) { + } + } else { + func(); + promise.SetValue(); + } + } + }; + + void DoExecute() override { + TExecutor<TResult, IgnoreException>::SetPromise(Result_, Func_); + } + }; + +} diff --git a/library/cpp/threading/future/legacy_future_ut.cpp b/library/cpp/threading/future/legacy_future_ut.cpp new file mode 100644 index 0000000000..ff63db1725 --- /dev/null +++ b/library/cpp/threading/future/legacy_future_ut.cpp @@ -0,0 +1,73 @@ +#include "legacy_future.h" + +#include <library/cpp/testing/unittest/registar.h> + +namespace NThreading { + Y_UNIT_TEST_SUITE(TLegacyFutureTest) { + int intf() { + return 17; + } + + Y_UNIT_TEST(TestIntFunction) { + TLegacyFuture<int> f((&intf)); + UNIT_ASSERT_VALUES_EQUAL(17, f.Get()); + } + + static int r; + + void voidf() { + r = 18; + } + + Y_UNIT_TEST(TestVoidFunction) { + r = 0; + TLegacyFuture<> f((&voidf)); + f.Get(); + UNIT_ASSERT_VALUES_EQUAL(18, r); + } + + struct TSampleClass { + int mValue; + + TSampleClass(int value) + : mValue(value) + { + } + + int Calc() { + return mValue + 1; + } + }; + + Y_UNIT_TEST(TestMethod) { + TLegacyFuture<int> f11(std::bind(&TSampleClass::Calc, TSampleClass(3))); + UNIT_ASSERT_VALUES_EQUAL(4, f11.Get()); + + TLegacyFuture<int> f12(std::bind(&TSampleClass::Calc, TSampleClass(3)), SystemThreadFactory()); + UNIT_ASSERT_VALUES_EQUAL(4, f12.Get()); + + TSampleClass c(5); + + TLegacyFuture<int> f21(std::bind(&TSampleClass::Calc, std::ref(c))); + UNIT_ASSERT_VALUES_EQUAL(6, f21.Get()); + + TLegacyFuture<int> f22(std::bind(&TSampleClass::Calc, std::ref(c)), SystemThreadFactory()); + UNIT_ASSERT_VALUES_EQUAL(6, f22.Get()); + } + + struct TSomeThreadPool: public IThreadFactory {}; + + Y_UNIT_TEST(TestFunction) { + std::function<int()> f((&intf)); + + UNIT_ASSERT_VALUES_EQUAL(17, TLegacyFuture<int>(f).Get()); + UNIT_ASSERT_VALUES_EQUAL(17, TLegacyFuture<int>(f, SystemThreadFactory()).Get()); + + if (false) { + TSomeThreadPool* q = nullptr; + TLegacyFuture<int>(f, q); // just check compiles, do not start + } + } + } + +} diff --git a/library/cpp/threading/future/mt_ut/ya.make b/library/cpp/threading/future/mt_ut/ya.make new file mode 100644 index 0000000000..288fe7b6bc --- /dev/null +++ b/library/cpp/threading/future/mt_ut/ya.make @@ -0,0 +1,20 @@ +UNITTEST_FOR(library/cpp/threading/future) + +OWNER( + g:util +) + +SRCS( + future_mt_ut.cpp +) + +IF(NOT SANITIZER_TYPE) +SIZE(SMALL) + +ELSE() +SIZE(MEDIUM) + +ENDIF() + + +END() diff --git a/library/cpp/threading/future/perf/main.cpp b/library/cpp/threading/future/perf/main.cpp new file mode 100644 index 0000000000..5a0690af47 --- /dev/null +++ b/library/cpp/threading/future/perf/main.cpp @@ -0,0 +1,50 @@ +#include <library/cpp/testing/benchmark/bench.h> +#include <library/cpp/threading/future/future.h> + +#include <util/generic/string.h> +#include <util/generic/xrange.h> + +using namespace NThreading; + +template <typename T> +void TestAllocPromise(const NBench::NCpu::TParams& iface) { + for (const auto it : xrange(iface.Iterations())) { + Y_UNUSED(it); + Y_DO_NOT_OPTIMIZE_AWAY(NewPromise<T>()); + } +} + +template <typename T> +TPromise<T> SetPromise(T value) { + auto promise = NewPromise<T>(); + promise.SetValue(value); + return promise; +} + +template <typename T> +void TestSetPromise(const NBench::NCpu::TParams& iface, T value) { + for (const auto it : xrange(iface.Iterations())) { + Y_UNUSED(it); + Y_DO_NOT_OPTIMIZE_AWAY(SetPromise(value)); + } +} + +Y_CPU_BENCHMARK(AllocPromiseVoid, iface) { + TestAllocPromise<void>(iface); +} + +Y_CPU_BENCHMARK(AllocPromiseUI64, iface) { + TestAllocPromise<ui64>(iface); +} + +Y_CPU_BENCHMARK(AllocPromiseStroka, iface) { + TestAllocPromise<TString>(iface); +} + +Y_CPU_BENCHMARK(SetPromiseUI64, iface) { + TestSetPromise<ui64>(iface, 1234567890ull); +} + +Y_CPU_BENCHMARK(SetPromiseStroka, iface) { + TestSetPromise<TString>(iface, "test test test"); +} diff --git a/library/cpp/threading/future/perf/ya.make b/library/cpp/threading/future/perf/ya.make new file mode 100644 index 0000000000..943d585d4b --- /dev/null +++ b/library/cpp/threading/future/perf/ya.make @@ -0,0 +1,16 @@ +Y_BENCHMARK(library-threading-future-perf) + +OWNER( + g:rtmr + ishfb +) + +SRCS( + main.cpp +) + +PEERDIR( + library/cpp/threading/future +) + +END() diff --git a/library/cpp/threading/future/subscription/README.md b/library/cpp/threading/future/subscription/README.md new file mode 100644 index 0000000000..62c7e1303e --- /dev/null +++ b/library/cpp/threading/future/subscription/README.md @@ -0,0 +1,104 @@ +Subscriptions manager and wait primitives library +================================================= + +Wait primitives +--------------- + +All wait primitives are futures those being signaled when some or all of theirs dependencies are signaled. +Wait privimitives could be constructed either from an initializer_list or from a standard container of futures. + +1. WaitAll is signaled when all its dependencies are signaled: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_all.h> + + auto w = NWait::WaitAll({ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all futures + ``` + +2. WaitAny is signaled when any of its dependencies is signaled: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_any.h> + + auto w = NWait::WaitAny(TVector<TFuture<T>>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for any future + ``` + +3. WaitAllOrException is signaled when all its dependencies are signaled with values or any dependency is signaled with an exception: + + ```C++ + #include <library/cpp/threading/subscriptions/wait_all_or_exception.h> + + auto w = NWait::WaitAllOrException(TVector<TFuture<T>>{ future1, future2, ..., futureN }); + ... + w.Wait(); // wait for all values or for an exception + ``` + +Subscriptions manager +--------------------- + +The subscription manager can manage multiple links beetween futures and callbacks. Multiple managed subscriptions to a single future shares just a single underlying subscription to the future. That allows dynamic creation and deletion of subscriptions and efficient implementation of different wait primitives. +The subscription manager could be used in the following way: + +1. Subscribe to a single future: + + ```C++ + #include <library/cpp/threading/subscriptions/subscription.h> + + TFuture<int> LongOperation(); + + ... + auto future = LongRunnigOperation(); + auto m = MakeSubsriptionManager<int>(); + auto id = m->Subscribe(future, [](TFuture<int> const& f) { + try { + auto value = f.GetValue(); + ... + } catch (...) { + ... // handle exception + } + }); + if (id.has_value()) { + ... // Callback will run asynchronously + } else { + ... // Future has been signaled already. The callback has been invoked synchronously + } + ``` + + Note that a callback could be invoked synchronously during a Subscribe call. In this case the returned optional will have no value. + +2. Unsubscribe from a single future: + + ```C++ + // id holds the subscription id from a previous Subscribe call + m->Unsubscribe(id.value()); + ``` + + There is no need to call Unsubscribe if the callback has been called. In this case Unsubscribe will do nothing. And it is safe to call Unsubscribe with the same id multiple times. + +3. Subscribe a single callback to multiple futures: + + ```C++ + auto ids = m->Subscribe({ future1, future2, ..., futureN }, [](auto&& f) { ... }); + ... + ``` + + Futures could be passed to Subscribe method either via an initializer_list or via a standard container like vector or list. Subscribe method accept an optional boolean parameter revertOnSignaled. If the parameter is false (default) then all suscriptions will be performed regardless of the futures states and the returned vector will have a subscription id for each future (even if callback has been executed synchronously for some futures). Otherwise the method will stop on the first signaled future (the callback will be synchronously called for it), no suscriptions will be created and an empty vector will be returned. + +4. Unsubscribe multiple subscriptions: + + ```C++ + // ids is the vector or subscription ids + m->Unsubscribe(ids); + ``` + + The vector of IDs could be a result of a previous Subscribe call or an arbitrary set of IDs of previously created subscriptions. + +5. If you do not want to instantiate a new instance of the subscription manager it is possible to use the default instance: + + ```C++ + auto m = TSubscriptionManager<T>::Default(); + ``` diff --git a/library/cpp/threading/future/subscription/subscription-inl.h b/library/cpp/threading/future/subscription/subscription-inl.h new file mode 100644 index 0000000000..a45d8999d3 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription-inl.h @@ -0,0 +1,118 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H) +#error "you should never include subscription-inl.h directly" +#endif + +namespace NThreading { + +namespace NPrivate { + +template <typename T> +TFutureStateId CheckedStateId(TFuture<T> const& future) { + auto const id = future.StateId(); + if (id.Defined()) { + return *id; + } + ythrow TFutureException() << "Future state should be initialized"; +} + +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TSubscriptionManager::TSubscription::TSubscription(TFuture<T> future, F&& callback, TCallbackExecutor&& executor) + : Callback( + [future = std::move(future), callback = std::forward<F>(callback), executor = std::forward<TCallbackExecutor>(executor)]() mutable { + executor(std::as_const(future), callback); + }) +{ +} + +template <typename T, typename F, typename TCallbackExecutor> +inline std::optional<TSubscriptionId> TSubscriptionManager::Subscribe(TFuture<T> const& future, F&& callback, TCallbackExecutor&& executor) { + auto stateId = NPrivate::CheckedStateId(future); + with_lock(Lock) { + auto const status = TrySubscribe(future, std::forward<F>(callback), stateId, std::forward<TCallbackExecutor>(executor)); + switch (status) { + case ECallbackStatus::Subscribed: + return TSubscriptionId(stateId, Revision); + case ECallbackStatus::ExecutedSynchronously: + return {}; + default: + Y_FAIL("Unexpected callback status"); + } + } +} + +template <typename TFutures, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled + , TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward<F>(callback), revertOnSignaled, std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::Subscribe(std::initializer_list<TFuture<T> const> futures, F&& callback + , bool revertOnSignaled, TCallbackExecutor&& executor) +{ + return SubscribeImpl(futures, std::forward<F>(callback), revertOnSignaled, std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename F, typename TCallbackExecutor> +inline TSubscriptionManager::ECallbackStatus TSubscriptionManager::TrySubscribe(TFuture<T> const& future, F&& callback, TFutureStateId stateId + , TCallbackExecutor&& executor) +{ + TSubscription subscription(future, std::forward<F>(callback), std::forward<TCallbackExecutor>(executor)); + auto const it = Subscriptions.find(stateId); + auto const revision = ++Revision; + if (it == std::end(Subscriptions)) { + auto const success = Subscriptions.emplace(stateId, THashMap<ui64, TSubscription>{ { revision, std::move(subscription) } }).second; + Y_VERIFY(success); + auto self = TSubscriptionManagerPtr(this); + future.Subscribe([self, stateId](TFuture<T> const&) { self->OnCallback(stateId); }); + if (Subscriptions.find(stateId) == std::end(Subscriptions)) { + return ECallbackStatus::ExecutedSynchronously; + } + } else { + Y_VERIFY(it->second.emplace(revision, std::move(subscription)).second); + } + return ECallbackStatus::Subscribed; +} + +template <typename TFutures, typename F, typename TCallbackExecutor> +inline TVector<TSubscriptionId> TSubscriptionManager::SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor) +{ + TVector<TSubscriptionId> results; + results.reserve(std::size(futures)); + // resolve all state ids to minimize processing under the lock + for (auto const& f : futures) { + results.push_back(TSubscriptionId(NPrivate::CheckedStateId(f), 0)); + } + with_lock(Lock) { + size_t i = 0; + for (auto const& f : futures) { + auto& r = results[i]; + auto const status = TrySubscribe(f, callback, r.StateId(), executor); + switch (status) { + case ECallbackStatus::Subscribed: + break; + case ECallbackStatus::ExecutedSynchronously: + if (revertOnSignaled) { + // revert + results.crop(i); + UnsubscribeImpl(results); + return {}; + } + break; + default: + Y_FAIL("Unexpected callback status"); + } + r.SetSubId(Revision); + ++i; + } + } + return results; +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.cpp b/library/cpp/threading/future/subscription/subscription.cpp new file mode 100644 index 0000000000..a98b4a4f03 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.cpp @@ -0,0 +1,65 @@ +#include "subscription.h" + +namespace NThreading { + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return l.StateId() == r.StateId() && l.SubId() == r.SubId(); +} + +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept { + return !(l == r); +} + +void TSubscriptionManager::TSubscription::operator()() { + Callback(); +} + +TSubscriptionManagerPtr TSubscriptionManager::NewInstance() { + return new TSubscriptionManager(); +} + +TSubscriptionManagerPtr TSubscriptionManager::Default() { + static auto instance = NewInstance(); + return instance; +} + +void TSubscriptionManager::Unsubscribe(TSubscriptionId id) { + with_lock(Lock) { + UnsubscribeImpl(id); + } +} + +void TSubscriptionManager::Unsubscribe(TVector<TSubscriptionId> const& ids) { + with_lock(Lock) { + UnsubscribeImpl(ids); + } +} + +void TSubscriptionManager::OnCallback(TFutureStateId stateId) noexcept { + THashMap<ui64, TSubscription> subscriptions; + with_lock(Lock) { + auto const it = Subscriptions.find(stateId); + Y_VERIFY(it != Subscriptions.end(), "The callback has been triggered more than once"); + subscriptions.swap(it->second); + Subscriptions.erase(it); + } + for (auto& [_, subscription] : subscriptions) { + subscription(); + } +} + +void TSubscriptionManager::UnsubscribeImpl(TSubscriptionId id) { + auto const it = Subscriptions.find(id.StateId()); + if (it == std::end(Subscriptions)) { + return; + } + it->second.erase(id.SubId()); +} + +void TSubscriptionManager::UnsubscribeImpl(TVector<TSubscriptionId> const& ids) { + for (auto const& id : ids) { + UnsubscribeImpl(id); + } +} + +} diff --git a/library/cpp/threading/future/subscription/subscription.h b/library/cpp/threading/future/subscription/subscription.h new file mode 100644 index 0000000000..afe5eda711 --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription.h @@ -0,0 +1,186 @@ +#pragma once + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/hash.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/system/mutex.h> + +#include <functional> +#include <optional> +#include <utility> + +namespace NThreading { + +namespace NPrivate { + +struct TNoexceptExecutor { + template <typename T, typename F> + void operator()(TFuture<T> const& future, F&& callee) const noexcept { + return callee(future); + } +}; + +} + +class TSubscriptionManager; + +using TSubscriptionManagerPtr = TIntrusivePtr<TSubscriptionManager>; + +//! A subscription id +class TSubscriptionId { +private: + TFutureStateId StateId_; + ui64 SubId_; // Secondary id to make the whole subscription id unique + + friend class TSubscriptionManager; + +public: + TFutureStateId StateId() const noexcept { + return StateId_; + } + + ui64 SubId() const noexcept { + return SubId_; + } + +private: + TSubscriptionId(TFutureStateId stateId, ui64 subId) + : StateId_(stateId) + , SubId_(subId) + { + } + + void SetSubId(ui64 subId) noexcept { + SubId_ = subId; + } +}; + +bool operator==(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; +bool operator!=(TSubscriptionId const& l, TSubscriptionId const& r) noexcept; + +//! The subscription manager manages subscriptions to futures +/** It provides an ability to create (and drop) multiple subscriptions to any future + with just a single underlying subscription per future. + + When a future is signaled all its subscriptions are removed. + So, there no need to call Unsubscribe for subscriptions to already signaled futures. + + Warning!!! For correct operation this class imposes the following requirement to futures/promises: + Any used future must be signaled (value or exception set) before the future state destruction. + Otherwise subscriptions and futures may happen. + Current future design does not provide the required guarantee. But that should be fixed soon. +**/ +class TSubscriptionManager final : public TAtomicRefCount<TSubscriptionManager> { +private: + //! A single subscription + class TSubscription { + private: + std::function<void()> Callback; + + public: + template <typename T, typename F, typename TCallbackExecutor> + TSubscription(TFuture<T> future, F&& callback, TCallbackExecutor&& executor); + + void operator()(); + }; + + struct TFutureStateIdHash { + size_t operator()(TFutureStateId const id) const noexcept { + auto const value = id.Value(); + return ::hash<decltype(value)>()(value); + } + }; + +private: + THashMap<TFutureStateId, THashMap<ui64, TSubscription>, TFutureStateIdHash> Subscriptions; + ui64 Revision = 0; + TMutex Lock; + +public: + //! Creates a new subscription manager instance + static TSubscriptionManagerPtr NewInstance(); + + //! The default subscription manager instance + static TSubscriptionManagerPtr Default(); + + //! Attempts to subscribe the callback to the future + /** Subscription should succeed if the future is not signaled yet. + Otherwise the callback will be called synchronously and nullopt will be returned + + @param future - The future to subscribe to + @param callback - The callback to attach + @return The subscription id on success, nullopt if the future has been signaled already + **/ + template <typename T, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + std::optional<TSubscriptionId> Subscribe(TFuture<T> const& future, F&& callback + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscription with the given id + /** @param id - The subscription id + **/ + void Unsubscribe(TSubscriptionId id); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template <typename TFutures, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + TVector<TSubscriptionId> Subscribe(TFutures const& futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Attempts to subscribe the callback to the set of futures + /** @param futures - The futures to subscribe to + @param callback - The callback to attach + @param revertOnSignaled - Shows whether to stop and revert the subscription process if one of the futures is in signaled state + @return The vector of subscription ids if no revert happened or an empty vector otherwise + A subscription id will be valid even if a corresponding future has been signaled + **/ + template <typename T, typename F, typename TCallbackExecutor = NPrivate::TNoexceptExecutor> + TVector<TSubscriptionId> Subscribe(std::initializer_list<TFuture<T> const> futures, F&& callback, bool revertOnSignaled = false + , TCallbackExecutor&& executor = NPrivate::TNoexceptExecutor()); + + //! Drops the subscriptions with the given ids + /** @param ids - The subscription ids + **/ + void Unsubscribe(TVector<TSubscriptionId> const& ids); + +private: + enum class ECallbackStatus { + Subscribed, //! A subscription has been created. The callback will be called asynchronously. + ExecutedSynchronously //! A callback has been called synchronously. No subscription has been created + }; + +private: + //! .ctor + TSubscriptionManager() = default; + //! Processes a callback from a future + void OnCallback(TFutureStateId stateId) noexcept; + //! Attempts to create a subscription + /** This method should be called under the lock + **/ + template <typename T, typename F, typename TCallbackExecutor> + ECallbackStatus TrySubscribe(TFuture<T> const& future, F&& callback, TFutureStateId stateId, TCallbackExecutor&& executor); + //! Batch subscribe implementation + template <typename TFutures, typename F, typename TCallbackExecutor> + TVector<TSubscriptionId> SubscribeImpl(TFutures const& futures, F const& callback, bool revertOnSignaled + , TCallbackExecutor const& executor); + //! Unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TSubscriptionId id); + //! Batch unsubscribe implementation + /** This method should be called under the lock + **/ + void UnsubscribeImpl(TVector<TSubscriptionId> const& ids); +}; + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H +#include "subscription-inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_SUBSCRIPTION_INL_H diff --git a/library/cpp/threading/future/subscription/subscription_ut.cpp b/library/cpp/threading/future/subscription/subscription_ut.cpp new file mode 100644 index 0000000000..d018ea15cc --- /dev/null +++ b/library/cpp/threading/future/subscription/subscription_ut.cpp @@ -0,0 +1,432 @@ +#include "subscription.h" + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TSubscriptionManagerTest) { + + Y_UNIT_TEST(TestSubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto id = m->Subscribe(f, [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(!id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledAndSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount1, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeUnsignaledUnsubscribeSignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestUnsubscribeTwice) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount = 0; + auto id = m->Subscribe(p.GetFuture(), [&callCount](auto&&) { ++callCount; } ); + UNIT_ASSERT(id.has_value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + m->Unsubscribe(id.value()); + UNIT_ASSERT_EQUAL(callCount, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 0); + } + + Y_UNIT_TEST(TestSubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeOneSignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto f = MakeFuture(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeOneUnsignaledManyTimes) { + auto m = TSubscriptionManager::NewInstance(); + auto p = NewPromise(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + m->Unsubscribe(id3.value()); + m->Unsubscribe(id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + + p.SetValue(); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 1); + } + + Y_UNIT_TEST(TestSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p1.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_UNEQUAL(id2.value(), id3.value()); + UNIT_ASSERT_UNEQUAL(id3.value(), id1.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManySignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(f1, [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(f2, [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f2, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(!id1.has_value()); + UNIT_ASSERT(!id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeManyMixed) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(f, [&callCount3](auto&&) { ++callCount3; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(!id3.has_value()); + UNIT_ASSERT_UNEQUAL(id1.value(), id2.value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 1); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount1, 1); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + } + + Y_UNIT_TEST(TestSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + + size_t callCount1 = 0; + auto id1 = m->Subscribe(p1.GetFuture(), [&callCount1](auto&&) { ++callCount1; } ); + size_t callCount2 = 0; + auto id2 = m->Subscribe(p2.GetFuture(), [&callCount2](auto&&) { ++callCount2; } ); + size_t callCount3 = 0; + auto id3 = m->Subscribe(p3.GetFuture(), [&callCount3](auto&&) { ++callCount3; } ); + size_t callCount4 = 0; + auto id4 = m->Subscribe(p2.GetFuture(), [&callCount4](auto&&) { ++callCount4; } ); + size_t callCount5 = 0; + auto id5 = m->Subscribe(p1.GetFuture(), [&callCount5](auto&&) { ++callCount5; } ); + + UNIT_ASSERT(id1.has_value()); + UNIT_ASSERT(id2.has_value()); + UNIT_ASSERT(id3.has_value()); + UNIT_ASSERT(id4.has_value()); + UNIT_ASSERT(id5.has_value()); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 0); + + m->Unsubscribe(id1.value()); + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 0); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + m->Unsubscribe(id4.value()); + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 0); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount1, 0); + UNIT_ASSERT_EQUAL(callCount2, 1); + UNIT_ASSERT_EQUAL(callCount3, 1); + UNIT_ASSERT_EQUAL(callCount4, 0); + UNIT_ASSERT_EQUAL(callCount5, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyUnsignaled) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), p1.GetFuture() }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(33); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(111); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManySignaledRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto f1 = MakeFuture(0); + auto f2 = MakeFuture(1); + + size_t callCount = 0; + auto ids = m->Subscribe({ f1, f2, f1 }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedNoRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), p2.GetFuture(), f }, [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 3); + UNIT_ASSERT_UNEQUAL(ids[0], ids[1]); + UNIT_ASSERT_UNEQUAL(ids[1], ids[2]); + UNIT_ASSERT_UNEQUAL(ids[2], ids[0]); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(45); + UNIT_ASSERT_EQUAL(callCount, 2); + + p2.SetValue(-7); + UNIT_ASSERT_EQUAL(callCount, 3); + } + + Y_UNIT_TEST(TestBulkSubscribeManyMixedRevert) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + + size_t callCount = 0; + auto ids = m->Subscribe({ p1.GetFuture(), f, p2.GetFuture() }, [&callCount](auto&&) { ++callCount; }, true); + + UNIT_ASSERT(ids.empty()); + UNIT_ASSERT_EQUAL(callCount, 1); + + p1.SetValue(); + p2.SetValue(); + UNIT_ASSERT_EQUAL(callCount, 1); + } + + Y_UNIT_TEST(TestBulkSubscribeUnsubscribeMany) { + auto m = TSubscriptionManager::NewInstance(); + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + + size_t callCount = 0; + auto ids = m->Subscribe( + TVector<TFuture<int>>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture(), p2.GetFuture(), p1.GetFuture() } + , [&callCount](auto&&) { ++callCount; } ); + + UNIT_ASSERT_EQUAL(ids.size(), 5); + UNIT_ASSERT_EQUAL(callCount, 0); + + m->Unsubscribe(TVector<TSubscriptionId>{ ids[0], ids[3] }); + UNIT_ASSERT_EQUAL(callCount, 0); + + p1.SetValue(-1); + UNIT_ASSERT_EQUAL(callCount, 1); + + p2.SetValue(23); + UNIT_ASSERT_EQUAL(callCount, 2); + + p3.SetValue(100500); + UNIT_ASSERT_EQUAL(callCount, 3); + } +} diff --git a/library/cpp/threading/future/subscription/ut/ya.make b/library/cpp/threading/future/subscription/ut/ya.make new file mode 100644 index 0000000000..45210f7bd7 --- /dev/null +++ b/library/cpp/threading/future/subscription/ut/ya.make @@ -0,0 +1,17 @@ +UNITTEST_FOR(library/cpp/threading/future/subscription) + +OWNER( + g:kwyt + g:rtmr + ishfb +) + +SRCS( + subscription_ut.cpp + wait_all_ut.cpp + wait_all_or_exception_ut.cpp + wait_any_ut.cpp + wait_ut_common.cpp +) + +END() diff --git a/library/cpp/threading/future/subscription/wait.h b/library/cpp/threading/future/subscription/wait.h new file mode 100644 index 0000000000..533bab9d8d --- /dev/null +++ b/library/cpp/threading/future/subscription/wait.h @@ -0,0 +1,119 @@ +#pragma once + +#include "subscription.h" + +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/system/spinlock.h> + + +#include <initializer_list> + +namespace NThreading::NPrivate { + +template <typename TDerived> +class TWait : public TThrRefBase { +private: + TSubscriptionManagerPtr Manager; + TVector<TSubscriptionId> Subscriptions; + bool Unsubscribed = false; + +protected: + TAdaptiveLock Lock; + TPromise<void> Promise; + +public: + template <typename TFutures, typename TCallbackExecutor> + static TFuture<void> Make(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + TIntrusivePtr<TDerived> w(new TDerived(std::move(manager))); + w->Subscribe(futures, std::forward<TCallbackExecutor>(executor)); + return w->Promise.GetFuture(); + } + +protected: + TWait(TSubscriptionManagerPtr manager) + : Manager(std::move(manager)) + , Subscriptions() + , Unsubscribed(false) + , Lock() + , Promise(NewPromise()) + { + Y_ENSURE(Manager != nullptr); + } + +protected: + //! Unsubscribes all existing subscriptions + /** Lock should be acquired! + **/ + void Unsubscribe() noexcept { + if (Unsubscribed) { + return; + } + Unsubscribe(Subscriptions); + Subscriptions.clear(); + } + +private: + //! Performs a subscription to the given futures + /** Lock should not be acquired! + @param future - The futures to subscribe to + @param callback - The callback to call for each future + **/ + template <typename TFutures, typename TCallbackExecutor> + void Subscribe(TFutures const& futures, TCallbackExecutor&& executor) { + auto self = TIntrusivePtr<TDerived>(static_cast<TDerived*>(this)); + self->BeforeSubscribe(futures); + auto callback = [self = std::move(self)](const auto& future) mutable { + self->Set(future); + }; + auto subscriptions = Manager->Subscribe(futures, callback, TDerived::RevertOnSignaled, std::forward<TCallbackExecutor>(executor)); + if (subscriptions.empty()) { + return; + } + with_lock (Lock) { + if (Unsubscribed) { + Unsubscribe(subscriptions); + } else { + Subscriptions = std::move(subscriptions); + } + } + } + + void Unsubscribe(TVector<TSubscriptionId>& subscriptions) noexcept { + Manager->Unsubscribe(subscriptions); + Unsubscribed = true; + } +}; + +template <typename TWaiter, typename TFutures, typename TCallbackExecutor> +TFuture<void> Wait(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); + } +} + +template <typename TWaiter, typename T, typename TCallbackExecutor> +TFuture<void> Wait(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + switch (std::size(futures)) { + case 0: + return MakeFuture(); + case 1: + return std::begin(futures)->IgnoreResult(); + default: + return TWaiter::Make(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); + } +} + + +template <typename TWaiter, typename T, typename TCallbackExecutor> +TFuture<void> Wait(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return TWaiter::Make(std::initializer_list<TFuture<T> const>({ future1, future2 }), std::move(manager) + , std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all.cpp b/library/cpp/threading/future/subscription/wait_all.cpp new file mode 100644 index 0000000000..10e7ee7598 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.cpp @@ -0,0 +1 @@ +#include "wait_all.h" diff --git a/library/cpp/threading/future/subscription/wait_all.h b/library/cpp/threading/future/subscription/wait_all.h new file mode 100644 index 0000000000..5c0d2bb862 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all.h @@ -0,0 +1,23 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAll(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAll(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H +#include "wait_all_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_inl.h b/library/cpp/threading/future/subscription/wait_all_inl.h new file mode 100644 index 0000000000..a3b665f642 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_inl.h @@ -0,0 +1,80 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_INL_H) +#error "you should never include wait_all_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAll final : public NThreading::NPrivate::TWait<TWaitAll> { +private: + size_t Count = 0; + std::exception_ptr Exception; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait<TWaitAll>; + friend TBase; + +private: + TWaitAll(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + , Exception() + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + if (!Exception) { + try { + future.TryRethrow(); + } catch (...) { + Exception = std::current_exception(); + } + } + + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + Y_ASSERT(!TBase::Promise.HasValue() && !TBase::Promise.HasException()); + if (Exception) { + TBase::Promise.SetException(std::move(Exception)); + } else { + TBase::Promise.SetValue(); + } + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAll(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAll>(future1, future2, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp new file mode 100644 index 0000000000..0c73ddeb84 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.cpp @@ -0,0 +1 @@ +#include "wait_all_or_exception.h" diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception.h b/library/cpp/threading/future/subscription/wait_all_or_exception.h new file mode 100644 index 0000000000..e3e0caf2f8 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception.h @@ -0,0 +1,25 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(std::initializer_list<TFuture<T> const> futures + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAllOrException(TFuture<T> const& future1, TFuture<T> const& future2 + , TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H +#include "wait_all_or_exception_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h new file mode 100644 index 0000000000..fcd9782d54 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_inl.h @@ -0,0 +1,79 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ALL_OR_EXCEPTION_INL_H) +#error "you should never include wait_all_or_exception_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAllOrException final : public NThreading::NPrivate::TWait<TWaitAllOrException> +{ +private: + size_t Count = 0; + + static constexpr bool RevertOnSignaled = false; + + using TBase = NThreading::NPrivate::TWait<TWaitAllOrException>; + friend TBase; + +private: + TWaitAllOrException(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + , Count(0) + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Count = std::size(futures); + Y_ENSURE(Count > 0, "It is meaningless to use this class with empty futures set"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + try { + future.TryRethrow(); + if (--Count == 0) { + // there is no need to call Unsubscribe here since all futures are signaled + TBase::Promise.SetValue(); + } + } catch (...) { + Y_ASSERT(!TBase::Promise.HasValue()); + TBase::Unsubscribe(); + if (!TBase::Promise.HasException()) { + TBase::Promise.SetException(std::current_exception()); + } + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAllOrException(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager + , TCallbackExecutor&& executor) +{ + return NThreading::NPrivate::Wait<NPrivate::TWaitAllOrException>(future1, future2, std::move(manager) + , std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp new file mode 100644 index 0000000000..34ae9edb4e --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_or_exception_ut.cpp @@ -0,0 +1,167 @@ +#include "wait_all_or_exception.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <atomic> +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllOrExceptionTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAllOrException(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAllOrException(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAllOrException(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAllOrException(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAllOrException({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAllOrException(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAllOrException(TVector<TFuture<int>>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p2.SetValue(34); + } + + Y_UNIT_TEST(TestManyWithVectorAndIntialError) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + constexpr TStringBuf message = "Test exception 5"; + auto f = MakeErrorFuture<void>(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAllOrException(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic<size_t> index = 0; + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAllOrException(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_all_ut.cpp b/library/cpp/threading/future/subscription/wait_all_ut.cpp new file mode 100644 index 0000000000..3bc9762671 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_all_ut.cpp @@ -0,0 +1,161 @@ +#include "wait_all.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <atomic> +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAllTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAll(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(-11); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(p.GetFuture(), f); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAll(f, p.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 2"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAll(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAll(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAll({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAll(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + p2.SetValue(-3); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestManyWithVector) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAll(TVector<TFuture<int>>{ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p1.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p2.SetValue(34); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(42); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAll(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAll(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + e = std::make_exception_ptr(yexception() << "Test exception 6"); + std::atomic<size_t> index = 0; + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAll(futures); } + , [e, &index](size_t size) { + auto exceptionIndex = size / 2; + index = 0; + return [e, exceptionIndex, &index](auto&& p) { + if (index++ == exceptionIndex) { + p.SetException(e); + } else { + p.SetValue(index); + } + }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_any.cpp b/library/cpp/threading/future/subscription/wait_any.cpp new file mode 100644 index 0000000000..57cc1b2c25 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.cpp @@ -0,0 +1 @@ +#include "wait_any.h" diff --git a/library/cpp/threading/future/subscription/wait_any.h b/library/cpp/threading/future/subscription/wait_any.h new file mode 100644 index 0000000000..e770d7b59e --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any.h @@ -0,0 +1,23 @@ +#pragma once + +#include "wait.h" + +namespace NThreading::NWait { + +template <typename TFutures, typename TCallbackExecutor> +TFuture<void> WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAny(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +template <typename T, typename TCallbackExecutor> +TFuture<void> WaitAny(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager = TSubscriptionManager::Default() + , TCallbackExecutor&& executor = TCallbackExecutor()); + +} + +#define INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H +#include "wait_any_inl.h" +#undef INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H diff --git a/library/cpp/threading/future/subscription/wait_any_inl.h b/library/cpp/threading/future/subscription/wait_any_inl.h new file mode 100644 index 0000000000..e80822bfc9 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_inl.h @@ -0,0 +1,64 @@ +#pragma once + +#if !defined(INCLUDE_LIBRARY_THREADING_FUTURE_WAIT_ANY_INL_H) +#error "you should never include wait_any_inl.h directly" +#endif + +#include "subscription.h" + +#include <initializer_list> + +namespace NThreading::NWait { + +namespace NPrivate { + +class TWaitAny final : public NThreading::NPrivate::TWait<TWaitAny> { +private: + static constexpr bool RevertOnSignaled = true; + + using TBase = NThreading::NPrivate::TWait<TWaitAny>; + friend TBase; + +private: + TWaitAny(TSubscriptionManagerPtr manager) + : TBase(std::move(manager)) + { + } + + template <typename TFutures> + void BeforeSubscribe(TFutures const& futures) { + Y_ENSURE(std::size(futures) > 0, "Futures set cannot be empty"); + } + + template <typename T> + void Set(TFuture<T> const& future) { + with_lock (TBase::Lock) { + TBase::Unsubscribe(); + try { + future.TryRethrow(); + TBase::Promise.TrySetValue(); + } catch (...) { + TBase::Promise.TrySetException(std::current_exception()); + } + } + } +}; + +} + +template <typename TFutures, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(TFutures const& futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(std::initializer_list<TFuture<T> const> futures, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(futures, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +template <typename T, typename TCallbackExecutor = NThreading::NPrivate::TNoexceptExecutor> +TFuture<void> WaitAny(TFuture<T> const& future1, TFuture<T> const& future2, TSubscriptionManagerPtr manager, TCallbackExecutor&& executor) { + return NThreading::NPrivate::Wait<NPrivate::TWaitAny>(future1, future2, std::move(manager), std::forward<TCallbackExecutor>(executor)); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_any_ut.cpp b/library/cpp/threading/future/subscription/wait_any_ut.cpp new file mode 100644 index 0000000000..262080e8d1 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_any_ut.cpp @@ -0,0 +1,166 @@ +#include "wait_any.h" +#include "wait_ut_common.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/strbuf.h> + +#include <exception> + +using namespace NThreading; + +Y_UNIT_TEST_SUITE(TWaitAnyTest) { + + Y_UNIT_TEST(TestTwoUnsignaled) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(10); + UNIT_ASSERT(w.HasValue()); + p2.SetValue(1); + } + + Y_UNIT_TEST(TestTwoUnsignaledWithException) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto w = NWait::WaitAny(p1.GetFuture(), p2.GetFuture()); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(-11); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaled) { + auto p = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(p.GetFuture(), f); + UNIT_ASSERT(w.HasValue()); + + p.SetValue(); + } + + Y_UNIT_TEST(TestOneUnsignaledOneSignaledWithException) { + auto p = NewPromise(); + constexpr TStringBuf message = "Test exception 2"; + auto f = MakeErrorFuture<void>(std::make_exception_ptr(yexception() << message)); + auto w = NWait::WaitAny(f, p.GetFuture()); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p.SetValue(); + } + + Y_UNIT_TEST(TestEmptyInitializer) { + auto w = NWait::WaitAny(std::initializer_list<TFuture<void> const>({})); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestEmptyVector) { + auto w = NWait::WaitAny(TVector<TFuture<int>>()); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithInitializer) { + auto p = NewPromise<int>(); + auto w = NWait::WaitAny({ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p.SetValue(1); + UNIT_ASSERT(w.HasValue()); + } + + Y_UNIT_TEST(TestOneUnsignaledWithVector) { + auto p = NewPromise(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 3"; + p.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + } + + Y_UNIT_TEST(TestManyUnsignaledWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto p3 = NewPromise<int>(); + auto w = NWait::WaitAny({ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + p1.SetValue(42); + UNIT_ASSERT(w.HasValue()); + + p2.SetValue(-3); + p3.SetValue(12); + } + + Y_UNIT_TEST(TestManyMixedWithInitializer) { + auto p1 = NewPromise<int>(); + auto p2 = NewPromise<int>(); + auto f = MakeFuture(42); + auto w = NWait::WaitAny({ p1.GetFuture(), f, p2.GetFuture() }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(10); + p2.SetValue(-3); + } + + + Y_UNIT_TEST(TestManyUnsignaledWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto p3 = NewPromise(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), p3.GetFuture() }); + UNIT_ASSERT(!w.HasValue() && !w.HasException()); + + constexpr TStringBuf message = "Test exception 4"; + p2.SetException(std::make_exception_ptr(yexception() << message)); + UNIT_ASSERT_EXCEPTION_SATISFIES(w.TryRethrow(), yexception, [message](auto const& e) { + return message == e.what(); + }); + + p1.SetValue(); + p3.SetValue(); + } + + + Y_UNIT_TEST(TestManyMixedWithVector) { + auto p1 = NewPromise(); + auto p2 = NewPromise(); + auto f = MakeFuture(); + auto w = NWait::WaitAny(TVector<TFuture<void>>{ p1.GetFuture(), p2.GetFuture(), f }); + UNIT_ASSERT(w.HasValue()); + + p1.SetValue(); + p2.SetValue(); + } + + Y_UNIT_TEST(TestManyStress) { + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + + NTest::TestManyStress<int>([](auto&& futures) { return NWait::WaitAny(futures); } + , [](size_t) { + return [](auto&& p) { p.SetValue(22); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasValue()); }); + auto e = std::make_exception_ptr(yexception() << "Test exception 5"); + NTest::TestManyStress<void>([](auto&& futures) { return NWait::WaitAny(futures); } + , [e](size_t) { + return [e](auto&& p) { p.SetException(e); }; + } + , [](auto&& waiter) { UNIT_ASSERT(waiter.HasException()); }); + } + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.cpp b/library/cpp/threading/future/subscription/wait_ut_common.cpp new file mode 100644 index 0000000000..9f961e7303 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.cpp @@ -0,0 +1,26 @@ +#include "wait_ut_common.h" + +#include <util/random/shuffle.h> +#include <util/system/event.h> +#include <util/thread/pool.h> + +namespace NThreading::NTest::NPrivate { + +void ExecuteAndWait(TVector<std::function<void()>> jobs, TFuture<void> waiter, size_t threads) { + Y_ENSURE(threads > 0); + Shuffle(jobs.begin(), jobs.end()); + auto pool = CreateThreadPool(threads); + TManualEvent start; + for (auto& j : jobs) { + pool->SafeAddFunc( + [&start, job = std::move(j)]() { + start.WaitI(); + job(); + }); + } + start.Signal(); + waiter.Wait(); + pool->Stop(); +} + +} diff --git a/library/cpp/threading/future/subscription/wait_ut_common.h b/library/cpp/threading/future/subscription/wait_ut_common.h new file mode 100644 index 0000000000..99530dd1f6 --- /dev/null +++ b/library/cpp/threading/future/subscription/wait_ut_common.h @@ -0,0 +1,56 @@ +#pragma once + +#include <library/cpp/threading/future/future.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/vector.h> + +#include <functional> +#include <type_traits> + +namespace NThreading::NTest { + +namespace NPrivate { + +void ExecuteAndWait(TVector<std::function<void()>> jobs, TFuture<void> waiter, size_t threads); + +template <typename TPromises, typename FSetter> +void SetConcurrentAndWait(TPromises&& promises, FSetter&& setter, TFuture<void> waiter, size_t threads = 8) { + TVector<std::function<void()>> jobs; + jobs.reserve(std::size(promises)); + for (auto& p : promises) { + jobs.push_back([p, setter]() mutable {setter(p); }); + } + ExecuteAndWait(std::move(jobs), std::move(waiter), threads); +} + +template <typename T> +auto MakePromise() { + if constexpr (std::is_same_v<T, void>) { + return NewPromise(); + } + return NewPromise<T>(); +} + +} + +template <typename T, typename FWaiterFactory, typename FSetterFactory, typename FChecker> +void TestManyStress(FWaiterFactory&& waiterFactory, FSetterFactory&& setterFactory, FChecker&& checker) { + for (size_t i : { 1, 2, 4, 8, 16, 32, 64, 128, 256 }) { + TVector<TPromise<T>> promises; + TVector<TFuture<T>> futures; + promises.reserve(i); + futures.reserve(i); + for (size_t j = 0; j < i; ++j) { + auto promise = NPrivate::MakePromise<T>(); + futures.push_back(promise.GetFuture()); + promises.push_back(std::move(promise)); + } + auto waiter = waiterFactory(futures); + NPrivate::SetConcurrentAndWait(std::move(promises), [valueSetter = setterFactory(i)](auto&& p) { valueSetter(p); } + , waiter); + checker(waiter); + } +} + +} diff --git a/library/cpp/threading/future/subscription/ya.make b/library/cpp/threading/future/subscription/ya.make new file mode 100644 index 0000000000..cb75731dbf --- /dev/null +++ b/library/cpp/threading/future/subscription/ya.make @@ -0,0 +1,24 @@ +OWNER( + g:kwyt + g:rtmr + ishfb +) + +LIBRARY() + +SRCS( + subscription.cpp + wait_all.cpp + wait_all_or_exception.cpp + wait_any.cpp +) + +PEERDIR( + library/cpp/threading/future +) + +END() + +RECURSE_FOR_TESTS( + ut +) diff --git a/library/cpp/threading/future/ut/ya.make b/library/cpp/threading/future/ut/ya.make new file mode 100644 index 0000000000..566b622370 --- /dev/null +++ b/library/cpp/threading/future/ut/ya.make @@ -0,0 +1,14 @@ +UNITTEST_FOR(library/cpp/threading/future) + +OWNER( + g:rtmr + ishfb +) + +SRCS( + async_ut.cpp + future_ut.cpp + legacy_future_ut.cpp +) + +END() diff --git a/library/cpp/threading/future/wait/fwd.cpp b/library/cpp/threading/future/wait/fwd.cpp new file mode 100644 index 0000000000..4214b6df83 --- /dev/null +++ b/library/cpp/threading/future/wait/fwd.cpp @@ -0,0 +1 @@ +#include "fwd.h" diff --git a/library/cpp/threading/future/wait/fwd.h b/library/cpp/threading/future/wait/fwd.h new file mode 100644 index 0000000000..de3b1313d5 --- /dev/null +++ b/library/cpp/threading/future/wait/fwd.h @@ -0,0 +1 @@ +// empty (for now) diff --git a/library/cpp/threading/future/wait/wait-inl.h b/library/cpp/threading/future/wait/wait-inl.h new file mode 100644 index 0000000000..2753d5446c --- /dev/null +++ b/library/cpp/threading/future/wait/wait-inl.h @@ -0,0 +1,36 @@ +#pragma once + +#if !defined(INCLUDE_FUTURE_INL_H) +#error "you should never include wait-inl.h directly" +#endif // INCLUDE_FUTURE_INL_H + +namespace NThreading { + namespace NImpl { + template <typename TContainer> + TVector<TFuture<void>> ToVoidFutures(const TContainer& futures) { + TVector<TFuture<void>> voidFutures; + voidFutures.reserve(futures.size()); + + for (const auto& future: futures) { + voidFutures.push_back(future.IgnoreResult()); + } + + return voidFutures; + } + } + + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitAll(const TContainer& futures) { + return WaitAll(NImpl::ToVoidFutures(futures)); + } + + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitExceptionOrAll(const TContainer& futures) { + return WaitExceptionOrAll(NImpl::ToVoidFutures(futures)); + } + + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitAny(const TContainer& futures) { + return WaitAny(NImpl::ToVoidFutures(futures)); + } +} diff --git a/library/cpp/threading/future/wait/wait.cpp b/library/cpp/threading/future/wait/wait.cpp new file mode 100644 index 0000000000..a173833a7f --- /dev/null +++ b/library/cpp/threading/future/wait/wait.cpp @@ -0,0 +1,82 @@ +#include "wait.h" + +#include "wait_group.h" +#include "wait_policy.h" + +namespace NThreading { + namespace { + template <class WaitPolicy> + TFuture<void> WaitGeneric(const TFuture<void>& f1) { + return f1; + } + + template <class WaitPolicy> + TFuture<void> WaitGeneric(const TFuture<void>& f1, const TFuture<void>& f2) { + TWaitGroup<WaitPolicy> wg; + + wg.Add(f1).Add(f2); + + return std::move(wg).Finish(); + } + + template <class WaitPolicy> + TFuture<void> WaitGeneric(TArrayRef<const TFuture<void>> futures) { + if (futures.empty()) { + return MakeFuture(); + } + if (futures.size() == 1) { + return futures.front(); + } + + TWaitGroup<WaitPolicy> wg; + for (const auto& fut : futures) { + wg.Add(fut); + } + + return std::move(wg).Finish(); + } + } + + //////////////////////////////////////////////////////////////////////////////// + + TFuture<void> WaitAll(const TFuture<void>& f1) { + return WaitGeneric<TWaitPolicy::TAll>(f1); + } + + TFuture<void> WaitAll(const TFuture<void>& f1, const TFuture<void>& f2) { + return WaitGeneric<TWaitPolicy::TAll>(f1, f2); + } + + TFuture<void> WaitAll(TArrayRef<const TFuture<void>> futures) { + return WaitGeneric<TWaitPolicy::TAll>(futures); + } + + + //////////////////////////////////////////////////////////////////////////////// + + TFuture<void> WaitExceptionOrAll(const TFuture<void>& f1) { + return WaitGeneric<TWaitPolicy::TExceptionOrAll>(f1); + } + + TFuture<void> WaitExceptionOrAll(const TFuture<void>& f1, const TFuture<void>& f2) { + return WaitGeneric<TWaitPolicy::TExceptionOrAll>(f1, f2); + } + + TFuture<void> WaitExceptionOrAll(TArrayRef<const TFuture<void>> futures) { + return WaitGeneric<TWaitPolicy::TExceptionOrAll>(futures); + } + + //////////////////////////////////////////////////////////////////////////////// + + TFuture<void> WaitAny(const TFuture<void>& f1) { + return WaitGeneric<TWaitPolicy::TAny>(f1); + } + + TFuture<void> WaitAny(const TFuture<void>& f1, const TFuture<void>& f2) { + return WaitGeneric<TWaitPolicy::TAny>(f1, f2); + } + + TFuture<void> WaitAny(TArrayRef<const TFuture<void>> futures) { + return WaitGeneric<TWaitPolicy::TAny>(futures); + } +} diff --git a/library/cpp/threading/future/wait/wait.h b/library/cpp/threading/future/wait/wait.h new file mode 100644 index 0000000000..6ff7d57baa --- /dev/null +++ b/library/cpp/threading/future/wait/wait.h @@ -0,0 +1,41 @@ +#pragma once + +#include "fwd.h" + +#include <library/cpp/threading/future/core/future.h> +#include <library/cpp/threading/future/wait/wait_group.h> + +#include <util/generic/array_ref.h> + +namespace NThreading { + namespace NImpl { + template <class TContainer> + using EnableGenericWait = std::enable_if_t< + !std::is_convertible_v<TContainer, TArrayRef<const TFuture<void>>>, + TFuture<void>>; + } + // waits for all futures + [[nodiscard]] TFuture<void> WaitAll(const TFuture<void>& f1); + [[nodiscard]] TFuture<void> WaitAll(const TFuture<void>& f1, const TFuture<void>& f2); + [[nodiscard]] TFuture<void> WaitAll(TArrayRef<const TFuture<void>> futures); + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitAll(const TContainer& futures); + + // waits for the first exception or for all futures + [[nodiscard]] TFuture<void> WaitExceptionOrAll(const TFuture<void>& f1); + [[nodiscard]] TFuture<void> WaitExceptionOrAll(const TFuture<void>& f1, const TFuture<void>& f2); + [[nodiscard]] TFuture<void> WaitExceptionOrAll(TArrayRef<const TFuture<void>> futures); + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitExceptionOrAll(const TContainer& futures); + + // waits for any future + [[nodiscard]] TFuture<void> WaitAny(const TFuture<void>& f1); + [[nodiscard]] TFuture<void> WaitAny(const TFuture<void>& f1, const TFuture<void>& f2); + [[nodiscard]] TFuture<void> WaitAny(TArrayRef<const TFuture<void>> futures); + template <typename TContainer> + [[nodiscard]] NImpl::EnableGenericWait<TContainer> WaitAny(const TContainer& futures); +} + +#define INCLUDE_FUTURE_INL_H +#include "wait-inl.h" +#undef INCLUDE_FUTURE_INL_H diff --git a/library/cpp/threading/future/wait/wait_group-inl.h b/library/cpp/threading/future/wait/wait_group-inl.h new file mode 100644 index 0000000000..a7da536f20 --- /dev/null +++ b/library/cpp/threading/future/wait/wait_group-inl.h @@ -0,0 +1,206 @@ +#pragma once + +#if !defined(INCLUDE_FUTURE_INL_H) +#error "you should never include wait_group-inl.h directly" +#endif // INCLUDE_FUTURE_INL_H + +#include "wait_policy.h" + +#include <util/generic/maybe.h> +#include <util/generic/ptr.h> + +#include <library/cpp/threading/future/core/future.h> + +#include <util/system/spinlock.h> + +#include <atomic> +#include <exception> + +namespace NThreading { + namespace NWaitGroup::NImpl { + template <class WaitPolicy> + struct TState final : TAtomicRefCount<TState<WaitPolicy>> { + template <class T> + void Add(const TFuture<T>& future); + TFuture<void> Finish(); + + void TryPublish(); + void Publish(); + + bool ShouldPublishByCount() const noexcept; + bool ShouldPublishByException() const noexcept; + + TStateRef<WaitPolicy> SharedFromThis() noexcept { + return TStateRef<WaitPolicy>{this}; + } + + enum class EPhase { + Initial, + Publishing, + }; + + // initially we have one imaginary discovered future which we + // use for synchronization with ::Finish + std::atomic<ui64> Discovered{1}; + + std::atomic<ui64> Finished{0}; + + std::atomic<EPhase> Phase{EPhase::Initial}; + + TPromise<void> Subscribers = NewPromise(); + + mutable TAdaptiveLock Mut; + std::exception_ptr ExceptionInFlight; + + void TrySetException(std::exception_ptr eptr) noexcept { + TGuard lock{Mut}; + if (!ExceptionInFlight) { + ExceptionInFlight = std::move(eptr); + } + } + + std::exception_ptr GetExceptionInFlight() const noexcept { + TGuard lock{Mut}; + return ExceptionInFlight; + } + }; + + template <class WaitPolicy> + inline TFuture<void> TState<WaitPolicy>::Finish() { + Finished.fetch_add(1); // complete the imaginary future + + // handle empty case explicitly: + if (Discovered.load() == 1) { + Y_ASSERT(Phase.load() == EPhase::Initial); + Publish(); + } else { + TryPublish(); + } + + return Subscribers; + } + + template <class WaitPolicy> + template <class T> + inline void TState<WaitPolicy>::Add(const TFuture<T>& future) { + future.EnsureInitialized(); + + Discovered.fetch_add(1); + + // NoexceptSubscribe is needed to make ::Add exception-safe + future.NoexceptSubscribe([self = SharedFromThis()](auto&& future) { + try { + future.TryRethrow(); + } catch (...) { + self->TrySetException(std::current_exception()); + } + + self->Finished.fetch_add(1); + self->TryPublish(); + }); + } + + // + // ============================ PublishByCount ================================== + // + + template <class WaitPolicy> + inline bool TState<WaitPolicy>::ShouldPublishByCount() const noexcept { + // - safety: a) If the future incremented ::Finished, and we observe the effect, then we will observe ::Discovered as incremented by its discovery later + // b) Every discovery of a future observes discovery of the imaginary future + // a, b => if finishedByNow == discoveredByNow, then every future discovered in [imaginary discovered, imaginary finished] is finished + // + // - liveness: a) TryPublish is called after each increment of ::Finished + // b) There is some last increment of ::Finished which follows all other operations with ::Finished and ::Discovered (provided that every future is eventually set) + // c) For each increment of ::Discovered there is an increment of ::Finished (provided that every future is eventually set) + // a, b c => some call to ShouldPublishByCount will always return true + // + // order of the following two operations is significant for the proof. + auto finishedByNow = Finished.load(); + auto discoveredByNow = Discovered.load(); + + return finishedByNow == discoveredByNow; + } + + template <> + inline bool TState<TWaitPolicy::TAny>::ShouldPublishByCount() const noexcept { + auto finishedByNow = Finished.load(); + + // note that the empty case is not handled here + return finishedByNow >= 2; // at least one non-imaginary + } + + // + // ============================ PublishByException ================================== + // + + template <> + inline bool TState<TWaitPolicy::TAny>::ShouldPublishByException() const noexcept { + // for TAny exceptions are handled by ShouldPublishByCount + return false; + } + + template <> + inline bool TState<TWaitPolicy::TAll>::ShouldPublishByException() const noexcept { + return false; + } + + template <> + inline bool TState<TWaitPolicy::TExceptionOrAll>::ShouldPublishByException() const noexcept { + return GetExceptionInFlight() != nullptr; + } + + // + // + // + + template <class WaitPolicy> + inline void TState<WaitPolicy>::TryPublish() { + // the order is insignificant (without proof) + bool shouldPublish = ShouldPublishByCount() || ShouldPublishByException(); + + if (shouldPublish) { + if (auto currentPhase = EPhase::Initial; + Phase.compare_exchange_strong(currentPhase, EPhase::Publishing)) { + Publish(); + } + } + } + + template <class WaitPolicy> + inline void TState<WaitPolicy>::Publish() { + auto eptr = GetExceptionInFlight(); + + // can potentially throw + if (eptr) { + Subscribers.SetException(std::move(eptr)); + } else { + Subscribers.SetValue(); + } + } + } + + template <class WaitPolicy> + inline TWaitGroup<WaitPolicy>::TWaitGroup() + : State_{MakeIntrusive<NWaitGroup::NImpl::TState<WaitPolicy>>()} + { + } + + template <class WaitPolicy> + template <class T> + inline TWaitGroup<WaitPolicy>& TWaitGroup<WaitPolicy>::Add(const TFuture<T>& future) { + State_->Add(future); + return *this; + } + + template <class WaitPolicy> + inline TFuture<void> TWaitGroup<WaitPolicy>::Finish() && { + auto res = State_->Finish(); + + // just to prevent nasty bugs from use-after-move + State_.Reset(); + + return res; + } +} + diff --git a/library/cpp/threading/future/wait/wait_group.cpp b/library/cpp/threading/future/wait/wait_group.cpp new file mode 100644 index 0000000000..4b9c7adb27 --- /dev/null +++ b/library/cpp/threading/future/wait/wait_group.cpp @@ -0,0 +1 @@ +#include "wait_group.h" diff --git a/library/cpp/threading/future/wait/wait_group.h b/library/cpp/threading/future/wait/wait_group.h new file mode 100644 index 0000000000..78d85594a2 --- /dev/null +++ b/library/cpp/threading/future/wait/wait_group.h @@ -0,0 +1,65 @@ +#pragma once + +#include <library/cpp/threading/future/core/future.h> + +#include <util/generic/ptr.h> + +namespace NThreading { + namespace NWaitGroup::NImpl { + template <class WaitPolicy> + struct TState; + + template <class WaitPolicy> + using TStateRef = TIntrusivePtr<TState<WaitPolicy>>; + } + + // a helper class which allows to + // wait for a set of futures which is + // not known beforehand. Might be useful, e.g., for graceful shutdown: + // while (!Stop()) { + // wg.Add( + // DoAsyncWork()); + // } + // std::move(wg).Finish() + // .GetValueSync(); + // + // + // the folowing are equivalent: + // { + // return WaitAll(futures); + // } + // { + // TWaitGroup<TWaitPolicy::TAll> wg; + // for (auto&& f: futures) { wg.Add(f); } + // return std::move(wg).Finish(); + // } + + template <class WaitPolicy> + class TWaitGroup { + public: + TWaitGroup(); + + // thread-safe, exception-safe + // + // adds the future to the set of futures to wait for + // + // if an exception is thrown during a call to ::Discover, the call has no effect + // + // accepts non-void T just for optimization + // (so that the caller does not have to use future.IgnoreResult()) + template <class T> + TWaitGroup& Add(const TFuture<T>& future); + + // finishes building phase + // and returns the future that combines the futures + // in the wait group according to WaitPolicy + [[nodiscard]] TFuture<void> Finish() &&; + + private: + NWaitGroup::NImpl::TStateRef<WaitPolicy> State_; + }; +} + +#define INCLUDE_FUTURE_INL_H +#include "wait_group-inl.h" +#undef INCLUDE_FUTURE_INL_H diff --git a/library/cpp/threading/future/wait/wait_policy.cpp b/library/cpp/threading/future/wait/wait_policy.cpp new file mode 100644 index 0000000000..dbebec4966 --- /dev/null +++ b/library/cpp/threading/future/wait/wait_policy.cpp @@ -0,0 +1 @@ +#include "wait_policy.h" diff --git a/library/cpp/threading/future/wait/wait_policy.h b/library/cpp/threading/future/wait/wait_policy.h new file mode 100644 index 0000000000..310b702f17 --- /dev/null +++ b/library/cpp/threading/future/wait/wait_policy.h @@ -0,0 +1,10 @@ +#pragma once + +namespace NThreading { + struct TWaitPolicy { + struct TAll {}; + struct TAny {}; + struct TExceptionOrAll {}; + }; +} + diff --git a/library/cpp/threading/future/ya.make b/library/cpp/threading/future/ya.make new file mode 100644 index 0000000000..6591031f46 --- /dev/null +++ b/library/cpp/threading/future/ya.make @@ -0,0 +1,22 @@ +OWNER( + g:rtmr +) + +LIBRARY() + +SRCS( + async.cpp + core/future.cpp + core/fwd.cpp + fwd.cpp + wait/fwd.cpp + wait/wait.cpp + wait/wait_group.cpp + wait/wait_policy.cpp +) + +END() + +RECURSE_FOR_TESTS( + mt_ut +) diff --git a/library/cpp/threading/light_rw_lock/bench/lightrwlock_test.cpp b/library/cpp/threading/light_rw_lock/bench/lightrwlock_test.cpp new file mode 100644 index 0000000000..c3027ea544 --- /dev/null +++ b/library/cpp/threading/light_rw_lock/bench/lightrwlock_test.cpp @@ -0,0 +1,188 @@ +#include <library/cpp/threading/light_rw_lock/lightrwlock.h> +#include <util/random/random.h> + +#ifdef _linux_ +// Light rw lock is implemented only for linux + +using namespace NS_LightRWLock; + +#include <pthread.h> +#include <stdlib.h> +#include <stdio.h> + +#define LIGHT + +#ifdef RWSPINLOCK +#include <library/cpp/lwtrace/rwspinlock.h> +#endif + +#define CHECK_LOGIC 1 +#define LOOPCOUNT 1000000 +#define RANRCOUNT 100 +#define THREADCOUNT 40 +#define WRITELOCKS 100 + +#if defined(_MSC_VER) +static int Y_FORCE_INLINE AtomicFetchAdd(volatile int& item, int value) { + return _InterlockedExchangeAdd((&item, value); +} +#elif defined(__GNUC__) +#else +#error unsupported platform +#endif + +class TPosixRWLock { +public: + TPosixRWLock() { + } + + ~TPosixRWLock() { + pthread_rwlock_destroy(&rwlock); + } + + TPosixRWLock(const TPosixRWLock&) = delete; + void operator=(const TPosixRWLock&) = delete; + +private: + pthread_rwlock_t rwlock = PTHREAD_RWLOCK_INITIALIZER; + friend class TPosixRWShareLocker; + friend class TPosixRWExclusiveLocker; +}; + +#if defined(LIGHT) +TLightRWLock __attribute__((aligned(64))) rwlock; +#elif defined(POSIX) +TPosixRWLock rwlock; +#elif defined(RWSPINLOCK) +TRWSpinLock __attribute__((aligned(64))) rwlock; +#else +#error "define lock type" +#endif + +volatile __attribute__((aligned(64))) int checkIt = 0; +volatile int checkExcl = 0; + +class TPosixRWShareLocker { +public: + TPosixRWShareLocker(TPosixRWLock& lock) + : LockP_(&lock) + { + pthread_rwlock_rdlock(&LockP_->rwlock); + } + + ~TPosixRWShareLocker() { + pthread_rwlock_unlock(&LockP_->rwlock); + } + + TPosixRWShareLocker(const TPosixRWShareLocker&) = delete; + void operator=(const TPosixRWShareLocker&) = delete; + +private: + TPosixRWLock* LockP_; +}; + +class TPosixRWExclusiveLocker { +public: + TPosixRWExclusiveLocker(TPosixRWLock& lock) + : LockP_(&lock) + { + pthread_rwlock_wrlock(&LockP_->rwlock); + } + + ~TPosixRWExclusiveLocker() { + pthread_rwlock_unlock(&LockP_->rwlock); + } + TPosixRWExclusiveLocker(const TPosixRWExclusiveLocker&) = delete; + void operator=(const TPosixRWExclusiveLocker&) = delete; + +private: + TPosixRWLock* LockP_; +}; + +template <typename TLocker, bool excl> +static Y_FORCE_INLINE void Run() { + TLocker lockIt(rwlock); + +#if defined(CHECK_LOGIC) && CHECK_LOGIC + if (!excl && checkExcl == 1) { + printf("there is a bug\n"); + } + + int result = AtomicFetchAdd(checkIt, 1); + if (excl) + checkExcl = 1; + + if (excl && result > 1) + printf("there is a bug\n"); +#endif + + for (unsigned w = 0; w < RANRCOUNT; ++w) + RandomNumber<ui32>(); + +#if defined(CHECK_LOGIC) && CHECK_LOGIC + if (excl) + checkExcl = 0; + + AtomicFetchAdd(checkIt, -1); +#endif +} + +#ifdef LIGHT +static void* fast_thread_start(__attribute__((unused)) void* arg) { + for (unsigned q = 0; q < LOOPCOUNT; ++q) { + char excl = (RandomNumber<ui32>() % WRITELOCKS) == 0; + if (excl) + Run<TLightWriteGuard, 1>(); + else + Run<TLightReadGuard, 0>(); + } + return NULL; +} +#endif + +#ifdef POSIX +static void* fast_thread_start(__attribute__((unused)) void* arg) { + for (unsigned q = 0; q < LOOPCOUNT; ++q) { + char excl = (RandomNumber<ui32>() % WRITELOCKS) == 0; + if (excl) + Run<TPosixRWExclusiveLocker, 1>(); + else + Run<TPosixRWShareLocker, 0>(); + } + return NULL; +} +#endif + +#ifdef RWSPINLOCK +static void* fast_thread_start(__attribute__((unused)) void* arg) { + for (unsigned q = 0; q < LOOPCOUNT; ++q) { + char excl = (RandomNumber<ui32>() % WRITELOCKS) == 0; + if (excl) + Run<TWriteSpinLockGuard, 1>(); + else + Run<TReadSpinLockGuard, 0>(); + } + return NULL; +} +#endif + +int main() { + pthread_t threads[THREADCOUNT]; + + for (unsigned q = 0; q < THREADCOUNT; ++q) { + pthread_create(&(threads[q]), NULL, &fast_thread_start, NULL); + } + + for (unsigned q = 0; q < THREADCOUNT; ++q) + pthread_join(threads[q], NULL); + + return 0; +} + +#else // !_linux_ + +int main() { + return 0; +} + +#endif diff --git a/library/cpp/threading/light_rw_lock/bench/ya.make b/library/cpp/threading/light_rw_lock/bench/ya.make new file mode 100644 index 0000000000..7969b52a50 --- /dev/null +++ b/library/cpp/threading/light_rw_lock/bench/ya.make @@ -0,0 +1,13 @@ +PROGRAM(lightrwlock_test) + +OWNER(agri) + +SRCS( + lightrwlock_test.cpp +) + +PEERDIR( + library/cpp/threading/light_rw_lock +) + +END() diff --git a/library/cpp/threading/light_rw_lock/lightrwlock.cpp b/library/cpp/threading/light_rw_lock/lightrwlock.cpp new file mode 100644 index 0000000000..fbb63fd47f --- /dev/null +++ b/library/cpp/threading/light_rw_lock/lightrwlock.cpp @@ -0,0 +1,113 @@ +#include "lightrwlock.h" +#include <util/system/spinlock.h> + +#if defined(_linux_) + +using namespace NS_LightRWLock; + +void TLightRWLock::WaitForUntrappedShared() { + for (;;) { + for (ui32 i = 0; i < SpinCount_; ++i) { + SpinLockPause(); + + if ((AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) + return; + } + + SequenceStore(UnshareFutex_, 1); + if ((AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) { + AtomicStore(UnshareFutex_, 0); + return; + } + FutexWait(UnshareFutex_, 1); + } +} + +void TLightRWLock::WaitForExclusiveAndUntrappedShared() { + for (;;) { + for (ui32 i = 0; i < SpinCount_; ++i) { + SpinLockPause(); + + if (AtomicLoad(Counter_) >= 0) + goto try_to_get_lock; + if (AtomicLoad(TrappedFutex_) == 1) + goto skip_store_trapped; + } + + SequenceStore(TrappedFutex_, 1); + skip_store_trapped: + + if (AtomicLoad(Counter_) < 0) { + FutexWait(TrappedFutex_, 1); + } + + try_to_get_lock: + if (!AtomicSetBit(Counter_, 31)) + break; + } + + for (ui32 j = 0;; ++j) { + for (ui32 i = 0; i < SpinCount_; ++i) { + if ((AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) + return; + + SpinLockPause(); + } + + SequenceStore(UnshareFutex_, 1); + + if ((AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) { + AtomicStore(UnshareFutex_, 0); + return; + } + + FutexWait(UnshareFutex_, 1); + } +} + +void TLightRWLock::WaitForUntrappedAndAcquireRead() { + if (AtomicFetchAdd(Counter_, -1) < 0) + goto skip_lock_try; + + for (;;) { + again: + if (Y_UNLIKELY(AtomicFetchAdd(Counter_, 1) >= 0)) { + return; + } else { + if (AtomicFetchAdd(Counter_, -1) >= 0) + goto again; + } + + skip_lock_try: + if (AtomicLoad(UnshareFutex_) && (AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) { + SequenceStore(UnshareFutex_, 0); + FutexWake(UnshareFutex_, 1); + } + + for (;;) { + for (ui32 i = 0; i < SpinCount_; ++i) { + SpinLockPause(); + + if (AtomicLoad(Counter_) >= 0) + goto again; + if (AtomicLoad(TrappedFutex_) == 1) + goto skip_store_trapped; + } + + SequenceStore(TrappedFutex_, 1); + skip_store_trapped: + + if (AtomicLoad(Counter_) < 0) { + FutexWait(TrappedFutex_, 1); + if (AtomicLoad(Counter_) < 0) + goto again; + } else if (AtomicLoad(TrappedFutex_)) { + SequenceStore(TrappedFutex_, 0); + FutexWake(TrappedFutex_, 0x7fffffff); + } + break; + } + } +} + +#endif // _linux_ diff --git a/library/cpp/threading/light_rw_lock/lightrwlock.h b/library/cpp/threading/light_rw_lock/lightrwlock.h new file mode 100644 index 0000000000..931a1817bc --- /dev/null +++ b/library/cpp/threading/light_rw_lock/lightrwlock.h @@ -0,0 +1,220 @@ +#pragma once + +#include <util/system/rwlock.h> +#include <util/system/sanitizers.h> + +#if defined(_linux_) +/* TLightRWLock is optimized for read lock and very fast lock/unlock switching. + Read lock increments counter. + Write lock sets highest bit of counter (makes counter negative). + + Whenever a thread tries to acquire read lock that thread increments + the counter. If the thread gets negative value of the counter right just + after the increment that means write lock was acquired in another thread. + In that case the thread decrements the counter back, wakes one thread on + UnshareFutex, waits on the TrappedFutex and then tries acquire read lock + from the beginning. + If the thread gets positive value of the counter after the increment + then read lock was successfully acquired and + the thread can proceed execution. + + Whenever a thread tries to acquire write lock that thread set the highest bit + of the counter. If the thread determine that the bit was set previously then + write lock was acquired in another thread. In that case the thread waits on + the TrappedFutex and then tries again from the beginning. + If the highest bit was successfully set then thread check if any read lock + exists at the moment. If so the thread waits on UnshareFutex. If there is + no more read locks then write lock was successfully acquired and the thread + can proceed execution. +*/ + +#include <linux/futex.h> +#include <unistd.h> +#include <sys/syscall.h> +#include <errno.h> + +namespace NS_LightRWLock { + static int Y_FORCE_INLINE AtomicFetchAdd(volatile int& item, int value) { + return __atomic_fetch_add(&item, value, __ATOMIC_SEQ_CST); + } + +#if defined(_x86_64_) || defined(_i386_) + + static char Y_FORCE_INLINE AtomicSetBit(volatile int& item, unsigned bit) { + char ret; + __asm__ __volatile__( + "lock bts %2,%0\n" + "setc %1\n" + : "+m"(item), "=rm"(ret) + : "r"(bit) + : "cc"); + + // msan doesn't treat ret as initialized + NSan::Unpoison(&ret, sizeof(ret)); + + return ret; + } + + static char Y_FORCE_INLINE AtomicClearBit(volatile int& item, unsigned bit) { + char ret; + __asm__ __volatile__( + "lock btc %2,%0\n" + "setc %1\n" + : "+m"(item), "=rm"(ret) + : "r"(bit) + : "cc"); + + // msan doesn't treat ret as initialized + NSan::Unpoison(&ret, sizeof(ret)); + + return ret; + } + + +#else + + static char Y_FORCE_INLINE AtomicSetBit(volatile int& item, unsigned bit) { + int prev = __atomic_fetch_or(&item, 1 << bit, __ATOMIC_SEQ_CST); + return (prev & (1 << bit)) != 0 ? 1 : 0; + } + + static char Y_FORCE_INLINE + AtomicClearBit(volatile int& item, unsigned bit) { + int prev = __atomic_fetch_and(&item, ~(1 << bit), __ATOMIC_SEQ_CST); + return (prev & (1 << bit)) != 0 ? 1 : 0; + } +#endif + +#if defined(_x86_64_) || defined(_i386_) || defined (__aarch64__) || defined (__powerpc64__) + static bool AtomicLockHighByte(volatile int& item) { + union TA { + int x; + char y[4]; + }; + + volatile TA* ptr = reinterpret_cast<volatile TA*>(&item); + char zero = 0; + return __atomic_compare_exchange_n(&(ptr->y[3]), &zero, (char)128, true, + __ATOMIC_SEQ_CST, __ATOMIC_RELAXED); + } + +#endif + + template <typename TInt> + static void Y_FORCE_INLINE AtomicStore(volatile TInt& var, TInt value) { + __atomic_store_n(&var, value, __ATOMIC_RELEASE); + } + + template <typename TInt> + static void Y_FORCE_INLINE SequenceStore(volatile TInt& var, TInt value) { + __atomic_store_n(&var, value, __ATOMIC_SEQ_CST); + } + + template <typename TInt> + static TInt Y_FORCE_INLINE AtomicLoad(const volatile TInt& var) { + return __atomic_load_n(&var, __ATOMIC_ACQUIRE); + } + + static void Y_FORCE_INLINE FutexWait(volatile int& fvar, int value) { + for (;;) { + int result = + syscall(SYS_futex, &fvar, FUTEX_WAIT_PRIVATE, value, NULL, NULL, 0); + if (Y_UNLIKELY(result == -1)) { + if (errno == EWOULDBLOCK) + return; + if (errno == EINTR) + continue; + Y_FAIL("futex error"); + } + } + } + + static void Y_FORCE_INLINE FutexWake(volatile int& fvar, int amount) { + const int result = + syscall(SYS_futex, &fvar, FUTEX_WAKE_PRIVATE, amount, NULL, NULL, 0); + if (Y_UNLIKELY(result == -1)) + Y_FAIL("futex error"); + } + +} + +class alignas(64) TLightRWLock { +public: + TLightRWLock(ui32 spinCount = 10) + : Counter_(0) + , TrappedFutex_(0) + , UnshareFutex_(0) + , SpinCount_(spinCount) + { + } + + TLightRWLock(const TLightRWLock&) = delete; + void operator=(const TLightRWLock&) = delete; + + Y_FORCE_INLINE void AcquireWrite() { + using namespace NS_LightRWLock; + + if (AtomicLockHighByte(Counter_)) { + if ((AtomicLoad(Counter_) & 0x7FFFFFFF) == 0) + return; + return WaitForUntrappedShared(); + } + WaitForExclusiveAndUntrappedShared(); + } + + Y_FORCE_INLINE void AcquireRead() { + using namespace NS_LightRWLock; + + if (Y_LIKELY(AtomicFetchAdd(Counter_, 1) >= 0)) + return; + WaitForUntrappedAndAcquireRead(); + } + + Y_FORCE_INLINE void ReleaseWrite() { + using namespace NS_LightRWLock; + + AtomicClearBit(Counter_, 31); + if (AtomicLoad(TrappedFutex_)) { + SequenceStore(TrappedFutex_, 0); + FutexWake(TrappedFutex_, 0x7fffffff); + } + } + + Y_FORCE_INLINE void ReleaseRead() { + using namespace NS_LightRWLock; + + if (Y_LIKELY(AtomicFetchAdd(Counter_, -1) >= 0)) + return; + if (!AtomicLoad(UnshareFutex_)) + return; + if ((AtomicLoad(Counter_) & 0x7fffffff) == 0) { + SequenceStore(UnshareFutex_, 0); + FutexWake(UnshareFutex_, 1); + } + } + +private: + volatile int Counter_; + volatile int TrappedFutex_; + volatile int UnshareFutex_; + const ui32 SpinCount_; + + void WaitForUntrappedShared(); + void WaitForExclusiveAndUntrappedShared(); + void WaitForUntrappedAndAcquireRead(); +}; + +#else + +class TLightRWLock: public TRWMutex { +public: + TLightRWLock() { + } + TLightRWLock(ui32) { + } +}; + +#endif + +using TLightReadGuard = TReadGuardBase<TLightRWLock>; +using TLightWriteGuard = TWriteGuardBase<TLightRWLock>; diff --git a/library/cpp/threading/light_rw_lock/ut/rwlock_ut.cpp b/library/cpp/threading/light_rw_lock/ut/rwlock_ut.cpp new file mode 100644 index 0000000000..e82063d959 --- /dev/null +++ b/library/cpp/threading/light_rw_lock/ut/rwlock_ut.cpp @@ -0,0 +1,122 @@ +#include <library/cpp/threading/light_rw_lock/lightrwlock.h> +#include <library/cpp/testing/unittest/registar.h> +#include <util/random/random.h> +#include <util/system/atomic.h> +#include <util/thread/pool.h> + +class TRWMutexTest: public TTestBase { + UNIT_TEST_SUITE(TRWMutexTest); + UNIT_TEST(TestReaders) + UNIT_TEST(TestReadersWriters) + UNIT_TEST_SUITE_END(); + + struct TSharedData { + TSharedData() + : writersIn(0) + , readersIn(0) + , failed(false) + { + } + + TAtomic writersIn; + TAtomic readersIn; + + bool failed; + + TLightRWLock mutex; + }; + + class TThreadTask: public IObjectInQueue { + public: + using PFunc = void (TThreadTask::*)(void); + + TThreadTask(PFunc func, TSharedData& data, size_t id, size_t total) + : Func_(func) + , Data_(data) + , Id_(id) + , Total_(total) + { + } + + void Process(void*) override { + THolder<TThreadTask> This(this); + + (this->*Func_)(); + } + +#define FAIL_ASSERT(cond) \ + if (!(cond)) { \ + Data_.failed = true; \ + } + void RunReaders() { + Data_.mutex.AcquireRead(); + + AtomicIncrement(Data_.readersIn); + usleep(100); + FAIL_ASSERT(Data_.readersIn == long(Total_)); + usleep(100); + AtomicDecrement(Data_.readersIn); + + Data_.mutex.ReleaseRead(); + } + + void RunReadersWriters() { + if (Id_ % 2 == 0) { + for (size_t i = 0; i < 10; ++i) { + Data_.mutex.AcquireRead(); + + AtomicIncrement(Data_.readersIn); + FAIL_ASSERT(Data_.writersIn == 0); + usleep(RandomNumber<ui32>() % 5); + AtomicDecrement(Data_.readersIn); + + Data_.mutex.ReleaseRead(); + } + } else { + for (size_t i = 0; i < 10; ++i) { + Data_.mutex.AcquireWrite(); + + AtomicIncrement(Data_.writersIn); + FAIL_ASSERT(Data_.readersIn == 0 && Data_.writersIn == 1); + usleep(RandomNumber<ui32>() % 5); + AtomicDecrement(Data_.writersIn); + + Data_.mutex.ReleaseWrite(); + } + } + } +#undef FAIL_ASSERT + + private: + PFunc Func_; + TSharedData& Data_; + size_t Id_; + size_t Total_; + }; + +private: +#define RUN_CYCLE(what, count) \ + Q_.Start(count); \ + for (size_t i = 0; i < count; ++i) { \ + UNIT_ASSERT(Q_.Add(new TThreadTask(&TThreadTask::what, Data_, i, count))); \ + } \ + Q_.Stop(); \ + bool b = Data_.failed; \ + Data_.failed = false; \ + UNIT_ASSERT(!b); + + void TestReaders() { + RUN_CYCLE(RunReaders, 1); + } + + void TestReadersWriters() { + RUN_CYCLE(RunReadersWriters, 1); + } + +#undef RUN_CYCLE +private: + TSharedData Data_; + TThreadPool Q_; +}; + +UNIT_TEST_SUITE_REGISTRATION(TRWMutexTest) diff --git a/library/cpp/threading/light_rw_lock/ut/ya.make b/library/cpp/threading/light_rw_lock/ut/ya.make new file mode 100644 index 0000000000..92928b837c --- /dev/null +++ b/library/cpp/threading/light_rw_lock/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/threading/light_rw_lock) + +OWNER(agri) + +SRCS( + rwlock_ut.cpp +) + +END() diff --git a/library/cpp/threading/light_rw_lock/ya.make b/library/cpp/threading/light_rw_lock/ya.make new file mode 100644 index 0000000000..a196fb8588 --- /dev/null +++ b/library/cpp/threading/light_rw_lock/ya.make @@ -0,0 +1,10 @@ +LIBRARY() + +OWNER(agri) + +SRCS( + lightrwlock.cpp + lightrwlock.h +) + +END() diff --git a/library/cpp/threading/local_executor/README.md b/library/cpp/threading/local_executor/README.md new file mode 100644 index 0000000000..aaad2e2986 --- /dev/null +++ b/library/cpp/threading/local_executor/README.md @@ -0,0 +1,74 @@ +# Library for parallel task execution in thread pool + +This library allows easy parallelization of existing code and cycles. +It provides `NPar::TLocalExecutor` class and `NPar::LocalExecutor()` singleton accessor. +At start, `TLocalExecutor` has no threads in thread pool and all async tasks will be queued for later execution when extra threads appear. +All tasks should be `NPar::ILocallyExecutable` child class or function equal to `std::function<void(int)>` + +## TLocalExecutor methods + +`TLocalExecutor::Run(int threadcount)` - add threads to thread pool (**WARNING!** `Run(threadcount)` will *add* `threadcount` threads to pool) + +`void TLocalExecutor::Exec(TLocallyExecutableFunction exec, int id, int flags)` - run one task and pass id as task function input, flags - bitmask composition of: + +- `TLocalExecutor::HIGH_PRIORITY = 0` - put task in high priority queue +- `TLocalExecutor::MED_PRIORITY = 1` - put task in medium priority queue +- `TLocalExecutor::LOW_PRIORITY = 2` - put task in low priority queue +- `TLocalExecutor::WAIT_COMPLETE = 4` - wait for task completion + +`void TLocalExecutor::ExecRange(TLocallyExecutableFunction exec, TExecRangeParams blockParams, int flags);` - run range of tasks `[TExecRangeParams::FirstId, TExecRangeParams::LastId).` + +`flags` is the same as for `TLocalExecutor::Exec`. + +`TExecRangeParams` is a structure that describes the range. +By default each task is executed separately. Threads from thread pool are taking +the tasks in the manner first come first serve. + +It is also possible to partition range of tasks in consequtive blocks and execute each block as a bigger task. +`TExecRangeParams::SetBlockCountToThreadCount()` will result in thread count tasks, + where thread count is the count of threads in thread pool. + each thread will execute approximately equal count of tasks from range. + +`TExecRangeParams::SetBlockSize()` and `TExecRangeParams::SetBlockCount()` will partition +the range of tasks into consequtive blocks of approximately given size, or of size calculated + by partitioning the range into approximately equal size blocks of given count. + +## Examples + +### Simple task async exec with medium priority + +```cpp +using namespace NPar; + +LocalExecutor().Run(4); +TEvent event; +LocalExecutor().Exec([](int) { + SomeFunc(); + event.Signal(); +}, 0, TLocalExecutor::MED_PRIORITY); + +SomeOtherCode(); +event.WaitI(); +``` + +### Execute task range and wait completion + +```cpp +using namespace NPar; + +LocalExecutor().Run(4); +LocalExecutor().ExecRange([](int id) { + SomeFunc(id); +}, TExecRangeParams(0, 10), TLocalExecutor::WAIT_COMPLETE | TLocalExecutor::MED_PRIORITY); +``` + +### Exception handling + +By default if a not caught exception arise in a task which runs through the Local Executor, then std::terminate() will be called immediately. The exception will be printed to stderr before the termination. Best practice is to handle exception within a task, or avoid throwing exceptions at all for performance reasons. + +However, if you'd like to handle and/or rethrow exceptions outside of a range, you can use ExecRangeWithFuture(). +It returns vector [0 .. LastId-FirstId] elements, where i-th element is a TFuture corresponding to task with id = (FirstId + i). +Use method .HasValue() of the element to check in Async mode if the corresponding task is complete. +Use .GetValue() or .GetValueSync() to wait for completion of the corresponding task. GetValue() and GetValueSync() will also rethrow an exception if it appears during execution of the task. + +You may also use ExecRangeWithThrow() to just receive an exception from a range if it appears. It rethrows an exception from a task with minimal id if such an exception exists, and guarantees normal flow if no exception arise. diff --git a/library/cpp/threading/local_executor/local_executor.cpp b/library/cpp/threading/local_executor/local_executor.cpp new file mode 100644 index 0000000000..1d3fbb4bf4 --- /dev/null +++ b/library/cpp/threading/local_executor/local_executor.cpp @@ -0,0 +1,369 @@ +#include "local_executor.h" + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/utility.h> +#include <util/system/atomic.h> +#include <util/system/event.h> +#include <util/system/thread.h> +#include <util/system/tls.h> +#include <util/system/yield.h> +#include <util/thread/lfqueue.h> + +#include <utility> + +#ifdef _win_ +static void RegularYield() { +} +#else +// unix actually has cooperative multitasking! :) +// without this function program runs slower and system lags for some magic reason +static void RegularYield() { + SchedYield(); +} +#endif + +namespace { + struct TFunctionWrapper : NPar::ILocallyExecutable { + NPar::TLocallyExecutableFunction Exec; + TFunctionWrapper(NPar::TLocallyExecutableFunction exec) + : Exec(std::move(exec)) + { + } + void LocalExec(int id) override { + Exec(id); + } + }; + + class TFunctionWrapperWithPromise: public NPar::ILocallyExecutable { + private: + NPar::TLocallyExecutableFunction Exec; + int FirstId, LastId; + TVector<NThreading::TPromise<void>> Promises; + + public: + TFunctionWrapperWithPromise(NPar::TLocallyExecutableFunction exec, int firstId, int lastId) + : Exec(std::move(exec)) + , FirstId(firstId) + , LastId(lastId) + { + Y_ASSERT(FirstId <= LastId); + const int rangeSize = LastId - FirstId; + Promises.resize(rangeSize, NThreading::NewPromise()); + for (auto& promise : Promises) { + promise = NThreading::NewPromise(); + } + } + + void LocalExec(int id) override { + Y_ASSERT(FirstId <= id && id < LastId); + NThreading::NImpl::SetValue(Promises[id - FirstId], [=] { Exec(id); }); + } + + TVector<NThreading::TFuture<void>> GetFutures() const { + TVector<NThreading::TFuture<void>> out; + out.reserve(Promises.ysize()); + for (auto& promise : Promises) { + out.push_back(promise.GetFuture()); + } + return out; + } + }; + + struct TSingleJob { + TIntrusivePtr<NPar::ILocallyExecutable> Exec; + int Id{0}; + + TSingleJob() = default; + TSingleJob(TIntrusivePtr<NPar::ILocallyExecutable> exec, int id) + : Exec(std::move(exec)) + , Id(id) + { + } + }; + + class TLocalRangeExecutor: public NPar::ILocallyExecutable { + TIntrusivePtr<NPar::ILocallyExecutable> Exec; + alignas(64) TAtomic Counter; + alignas(64) TAtomic WorkerCount; + int LastId; + + void LocalExec(int) override { + AtomicAdd(WorkerCount, 1); + for (;;) { + if (!DoSingleOp()) + break; + } + AtomicAdd(WorkerCount, -1); + } + + public: + TLocalRangeExecutor(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId) + : Exec(std::move(exec)) + , Counter(firstId) + , WorkerCount(0) + , LastId(lastId) + { + } + bool DoSingleOp() { + const int id = AtomicAdd(Counter, 1) - 1; + if (id >= LastId) + return false; + Exec->LocalExec(id); + RegularYield(); + return true; + } + void WaitComplete() { + while (AtomicGet(WorkerCount) > 0) + RegularYield(); + } + int GetRangeSize() const { + return Max<int>(LastId - Counter, 0); + } + }; + +} + +////////////////////////////////////////////////////////////////////////// +class NPar::TLocalExecutor::TImpl { +public: + TLockFreeQueue<TSingleJob> JobQueue; + TLockFreeQueue<TSingleJob> MedJobQueue; + TLockFreeQueue<TSingleJob> LowJobQueue; + alignas(64) TSystemEvent HasJob; + + TAtomic ThreadCount{0}; + alignas(64) TAtomic QueueSize{0}; + TAtomic MPQueueSize{0}; + TAtomic LPQueueSize{0}; + TAtomic ThreadId{0}; + + Y_THREAD(int) + CurrentTaskPriority; + Y_THREAD(int) + WorkerThreadId; + + static void* HostWorkerThread(void* p); + bool GetJob(TSingleJob* job); + void RunNewThread(); + void LaunchRange(TIntrusivePtr<TLocalRangeExecutor> execRange, int queueSizeLimit, + TAtomic* queueSize, TLockFreeQueue<TSingleJob>* jobQueue); + + TImpl() = default; + ~TImpl(); +}; + +NPar::TLocalExecutor::TImpl::~TImpl() { + AtomicAdd(QueueSize, 1); + JobQueue.Enqueue(TSingleJob(nullptr, 0)); + HasJob.Signal(); + while (AtomicGet(ThreadCount)) { + ThreadYield(); + } +} + +void* NPar::TLocalExecutor::TImpl::HostWorkerThread(void* p) { + static const int FAST_ITERATIONS = 200; + + auto* const ctx = (TImpl*)p; + TThread::SetCurrentThreadName("ParLocalExecutor"); + ctx->WorkerThreadId = AtomicAdd(ctx->ThreadId, 1); + for (bool cont = true; cont;) { + TSingleJob job; + bool gotJob = false; + for (int iter = 0; iter < FAST_ITERATIONS; ++iter) { + if (ctx->GetJob(&job)) { + gotJob = true; + break; + } + } + if (!gotJob) { + ctx->HasJob.Reset(); + if (!ctx->GetJob(&job)) { + ctx->HasJob.Wait(); + continue; + } + } + if (job.Exec.Get()) { + job.Exec->LocalExec(job.Id); + RegularYield(); + } else { + AtomicAdd(ctx->QueueSize, 1); + ctx->JobQueue.Enqueue(job); + ctx->HasJob.Signal(); + cont = false; + } + } + AtomicAdd(ctx->ThreadCount, -1); + return nullptr; +} + +bool NPar::TLocalExecutor::TImpl::GetJob(TSingleJob* job) { + if (JobQueue.Dequeue(job)) { + CurrentTaskPriority = TLocalExecutor::HIGH_PRIORITY; + AtomicAdd(QueueSize, -1); + return true; + } else if (MedJobQueue.Dequeue(job)) { + CurrentTaskPriority = TLocalExecutor::MED_PRIORITY; + AtomicAdd(MPQueueSize, -1); + return true; + } else if (LowJobQueue.Dequeue(job)) { + CurrentTaskPriority = TLocalExecutor::LOW_PRIORITY; + AtomicAdd(LPQueueSize, -1); + return true; + } + return false; +} + +void NPar::TLocalExecutor::TImpl::RunNewThread() { + AtomicAdd(ThreadCount, 1); + TThread thr(HostWorkerThread, this); + thr.Start(); + thr.Detach(); +} + +void NPar::TLocalExecutor::TImpl::LaunchRange(TIntrusivePtr<TLocalRangeExecutor> rangeExec, + int queueSizeLimit, + TAtomic* queueSize, + TLockFreeQueue<TSingleJob>* jobQueue) { + int count = Min<int>(ThreadCount + 1, rangeExec->GetRangeSize()); + if (queueSizeLimit >= 0 && AtomicGet(*queueSize) >= queueSizeLimit) { + return; + } + AtomicAdd(*queueSize, count); + jobQueue->EnqueueAll(TVector<TSingleJob>{size_t(count), TSingleJob(rangeExec, 0)}); + HasJob.Signal(); +} + +NPar::TLocalExecutor::TLocalExecutor() + : Impl_{MakeHolder<TImpl>()} { +} + +NPar::TLocalExecutor::~TLocalExecutor() = default; + +void NPar::TLocalExecutor::RunAdditionalThreads(int threadCount) { + for (int i = 0; i < threadCount; i++) + Impl_->RunNewThread(); +} + +void NPar::TLocalExecutor::Exec(TIntrusivePtr<ILocallyExecutable> exec, int id, int flags) { + Y_ASSERT((flags & WAIT_COMPLETE) == 0); // unsupported + int prior = Max<int>(Impl_->CurrentTaskPriority, flags & PRIORITY_MASK); + switch (prior) { + case HIGH_PRIORITY: + AtomicAdd(Impl_->QueueSize, 1); + Impl_->JobQueue.Enqueue(TSingleJob(std::move(exec), id)); + break; + case MED_PRIORITY: + AtomicAdd(Impl_->MPQueueSize, 1); + Impl_->MedJobQueue.Enqueue(TSingleJob(std::move(exec), id)); + break; + case LOW_PRIORITY: + AtomicAdd(Impl_->LPQueueSize, 1); + Impl_->LowJobQueue.Enqueue(TSingleJob(std::move(exec), id)); + break; + default: + Y_ASSERT(0); + break; + } + Impl_->HasJob.Signal(); +} + +void NPar::ILocalExecutor::Exec(TLocallyExecutableFunction exec, int id, int flags) { + Exec(new TFunctionWrapper(std::move(exec)), id, flags); +} + +void NPar::TLocalExecutor::ExecRange(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId, int flags) { + Y_ASSERT(lastId >= firstId); + if (TryExecRangeSequentially([=] (int id) { exec->LocalExec(id); }, firstId, lastId, flags)) { + return; + } + auto rangeExec = MakeIntrusive<TLocalRangeExecutor>(std::move(exec), firstId, lastId); + int queueSizeLimit = (flags & WAIT_COMPLETE) ? 10000 : -1; + int prior = Max<int>(Impl_->CurrentTaskPriority, flags & PRIORITY_MASK); + switch (prior) { + case HIGH_PRIORITY: + Impl_->LaunchRange(rangeExec, queueSizeLimit, &Impl_->QueueSize, &Impl_->JobQueue); + break; + case MED_PRIORITY: + Impl_->LaunchRange(rangeExec, queueSizeLimit, &Impl_->MPQueueSize, &Impl_->MedJobQueue); + break; + case LOW_PRIORITY: + Impl_->LaunchRange(rangeExec, queueSizeLimit, &Impl_->LPQueueSize, &Impl_->LowJobQueue); + break; + default: + Y_ASSERT(0); + break; + } + if (flags & WAIT_COMPLETE) { + int keepPrior = Impl_->CurrentTaskPriority; + Impl_->CurrentTaskPriority = prior; + while (rangeExec->DoSingleOp()) { + } + Impl_->CurrentTaskPriority = keepPrior; + rangeExec->WaitComplete(); + } +} + +void NPar::ILocalExecutor::ExecRange(TLocallyExecutableFunction exec, int firstId, int lastId, int flags) { + if (TryExecRangeSequentially(exec, firstId, lastId, flags)) { + return; + } + ExecRange(new TFunctionWrapper(exec), firstId, lastId, flags); +} + +void NPar::ILocalExecutor::ExecRangeWithThrow(TLocallyExecutableFunction exec, int firstId, int lastId, int flags) { + Y_VERIFY((flags & WAIT_COMPLETE) != 0, "ExecRangeWithThrow() requires WAIT_COMPLETE to wait if exceptions arise."); + if (TryExecRangeSequentially(exec, firstId, lastId, flags)) { + return; + } + TVector<NThreading::TFuture<void>> currentRun = ExecRangeWithFutures(exec, firstId, lastId, flags); + for (auto& result : currentRun) { + result.GetValueSync(); // Exception will be rethrown if exists. If several exception - only the one with minimal id is rethrown. + } +} + +TVector<NThreading::TFuture<void>> +NPar::ILocalExecutor::ExecRangeWithFutures(TLocallyExecutableFunction exec, int firstId, int lastId, int flags) { + TFunctionWrapperWithPromise* execWrapper = new TFunctionWrapperWithPromise(exec, firstId, lastId); + TVector<NThreading::TFuture<void>> out = execWrapper->GetFutures(); + ExecRange(execWrapper, firstId, lastId, flags); + return out; +} + +void NPar::TLocalExecutor::ClearLPQueue() { + for (bool cont = true; cont;) { + cont = false; + TSingleJob job; + while (Impl_->LowJobQueue.Dequeue(&job)) { + AtomicAdd(Impl_->LPQueueSize, -1); + cont = true; + } + while (Impl_->MedJobQueue.Dequeue(&job)) { + AtomicAdd(Impl_->MPQueueSize, -1); + cont = true; + } + } +} + +int NPar::TLocalExecutor::GetQueueSize() const noexcept { + return AtomicGet(Impl_->QueueSize); +} + +int NPar::TLocalExecutor::GetMPQueueSize() const noexcept { + return AtomicGet(Impl_->MPQueueSize); +} + +int NPar::TLocalExecutor::GetLPQueueSize() const noexcept { + return AtomicGet(Impl_->LPQueueSize); +} + +int NPar::TLocalExecutor::GetWorkerThreadId() const noexcept { + return Impl_->WorkerThreadId; +} + +int NPar::TLocalExecutor::GetThreadCount() const noexcept { + return AtomicGet(Impl_->ThreadCount); +} + +////////////////////////////////////////////////////////////////////////// diff --git a/library/cpp/threading/local_executor/local_executor.h b/library/cpp/threading/local_executor/local_executor.h new file mode 100644 index 0000000000..c1c824f67c --- /dev/null +++ b/library/cpp/threading/local_executor/local_executor.h @@ -0,0 +1,294 @@ +#pragma once + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/cast.h> +#include <util/generic/fwd.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/generic/singleton.h> +#include <util/generic/ymath.h> + +#include <functional> + +namespace NPar { + struct ILocallyExecutable : virtual public TThrRefBase { + // Must be implemented by the end user to define job that will be processed by one of + // executor threads. + // + // @param id Job parameter, typically an index pointing somewhere in array, or just + // some dummy value, e.g. `0`. + virtual void LocalExec(int id) = 0; + }; + + // Alternative and simpler way of describing a job for executor. Function argument has the + // same meaning as `id` in `ILocallyExecutable::LocalExec`. + // + using TLocallyExecutableFunction = std::function<void(int)>; + + class ILocalExecutor: public TNonCopyable { + public: + ILocalExecutor() = default; + virtual ~ILocalExecutor() = default; + + enum EFlags : int { + HIGH_PRIORITY = 0, + MED_PRIORITY = 1, + LOW_PRIORITY = 2, + PRIORITY_MASK = 3, + WAIT_COMPLETE = 4 + }; + + // Add task for further execution. + // + // @param exec Task description. + // @param id Task argument. + // @param flags Bitmask composed by `HIGH_PRIORITY`, `MED_PRIORITY`, `LOW_PRIORITY` + // and `WAIT_COMPLETE`. + virtual void Exec(TIntrusivePtr<ILocallyExecutable> exec, int id, int flags) = 0; + + // Add tasks range for further execution. + // + // @param exec Task description. + // @param firstId, lastId Task arguments [firstId, lastId) + // @param flags Same as for `Exec`. + virtual void ExecRange(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId, int flags) = 0; + + // 0-based ILocalExecutor worker thread identification + virtual int GetWorkerThreadId() const noexcept = 0; + virtual int GetThreadCount() const noexcept = 0; + + // Describes a range of tasks with parameters from integer range [FirstId, LastId). + // + class TExecRangeParams { + public: + template <typename TFirst, typename TLast> + TExecRangeParams(TFirst firstId, TLast lastId) + : FirstId(SafeIntegerCast<int>(firstId)) + , LastId(SafeIntegerCast<int>(lastId)) + { + Y_ASSERT(LastId >= FirstId); + SetBlockSize(1); + } + // Partition tasks into `blockCount` blocks of approximately equal size, each of which + // will be executed as a separate bigger task. + // + template <typename TBlockCount> + TExecRangeParams& SetBlockCount(TBlockCount blockCount) { + Y_ASSERT(SafeIntegerCast<int>(blockCount) > 0 || FirstId == LastId); + BlockSize = FirstId == LastId ? 0 : CeilDiv(LastId - FirstId, SafeIntegerCast<int>(blockCount)); + BlockCount = BlockSize == 0 ? 0 : CeilDiv(LastId - FirstId, BlockSize); + BlockEqualToThreads = false; + return *this; + } + // Partition tasks into blocks of approximately `blockSize` size, each of which will + // be executed as a separate bigger task. + // + template <typename TBlockSize> + TExecRangeParams& SetBlockSize(TBlockSize blockSize) { + Y_ASSERT(SafeIntegerCast<int>(blockSize) > 0 || FirstId == LastId); + BlockSize = SafeIntegerCast<int>(blockSize); + BlockCount = BlockSize == 0 ? 0 : CeilDiv(LastId - FirstId, BlockSize); + BlockEqualToThreads = false; + return *this; + } + // Partition tasks into thread count blocks of approximately equal size, each of which + // will be executed as a separate bigger task. + // + TExecRangeParams& SetBlockCountToThreadCount() { + BlockEqualToThreads = true; + return *this; + } + int GetBlockCount() const { + Y_ASSERT(!BlockEqualToThreads); + return BlockCount; + } + int GetBlockSize() const { + Y_ASSERT(!BlockEqualToThreads); + return BlockSize; + } + bool GetBlockEqualToThreads() { + return BlockEqualToThreads; + } + + const int FirstId = 0; + const int LastId = 0; + + private: + int BlockSize; + int BlockCount; + bool BlockEqualToThreads; + }; + + // `Exec` and `ExecRange` versions that accept functions. + // + void Exec(TLocallyExecutableFunction exec, int id, int flags); + void ExecRange(TLocallyExecutableFunction exec, int firstId, int lastId, int flags); + + // Version of `ExecRange` that throws exception from task with minimal id if at least one of + // task threw an exception. + // + void ExecRangeWithThrow(TLocallyExecutableFunction exec, int firstId, int lastId, int flags); + + // Version of `ExecRange` that returns vector of futures, thus allowing to retry any task if + // it fails. + // + TVector<NThreading::TFuture<void>> ExecRangeWithFutures(TLocallyExecutableFunction exec, int firstId, int lastId, int flags); + + template <typename TBody> + static inline auto BlockedLoopBody(const TExecRangeParams& params, const TBody& body) { + return [=](int blockId) { + const int blockFirstId = params.FirstId + blockId * params.GetBlockSize(); + const int blockLastId = Min(params.LastId, blockFirstId + params.GetBlockSize()); + for (int i = blockFirstId; i < blockLastId; ++i) { + body(i); + } + }; + } + + template <typename TBody> + inline void ExecRange(TBody&& body, TExecRangeParams params, int flags) { + if (TryExecRangeSequentially(body, params.FirstId, params.LastId, flags)) { + return; + } + if (params.GetBlockEqualToThreads()) { + params.SetBlockCount(GetThreadCount() + ((flags & WAIT_COMPLETE) != 0)); // ThreadCount or ThreadCount+1 depending on WaitFlag + } + ExecRange(BlockedLoopBody(params, body), 0, params.GetBlockCount(), flags); + } + + template <typename TBody> + inline void ExecRangeBlockedWithThrow(TBody&& body, int firstId, int lastId, int batchSizeOrZeroForAutoBatchSize, int flags) { + if (firstId >= lastId) { + return; + } + const int threadCount = Max(GetThreadCount(), 1); + const int batchSize = batchSizeOrZeroForAutoBatchSize + ? batchSizeOrZeroForAutoBatchSize + : (lastId - firstId + threadCount - 1) / threadCount; + const int batchCount = (lastId - firstId + batchSize - 1) / batchSize; + const int batchCountPerThread = (batchCount + threadCount - 1) / threadCount; + auto states = ExecRangeWithFutures( + [=](int threadId) { + for (int batchIdPerThread = 0; batchIdPerThread < batchCountPerThread; ++batchIdPerThread) { + int batchId = batchIdPerThread * threadCount + threadId; + int begin = firstId + batchId * batchSize; + int end = Min(begin + batchSize, lastId); + for (int i = begin; i < end; ++i) { + body(i); + } + } + }, + 0, threadCount, flags); + for (auto& state: states) { + state.GetValueSync(); // Re-throw exception if any. + } + } + + template <typename TBody> + static inline bool TryExecRangeSequentially(TBody&& body, int firstId, int lastId, int flags) { + if (lastId == firstId) { + return true; + } + if ((flags & WAIT_COMPLETE) && lastId - firstId == 1) { + body(firstId); + return true; + } + return false; + } + }; + + // `TLocalExecutor` provides facilities for easy parallelization of existing code and cycles. + // + // Examples: + // Execute one task with medium priority and wait for it completion. + // ``` + // LocalExecutor().Run(4); + // TEvent event; + // LocalExecutor().Exec([](int) { + // SomeFunc(); + // event.Signal(); + // }, 0, TLocalExecutor::MED_PRIORITY); + // + // SomeOtherCode(); + // event.WaitI(); + // ``` + // + // Execute range of tasks with medium priority. + // ``` + // LocalExecutor().Run(4); + // LocalExecutor().ExecRange([](int id) { + // SomeFunc(id); + // }, TExecRangeParams(0, 10), TLocalExecutor::WAIT_COMPLETE | TLocalExecutor::MED_PRIORITY); + // ``` + // + class TLocalExecutor final: public ILocalExecutor { + public: + using EFlags = ILocalExecutor::EFlags; + + // Creates executor without threads. You'll need to explicitly call `RunAdditionalThreads` + // to add threads to underlying thread pool. + // + TLocalExecutor(); + ~TLocalExecutor(); + + int GetQueueSize() const noexcept; + int GetMPQueueSize() const noexcept; + int GetLPQueueSize() const noexcept; + void ClearLPQueue(); + + // 0-based TLocalExecutor worker thread identification + int GetWorkerThreadId() const noexcept override; + int GetThreadCount() const noexcept override; + + // **Add** threads to underlying thread pool. + // + // @param threadCount Number of threads to add. + void RunAdditionalThreads(int threadCount); + + // Add task for further execution. + // + // @param exec Task description. + // @param id Task argument. + // @param flags Bitmask composed by `HIGH_PRIORITY`, `MED_PRIORITY`, `LOW_PRIORITY` + // and `WAIT_COMPLETE`. + void Exec(TIntrusivePtr<ILocallyExecutable> exec, int id, int flags) override; + + // Add tasks range for further execution. + // + // @param exec Task description. + // @param firstId, lastId Task arguments [firstId, lastId) + // @param flags Same as for `Exec`. + void ExecRange(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId, int flags) override; + + using ILocalExecutor::Exec; + using ILocalExecutor::ExecRange; + + private: + class TImpl; + THolder<TImpl> Impl_; + }; + + static inline TLocalExecutor& LocalExecutor() { + return *Singleton<TLocalExecutor>(); + } + + template <typename TBody> + inline void ParallelFor(ILocalExecutor& executor, ui32 from, ui32 to, TBody&& body) { + ILocalExecutor::TExecRangeParams params(from, to); + params.SetBlockCountToThreadCount(); + executor.ExecRange(std::forward<TBody>(body), params, TLocalExecutor::WAIT_COMPLETE); + } + + template <typename TBody> + inline void ParallelFor(ui32 from, ui32 to, TBody&& body) { + ParallelFor(LocalExecutor(), from, to, std::forward<TBody>(body)); + } + + template <typename TBody> + inline void AsyncParallelFor(ui32 from, ui32 to, TBody&& body) { + ILocalExecutor::TExecRangeParams params(from, to); + params.SetBlockCountToThreadCount(); + LocalExecutor().ExecRange(std::forward<TBody>(body), params, 0); + } +} diff --git a/library/cpp/threading/local_executor/tbb_local_executor.cpp b/library/cpp/threading/local_executor/tbb_local_executor.cpp new file mode 100644 index 0000000000..65d6659443 --- /dev/null +++ b/library/cpp/threading/local_executor/tbb_local_executor.cpp @@ -0,0 +1,53 @@ +#include "tbb_local_executor.h" + +template <bool RespectTls> +void NPar::TTbbLocalExecutor<RespectTls>::SubmitAsyncTasks(TLocallyExecutableFunction exec, int firstId, int lastId) { + for (int i = firstId; i < lastId; ++i) { + Group.run([=] { exec(i); }); + } +} + +template <bool RespectTls> +int NPar::TTbbLocalExecutor<RespectTls>::GetThreadCount() const noexcept { + return NumberOfTbbThreads - 1; +} + +template <bool RespectTls> +int NPar::TTbbLocalExecutor<RespectTls>::GetWorkerThreadId() const noexcept { + return TbbArena.execute([] { + return tbb::this_task_arena::current_thread_index(); + }); +} + +template <bool RespectTls> +void NPar::TTbbLocalExecutor<RespectTls>::Exec(TIntrusivePtr<ILocallyExecutable> exec, int id, int flags) { + if (flags & WAIT_COMPLETE) { + exec->LocalExec(id); + } else { + TbbArena.execute([=] { + SubmitAsyncTasks([=] (int id) { exec->LocalExec(id); }, id, id + 1); + }); + } +} + +template <bool RespectTls> +void NPar::TTbbLocalExecutor<RespectTls>::ExecRange(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId, int flags) { + if (flags & WAIT_COMPLETE) { + TbbArena.execute([=] { + if (RespectTls) { + tbb::this_task_arena::isolate([=]{ + tbb::parallel_for(firstId, lastId, [=] (int id) { exec->LocalExec(id); }); + }); + } else { + tbb::parallel_for(firstId, lastId, [=] (int id) { exec->LocalExec(id); }); + } + }); + } else { + TbbArena.execute([=] { + SubmitAsyncTasks([=] (int id) { exec->LocalExec(id); }, firstId, lastId); + }); + } +} + +template class NPar::TTbbLocalExecutor<true>; +template class NPar::TTbbLocalExecutor<false>; diff --git a/library/cpp/threading/local_executor/tbb_local_executor.h b/library/cpp/threading/local_executor/tbb_local_executor.h new file mode 100644 index 0000000000..8d790db18c --- /dev/null +++ b/library/cpp/threading/local_executor/tbb_local_executor.h @@ -0,0 +1,49 @@ +#pragma once + +#include "local_executor.h" +#define __TBB_TASK_ISOLATION 1 +#define __TBB_NO_IMPLICIT_LINKAGE 1 + +#include <contrib/libs/tbb/include/tbb/blocked_range.h> +#include <contrib/libs/tbb/include/tbb/parallel_for.h> +#include <contrib/libs/tbb/include/tbb/task_arena.h> +#include <contrib/libs/tbb/include/tbb/task_group.h> + +namespace NPar { + template <bool RespectTls = false> + class TTbbLocalExecutor final: public ILocalExecutor { + public: + TTbbLocalExecutor(int nThreads) + : ILocalExecutor() + , TbbArena(nThreads) + , NumberOfTbbThreads(nThreads) {} + ~TTbbLocalExecutor() noexcept override {} + + // 0-based ILocalExecutor worker thread identification + virtual int GetWorkerThreadId() const noexcept override; + virtual int GetThreadCount() const noexcept override; + + // Add task for further execution. + // + // @param exec Task description. + // @param id Task argument. + // @param flags Bitmask composed by `HIGH_PRIORITY`, `MED_PRIORITY`, `LOW_PRIORITY` + // and `WAIT_COMPLETE`. + virtual void Exec(TIntrusivePtr<ILocallyExecutable> exec, int id, int flags) override; + + // Add tasks range for further execution. + // + // @param exec Task description. + // @param firstId, lastId Task arguments [firstId, lastId) + // @param flags Same as for `Exec`. + virtual void ExecRange(TIntrusivePtr<ILocallyExecutable> exec, int firstId, int lastId, int flags) override; + + // Submit tasks for async run + void SubmitAsyncTasks(TLocallyExecutableFunction exec, int firstId, int lastId); + + private: + mutable tbb::task_arena TbbArena; + tbb::task_group Group; + int NumberOfTbbThreads; + }; +} diff --git a/library/cpp/threading/local_executor/ut/local_executor_ut.cpp b/library/cpp/threading/local_executor/ut/local_executor_ut.cpp new file mode 100644 index 0000000000..ac5737717c --- /dev/null +++ b/library/cpp/threading/local_executor/ut/local_executor_ut.cpp @@ -0,0 +1,371 @@ +#include <library/cpp/threading/local_executor/local_executor.h> +#include <library/cpp/threading/future/future.h> + +#include <library/cpp/testing/unittest/registar.h> +#include <util/system/mutex.h> +#include <util/system/rwlock.h> +#include <util/generic/algorithm.h> + +using namespace NPar; + +class TTestException: public yexception { +}; + +static const int DefaultThreadsCount = 41; +static const int DefaultRangeSize = 999; + +Y_UNIT_TEST_SUITE(ExecRangeWithFutures){ + bool AllOf(const TVector<int>& vec, int value){ + return AllOf(vec, [value](int element) { return value == element; }); +} + +void AsyncRunAndWaitFuturesReady(int rangeSize, int threads) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(threads); + TAtomic signal = 0; + TVector<int> data(rangeSize, 0); + TVector<NThreading::TFuture<void>> futures = localExecutor.ExecRangeWithFutures([&signal, &data](int i) { + UNIT_ASSERT(data[i] == 0); + while (AtomicGet(signal) == 0) + ; + data[i] += 1; + }, + 0, rangeSize, TLocalExecutor::HIGH_PRIORITY); + UNIT_ASSERT(AllOf(data, 0)); + for (auto& future : futures) + UNIT_ASSERT(!future.HasValue()); + AtomicSet(signal, 1); + for (auto& future : futures) { + future.GetValueSync(); + } + UNIT_ASSERT(AllOf(data, 1)); +} + +Y_UNIT_TEST(AsyncRunRangeAndWaitFuturesReady) { + AsyncRunAndWaitFuturesReady(DefaultRangeSize, DefaultThreadsCount); +} + +Y_UNIT_TEST(AsyncRunOneTaskAndWaitFuturesReady) { + AsyncRunAndWaitFuturesReady(1, DefaultThreadsCount); +} + +Y_UNIT_TEST(AsyncRunRangeAndWaitFuturesReadyOneExtraThread) { + AsyncRunAndWaitFuturesReady(DefaultRangeSize, 1); +} + +Y_UNIT_TEST(AsyncRunOneThreadAndWaitFuturesReadyOneExtraThread) { + AsyncRunAndWaitFuturesReady(1, 1); +} + +Y_UNIT_TEST(AsyncRunTwoRangesAndWaitFuturesReady) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(DefaultThreadsCount); + TAtomic signal = 0; + TVector<int> data1(DefaultRangeSize, 0); + TVector<NThreading::TFuture<void>> futures1 = localExecutor.ExecRangeWithFutures([&signal, &data1](int i) { + UNIT_ASSERT(data1[i] == 0); + while (AtomicGet(signal) == 0) + ; + data1[i] += 1; + }, + 0, DefaultRangeSize, TLocalExecutor::HIGH_PRIORITY); + TVector<int> data2(DefaultRangeSize, 0); + TVector<NThreading::TFuture<void>> futures2 = localExecutor.ExecRangeWithFutures([&signal, &data2](int i) { + UNIT_ASSERT(data2[i] == 0); + while (AtomicGet(signal) == 0) + ; + data2[i] += 2; + }, + 0, DefaultRangeSize, TLocalExecutor::HIGH_PRIORITY); + UNIT_ASSERT(AllOf(data1, 0)); + UNIT_ASSERT(AllOf(data2, 0)); + AtomicSet(signal, 1); + for (int i = 0; i < DefaultRangeSize; ++i) { + futures1[i].GetValueSync(); + futures2[i].GetValueSync(); + } + UNIT_ASSERT(AllOf(data1, 1)); + UNIT_ASSERT(AllOf(data2, 2)); +} + +void AsyncRunRangeAndWaitExceptions(int rangeSize, int threadsCount) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(threadsCount); + TAtomic signal = 0; + TVector<int> data(rangeSize, 0); + TVector<NThreading::TFuture<void>> futures = localExecutor.ExecRangeWithFutures([&signal, &data](int i) { + UNIT_ASSERT(data[i] == 0); + while (AtomicGet(signal) == 0) + ; + data[i] += 1; + throw 10000 + i; + }, + 0, rangeSize, TLocalExecutor::HIGH_PRIORITY); + UNIT_ASSERT(AllOf(data, 0)); + UNIT_ASSERT(futures.ysize() == rangeSize); + AtomicSet(signal, 1); + int exceptionsCaught = 0; + for (int i = 0; i < rangeSize; ++i) { + try { + futures[i].GetValueSync(); + } catch (int& e) { + if (e == 10000 + i) { + ++exceptionsCaught; + } + } + } + UNIT_ASSERT(exceptionsCaught == rangeSize); + UNIT_ASSERT(AllOf(data, 1)); +} + +Y_UNIT_TEST(AsyncRunRangeAndWaitExceptions) { + AsyncRunRangeAndWaitExceptions(DefaultRangeSize, DefaultThreadsCount); +} + +Y_UNIT_TEST(AsyncRunOneTaskAndWaitExceptions) { + AsyncRunRangeAndWaitExceptions(1, DefaultThreadsCount); +} + +Y_UNIT_TEST(AsyncRunRangeAndWaitExceptionsOneExtraThread) { + AsyncRunRangeAndWaitExceptions(DefaultRangeSize, 1); +} + +Y_UNIT_TEST(AsyncRunOneTaskAndWaitExceptionsOneExtraThread) { + AsyncRunRangeAndWaitExceptions(1, 1); +} + +Y_UNIT_TEST(AsyncRunTwoRangesAndWaitExceptions) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(DefaultThreadsCount); + TAtomic signal = 0; + TVector<int> data1(DefaultRangeSize, 0); + TVector<NThreading::TFuture<void>> futures1 = localExecutor.ExecRangeWithFutures([&signal, &data1](int i) { + UNIT_ASSERT(data1[i] == 0); + while (AtomicGet(signal) == 0) + ; + data1[i] += 1; + throw 15000 + i; + }, + 0, DefaultRangeSize, TLocalExecutor::LOW_PRIORITY); + TVector<int> data2(DefaultRangeSize, 0); + TVector<NThreading::TFuture<void>> futures2 = localExecutor.ExecRangeWithFutures([&signal, &data2](int i) { + UNIT_ASSERT(data2[i] == 0); + while (AtomicGet(signal) == 0) + ; + data2[i] += 2; + throw 16000 + i; + }, + 0, DefaultRangeSize, TLocalExecutor::HIGH_PRIORITY); + + UNIT_ASSERT(AllOf(data1, 0)); + UNIT_ASSERT(AllOf(data2, 0)); + UNIT_ASSERT(futures1.size() == DefaultRangeSize); + UNIT_ASSERT(futures2.size() == DefaultRangeSize); + AtomicSet(signal, 1); + int exceptionsCaught = 0; + for (int i = 0; i < DefaultRangeSize; ++i) { + try { + futures1[i].GetValueSync(); + } catch (int& e) { + if (e == 15000 + i) { + ++exceptionsCaught; + } + } + try { + futures2[i].GetValueSync(); + } catch (int& e) { + if (e == 16000 + i) { + ++exceptionsCaught; + } + } + } + UNIT_ASSERT(exceptionsCaught == 2 * DefaultRangeSize); + UNIT_ASSERT(AllOf(data1, 1)); + UNIT_ASSERT(AllOf(data2, 2)); +} + +void RunRangeAndCheckExceptionsWithWaitComplete(int rangeSize, int threadsCount) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(threadsCount); + TVector<int> data(rangeSize, 0); + TVector<NThreading::TFuture<void>> futures = localExecutor.ExecRangeWithFutures([&data](int i) { + UNIT_ASSERT(data[i] == 0); + data[i] += 1; + throw 30000 + i; + }, + 0, rangeSize, TLocalExecutor::EFlags::WAIT_COMPLETE); + UNIT_ASSERT(AllOf(data, 1)); + int exceptionsCaught = 0; + for (int i = 0; i < rangeSize; ++i) { + try { + futures[i].GetValueSync(); + } catch (int& e) { + if (e == 30000 + i) { + ++exceptionsCaught; + } + } + } + UNIT_ASSERT(exceptionsCaught == rangeSize); + UNIT_ASSERT(AllOf(data, 1)); +} + +Y_UNIT_TEST(RunRangeAndCheckExceptionsWithWaitComplete) { + RunRangeAndCheckExceptionsWithWaitComplete(DefaultRangeSize, DefaultThreadsCount); +} + +Y_UNIT_TEST(RunOneAndCheckExceptionsWithWaitComplete) { + RunRangeAndCheckExceptionsWithWaitComplete(1, DefaultThreadsCount); +} + +Y_UNIT_TEST(RunRangeAndCheckExceptionsWithWaitCompleteOneExtraThread) { + RunRangeAndCheckExceptionsWithWaitComplete(DefaultRangeSize, 1); +} + +Y_UNIT_TEST(RunOneAndCheckExceptionsWithWaitCompleteOneExtraThread) { + RunRangeAndCheckExceptionsWithWaitComplete(1, 1); +} + +Y_UNIT_TEST(RunRangeAndCheckExceptionsWithWaitCompleteZeroExtraThreads) { + RunRangeAndCheckExceptionsWithWaitComplete(DefaultRangeSize, 0); +} + +Y_UNIT_TEST(RunOneAndCheckExceptionsWithWaitCompleteZeroExtraThreads) { + RunRangeAndCheckExceptionsWithWaitComplete(1, 0); +} +} +; + +Y_UNIT_TEST_SUITE(ExecRangeWithThrow){ + void RunParallelWhichThrowsTTestException(int rangeStart, int rangeSize, int threadsCount, int flags, TAtomic& processed){ + AtomicSet(processed, 0); +TLocalExecutor localExecutor; +localExecutor.RunAdditionalThreads(threadsCount); +localExecutor.ExecRangeWithThrow([&processed](int) { + AtomicAdd(processed, 1); + throw TTestException(); +}, + rangeStart, rangeStart + rangeSize, flags); +} + +Y_UNIT_TEST(RunParallelWhichThrowsTTestException) { + TAtomic processed = 0; + UNIT_ASSERT_EXCEPTION( + RunParallelWhichThrowsTTestException(10, 40, DefaultThreadsCount, + TLocalExecutor::EFlags::WAIT_COMPLETE, processed), + TTestException); + UNIT_ASSERT(AtomicGet(processed) == 40); +} + +void ThrowAndCatchTTestException(int rangeSize, int threadsCount, int flags) { + TAtomic processed = 0; + UNIT_ASSERT_EXCEPTION( + RunParallelWhichThrowsTTestException(0, rangeSize, threadsCount, flags, processed), + TTestException); + UNIT_ASSERT(AtomicGet(processed) == rangeSize); +} + +Y_UNIT_TEST(ThrowAndCatchTTestExceptionLowPriority) { + ThrowAndCatchTTestException(DefaultRangeSize, DefaultThreadsCount, + TLocalExecutor::EFlags::WAIT_COMPLETE | TLocalExecutor::EFlags::LOW_PRIORITY); +} + +Y_UNIT_TEST(ThrowAndCatchTTestExceptionMedPriority) { + ThrowAndCatchTTestException(DefaultRangeSize, DefaultThreadsCount, + TLocalExecutor::EFlags::WAIT_COMPLETE | TLocalExecutor::EFlags::MED_PRIORITY); +} + +Y_UNIT_TEST(ThrowAndCatchTTestExceptionHighPriority) { + ThrowAndCatchTTestException(DefaultRangeSize, DefaultThreadsCount, + TLocalExecutor::EFlags::WAIT_COMPLETE | TLocalExecutor::EFlags::HIGH_PRIORITY); +} + +Y_UNIT_TEST(ThrowAndCatchTTestExceptionWaitComplete) { + ThrowAndCatchTTestException(DefaultRangeSize, DefaultThreadsCount, + TLocalExecutor::EFlags::WAIT_COMPLETE); +} + +Y_UNIT_TEST(RethrowExeptionSequentialWaitComplete) { + ThrowAndCatchTTestException(DefaultRangeSize, 0, + TLocalExecutor::EFlags::WAIT_COMPLETE); +} + +Y_UNIT_TEST(RethrowExeptionOneExtraThreadWaitComplete) { + ThrowAndCatchTTestException(DefaultRangeSize, 1, + TLocalExecutor::EFlags::WAIT_COMPLETE); +} + +void ThrowsTTestExceptionFromNested(TLocalExecutor& localExecutor) { + localExecutor.ExecRangeWithThrow([](int) { + throw TTestException(); + }, + 0, 10, TLocalExecutor::EFlags::WAIT_COMPLETE); +} + +void CatchTTestExceptionFromNested(TAtomic& processed1, TAtomic& processed2) { + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(DefaultThreadsCount); + localExecutor.ExecRangeWithThrow([&processed1, &processed2, &localExecutor](int) { + AtomicAdd(processed1, 1); + UNIT_ASSERT_EXCEPTION( + ThrowsTTestExceptionFromNested(localExecutor), + TTestException); + AtomicAdd(processed2, 1); + }, + 0, DefaultRangeSize, TLocalExecutor::EFlags::WAIT_COMPLETE); +} + +Y_UNIT_TEST(NestedParallelExceptionsDoNotLeak) { + TAtomic processed1 = 0; + TAtomic processed2 = 0; + UNIT_ASSERT_NO_EXCEPTION( + CatchTTestExceptionFromNested(processed1, processed2)); + UNIT_ASSERT_EQUAL(AtomicGet(processed1), DefaultRangeSize); + UNIT_ASSERT_EQUAL(AtomicGet(processed2), DefaultRangeSize); +} +} +; + +Y_UNIT_TEST_SUITE(ExecLargeRangeWithThrow){ + + constexpr int LARGE_COUNT = 128 * (1 << 20); + + static auto IsValue(char v) { + return [=](char c) { return c == v; }; + } + + Y_UNIT_TEST(ExecLargeRangeNoExceptions) { + TVector<char> tasks(LARGE_COUNT); + + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(DefaultThreadsCount); + + localExecutor.ExecRangeBlockedWithThrow([&tasks](int i) { + tasks[i] = 1; + }, 0, tasks.size(), 0, TLocalExecutor::EFlags::WAIT_COMPLETE); + UNIT_ASSERT(AllOf(tasks, IsValue(1))); + + + localExecutor.ExecRangeBlockedWithThrow([&tasks](int i) { + tasks[i] += 1; + }, 0, tasks.size(), 128, TLocalExecutor::EFlags::WAIT_COMPLETE); + UNIT_ASSERT(AllOf(tasks, IsValue(2))); + } + + Y_UNIT_TEST(ExecLargeRangeWithException) { + TVector<char> tasks(LARGE_COUNT); + + TLocalExecutor localExecutor; + localExecutor.RunAdditionalThreads(DefaultThreadsCount); + + Fill(tasks.begin(), tasks.end(), 0); + UNIT_ASSERT_EXCEPTION( + localExecutor.ExecRangeBlockedWithThrow([&tasks](int i) { + tasks[i] += 1; + if (i == LARGE_COUNT / 2) { + throw TTestException(); + } + }, 0, tasks.size(), 0, TLocalExecutor::EFlags::WAIT_COMPLETE), + TTestException + ); + } +}; diff --git a/library/cpp/threading/local_executor/ut/ya.make b/library/cpp/threading/local_executor/ut/ya.make new file mode 100644 index 0000000000..be579a5ca0 --- /dev/null +++ b/library/cpp/threading/local_executor/ut/ya.make @@ -0,0 +1,12 @@ +OWNER( + g:matrixnet + gulin +) + +UNITTEST_FOR(library/cpp/threading/local_executor) + +SRCS( + local_executor_ut.cpp +) + +END() diff --git a/library/cpp/threading/local_executor/ya.make b/library/cpp/threading/local_executor/ya.make new file mode 100644 index 0000000000..df210f92bb --- /dev/null +++ b/library/cpp/threading/local_executor/ya.make @@ -0,0 +1,20 @@ +OWNER( + g:matrixnet + gulin + kirillovs + espetrov +) + +LIBRARY() + +SRCS( + local_executor.cpp + tbb_local_executor.cpp +) + +PEERDIR( + contrib/libs/tbb + library/cpp/threading/future +) + +END() diff --git a/library/cpp/threading/poor_man_openmp/thread_helper.cpp b/library/cpp/threading/poor_man_openmp/thread_helper.cpp new file mode 100644 index 0000000000..34cb6507b9 --- /dev/null +++ b/library/cpp/threading/poor_man_openmp/thread_helper.cpp @@ -0,0 +1,7 @@ +#include "thread_helper.h" + +#include <util/generic/singleton.h> + +TMtpQueueHelper& TMtpQueueHelper::Instance() { + return *Singleton<TMtpQueueHelper>(); +} diff --git a/library/cpp/threading/poor_man_openmp/thread_helper.h b/library/cpp/threading/poor_man_openmp/thread_helper.h new file mode 100644 index 0000000000..0ecee0590b --- /dev/null +++ b/library/cpp/threading/poor_man_openmp/thread_helper.h @@ -0,0 +1,105 @@ +#pragma once + +#include <util/thread/pool.h> +#include <util/generic/utility.h> +#include <util/generic/yexception.h> +#include <util/system/info.h> +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/stream/output.h> + +#include <functional> +#include <cstdlib> + +class TMtpQueueHelper { +public: + TMtpQueueHelper() { + SetThreadCount(NSystemInfo::CachedNumberOfCpus()); + } + IThreadPool* Get() { + return q.Get(); + } + size_t GetThreadCount() { + return ThreadCount; + } + void SetThreadCount(size_t threads) { + ThreadCount = threads; + q = CreateThreadPool(ThreadCount); + } + + static TMtpQueueHelper& Instance(); + +private: + size_t ThreadCount; + TAutoPtr<IThreadPool> q; +}; + +namespace NYmp { + inline void SetThreadCount(size_t threads) { + TMtpQueueHelper::Instance().SetThreadCount(threads); + } + + inline size_t GetThreadCount() { + return TMtpQueueHelper::Instance().GetThreadCount(); + } + + template <typename T> + inline void ParallelForStaticChunk(T begin, T end, size_t chunkSize, std::function<void(T)> func) { + chunkSize = Max<size_t>(chunkSize, 1); + + size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount(); + IThreadPool* queue = TMtpQueueHelper::Instance().Get(); + TCondVar cv; + TMutex mutex; + TAtomic counter = threadCount; + std::exception_ptr err; + + for (size_t i = 0; i < threadCount; ++i) { + queue->SafeAddFunc([&cv, &counter, &mutex, &func, i, begin, end, chunkSize, threadCount, &err]() { + try { + T currentChunkStart = begin + static_cast<decltype(T() - T())>(i * chunkSize); + + while (currentChunkStart < end) { + T currentChunkEnd = Min<T>(end, currentChunkStart + chunkSize); + + for (T val = currentChunkStart; val < currentChunkEnd; ++val) { + func(val); + } + + currentChunkStart += chunkSize * threadCount; + } + } catch (...) { + with_lock (mutex) { + err = std::current_exception(); + } + } + + with_lock (mutex) { + if (AtomicDecrement(counter) == 0) { + //last one + cv.Signal(); + } + } + }); + } + + with_lock (mutex) { + while (AtomicGet(counter) > 0) { + cv.WaitI(mutex); + } + } + + if (err) { + std::rethrow_exception(err); + } + } + + template <typename T> + inline void ParallelForStaticAutoChunk(T begin, T end, std::function<void(T)> func) { + const size_t taskSize = end - begin; + const size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount(); + + ParallelForStaticChunk(begin, end, (taskSize + threadCount - 1) / threadCount, func); + } +} diff --git a/library/cpp/threading/poor_man_openmp/thread_helper_ut.cpp b/library/cpp/threading/poor_man_openmp/thread_helper_ut.cpp new file mode 100644 index 0000000000..7417636864 --- /dev/null +++ b/library/cpp/threading/poor_man_openmp/thread_helper_ut.cpp @@ -0,0 +1,26 @@ +#include "thread_helper.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/string.h> +#include <util/generic/yexception.h> + +Y_UNIT_TEST_SUITE(TestMP) { + Y_UNIT_TEST(TestErr) { + std::function<void(int)> f = [](int x) { + if (x == 5) { + ythrow yexception() << "oops"; + } + }; + + TString s; + + try { + NYmp::ParallelForStaticAutoChunk(0, 10, f); + } catch (...) { + s = CurrentExceptionMessage(); + } + + UNIT_ASSERT(s.find("oops") > 0); + } +} diff --git a/library/cpp/threading/poor_man_openmp/ut/ya.make b/library/cpp/threading/poor_man_openmp/ut/ya.make new file mode 100644 index 0000000000..6d7aa123ed --- /dev/null +++ b/library/cpp/threading/poor_man_openmp/ut/ya.make @@ -0,0 +1,12 @@ +UNITTEST_FOR(library/cpp/threading/poor_man_openmp) + +OWNER( + pg + agorodilov +) + +SRCS( + thread_helper_ut.cpp +) + +END() diff --git a/library/cpp/threading/poor_man_openmp/ya.make b/library/cpp/threading/poor_man_openmp/ya.make new file mode 100644 index 0000000000..241b61dead --- /dev/null +++ b/library/cpp/threading/poor_man_openmp/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(agorodilov) + +SRCS( + thread_helper.cpp +) + +END() diff --git a/library/cpp/threading/queue/basic_ut.cpp b/library/cpp/threading/queue/basic_ut.cpp new file mode 100644 index 0000000000..5f56f8583e --- /dev/null +++ b/library/cpp/threading/queue/basic_ut.cpp @@ -0,0 +1,92 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <util/generic/vector.h> +#include <util/system/thread.h> + +#include "ut_helpers.h" + +template <typename TQueueType> +class TQueueTestsInSingleThread: public TTestBase { +private: + using TSelf = TQueueTestsInSingleThread<TQueueType>; + using TLink = TIntrusiveLink; + + UNIT_TEST_SUITE_DEMANGLE(TSelf); + UNIT_TEST(OnePushOnePop) + UNIT_TEST(OnePushOnePop_Repeat1M) + UNIT_TEST(Threads8_Repeat1M_Push1Pop1) + UNIT_TEST_SUITE_END(); + +public: + void OnePushOnePop() { + TQueueType queue; + + auto popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + + TLink msg; + queue.Push(&msg); + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(&msg, popped); + + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + }; + + void OnePushOnePop_Repeat1M() { + TQueueType queue; + TLink msg; + + auto popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + + for (int i = 0; i < 1000000; ++i) { + queue.Push(&msg); + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(&msg, popped); + + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + } + } + + template <size_t NUMBER_OF_THREADS> + void RepeatPush1Pop1_InManyThreads() { + class TCycleThread: public ISimpleThread { + public: + void* ThreadProc() override { + TQueueType queue; + TLink msg; + auto popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + + for (size_t i = 0; i < 1000000; ++i) { + queue.Push(&msg); + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, &msg); + + popped = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(popped, nullptr); + } + return nullptr; + } + }; + + TVector<TAutoPtr<TCycleThread>> cyclers; + + for (size_t i = 0; i < NUMBER_OF_THREADS; ++i) { + cyclers.emplace_back(new TCycleThread); + cyclers.back()->Start(); + } + + for (size_t i = 0; i < NUMBER_OF_THREADS; ++i) { + cyclers[i]->Join(); + } + } + + void Threads8_Repeat1M_Push1Pop1() { + RepeatPush1Pop1_InManyThreads<8>(); + } +}; + +REGISTER_TESTS_FOR_ALL_ORDERED_QUEUES(TQueueTestsInSingleThread); +REGISTER_TESTS_FOR_ALL_UNORDERED_QUEUES(TQueueTestsInSingleThread) diff --git a/library/cpp/threading/queue/mpmc_unordered_ring.cpp b/library/cpp/threading/queue/mpmc_unordered_ring.cpp new file mode 100644 index 0000000000..160547f594 --- /dev/null +++ b/library/cpp/threading/queue/mpmc_unordered_ring.cpp @@ -0,0 +1,74 @@ +#include "mpmc_unordered_ring.h" + +namespace NThreading { + TMPMCUnorderedRing::TMPMCUnorderedRing(size_t size) { + Y_VERIFY(size > 0); + RingSize = size; + RingBuffer.Reset(new void*[size]); + memset(&RingBuffer[0], 0, sizeof(void*) * size); + } + + bool TMPMCUnorderedRing::Push(void* msg, ui16 retryCount) noexcept { + if (retryCount == 0) { + StubbornPush(msg); + return true; + } + for (ui16 itry = retryCount; itry-- > 0;) { + if (WeakPush(msg)) { + return true; + } + } + return false; + } + + bool TMPMCUnorderedRing::WeakPush(void* msg) noexcept { + auto pawl = AtomicIncrement(WritePawl); + if (pawl - AtomicGet(ReadFront) >= RingSize) { + // Queue is full + AtomicDecrement(WritePawl); + return false; + } + + auto writeSlot = AtomicGetAndIncrement(WriteFront); + if (AtomicCas(&RingBuffer[writeSlot % RingSize], msg, nullptr)) { + return true; + } + // slot is occupied for some reason, retry + return false; + } + + void* TMPMCUnorderedRing::Pop() noexcept { + ui64 readSlot; + + for (ui16 itry = MAX_POP_TRIES; itry-- > 0;) { + auto pawl = AtomicIncrement(ReadPawl); + if (pawl > AtomicGet(WriteFront)) { + // Queue is empty + AtomicDecrement(ReadPawl); + return nullptr; + } + + readSlot = AtomicGetAndIncrement(ReadFront); + + auto msg = AtomicSwap(&RingBuffer[readSlot % RingSize], nullptr); + if (msg != nullptr) { + return msg; + } + } + + /* got no message in the slot, let's try to rollback readfront */ + AtomicCas(&ReadFront, readSlot - 1, readSlot); + return nullptr; + } + + void* TMPMCUnorderedRing::UnsafeScanningPop(ui64* last) noexcept { + for (; *last < RingSize;) { + auto msg = AtomicSwap(&RingBuffer[*last], nullptr); + ++*last; + if (msg != nullptr) { + return msg; + } + } + return nullptr; + } +} diff --git a/library/cpp/threading/queue/mpmc_unordered_ring.h b/library/cpp/threading/queue/mpmc_unordered_ring.h new file mode 100644 index 0000000000..5042f7528e --- /dev/null +++ b/library/cpp/threading/queue/mpmc_unordered_ring.h @@ -0,0 +1,42 @@ +#pragma once + +/* + It's not a general purpose queue. + No order guarantee, but it mostly ordered. + Items may stuck in almost empty queue. + Use UnsafeScanningPop to pop all stuck items. + Almost wait-free for producers and consumers. + */ + +#include <util/system/atomic.h> +#include <util/generic/ptr.h> + +namespace NThreading { + struct TMPMCUnorderedRing { + public: + static constexpr ui16 MAX_PUSH_TRIES = 4; + static constexpr ui16 MAX_POP_TRIES = 4; + + TMPMCUnorderedRing(size_t size); + + bool Push(void* msg, ui16 retryCount = MAX_PUSH_TRIES) noexcept; + void StubbornPush(void* msg) { + while (!WeakPush(msg)) { + } + } + + void* Pop() noexcept; + + void* UnsafeScanningPop(ui64* last) noexcept; + + private: + bool WeakPush(void* msg) noexcept; + + size_t RingSize; + TArrayPtr<void*> RingBuffer; + ui64 WritePawl = 0; + ui64 WriteFront = 0; + ui64 ReadPawl = 0; + ui64 ReadFront = 0; + }; +} diff --git a/library/cpp/threading/queue/mpsc_htswap.cpp b/library/cpp/threading/queue/mpsc_htswap.cpp new file mode 100644 index 0000000000..610c8f67f1 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_htswap.cpp @@ -0,0 +1 @@ +#include "mpsc_htswap.h" diff --git a/library/cpp/threading/queue/mpsc_htswap.h b/library/cpp/threading/queue/mpsc_htswap.h new file mode 100644 index 0000000000..c42caa7ac0 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_htswap.h @@ -0,0 +1,132 @@ +#pragma once + +/* + http://www.1024cores.net/home/lock-free-algorithms/queues/non-intrusive-mpsc-node-based-queue + + Simple semi-wait-free queue. Many producers - one consumer. + Tracking of allocated memory is not required. + No CAS. Only atomic swap (exchange) operations. + + WARNING: a sleeping producer can stop progress for consumer. + + WARNING: there is no wait¬ify mechanic for consumer, + consumer receives nullptr if queue was empty. + + WARNING: the algorithm itself is lock-free + but producers and consumer could be blocked by memory allocator + + Reference design: rtmapreduce/libs/threading/lfqueue.h + */ + +#include <util/generic/noncopyable.h> +#include <util/system/types.h> +#include <util/system/atomic.h> + +#include "tune.h" + +namespace NThreading { + namespace NHTSwapPrivate { + template <typename T, typename TTuneup> + struct TNode + : public TTuneup::TNodeBase, + public TTuneup::template TNodeLayout<TNode<T, TTuneup>, T> { + TNode(const T& item) { + this->Next = nullptr; + this->Item = item; + } + + TNode(T&& item) { + this->Next = nullptr; + this->Item = std::move(item); + } + }; + + struct TDefaultTuneup { + struct TNodeBase: private TNonCopyable { + }; + + template <typename TNode, typename T> + struct TNodeLayout { + TNode* Next; + T Item; + }; + + template <typename TNode> + struct TQueueLayout { + TNode* Head; + TNode* Tail; + }; + }; + + template <typename T, typename TTuneup> + class THTSwapQueueImpl + : protected TTuneup::template TQueueLayout<TNode<T, TTuneup>> { + protected: + using TTunedNode = TNode<T, TTuneup>; + + public: + using TItem = T; + + THTSwapQueueImpl() { + this->Head = new TTunedNode(T()); + this->Tail = this->Head; + } + + ~THTSwapQueueImpl() { + TTunedNode* node = this->Head; + while (node != nullptr) { + TTunedNode* next = node->Next; + delete node; + node = next; + } + } + + template <typename TT> + void Push(TT&& item) { + Enqueue(new TTunedNode(std::forward<TT>(item))); + } + + T Peek() { + TTunedNode* next = AtomicGet(this->Head->Next); + if (next == nullptr) { + return T(); + } + return next->Item; + } + + void Enqueue(TTunedNode* node) { + // our goal is to avoid expensive CAS here, + // but now consumer will be blocked until new tail linked. + // fortunately 'window of inconsistency' is extremely small. + TTunedNode* prev = AtomicSwap(&this->Tail, node); + AtomicSet(prev->Next, node); + } + + T Pop() { + TTunedNode* next = AtomicGet(this->Head->Next); + if (next == nullptr) { + return nullptr; + } + auto item = std::move(next->Item); + std::swap(this->Head, next); // no need atomic here + delete next; + return item; + } + + bool IsEmpty() const { + TTunedNode* next = AtomicGet(this->Head->Next); + return (next == nullptr); + } + }; + } + + DeclareTuneTypeParam(THTSwapNodeBase, TNodeBase); + DeclareTuneTypeParam(THTSwapNodeLayout, TNodeLayout); + DeclareTuneTypeParam(THTSwapQueueLayout, TQueueLayout); + + template <typename T = void*, typename... TParams> + class THTSwapQueue + : public NHTSwapPrivate::THTSwapQueueImpl<T, + TTune<NHTSwapPrivate::TDefaultTuneup, TParams...>> { + }; +} diff --git a/library/cpp/threading/queue/mpsc_intrusive_unordered.cpp b/library/cpp/threading/queue/mpsc_intrusive_unordered.cpp new file mode 100644 index 0000000000..3bb1a04f7e --- /dev/null +++ b/library/cpp/threading/queue/mpsc_intrusive_unordered.cpp @@ -0,0 +1,79 @@ +#include "mpsc_intrusive_unordered.h" +#include <util/system/atomic.h> + +namespace NThreading { + void TMPSCIntrusiveUnordered::Push(TIntrusiveNode* node) noexcept { + auto head = AtomicGet(HeadForCaS); + for (ui32 i = NUMBER_OF_TRIES_FOR_CAS; i-- > 0;) { + // no ABA here, because Next is exactly head + // it does not matter how many travels head was made/ + node->Next = head; + auto prev = AtomicGetAndCas(&HeadForCaS, node, head); + if (head == prev) { + return; + } + head = prev; + } + // boring of trying to do cas, let's just swap + + // no need for atomic here, because the next is atomic swap + node->Next = 0; + + head = AtomicSwap(&HeadForSwap, node); + if (head != nullptr) { + AtomicSet(node->Next, head); + } else { + // consumer must know if no other thread may access the memory, + // setting Next to node is a way to notify consumer + AtomicSet(node->Next, node); + } + } + + TIntrusiveNode* TMPSCIntrusiveUnordered::PopMany() noexcept { + if (NotReadyChain == nullptr) { + auto head = AtomicSwap(&HeadForSwap, nullptr); + NotReadyChain = head; + } + + if (NotReadyChain != nullptr) { + auto next = AtomicGet(NotReadyChain->Next); + if (next != nullptr) { + auto ready = NotReadyChain; + TIntrusiveNode* cut; + do { + cut = NotReadyChain; + NotReadyChain = next; + next = AtomicGet(NotReadyChain->Next); + if (next == NotReadyChain) { + cut = NotReadyChain; + NotReadyChain = nullptr; + break; + } + } while (next != nullptr); + cut->Next = nullptr; + return ready; + } + } + + if (AtomicGet(HeadForCaS) != nullptr) { + return AtomicSwap(&HeadForCaS, nullptr); + } + return nullptr; + } + + TIntrusiveNode* TMPSCIntrusiveUnordered::Pop() noexcept { + if (PopOneQueue != nullptr) { + auto head = PopOneQueue; + PopOneQueue = PopOneQueue->Next; + return head; + } + + PopOneQueue = PopMany(); + if (PopOneQueue != nullptr) { + auto head = PopOneQueue; + PopOneQueue = PopOneQueue->Next; + return head; + } + return nullptr; + } +} diff --git a/library/cpp/threading/queue/mpsc_intrusive_unordered.h b/library/cpp/threading/queue/mpsc_intrusive_unordered.h new file mode 100644 index 0000000000..6ac7537ae9 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_intrusive_unordered.h @@ -0,0 +1,35 @@ +#pragma once + +/* + Simple almost-wait-free unordered queue for low contention operations. + + It's wait-free for producers. + Hanging producer can hide some items from consumer. + */ + +#include <util/system/types.h> + +namespace NThreading { + struct TIntrusiveNode { + TIntrusiveNode* Next; + }; + + class TMPSCIntrusiveUnordered { + public: + static constexpr ui32 NUMBER_OF_TRIES_FOR_CAS = 3; + + void Push(TIntrusiveNode* node) noexcept; + TIntrusiveNode* PopMany() noexcept; + TIntrusiveNode* Pop() noexcept; + + void Push(void* node) noexcept { + Push(reinterpret_cast<TIntrusiveNode*>(node)); + } + + private: + TIntrusiveNode* HeadForCaS = nullptr; + TIntrusiveNode* HeadForSwap = nullptr; + TIntrusiveNode* NotReadyChain = nullptr; + TIntrusiveNode* PopOneQueue = nullptr; + }; +} diff --git a/library/cpp/threading/queue/mpsc_read_as_filled.cpp b/library/cpp/threading/queue/mpsc_read_as_filled.cpp new file mode 100644 index 0000000000..8b4664a6f3 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_read_as_filled.cpp @@ -0,0 +1 @@ +#include "mpsc_read_as_filled.h" diff --git a/library/cpp/threading/queue/mpsc_read_as_filled.h b/library/cpp/threading/queue/mpsc_read_as_filled.h new file mode 100644 index 0000000000..be33ba5a58 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_read_as_filled.h @@ -0,0 +1,611 @@ +#pragma once + +/* + Completely wait-free queue, multiple producers - one consumer. Strict order. + The queue algorithm is using concept of virtual infinite array. + + A producer takes a number from a counter and atomically increments the counter. + The number taken is a number of a slot for the producer to put a new message + into infinite array. + + Then producer constructs a virtual infinite array by bidirectional linked list + of blocks. Each block contains several slots. + + There is a hint pointer which optimistically points to the last block + of the list and never goes backward. + + Consumer exploits the property of the hint pointer always going forward + to free old blocks eventually. Consumer periodically read the hint pointer + and the counter and thus deduce producers which potentially holds the pointer + to a block. Consumer can free the block if all that producers filled their + slots and left the queue. + + No producer can stop the progress for other producers. + + Consumer can't stop the progress for producers. + Consumer can skip not-yet-filled slots and read them later. + Thus no producer can stop the progress for consumer. + The algorithm is virtually strictly ordered because it skips slots only + if it is really does not matter in which order the slots were produced and + consumed. + + WARNING: there is no wait¬ify mechanic for consumer, + consumer receives nullptr if queue was empty. + + WARNING: though the algorithm itself is completely wait-free + but producers and consumer could be blocked by memory allocator + + WARNING: copy constructors of the queue are not thread-safe + */ + +#include <util/generic/deque.h> +#include <util/generic/ptr.h> +#include <util/system/atomic.h> +#include <util/system/spinlock.h> + +#include "tune.h" + +namespace NThreading { + namespace NReadAsFilledPrivate { + typedef void* TMsgLink; + + static constexpr ui32 DEFAULT_BUNCH_SIZE = 251; + + struct TEmpty { + }; + + struct TEmptyAux { + TEmptyAux Retrieve() const { + return TEmptyAux(); + } + + void Store(TEmptyAux&) { + } + + static constexpr TEmptyAux Zero() { + return TEmptyAux(); + } + }; + + template <typename TAux> + struct TSlot { + TMsgLink volatile Msg; + TAux AuxiliaryData; + + inline void Store(TAux& aux) { + AuxiliaryData.Store(aux); + } + + inline TAux Retrieve() const { + return AuxiliaryData.Retrieve(); + } + + static TSlot<TAux> NullElem() { + return {nullptr, TAux::Zero()}; + } + + static TSlot<TAux> Pair(TMsgLink msg, TAux aux) { + return {msg, std::move(aux)}; + } + }; + + template <> + struct TSlot<TEmptyAux> { + TMsgLink volatile Msg; + + inline void Store(TEmptyAux&) { + } + + inline TEmptyAux Retrieve() const { + return TEmptyAux(); + } + + static TSlot<TEmptyAux> NullElem() { + return {nullptr}; + } + + static TSlot<TEmptyAux> Pair(TMsgLink msg, TEmptyAux) { + return {msg}; + } + }; + + enum TPushResult { + PUSH_RESULT_OK, + PUSH_RESULT_BACKWARD, + PUSH_RESULT_FORWARD, + }; + + template <ui32 BUNCH_SIZE = DEFAULT_BUNCH_SIZE, + typename TBase = TEmpty, + typename TAux = TEmptyAux> + struct TMsgBunch: public TBase { + static constexpr size_t RELEASE_SIZE = BUNCH_SIZE * 2; + + ui64 FirstSlot; + + TSlot<TAux> LinkArray[BUNCH_SIZE]; + + TMsgBunch* volatile NextBunch; + TMsgBunch* volatile BackLink; + + ui64 volatile Token; + TMsgBunch* volatile NextToken; + + /* this push can return PUSH_RESULT_BLOCKED */ + inline TPushResult Push(TMsgLink msg, ui64 slot, TAux auxiliary) { + if (Y_UNLIKELY(slot < FirstSlot)) { + return PUSH_RESULT_BACKWARD; + } + + if (Y_UNLIKELY(slot >= FirstSlot + BUNCH_SIZE)) { + return PUSH_RESULT_FORWARD; + } + + LinkArray[slot - FirstSlot].Store(auxiliary); + + AtomicSet(LinkArray[slot - FirstSlot].Msg, msg); + return PUSH_RESULT_OK; + } + + inline bool IsSlotHere(ui64 slot) { + return slot < FirstSlot + BUNCH_SIZE; + } + + inline TMsgLink GetSlot(ui64 slot) const { + return AtomicGet(LinkArray[slot - FirstSlot].Msg); + } + + inline TSlot<TAux> GetSlotAux(ui64 slot) const { + auto msg = GetSlot(slot); + auto aux = LinkArray[slot - FirstSlot].Retrieve(); + return TSlot<TAux>::Pair(msg, aux); + } + + inline TMsgBunch* GetNextBunch() const { + return AtomicGet(NextBunch); + } + + inline bool SetNextBunch(TMsgBunch* ptr) { + return AtomicCas(&NextBunch, ptr, nullptr); + } + + inline TMsgBunch* GetBackLink() const { + return AtomicGet(BackLink); + } + + inline TMsgBunch* GetToken(ui64 slot) { + return reinterpret_cast<TMsgBunch*>( + LinkArray[slot - FirstSlot].Msg); + } + + inline void IncrementToken() { + AtomicIncrement(Token); + } + + // the object could be destroyed after this method + inline void DecrementToken() { + if (Y_UNLIKELY(AtomicDecrement(Token) == RELEASE_SIZE)) { + Release(this); + AtomicGet(NextToken)->DecrementToken(); + // this could be invalid here + } + } + + // the object could be destroyed after this method + inline void SetNextToken(TMsgBunch* next) { + AtomicSet(NextToken, next); + if (Y_UNLIKELY(AtomicAdd(Token, RELEASE_SIZE) == RELEASE_SIZE)) { + Release(this); + next->DecrementToken(); + } + // this could be invalid here + } + + TMsgBunch(ui64 start, TMsgBunch* backLink) { + AtomicSet(FirstSlot, start); + memset(&LinkArray, 0, sizeof(LinkArray)); + AtomicSet(NextBunch, nullptr); + AtomicSet(BackLink, backLink); + + AtomicSet(Token, 1); + AtomicSet(NextToken, nullptr); + } + + static void Release(TMsgBunch* block) { + auto backLink = AtomicGet(block->BackLink); + if (backLink == nullptr) { + return; + } + AtomicSet(block->BackLink, nullptr); + + do { + auto bbackLink = backLink->BackLink; + delete backLink; + backLink = bbackLink; + } while (backLink != nullptr); + } + + void Destroy() { + for (auto tail = BackLink; tail != nullptr;) { + auto next = tail->BackLink; + delete tail; + tail = next; + } + + for (auto next = this; next != nullptr;) { + auto nnext = next->NextBunch; + delete next; + next = nnext; + } + } + }; + + template <ui32 BUNCH_SIZE = DEFAULT_BUNCH_SIZE, + typename TBunchBase = NReadAsFilledPrivate::TEmpty, + typename TAux = TEmptyAux> + class TWriteBucket { + public: + using TUsingAux = TAux; // for TReadBucket binding + using TBunch = TMsgBunch<BUNCH_SIZE, TBunchBase, TAux>; + + TWriteBucket(TBunch* bunch = new TBunch(0, nullptr)) { + AtomicSet(LastBunch, bunch); + AtomicSet(SlotCounter, 0); + } + + TWriteBucket(TWriteBucket&& move) + : LastBunch(move.LastBunch) + , SlotCounter(move.SlotCounter) + { + move.LastBunch = nullptr; + } + + ~TWriteBucket() { + if (LastBunch != nullptr) { + LastBunch->Destroy(); + } + } + + inline void Push(TMsgLink msg, TAux aux) { + ui64 pushSlot = AtomicGetAndIncrement(SlotCounter); + TBunch* hintBunch = GetLastBunch(); + + for (;;) { + auto hint = hintBunch->Push(msg, pushSlot, aux); + if (Y_LIKELY(hint == PUSH_RESULT_OK)) { + return; + } + HandleHint(hintBunch, hint); + } + } + + protected: + template <typename, template <typename, typename...> class> + friend class TReadBucket; + + TBunch* volatile LastBunch; // Hint + volatile ui64 SlotCounter; + + inline TBunch* GetLastBunch() const { + return AtomicGet(LastBunch); + } + + void HandleHint(TBunch*& hintBunch, TPushResult hint) { + if (Y_UNLIKELY(hint == PUSH_RESULT_BACKWARD)) { + hintBunch = hintBunch->GetBackLink(); + return; + } + + // PUSH_RESULT_FORWARD + auto nextBunch = hintBunch->GetNextBunch(); + + if (nextBunch == nullptr) { + auto first = hintBunch->FirstSlot + BUNCH_SIZE; + nextBunch = new TBunch(first, hintBunch); + if (Y_UNLIKELY(!hintBunch->SetNextBunch(nextBunch))) { + delete nextBunch; + nextBunch = hintBunch->GetNextBunch(); + } + } + + // hintBunch could not be freed here so it cannot be reused + // it's alright if this CAS was not succeeded, + // it means that other thread did that recently + AtomicCas(&LastBunch, nextBunch, hintBunch); + + hintBunch = nextBunch; + } + }; + + template <typename TWBucket = TWriteBucket<>, + template <typename, typename...> class TContainer = TDeque> + class TReadBucket { + public: + using TAux = typename TWBucket::TUsingAux; + using TBunch = typename TWBucket::TBunch; + + static constexpr int MAX_NUMBER_OF_TRIES_TO_READ = 5; + + TReadBucket(TWBucket* writer) + : Writer(writer) + , ReadBunch(writer->GetLastBunch()) + , LastKnownPushBunch(writer->GetLastBunch()) + { + ReadBunch->DecrementToken(); // no previous token + } + + TReadBucket(TReadBucket toCopy, TWBucket* writer) + : TReadBucket(std::move(toCopy)) + { + Writer = writer; + } + + ui64 ReadyCount() const { + return AtomicGet(Writer->SlotCounter) - ReadSlot; + } + + TMsgLink Pop() { + return PopAux().Msg; + } + + TMsgLink Peek() { + return PeekAux().Msg; + } + + TSlot<TAux> PopAux() { + for (;;) { + if (Y_UNLIKELY(ReadNow.size() != 0)) { + auto result = PopSkipped(); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + } + + if (Y_UNLIKELY(ReadSlot == LastKnownPushSlot)) { + if (Y_LIKELY(!RereadPushSlot())) { + return TSlot<TAux>::NullElem(); + } + continue; + } + + if (Y_UNLIKELY(!ReadBunch->IsSlotHere(ReadSlot))) { + if (Y_UNLIKELY(!SwitchToNextBunch())) { + return TSlot<TAux>::NullElem(); + } + } + + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + ++ReadSlot; + return result; + } + + result = StubbornPop(); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + } + } + + TSlot<TAux> PeekAux() { + for (;;) { + if (Y_UNLIKELY(ReadNow.size() != 0)) { + auto result = PeekSkipped(); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + } + + if (Y_UNLIKELY(ReadSlot == LastKnownPushSlot)) { + if (Y_LIKELY(!RereadPushSlot())) { + return TSlot<TAux>::NullElem(); + } + continue; + } + + if (Y_UNLIKELY(!ReadBunch->IsSlotHere(ReadSlot))) { + if (Y_UNLIKELY(!SwitchToNextBunch())) { + return TSlot<TAux>::NullElem(); + } + } + + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + + result = StubbornPeek(); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + } + } + + private: + TWBucket* Writer; + TBunch* ReadBunch; + ui64 ReadSlot = 0; + TBunch* LastKnownPushBunch; + ui64 LastKnownPushSlot = 0; + + struct TSkipItem { + TBunch* Bunch; + ui64 Slot; + TBunch* Token; + }; + + TContainer<TSkipItem> ReadNow; + TContainer<TSkipItem> ReadLater; + + void AddToReadLater() { + ReadLater.push_back({ReadBunch, ReadSlot, LastKnownPushBunch}); + LastKnownPushBunch->IncrementToken(); + ++ReadSlot; + } + + // MUST BE: ReadSlot == LastKnownPushSlot + bool RereadPushSlot() { + ReadNow = std::move(ReadLater); + ReadLater.clear(); + + auto oldSlot = LastKnownPushSlot; + + auto currentPushBunch = Writer->GetLastBunch(); + auto currentPushSlot = AtomicGet(Writer->SlotCounter); + + if (currentPushBunch != LastKnownPushBunch) { + // LastKnownPushBunch could be invalid after this line + LastKnownPushBunch->SetNextToken(currentPushBunch); + } + + LastKnownPushBunch = currentPushBunch; + LastKnownPushSlot = currentPushSlot; + + return oldSlot != LastKnownPushSlot; + } + + bool SwitchToNextBunch() { + for (int q = 0; q < MAX_NUMBER_OF_TRIES_TO_READ; ++q) { + auto next = ReadBunch->GetNextBunch(); + if (next != nullptr) { + ReadBunch = next; + return true; + } + SpinLockPause(); + } + return false; + } + + TSlot<TAux> StubbornPop() { + for (int q = 0; q < MAX_NUMBER_OF_TRIES_TO_READ; ++q) { + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + ++ReadSlot; + return result; + } + SpinLockPause(); + } + + AddToReadLater(); + return TSlot<TAux>::NullElem(); + } + + TSlot<TAux> StubbornPeek() { + for (int q = 0; q < MAX_NUMBER_OF_TRIES_TO_READ; ++q) { + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + SpinLockPause(); + } + + AddToReadLater(); + return TSlot<TAux>::NullElem(); + } + + TSlot<TAux> PopSkipped() { + do { + auto elem = ReadNow.front(); + ReadNow.pop_front(); + + auto result = elem.Bunch->GetSlotAux(elem.Slot); + if (Y_LIKELY(result.Msg != nullptr)) { + elem.Token->DecrementToken(); + return result; + } + + ReadLater.emplace_back(elem); + + } while (ReadNow.size() > 0); + + return TSlot<TAux>::NullElem(); + } + + TSlot<TAux> PeekSkipped() { + do { + auto elem = ReadNow.front(); + + auto result = elem.Bunch->GetSlotAux(elem.Slot); + if (Y_LIKELY(result.Msg != nullptr)) { + return result; + } + + ReadNow.pop_front(); + ReadLater.emplace_back(elem); + + } while (ReadNow.size() > 0); + + return TSlot<TAux>::NullElem(); + } + }; + + struct TDefaultParams { + static constexpr ui32 BUNCH_SIZE = DEFAULT_BUNCH_SIZE; + using TBunchBase = TEmpty; + + template <typename TElem, typename... TRest> + using TContainer = TDeque<TElem, TRest...>; + + static constexpr bool DeleteItems = true; + }; + + } //namespace NReadAsFilledPrivate + + DeclareTuneValueParam(TRaFQueueBunchSize, ui32, BUNCH_SIZE); + DeclareTuneTypeParam(TRaFQueueBunchBase, TBunchBase); + DeclareTuneContainer(TRaFQueueSkipContainer, TContainer); + DeclareTuneValueParam(TRaFQueueDeleteItems, bool, DeleteItems); + + template <typename TItem = void, typename... TParams> + class TReadAsFilledQueue { + private: + using TTuned = TTune<NReadAsFilledPrivate::TDefaultParams, TParams...>; + + static constexpr ui32 BUNCH_SIZE = TTuned::BUNCH_SIZE; + + using TBunchBase = typename TTuned::TBunchBase; + + template <typename TElem, typename... TRest> + using TContainer = + typename TTuned::template TContainer<TElem, TRest...>; + + using TWriteBucket = + NReadAsFilledPrivate::TWriteBucket<BUNCH_SIZE, TBunchBase>; + using TReadBucket = + NReadAsFilledPrivate::TReadBucket<TWriteBucket, TContainer>; + + public: + TReadAsFilledQueue() + : RBucket(&WBucket) + { + } + + ~TReadAsFilledQueue() { + if (TTuned::DeleteItems) { + for (;;) { + auto msg = Pop(); + if (msg == nullptr) { + break; + } + TDelete::Destroy(msg); + } + } + } + + void Push(TItem* msg) { + WBucket.Push((void*)msg, NReadAsFilledPrivate::TEmptyAux()); + } + + TItem* Pop() { + return (TItem*)RBucket.Pop(); + } + + TItem* Peek() { + return (TItem*)RBucket.Peek(); + } + + protected: + TWriteBucket WBucket; + TReadBucket RBucket; + }; +} diff --git a/library/cpp/threading/queue/mpsc_vinfarr_obstructive.cpp b/library/cpp/threading/queue/mpsc_vinfarr_obstructive.cpp new file mode 100644 index 0000000000..2bd0c29821 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_vinfarr_obstructive.cpp @@ -0,0 +1 @@ +#include "mpsc_vinfarr_obstructive.h" diff --git a/library/cpp/threading/queue/mpsc_vinfarr_obstructive.h b/library/cpp/threading/queue/mpsc_vinfarr_obstructive.h new file mode 100644 index 0000000000..5f91f1b5a8 --- /dev/null +++ b/library/cpp/threading/queue/mpsc_vinfarr_obstructive.h @@ -0,0 +1,528 @@ +#pragma once + +/* + Semi-wait-free queue, multiple producers - one consumer. Strict order. + The queue algorithm is using concept of virtual infinite array. + + A producer takes a number from a counter and atomicaly increments the counter. + The number taken is a number of a slot for the producer to put a new message + into infinite array. + + Then producer constructs a virtual infinite array by bidirectional linked list + of blocks. Each block contains several slots. + + There is a hint pointer which optimisticly points to the last block + of the list and never goes backward. + + Consumer exploits the property of the hint pointer always going forward + to free old blocks eventually. Consumer periodically read the hint pointer + and the counter and thus deduce producers which potentially holds the pointer + to a block. Consumer can free the block if all that producers filled their + slots and left the queue. + + No producer can stop the progress for other producers. + + Consumer can obstruct a slot of a delayed producer by putting special mark. + Thus no producer can stop the progress for consumer. + But a slow producer may be forced to retry unlimited number of times. + Though it's very unlikely for a non-preempted producer to be obstructed. + That's why the algorithm is semi-wait-free. + + WARNING: there is no wait¬ify mechanic for consumer, + consumer receives nullptr if queue was empty. + + WARNING: though the algorithm itself is lock-free + but producers and consumer could be blocked by memory allocator + + WARNING: copy constructers of the queue are not thread-safe + */ + +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/system/atomic.h> +#include <util/system/spinlock.h> + +#include "tune.h" + +namespace NThreading { + namespace NObstructiveQueuePrivate { + typedef void* TMsgLink; + + struct TEmpty { + }; + + struct TEmptyAux { + TEmptyAux Retrieve() const { + return TEmptyAux(); + } + void Store(TEmptyAux&) { + } + static constexpr TEmptyAux Zero() { + return TEmptyAux(); + } + }; + + template <typename TAux> + struct TSlot { + TMsgLink volatile Msg; + TAux AuxiliaryData; + + inline void Store(TAux& aux) { + AuxiliaryData.Store(aux); + } + + inline TAux Retrieve() const { + return AuxiliaryData.Retrieve(); + } + + static TSlot<TAux> NullElem() { + return {nullptr, TAux::Zero()}; + } + + static TSlot<TAux> Pair(TMsgLink msg, TAux aux) { + return {msg, std::move(aux)}; + } + }; + + template <> + struct TSlot<TEmptyAux> { + TMsgLink volatile Msg; + inline void Store(TEmptyAux&) { + } + inline TEmptyAux Retrieve() const { + return TEmptyAux(); + } + + static TSlot<TEmptyAux> NullElem() { + return {nullptr}; + } + + static TSlot<TEmptyAux> Pair(TMsgLink msg, TEmptyAux) { + return {msg}; + } + }; + + enum TPushResult { + PUSH_RESULT_OK, + PUSH_RESULT_BACKWARD, + PUSH_RESULT_FORWARD, + PUSH_RESULT_BLOCKED, + }; + + template <typename TAux, ui32 BUNCH_SIZE, typename TBase = TEmpty> + struct TMsgBunch: public TBase { + ui64 FirstSlot; + + TSlot<TAux> LinkArray[BUNCH_SIZE]; + + TMsgBunch* volatile NextBunch; + TMsgBunch* volatile BackLink; + + ui64 volatile Token; + TMsgBunch* volatile NextToken; + + /* this push can return PUSH_RESULT_BLOCKED */ + inline TPushResult Push(TMsgLink msg, ui64 slot, TAux auxiliary) { + if (Y_UNLIKELY(slot < FirstSlot)) { + return PUSH_RESULT_BACKWARD; + } + + if (Y_UNLIKELY(slot >= FirstSlot + BUNCH_SIZE)) { + return PUSH_RESULT_FORWARD; + } + + LinkArray[slot - FirstSlot].Store(auxiliary); + + auto oldValue = AtomicSwap(&LinkArray[slot - FirstSlot].Msg, msg); + + if (Y_LIKELY(oldValue == nullptr)) { + return PUSH_RESULT_OK; + } else { + LeaveBlocked(oldValue); + return PUSH_RESULT_BLOCKED; + } + } + + inline bool IsSlotHere(ui64 slot) { + return slot < FirstSlot + BUNCH_SIZE; + } + + inline TMsgLink GetSlot(ui64 slot) const { + return AtomicGet(LinkArray[slot - FirstSlot].Msg); + } + + inline TSlot<TAux> GetSlotAux(ui64 slot) const { + auto msg = GetSlot(slot); + auto aux = LinkArray[slot - FirstSlot].Retrieve(); + return TSlot<TAux>::Pair(msg, aux); + } + + void LeaveBlocked(ui64 slot) { + auto token = GetToken(slot); + token->DecrementToken(); + } + + void LeaveBlocked(TMsgLink msg) { + auto token = reinterpret_cast<TMsgBunch*>(msg); + token->DecrementToken(); + } + + TSlot<TAux> BlockSlotAux(ui64 slot, TMsgBunch* token) { + auto old = + AtomicSwap(&LinkArray[slot - FirstSlot].Msg, (TMsgLink)token); + if (old == nullptr) { + // It's valid to increment after AtomicCas + // because token will release data only after SetNextToken + token->IncrementToken(); + return TSlot<TAux>::NullElem(); + } + return TSlot<TAux>::Pair(old, LinkArray[slot - FirstSlot].Retrieve()); + } + + inline TMsgBunch* GetNextBunch() const { + return AtomicGet(NextBunch); + } + + inline bool SetNextBunch(TMsgBunch* ptr) { + return AtomicCas(&NextBunch, ptr, nullptr); + } + + inline TMsgBunch* GetBackLink() const { + return AtomicGet(BackLink); + } + + inline TMsgBunch* GetToken(ui64 slot) { + return reinterpret_cast<TMsgBunch*>(LinkArray[slot - FirstSlot].Msg); + } + + inline void IncrementToken() { + AtomicIncrement(Token); + } + + // the object could be destroyed after this method + inline void DecrementToken() { + if (Y_UNLIKELY(AtomicDecrement(Token) == BUNCH_SIZE)) { + Release(this); + AtomicGet(NextToken)->DecrementToken(); + // this could be invalid here + } + } + + // the object could be destroyed after this method + inline void SetNextToken(TMsgBunch* next) { + AtomicSet(NextToken, next); + if (Y_UNLIKELY(AtomicAdd(Token, BUNCH_SIZE) == BUNCH_SIZE)) { + Release(this); + next->DecrementToken(); + } + // this could be invalid here + } + + TMsgBunch(ui64 start, TMsgBunch* backLink) { + AtomicSet(FirstSlot, start); + memset(&LinkArray, 0, sizeof(LinkArray)); + AtomicSet(NextBunch, nullptr); + AtomicSet(BackLink, backLink); + + AtomicSet(Token, 1); + AtomicSet(NextToken, nullptr); + } + + static void Release(TMsgBunch* bunch) { + auto backLink = AtomicGet(bunch->BackLink); + if (backLink == nullptr) { + return; + } + AtomicSet(bunch->BackLink, nullptr); + + do { + auto bbackLink = backLink->BackLink; + delete backLink; + backLink = bbackLink; + } while (backLink != nullptr); + } + + void Destroy() { + for (auto tail = BackLink; tail != nullptr;) { + auto next = tail->BackLink; + delete tail; + tail = next; + } + + for (auto next = this; next != nullptr;) { + auto nnext = next->NextBunch; + delete next; + next = nnext; + } + } + }; + + template <typename TAux, ui32 BUNCH_SIZE, typename TBunchBase = TEmpty> + class TWriteBucket { + public: + static const ui64 GROSS_SIZE; + + using TBunch = TMsgBunch<TAux, BUNCH_SIZE, TBunchBase>; + + TWriteBucket(TBunch* bunch = new TBunch(0, nullptr)) + : LastBunch(bunch) + , SlotCounter(0) + { + } + + TWriteBucket(TWriteBucket&& move) + : LastBunch(move.LastBunch) + , SlotCounter(move.SlotCounter) + { + move.LastBunch = nullptr; + } + + ~TWriteBucket() { + if (LastBunch != nullptr) { + LastBunch->Destroy(); + } + } + + inline bool Push(TMsgLink msg, TAux aux) { + ui64 pushSlot = AtomicGetAndIncrement(SlotCounter); + TBunch* hintBunch = GetLastBunch(); + + for (;;) { + auto hint = hintBunch->Push(msg, pushSlot, aux); + if (Y_LIKELY(hint == PUSH_RESULT_OK)) { + return true; + } + bool hhResult = HandleHint(hintBunch, hint); + if (Y_UNLIKELY(!hhResult)) { + return false; + } + } + } + + protected: + template <typename, ui32, typename> + friend class TReadBucket; + + TBunch* volatile LastBunch; // Hint + volatile ui64 SlotCounter; + + inline TBunch* GetLastBunch() const { + return AtomicGet(LastBunch); + } + + bool HandleHint(TBunch*& hintBunch, TPushResult hint) { + if (Y_UNLIKELY(hint == PUSH_RESULT_BLOCKED)) { + return false; + } + + if (Y_UNLIKELY(hint == PUSH_RESULT_BACKWARD)) { + hintBunch = hintBunch->GetBackLink(); + return true; + } + + // PUSH_RESULT_FORWARD + auto nextBunch = hintBunch->GetNextBunch(); + + if (nextBunch == nullptr) { + auto first = hintBunch->FirstSlot + BUNCH_SIZE; + nextBunch = new TBunch(first, hintBunch); + if (Y_UNLIKELY(!hintBunch->SetNextBunch(nextBunch))) { + delete nextBunch; + nextBunch = hintBunch->GetNextBunch(); + } + } + + // hintBunch could not be freed here so it cannot be reused + // it's alright if this CAS was not succeeded, + // it means that other thread did that recently + AtomicCas(&LastBunch, nextBunch, hintBunch); + + hintBunch = nextBunch; + return true; + } + }; + + template <typename TAux, ui32 BUNCH_SIZE, typename TBunchBase> + class TReadBucket { + public: + static constexpr int MAX_NUMBER_OF_TRIES_TO_READ = 20; + + using TWBucket = TWriteBucket<TAux, BUNCH_SIZE, TBunchBase>; + using TBunch = TMsgBunch<TAux, BUNCH_SIZE, TBunchBase>; + + TReadBucket(TWBucket* writer) + : Writer(writer) + , ReadBunch(writer->GetLastBunch()) + , LastKnownPushBunch(writer->GetLastBunch()) + { + ReadBunch->DecrementToken(); // no previous token + } + + TReadBucket(TReadBucket toCopy, TWBucket* writer) + : TReadBucket(std::move(toCopy)) + { + Writer = writer; + } + + ui64 ReadyCount() const { + return AtomicGet(Writer->SlotCounter) - ReadSlot; + } + + inline TMsgLink Pop() { + return PopAux().Msg; + } + + inline TSlot<TAux> PopAux() { + for (;;) { + if (Y_UNLIKELY(ReadSlot == LastKnownPushSlot)) { + if (Y_LIKELY(!RereadPushSlot())) { + return TSlot<TAux>::NullElem(); + } + } + + if (Y_UNLIKELY(!ReadBunch->IsSlotHere(ReadSlot))) { + if (Y_UNLIKELY(!SwitchToNextBunch())) { + return TSlot<TAux>::NullElem(); + } + } + + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + ++ReadSlot; + return result; + } + + if (ReadSlot + 1 == AtomicGet(Writer->SlotCounter)) { + return TSlot<TAux>::NullElem(); + } + + result = StubbornPopAux(); + + if (result.Msg != nullptr) { + return result; + } + } + } + + private: + TWBucket* Writer; + TBunch* ReadBunch; + ui64 ReadSlot = 0; + TBunch* LastKnownPushBunch; + ui64 LastKnownPushSlot = 0; + + // MUST BE: ReadSlot == LastKnownPushSlot + bool RereadPushSlot() { + auto oldSlot = LastKnownPushSlot; + + auto currentPushBunch = Writer->GetLastBunch(); + auto currentPushSlot = AtomicGet(Writer->SlotCounter); + + if (currentPushBunch != LastKnownPushBunch) { + // LastKnownPushBunch could be invalid after this line + LastKnownPushBunch->SetNextToken(currentPushBunch); + } + + LastKnownPushBunch = currentPushBunch; + LastKnownPushSlot = currentPushSlot; + + return oldSlot != LastKnownPushSlot; + } + + bool SwitchToNextBunch() { + for (int q = 0; q < MAX_NUMBER_OF_TRIES_TO_READ; ++q) { + auto next = ReadBunch->GetNextBunch(); + if (next != nullptr) { + ReadBunch = next; + return true; + } + SpinLockPause(); + } + return false; + } + + TSlot<TAux> StubbornPopAux() { + for (int q = 0; q < MAX_NUMBER_OF_TRIES_TO_READ; ++q) { + auto result = ReadBunch->GetSlotAux(ReadSlot); + if (Y_LIKELY(result.Msg != nullptr)) { + ++ReadSlot; + return result; + } + SpinLockPause(); + } + + return ReadBunch->BlockSlotAux(ReadSlot++, LastKnownPushBunch); + } + }; + + struct TDefaultParams { + static constexpr bool DeleteItems = true; + using TAux = NObstructiveQueuePrivate::TEmptyAux; + using TBunchBase = NObstructiveQueuePrivate::TEmpty; + static constexpr ui32 BUNCH_SIZE = 251; + }; + + } //namespace NObstructiveQueuePrivate + + DeclareTuneValueParam(TObstructiveQueueBunchSize, ui32, BUNCH_SIZE); + DeclareTuneValueParam(TObstructiveQueueDeleteItems, bool, DeleteItems); + DeclareTuneTypeParam(TObstructiveQueueBunchBase, TBunchBase); + DeclareTuneTypeParam(TObstructiveQueueAux, TAux); + + template <typename TItem = void, typename... TParams> + class TObstructiveConsumerAuxQueue { + private: + using TTuned = + TTune<NObstructiveQueuePrivate::TDefaultParams, TParams...>; + + using TAux = typename TTuned::TAux; + using TSlot = NObstructiveQueuePrivate::TSlot<TAux>; + using TMsgLink = NObstructiveQueuePrivate::TMsgLink; + using TBunchBase = typename TTuned::TBunchBase; + static constexpr bool DeleteItems = TTuned::DeleteItems; + static constexpr ui32 BUNCH_SIZE = TTuned::BUNCH_SIZE; + + public: + TObstructiveConsumerAuxQueue() + : RBuckets(&WBucket) + { + } + + ~TObstructiveConsumerAuxQueue() { + if (DeleteItems) { + for (;;) { + auto msg = Pop(); + if (msg == nullptr) { + break; + } + TDelete::Destroy(msg); + } + } + } + + void Push(TItem* msg) { + while (!WBucket.Push(reinterpret_cast<TMsgLink>(msg), TAux())) { + } + } + + TItem* Pop() { + return reinterpret_cast<TItem*>(RBuckets.Pop()); + } + + TSlot PopAux() { + return RBuckets.PopAux(); + } + + private: + NObstructiveQueuePrivate::TWriteBucket<TAux, BUNCH_SIZE, TBunchBase> + WBucket; + NObstructiveQueuePrivate::TReadBucket<TAux, BUNCH_SIZE, TBunchBase> + RBuckets; + }; + + template <typename TItem = void, bool DeleteItems = true> + class TObstructiveConsumerQueue + : public TObstructiveConsumerAuxQueue<TItem, + TObstructiveQueueDeleteItems<DeleteItems>> { + }; +} diff --git a/library/cpp/threading/queue/queue_ut.cpp b/library/cpp/threading/queue/queue_ut.cpp new file mode 100644 index 0000000000..80eca147da --- /dev/null +++ b/library/cpp/threading/queue/queue_ut.cpp @@ -0,0 +1,242 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <util/system/thread.h> + +#include "ut_helpers.h" + +typedef void* TMsgLink; + +template <typename TQueueType> +class TQueueTestProcs: public TTestBase { +private: + UNIT_TEST_SUITE_DEMANGLE(TQueueTestProcs<TQueueType>); + UNIT_TEST(Threads2_Push1M_Threads1_Pop2M) + UNIT_TEST(Threads4_Push1M_Threads1_Pop4M) + UNIT_TEST(Threads8_RndPush100K_Threads8_Queues) + /* + UNIT_TEST(Threads24_RndPush100K_Threads24_Queues) + UNIT_TEST(Threads24_RndPush100K_Threads8_Queues) + UNIT_TEST(Threads24_RndPush100K_Threads4_Queues) +*/ + UNIT_TEST_SUITE_END(); + +public: + void Push1M_Pop1M() { + TQueueType queue; + TMsgLink msg = &queue; + + auto pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + + for (int i = 0; i < 1000000; ++i) { + queue.Push((char*)msg + i); + } + + for (int i = 0; i < 1000000; ++i) { + auto popped = queue.Pop(); + UNIT_ASSERT_EQUAL((char*)msg + i, popped); + } + + pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + } + + void Threads2_Push1M_Threads1_Pop2M() { + TQueueType queue; + + class TPusherThread: public ISimpleThread { + public: + TPusherThread(TQueueType& theQueue, char* start) + : Queue(theQueue) + , Arg(start) + { + } + + TQueueType& Queue; + char* Arg; + + void* ThreadProc() override { + for (int i = 0; i < 1000000; ++i) { + Queue.Push(Arg + i); + } + return nullptr; + } + }; + + TPusherThread pusher1(queue, (char*)&queue); + TPusherThread pusher2(queue, (char*)&queue + 2000000); + + pusher1.Start(); + pusher2.Start(); + + for (int i = 0; i < 2000000; ++i) { + while (queue.Pop() == nullptr) { + SpinLockPause(); + } + } + + auto pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + } + + void Threads4_Push1M_Threads1_Pop4M() { + TQueueType queue; + + class TPusherThread: public ISimpleThread { + public: + TPusherThread(TQueueType& theQueue, char* start) + : Queue(theQueue) + , Arg(start) + { + } + + TQueueType& Queue; + char* Arg; + + void* ThreadProc() override { + for (int i = 0; i < 1000000; ++i) { + Queue.Push(Arg + i); + } + return nullptr; + } + }; + + TPusherThread pusher1(queue, (char*)&queue); + TPusherThread pusher2(queue, (char*)&queue + 2000000); + TPusherThread pusher3(queue, (char*)&queue + 4000000); + TPusherThread pusher4(queue, (char*)&queue + 6000000); + + pusher1.Start(); + pusher2.Start(); + pusher3.Start(); + pusher4.Start(); + + for (int i = 0; i < 4000000; ++i) { + while (queue.Pop() == nullptr) { + SpinLockPause(); + } + } + + auto pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + } + + template <size_t NUMBER_OF_PUSHERS, size_t NUMBER_OF_QUEUES> + void ManyRndPush100K_ManyQueues() { + TQueueType queue[NUMBER_OF_QUEUES]; + + class TPusherThread: public ISimpleThread { + public: + TPusherThread(TQueueType* queues, char* start) + : Queues(queues) + , Arg(start) + { + } + + TQueueType* Queues; + char* Arg; + + void* ThreadProc() override { + ui64 counters[NUMBER_OF_QUEUES]; + for (size_t i = 0; i < NUMBER_OF_QUEUES; ++i) { + counters[i] = 0; + } + + for (int i = 0; i < 100000; ++i) { + size_t rnd = GetCycleCount() % NUMBER_OF_QUEUES; + int cookie = counters[rnd]++; + Queues[rnd].Push(Arg + cookie); + } + + for (size_t i = 0; i < NUMBER_OF_QUEUES; ++i) { + Queues[i].Push((void*)2ULL); + } + + return nullptr; + } + }; + + class TPopperThread: public ISimpleThread { + public: + TPopperThread(TQueueType* theQueue, char* base) + : Queue(theQueue) + , Base(base) + { + } + + TQueueType* Queue; + char* Base; + + void* ThreadProc() override { + ui64 counters[NUMBER_OF_PUSHERS]; + for (size_t i = 0; i < NUMBER_OF_PUSHERS; ++i) { + counters[i] = 0; + } + + for (size_t fin = 0; fin < NUMBER_OF_PUSHERS;) { + auto msg = Queue->Pop(); + if (msg == nullptr) { + SpinLockPause(); + continue; + } + if (msg == (void*)2ULL) { + ++fin; + continue; + } + ui64 shift = (char*)msg - Base; + auto pusherNum = shift / 200000000ULL; + auto msgNum = shift % 200000000ULL; + + UNIT_ASSERT_EQUAL(counters[pusherNum], msgNum); + ++counters[pusherNum]; + } + + auto pmsg = Queue->Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + + return nullptr; + } + }; + + TVector<TAutoPtr<TPopperThread>> poppers; + TVector<TAutoPtr<TPusherThread>> pushers; + + for (size_t i = 0; i < NUMBER_OF_QUEUES; ++i) { + poppers.emplace_back(new TPopperThread(&queue[i], (char*)&queue)); + poppers.back()->Start(); + } + + for (size_t i = 0; i < NUMBER_OF_PUSHERS; ++i) { + pushers.emplace_back( + new TPusherThread(queue, (char*)&queue + 200000000ULL * i)); + pushers.back()->Start(); + } + + for (size_t i = 0; i < NUMBER_OF_QUEUES; ++i) { + poppers[i]->Join(); + } + + for (size_t i = 0; i < NUMBER_OF_PUSHERS; ++i) { + pushers[i]->Join(); + } + } + + void Threads8_RndPush100K_Threads8_Queues() { + ManyRndPush100K_ManyQueues<8, 8>(); + } + + /* + void Threads24_RndPush100K_Threads24_Queues() { + ManyRndPush100K_ManyQueues<24, 24>(); + } + + void Threads24_RndPush100K_Threads8_Queues() { + ManyRndPush100K_ManyQueues<24, 8>(); + } + + void Threads24_RndPush100K_Threads4_Queues() { + ManyRndPush100K_ManyQueues<24, 4>(); + } + */ +}; + +REGISTER_TESTS_FOR_ALL_ORDERED_QUEUES(TQueueTestProcs); diff --git a/library/cpp/threading/queue/tune.h b/library/cpp/threading/queue/tune.h new file mode 100644 index 0000000000..50fc3dc17c --- /dev/null +++ b/library/cpp/threading/queue/tune.h @@ -0,0 +1,125 @@ +#pragma once + +/* + Motivation: consider you have a template class with many parameters + with default associations + + template <typename A = TDefA, + typename B = TDefB, + typename C = TDefC, + typename D = TDefD> + class TExample { + }; + + consider you would like to provide easy to use interface to tune all + these parameters in position independed manner, + In that case TTune would be helpful for you. + + How to use: + First step: declare a struct with all default associations + + struct TDefaultTune { + using TStructA = TDefA; + using TStructB = TDefB; + using TStructC = TDefC; + using TStructD = TDefD; + }; + + Second step: declare helper names visible to a user + + DeclareTuneTypeParam(TTuneParamA, TStructA); + DeclareTuneTypeParam(TTuneParamB, TStructB); + DeclareTuneTypeParam(TTuneParamC, TStructC); + DeclareTuneTypeParam(TTuneParamD, TStructD); + + Third step: declare TExample this way: + + template <typename...TParams> + class TExample { + using TMyParams = TTune<TDefaultTune, TParams...>; + + using TActualA = TMyParams::TStructA; + using TActualB = TMyParams::TStructB; + ... + }; + + TTune<TDefaultTune, TParams...> is a struct with the default parameteres + taken from TDefaultTune and overridden from "TParams...". + + for example: "TTune<TDefaultTune, TTuneParamC<TUserClass>>" + will be virtually the same as: + + struct TTunedClass { + using TStructA = TDefA; + using TStructB = TDefB; + using TStructC = TUserClass; + using TStructD = TDefD; + }; + + From now on you can tune your TExample in the following manner: + + using TCustomClass = + TExample <TTuneParamA<TUserStruct1>, TTuneParamD<TUserStruct2>>; + + You can also tweak constant expressions in your TDefaultTune. + Consider you have: + + struct TDefaultTune { + static constexpr ui32 MySize = 42; + }; + + declare an interface to modify the parameter this way: + + DeclareTuneValueParam(TStructSize, ui32, MySize); + + and tweak your class: + + using TTwiceBigger = TExample<TStructSize<84>>; + + */ + +#define DeclareTuneTypeParam(TParamName, InternalName) \ + template <typename TNewType> \ + struct TParamName { \ + template <typename TBase> \ + struct TApply: public TBase { \ + using InternalName = TNewType; \ + }; \ + } + +#define DeclareTuneValueParam(TParamName, TValueType, InternalName) \ + template <TValueType NewValue> \ + struct TParamName { \ + template <typename TBase> \ + struct TApply: public TBase { \ + static constexpr TValueType InternalName = NewValue; \ + }; \ + } + +#define DeclareTuneContainer(TParamName, InternalName) \ + template <template <typename, typename...> class TNewContainer> \ + struct TParamName { \ + template <typename TBase> \ + struct TApply: public TBase { \ + template <typename TElem, typename... TRest> \ + using InternalName = TNewContainer<TElem, TRest...>; \ + }; \ + } + +namespace NTunePrivate { + template <typename TBase, typename... TParams> + struct TFold; + + template <typename TBase> + struct TFold<TBase>: public TBase { + }; + + template <typename TBase, typename TFirstArg, typename... TRest> + struct TFold<TBase, TFirstArg, TRest...> + : public TFold<typename TFirstArg::template TApply<TBase>, TRest...> { + }; +} + +template <typename TDefault, typename... TParams> +struct TTune: public NTunePrivate::TFold<TDefault, TParams...> { +}; diff --git a/library/cpp/threading/queue/tune_ut.cpp b/library/cpp/threading/queue/tune_ut.cpp new file mode 100644 index 0000000000..7e980d3e27 --- /dev/null +++ b/library/cpp/threading/queue/tune_ut.cpp @@ -0,0 +1,118 @@ +#include <library/cpp/testing/unittest/registar.h> +#include "tune.h" + +struct TDefaultStructA { +}; + +struct TDefaultStructB { +}; + +struct TDefaults { + using TStructA = TDefaultStructA; + using TStructB = TDefaultStructB; + static constexpr ui32 Param1 = 42; + static constexpr ui32 Param2 = 42; +}; + +DeclareTuneTypeParam(TweakStructA, TStructA); +DeclareTuneTypeParam(TweakStructB, TStructB); +DeclareTuneValueParam(TweakParam1, ui32, Param1); +DeclareTuneValueParam(TweakParam2, ui32, Param2); + +Y_UNIT_TEST_SUITE(TestTuning) { + Y_UNIT_TEST(Defaults) { + using TTuned = TTune<TDefaults>; + using TunedA = TTuned::TStructA; + using TunedB = TTuned::TStructB; + auto sameA = std::is_same<TDefaultStructA, TunedA>::value; + auto sameB = std::is_same<TDefaultStructB, TunedB>::value; + auto param1 = TTuned::Param1; + auto param2 = TTuned::Param2; + + UNIT_ASSERT(sameA); + UNIT_ASSERT(sameB); + UNIT_ASSERT_EQUAL(param1, 42); + UNIT_ASSERT_EQUAL(param2, 42); + } + + Y_UNIT_TEST(TuneStructA) { + struct TMyStruct { + }; + + using TTuned = TTune<TDefaults, TweakStructA<TMyStruct>>; + + using TunedA = TTuned::TStructA; + using TunedB = TTuned::TStructB; + //auto sameA = std::is_same<TDefaultStructA, TunedA>::value; + auto sameB = std::is_same<TDefaultStructB, TunedB>::value; + auto param1 = TTuned::Param1; + auto param2 = TTuned::Param2; + + auto sameA = std::is_same<TMyStruct, TunedA>::value; + + UNIT_ASSERT(sameA); + UNIT_ASSERT(sameB); + UNIT_ASSERT_EQUAL(param1, 42); + UNIT_ASSERT_EQUAL(param2, 42); + } + + Y_UNIT_TEST(TuneParam1) { + using TTuned = TTune<TDefaults, TweakParam1<24>>; + + using TunedA = TTuned::TStructA; + using TunedB = TTuned::TStructB; + auto sameA = std::is_same<TDefaultStructA, TunedA>::value; + auto sameB = std::is_same<TDefaultStructB, TunedB>::value; + auto param1 = TTuned::Param1; + auto param2 = TTuned::Param2; + + UNIT_ASSERT(sameA); + UNIT_ASSERT(sameB); + UNIT_ASSERT_EQUAL(param1, 24); + UNIT_ASSERT_EQUAL(param2, 42); + } + + Y_UNIT_TEST(TuneStructAAndParam1) { + struct TMyStruct { + }; + + using TTuned = + TTune<TDefaults, TweakStructA<TMyStruct>, TweakParam1<24>>; + + using TunedA = TTuned::TStructA; + using TunedB = TTuned::TStructB; + //auto sameA = std::is_same<TDefaultStructA, TunedA>::value; + auto sameB = std::is_same<TDefaultStructB, TunedB>::value; + auto param1 = TTuned::Param1; + auto param2 = TTuned::Param2; + + auto sameA = std::is_same<TMyStruct, TunedA>::value; + + UNIT_ASSERT(sameA); + UNIT_ASSERT(sameB); + UNIT_ASSERT_EQUAL(param1, 24); + UNIT_ASSERT_EQUAL(param2, 42); + } + + Y_UNIT_TEST(TuneParam1AndStructA) { + struct TMyStruct { + }; + + using TTuned = + TTune<TDefaults, TweakParam1<24>, TweakStructA<TMyStruct>>; + + using TunedA = TTuned::TStructA; + using TunedB = TTuned::TStructB; + //auto sameA = std::is_same<TDefaultStructA, TunedA>::value; + auto sameB = std::is_same<TDefaultStructB, TunedB>::value; + auto param1 = TTuned::Param1; + auto param2 = TTuned::Param2; + + auto sameA = std::is_same<TMyStruct, TunedA>::value; + + UNIT_ASSERT(sameA); + UNIT_ASSERT(sameB); + UNIT_ASSERT_EQUAL(param1, 24); + UNIT_ASSERT_EQUAL(param2, 42); + } +} diff --git a/library/cpp/threading/queue/unordered_ut.cpp b/library/cpp/threading/queue/unordered_ut.cpp new file mode 100644 index 0000000000..a43b7f520e --- /dev/null +++ b/library/cpp/threading/queue/unordered_ut.cpp @@ -0,0 +1,154 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <util/system/thread.h> +#include <algorithm> +#include <util/generic/vector.h> +#include <util/random/fast.h> + +#include "ut_helpers.h" + +template <typename TQueueType> +class TTestUnorderedQueue: public TTestBase { +private: + using TLink = TIntrusiveLink; + + UNIT_TEST_SUITE_DEMANGLE(TTestUnorderedQueue<TQueueType>); + UNIT_TEST(Push1M_Pop1M_Unordered) + UNIT_TEST_SUITE_END(); + +public: + void Push1M_Pop1M_Unordered() { + constexpr int REPEAT = 1000000; + TQueueType queue; + TLink msg[REPEAT]; + + auto pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + + for (int i = 0; i < REPEAT; ++i) { + queue.Push(&msg[i]); + } + + TVector<TLink*> popped; + popped.reserve(REPEAT); + for (int i = 0; i < REPEAT; ++i) { + popped.push_back((TLink*)queue.Pop()); + } + + pmsg = queue.Pop(); + UNIT_ASSERT_VALUES_EQUAL(pmsg, nullptr); + + std::sort(popped.begin(), popped.end()); + for (int i = 0; i < REPEAT; ++i) { + UNIT_ASSERT_VALUES_EQUAL(&msg[i], popped[i]); + } + } +}; + +template <typename TQueueType> +class TTestWeakQueue: public TTestBase { +private: + UNIT_TEST_SUITE_DEMANGLE(TTestWeakQueue<TQueueType>); + UNIT_TEST(Threads8_Rnd_Exchange) + UNIT_TEST_SUITE_END(); + +public: + template <ui16 COUNT = 48, ui32 MSG_COUNT = 10000> + void ManyThreadsRndExchange() { + TQueueType queues[COUNT]; + + class TWorker: public ISimpleThread { + public: + TWorker( + TQueueType* queues_, + ui16 mine, + TAtomic* pushDone) + : Queues(queues_) + , MineQueue(mine) + , PushDone(pushDone) + { + } + + TQueueType* Queues; + ui16 MineQueue; + TVector<uintptr_t> Received; + TAtomic* PushDone; + + void* ThreadProc() override { + TReallyFastRng32 rng(GetCycleCount()); + Received.reserve(MSG_COUNT * 2); + + for (ui32 loop = 1; loop <= MSG_COUNT; ++loop) { + for (;;) { + auto msg = Queues[MineQueue].Pop(); + if (msg == nullptr) { + break; + } + + Received.push_back((uintptr_t)msg); + } + + ui16 rnd = rng.GenRand64() % COUNT; + ui64 msg = ((ui64)MineQueue << 32) + loop; + while (!Queues[rnd].Push((void*)msg)) { + } + } + + AtomicIncrement(*PushDone); + + for (;;) { + bool isItLast = AtomicGet(*PushDone) == COUNT; + auto msg = Queues[MineQueue].Pop(); + if (msg != nullptr) { + Received.push_back((uintptr_t)msg); + } else { + if (isItLast) { + break; + } + SpinLockPause(); + } + } + + for (ui64 last = 0;;) { + auto msg = Queues[MineQueue].UnsafeScanningPop(&last); + if (msg == nullptr) { + break; + } + Received.push_back((uintptr_t)msg); + } + + return nullptr; + } + }; + + TVector<TAutoPtr<TWorker>> workers; + TAtomic pushDone = 0; + + for (ui32 i = 0; i < COUNT; ++i) { + workers.emplace_back(new TWorker(&queues[0], i, &pushDone)); + workers.back()->Start(); + } + + TVector<uintptr_t> all; + for (ui32 i = 0; i < COUNT; ++i) { + workers[i]->Join(); + all.insert(all.begin(), + workers[i]->Received.begin(), workers[i]->Received.end()); + } + + std::sort(all.begin(), all.end()); + auto iter = all.begin(); + for (ui32 i = 0; i < COUNT; ++i) { + for (ui32 k = 1; k <= MSG_COUNT; ++k) { + UNIT_ASSERT_VALUES_EQUAL(((ui64)i << 32) + k, *iter); + ++iter; + } + } + } + + void Threads8_Rnd_Exchange() { + ManyThreadsRndExchange<8>(); + } +}; + +REGISTER_TESTS_FOR_ALL_UNORDERED_QUEUES(TTestUnorderedQueue); +UNIT_TEST_SUITE_REGISTRATION(TTestWeakQueue<TMPMCUnorderedRing>); diff --git a/library/cpp/threading/queue/ut/ya.make b/library/cpp/threading/queue/ut/ya.make new file mode 100644 index 0000000000..8883d9bf69 --- /dev/null +++ b/library/cpp/threading/queue/ut/ya.make @@ -0,0 +1,16 @@ +UNITTEST_FOR(library/cpp/threading/queue) + +OWNER(agri) + +ALLOCATOR(B) + +SRCS( + basic_ut.cpp + queue_ut.cpp + tune_ut.cpp + unordered_ut.cpp + ut_helpers.cpp + ut_helpers.h +) + +END() diff --git a/library/cpp/threading/queue/ut_helpers.cpp b/library/cpp/threading/queue/ut_helpers.cpp new file mode 100644 index 0000000000..aa3a831441 --- /dev/null +++ b/library/cpp/threading/queue/ut_helpers.cpp @@ -0,0 +1 @@ +#include "ut_helpers.h" diff --git a/library/cpp/threading/queue/ut_helpers.h b/library/cpp/threading/queue/ut_helpers.h new file mode 100644 index 0000000000..2756b52601 --- /dev/null +++ b/library/cpp/threading/queue/ut_helpers.h @@ -0,0 +1,40 @@ +#pragma once + +#include "mpsc_read_as_filled.h" +#include "mpsc_htswap.h" +#include "mpsc_vinfarr_obstructive.h" +#include "mpsc_intrusive_unordered.h" +#include "mpmc_unordered_ring.h" + +struct TBasicHTSwap: public NThreading::THTSwapQueue<> { +}; + +struct TBasicReadAsFilled: public NThreading::TReadAsFilledQueue<> { +}; + +struct TBasicObstructiveConsumer + : public NThreading::TObstructiveConsumerQueue<> { +}; + +struct TBasicMPSCIntrusiveUnordered + : public NThreading::TMPSCIntrusiveUnordered { +}; + +struct TIntrusiveLink: public NThreading::TIntrusiveNode { +}; + +struct TMPMCUnorderedRing: public NThreading::TMPMCUnorderedRing { + TMPMCUnorderedRing() + : NThreading::TMPMCUnorderedRing(10000000) + { + } +}; + +#define REGISTER_TESTS_FOR_ALL_ORDERED_QUEUES(TestTemplate) \ + UNIT_TEST_SUITE_REGISTRATION(TestTemplate<TBasicHTSwap>); \ + UNIT_TEST_SUITE_REGISTRATION(TestTemplate<TBasicReadAsFilled>); \ + UNIT_TEST_SUITE_REGISTRATION(TestTemplate<TBasicObstructiveConsumer>) + +#define REGISTER_TESTS_FOR_ALL_UNORDERED_QUEUES(TestTemplate) \ + UNIT_TEST_SUITE_REGISTRATION(TestTemplate<TBasicMPSCIntrusiveUnordered>); \ + UNIT_TEST_SUITE_REGISTRATION(TestTemplate<TMPMCUnorderedRing>); diff --git a/library/cpp/threading/queue/ya.make b/library/cpp/threading/queue/ya.make new file mode 100644 index 0000000000..6570b38ce5 --- /dev/null +++ b/library/cpp/threading/queue/ya.make @@ -0,0 +1,18 @@ +LIBRARY() + +OWNER(agri) + +SRCS( + mpmc_unordered_ring.cpp + mpmc_unordered_ring.h + mpsc_htswap.cpp + mpsc_htswap.h + mpsc_intrusive_unordered.cpp + mpsc_intrusive_unordered.h + mpsc_read_as_filled.cpp + mpsc_read_as_filled.h + mpsc_vinfarr_obstructive.cpp + mpsc_vinfarr_obstructive.h +) + +END() diff --git a/library/cpp/threading/skip_list/compare.h b/library/cpp/threading/skip_list/compare.h new file mode 100644 index 0000000000..ac98b3e1ce --- /dev/null +++ b/library/cpp/threading/skip_list/compare.h @@ -0,0 +1,77 @@ +#pragma once + +#include <util/generic/typetraits.h> +#include <util/str_stl.h> + +namespace NThreading { + namespace NImpl { + Y_HAS_MEMBER(compare); + Y_HAS_MEMBER(Compare); + + template <typename T> + inline int CompareImpl(const T& l, const T& r) { + if (l < r) { + return -1; + } else if (r < l) { + return +1; + } else { + return 0; + } + } + + template <bool val> + struct TSmallCompareSelector { + template <typename T> + static inline int Compare(const T& l, const T& r) { + return CompareImpl(l, r); + } + }; + + template <> + struct TSmallCompareSelector<true> { + template <typename T> + static inline int Compare(const T& l, const T& r) { + return l.compare(r); + } + }; + + template <bool val> + struct TBigCompareSelector { + template <typename T> + static inline int Compare(const T& l, const T& r) { + return TSmallCompareSelector<THascompare<T>::value>::Compare(l, r); + } + }; + + template <> + struct TBigCompareSelector<true> { + template <typename T> + static inline int Compare(const T& l, const T& r) { + return l.Compare(r); + } + }; + + template <typename T> + struct TCompareSelector: public TBigCompareSelector<THasCompare<T>::value> { + }; + } + + //////////////////////////////////////////////////////////////////////////////// + // Generic compare function + + template <typename T> + inline int Compare(const T& l, const T& r) { + return NImpl::TCompareSelector<T>::Compare(l, r); + } + + //////////////////////////////////////////////////////////////////////////////// + // Generic compare functor + + template <typename T> + struct TCompare { + inline int operator()(const T& l, const T& r) const { + return Compare(l, r); + } + }; + +} diff --git a/library/cpp/threading/skip_list/perf/main.cpp b/library/cpp/threading/skip_list/perf/main.cpp new file mode 100644 index 0000000000..4ad52049e7 --- /dev/null +++ b/library/cpp/threading/skip_list/perf/main.cpp @@ -0,0 +1,362 @@ +#include <library/cpp/threading/skip_list/skiplist.h> + +#include <library/cpp/getopt/small/last_getopt.h> + +#include <library/cpp/charset/ci_string.h> +#include <util/datetime/base.h> +#include <util/generic/map.h> +#include <util/generic/vector.h> +#include <functional> +#include <util/memory/pool.h> +#include <util/random/random.h> +#include <util/string/join.h> +#include <util/system/mutex.h> +#include <util/system/thread.h> + +namespace { + using namespace NThreading; + + //////////////////////////////////////////////////////////////////////////////// + + IOutputStream& LogInfo() { + return Cerr << TInstant::Now() << " INFO: "; + } + + IOutputStream& LogError() { + return Cerr << TInstant::Now() << " ERROR: "; + } + + //////////////////////////////////////////////////////////////////////////////// + + struct TListItem { + TStringBuf Key; + TStringBuf Value; + + TListItem(const TStringBuf& key, const TStringBuf& value) + : Key(key) + , Value(value) + { + } + + int Compare(const TListItem& other) const { + return Key.compare(other.Key); + } + }; + + using TListType = TSkipList<TListItem>; + + //////////////////////////////////////////////////////////////////////////////// + + class TRandomData { + private: + TVector<char> Buffer; + + public: + TRandomData() + : Buffer(1024 * 1024) + { + for (size_t i = 0; i < Buffer.size(); ++i) { + Buffer[i] = RandomNumber<char>(); + } + } + + TStringBuf GetString(size_t len) const { + size_t start = RandomNumber(Buffer.size() - len); + return TStringBuf(&Buffer[start], len); + } + + TStringBuf GetString(size_t min, size_t max) const { + return GetString(min + RandomNumber(max - min)); + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + class TWorkerThread: public ISimpleThread { + private: + std::function<void()> Func; + TDuration Time; + + public: + TWorkerThread(std::function<void()> func) + : Func(func) + { + } + + TDuration GetTime() const { + return Time; + } + + private: + void* ThreadProc() noexcept override { + TInstant started = TInstant::Now(); + Func(); + Time = TInstant::Now() - started; + return nullptr; + } + }; + + inline TAutoPtr<TWorkerThread> StartThread(std::function<void()> func) { + TAutoPtr<TWorkerThread> thread = new TWorkerThread(func); + thread->Start(); + return thread; + } + + //////////////////////////////////////////////////////////////////////////////// + + typedef std::function<void()> TTestFunc; + + struct TTest { + TString Name; + TTestFunc Func; + + TTest() { + } + + TTest(const TString& name, const TTestFunc& func) + : Name(name) + , Func(func) + { + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + class TTestSuite { + private: + size_t Iterations = 1000000; + size_t KeyLen = 10; + size_t ValueLen = 100; + size_t NumReaders = 4; + size_t NumWriters = 1; + size_t BatchSize = 20; + + TMemoryPool MemoryPool; + TListType List; + TMutex Mutex; + TRandomData Random; + + TMap<TCiString, TTest> AllTests; + TVector<TTest> Tests; + + public: + TTestSuite() + : MemoryPool(64 * 1024) + , List(MemoryPool) + { + } + + bool Init(int argc, const char* argv[]) { + TVector<TString> tests; + try { + NLastGetopt::TOpts opts; + opts.AddHelpOption(); + +#define OPTION(opt, x) \ + opts.AddLongOption(opt, #x) \ + .Optional() \ + .DefaultValue(ToString(x)) \ + .StoreResult(&x) // end of OPTION + + OPTION('i', Iterations); + OPTION('k', KeyLen); + OPTION('v', ValueLen); + OPTION('r', NumReaders); + OPTION('w', NumWriters); + OPTION('b', BatchSize); + +#undef OPTION + + NLastGetopt::TOptsParseResultException optsRes(&opts, argc, argv); + for (const auto& opt : opts.Opts_) { + const NLastGetopt::TOptParseResult* r = optsRes.FindOptParseResult(opt.Get(), true); + if (r) { + LogInfo() << "[-" << opt->GetChar() << "] " << opt->GetName() << ": " << r->Back() << Endl; + } + } + tests = optsRes.GetFreeArgs(); + } catch (...) { + LogError() << CurrentExceptionMessage() << Endl; + return false; + } + +#define TEST(type) \ + AddTest(#type, std::bind(&TTestSuite::Y_CAT(TEST_, type), this)) // end of TEST + + TEST(Clear); + TEST(InsertRandom); + TEST(InsertSequential); + TEST(InsertSequentialSimple); + TEST(LookupRandom); + TEST(Concurrent); + +#undef TEST + + if (tests.empty()) { + LogError() << "no tests specified, choose from: " << PrintTests() << Endl; + return false; + } + + for (size_t i = 0; i < tests.size(); ++i) { + if (!AllTests.contains(tests[i])) { + LogError() << "unknown test name: " << tests[i] << Endl; + return false; + } + Tests.push_back(AllTests[tests[i]]); + } + + return true; + } + + void Run() { +#if !defined(NDEBUG) + LogInfo() << "*** DEBUG build! ***" << Endl; +#endif + + for (const TTest& test : Tests) { + LogInfo() << "Starting test " << test.Name << Endl; + + TInstant started = TInstant::Now(); + try { + test.Func(); + } catch (...) { + LogError() << "test " << test.Name + << " failed: " << CurrentExceptionMessage() + << Endl; + } + + LogInfo() << "List size = " << List.GetSize() << Endl; + + TDuration duration = TInstant::Now() - started; + LogInfo() << "test " << test.Name + << " duration: " << duration + << " (" << (double)duration.MicroSeconds() / (Iterations * NumWriters) << "us per iteration)" + << Endl; + LogInfo() << "Finished test " << test.Name << Endl; + } + } + + private: + void AddTest(const char* name, TTestFunc func) { + AllTests[name] = TTest(name, func); + } + + TString PrintTests() const { + TVector<TString> names; + for (const auto& it : AllTests) { + names.push_back(it.first); + } + return JoinSeq(", ", names); + } + + void TEST_Clear() { + List.Clear(); + } + + void TEST_InsertRandom() { + for (size_t i = 0; i < Iterations; ++i) { + List.Insert(TListItem(Random.GetString(KeyLen), Random.GetString(ValueLen))); + } + } + + void TEST_InsertSequential() { + TString key; + for (size_t i = 0; i < Iterations;) { + key.assign(Random.GetString(KeyLen)); + size_t batch = BatchSize / 2 + RandomNumber(BatchSize); + for (size_t j = 0; j < batch; ++j, ++i) { + key.resize(KeyLen - 1); + key.append((char)j); + List.Insert(TListItem(key, Random.GetString(ValueLen))); + } + } + } + + void TEST_InsertSequentialSimple() { + for (size_t i = 0; i < Iterations; ++i) { + List.Insert(TListItem(Random.GetString(KeyLen), Random.GetString(ValueLen))); + } + } + + void TEST_LookupRandom() { + for (size_t i = 0; i < Iterations; ++i) { + List.SeekTo(TListItem(Random.GetString(KeyLen), TStringBuf())); + } + } + + void TEST_Concurrent() { + LogInfo() << "starting producers..." << Endl; + + TVector<TAutoPtr<TWorkerThread>> producers(NumWriters); + for (size_t i1 = 0; i1 < producers.size(); ++i1) { + producers[i1] = StartThread([&] { + TInstant started = TInstant::Now(); + for (size_t i2 = 0; i2 < Iterations; ++i2) { + { + TGuard<TMutex> guard(Mutex); + List.Insert(TListItem(Random.GetString(KeyLen), Random.GetString(ValueLen))); + } + } + TDuration duration = TInstant::Now() - started; + LogInfo() + << "Average time for producer = " + << (double)duration.MicroSeconds() / Iterations << "us per iteration" + << Endl; + }); + } + + LogInfo() << "starting consumers..." << Endl; + + TVector<TAutoPtr<TWorkerThread>> consumers(NumReaders); + for (size_t i1 = 0; i1 < consumers.size(); ++i1) { + consumers[i1] = StartThread([&] { + TInstant started = TInstant::Now(); + for (size_t i2 = 0; i2 < Iterations; ++i2) { + List.SeekTo(TListItem(Random.GetString(KeyLen), TStringBuf())); + } + TDuration duration = TInstant::Now() - started; + LogInfo() + << "Average time for consumer = " + << (double)duration.MicroSeconds() / Iterations << "us per iteration" + << Endl; + }); + } + + LogInfo() << "wait for producers..." << Endl; + + TDuration producerTime; + for (size_t i = 0; i < producers.size(); ++i) { + producers[i]->Join(); + producerTime += producers[i]->GetTime(); + } + + LogInfo() << "wait for consumers..." << Endl; + + TDuration consumerTime; + for (size_t i = 0; i < consumers.size(); ++i) { + consumers[i]->Join(); + consumerTime += consumers[i]->GetTime(); + } + + LogInfo() << "average producer time: " + << producerTime.SecondsFloat() / producers.size() << " seconds" + << Endl; + + LogInfo() << "average consumer time: " + << consumerTime.SecondsFloat() / consumers.size() << " seconds" + << Endl; + } + }; + +} + +//////////////////////////////////////////////////////////////////////////////// + +int main(int argc, const char* argv[]) { + TTestSuite suite; + if (!suite.Init(argc, argv)) { + return -1; + } + suite.Run(); + return 0; +} diff --git a/library/cpp/threading/skip_list/perf/ya.make b/library/cpp/threading/skip_list/perf/ya.make new file mode 100644 index 0000000000..01bfafa404 --- /dev/null +++ b/library/cpp/threading/skip_list/perf/ya.make @@ -0,0 +1,15 @@ +PROGRAM(skiplist-perf) + +OWNER(g:rtmr) + +PEERDIR( + library/cpp/charset + library/cpp/getopt/small + library/cpp/threading/skip_list +) + +SRCS( + main.cpp +) + +END() diff --git a/library/cpp/threading/skip_list/skiplist.cpp b/library/cpp/threading/skip_list/skiplist.cpp new file mode 100644 index 0000000000..c6e98816fb --- /dev/null +++ b/library/cpp/threading/skip_list/skiplist.cpp @@ -0,0 +1 @@ +#include "skiplist.h" diff --git a/library/cpp/threading/skip_list/skiplist.h b/library/cpp/threading/skip_list/skiplist.h new file mode 100644 index 0000000000..914a7c6ee7 --- /dev/null +++ b/library/cpp/threading/skip_list/skiplist.h @@ -0,0 +1,408 @@ +#pragma once + +#include "compare.h" + +#include <util/generic/algorithm.h> +#include <util/generic/noncopyable.h> +#include <util/generic/typetraits.h> +#include <util/memory/pool.h> +#include <util/random/random.h> +#include <util/system/atomic.h> + +namespace NThreading { + //////////////////////////////////////////////////////////////////////////////// + + class TNopCounter { + protected: + template <typename T> + void OnInsert(const T&) { + } + + template <typename T> + void OnUpdate(const T&) { + } + + void Reset() { + } + }; + + //////////////////////////////////////////////////////////////////////////////// + + class TSizeCounter { + private: + size_t Size; + + public: + TSizeCounter() + : Size(0) + { + } + + size_t GetSize() const { + return Size; + } + + protected: + template <typename T> + void OnInsert(const T&) { + ++Size; + } + + template <typename T> + void OnUpdate(const T&) { + } + + void Reset() { + Size = 0; + } + }; + + //////////////////////////////////////////////////////////////////////////////// + // Append-only concurrent skip-list + // + // Readers do not require any synchronization. + // Writers should be externally synchronized. + // Nodes will be allocated using TMemoryPool instance. + + template < + typename T, + typename TComparer = TCompare<T>, + typename TAllocator = TMemoryPool, + typename TCounter = TSizeCounter, + int MaxHeight = 12, + int Branching = 4> + class TSkipList: public TCounter, private TNonCopyable { + class TNode { + private: + T Value; // should be immutable after insert + TNode* Next[]; // variable-size array maximum of MaxHeight values + + public: + TNode(T&& value) + : Value(std::move(value)) + { + Y_UNUSED(Next); + } + + const T& GetValue() const { + return Value; + } + + T& GetValue() { + return Value; + } + + TNode* GetNext(int height) const { + return AtomicGet(Next[height]); + } + + void Link(int height, TNode** prev) { + for (int i = 0; i < height; ++i) { + Next[i] = prev[i]->Next[i]; + AtomicSet(prev[i]->Next[i], this); + } + } + }; + + public: + class TIterator { + private: + const TSkipList* List; + const TNode* Node; + + public: + TIterator() + : List(nullptr) + , Node(nullptr) + { + } + + TIterator(const TSkipList* list, const TNode* node) + : List(list) + , Node(node) + { + } + + TIterator(const TIterator& other) + : List(other.List) + , Node(other.Node) + { + } + + TIterator& operator=(const TIterator& other) { + List = other.List; + Node = other.Node; + return *this; + } + + void Next() { + Node = Node ? Node->GetNext(0) : nullptr; + } + + // much less efficient than Next as our list is single-linked + void Prev() { + if (Node) { + TNode* node = List->FindLessThan(Node->GetValue(), nullptr); + Node = (node != List->Head ? node : nullptr); + } + } + + void Reset() { + Node = nullptr; + } + + bool IsValid() const { + return Node != nullptr; + } + + const T& GetValue() const { + Y_ASSERT(IsValid()); + return Node->GetValue(); + } + }; + + private: + TAllocator& Allocator; + TComparer Comparer; + + TNode* Head; + TAtomic Height; + TCounter Counter; + + TNode* Prev[MaxHeight]; + + template <typename TValue> + using TComparerReturnType = std::invoke_result_t<TComparer, const T&, const TValue&>; + + public: + TSkipList(TAllocator& allocator, const TComparer& comparer = TComparer()) + : Allocator(allocator) + , Comparer(comparer) + { + Init(); + } + + ~TSkipList() { + CallDtors(); + } + + void Clear() { + CallDtors(); + Allocator.ClearKeepFirstChunk(); + Init(); + } + + bool Insert(T value) { + TNode* node = PrepareInsert(value); + if (Y_UNLIKELY(node && Compare(node, value) == 0)) { + // we do not allow duplicates + return false; + } + node = DoInsert(std::move(value)); + TCounter::OnInsert(node->GetValue()); + return true; + } + + template <typename TInsertAction, typename TUpdateAction> + bool Insert(const T& value, TInsertAction insert, TUpdateAction update) { + TNode* node = PrepareInsert(value); + if (Y_UNLIKELY(node && Compare(node, value) == 0)) { + if (update(node->GetValue())) { + TCounter::OnUpdate(node->GetValue()); + return true; + } + // we do not allow duplicates + return false; + } + node = DoInsert(insert(value)); + TCounter::OnInsert(node->GetValue()); + return true; + } + + template <typename TValue> + bool Contains(const TValue& value) const { + TNode* node = FindGreaterThanOrEqual(value); + return node && Compare(node, value) == 0; + } + + TIterator SeekToFirst() const { + return TIterator(this, FindFirst()); + } + + TIterator SeekToLast() const { + TNode* last = FindLast(); + return TIterator(this, last != Head ? last : nullptr); + } + + template <typename TValue> + TIterator SeekTo(const TValue& value) const { + return TIterator(this, FindGreaterThanOrEqual(value)); + } + + private: + static int RandomHeight() { + int height = 1; + while (height < MaxHeight && (RandomNumber<unsigned int>() % Branching) == 0) { + ++height; + } + return height; + } + + void Init() { + Head = AllocateRootNode(); + Height = 1; + TCounter::Reset(); + + for (int i = 0; i < MaxHeight; ++i) { + Prev[i] = Head; + } + } + + void CallDtors() { + if (!TTypeTraits<T>::IsPod) { + // we should explicitly call destructors for our nodes + TNode* node = Head->GetNext(0); + while (node) { + TNode* next = node->GetNext(0); + node->~TNode(); + node = next; + } + } + } + + TNode* AllocateRootNode() { + size_t size = sizeof(TNode) + sizeof(TNode*) * MaxHeight; + void* buffer = Allocator.Allocate(size); + memset(buffer, 0, size); + return static_cast<TNode*>(buffer); + } + + TNode* AllocateNode(T&& value, int height) { + size_t size = sizeof(TNode) + sizeof(TNode*) * height; + void* buffer = Allocator.Allocate(size); + memset(buffer, 0, size); + return new (buffer) TNode(std::move(value)); + } + + TNode* FindFirst() const { + return Head->GetNext(0); + } + + TNode* FindLast() const { + TNode* node = Head; + int height = AtomicGet(Height) - 1; + + while (true) { + TNode* next = node->GetNext(height); + if (next) { + node = next; + continue; + } + + if (height) { + --height; + } else { + return node; + } + } + } + + template <typename TValue> + TComparerReturnType<TValue> Compare(const TNode* node, const TValue& value) const { + return Comparer(node->GetValue(), value); + } + + template <typename TValue> + TNode* FindLessThan(const TValue& value, TNode** links) const { + TNode* node = Head; + int height = AtomicGet(Height) - 1; + + TNode* prev = nullptr; + while (true) { + TNode* next = node->GetNext(height); + if (next && next != prev) { + TComparerReturnType<TValue> cmp = Compare(next, value); + if (cmp < 0) { + node = next; + continue; + } + } + + if (links) { + // collect links from upper levels + links[height] = node; + } + + if (height) { + prev = next; + --height; + } else { + return node; + } + } + } + + template <typename TValue> + TNode* FindGreaterThanOrEqual(const TValue& value) const { + TNode* node = Head; + int height = AtomicGet(Height) - 1; + + TNode* prev = nullptr; + while (true) { + TNode* next = node->GetNext(height); + if (next && next != prev) { + TComparerReturnType<TValue> cmp = Compare(next, value); + if (cmp < 0) { + node = next; + continue; + } + if (cmp == 0) { + return next; + } + } + + if (height) { + prev = next; + --height; + } else { + return next; + } + } + } + + TNode* PrepareInsert(const T& value) { + TNode* prev = Prev[0]; + TNode* next = prev->GetNext(0); + if ((prev == Head || Compare(prev, value) < 0) && (next == nullptr || Compare(next, value) >= 0)) { + // avoid seek in case of sequential insert + } else { + prev = FindLessThan(value, Prev); + next = prev->GetNext(0); + } + return next; + } + + TNode* DoInsert(T&& value) { + // choose level to place new node + int currentHeight = AtomicGet(Height); + int height = RandomHeight(); + if (height > currentHeight) { + for (int i = currentHeight; i < height; ++i) { + // head should link to all levels + Prev[i] = Head; + } + AtomicSet(Height, height); + } + + TNode* node = AllocateNode(std::move(value), height); + node->Link(height, Prev); + + // keep last inserted node to optimize sequential inserts + for (int i = 0; i < height; i++) { + Prev[i] = node; + } + return node; + } + }; + +} diff --git a/library/cpp/threading/skip_list/skiplist_ut.cpp b/library/cpp/threading/skip_list/skiplist_ut.cpp new file mode 100644 index 0000000000..52fcffda66 --- /dev/null +++ b/library/cpp/threading/skip_list/skiplist_ut.cpp @@ -0,0 +1,185 @@ +#include "skiplist.h" + +#include <library/cpp/testing/unittest/registar.h> + +namespace NThreading { + namespace { + struct TTestObject { + static size_t Count; + int Tag; + + TTestObject(int tag) + : Tag(tag) + { + ++Count; + } + + TTestObject(const TTestObject& other) + : Tag(other.Tag) + { + ++Count; + } + + ~TTestObject() { + --Count; + } + + bool operator<(const TTestObject& other) const { + return Tag < other.Tag; + } + }; + + size_t TTestObject::Count = 0; + + } + + //////////////////////////////////////////////////////////////////////////////// + + Y_UNIT_TEST_SUITE(TSkipListTest) { + Y_UNIT_TEST(ShouldBeEmptyAfterCreation) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT_EQUAL(list.GetSize(), 0); + } + + Y_UNIT_TEST(ShouldAllowInsertion) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(12345678)); + UNIT_ASSERT_EQUAL(list.GetSize(), 1); + } + + Y_UNIT_TEST(ShouldNotAllowDuplicates) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(12345678)); + UNIT_ASSERT_EQUAL(list.GetSize(), 1); + + UNIT_ASSERT(!list.Insert(12345678)); + UNIT_ASSERT_EQUAL(list.GetSize(), 1); + } + + Y_UNIT_TEST(ShouldContainInsertedItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(12345678)); + UNIT_ASSERT(list.Contains(12345678)); + } + + Y_UNIT_TEST(ShouldNotContainNotInsertedItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(12345678)); + UNIT_ASSERT(!list.Contains(87654321)); + } + + Y_UNIT_TEST(ShouldIterateAllItems) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + for (int i = 8; i > 0; --i) { + UNIT_ASSERT(list.Insert(i)); + } + + TSkipList<int>::TIterator it = list.SeekToFirst(); + for (int i = 1; i <= 8; ++i) { + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), i); + it.Next(); + } + UNIT_ASSERT(!it.IsValid()); + } + + Y_UNIT_TEST(ShouldIterateAllItemsInReverseDirection) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + for (int i = 8; i > 0; --i) { + UNIT_ASSERT(list.Insert(i)); + } + + TSkipList<int>::TIterator it = list.SeekToLast(); + for (int i = 8; i > 0; --i) { + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), i); + it.Prev(); + } + UNIT_ASSERT(!it.IsValid()); + } + + Y_UNIT_TEST(ShouldSeekToFirstItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + for (int i = 1; i < 10; ++i) { + UNIT_ASSERT(list.Insert(i)); + } + + TSkipList<int>::TIterator it = list.SeekToFirst(); + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), 1); + } + + Y_UNIT_TEST(ShouldSeekToLastItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + for (int i = 1; i < 10; ++i) { + UNIT_ASSERT(list.Insert(i)); + } + + TSkipList<int>::TIterator it = list.SeekToLast(); + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), 9); + } + + Y_UNIT_TEST(ShouldSeekToExistingItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(12345678)); + + TSkipList<int>::TIterator it = list.SeekTo(12345678); + UNIT_ASSERT(it.IsValid()); + } + + Y_UNIT_TEST(ShouldSeekAfterMissedItem) { + TMemoryPool pool(1024); + TSkipList<int> list(pool); + + UNIT_ASSERT(list.Insert(100)); + UNIT_ASSERT(list.Insert(300)); + + TSkipList<int>::TIterator it = list.SeekTo(200); + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), 300); + + it.Prev(); + UNIT_ASSERT(it.IsValid()); + UNIT_ASSERT_EQUAL(it.GetValue(), 100); + } + + Y_UNIT_TEST(ShouldCallDtorsOfNonPodTypes) { + UNIT_ASSERT(!TTypeTraits<TTestObject>::IsPod); + UNIT_ASSERT_EQUAL(TTestObject::Count, 0); + + { + TMemoryPool pool(1024); + TSkipList<TTestObject> list(pool); + + UNIT_ASSERT(list.Insert(TTestObject(1))); + UNIT_ASSERT(list.Insert(TTestObject(2))); + + UNIT_ASSERT_EQUAL(TTestObject::Count, 2); + } + + UNIT_ASSERT_EQUAL(TTestObject::Count, 0); + } + } + +} diff --git a/library/cpp/threading/skip_list/ut/ya.make b/library/cpp/threading/skip_list/ut/ya.make new file mode 100644 index 0000000000..704a31e9a2 --- /dev/null +++ b/library/cpp/threading/skip_list/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/threading/skip_list) + +OWNER(g:rtmr) + +SRCS( + skiplist_ut.cpp +) + +END() diff --git a/library/cpp/threading/skip_list/ya.make b/library/cpp/threading/skip_list/ya.make new file mode 100644 index 0000000000..d338aeae2b --- /dev/null +++ b/library/cpp/threading/skip_list/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(g:rtmr) + +SRCS( + skiplist.cpp +) + +END() diff --git a/library/cpp/threading/task_scheduler/task_scheduler.cpp b/library/cpp/threading/task_scheduler/task_scheduler.cpp new file mode 100644 index 0000000000..174dde4bf7 --- /dev/null +++ b/library/cpp/threading/task_scheduler/task_scheduler.cpp @@ -0,0 +1,246 @@ +#include "task_scheduler.h" + +#include <util/system/thread.h> +#include <util/string/cast.h> +#include <util/stream/output.h> + +TTaskScheduler::ITask::~ITask() {} +TTaskScheduler::IRepeatedTask::~IRepeatedTask() {} + + + +class TTaskScheduler::TWorkerThread + : public ISimpleThread +{ +public: + TWorkerThread(TTaskScheduler& state) + : Scheduler_(state) + { + } + + TString DebugState = "?"; + TString DebugId = ""; +private: + void* ThreadProc() noexcept override { + Scheduler_.WorkerFunc(this); + return nullptr; + } +private: + TTaskScheduler& Scheduler_; +}; + + + +TTaskScheduler::TTaskScheduler(size_t threadCount, size_t maxTaskCount) + : MaxTaskCount_(maxTaskCount) +{ + for (size_t i = 0; i < threadCount; ++i) { + Workers_.push_back(new TWorkerThread(*this)); + Workers_.back()->DebugId = ToString(i); + } +} + +TTaskScheduler::~TTaskScheduler() { + try { + Stop(); + } catch (...) { + Cdbg << "task scheduled destruction error: " << CurrentExceptionMessage(); + } +} + +void TTaskScheduler::Start() { + for (auto& w : Workers_) { + w->Start(); + } +} + +void TTaskScheduler::Stop() { + with_lock (Lock_) { + IsStopped_ = true; + CondVar_.BroadCast(); + } + + for (auto& w: Workers_) { + w->Join(); + } + + Workers_.clear(); + Queue_.clear(); +} + +size_t TTaskScheduler::GetTaskCount() const { + return static_cast<size_t>(AtomicGet(TaskCounter_)); +} + +namespace { + class TTaskWrapper + : public TTaskScheduler::ITask + , TNonCopyable + { + public: + TTaskWrapper(TTaskScheduler::ITaskRef task, TAtomic& counter) + : Task_(task) + , Counter_(counter) + { + AtomicIncrement(Counter_); + } + + ~TTaskWrapper() override { + AtomicDecrement(Counter_); + } + private: + TInstant Process() override { + return Task_->Process(); + } + private: + TTaskScheduler::ITaskRef Task_; + TAtomic& Counter_; + }; +} + +bool TTaskScheduler::Add(ITaskRef task, TInstant expire) { + with_lock (Lock_) { + if (!IsStopped_ && Workers_.size() > 0 && GetTaskCount() + 1 <= MaxTaskCount_) { + ITaskRef newTask = new TTaskWrapper(task, TaskCounter_); + Queue_.insert(std::make_pair(expire, TTaskHolder(newTask))); + + if (!Queue_.begin()->second.WaitingWorker) { + CondVar_.Signal(); + } + return true; + } + } + + return false; +} + +namespace { + class TRepeatedTask + : public TTaskScheduler::ITask + { + public: + TRepeatedTask(TTaskScheduler::IRepeatedTaskRef task, TDuration period, TInstant deadline) + : Task_(task) + , Period_(period) + , Deadline_(deadline) + { + } + private: + TInstant Process() final { + Deadline_ += Period_; + if (Task_->Process()) { + return Deadline_; + } else { + return TInstant::Max(); + } + } + private: + TTaskScheduler::IRepeatedTaskRef Task_; + TDuration Period_; + TInstant Deadline_; + }; +} + +bool TTaskScheduler::Add(IRepeatedTaskRef task, TDuration period) { + const TInstant deadline = Now() + period; + ITaskRef t = new TRepeatedTask(task, period, deadline); + return Add(t, deadline); +} + + +const bool debugOutput = false; + +void TTaskScheduler::ChangeDebugState(TWorkerThread* thread, const TString& state) { + if (!debugOutput) { + Y_UNUSED(thread); + Y_UNUSED(state); + return; + } + + thread->DebugState = state; + + TStringStream ss; + ss << Now() << " " << thread->DebugId << ":\t"; + for (auto& w : Workers_) { + ss << w->DebugState << " "; + } + ss << " [" << Queue_.size() << "] [" << TaskCounter_ << "]" << Endl; + Cerr << ss.Str(); +} + +bool TTaskScheduler::Wait(TWorkerThread* thread, TQueueIterator& toWait) { + ChangeDebugState(thread, "w"); + toWait->second.WaitingWorker = thread; + return !CondVar_.WaitD(Lock_, toWait->first); +} + +void TTaskScheduler::ChooseFromQueue(TQueueIterator& toWait) { + for (TQueueIterator it = Queue_.begin(); it != Queue_.end(); ++it) { + if (!it->second.WaitingWorker) { + if (toWait == Queue_.end()) { + toWait = it; + } else if (it->first < toWait->first) { + toWait->second.WaitingWorker = nullptr; + toWait = it; + } + break; + } + } +} + +void TTaskScheduler::WorkerFunc(TWorkerThread* thread) { + TThread::SetCurrentThreadName("TaskSchedWorker"); + + TQueueIterator toWait = Queue_.end(); + ITaskRef toDo; + + for (;;) { + TInstant repeat = TInstant::Max(); + + if (!!toDo) { + try { + repeat = toDo->Process(); + } catch (...) { + Cdbg << "task scheduler error: " << CurrentExceptionMessage(); + } + } + + + with_lock (Lock_) { + ChangeDebugState(thread, "f"); + + if (IsStopped_) { + ChangeDebugState(thread, "s"); + return ; + } + + if (!!toDo) { + if (repeat < TInstant::Max()) { + Queue_.insert(std::make_pair(repeat, TTaskHolder(toDo))); + } + } + + toDo = nullptr; + + ChooseFromQueue(toWait); + + if (toWait != Queue_.end()) { + if (toWait->first <= Now() || Wait(thread, toWait)) { + + toDo = toWait->second.Task; + Queue_.erase(toWait); + toWait = Queue_.end(); + + if (!Queue_.empty() && !Queue_.begin()->second.WaitingWorker && Workers_.size() > 1) { + CondVar_.Signal(); + } + + ChangeDebugState(thread, "p"); + } + } else { + ChangeDebugState(thread, "e"); + CondVar_.WaitI(Lock_); + } + } + } +} diff --git a/library/cpp/threading/task_scheduler/task_scheduler.h b/library/cpp/threading/task_scheduler/task_scheduler.h new file mode 100644 index 0000000000..df4da941a8 --- /dev/null +++ b/library/cpp/threading/task_scheduler/task_scheduler.h @@ -0,0 +1,86 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/generic/ptr.h> +#include <util/generic/map.h> + +#include <util/datetime/base.h> + +#include <util/system/condvar.h> +#include <util/system/mutex.h> + +class TTaskScheduler { +public: + class ITask; + using ITaskRef = TIntrusivePtr<ITask>; + + class IRepeatedTask; + using IRepeatedTaskRef = TIntrusivePtr<IRepeatedTask>; +public: + explicit TTaskScheduler(size_t threadCount = 1, size_t maxTaskCount = Max<size_t>()); + ~TTaskScheduler(); + + void Start(); + void Stop(); + + bool Add(ITaskRef task, TInstant expire); + bool Add(IRepeatedTaskRef task, TDuration period); + + size_t GetTaskCount() const; +private: + class TWorkerThread; + + struct TTaskHolder { + explicit TTaskHolder(ITaskRef& task) + : Task(task) + { + } + public: + ITaskRef Task; + TWorkerThread* WaitingWorker = nullptr; + }; + + using TQueueType = TMultiMap<TInstant, TTaskHolder>; + using TQueueIterator = TQueueType::iterator; +private: + void ChangeDebugState(TWorkerThread* thread, const TString& state); + void ChooseFromQueue(TQueueIterator& toWait); + bool Wait(TWorkerThread* thread, TQueueIterator& toWait); + + void WorkerFunc(TWorkerThread* thread); +private: + bool IsStopped_ = false; + + TAtomic TaskCounter_ = 0; + TQueueType Queue_; + + TCondVar CondVar_; + TMutex Lock_; + + TVector<TAutoPtr<TWorkerThread>> Workers_; + + const size_t MaxTaskCount_; +}; + +class TTaskScheduler::ITask + : public TAtomicRefCount<ITask> +{ +public: + virtual ~ITask(); + + virtual TInstant Process() {//returns time to repeat this task + return TInstant::Max(); + } +}; + +class TTaskScheduler::IRepeatedTask + : public TAtomicRefCount<IRepeatedTask> +{ +public: + virtual ~IRepeatedTask(); + + virtual bool Process() {//returns if to repeat task again + return false; + } +}; + diff --git a/library/cpp/threading/task_scheduler/task_scheduler_ut.cpp b/library/cpp/threading/task_scheduler/task_scheduler_ut.cpp new file mode 100644 index 0000000000..3b5203194a --- /dev/null +++ b/library/cpp/threading/task_scheduler/task_scheduler_ut.cpp @@ -0,0 +1,86 @@ +#include <algorithm> +#include <library/cpp/testing/unittest/registar.h> + +#include <util/stream/output.h> +#include <util/system/atomic.h> +#include <util/generic/vector.h> + +#include "task_scheduler.h" + +class TTaskSchedulerTest: public TTestBase { + UNIT_TEST_SUITE(TTaskSchedulerTest); + UNIT_TEST(Test); + UNIT_TEST_SUITE_END(); + + class TCheckTask: public TTaskScheduler::IRepeatedTask { + public: + TCheckTask(const TDuration& delay) + : Start_(Now()) + , Delay_(delay) + { + AtomicIncrement(ScheduledTaskCounter_); + } + + ~TCheckTask() override { + } + + bool Process() override { + const TDuration delay = Now() - Start_; + + if (delay < Delay_) { + AtomicIncrement(BadTimeoutCounter_); + } + + AtomicIncrement(ExecutedTaskCounter_); + + return false; + } + + static bool AllTaskExecuted() { + return AtomicGet(ScheduledTaskCounter_) == AtomicGet(ExecutedTaskCounter_); + } + + static size_t BadTimeoutCount() { + return AtomicGet(BadTimeoutCounter_); + } + + private: + TInstant Start_; + TDuration Delay_; + static TAtomic BadTimeoutCounter_; + static TAtomic ScheduledTaskCounter_; + static TAtomic ExecutedTaskCounter_; + }; + + public: + inline void Test() { + ScheduleCheckTask(200); + ScheduleCheckTask(100); + ScheduleCheckTask(1000); + ScheduleCheckTask(10000); + ScheduleCheckTask(5000); + + Scheduler_.Start(); + + usleep(1000000); + + UNIT_ASSERT_EQUAL(TCheckTask::BadTimeoutCount(), 0); + UNIT_ASSERT(TCheckTask::AllTaskExecuted()); + } + + private: + void ScheduleCheckTask(size_t delay) { + TDuration d = TDuration::MicroSeconds(delay); + + Scheduler_.Add(new TCheckTask(d), d); + } + + private: + TTaskScheduler Scheduler_; +}; + +TAtomic TTaskSchedulerTest::TCheckTask::BadTimeoutCounter_ = 0; +TAtomic TTaskSchedulerTest::TCheckTask::ScheduledTaskCounter_ = 0; +TAtomic TTaskSchedulerTest::TCheckTask::ExecutedTaskCounter_ = 0; + +UNIT_TEST_SUITE_REGISTRATION(TTaskSchedulerTest); diff --git a/library/cpp/threading/task_scheduler/ut/ya.make b/library/cpp/threading/task_scheduler/ut/ya.make new file mode 100644 index 0000000000..07ee8b0877 --- /dev/null +++ b/library/cpp/threading/task_scheduler/ut/ya.make @@ -0,0 +1,9 @@ +UNITTEST_FOR(library/cpp/threading/task_scheduler) + +OWNER(g:middle) + +SRCS( + task_scheduler_ut.cpp +) + +END() diff --git a/library/cpp/threading/task_scheduler/ya.make b/library/cpp/threading/task_scheduler/ya.make new file mode 100644 index 0000000000..5b14c0aa63 --- /dev/null +++ b/library/cpp/threading/task_scheduler/ya.make @@ -0,0 +1,9 @@ +LIBRARY() + +OWNER(g:middle) + +SRCS( + task_scheduler.cpp +) + +END() diff --git a/library/cpp/threading/ya.make b/library/cpp/threading/ya.make new file mode 100644 index 0000000000..f4d850ee17 --- /dev/null +++ b/library/cpp/threading/ya.make @@ -0,0 +1,64 @@ +RECURSE( + algorithm + async_task_batch + async_task_batch/ut + atomic + atomic/ut + atomic_shared_ptr + atomic_shared_ptr/ut + blocking_counter + blocking_counter/ut + blocking_queue + cancellation + chunk_queue + chunk_queue/ut + cron + cron/example + equeue + equeue/ut + fair_lock + fair_lock/ut + future + future/perf + future/subscription + future/ut + hot_swap + hot_swap/ut + light_rw_lock + light_rw_lock/bench + light_rw_lock/ut + local_executor + local_executor/ut + mtp_tasks + mtp_tasks/ut + mux_event + mux_event/ut + named_lock + named_lock/ut + name_guard + name_guard/ut + periodically_updated + poor_man_openmp + poor_man_openmp/ut + queue + queue/ut + rcu + rcu/ut + serial_postprocess_queue + skip_list + skip_list/perf + skip_list/ut + synchronized + synchronized/ut + task_scheduler + task_scheduler/ut + thread_local + thread_local/benchmark + thread_local/ut + thread_namer + thread_namer/ut + ticket_lock + ticket_lock/ut + work_stealing + work_stealing/ut +) |