#include "pool.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/random/fast.h>
#include <util/system/spinlock.h>
#include <util/system/thread.h>
#include <util/system/mutex.h>
#include <util/system/condvar.h>

struct TThreadPoolTest {
    TSpinLock Lock;
    long R = -1;

    struct TTask: public IObjectInQueue {
        TThreadPoolTest* Test = nullptr;
        long Value = 0;

        TTask(TThreadPoolTest* test, int value)
            : Test(test)
            , Value(value)
        {
        }

        void Process(void*) override {
            THolder<TTask> This(this);

            TGuard<TSpinLock> guard(Test->Lock);
            Test->R ^= Value;
        }
    };

    struct TOwnedTask: public IObjectInQueue {
        bool& Processed;
        bool& Destructed;

        TOwnedTask(bool& processed, bool& destructed)
            : Processed(processed)
            , Destructed(destructed)
        {
        }

        ~TOwnedTask() override {
            Destructed = true;
        }

        void Process(void*) override {
            Processed = true;
        }
    };

    inline void TestAnyQueue(IThreadPool* queue, size_t queueSize = 1000) {
        TReallyFastRng32 rand(17);
        const size_t cnt = 1000;

        R = 0;

        for (size_t i = 0; i < cnt; ++i) {
            R ^= (long)rand.GenRand();
        }

        queue->Start(10, queueSize);
        rand = TReallyFastRng32(17);

        for (size_t i = 0; i < cnt; ++i) {
            UNIT_ASSERT(queue->Add(new TTask(this, (long)rand.GenRand())));
        }

        queue->Stop();

        UNIT_ASSERT_EQUAL(0, R);
    }
};

class TFailAddQueue: public IThreadPool {
public:
    bool Add(IObjectInQueue* /*obj*/) override Y_WARN_UNUSED_RESULT {
        return false;
    }

    void Start(size_t, size_t) override {
    }

    void Stop() noexcept override {
    }

    size_t Size() const noexcept override {
        return 0;
    }
};

Y_UNIT_TEST_SUITE(TThreadPoolTest) {
    Y_UNIT_TEST(TestTThreadPool) {
        TThreadPoolTest t;
        TThreadPool q;
        t.TestAnyQueue(&q);
    }

    Y_UNIT_TEST(TestTThreadPoolBlocking) {
        TThreadPoolTest t;
        TThreadPool q(TThreadPool::TParams().SetBlocking(true));
        t.TestAnyQueue(&q, 100);
    }

    // disabled by pg@ long time ago due to test flaps
    // Tried to enable: REVIEW:78772
    Y_UNIT_TEST(TestTAdaptiveThreadPool) {
        if (false) {
            TThreadPoolTest t;
            TAdaptiveThreadPool q;
            t.TestAnyQueue(&q);
        }
    }

    Y_UNIT_TEST(TestAddAndOwn) {
        TThreadPool q;
        q.Start(2);
        bool processed = false;
        bool destructed = false;
        q.SafeAddAndOwn(MakeHolder<TThreadPoolTest::TOwnedTask>(processed, destructed));
        q.Stop();

        UNIT_ASSERT_C(processed, "Not processed");
        UNIT_ASSERT_C(destructed, "Not destructed");
    }

    Y_UNIT_TEST(TestAddFunc) {
        TFailAddQueue queue;
        bool added = queue.AddFunc(
            []() {} // Lambda, I call him 'Lambda'!
        );
        UNIT_ASSERT_VALUES_EQUAL(added, false);
    }

    Y_UNIT_TEST(TestSafeAddFuncThrows) {
        TFailAddQueue queue;
        UNIT_CHECK_GENERATED_EXCEPTION(queue.SafeAddFunc([] {}), TThreadPoolException);
    }

    Y_UNIT_TEST(TestFunctionNotCopied) {
        struct TFailOnCopy {
            TFailOnCopy() {
            }

            TFailOnCopy(TFailOnCopy&&) {
            }

            TFailOnCopy(const TFailOnCopy&) {
                UNIT_FAIL("Don't copy std::function inside TThreadPool");
            }
        };

        TThreadPool queue(TThreadPool::TParams().SetBlocking(false).SetCatching(true));
        queue.Start(2);

        queue.SafeAddFunc([data = TFailOnCopy()]() {});

        queue.Stop();
    }

    Y_UNIT_TEST(TestInfoGetters) {
        TThreadPool queue;

        queue.Start(2, 7);

        UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 2);
        UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 2);
        UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 7);

        queue.Stop();

        queue.Start(4, 1);

        UNIT_ASSERT_EQUAL(queue.GetThreadCountExpected(), 4);
        UNIT_ASSERT_EQUAL(queue.GetThreadCountReal(), 4);
        UNIT_ASSERT_EQUAL(queue.GetMaxQueueSize(), 1);

        queue.Stop();
    }

    void TestFixedThreadName(IThreadPool& pool, const TString& expectedName) {
        pool.Start(1);
        TString name;
        pool.SafeAddFunc([&name]() {
            name = TThread::CurrentThreadName();
        });
        pool.Stop();
        if (TThread::CanGetCurrentThreadName()) {
            UNIT_ASSERT_EQUAL(name, expectedName);
            UNIT_ASSERT_UNEQUAL(TThread::CurrentThreadName(), expectedName);
        }
    }

    Y_UNIT_TEST(TestFixedThreadName) {
        const TString expectedName = "HelloWorld";
        {
            TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadName(expectedName));
            TestFixedThreadName(pool, expectedName);
        }
        {
            TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadName(expectedName));
            TestFixedThreadName(pool, expectedName);
        }
    }

    void TestEnumeratedThreadName(IThreadPool& pool, const THashSet<TString>& expectedNames) {
        pool.Start(expectedNames.size());
        TMutex lock;
        TCondVar allReady;
        size_t readyCount = 0;
        THashSet<TString> names;
        for (size_t i = 0; i < expectedNames.size(); ++i) {
            pool.SafeAddFunc([&]() {
                with_lock (lock) {
                    if (++readyCount == expectedNames.size()) {
                        allReady.BroadCast();
                    } else {
                        while (readyCount != expectedNames.size()) {
                            allReady.WaitI(lock);
                        }
                    }
                    names.insert(TThread::CurrentThreadName());
                }
            });
        }
        pool.Stop();
        if (TThread::CanGetCurrentThreadName()) {
            UNIT_ASSERT_EQUAL(names, expectedNames);
        }
    }

    Y_UNIT_TEST(TestEnumeratedThreadName) {
        const TString namePrefix = "HelloWorld";
        const THashSet<TString> expectedNames = {
            "HelloWorld0",
            "HelloWorld1",
            "HelloWorld2",
            "HelloWorld3",
            "HelloWorld4",
            "HelloWorld5",
            "HelloWorld6",
            "HelloWorld7",
            "HelloWorld8",
            "HelloWorld9",
            "HelloWorld10",
        };
        {
            TThreadPool pool(TThreadPool::TParams().SetBlocking(true).SetCatching(false).SetThreadNamePrefix(namePrefix));
            TestEnumeratedThreadName(pool, expectedNames);
        }
        {
            TAdaptiveThreadPool pool(TThreadPool::TParams().SetThreadNamePrefix(namePrefix));
            TestEnumeratedThreadName(pool, expectedNames);
        }
    }
}