aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/threading/poor_man_openmp/thread_helper.h
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/threading/poor_man_openmp/thread_helper.h
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/threading/poor_man_openmp/thread_helper.h')
-rw-r--r--library/cpp/threading/poor_man_openmp/thread_helper.h105
1 files changed, 105 insertions, 0 deletions
diff --git a/library/cpp/threading/poor_man_openmp/thread_helper.h b/library/cpp/threading/poor_man_openmp/thread_helper.h
new file mode 100644
index 0000000000..0ecee0590b
--- /dev/null
+++ b/library/cpp/threading/poor_man_openmp/thread_helper.h
@@ -0,0 +1,105 @@
+#pragma once
+
+#include <util/thread/pool.h>
+#include <util/generic/utility.h>
+#include <util/generic/yexception.h>
+#include <util/system/info.h>
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/stream/output.h>
+
+#include <functional>
+#include <cstdlib>
+
+class TMtpQueueHelper {
+public:
+ TMtpQueueHelper() {
+ SetThreadCount(NSystemInfo::CachedNumberOfCpus());
+ }
+ IThreadPool* Get() {
+ return q.Get();
+ }
+ size_t GetThreadCount() {
+ return ThreadCount;
+ }
+ void SetThreadCount(size_t threads) {
+ ThreadCount = threads;
+ q = CreateThreadPool(ThreadCount);
+ }
+
+ static TMtpQueueHelper& Instance();
+
+private:
+ size_t ThreadCount;
+ TAutoPtr<IThreadPool> q;
+};
+
+namespace NYmp {
+ inline void SetThreadCount(size_t threads) {
+ TMtpQueueHelper::Instance().SetThreadCount(threads);
+ }
+
+ inline size_t GetThreadCount() {
+ return TMtpQueueHelper::Instance().GetThreadCount();
+ }
+
+ template <typename T>
+ inline void ParallelForStaticChunk(T begin, T end, size_t chunkSize, std::function<void(T)> func) {
+ chunkSize = Max<size_t>(chunkSize, 1);
+
+ size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount();
+ IThreadPool* queue = TMtpQueueHelper::Instance().Get();
+ TCondVar cv;
+ TMutex mutex;
+ TAtomic counter = threadCount;
+ std::exception_ptr err;
+
+ for (size_t i = 0; i < threadCount; ++i) {
+ queue->SafeAddFunc([&cv, &counter, &mutex, &func, i, begin, end, chunkSize, threadCount, &err]() {
+ try {
+ T currentChunkStart = begin + static_cast<decltype(T() - T())>(i * chunkSize);
+
+ while (currentChunkStart < end) {
+ T currentChunkEnd = Min<T>(end, currentChunkStart + chunkSize);
+
+ for (T val = currentChunkStart; val < currentChunkEnd; ++val) {
+ func(val);
+ }
+
+ currentChunkStart += chunkSize * threadCount;
+ }
+ } catch (...) {
+ with_lock (mutex) {
+ err = std::current_exception();
+ }
+ }
+
+ with_lock (mutex) {
+ if (AtomicDecrement(counter) == 0) {
+ //last one
+ cv.Signal();
+ }
+ }
+ });
+ }
+
+ with_lock (mutex) {
+ while (AtomicGet(counter) > 0) {
+ cv.WaitI(mutex);
+ }
+ }
+
+ if (err) {
+ std::rethrow_exception(err);
+ }
+ }
+
+ template <typename T>
+ inline void ParallelForStaticAutoChunk(T begin, T end, std::function<void(T)> func) {
+ const size_t taskSize = end - begin;
+ const size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount();
+
+ ParallelForStaticChunk(begin, end, (taskSize + threadCount - 1) / threadCount, func);
+ }
+}