aboutsummaryrefslogblamecommitdiffstats
path: root/util/system/condvar_ut.cpp
blob: c6cce1c0a7ac6787270b2f2dce19fed5c7731d0e (plain) (tree)
1
2
3
4
5
6
7
8
9
10


                    
                                                  
 
                             
 
                 
                                      
                                  
                             

                          


                          
                                              
 
                                    
 
                                       
 
                            
      
                                              
                                                  
 





                                                                              
 
                                      


                                            
 

                              
     
                          
                                     
 
                                                  
                                                   
                                                                                        


                                             
                                              











                                                                                               
                                                    


                                         
                                              





                                                  

                                                  

                                                   
                                               
                                                         
                                                                                            


                                                 

                                                    
                                               
                                                          
                                                                                            
                                                 
                     
             
 
                                                                     



                                                  

                                                   
                                               
                                   
                                                                                                   
                                                            


                                             

                                                    

                                                                                                   
                                                             



                                             
                                                              
                  

                    
                         
                           
      








                                                                                    
 

                                
 


                                             

                                    


                                                 
                      
                   

                                           
#include "mutex.h"
#include "guard.h"
#include "condvar.h"

#include <library/cpp/testing/unittest/registar.h>

#include <util/thread/pool.h>

#include <atomic>

class TCondVarTest: public TTestBase {
    UNIT_TEST_SUITE(TCondVarTest);
    UNIT_TEST(TestBasics)
    UNIT_TEST(TestSyncronize)
    UNIT_TEST_SUITE_END();

    struct TSharedData {
        TMutex mutex;
        TCondVar condVar1;
        TCondVar condVar2;

        std::atomic<bool> stopWaiting = false;

        std::atomic<size_t> in = 0;
        std::atomic<size_t> out = 0;

        std::atomic<size_t> waited = 0;

        bool failed = false;
    };

    class TThreadTask: public IObjectInQueue {
    public:
        using PFunc = void (TThreadTask::*)(void);

        TThreadTask(PFunc func, size_t id, size_t totalIds, TSharedData& data)
            : Func_(func)
            , Id_(id)
            , TotalIds_(totalIds)
            , Data_(data)
        {
        }

        void Process(void*) override {
            THolder<TThreadTask> This(this);

            (this->*Func_)();
        }

#define FAIL_ASSERT(cond)    \
    if (!(cond)) {           \
        Data_.failed = true; \
    }
        void RunBasics() {
            Y_ASSERT(TotalIds_ == 3);

            if (Id_ < 2) {
                TGuard<TMutex> guard(Data_.mutex);
                while (!Data_.stopWaiting.load()) {
                    bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
                    FAIL_ASSERT(res == true);
                }
            } else {
                usleep(100000);
                Data_.stopWaiting.store(true);

                TGuard<TMutex> guard(Data_.mutex);
                Data_.condVar1.Signal();
                Data_.condVar1.Signal();
            }
        }

        void RunBasicsWithPredicate() {
            Y_ASSERT(TotalIds_ == 3);

            if (Id_ < 2) {
                TGuard<TMutex> guard(Data_.mutex);
                const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
                    return Data_.stopWaiting.load();
                });
                FAIL_ASSERT(res == true);
            } else {
                usleep(100000);
                Data_.stopWaiting.store(true);

                TGuard<TMutex> guard(Data_.mutex);
                Data_.condVar1.Signal();
                Data_.condVar1.Signal();
            }
        }

        void RunSyncronize() {
            for (size_t i = 0; i < 10; ++i) {
                TGuard<TMutex> guard(Data_.mutex);
                ++Data_.in;
                if (Data_.in.load() == TotalIds_) {
                    Data_.out.store(0);
                    Data_.condVar1.BroadCast();
                } else {
                    ++Data_.waited;
                    while (Data_.in.load() < TotalIds_) {
                        bool res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1));
                        FAIL_ASSERT(res == true);
                    }
                }

                ++Data_.out;
                if (Data_.out.load() == TotalIds_) {
                    Data_.in.store(0);
                    Data_.condVar2.BroadCast();
                } else {
                    while (Data_.out.load() < TotalIds_) {
                        bool res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1));
                        FAIL_ASSERT(res == true);
                    }
                }
            }

            FAIL_ASSERT(Data_.waited.load() == (TotalIds_ - 1) * 10);
        }

        void RunSyncronizeWithPredicate() {
            for (size_t i = 0; i < 10; ++i) {
                TGuard<TMutex> guard(Data_.mutex);
                ++Data_.in;
                if (Data_.in.load() == TotalIds_) {
                    Data_.out.store(0);
                    Data_.condVar1.BroadCast();
                } else {
                    ++Data_.waited;
                    const auto res = Data_.condVar1.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
                        return Data_.in.load() >= TotalIds_;
                    });
                    FAIL_ASSERT(res == true);
                }

                ++Data_.out;
                if (Data_.out.load() == TotalIds_) {
                    Data_.in.store(0);
                    Data_.condVar2.BroadCast();
                } else {
                    const auto res = Data_.condVar2.WaitT(Data_.mutex, TDuration::Seconds(1), [&] {
                        return Data_.out.load() >= TotalIds_;
                    });
                    FAIL_ASSERT(res == true);
                }
            }

            FAIL_ASSERT(Data_.waited == (TotalIds_ - 1) * 10);
        }
#undef FAIL_ASSERT

    private:
        PFunc Func_;
        size_t Id_;
        size_t TotalIds_;
        TSharedData& Data_;
    };

private:
#define RUN_CYCLE(what, count)                                                     \
    Q_.Start(count);                                                               \
    for (size_t i = 0; i < count; ++i) {                                           \
        UNIT_ASSERT(Q_.Add(new TThreadTask(&TThreadTask::what, i, count, Data_))); \
    }                                                                              \
    Q_.Stop();                                                                     \
    bool b = Data_.failed;                                                         \
    Data_.failed = false;                                                          \
    UNIT_ASSERT(!b);

    inline void TestBasics() {
        RUN_CYCLE(RunBasics, 3);
    }

    inline void TestBasicsWithPredicate() {
        RUN_CYCLE(RunBasicsWithPredicate, 3);
    }

    inline void TestSyncronize() {
        RUN_CYCLE(RunSyncronize, 6);
    }

    inline void TestSyncronizeWithPredicate() {
        RUN_CYCLE(RunSyncronizeWithPredicate, 6);
    }
#undef RUN_CYCLE
    TSharedData Data_;
    TThreadPool Q_;
};

UNIT_TEST_SUITE_REGISTRATION(TCondVarTest);