diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/messagebus | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/messagebus')
265 files changed, 23011 insertions, 0 deletions
diff --git a/library/cpp/messagebus/acceptor.cpp b/library/cpp/messagebus/acceptor.cpp new file mode 100644 index 0000000000..64a38619c2 --- /dev/null +++ b/library/cpp/messagebus/acceptor.cpp @@ -0,0 +1,127 @@ +#include "acceptor.h" + +#include "key_value_printer.h" +#include "mb_lwtrace.h" +#include "network.h" + +#include <util/network/init.h> +#include <util/system/defaults.h> +#include <util/system/error.h> +#include <util/system/yassert.h> + +LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER) + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; + +TAcceptor::TAcceptor(TBusSessionImpl* session, ui64 acceptorId, SOCKET socket, const TNetAddr& addr) + : TActor<TAcceptor>(session->Queue->WorkQueue.Get()) + , AcceptorId(acceptorId) + , Session(session) + , GranStatus(session->Config.Secret.StatusFlushPeriod) +{ + SetNonBlock(socket, true); + + Channel = Session->ReadEventLoop.Register(socket, this); + Channel->EnableRead(); + + Stats.AcceptorId = acceptorId; + Stats.Fd = socket; + Stats.ListenAddr = addr; + + SendStatus(TInstant::Now()); +} + +void TAcceptor::Act(TDefaultTag) { + EShutdownState state = ShutdownState.State.Get(); + + if (state == SS_SHUTDOWN_COMPLETE) { + return; + } + + TInstant now = TInstant::Now(); + + if (state == SS_SHUTDOWN_COMMAND) { + if (!!Channel) { + Channel->Unregister(); + Channel.Drop(); + Stats.Fd = INVALID_SOCKET; + } + + SendStatus(now); + + Session->GetDeadAcceptorStatusQueue()->EnqueueAndSchedule(Stats); + Stats.ResetIncremental(); + + ShutdownState.CompleteShutdown(); + return; + } + + THolder<TOpaqueAddr> addr(new TOpaqueAddr()); + SOCKET acceptedSocket = accept(Channel->GetSocket(), addr->MutableAddr(), addr->LenPtr()); + + int acceptErrno = LastSystemError(); + + if (acceptedSocket == INVALID_SOCKET) { + if (LastSystemError() != EWOULDBLOCK) { + Stats.LastAcceptErrorErrno = acceptErrno; + Stats.LastAcceptErrorInstant = now; + ++Stats.AcceptErrorCount; + } + } else { + TSocketHolder s(acceptedSocket); + try { + SetKeepAlive(s, true); + SetNoDelay(s, Session->Config.TcpNoDelay); + SetSockOptTcpCork(s, Session->Config.TcpCork); + SetCloseOnExec(s, true); + SetNonBlock(s, true); + if (Session->Config.SocketToS >= 0) { + SetSocketToS(s, addr.Get(), Session->Config.SocketToS); + } + } catch (...) { + // It means that connection was reset just now + // TODO: do something better + goto skipAccept; + } + + { + TOnAccept onAccept; + onAccept.s = s.Release(); + onAccept.addr = TNetAddr(addr.Release()); + onAccept.now = now; + + LWPROBE(Accepted, ToString(onAccept.addr)); + + Session->GetOnAcceptQueue()->EnqueueAndSchedule(onAccept); + + Stats.LastAcceptSuccessInstant = now; + ++Stats.AcceptSuccessCount; + } + + skipAccept:; + } + + Channel->EnableRead(); + + SendStatus(now); +} + +void TAcceptor::SendStatus(TInstant now) { + GranStatus.Listen.Update(Stats, now); +} + +void TAcceptor::HandleEvent(SOCKET socket, void* cookie) { + Y_UNUSED(socket); + Y_UNUSED(cookie); + + GetActor()->Schedule(); +} + +void TAcceptor::Shutdown() { + ShutdownState.ShutdownCommand(); + GetActor()->Schedule(); + + ShutdownState.ShutdownComplete.WaitI(); +} diff --git a/library/cpp/messagebus/acceptor.h b/library/cpp/messagebus/acceptor.h new file mode 100644 index 0000000000..57cb010bf2 --- /dev/null +++ b/library/cpp/messagebus/acceptor.h @@ -0,0 +1,60 @@ +#pragma once + +#include "acceptor_status.h" +#include "defs.h" +#include "event_loop.h" +#include "netaddr.h" +#include "session_impl.h" +#include "shutdown_state.h" + +#include <library/cpp/messagebus/actor/actor.h> + +#include <util/system/event.h> + +namespace NBus { + namespace NPrivate { + class TAcceptor + : public NEventLoop::IEventHandler, + private ::NActor::TActor<TAcceptor> { + friend struct TBusSessionImpl; + friend class ::NActor::TActor<TAcceptor>; + + public: + TAcceptor(TBusSessionImpl* session, ui64 acceptorId, SOCKET socket, const TNetAddr& addr); + + void HandleEvent(SOCKET socket, void* cookie) override; + + void Shutdown(); + + inline ::NActor::TActor<TAcceptor>* GetActor() { + return this; + } + + private: + void SendStatus(TInstant now); + void Act(::NActor::TDefaultTag); + + private: + const ui64 AcceptorId; + + TBusSessionImpl* const Session; + NEventLoop::TChannelPtr Channel; + + TAcceptorStatus Stats; + + TAtomicShutdownState ShutdownState; + + struct TGranStatus { + TGranStatus(TDuration gran) + : Listen(gran) + { + } + + TGranUp<TAcceptorStatus> Listen; + }; + + TGranStatus GranStatus; + }; + + } +} diff --git a/library/cpp/messagebus/acceptor_status.cpp b/library/cpp/messagebus/acceptor_status.cpp new file mode 100644 index 0000000000..5006ff68ae --- /dev/null +++ b/library/cpp/messagebus/acceptor_status.cpp @@ -0,0 +1,68 @@ +#include "acceptor_status.h" + +#include "key_value_printer.h" + +#include <util/stream/format.h> +#include <util/stream/output.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TAcceptorStatus::TAcceptorStatus() + : Summary(false) + , AcceptorId(0) + , Fd(INVALID_SOCKET) +{ + ResetIncremental(); +} + +void TAcceptorStatus::ResetIncremental() { + AcceptSuccessCount = 0; + AcceptErrorCount = 0; + LastAcceptErrorErrno = 0; + LastAcceptErrorInstant = TInstant(); + LastAcceptSuccessInstant = TInstant(); +} + +TAcceptorStatus& TAcceptorStatus::operator+=(const TAcceptorStatus& that) { + Y_ASSERT(Summary); + Y_ASSERT(AcceptorId == 0); + + AcceptSuccessCount += that.AcceptSuccessCount; + LastAcceptSuccessInstant = Max(LastAcceptSuccessInstant, that.LastAcceptSuccessInstant); + + AcceptErrorCount += that.AcceptErrorCount; + if (that.LastAcceptErrorInstant > LastAcceptErrorInstant) { + LastAcceptErrorInstant = that.LastAcceptErrorInstant; + LastAcceptErrorErrno = that.LastAcceptErrorErrno; + } + + return *this; +} + +TString TAcceptorStatus::PrintToString() const { + TStringStream ss; + + if (!Summary) { + ss << "acceptor (" << AcceptorId << "), fd=" << Fd << ", addr=" << ListenAddr << Endl; + } + + TKeyValuePrinter p; + + p.AddRow("accept error count", LeftPad(AcceptErrorCount, 4)); + + if (AcceptErrorCount > 0) { + p.AddRow("last accept error", + TString() + LastSystemErrorText(LastAcceptErrorErrno) + " at " + LastAcceptErrorInstant.ToString()); + } + + p.AddRow("accept success count", LeftPad(AcceptSuccessCount, 4)); + if (AcceptSuccessCount > 0) { + p.AddRow("last accept success", + TString() + "at " + LastAcceptSuccessInstant.ToString()); + } + + ss << p.PrintToString(); + + return ss.Str(); +} diff --git a/library/cpp/messagebus/acceptor_status.h b/library/cpp/messagebus/acceptor_status.h new file mode 100644 index 0000000000..6aa1404f4d --- /dev/null +++ b/library/cpp/messagebus/acceptor_status.h @@ -0,0 +1,35 @@ +#pragma once + +#include "netaddr.h" + +#include <util/network/init.h> + +namespace NBus { + namespace NPrivate { + struct TAcceptorStatus { + bool Summary; + + ui64 AcceptorId; + + SOCKET Fd; + + TNetAddr ListenAddr; + + unsigned AcceptSuccessCount; + TInstant LastAcceptSuccessInstant; + + unsigned AcceptErrorCount; + TInstant LastAcceptErrorInstant; + int LastAcceptErrorErrno; + + void ResetIncremental(); + + TAcceptorStatus(); + + TAcceptorStatus& operator+=(const TAcceptorStatus& that); + + TString PrintToString() const; + }; + + } +} diff --git a/library/cpp/messagebus/actor/actor.h b/library/cpp/messagebus/actor/actor.h new file mode 100644 index 0000000000..9b8f20298a --- /dev/null +++ b/library/cpp/messagebus/actor/actor.h @@ -0,0 +1,144 @@ +#pragma once + +#include "executor.h" +#include "tasks.h" +#include "what_thread_does.h" + +#include <util/system/yassert.h> + +namespace NActor { + class IActor: protected IWorkItem { + public: + // TODO: make private + TTasks Tasks; + + public: + virtual void ScheduleHereV() = 0; + virtual void ScheduleV() = 0; + virtual void ScheduleHereAtMostOnceV() = 0; + + // TODO: make private + virtual void RefV() = 0; + virtual void UnRefV() = 0; + + // mute warnings + ~IActor() override { + } + }; + + struct TDefaultTag {}; + + template <typename TThis, typename TTag = TDefaultTag> + class TActor: public IActor { + private: + TExecutor* const Executor; + + public: + TActor(TExecutor* executor) + : Executor(executor) + { + } + + void AddTaskFromActorLoop() { + bool schedule = Tasks.AddTask(); + // TODO: check thread id + Y_ASSERT(!schedule); + } + + /** + * Schedule actor. + * + * If actor is sleeping, then actor will be executed right now. + * If actor is executing right now, it will be executed one more time. + * If this method is called multiple time, actor will be re-executed no more than one more time. + */ + void Schedule() { + if (Tasks.AddTask()) { + EnqueueWork(); + } + } + + /** + * Schedule actor, execute it in current thread. + * + * If actor is running, continue executing where it is executing. + * If actor is sleeping, execute it in current thread. + * + * Operation is useful for tasks that are likely to complete quickly. + */ + void ScheduleHere() { + if (Tasks.AddTask()) { + Loop(); + } + } + + /** + * Schedule actor, execute in current thread no more than once. + * + * If actor is running, continue executing where it is executing. + * If actor is sleeping, execute one iteration here, and if actor got new tasks, + * reschedule it in worker pool. + */ + void ScheduleHereAtMostOnce() { + if (Tasks.AddTask()) { + bool fetched = Tasks.FetchTask(); + Y_VERIFY(fetched, "happens"); + + DoAct(); + + // if someone added more tasks, schedule them + if (Tasks.FetchTask()) { + bool added = Tasks.AddTask(); + Y_VERIFY(!added, "happens"); + EnqueueWork(); + } + } + } + + void ScheduleHereV() override { + ScheduleHere(); + } + void ScheduleV() override { + Schedule(); + } + void ScheduleHereAtMostOnceV() override { + ScheduleHereAtMostOnce(); + } + void RefV() override { + GetThis()->Ref(); + } + void UnRefV() override { + GetThis()->UnRef(); + } + + private: + TThis* GetThis() { + return static_cast<TThis*>(this); + } + + void EnqueueWork() { + GetThis()->Ref(); + Executor->EnqueueWork({this}); + } + + void DoAct() { + WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC(); + + GetThis()->Act(TTag()); + } + + void Loop() { + // TODO: limit number of iterations + while (Tasks.FetchTask()) { + DoAct(); + } + } + + void DoWork() override { + Y_ASSERT(GetThis()->RefCount() >= 1); + Loop(); + GetThis()->UnRef(); + } + }; + +} diff --git a/library/cpp/messagebus/actor/actor_ut.cpp b/library/cpp/messagebus/actor/actor_ut.cpp new file mode 100644 index 0000000000..b76ab55bfa --- /dev/null +++ b/library/cpp/messagebus/actor/actor_ut.cpp @@ -0,0 +1,157 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "actor.h" +#include "queue_in_actor.h" + +#include <library/cpp/messagebus/misc/test_sync.h> + +#include <util/generic/object_counter.h> +#include <util/system/event.h> + +using namespace NActor; + +template <typename TThis> +struct TTestActorBase: public TAtomicRefCount<TThis>, public TActor<TThis> { + TTestSync Started; + TTestSync Acted; + + TTestActorBase(TExecutor* executor) + : TActor<TThis>(executor) + { + } + + void Act(TDefaultTag) { + Started.Inc(); + static_cast<TThis*>(this)->Act2(); + Acted.Inc(); + } +}; + +struct TNopActor: public TTestActorBase<TNopActor> { + TObjectCounter<TNopActor> AllocCounter; + + TNopActor(TExecutor* executor) + : TTestActorBase<TNopActor>(executor) + { + } + + void Act2() { + } +}; + +struct TWaitForSignalActor: public TTestActorBase<TWaitForSignalActor> { + TWaitForSignalActor(TExecutor* executor) + : TTestActorBase<TWaitForSignalActor>(executor) + { + } + + TSystemEvent WaitFor; + + void Act2() { + WaitFor.Wait(); + } +}; + +struct TDecrementAndSendActor: public TTestActorBase<TDecrementAndSendActor>, public TQueueInActor<TDecrementAndSendActor, int> { + TSystemEvent Done; + + TDecrementAndSendActor* Next; + + TDecrementAndSendActor(TExecutor* executor) + : TTestActorBase<TDecrementAndSendActor>(executor) + , Next(nullptr) + { + } + + void ProcessItem(TDefaultTag, TDefaultTag, int n) { + if (n == 0) { + Done.Signal(); + } else { + Next->EnqueueAndSchedule(n - 1); + } + } + + void Act(TDefaultTag) { + DequeueAll(); + } +}; + +struct TObjectCountChecker { + TObjectCountChecker() { + CheckCounts(); + } + + ~TObjectCountChecker() { + CheckCounts(); + } + + void CheckCounts() { + UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TNopActor>::ObjectCount()); + UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TWaitForSignalActor>::ObjectCount()); + UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TDecrementAndSendActor>::ObjectCount()); + } +}; + +Y_UNIT_TEST_SUITE(TActor) { + Y_UNIT_TEST(Simple) { + TObjectCountChecker objectCountChecker; + + TExecutor executor(4); + + TIntrusivePtr<TNopActor> actor(new TNopActor(&executor)); + + actor->Schedule(); + + actor->Acted.WaitFor(1u); + } + + Y_UNIT_TEST(ScheduleAfterStart) { + TObjectCountChecker objectCountChecker; + + TExecutor executor(4); + + TIntrusivePtr<TWaitForSignalActor> actor(new TWaitForSignalActor(&executor)); + + actor->Schedule(); + + actor->Started.WaitFor(1); + + actor->Schedule(); + + actor->WaitFor.Signal(); + + // make sure Act is called second time + actor->Acted.WaitFor(2u); + } + + void ComplexImpl(int queueSize, int actorCount) { + TObjectCountChecker objectCountChecker; + + TExecutor executor(queueSize); + + TVector<TIntrusivePtr<TDecrementAndSendActor>> actors; + for (int i = 0; i < actorCount; ++i) { + actors.push_back(new TDecrementAndSendActor(&executor)); + } + + for (int i = 0; i < actorCount; ++i) { + actors.at(i)->Next = &*actors.at((i + 1) % actorCount); + } + + for (int i = 0; i < actorCount; ++i) { + actors.at(i)->EnqueueAndSchedule(10000); + } + + for (int i = 0; i < actorCount; ++i) { + actors.at(i)->Done.WaitI(); + } + } + + Y_UNIT_TEST(ComplexContention) { + ComplexImpl(4, 6); + } + + Y_UNIT_TEST(ComplexNoContention) { + ComplexImpl(6, 4); + } +} diff --git a/library/cpp/messagebus/actor/executor.cpp b/library/cpp/messagebus/actor/executor.cpp new file mode 100644 index 0000000000..7a2227a458 --- /dev/null +++ b/library/cpp/messagebus/actor/executor.cpp @@ -0,0 +1,338 @@ +#include "executor.h" + +#include "thread_extra.h" +#include "what_thread_does.h" +#include "what_thread_does_guard.h" + +#include <util/generic/utility.h> +#include <util/random/random.h> +#include <util/stream/str.h> +#include <util/system/tls.h> +#include <util/system/yassert.h> + +#include <array> + +using namespace NActor; +using namespace NActor::NPrivate; + +namespace { + struct THistoryInternal { + struct TRecord { + TAtomic MaxQueueSize; + + TRecord() + : MaxQueueSize() + { + } + + TExecutorHistory::THistoryRecord Capture() { + TExecutorHistory::THistoryRecord r; + r.MaxQueueSize = AtomicGet(MaxQueueSize); + return r; + } + }; + + ui64 Start; + ui64 LastTime; + + std::array<TRecord, 3600> Records; + + THistoryInternal() { + Start = TInstant::Now().Seconds(); + LastTime = Start - 1; + } + + TRecord& GetRecordForTime(ui64 time) { + return Records[time % Records.size()]; + } + + TRecord& GetNowRecord(ui64 now) { + for (ui64 t = LastTime + 1; t <= now; ++t) { + GetRecordForTime(t) = TRecord(); + } + LastTime = now; + return GetRecordForTime(now); + } + + TExecutorHistory Capture() { + TExecutorHistory history; + ui64 now = TInstant::Now().Seconds(); + ui64 lastHistoryRecord = now - 1; + ui32 historySize = Min<ui32>(lastHistoryRecord - Start, Records.size() - 1); + history.HistoryRecords.resize(historySize); + for (ui32 i = 0; i < historySize; ++i) { + history.HistoryRecords[i] = GetRecordForTime(lastHistoryRecord - historySize + i).Capture(); + } + history.LastHistoryRecordSecond = lastHistoryRecord; + return history; + } + }; + +} + +Y_POD_STATIC_THREAD(TExecutor*) +ThreadCurrentExecutor; + +static const char* NoLocation = "nowhere"; + +struct TExecutorWorkerThreadLocalData { + ui32 MaxQueueSize; +}; + +static TExecutorWorkerThreadLocalData WorkerNoThreadLocalData; +Y_POD_STATIC_THREAD(TExecutorWorkerThreadLocalData) +WorkerThreadLocalData; + +namespace NActor { + struct TExecutorWorker { + TExecutor* const Executor; + TThread Thread; + const char** WhatThreadDoesLocation; + TExecutorWorkerThreadLocalData* ThreadLocalData; + + TExecutorWorker(TExecutor* executor) + : Executor(executor) + , Thread(RunThreadProc, this) + , WhatThreadDoesLocation(&NoLocation) + , ThreadLocalData(&::WorkerNoThreadLocalData) + { + Thread.Start(); + } + + void Run() { + WhatThreadDoesLocation = ::WhatThreadDoesLocation(); + AtomicSet(ThreadLocalData, &::WorkerThreadLocalData); + WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC(); + Executor->RunWorker(); + } + + static void* RunThreadProc(void* thiz0) { + TExecutorWorker* thiz = (TExecutorWorker*)thiz0; + thiz->Run(); + return nullptr; + } + }; + + struct TExecutor::TImpl { + TExecutor* const Executor; + THistoryInternal History; + + TSystemEvent HelperStopSignal; + TThread HelperThread; + + TImpl(TExecutor* executor) + : Executor(executor) + , HelperThread(HelperThreadProc, this) + { + } + + void RunHelper() { + ui64 nowSeconds = TInstant::Now().Seconds(); + for (;;) { + TInstant nextStop = TInstant::Seconds(nowSeconds + 1) + TDuration::MilliSeconds(RandomNumber<ui32>(1000)); + + if (HelperStopSignal.WaitD(nextStop)) { + return; + } + + nowSeconds = nextStop.Seconds(); + + THistoryInternal::TRecord& record = History.GetNowRecord(nowSeconds); + + ui32 maxQueueSize = Executor->GetMaxQueueSizeAndClear(); + if (maxQueueSize > record.MaxQueueSize) { + AtomicSet(record.MaxQueueSize, maxQueueSize); + } + } + } + + static void* HelperThreadProc(void* impl0) { + TImpl* impl = (TImpl*)impl0; + impl->RunHelper(); + return nullptr; + } + }; + +} + +static TExecutor::TConfig MakeConfig(unsigned workerCount) { + TExecutor::TConfig config; + config.WorkerCount = workerCount; + return config; +} + +TExecutor::TExecutor(size_t workerCount) + : Config(MakeConfig(workerCount)) +{ + Init(); +} + +TExecutor::TExecutor(const TExecutor::TConfig& config) + : Config(config) +{ + Init(); +} + +void TExecutor::Init() { + Impl.Reset(new TImpl(this)); + + AtomicSet(ExitWorkers, 0); + + Y_VERIFY(Config.WorkerCount > 0); + + for (size_t i = 0; i < Config.WorkerCount; i++) { + WorkerThreads.push_back(new TExecutorWorker(this)); + } + + Impl->HelperThread.Start(); +} + +TExecutor::~TExecutor() { + Stop(); +} + +void TExecutor::Stop() { + AtomicSet(ExitWorkers, 1); + + Impl->HelperStopSignal.Signal(); + Impl->HelperThread.Join(); + + { + TWhatThreadDoesAcquireGuard<TMutex> guard(WorkMutex, "executor: acquiring lock for Stop"); + WorkAvailable.BroadCast(); + } + + for (size_t i = 0; i < WorkerThreads.size(); i++) { + WorkerThreads[i]->Thread.Join(); + } + + // TODO: make queue empty at this point + ProcessWorkQueueHere(); +} + +void TExecutor::EnqueueWork(TArrayRef<IWorkItem* const> wis) { + if (wis.empty()) + return; + + if (Y_UNLIKELY(AtomicGet(ExitWorkers) != 0)) { + Y_VERIFY(WorkItems.Empty(), "executor %s: cannot add tasks after queue shutdown", Config.Name); + } + + TWhatThreadDoesPushPop pp("executor: EnqueueWork"); + + WorkItems.PushAll(wis); + + { + if (wis.size() == 1) { + TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for EnqueueWork"); + WorkAvailable.Signal(); + } else { + TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for EnqueueWork"); + WorkAvailable.BroadCast(); + } + } +} + +size_t TExecutor::GetWorkQueueSize() const { + return WorkItems.Size(); +} + +using namespace NTSAN; + +ui32 TExecutor::GetMaxQueueSizeAndClear() const { + ui32 max = 0; + for (unsigned i = 0; i < WorkerThreads.size(); ++i) { + TExecutorWorkerThreadLocalData* wtls = RelaxedLoad(&WorkerThreads[i]->ThreadLocalData); + max = Max<ui32>(max, RelaxedLoad(&wtls->MaxQueueSize)); + RelaxedStore<ui32>(&wtls->MaxQueueSize, 0); + } + return max; +} + +TString TExecutor::GetStatus() const { + return GetStatusRecordInternal().Status; +} + +TString TExecutor::GetStatusSingleLine() const { + TStringStream ss; + ss << "work items: " << GetWorkQueueSize(); + return ss.Str(); +} + +TExecutorStatus TExecutor::GetStatusRecordInternal() const { + TExecutorStatus r; + + r.WorkQueueSize = GetWorkQueueSize(); + + { + TStringStream ss; + ss << "work items: " << GetWorkQueueSize() << "\n"; + ss << "workers:\n"; + for (unsigned i = 0; i < WorkerThreads.size(); ++i) { + ss << "-- " << AtomicGet(*AtomicGet(WorkerThreads[i]->WhatThreadDoesLocation)) << "\n"; + } + r.Status = ss.Str(); + } + + r.History = Impl->History.Capture(); + + return r; +} + +bool TExecutor::IsInExecutorThread() const { + return ThreadCurrentExecutor == this; +} + +TAutoPtr<IWorkItem> TExecutor::DequeueWork() { + IWorkItem* wi = reinterpret_cast<IWorkItem*>(1); + size_t queueSize = Max<size_t>(); + if (!WorkItems.TryPop(&wi, &queueSize)) { + TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for DequeueWork"); + while (!WorkItems.TryPop(&wi, &queueSize)) { + if (AtomicGet(ExitWorkers) != 0) + return nullptr; + + TWhatThreadDoesPushPop pp("waiting for work on condvar"); + WorkAvailable.Wait(WorkMutex); + } + } + + auto& wtls = TlsRef(WorkerThreadLocalData); + + if (queueSize > RelaxedLoad(&wtls.MaxQueueSize)) { + RelaxedStore<ui32>(&wtls.MaxQueueSize, queueSize); + } + + return wi; +} + +void TExecutor::RunWorkItem(TAutoPtr<IWorkItem> wi) { + WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC(); + wi.Release()->DoWork(); +} + +void TExecutor::ProcessWorkQueueHere() { + IWorkItem* wi; + while (WorkItems.TryPop(&wi)) { + RunWorkItem(wi); + } +} + +void TExecutor::RunWorker() { + Y_VERIFY(!ThreadCurrentExecutor, "state check"); + ThreadCurrentExecutor = this; + + SetCurrentThreadName("wrkr"); + + for (;;) { + TAutoPtr<IWorkItem> wi = DequeueWork(); + if (!wi) { + break; + } + // Note for messagebus users: make sure program crashes + // on uncaught exception in thread, otherewise messagebus may just hang on error. + RunWorkItem(wi); + } + + ThreadCurrentExecutor = (TExecutor*)nullptr; +} diff --git a/library/cpp/messagebus/actor/executor.h b/library/cpp/messagebus/actor/executor.h new file mode 100644 index 0000000000..7292d8be53 --- /dev/null +++ b/library/cpp/messagebus/actor/executor.h @@ -0,0 +1,105 @@ +#pragma once + +#include "ring_buffer_with_spin_lock.h" + +#include <util/generic/array_ref.h> +#include <util/generic/vector.h> +#include <util/system/condvar.h> +#include <util/system/event.h> +#include <util/system/mutex.h> +#include <util/system/thread.h> +#include <util/thread/lfqueue.h> + +namespace NActor { + namespace NPrivate { + struct TExecutorHistory { + struct THistoryRecord { + ui32 MaxQueueSize; + }; + TVector<THistoryRecord> HistoryRecords; + ui64 LastHistoryRecordSecond; + + ui64 FirstHistoryRecordSecond() const { + return LastHistoryRecordSecond - HistoryRecords.size() + 1; + } + }; + + struct TExecutorStatus { + size_t WorkQueueSize = 0; + TExecutorHistory History; + TString Status; + }; + } + + class IWorkItem { + public: + virtual ~IWorkItem() { + } + virtual void DoWork(/* must release this */) = 0; + }; + + struct TExecutorWorker; + + class TExecutor: public TAtomicRefCount<TExecutor> { + friend struct TExecutorWorker; + + public: + struct TConfig { + size_t WorkerCount; + const char* Name; + + TConfig() + : WorkerCount(1) + , Name() + { + } + }; + + private: + struct TImpl; + THolder<TImpl> Impl; + + const TConfig Config; + + TAtomic ExitWorkers; + + TVector<TAutoPtr<TExecutorWorker>> WorkerThreads; + + TRingBufferWithSpinLock<IWorkItem*> WorkItems; + + TMutex WorkMutex; + TCondVar WorkAvailable; + + public: + explicit TExecutor(size_t workerCount); + TExecutor(const TConfig& config); + ~TExecutor(); + + void Stop(); + + void EnqueueWork(TArrayRef<IWorkItem* const> w); + + size_t GetWorkQueueSize() const; + TString GetStatus() const; + TString GetStatusSingleLine() const; + NPrivate::TExecutorStatus GetStatusRecordInternal() const; + + bool IsInExecutorThread() const; + + private: + void Init(); + + TAutoPtr<IWorkItem> DequeueWork(); + + void ProcessWorkQueueHere(); + + inline void RunWorkItem(TAutoPtr<IWorkItem>); + + void RunWorker(); + + ui32 GetMaxQueueSizeAndClear() const; + }; + + using TExecutorPtr = TIntrusivePtr<TExecutor>; + +} diff --git a/library/cpp/messagebus/actor/queue_for_actor.h b/library/cpp/messagebus/actor/queue_for_actor.h new file mode 100644 index 0000000000..40fa536b82 --- /dev/null +++ b/library/cpp/messagebus/actor/queue_for_actor.h @@ -0,0 +1,74 @@ +#pragma once + +#include <util/generic/vector.h> +#include <util/system/yassert.h> +#include <util/thread/lfstack.h> +#include <util/thread/singleton.h> + +// TODO: include from correct directory +#include "temp_tls_vector.h" + +namespace NActor { + namespace NPrivate { + struct TTagForTl {}; + + } + + template <typename T> + class TQueueForActor { + private: + TLockFreeStack<T> Queue; + + public: + ~TQueueForActor() { + Y_VERIFY(Queue.IsEmpty()); + } + + bool IsEmpty() { + return Queue.IsEmpty(); + } + + void Enqueue(const T& value) { + Queue.Enqueue(value); + } + + template <typename TCollection> + void EnqueueAll(const TCollection& all) { + Queue.EnqueueAll(all); + } + + void Clear() { + TVector<T> tmp; + Queue.DequeueAll(&tmp); + } + + template <typename TFunc> + void DequeueAll(const TFunc& func + // TODO: , std::enable_if_t<TFunctionParamCount<TFunc>::Value == 1>* = 0 + ) { + TTempTlsVector<T> temp; + + Queue.DequeueAllSingleConsumer(temp.GetVector()); + + for (typename TVector<T>::reverse_iterator i = temp.GetVector()->rbegin(); i != temp.GetVector()->rend(); ++i) { + func(*i); + } + + temp.Clear(); + + if (temp.Capacity() * sizeof(T) > 64 * 1024) { + temp.Shrink(); + } + } + + template <typename TFunc> + void DequeueAllLikelyEmpty(const TFunc& func) { + if (Y_LIKELY(IsEmpty())) { + return; + } + + DequeueAll(func); + } + }; + +} diff --git a/library/cpp/messagebus/actor/queue_in_actor.h b/library/cpp/messagebus/actor/queue_in_actor.h new file mode 100644 index 0000000000..9865996532 --- /dev/null +++ b/library/cpp/messagebus/actor/queue_in_actor.h @@ -0,0 +1,80 @@ +#pragma once + +#include "actor.h" +#include "queue_for_actor.h" + +#include <functional> + +namespace NActor { + template <typename TItem> + class IQueueInActor { + public: + virtual void EnqueueAndScheduleV(const TItem& item) = 0; + virtual void DequeueAllV() = 0; + virtual void DequeueAllLikelyEmptyV() = 0; + + virtual ~IQueueInActor() { + } + }; + + template <typename TThis, typename TItem, typename TActorTag = TDefaultTag, typename TQueueTag = TDefaultTag> + class TQueueInActor: public IQueueInActor<TItem> { + typedef TQueueInActor<TThis, TItem, TActorTag, TQueueTag> TSelf; + + public: + // TODO: make protected + TQueueForActor<TItem> QueueInActor; + + private: + TActor<TThis, TActorTag>* GetActor() { + return GetThis(); + } + + TThis* GetThis() { + return static_cast<TThis*>(this); + } + + void ProcessItem(const TItem& item) { + GetThis()->ProcessItem(TActorTag(), TQueueTag(), item); + } + + public: + void EnqueueAndNoSchedule(const TItem& item) { + QueueInActor.Enqueue(item); + } + + void EnqueueAndSchedule(const TItem& item) { + EnqueueAndNoSchedule(item); + GetActor()->Schedule(); + } + + void EnqueueAndScheduleV(const TItem& item) override { + EnqueueAndSchedule(item); + } + + void Clear() { + QueueInActor.Clear(); + } + + void DequeueAll() { + QueueInActor.DequeueAll(std::bind(&TSelf::ProcessItem, this, std::placeholders::_1)); + } + + void DequeueAllV() override { + return DequeueAll(); + } + + void DequeueAllLikelyEmpty() { + QueueInActor.DequeueAllLikelyEmpty(std::bind(&TSelf::ProcessItem, this, std::placeholders::_1)); + } + + void DequeueAllLikelyEmptyV() override { + return DequeueAllLikelyEmpty(); + } + + bool IsEmpty() { + return QueueInActor.IsEmpty(); + } + }; + +} diff --git a/library/cpp/messagebus/actor/ring_buffer.h b/library/cpp/messagebus/actor/ring_buffer.h new file mode 100644 index 0000000000..ec5706f7c7 --- /dev/null +++ b/library/cpp/messagebus/actor/ring_buffer.h @@ -0,0 +1,135 @@ +#pragma once + +#include <util/generic/array_ref.h> +#include <util/generic/maybe.h> +#include <util/generic/utility.h> +#include <util/generic/vector.h> +#include <util/system/yassert.h> + +template <typename T> +struct TRingBuffer { +private: + ui32 CapacityPow; + ui32 CapacityMask; + ui32 Capacity; + ui32 WritePos; + ui32 ReadPos; + TVector<T> Data; + + void StateCheck() const { + Y_ASSERT(Capacity == Data.size()); + Y_ASSERT(Capacity == (1u << CapacityPow)); + Y_ASSERT((Capacity & CapacityMask) == 0u); + Y_ASSERT(Capacity - CapacityMask == 1u); + Y_ASSERT(WritePos < Capacity); + Y_ASSERT(ReadPos < Capacity); + } + + size_t Writable() const { + return (Capacity + ReadPos - WritePos - 1) & CapacityMask; + } + + void ReserveWritable(ui32 sz) { + if (sz <= Writable()) + return; + + ui32 newCapacityPow = CapacityPow; + while ((1u << newCapacityPow) < sz + ui32(Size()) + 1u) { + ++newCapacityPow; + } + ui32 newCapacity = 1u << newCapacityPow; + ui32 newCapacityMask = newCapacity - 1u; + TVector<T> newData(newCapacity); + ui32 oldSize = Size(); + // Copy old elements + for (size_t i = 0; i < oldSize; ++i) { + newData[i] = Get(i); + } + + CapacityPow = newCapacityPow; + Capacity = newCapacity; + CapacityMask = newCapacityMask; + Data.swap(newData); + ReadPos = 0; + WritePos = oldSize; + + StateCheck(); + } + + const T& Get(ui32 i) const { + return Data[(ReadPos + i) & CapacityMask]; + } + +public: + TRingBuffer() + : CapacityPow(0) + , CapacityMask(0) + , Capacity(1 << CapacityPow) + , WritePos(0) + , ReadPos(0) + , Data(Capacity) + { + StateCheck(); + } + + size_t Size() const { + return (Capacity + WritePos - ReadPos) & CapacityMask; + } + + bool Empty() const { + return WritePos == ReadPos; + } + + void PushAll(TArrayRef<const T> value) { + ReserveWritable(value.size()); + + ui32 secondSize; + ui32 firstSize; + + if (WritePos + value.size() <= Capacity) { + firstSize = value.size(); + secondSize = 0; + } else { + firstSize = Capacity - WritePos; + secondSize = value.size() - firstSize; + } + + for (size_t i = 0; i < firstSize; ++i) { + Data[WritePos + i] = value[i]; + } + + for (size_t i = 0; i < secondSize; ++i) { + Data[i] = value[firstSize + i]; + } + + WritePos = (WritePos + value.size()) & CapacityMask; + StateCheck(); + } + + void Push(const T& t) { + PushAll(MakeArrayRef(&t, 1)); + } + + bool TryPop(T* r) { + StateCheck(); + if (Empty()) { + return false; + } + *r = Data[ReadPos]; + ReadPos = (ReadPos + 1) & CapacityMask; + return true; + } + + TMaybe<T> TryPop() { + T tmp; + if (TryPop(&tmp)) { + return tmp; + } else { + return TMaybe<T>(); + } + } + + T Pop() { + return *TryPop(); + } +}; diff --git a/library/cpp/messagebus/actor/ring_buffer_ut.cpp b/library/cpp/messagebus/actor/ring_buffer_ut.cpp new file mode 100644 index 0000000000..bdb379b3a9 --- /dev/null +++ b/library/cpp/messagebus/actor/ring_buffer_ut.cpp @@ -0,0 +1,60 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "ring_buffer.h" + +#include <util/random/random.h> + +Y_UNIT_TEST_SUITE(RingBuffer) { + struct TRingBufferTester { + TRingBuffer<unsigned> RingBuffer; + + unsigned NextPush; + unsigned NextPop; + + TRingBufferTester() + : NextPush() + , NextPop() + { + } + + void Push() { + //Cerr << "push " << NextPush << "\n"; + RingBuffer.Push(NextPush); + NextPush += 1; + } + + void Pop() { + //Cerr << "pop " << NextPop << "\n"; + unsigned popped = RingBuffer.Pop(); + UNIT_ASSERT_VALUES_EQUAL(NextPop, popped); + NextPop += 1; + } + + bool Empty() const { + UNIT_ASSERT_VALUES_EQUAL(RingBuffer.Size(), NextPush - NextPop); + UNIT_ASSERT_VALUES_EQUAL(RingBuffer.Empty(), RingBuffer.Size() == 0); + return RingBuffer.Empty(); + } + }; + + void Iter() { + TRingBufferTester rb; + + while (rb.NextPush < 1000) { + rb.Push(); + while (!rb.Empty() && RandomNumber<bool>()) { + rb.Pop(); + } + } + + while (!rb.Empty()) { + rb.Pop(); + } + } + + Y_UNIT_TEST(Random) { + for (unsigned i = 0; i < 100; ++i) { + Iter(); + } + } +} diff --git a/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h b/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h new file mode 100644 index 0000000000..f0b7cd90e4 --- /dev/null +++ b/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h @@ -0,0 +1,91 @@ +#pragma once + +#include "ring_buffer.h" + +#include <util/system/spinlock.h> + +template <typename T> +class TRingBufferWithSpinLock { +private: + TRingBuffer<T> RingBuffer; + TSpinLock SpinLock; + TAtomic CachedSize; + +public: + TRingBufferWithSpinLock() + : CachedSize(0) + { + } + + void Push(const T& t) { + PushAll(t); + } + + void PushAll(TArrayRef<const T> collection) { + if (collection.empty()) { + return; + } + + TGuard<TSpinLock> Guard(SpinLock); + RingBuffer.PushAll(collection); + AtomicSet(CachedSize, RingBuffer.Size()); + } + + bool TryPop(T* r, size_t* sizePtr = nullptr) { + if (AtomicGet(CachedSize) == 0) { + return false; + } + + bool ok; + size_t size; + { + TGuard<TSpinLock> Guard(SpinLock); + ok = RingBuffer.TryPop(r); + size = RingBuffer.Size(); + AtomicSet(CachedSize, size); + } + if (!!sizePtr) { + *sizePtr = size; + } + return ok; + } + + TMaybe<T> TryPop() { + T tmp; + if (TryPop(&tmp)) { + return tmp; + } else { + return TMaybe<T>(); + } + } + + bool PushAllAndTryPop(TArrayRef<const T> collection, T* r) { + if (collection.size() == 0) { + return TryPop(r); + } else { + if (AtomicGet(CachedSize) == 0) { + *r = collection[0]; + if (collection.size() > 1) { + TGuard<TSpinLock> guard(SpinLock); + RingBuffer.PushAll(MakeArrayRef(collection.data() + 1, collection.size() - 1)); + AtomicSet(CachedSize, RingBuffer.Size()); + } + } else { + TGuard<TSpinLock> guard(SpinLock); + RingBuffer.PushAll(collection); + *r = RingBuffer.Pop(); + AtomicSet(CachedSize, RingBuffer.Size()); + } + return true; + } + } + + bool Empty() const { + return AtomicGet(CachedSize) == 0; + } + + size_t Size() const { + TGuard<TSpinLock> Guard(SpinLock); + return RingBuffer.Size(); + } +}; diff --git a/library/cpp/messagebus/actor/tasks.h b/library/cpp/messagebus/actor/tasks.h new file mode 100644 index 0000000000..31d35931d2 --- /dev/null +++ b/library/cpp/messagebus/actor/tasks.h @@ -0,0 +1,48 @@ +#pragma once + +#include <util/system/atomic.h> +#include <util/system/yassert.h> + +namespace NActor { + class TTasks { + enum { + // order of values is important + E_WAITING, + E_RUNNING_NO_TASKS, + E_RUNNING_GOT_TASKS, + }; + + private: + TAtomic State; + + public: + TTasks() + : State(E_WAITING) + { + } + + // @return true iff caller have to either schedule task or execute it + bool AddTask() { + // High contention case optimization: AtomicGet is cheaper than AtomicSwap. + if (E_RUNNING_GOT_TASKS == AtomicGet(State)) { + return false; + } + + TAtomicBase oldState = AtomicSwap(&State, E_RUNNING_GOT_TASKS); + return oldState == E_WAITING; + } + + // called by executor + // @return true iff we have to recheck queues + bool FetchTask() { + TAtomicBase newState = AtomicDecrement(State); + if (newState == E_RUNNING_NO_TASKS) { + return true; + } else if (newState == E_WAITING) { + return false; + } + Y_FAIL("unknown"); + } + }; + +} diff --git a/library/cpp/messagebus/actor/tasks_ut.cpp b/library/cpp/messagebus/actor/tasks_ut.cpp new file mode 100644 index 0000000000..d80e8451a5 --- /dev/null +++ b/library/cpp/messagebus/actor/tasks_ut.cpp @@ -0,0 +1,37 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "tasks.h" + +using namespace NActor; + +Y_UNIT_TEST_SUITE(TTasks) { + Y_UNIT_TEST(AddTask_FetchTask_Simple) { + TTasks tasks; + + UNIT_ASSERT(tasks.AddTask()); + UNIT_ASSERT(!tasks.AddTask()); + UNIT_ASSERT(!tasks.AddTask()); + + UNIT_ASSERT(tasks.FetchTask()); + UNIT_ASSERT(!tasks.FetchTask()); + + UNIT_ASSERT(tasks.AddTask()); + } + + Y_UNIT_TEST(AddTask_FetchTask_AddTask) { + TTasks tasks; + + UNIT_ASSERT(tasks.AddTask()); + UNIT_ASSERT(!tasks.AddTask()); + + UNIT_ASSERT(tasks.FetchTask()); + UNIT_ASSERT(!tasks.AddTask()); + UNIT_ASSERT(tasks.FetchTask()); + UNIT_ASSERT(!tasks.AddTask()); + UNIT_ASSERT(!tasks.AddTask()); + UNIT_ASSERT(tasks.FetchTask()); + UNIT_ASSERT(!tasks.FetchTask()); + + UNIT_ASSERT(tasks.AddTask()); + } +} diff --git a/library/cpp/messagebus/actor/temp_tls_vector.h b/library/cpp/messagebus/actor/temp_tls_vector.h new file mode 100644 index 0000000000..675d92f5b0 --- /dev/null +++ b/library/cpp/messagebus/actor/temp_tls_vector.h @@ -0,0 +1,40 @@ +#pragma once + +#include "thread_extra.h" + +#include <util/generic/vector.h> +#include <util/system/yassert.h> + +template <typename T, typename TTag = void, template <typename, class> class TVectorType = TVector> +class TTempTlsVector { +private: + struct TTagForTls {}; + + TVectorType<T, std::allocator<T>>* Vector; + +public: + TVectorType<T, std::allocator<T>>* GetVector() { + return Vector; + } + + TTempTlsVector() { + Vector = FastTlsSingletonWithTag<TVectorType<T, std::allocator<T>>, TTagForTls>(); + Y_ASSERT(Vector->empty()); + } + + ~TTempTlsVector() { + Clear(); + } + + void Clear() { + Vector->clear(); + } + + size_t Capacity() const noexcept { + return Vector->capacity(); + } + + void Shrink() { + Vector->shrink_to_fit(); + } +}; diff --git a/library/cpp/messagebus/actor/thread_extra.cpp b/library/cpp/messagebus/actor/thread_extra.cpp new file mode 100644 index 0000000000..048480f255 --- /dev/null +++ b/library/cpp/messagebus/actor/thread_extra.cpp @@ -0,0 +1,30 @@ +#include "thread_extra.h" + +#include <util/stream/str.h> +#include <util/system/execpath.h> +#include <util/system/platform.h> +#include <util/system/thread.h> + +namespace { +#ifdef _linux_ + TString GetExecName() { + TString execPath = GetExecPath(); + size_t lastSlash = execPath.find_last_of('/'); + if (lastSlash == TString::npos) { + return execPath; + } else { + return execPath.substr(lastSlash + 1); + } + } +#endif +} + +void SetCurrentThreadName(const char* name) { +#ifdef _linux_ + TStringStream linuxName; + linuxName << GetExecName() << "." << name; + TThread::SetCurrentThreadName(linuxName.Str().data()); +#else + TThread::SetCurrentThreadName(name); +#endif +} diff --git a/library/cpp/messagebus/actor/thread_extra.h b/library/cpp/messagebus/actor/thread_extra.h new file mode 100644 index 0000000000..b5aa151618 --- /dev/null +++ b/library/cpp/messagebus/actor/thread_extra.h @@ -0,0 +1,41 @@ +#pragma once + +#include <util/thread/singleton.h> + +namespace NTSAN { + template <typename T> + inline void RelaxedStore(volatile T* a, T x) { + static_assert(std::is_integral<T>::value || std::is_pointer<T>::value, "expect std::is_integral<T>::value || std::is_pointer<T>::value"); +#ifdef _win_ + *a = x; +#else + __atomic_store_n(a, x, __ATOMIC_RELAXED); +#endif + } + + template <typename T> + inline T RelaxedLoad(volatile T* a) { +#ifdef _win_ + return *a; +#else + return __atomic_load_n(a, __ATOMIC_RELAXED); +#endif + } + +} + +void SetCurrentThreadName(const char* name); + +namespace NThreadExtra { + namespace NPrivate { + template <typename TValue, typename TTag> + struct TValueHolder { + TValue Value; + }; + } +} + +template <typename TValue, typename TTag> +static inline TValue* FastTlsSingletonWithTag() { + return &FastTlsSingleton< ::NThreadExtra::NPrivate::TValueHolder<TValue, TTag>>()->Value; +} diff --git a/library/cpp/messagebus/actor/what_thread_does.cpp b/library/cpp/messagebus/actor/what_thread_does.cpp new file mode 100644 index 0000000000..bebb6a888c --- /dev/null +++ b/library/cpp/messagebus/actor/what_thread_does.cpp @@ -0,0 +1,22 @@ +#include "what_thread_does.h" + +#include "thread_extra.h" + +#include <util/system/tls.h> + +Y_POD_STATIC_THREAD(const char*) +WhatThreadDoes; + +const char* PushWhatThreadDoes(const char* what) { + const char* r = NTSAN::RelaxedLoad(&WhatThreadDoes); + NTSAN::RelaxedStore(&WhatThreadDoes, what); + return r; +} + +void PopWhatThreadDoes(const char* prev) { + NTSAN::RelaxedStore(&WhatThreadDoes, prev); +} + +const char** WhatThreadDoesLocation() { + return &WhatThreadDoes; +} diff --git a/library/cpp/messagebus/actor/what_thread_does.h b/library/cpp/messagebus/actor/what_thread_does.h new file mode 100644 index 0000000000..235d2c3700 --- /dev/null +++ b/library/cpp/messagebus/actor/what_thread_does.h @@ -0,0 +1,28 @@ +#pragma once + +const char* PushWhatThreadDoes(const char* what); +void PopWhatThreadDoes(const char* prev); +const char** WhatThreadDoesLocation(); + +struct TWhatThreadDoesPushPop { +private: + const char* Prev; + +public: + TWhatThreadDoesPushPop(const char* what) { + Prev = PushWhatThreadDoes(what); + } + + ~TWhatThreadDoesPushPop() { + PopWhatThreadDoes(Prev); + } +}; + +#ifdef __GNUC__ +#define WHAT_THREAD_DOES_FUNCTION __PRETTY_FUNCTION__ +#else +#define WHAT_THREAD_DOES_FUNCTION __FUNCTION__ +#endif + +#define WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC() \ + TWhatThreadDoesPushPop whatThreadDoesPushPopCurrentFunc(WHAT_THREAD_DOES_FUNCTION) diff --git a/library/cpp/messagebus/actor/what_thread_does_guard.h b/library/cpp/messagebus/actor/what_thread_does_guard.h new file mode 100644 index 0000000000..f104e9e173 --- /dev/null +++ b/library/cpp/messagebus/actor/what_thread_does_guard.h @@ -0,0 +1,40 @@ +#pragma once + +#include "what_thread_does.h" + +template <class T> +class TWhatThreadDoesAcquireGuard: public TNonCopyable { +public: + inline TWhatThreadDoesAcquireGuard(const T& t, const char* acquire) noexcept { + Init(&t, acquire); + } + + inline TWhatThreadDoesAcquireGuard(const T* t, const char* acquire) noexcept { + Init(t, acquire); + } + + inline ~TWhatThreadDoesAcquireGuard() { + Release(); + } + + inline void Release() noexcept { + if (WasAcquired()) { + const_cast<T*>(T_)->Release(); + T_ = nullptr; + } + } + + inline bool WasAcquired() const noexcept { + return T_ != nullptr; + } + +private: + inline void Init(const T* t, const char* acquire) noexcept { + T_ = const_cast<T*>(t); + TWhatThreadDoesPushPop pp(acquire); + T_->Acquire(); + } + +private: + T* T_; +}; diff --git a/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp b/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp new file mode 100644 index 0000000000..e4b218a7ca --- /dev/null +++ b/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp @@ -0,0 +1,13 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "what_thread_does_guard.h" + +#include <util/system/mutex.h> + +Y_UNIT_TEST_SUITE(WhatThreadDoesGuard) { + Y_UNIT_TEST(Simple) { + TMutex mutex; + + TWhatThreadDoesAcquireGuard<TMutex> guard(mutex, "acquiring my mutex"); + } +} diff --git a/library/cpp/messagebus/actor/ya.make b/library/cpp/messagebus/actor/ya.make new file mode 100644 index 0000000000..59bd1b0b99 --- /dev/null +++ b/library/cpp/messagebus/actor/ya.make @@ -0,0 +1,11 @@ +LIBRARY(messagebus_actor) + +OWNER(g:messagebus) + +SRCS( + executor.cpp + thread_extra.cpp + what_thread_does.cpp +) + +END() diff --git a/library/cpp/messagebus/all.lwt b/library/cpp/messagebus/all.lwt new file mode 100644 index 0000000000..0f04be4b2c --- /dev/null +++ b/library/cpp/messagebus/all.lwt @@ -0,0 +1,8 @@ +Blocks { + ProbeDesc { + Group: "MessagebusRare" + } + Action { + PrintToStderrAction {} + } +} diff --git a/library/cpp/messagebus/all/ya.make b/library/cpp/messagebus/all/ya.make new file mode 100644 index 0000000000..ffa2dbfabc --- /dev/null +++ b/library/cpp/messagebus/all/ya.make @@ -0,0 +1,10 @@ +OWNER(g:messagebus) + +RECURSE_ROOT_RELATIVE( + library/python/messagebus + library/cpp/messagebus/debug_receiver + library/cpp/messagebus/oldmodule + library/cpp/messagebus/rain_check + library/cpp/messagebus/test + library/cpp/messagebus/www +) diff --git a/library/cpp/messagebus/async_result.h b/library/cpp/messagebus/async_result.h new file mode 100644 index 0000000000..d24dde284a --- /dev/null +++ b/library/cpp/messagebus/async_result.h @@ -0,0 +1,54 @@ +#pragma once + +#include <util/generic/maybe.h> +#include <util/generic/noncopyable.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/system/yassert.h> + +#include <functional> + +// probably this thing should have been called TFuture +template <typename T> +class TAsyncResult : TNonCopyable { +private: + TMutex Mutex; + TCondVar CondVar; + + TMaybe<T> Result; + + typedef void TOnResult(const T&); + + std::function<TOnResult> OnResult; + +public: + void SetResult(const T& result) { + TGuard<TMutex> guard(Mutex); + Y_VERIFY(!Result, "cannot set result twice"); + Result = result; + CondVar.BroadCast(); + + if (!!OnResult) { + OnResult(result); + } + } + + const T& GetResult() { + TGuard<TMutex> guard(Mutex); + while (!Result) { + CondVar.Wait(Mutex); + } + return *Result; + } + + template <typename TFunc> + void AndThen(const TFunc& onResult) { + TGuard<TMutex> guard(Mutex); + if (!!Result) { + onResult(*Result); + } else { + Y_ASSERT(!OnResult); + OnResult = std::function<TOnResult>(onResult); + } + } +}; diff --git a/library/cpp/messagebus/async_result_ut.cpp b/library/cpp/messagebus/async_result_ut.cpp new file mode 100644 index 0000000000..2e96492afd --- /dev/null +++ b/library/cpp/messagebus/async_result_ut.cpp @@ -0,0 +1,37 @@ + +#include <library/cpp/testing/unittest/registar.h> + +#include "async_result.h" + +namespace { + void SetValue(int* location, const int& value) { + *location = value; + } + +} + +Y_UNIT_TEST_SUITE(TAsyncResult) { + Y_UNIT_TEST(AndThen_Here) { + TAsyncResult<int> r; + + int var = 1; + + r.SetResult(17); + + r.AndThen(std::bind(&SetValue, &var, std::placeholders::_1)); + + UNIT_ASSERT_VALUES_EQUAL(17, var); + } + + Y_UNIT_TEST(AndThen_Later) { + TAsyncResult<int> r; + + int var = 1; + + r.AndThen(std::bind(&SetValue, &var, std::placeholders::_1)); + + r.SetResult(17); + + UNIT_ASSERT_VALUES_EQUAL(17, var); + } +} diff --git a/library/cpp/messagebus/base.h b/library/cpp/messagebus/base.h new file mode 100644 index 0000000000..79fccc312e --- /dev/null +++ b/library/cpp/messagebus/base.h @@ -0,0 +1,11 @@ +#pragma once + +#include <util/system/defaults.h> + +namespace NBus { + /// millis since epoch + using TBusInstant = ui64; + /// returns time in milliseconds + TBusInstant Now(); + +} diff --git a/library/cpp/messagebus/cc_semaphore.h b/library/cpp/messagebus/cc_semaphore.h new file mode 100644 index 0000000000..0df8a3d664 --- /dev/null +++ b/library/cpp/messagebus/cc_semaphore.h @@ -0,0 +1,36 @@ +#pragma once + +#include "latch.h" + +template <typename TThis> +class TComplexConditionSemaphore { +private: + TLatch Latch; + +public: + void Updated() { + if (GetThis()->TryWait()) { + Latch.Unlock(); + } + } + + void Wait() { + while (!GetThis()->TryWait()) { + Latch.Lock(); + if (GetThis()->TryWait()) { + Latch.Unlock(); + return; + } + Latch.Wait(); + } + } + + bool IsLocked() { + return Latch.IsLocked(); + } + +private: + TThis* GetThis() { + return static_cast<TThis*>(this); + } +}; diff --git a/library/cpp/messagebus/cc_semaphore_ut.cpp b/library/cpp/messagebus/cc_semaphore_ut.cpp new file mode 100644 index 0000000000..206bb7c96a --- /dev/null +++ b/library/cpp/messagebus/cc_semaphore_ut.cpp @@ -0,0 +1,45 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "cc_semaphore.h" + +#include <util/system/atomic.h> + +namespace { + struct TTestSemaphore: public TComplexConditionSemaphore<TTestSemaphore> { + TAtomic Current; + + TTestSemaphore() + : Current(0) + { + } + + bool TryWait() { + return AtomicGet(Current) > 0; + } + + void Aquire() { + Wait(); + AtomicDecrement(Current); + } + + void Release() { + AtomicIncrement(Current); + Updated(); + } + }; +} + +Y_UNIT_TEST_SUITE(TComplexConditionSemaphore) { + Y_UNIT_TEST(Simple) { + TTestSemaphore sema; + UNIT_ASSERT(!sema.TryWait()); + sema.Release(); + UNIT_ASSERT(sema.TryWait()); + sema.Release(); + UNIT_ASSERT(sema.TryWait()); + sema.Aquire(); + UNIT_ASSERT(sema.TryWait()); + sema.Aquire(); + UNIT_ASSERT(!sema.TryWait()); + } +} diff --git a/library/cpp/messagebus/codegen.h b/library/cpp/messagebus/codegen.h new file mode 100644 index 0000000000..83e969e811 --- /dev/null +++ b/library/cpp/messagebus/codegen.h @@ -0,0 +1,4 @@ +#pragma once + +#include <library/cpp/messagebus/config/codegen.h> + diff --git a/library/cpp/messagebus/config/codegen.h b/library/cpp/messagebus/config/codegen.h new file mode 100644 index 0000000000..97ddada005 --- /dev/null +++ b/library/cpp/messagebus/config/codegen.h @@ -0,0 +1,10 @@ +#pragma once + +#define COMMA , + +#define STRUCT_FIELD_GEN(name, type, ...) type name; + +#define STRUCT_FIELD_INIT(name, type, defa) name(defa) +#define STRUCT_FIELD_INIT_DEFAULT(name, type, ...) name() + +#define STRUCT_FIELD_PRINT(name, ...) ss << #name << "=" << name << "\n"; diff --git a/library/cpp/messagebus/config/defs.h b/library/cpp/messagebus/config/defs.h new file mode 100644 index 0000000000..92b1df9969 --- /dev/null +++ b/library/cpp/messagebus/config/defs.h @@ -0,0 +1,82 @@ +#pragma once + +// unique tag to fix pragma once gcc glueing: ./library/cpp/messagebus/defs.h + +#include "codegen.h" +#include "netaddr.h" + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +#include <util/generic/list.h> + +#include <utility> + +// For historical reasons TCrawlerModule need to access +// APIs that should be private. +class TCrawlerModule; + +struct TDebugReceiverHandler; + +namespace NBus { + namespace NPrivate { + class TAcceptor; + struct TBusSessionImpl; + class TRemoteServerSession; + class TRemoteClientSession; + class TRemoteConnection; + class TRemoteServerConnection; + class TRemoteClientConnection; + class TBusSyncSourceSessionImpl; + + struct TBusMessagePtrAndHeader; + + struct TSessionDumpStatus; + + struct TClientRequestImpl; + + } + + class TBusSession; + struct TBusServerSession; + struct TBusClientSession; + class TBusProtocol; + class TBusMessage; + class TBusMessageConnection; + class TBusMessageQueue; + class TBusLocator; + struct TBusQueueConfig; + struct TBusSessionConfig; + struct TBusHeader; + + class IThreadHandler; + + using TBusKey = ui64; + using TBusMessageList = TList<TBusMessage*>; + using TBusKeyVec = TVector<std::pair<TBusKey, TBusKey>>; + + using TBusMessageQueuePtr = TIntrusivePtr<TBusMessageQueue>; + + class TBusModule; + + using TBusData = TString; + using TBusService = const char*; + +#define YBUS_KEYMIN TBusKey(0L) +#define YBUS_KEYMAX TBusKey(-1L) +#define YBUS_KEYLOCAL TBusKey(7L) +#define YBUS_KEYINVALID TBusKey(99999999L) + + // Check that generated id is valid for remote message + inline bool IsBusKeyValid(TBusKey key) { + return key != YBUS_KEYINVALID && key != YBUS_KEYMAX && key > YBUS_KEYLOCAL; + } + +#define YBUS_VERSION 0 + +#define YBUS_INFINITE (1u << 30u) + +#define YBUS_STATUS_BASIC 0x0000 +#define YBUS_STATUS_CONNS 0x0001 +#define YBUS_STATUS_INFLIGHT 0x0002 + +} diff --git a/library/cpp/messagebus/config/netaddr.cpp b/library/cpp/messagebus/config/netaddr.cpp new file mode 100644 index 0000000000..962ac538e2 --- /dev/null +++ b/library/cpp/messagebus/config/netaddr.cpp @@ -0,0 +1,183 @@ +#include "netaddr.h" + +#include <util/network/address.h> + +#include <cstdlib> + +namespace NBus { + const char* ToCString(EIpVersion ipVersion) { + switch (ipVersion) { + case EIP_VERSION_ANY: + return "EIP_VERSION_ANY"; + case EIP_VERSION_4: + return "EIP_VERSION_4"; + case EIP_VERSION_6: + return "EIP_VERSION_6"; + } + Y_FAIL(); + } + + int ToAddrFamily(EIpVersion ipVersion) { + switch (ipVersion) { + case EIP_VERSION_ANY: + return AF_UNSPEC; + case EIP_VERSION_4: + return AF_INET; + case EIP_VERSION_6: + return AF_INET6; + } + Y_FAIL(); + } + + class TNetworkAddressRef: private TNetworkAddress, public TAddrInfo { + public: + TNetworkAddressRef(const TNetworkAddress& na, const TAddrInfo& ai) + : TNetworkAddress(na) + , TAddrInfo(ai) + { + } + }; + + static bool Compare(const IRemoteAddr& l, const IRemoteAddr& r) noexcept { + if (l.Addr()->sa_family != r.Addr()->sa_family) { + return false; + } + + switch (l.Addr()->sa_family) { + case AF_INET: { + return memcmp(&(((const sockaddr_in*)l.Addr())->sin_addr), &(((const sockaddr_in*)r.Addr())->sin_addr), sizeof(in_addr)) == 0 && + ((const sockaddr_in*)l.Addr())->sin_port == ((const sockaddr_in*)r.Addr())->sin_port; + } + + case AF_INET6: { + return memcmp(&(((const sockaddr_in6*)l.Addr())->sin6_addr), &(((const sockaddr_in6*)r.Addr())->sin6_addr), sizeof(in6_addr)) == 0 && + ((const sockaddr_in6*)l.Addr())->sin6_port == ((const sockaddr_in6*)r.Addr())->sin6_port; + } + } + + return memcmp(l.Addr(), r.Addr(), Min<size_t>(l.Len(), r.Len())) == 0; + } + + TNetAddr::TNetAddr() + : Ptr(new TOpaqueAddr) + { + } + + TNetAddr::TNetAddr(TAutoPtr<IRemoteAddr> addr) + : Ptr(addr) + { + Y_VERIFY(!!Ptr); + } + + namespace { + using namespace NAddr; + + const char* Describe(EIpVersion version) { + switch (version) { + case EIP_VERSION_4: + return "ipv4 address"; + case EIP_VERSION_6: + return "ipv6 address"; + case EIP_VERSION_ANY: + return "any address"; + default: + Y_FAIL("unreachable"); + } + } + + TAutoPtr<IRemoteAddr> MakeAddress(const TNetworkAddress& na, EIpVersion requireVersion, EIpVersion preferVersion) { + TAutoPtr<IRemoteAddr> addr; + for (TNetworkAddress::TIterator it = na.Begin(); it != na.End(); ++it) { + if (IsFamilyAllowed(it->ai_family, requireVersion)) { + if (IsFamilyAllowed(it->ai_family, preferVersion)) { + return new TNetworkAddressRef(na, &*it); + } else if (!addr) { + addr.Reset(new TNetworkAddressRef(na, &*it)); + } + } + } + return addr; + } + TAutoPtr<IRemoteAddr> MakeAddress(TStringBuf host, int port, EIpVersion requireVersion, EIpVersion preferVersion) { + TString hostString(host); + TNetworkAddress na(hostString, port); + return MakeAddress(na, requireVersion, preferVersion); + } + TAutoPtr<IRemoteAddr> MakeAddress(const char* hostPort, EIpVersion requireVersion, EIpVersion preferVersion) { + const char* portStr = strchr(hostPort, ':'); + if (!portStr) { + ythrow TNetAddr::TError() << "port not specified in " << hostPort; + } + int port = atoi(portStr + 1); + TNetworkAddress na(TString(hostPort, portStr), port); + return MakeAddress(na, requireVersion, preferVersion); + } + } + + TNetAddr::TNetAddr(const char* hostPort, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) + : Ptr(MakeAddress(hostPort, requireVersion, preferVersion)) + { + if (!Ptr) { + ythrow TNetAddr::TError() << "cannot resolve " << hostPort << " into " << Describe(requireVersion); + } + } + + TNetAddr::TNetAddr(TStringBuf host, int port, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) + : Ptr(MakeAddress(host, port, requireVersion, preferVersion)) + { + if (!Ptr) { + ythrow TNetAddr::TError() << "cannot resolve " << host << ":" << port << " into " << Describe(requireVersion); + } + } + + TNetAddr::TNetAddr(const TNetworkAddress& na, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) + : Ptr(MakeAddress(na, requireVersion, preferVersion)) + { + if (!Ptr) { + ythrow TNetAddr::TError() << "cannot resolve into " << Describe(requireVersion); + } + } + + TNetAddr::TNetAddr(const TNetworkAddress& na, const TAddrInfo& ai) + : Ptr(new TNetworkAddressRef(na, ai)) + { + } + + const sockaddr* TNetAddr::Addr() const { + return Ptr->Addr(); + } + + socklen_t TNetAddr::Len() const { + return Ptr->Len(); + } + + int TNetAddr::GetPort() const { + switch (Ptr->Addr()->sa_family) { + case AF_INET: + return InetToHost(((sockaddr_in*)Ptr->Addr())->sin_port); + case AF_INET6: + return InetToHost(((sockaddr_in6*)Ptr->Addr())->sin6_port); + default: + Y_FAIL("unknown AF: %d", (int)Ptr->Addr()->sa_family); + throw 1; + } + } + + bool TNetAddr::IsIpv4() const { + return Ptr->Addr()->sa_family == AF_INET; + } + + bool TNetAddr::IsIpv6() const { + return Ptr->Addr()->sa_family == AF_INET6; + } + + bool TNetAddr::operator==(const TNetAddr& rhs) const { + return Ptr == rhs.Ptr || Compare(*Ptr, *rhs.Ptr); + } + +} + +template <> +void Out<NBus::TNetAddr>(IOutputStream& out, const NBus::TNetAddr& addr) { + Out<NAddr::IRemoteAddr>(out, addr); +} diff --git a/library/cpp/messagebus/config/netaddr.h b/library/cpp/messagebus/config/netaddr.h new file mode 100644 index 0000000000..b79c0cc355 --- /dev/null +++ b/library/cpp/messagebus/config/netaddr.h @@ -0,0 +1,86 @@ +#pragma once + +#include <util/digest/numeric.h> +#include <util/generic/hash.h> +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> +#include <util/generic/vector.h> +#include <util/network/address.h> + +namespace NBus { + using namespace NAddr; + + /// IP protocol version. + enum EIpVersion { + EIP_VERSION_4 = 1, + EIP_VERSION_6 = 2, + EIP_VERSION_ANY = EIP_VERSION_4 | EIP_VERSION_6, + }; + + inline bool IsFamilyAllowed(ui16 sa_family, EIpVersion ipVersion) { + if (ipVersion == EIP_VERSION_4 && sa_family != AF_INET) { + return false; + } + if (ipVersion == EIP_VERSION_6 && sa_family != AF_INET6) { + return false; + } + return true; + } + + const char* ToCString(EIpVersion); + int ToAddrFamily(EIpVersion); + + /// Hold referenced pointer to address description structure (ex. sockaddr_storage) + /// It's make possible to work with IPv4 / IPv6 addresses simultaneously + class TNetAddr: public IRemoteAddr { + public: + class TError: public yexception { + }; + + TNetAddr(); + TNetAddr(TAutoPtr<IRemoteAddr> addr); + TNetAddr(const char* hostPort, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY); + TNetAddr(TStringBuf host, int port, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY); + TNetAddr(const TNetworkAddress& na, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY); + TNetAddr(const TNetworkAddress& na, const TAddrInfo& ai); + + bool operator==(const TNetAddr&) const; + bool operator!=(const TNetAddr& other) const { + return !(*this == other); + } + inline explicit operator bool() const noexcept { + return !!Ptr; + } + + const sockaddr* Addr() const override; + socklen_t Len() const override; + + bool IsIpv4() const; + bool IsIpv6() const; + int GetPort() const; + + private: + TAtomicSharedPtr<IRemoteAddr> Ptr; + }; + + using TSockAddrInVector = TVector<TNetAddr>; + + struct TNetAddrHostPortHash { + inline size_t operator()(const TNetAddr& a) const { + const sockaddr* s = a.Addr(); + const sockaddr_in* const sa = reinterpret_cast<const sockaddr_in*>(s); + const sockaddr_in6* const sa6 = reinterpret_cast<const sockaddr_in6*>(s); + + switch (s->sa_family) { + case AF_INET: + return CombineHashes<size_t>(ComputeHash(TStringBuf(reinterpret_cast<const char*>(&sa->sin_addr), sizeof(sa->sin_addr))), IntHashImpl(sa->sin_port)); + + case AF_INET6: + return CombineHashes<size_t>(ComputeHash(TStringBuf(reinterpret_cast<const char*>(&sa6->sin6_addr), sizeof(sa6->sin6_addr))), IntHashImpl(sa6->sin6_port)); + } + + return ComputeHash(TStringBuf(reinterpret_cast<const char*>(s), a.Len())); + } + }; + +} diff --git a/library/cpp/messagebus/config/session_config.cpp b/library/cpp/messagebus/config/session_config.cpp new file mode 100644 index 0000000000..fbbbb106c9 --- /dev/null +++ b/library/cpp/messagebus/config/session_config.cpp @@ -0,0 +1,157 @@ +#include "session_config.h" + +#include <util/generic/strbuf.h> +#include <util/string/hex.h> + +using namespace NBus; + +TBusSessionConfig::TSecret::TSecret() + : TimeoutPeriod(TDuration::Seconds(1)) + , StatusFlushPeriod(TDuration::MilliSeconds(400)) +{ +} + +TBusSessionConfig::TBusSessionConfig() + : BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_INIT, COMMA) +{ +} + +TString TBusSessionConfig::PrintToString() const { + TStringStream ss; + BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_PRINT, ) + return ss.Str(); +} + +static int ParseDurationForMessageBus(const char* option) { + return TDuration::Parse(option).MilliSeconds(); +} + +static int ParseToSForMessageBus(const char* option) { + int tos; + TStringBuf str(option); + if (str.StartsWith("0x")) { + str = str.Tail(2); + Y_VERIFY(str.length() == 2, "ToS must be a number between 0x00 and 0xFF"); + tos = String2Byte(str.data()); + } else { + tos = FromString<int>(option); + } + Y_VERIFY(tos >= 0 && tos <= 255, "ToS must be between 0x00 and 0xFF"); + return tos; +} + +template <class T> +static T ParseWithKmgSuffixT(const char* option) { + TStringBuf str(option); + T multiplier = 1; + if (str.EndsWith('k')) { + multiplier = 1024; + str = str.Head(str.size() - 1); + } else if (str.EndsWith('m')) { + multiplier = 1024 * 1024; + str = str.Head(str.size() - 1); + } else if (str.EndsWith('g')) { + multiplier = 1024 * 1024 * 1024; + str = str.Head(str.size() - 1); + } + return FromString<T>(str) * multiplier; +} + +static ui64 ParseWithKmgSuffix(const char* option) { + return ParseWithKmgSuffixT<ui64>(option); +} + +static i64 ParseWithKmgSuffixS(const char* option) { + return ParseWithKmgSuffixT<i64>(option); +} + +void TBusSessionConfig::ConfigureLastGetopt(NLastGetopt::TOpts& opts, + const TString& prefix) { + opts.AddLongOption(prefix + "total-timeout") + .RequiredArgument("MILLISECONDS") + .DefaultValue(ToString(TotalTimeout)) + .StoreMappedResultT<const char*>(&TotalTimeout, + &ParseDurationForMessageBus); + opts.AddLongOption(prefix + "connect-timeout") + .RequiredArgument("MILLISECONDS") + .DefaultValue(ToString(ConnectTimeout)) + .StoreMappedResultT<const char*>(&ConnectTimeout, + &ParseDurationForMessageBus); + opts.AddLongOption(prefix + "send-timeout") + .RequiredArgument("MILLISECONDS") + .DefaultValue(ToString(SendTimeout)) + .StoreMappedResultT<const char*>(&SendTimeout, + &ParseDurationForMessageBus); + opts.AddLongOption(prefix + "send-threshold") + .RequiredArgument("BYTES") + .DefaultValue(ToString(SendThreshold)) + .StoreMappedResultT<const char*>(&SendThreshold, &ParseWithKmgSuffix); + + opts.AddLongOption(prefix + "max-in-flight") + .RequiredArgument("COUNT") + .DefaultValue(ToString(MaxInFlight)) + .StoreMappedResultT<const char*>(&MaxInFlight, &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "max-in-flight-by-size") + .RequiredArgument("BYTES") + .DefaultValue( + ToString(MaxInFlightBySize)) + .StoreMappedResultT<const char*>(&MaxInFlightBySize, &ParseWithKmgSuffixS); + opts.AddLongOption(prefix + "per-con-max-in-flight") + .RequiredArgument("COUNT") + .DefaultValue(ToString(PerConnectionMaxInFlight)) + .StoreMappedResultT<const char*>(&PerConnectionMaxInFlight, + &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "per-con-max-in-flight-by-size") + .RequiredArgument("BYTES") + .DefaultValue( + ToString(PerConnectionMaxInFlightBySize)) + .StoreMappedResultT<const char*>(&PerConnectionMaxInFlightBySize, + &ParseWithKmgSuffix); + + opts.AddLongOption(prefix + "default-buffer-size") + .RequiredArgument("BYTES") + .DefaultValue(ToString(DefaultBufferSize)) + .StoreMappedResultT<const char*>(&DefaultBufferSize, + &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "max-buffer-size") + .RequiredArgument("BYTES") + .DefaultValue(ToString(MaxBufferSize)) + .StoreMappedResultT<const char*>(&MaxBufferSize, &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "max-message-size") + .RequiredArgument("BYTES") + .DefaultValue(ToString(MaxMessageSize)) + .StoreMappedResultT<const char*>(&MaxMessageSize, &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "socket-recv-buffer-size") + .RequiredArgument("BYTES") + .DefaultValue(ToString(SocketRecvBufferSize)) + .StoreMappedResultT<const char*>(&SocketRecvBufferSize, + &ParseWithKmgSuffix); + opts.AddLongOption(prefix + "socket-send-buffer-size") + .RequiredArgument("BYTES") + .DefaultValue(ToString(SocketSendBufferSize)) + .StoreMappedResultT<const char*>(&SocketSendBufferSize, + &ParseWithKmgSuffix); + + opts.AddLongOption(prefix + "socket-tos") + .RequiredArgument("[0x00, 0xFF]") + .StoreMappedResultT<const char*>(&SocketToS, &ParseToSForMessageBus); + ; + opts.AddLongOption(prefix + "tcp-cork") + .RequiredArgument("BOOL") + .DefaultValue(ToString(TcpCork)) + .StoreResult(&TcpCork); + opts.AddLongOption(prefix + "cork") + .RequiredArgument("SECONDS") + .DefaultValue( + ToString(Cork.Seconds())) + .StoreMappedResultT<const char*>(&Cork, &TDuration::Parse); + + opts.AddLongOption(prefix + "on-message-in-pool") + .RequiredArgument("BOOL") + .DefaultValue(ToString(ExecuteOnMessageInWorkerPool)) + .StoreResult(&ExecuteOnMessageInWorkerPool); + opts.AddLongOption(prefix + "on-reply-in-pool") + .RequiredArgument("BOOL") + .DefaultValue(ToString(ExecuteOnReplyInWorkerPool)) + .StoreResult(&ExecuteOnReplyInWorkerPool); +} diff --git a/library/cpp/messagebus/config/session_config.h b/library/cpp/messagebus/config/session_config.h new file mode 100644 index 0000000000..84753350a9 --- /dev/null +++ b/library/cpp/messagebus/config/session_config.h @@ -0,0 +1,65 @@ +#pragma once + +#include "codegen.h" +#include "defs.h" + +#include <library/cpp/getopt/last_getopt.h> + +#include <util/generic/string.h> + +namespace NBus { +#define BUS_SESSION_CONFIG_MAP(XX, comma) \ + XX(Name, TString, "") \ + comma \ + XX(NumRetries, int, 0) comma \ + XX(RetryInterval, int, 1000) comma \ + XX(ReconnectWhenIdle, bool, false) comma \ + XX(MaxInFlight, i64, 1000) comma \ + XX(PerConnectionMaxInFlight, unsigned, 0) comma \ + XX(PerConnectionMaxInFlightBySize, unsigned, 0) comma \ + XX(MaxInFlightBySize, i64, -1) comma \ + XX(TotalTimeout, i64, 0) comma \ + XX(SendTimeout, i64, 0) comma \ + XX(ConnectTimeout, i64, 0) comma \ + XX(DefaultBufferSize, size_t, 10 * 1024) comma \ + XX(MaxBufferSize, size_t, 1024 * 1024) comma \ + XX(SocketRecvBufferSize, unsigned, 0) comma \ + XX(SocketSendBufferSize, unsigned, 0) comma \ + XX(SocketToS, int, -1) comma \ + XX(SendThreshold, size_t, 10 * 1024) comma \ + XX(Cork, TDuration, TDuration::Zero()) comma \ + XX(MaxMessageSize, unsigned, 26 << 20) comma \ + XX(TcpNoDelay, bool, false) comma \ + XX(TcpCork, bool, false) comma \ + XX(ExecuteOnMessageInWorkerPool, bool, true) comma \ + XX(ExecuteOnReplyInWorkerPool, bool, true) comma \ + XX(ReusePort, bool, false) comma \ + XX(ListenPort, unsigned, 0) /* TODO: server only */ + + //////////////////////////////////////////////////////////////////// + /// \brief Configuration for client and server session + struct TBusSessionConfig { + BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_GEN, ) + + struct TSecret { + TDuration TimeoutPeriod; + TDuration StatusFlushPeriod; + + TSecret(); + }; + + // secret options are available, but you shouldn't probably use them + TSecret Secret; + + /// initialized with default settings + TBusSessionConfig(); + + TString PrintToString() const; + + void ConfigureLastGetopt(NLastGetopt::TOpts&, const TString& prefix = "mb-"); + }; + + using TBusClientSessionConfig = TBusSessionConfig; + using TBusServerSessionConfig = TBusSessionConfig; + +} // NBus diff --git a/library/cpp/messagebus/config/ya.make b/library/cpp/messagebus/config/ya.make new file mode 100644 index 0000000000..20c7dfed19 --- /dev/null +++ b/library/cpp/messagebus/config/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/getopt + library/cpp/deprecated/enum_codegen +) + +SRCS( + netaddr.cpp + session_config.cpp +) + +END() diff --git a/library/cpp/messagebus/connection.cpp b/library/cpp/messagebus/connection.cpp new file mode 100644 index 0000000000..07580ce18a --- /dev/null +++ b/library/cpp/messagebus/connection.cpp @@ -0,0 +1,16 @@ +#include "connection.h" + +#include "remote_client_connection.h" + +#include <util/generic/cast.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +void TBusClientConnectionPtrOps::Ref(TBusClientConnection* c) { + return CheckedCast<TRemoteClientConnection*>(c)->Ref(); +} + +void TBusClientConnectionPtrOps::UnRef(TBusClientConnection* c) { + return CheckedCast<TRemoteClientConnection*>(c)->UnRef(); +} diff --git a/library/cpp/messagebus/connection.h b/library/cpp/messagebus/connection.h new file mode 100644 index 0000000000..b1df64ddc1 --- /dev/null +++ b/library/cpp/messagebus/connection.h @@ -0,0 +1,61 @@ +#pragma once + +#include "defs.h" +#include "message.h" + +#include <util/generic/ptr.h> + +namespace NBus { + struct TBusClientConnection { + /// if you want to open connection early + virtual void OpenConnection() = 0; + + /// Send message to the destination + /// If addr is set then use it as destination. + /// Takes ownership of addr (see ClearState method). + virtual EMessageStatus SendMessage(TBusMessage* pMes, bool wait = false) = 0; + + virtual EMessageStatus SendMessageOneWay(TBusMessage* pMes, bool wait = false) = 0; + + /// Like SendMessage but cares about message + template <typename T /* <: TBusMessage */> + EMessageStatus SendMessageAutoPtr(const TAutoPtr<T>& mes, bool wait = false) { + EMessageStatus status = SendMessage(mes.Get(), wait); + if (status == MESSAGE_OK) + Y_UNUSED(mes.Release()); + return status; + } + + /// Like SendMessageOneWay but cares about message + template <typename T /* <: TBusMessage */> + EMessageStatus SendMessageOneWayAutoPtr(const TAutoPtr<T>& mes, bool wait = false) { + EMessageStatus status = SendMessageOneWay(mes.Get(), wait); + if (status == MESSAGE_OK) + Y_UNUSED(mes.Release()); + return status; + } + + EMessageStatus SendMessageMove(TBusMessageAutoPtr message, bool wait = false) { + return SendMessageAutoPtr(message, wait); + } + + EMessageStatus SendMessageOneWayMove(TBusMessageAutoPtr message, bool wait = false) { + return SendMessageOneWayAutoPtr(message, wait); + } + + // TODO: implement similar one-way methods + + virtual ~TBusClientConnection() { + } + }; + + namespace NPrivate { + struct TBusClientConnectionPtrOps { + static void Ref(TBusClientConnection*); + static void UnRef(TBusClientConnection*); + }; + } + + using TBusClientConnectionPtr = TIntrusivePtr<TBusClientConnection, NPrivate::TBusClientConnectionPtrOps>; + +} diff --git a/library/cpp/messagebus/coreconn.cpp b/library/cpp/messagebus/coreconn.cpp new file mode 100644 index 0000000000..d9411bb5db --- /dev/null +++ b/library/cpp/messagebus/coreconn.cpp @@ -0,0 +1,30 @@ +#include "coreconn.h" + +#include "remote_connection.h" + +#include <util/datetime/base.h> +#include <util/generic/yexception.h> +#include <util/network/socket.h> +#include <util/string/util.h> +#include <util/system/thread.h> + +namespace NBus { + TBusInstant Now() { + return millisec(); + } + + EIpVersion MakeIpVersion(bool allowIpv4, bool allowIpv6) { + if (allowIpv4) { + if (allowIpv6) { + return EIP_VERSION_ANY; + } else { + return EIP_VERSION_4; + } + } else if (allowIpv6) { + return EIP_VERSION_6; + } + + ythrow yexception() << "Neither of IPv4/IPv6 is allowed."; + } + +} diff --git a/library/cpp/messagebus/coreconn.h b/library/cpp/messagebus/coreconn.h new file mode 100644 index 0000000000..fca228d82e --- /dev/null +++ b/library/cpp/messagebus/coreconn.h @@ -0,0 +1,67 @@ +#pragma once + +////////////////////////////////////////////////////////////// +/// \file +/// \brief Definitions for asynchonous connection queue + +#include "base.h" +#include "event_loop.h" +#include "netaddr.h" + +#include <util/datetime/base.h> +#include <util/generic/algorithm.h> +#include <util/generic/list.h> +#include <util/generic/map.h> +#include <util/generic/set.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/network/address.h> +#include <util/network/ip.h> +#include <util/network/poller.h> +#include <util/string/util.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/system/thread.h> +#include <util/thread/lfqueue.h> + +#include <deque> +#include <utility> + +#ifdef NO_ERROR +#undef NO_ERROR +#endif + +#define BUS_WORKER_CONDVAR +//#define BUS_WORKER_MIXED + +namespace NBus { + class TBusConnection; + class TBusConnectionFactory; + class TBusServerFactory; + + using TBusConnectionList = TList<TBusConnection*>; + + /// @throw yexception + EIpVersion MakeIpVersion(bool allowIpv4, bool allowIpv6); + + inline bool WouldBlock() { + int syserr = LastSystemError(); + return syserr == EAGAIN || syserr == EINPROGRESS || syserr == EWOULDBLOCK || syserr == EINTR; + } + + class TBusSession; + + struct TMaxConnectedException: public yexception { + TMaxConnectedException(unsigned maxConnect) { + yexception& exc = *this; + exc << TStringBuf("Exceeded maximum number of outstanding connections: "); + exc << maxConnect; + } + }; + + enum EPollType { + POLL_READ, + POLL_WRITE + }; + +} diff --git a/library/cpp/messagebus/coreconn_ut.cpp b/library/cpp/messagebus/coreconn_ut.cpp new file mode 100644 index 0000000000..beb6850f26 --- /dev/null +++ b/library/cpp/messagebus/coreconn_ut.cpp @@ -0,0 +1,25 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "coreconn.h" + +#include <util/generic/yexception.h> + +Y_UNIT_TEST_SUITE(TMakeIpVersionTest) { + using namespace NBus; + + Y_UNIT_TEST(IpV4Allowed) { + UNIT_ASSERT_EQUAL(MakeIpVersion(true, false), EIP_VERSION_4); + } + + Y_UNIT_TEST(IpV6Allowed) { + UNIT_ASSERT_EQUAL(MakeIpVersion(false, true), EIP_VERSION_6); + } + + Y_UNIT_TEST(AllAllowed) { + UNIT_ASSERT_EQUAL(MakeIpVersion(true, true), EIP_VERSION_ANY); + } + + Y_UNIT_TEST(NothingAllowed) { + UNIT_ASSERT_EXCEPTION(MakeIpVersion(false, false), yexception); + } +} diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver.cpp new file mode 100644 index 0000000000..23b02d1003 --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/debug_receiver.cpp @@ -0,0 +1,42 @@ +#include "debug_receiver_handler.h" +#include "debug_receiver_proto.h" + +#include <library/cpp/messagebus/ybus.h> + +#include <library/cpp/getopt/last_getopt.h> +#include <library/cpp/lwtrace/all.h> + +using namespace NBus; + +int main(int argc, char** argv) { + NLWTrace::StartLwtraceFromEnv(); + + TBusQueueConfig queueConfig; + TBusServerSessionConfig sessionConfig; + + NLastGetopt::TOpts opts; + + queueConfig.ConfigureLastGetopt(opts); + sessionConfig.ConfigureLastGetopt(opts); + + opts.AddLongOption("port").Required().RequiredArgument("PORT").StoreResult(&sessionConfig.ListenPort); + + opts.SetFreeArgsMax(0); + + NLastGetopt::TOptsParseResult r(&opts, argc, argv); + + TBusMessageQueuePtr q(CreateMessageQueue(queueConfig)); + + TDebugReceiverProtocol proto; + TDebugReceiverHandler handler; + + TBusServerSessionPtr serverSession = TBusServerSession::Create(&proto, &handler, sessionConfig, q); + // TODO: race is here + handler.ServerSession = serverSession.Get(); + + for (;;) { + Sleep(TDuration::Hours(17)); + } + + return 0; +} diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp new file mode 100644 index 0000000000..05f99e94ca --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp @@ -0,0 +1,20 @@ +#include "debug_receiver_handler.h" + +#include "debug_receiver_proto.h" + +#include <util/generic/cast.h> +#include <util/string/printf.h> + +void TDebugReceiverHandler::OnError(TAutoPtr<NBus::TBusMessage>, NBus::EMessageStatus status) { + Cerr << "error " << status << "\n"; +} + +void TDebugReceiverHandler::OnMessage(NBus::TOnMessageContext& message) { + TDebugReceiverMessage* typedMessage = VerifyDynamicCast<TDebugReceiverMessage*>(message.GetMessage()); + Cerr << "type=" << typedMessage->GetHeader()->Type + << " size=" << typedMessage->GetHeader()->Size + << " flags=" << Sprintf("0x%04x", (int)typedMessage->GetHeader()->FlagsInternal) + << "\n"; + + message.ForgetRequest(); +} diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h new file mode 100644 index 0000000000..0aed6b9984 --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h @@ -0,0 +1,10 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +struct TDebugReceiverHandler: public NBus::IBusServerHandler { + NBus::TBusServerSession* ServerSession; + + void OnError(TAutoPtr<NBus::TBusMessage> pMessage, NBus::EMessageStatus status) override; + void OnMessage(NBus::TOnMessageContext& message) override; +}; diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp new file mode 100644 index 0000000000..0c74f9ecc3 --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp @@ -0,0 +1,20 @@ +#include "debug_receiver_proto.h" + +using namespace NBus; + +TDebugReceiverProtocol::TDebugReceiverProtocol() + : TBusProtocol("debug receiver", 0) +{ +} + +void TDebugReceiverProtocol::Serialize(const NBus::TBusMessage*, TBuffer&) { + Y_FAIL("it is receiver only"); +} + +TAutoPtr<NBus::TBusMessage> TDebugReceiverProtocol::Deserialize(ui16, TArrayRef<const char> payload) { + THolder<TDebugReceiverMessage> r(new TDebugReceiverMessage(ECreateUninitialized())); + + r->Payload.Append(payload.data(), payload.size()); + + return r.Release(); +} diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h new file mode 100644 index 0000000000..d34710dcf7 --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h @@ -0,0 +1,27 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +struct TDebugReceiverMessage: public NBus::TBusMessage { + /// constructor to create messages on sending end + TDebugReceiverMessage(ui16 type) + : NBus::TBusMessage(type) + { + } + + /// constructor with serialzed data to examine the header + TDebugReceiverMessage(NBus::ECreateUninitialized) + : NBus::TBusMessage(NBus::ECreateUninitialized()) + { + } + + TBuffer Payload; +}; + +struct TDebugReceiverProtocol: public NBus::TBusProtocol { + TDebugReceiverProtocol(); + + void Serialize(const NBus::TBusMessage* mess, TBuffer& data) override; + + TAutoPtr<NBus::TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override; +}; diff --git a/library/cpp/messagebus/debug_receiver/ya.make b/library/cpp/messagebus/debug_receiver/ya.make new file mode 100644 index 0000000000..f1b14d35bb --- /dev/null +++ b/library/cpp/messagebus/debug_receiver/ya.make @@ -0,0 +1,17 @@ +PROGRAM(messagebus_debug_receiver) + +OWNER(g:messagebus) + +SRCS( + debug_receiver.cpp + debug_receiver_proto.cpp + debug_receiver_handler.cpp +) + +PEERDIR( + library/cpp/getopt + library/cpp/lwtrace + library/cpp/messagebus +) + +END() diff --git a/library/cpp/messagebus/defs.h b/library/cpp/messagebus/defs.h new file mode 100644 index 0000000000..cb553acc45 --- /dev/null +++ b/library/cpp/messagebus/defs.h @@ -0,0 +1,4 @@ +#pragma once + +#include <library/cpp/messagebus/config/defs.h> + diff --git a/library/cpp/messagebus/dummy_debugger.h b/library/cpp/messagebus/dummy_debugger.h new file mode 100644 index 0000000000..89a4e18716 --- /dev/null +++ b/library/cpp/messagebus/dummy_debugger.h @@ -0,0 +1,9 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/stream/output.h> + +#define MB_TRACE() \ + do { \ + Cerr << TInstant::Now() << " " << __FILE__ << ":" << __LINE__ << " " << __FUNCTION__ << Endl; \ + } while (false) diff --git a/library/cpp/messagebus/duration_histogram.cpp b/library/cpp/messagebus/duration_histogram.cpp new file mode 100644 index 0000000000..32a0001d41 --- /dev/null +++ b/library/cpp/messagebus/duration_histogram.cpp @@ -0,0 +1,74 @@ +#include "duration_histogram.h" + +#include <util/generic/singleton.h> +#include <util/stream/str.h> + +namespace { + ui64 SecondsRound(TDuration d) { + if (d.MilliSeconds() % 1000 >= 500) { + return d.Seconds() + 1; + } else { + return d.Seconds(); + } + } + + ui64 MilliSecondsRound(TDuration d) { + if (d.MicroSeconds() % 1000 >= 500) { + return d.MilliSeconds() + 1; + } else { + return d.MilliSeconds(); + } + } + + ui64 MinutesRound(TDuration d) { + if (d.Seconds() % 60 >= 30) { + return d.Minutes() + 1; + } else { + return d.Minutes(); + } + } + +} + +namespace { + struct TMarks { + std::array<TDuration, TDurationHistogram::Buckets> Marks; + + TMarks() { + Marks[0] = TDuration::Zero(); + for (unsigned i = 1; i < TDurationHistogram::Buckets; ++i) { + if (i >= TDurationHistogram::SecondBoundary) { + Marks[i] = TDuration::Seconds(1) * (1 << (i - TDurationHistogram::SecondBoundary)); + } else { + Marks[i] = TDuration::Seconds(1) / (1 << (TDurationHistogram::SecondBoundary - i)); + } + } + } + }; +} + +TString TDurationHistogram::LabelBefore(unsigned i) { + Y_VERIFY(i < Buckets); + + TDuration d = Singleton<TMarks>()->Marks[i]; + + TStringStream ss; + if (d == TDuration::Zero()) { + ss << "0"; + } else if (d < TDuration::Seconds(1)) { + ss << MilliSecondsRound(d) << "ms"; + } else if (d < TDuration::Minutes(1)) { + ss << SecondsRound(d) << "s"; + } else { + ss << MinutesRound(d) << "m"; + } + return ss.Str(); +} + +TString TDurationHistogram::PrintToString() const { + TStringStream ss; + for (auto time : Times) { + ss << time << "\n"; + } + return ss.Str(); +} diff --git a/library/cpp/messagebus/duration_histogram.h b/library/cpp/messagebus/duration_histogram.h new file mode 100644 index 0000000000..ed060b0101 --- /dev/null +++ b/library/cpp/messagebus/duration_histogram.h @@ -0,0 +1,45 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/generic/bitops.h> +#include <util/generic/string.h> + +#include <array> + +struct TDurationHistogram { + static const unsigned Buckets = 20; + std::array<ui64, Buckets> Times; + + static const unsigned SecondBoundary = 11; + + TDurationHistogram() { + Times.fill(0); + } + + static unsigned BucketFor(TDuration d) { + ui64 units = d.MicroSeconds() * (1 << SecondBoundary) / 1000000; + if (units == 0) { + return 0; + } + unsigned bucket = GetValueBitCount(units) - 1; + if (bucket >= Buckets) { + bucket = Buckets - 1; + } + return bucket; + } + + void AddTime(TDuration d) { + Times[BucketFor(d)] += 1; + } + + TDurationHistogram& operator+=(const TDurationHistogram& that) { + for (unsigned i = 0; i < Times.size(); ++i) { + Times[i] += that.Times[i]; + } + return *this; + } + + static TString LabelBefore(unsigned i); + + TString PrintToString() const; +}; diff --git a/library/cpp/messagebus/duration_histogram_ut.cpp b/library/cpp/messagebus/duration_histogram_ut.cpp new file mode 100644 index 0000000000..01bcc095e9 --- /dev/null +++ b/library/cpp/messagebus/duration_histogram_ut.cpp @@ -0,0 +1,38 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "duration_histogram.h" + +Y_UNIT_TEST_SUITE(TDurationHistogramTest) { + Y_UNIT_TEST(BucketFor) { + UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(0))); + UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(1))); + UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(900))); + UNIT_ASSERT_VALUES_EQUAL(1u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(1500))); + UNIT_ASSERT_VALUES_EQUAL(2u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(2500))); + + unsigned sb = TDurationHistogram::SecondBoundary; + + UNIT_ASSERT_VALUES_EQUAL(sb - 1, TDurationHistogram::BucketFor(TDuration::MilliSeconds(999))); + UNIT_ASSERT_VALUES_EQUAL(sb, TDurationHistogram::BucketFor(TDuration::MilliSeconds(1000))); + UNIT_ASSERT_VALUES_EQUAL(sb, TDurationHistogram::BucketFor(TDuration::MilliSeconds(1001))); + + UNIT_ASSERT_VALUES_EQUAL(TDurationHistogram::Buckets - 1, TDurationHistogram::BucketFor(TDuration::Hours(1))); + } + + Y_UNIT_TEST(Simple) { + TDurationHistogram h1; + h1.AddTime(TDuration::MicroSeconds(1)); + UNIT_ASSERT_VALUES_EQUAL(1u, h1.Times.front()); + + TDurationHistogram h2; + h1.AddTime(TDuration::Hours(1)); + UNIT_ASSERT_VALUES_EQUAL(1u, h1.Times.back()); + } + + Y_UNIT_TEST(LabelFor) { + for (unsigned i = 0; i < TDurationHistogram::Buckets; ++i) { + TDurationHistogram::LabelBefore(i); + //Cerr << TDurationHistogram::LabelBefore(i) << "\n"; + } + } +} diff --git a/library/cpp/messagebus/event_loop.cpp b/library/cpp/messagebus/event_loop.cpp new file mode 100644 index 0000000000..f685135bed --- /dev/null +++ b/library/cpp/messagebus/event_loop.cpp @@ -0,0 +1,370 @@ +#include "event_loop.h" + +#include "network.h" +#include "thread_extra.h" + +#include <util/generic/hash.h> +#include <util/network/pair.h> +#include <util/network/poller.h> +#include <util/system/event.h> +#include <util/system/mutex.h> +#include <util/system/thread.h> +#include <util/system/yassert.h> +#include <util/thread/lfqueue.h> + +#include <errno.h> + +using namespace NEventLoop; + +namespace { + enum ERunningState { + EVENT_LOOP_CREATED, + EVENT_LOOP_RUNNING, + EVENT_LOOP_STOPPED, + }; + + enum EOperation { + OP_READ = 1, + OP_WRITE = 2, + OP_READ_WRITE = OP_READ | OP_WRITE, + }; +} + +class TChannel::TImpl { +public: + TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr, void* cookie); + ~TImpl(); + + void EnableRead(); + void DisableRead(); + void EnableWrite(); + void DisableWrite(); + + void Unregister(); + + SOCKET GetSocket() const; + TSocket GetSocketPtr() const; + + void Update(int pollerFlags, bool enable); + void CallHandler(); + + TEventLoop::TImpl* EventLoop; + TSocket Socket; + TEventHandlerPtr EventHandler; + void* Cookie; + + TMutex Mutex; + + int CurrentFlags; + bool Close; +}; + +class TEventLoop::TImpl { +public: + TImpl(const char* name); + + void Run(); + void Wakeup(); + void Stop(); + + TChannelPtr Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie); + void Unregister(SOCKET socket); + + typedef THashMap<SOCKET, TChannelPtr> TData; + + void AddToPoller(SOCKET socket, void* cookie, int flags); + + TMutex Mutex; + + const char* Name; + + TAtomic RunningState; + TAtomic StopSignal; + TSystemEvent StoppedEvent; + TData Data; + + TLockFreeQueue<SOCKET> SocketsToRemove; + + TSocketPoller Poller; + TSocketHolder WakeupReadSocket; + TSocketHolder WakeupWriteSocket; +}; + +TChannel::~TChannel() { +} + +void TChannel::EnableRead() { + Impl->EnableRead(); +} + +void TChannel::DisableRead() { + Impl->DisableRead(); +} + +void TChannel::EnableWrite() { + Impl->EnableWrite(); +} + +void TChannel::DisableWrite() { + Impl->DisableWrite(); +} + +void TChannel::Unregister() { + Impl->Unregister(); +} + +SOCKET TChannel::GetSocket() const { + return Impl->GetSocket(); +} + +TSocket TChannel::GetSocketPtr() const { + return Impl->GetSocketPtr(); +} + +TChannel::TChannel(TImpl* impl) + : Impl(impl) +{ +} + +TEventLoop::TEventLoop(const char* name) + : Impl(new TImpl(name)) +{ +} + +TEventLoop::~TEventLoop() { +} + +void TEventLoop::Run() { + Impl->Run(); +} + +void TEventLoop::Stop() { + Impl->Stop(); +} + +bool TEventLoop::IsRunning() { + return AtomicGet(Impl->RunningState) == EVENT_LOOP_RUNNING; +} + +TChannelPtr TEventLoop::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) { + return Impl->Register(socket, eventHandler, cookie); +} + +TChannel::TImpl::TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr eventHandler, void* cookie) + : EventLoop(eventLoop) + , Socket(socket) + , EventHandler(eventHandler) + , Cookie(cookie) + , CurrentFlags(0) + , Close(false) +{ +} + +TChannel::TImpl::~TImpl() { + Y_ASSERT(Close); +} + +void TChannel::TImpl::EnableRead() { + Update(OP_READ, true); +} + +void TChannel::TImpl::DisableRead() { + Update(OP_READ, false); +} + +void TChannel::TImpl::EnableWrite() { + Update(OP_WRITE, true); +} + +void TChannel::TImpl::DisableWrite() { + Update(OP_WRITE, false); +} + +void TChannel::TImpl::Unregister() { + TGuard<TMutex> guard(Mutex); + + if (Close) { + return; + } + + Close = true; + if (CurrentFlags != 0) { + EventLoop->Poller.Unwait(Socket); + CurrentFlags = 0; + } + EventHandler.Drop(); + + EventLoop->SocketsToRemove.Enqueue(Socket); + EventLoop->Wakeup(); +} + +void TChannel::TImpl::Update(int flags, bool enable) { + TGuard<TMutex> guard(Mutex); + + if (Close) { + return; + } + + int newFlags = enable ? (CurrentFlags | flags) : (CurrentFlags & ~flags); + + if (CurrentFlags == newFlags) { + return; + } + + if (!newFlags) { + EventLoop->Poller.Unwait(Socket); + } else { + void* cookie = reinterpret_cast<void*>(this); + EventLoop->AddToPoller(Socket, cookie, newFlags); + } + + CurrentFlags = newFlags; +} + +SOCKET TChannel::TImpl::GetSocket() const { + return Socket; +} + +TSocket TChannel::TImpl::GetSocketPtr() const { + return Socket; +} + +void TChannel::TImpl::CallHandler() { + TEventHandlerPtr handler; + + { + TGuard<TMutex> guard(Mutex); + + // other thread may have re-added socket to epoll + // so even if CurrentFlags is 0, epoll may fire again + // so please use non-blocking operations + CurrentFlags = 0; + + if (Close) { + return; + } + + handler = EventHandler; + } + + if (!!handler) { + handler->HandleEvent(Socket, Cookie); + } +} + +TEventLoop::TImpl::TImpl(const char* name) + : Name(name) + , RunningState(EVENT_LOOP_CREATED) + , StopSignal(0) +{ + SOCKET wakeupSockets[2]; + + if (SocketPair(wakeupSockets) < 0) { + Y_FAIL("failed to create socket pair for wakeup sockets: %s", LastSystemErrorText()); + } + + TSocketHolder wakeupReadSocket(wakeupSockets[0]); + TSocketHolder wakeupWriteSocket(wakeupSockets[1]); + + WakeupReadSocket.Swap(wakeupReadSocket); + WakeupWriteSocket.Swap(wakeupWriteSocket); + + SetNonBlock(WakeupWriteSocket, true); + SetNonBlock(WakeupReadSocket, true); + + Poller.WaitRead(WakeupReadSocket, + reinterpret_cast<void*>(this)); +} + +void TEventLoop::TImpl::Run() { + bool res = AtomicCas(&RunningState, EVENT_LOOP_RUNNING, EVENT_LOOP_CREATED); + Y_VERIFY(res, "Invalid mbus event loop state"); + + if (!!Name) { + SetCurrentThreadName(Name); + } + + while (AtomicGet(StopSignal) == 0) { + void* cookies[1024]; + const size_t count = Poller.WaitI(cookies, Y_ARRAY_SIZE(cookies)); + + void** end = cookies + count; + for (void** c = cookies; c != end; ++c) { + TChannel::TImpl* s = reinterpret_cast<TChannel::TImpl*>(*c); + + if (*c == this) { + char buf[0x1000]; + if (NBus::NPrivate::SocketRecv(WakeupReadSocket, buf) < 0) { + Y_FAIL("failed to recv from wakeup socket: %s", LastSystemErrorText()); + } + continue; + } + + s->CallHandler(); + } + + SOCKET socket = -1; + while (SocketsToRemove.Dequeue(&socket)) { + TGuard<TMutex> guard(Mutex); + Y_VERIFY(Data.erase(socket) == 1, "must be removed once"); + } + } + + { + TGuard<TMutex> guard(Mutex); + for (auto& it : Data) { + it.second->Unregister(); + } + + // release file descriptors + Data.clear(); + } + + res = AtomicCas(&RunningState, EVENT_LOOP_STOPPED, EVENT_LOOP_RUNNING); + + Y_VERIFY(res); + + StoppedEvent.Signal(); +} + +void TEventLoop::TImpl::Stop() { + AtomicSet(StopSignal, 1); + + if (AtomicGet(RunningState) == EVENT_LOOP_RUNNING) { + Wakeup(); + + StoppedEvent.WaitI(); + } +} + +TChannelPtr TEventLoop::TImpl::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) { + Y_VERIFY(socket != INVALID_SOCKET, "must be a valid socket"); + + TChannelPtr channel = new TChannel(new TChannel::TImpl(this, socket, eventHandler, cookie)); + + TGuard<TMutex> guard(Mutex); + + Y_VERIFY(Data.insert(std::make_pair(socket, channel)).second, "must not be already inserted"); + + return channel; +} + +void TEventLoop::TImpl::Wakeup() { + if (NBus::NPrivate::SocketSend(WakeupWriteSocket, TArrayRef<const char>("", 1)) < 0) { + if (LastSystemError() != EAGAIN) { + Y_FAIL("failed to send to wakeup socket: %s", LastSystemErrorText()); + } + } +} + +void TEventLoop::TImpl::AddToPoller(SOCKET socket, void* cookie, int flags) { + if (flags == OP_READ) { + Poller.WaitReadOneShot(socket, cookie); + } else if (flags == OP_WRITE) { + Poller.WaitWriteOneShot(socket, cookie); + } else if (flags == OP_READ_WRITE) { + Poller.WaitReadWriteOneShot(socket, cookie); + } else { + Y_FAIL("Wrong flags: %d", int(flags)); + } +} diff --git a/library/cpp/messagebus/event_loop.h b/library/cpp/messagebus/event_loop.h new file mode 100644 index 0000000000..d5b0a53b0c --- /dev/null +++ b/library/cpp/messagebus/event_loop.h @@ -0,0 +1,72 @@ +#pragma once + +#include <util/generic/object_counter.h> +#include <util/generic/ptr.h> +#include <util/network/init.h> +#include <util/network/socket.h> + +namespace NEventLoop { + struct IEventHandler + : public TAtomicRefCount<IEventHandler> { + virtual void HandleEvent(SOCKET socket, void* cookie) = 0; + virtual ~IEventHandler() { + } + }; + + typedef TIntrusivePtr<IEventHandler> TEventHandlerPtr; + + class TEventLoop; + + // TODO: make TChannel itself a pointer + // to avoid confusion with Drop and Unregister + class TChannel + : public TAtomicRefCount<TChannel> { + public: + ~TChannel(); + + void EnableRead(); + void DisableRead(); + void EnableWrite(); + void DisableWrite(); + + void Unregister(); + + SOCKET GetSocket() const; + TSocket GetSocketPtr() const; + + private: + class TImpl; + friend class TEventLoop; + + TObjectCounter<TChannel> ObjectCounter; + + TChannel(TImpl*); + + private: + THolder<TImpl> Impl; + }; + + typedef TIntrusivePtr<TChannel> TChannelPtr; + + class TEventLoop { + public: + TEventLoop(const char* name = nullptr); + ~TEventLoop(); + + void Run(); + void Stop(); + bool IsRunning(); + + TChannelPtr Register(TSocket socket, TEventHandlerPtr, void* cookie = nullptr); + + private: + class TImpl; + friend class TChannel; + + TObjectCounter<TEventLoop> ObjectCounter; + + private: + THolder<TImpl> Impl; + }; + +} diff --git a/library/cpp/messagebus/extra_ref.h b/library/cpp/messagebus/extra_ref.h new file mode 100644 index 0000000000..2927123266 --- /dev/null +++ b/library/cpp/messagebus/extra_ref.h @@ -0,0 +1,36 @@ +#pragma once + +#include <util/system/yassert.h> + +class TExtraRef { + TAtomic Holds; + +public: + TExtraRef() + : Holds(false) + { + } + ~TExtraRef() { + Y_VERIFY(!Holds); + } + + template <typename TThis> + void Retain(TThis* thiz) { + if (AtomicGet(Holds)) { + return; + } + if (AtomicCas(&Holds, 1, 0)) { + thiz->Ref(); + } + } + + template <typename TThis> + void Release(TThis* thiz) { + if (!AtomicGet(Holds)) { + return; + } + if (AtomicCas(&Holds, 0, 1)) { + thiz->UnRef(); + } + } +}; diff --git a/library/cpp/messagebus/futex_like.cpp b/library/cpp/messagebus/futex_like.cpp new file mode 100644 index 0000000000..7f965126db --- /dev/null +++ b/library/cpp/messagebus/futex_like.cpp @@ -0,0 +1,55 @@ +#include <util/system/platform.h> + +#ifdef _linux_ +#include <sys/syscall.h> +#include <linux/futex.h> + +#if !defined(SYS_futex) +#define SYS_futex __NR_futex +#endif +#endif + +#include <errno.h> + +#include <util/system/yassert.h> + +#include "futex_like.h" + +#ifdef _linux_ +namespace { + int futex(int* uaddr, int op, int val, const struct timespec* timeout, + int* uaddr2, int val3) { + return syscall(SYS_futex, uaddr, op, val, timeout, uaddr2, val3); + } +} +#endif + +void TFutexLike::Wake(size_t count) { + Y_ASSERT(count > 0); +#ifdef _linux_ + if (count > unsigned(Max<int>())) { + count = Max<int>(); + } + int r = futex(&Value, FUTEX_WAKE, count, nullptr, nullptr, 0); + Y_VERIFY(r >= 0, "futex_wake failed: %s", strerror(errno)); +#else + TGuard<TMutex> guard(Mutex); + if (count == 1) { + CondVar.Signal(); + } else { + CondVar.BroadCast(); + } +#endif +} + +void TFutexLike::Wait(int expected) { +#ifdef _linux_ + int r = futex(&Value, FUTEX_WAIT, expected, nullptr, nullptr, 0); + Y_VERIFY(r >= 0 || errno == EWOULDBLOCK, "futex_wait failed: %s", strerror(errno)); +#else + TGuard<TMutex> guard(Mutex); + if (expected == Get()) { + CondVar.WaitI(Mutex); + } +#endif +} diff --git a/library/cpp/messagebus/futex_like.h b/library/cpp/messagebus/futex_like.h new file mode 100644 index 0000000000..31d60c60f1 --- /dev/null +++ b/library/cpp/messagebus/futex_like.h @@ -0,0 +1,86 @@ +#pragma once + +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/system/platform.h> + +class TFutexLike { +private: +#ifdef _linux_ + int Value; +#else + TAtomic Value; + TMutex Mutex; + TCondVar CondVar; +#endif + +public: + TFutexLike() + : Value(0) + { + } + + int AddAndGet(int add) { +#ifdef _linux_ + //return __atomic_add_fetch(&Value, add, __ATOMIC_SEQ_CST); + return __sync_add_and_fetch(&Value, add); +#else + return AtomicAdd(Value, add); +#endif + } + + int GetAndAdd(int add) { + return AddAndGet(add) - add; + } + +// until we have modern GCC +#if 0 + int GetAndSet(int newValue) { +#ifdef _linux_ + return __atomic_exchange_n(&Value, newValue, __ATOMIC_SEQ_CST); +#else + return AtomicSwap(&Value, newValue); +#endif + } +#endif + + int Get() { +#ifdef _linux_ + //return __atomic_load_n(&Value, __ATOMIC_SEQ_CST); + __sync_synchronize(); + return Value; +#else + return AtomicGet(Value); +#endif + } + + void Set(int newValue) { +#ifdef _linux_ + //__atomic_store_n(&Value, newValue, __ATOMIC_SEQ_CST); + Value = newValue; + __sync_synchronize(); +#else + AtomicSet(Value, newValue); +#endif + } + + int GetAndIncrement() { + return AddAndGet(1) - 1; + } + + int IncrementAndGet() { + return AddAndGet(1); + } + + int GetAndDecrement() { + return AddAndGet(-1) + 1; + } + + int DecrementAndGet() { + return AddAndGet(-1); + } + + void Wake(size_t count = Max<size_t>()); + + void Wait(int expected); +}; diff --git a/library/cpp/messagebus/handler.cpp b/library/cpp/messagebus/handler.cpp new file mode 100644 index 0000000000..333bd52934 --- /dev/null +++ b/library/cpp/messagebus/handler.cpp @@ -0,0 +1,36 @@ +#include "handler.h" + +#include "remote_server_connection.h" +#include "ybus.h" + +using namespace NBus; +using namespace NBus::NPrivate; + +void IBusErrorHandler::OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) { + Y_UNUSED(pMessage); + Y_UNUSED(status); +} +void IBusServerHandler::OnSent(TAutoPtr<TBusMessage> pMessage) { + Y_UNUSED(pMessage); +} +void IBusClientHandler::OnMessageSent(TBusMessage* pMessage) { + Y_UNUSED(pMessage); +} +void IBusClientHandler::OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) { + Y_UNUSED(pMessage); +} + +void IBusClientHandler::OnClientConnectionEvent(const TClientConnectionEvent&) { +} + +void TOnMessageContext::ForgetRequest() { + Session->ForgetRequest(Ident); +} + +TNetAddr TOnMessageContext::GetPeerAddrNetAddr() const { + return Ident.GetNetAddr(); +} + +bool TOnMessageContext::IsConnectionAlive() const { + return !!Ident.Connection && Ident.Connection->IsAlive(); +} diff --git a/library/cpp/messagebus/handler.h b/library/cpp/messagebus/handler.h new file mode 100644 index 0000000000..60002c68a6 --- /dev/null +++ b/library/cpp/messagebus/handler.h @@ -0,0 +1,135 @@ +#pragma once + +#include "defs.h" +#include "message.h" +#include "message_status.h" +#include "use_after_free_checker.h" +#include "use_count_checker.h" + +#include <util/generic/noncopyable.h> + +namespace NBus { + ///////////////////////////////////////////////////////////////// + /// \brief Interface to message bus handler + + struct IBusErrorHandler { + friend struct ::NBus::NPrivate::TBusSessionImpl; + + private: + TUseAfterFreeChecker UseAfterFreeChecker; + TUseCountChecker UseCountChecker; + + public: + /// called when message or reply can't be delivered + virtual void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status); + + virtual ~IBusErrorHandler() { + } + }; + + class TClientConnectionEvent { + public: + enum EType { + CONNECTED, + DISCONNECTED, + }; + + private: + EType Type; + ui64 Id; + TNetAddr Addr; + + public: + TClientConnectionEvent(EType type, ui64 id, TNetAddr addr) + : Type(type) + , Id(id) + , Addr(addr) + { + } + + EType GetType() const { + return Type; + } + ui64 GetId() const { + return Id; + } + TNetAddr GetAddr() const { + return Addr; + } + }; + + class TOnMessageContext : TNonCopyable { + private: + THolder<TBusMessage> Message; + TBusIdentity Ident; + // TODO: we don't need to store session, we have connection in ident + TBusServerSession* Session; + + public: + TOnMessageContext() + : Session() + { + } + TOnMessageContext(TAutoPtr<TBusMessage> message, TBusIdentity& ident, TBusServerSession* session) + : Message(message) + , Session(session) + { + Ident.Swap(ident); + } + + bool IsInWork() const { + return Ident.IsInWork(); + } + + bool operator!() const { + return !IsInWork(); + } + + TBusMessage* GetMessage() { + return Message.Get(); + } + + TBusMessage* ReleaseMessage() { + return Message.Release(); + } + + TBusServerSession* GetSession() { + return Session; + } + + template <typename U /* <: TBusMessage */> + EMessageStatus SendReplyAutoPtr(TAutoPtr<U>& rep); + + EMessageStatus SendReplyMove(TBusMessageAutoPtr response); + + void AckMessage(TBusIdentity& ident); + + void ForgetRequest(); + + void Swap(TOnMessageContext& that) { + DoSwap(Message, that.Message); + Ident.Swap(that.Ident); + DoSwap(Session, that.Session); + } + + TNetAddr GetPeerAddrNetAddr() const; + + bool IsConnectionAlive() const; + }; + + struct IBusServerHandler: public IBusErrorHandler { + virtual void OnMessage(TOnMessageContext& onMessage) = 0; + /// called when reply has been sent from destination + virtual void OnSent(TAutoPtr<TBusMessage> pMessage); + }; + + struct IBusClientHandler: public IBusErrorHandler { + /// called on source when reply arrives from destination + virtual void OnReply(TAutoPtr<TBusMessage> pMessage, TAutoPtr<TBusMessage> pReply) = 0; + /// called when client side message has gone into wire, place to call AckMessage() + virtual void OnMessageSent(TBusMessage* pMessage); + virtual void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage); + virtual void OnClientConnectionEvent(const TClientConnectionEvent&); + }; + +} diff --git a/library/cpp/messagebus/handler_impl.h b/library/cpp/messagebus/handler_impl.h new file mode 100644 index 0000000000..6593f04cc3 --- /dev/null +++ b/library/cpp/messagebus/handler_impl.h @@ -0,0 +1,23 @@ +#pragma once + +#include "handler.h" +#include "local_flags.h" +#include "session.h" + +namespace NBus { + template <typename U /* <: TBusMessage */> + EMessageStatus TOnMessageContext::SendReplyAutoPtr(TAutoPtr<U>& response) { + return Session->SendReplyAutoPtr(Ident, response); + } + + inline EMessageStatus TOnMessageContext::SendReplyMove(TBusMessageAutoPtr response) { + return SendReplyAutoPtr(response); + } + + inline void TOnMessageContext::AckMessage(TBusIdentity& ident) { + Y_VERIFY(Ident.LocalFlags == NPrivate::MESSAGE_IN_WORK); + Y_VERIFY(ident.LocalFlags == 0); + Ident.Swap(ident); + } + +} diff --git a/library/cpp/messagebus/hash.h b/library/cpp/messagebus/hash.h new file mode 100644 index 0000000000..cc1b136a86 --- /dev/null +++ b/library/cpp/messagebus/hash.h @@ -0,0 +1,19 @@ +#pragma once + +#include <util/str_stl.h> +#include <util/digest/numeric.h> + +namespace NBus { + namespace NPrivate { + template <typename T> + size_t Hash(const T& val) { + return THash<T>()(val); + } + + template <typename T, typename U> + size_t HashValues(const T& a, const U& b) { + return CombineHashes(Hash(a), Hash(b)); + } + + } +} diff --git a/library/cpp/messagebus/key_value_printer.cpp b/library/cpp/messagebus/key_value_printer.cpp new file mode 100644 index 0000000000..c8592145c7 --- /dev/null +++ b/library/cpp/messagebus/key_value_printer.cpp @@ -0,0 +1,46 @@ +#include "key_value_printer.h" + +#include <util/stream/format.h> + +TKeyValuePrinter::TKeyValuePrinter(const TString& sep) + : Sep(sep) +{ +} + +TKeyValuePrinter::~TKeyValuePrinter() { +} + +void TKeyValuePrinter::AddRowImpl(const TString& key, const TString& value, bool alignLeft) { + Keys.push_back(key); + Values.push_back(value); + AlignLefts.push_back(alignLeft); +} + +TString TKeyValuePrinter::PrintToString() const { + if (Keys.empty()) { + return TString(); + } + + size_t keyWidth = 0; + size_t valueWidth = 0; + + for (size_t i = 0; i < Keys.size(); ++i) { + keyWidth = Max(keyWidth, Keys.at(i).size()); + valueWidth = Max(valueWidth, Values.at(i).size()); + } + + TStringStream ss; + + for (size_t i = 0; i < Keys.size(); ++i) { + ss << RightPad(Keys.at(i), keyWidth); + ss << Sep; + if (AlignLefts.at(i)) { + ss << Values.at(i); + } else { + ss << LeftPad(Values.at(i), valueWidth); + } + ss << Endl; + } + + return ss.Str(); +} diff --git a/library/cpp/messagebus/key_value_printer.h b/library/cpp/messagebus/key_value_printer.h new file mode 100644 index 0000000000..bca1fde50e --- /dev/null +++ b/library/cpp/messagebus/key_value_printer.h @@ -0,0 +1,28 @@ +#pragma once + +#include <util/generic/string.h> +#include <util/generic/typetraits.h> +#include <util/generic/vector.h> +#include <util/string/cast.h> + +class TKeyValuePrinter { +private: + TString Sep; + TVector<TString> Keys; + TVector<TString> Values; + TVector<bool> AlignLefts; + +public: + TKeyValuePrinter(const TString& sep = TString(": ")); + ~TKeyValuePrinter(); + + template <typename TKey, typename TValue> + void AddRow(const TKey& key, const TValue& value, bool leftAlign = !std::is_integral<TValue>::value) { + return AddRowImpl(ToString(key), ToString(value), leftAlign); + } + + TString PrintToString() const; + +private: + void AddRowImpl(const TString& key, const TString& value, bool leftAlign); +}; diff --git a/library/cpp/messagebus/latch.h b/library/cpp/messagebus/latch.h new file mode 100644 index 0000000000..373f4c0e13 --- /dev/null +++ b/library/cpp/messagebus/latch.h @@ -0,0 +1,53 @@ +#pragma once + +#include <util/system/condvar.h> +#include <util/system/mutex.h> + +class TLatch { +private: + // 0 for unlocked, 1 for locked + TAtomic Locked; + TMutex Mutex; + TCondVar CondVar; + +public: + TLatch() + : Locked(0) + { + } + + void Wait() { + // optimistic path + if (AtomicGet(Locked) == 0) { + return; + } + + TGuard<TMutex> guard(Mutex); + while (AtomicGet(Locked) == 1) { + CondVar.WaitI(Mutex); + } + } + + bool TryWait() { + return AtomicGet(Locked) == 0; + } + + void Unlock() { + // optimistic path + if (AtomicGet(Locked) == 0) { + return; + } + + TGuard<TMutex> guard(Mutex); + AtomicSet(Locked, 0); + CondVar.BroadCast(); + } + + void Lock() { + AtomicSet(Locked, 1); + } + + bool IsLocked() { + return AtomicGet(Locked); + } +}; diff --git a/library/cpp/messagebus/latch_ut.cpp b/library/cpp/messagebus/latch_ut.cpp new file mode 100644 index 0000000000..bfab04f527 --- /dev/null +++ b/library/cpp/messagebus/latch_ut.cpp @@ -0,0 +1,20 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "latch.h" + +Y_UNIT_TEST_SUITE(TLatch) { + Y_UNIT_TEST(Simple) { + TLatch latch; + UNIT_ASSERT(latch.TryWait()); + latch.Lock(); + UNIT_ASSERT(!latch.TryWait()); + latch.Lock(); + latch.Lock(); + UNIT_ASSERT(!latch.TryWait()); + latch.Unlock(); + UNIT_ASSERT(latch.TryWait()); + latch.Unlock(); + latch.Unlock(); + UNIT_ASSERT(latch.TryWait()); + } +} diff --git a/library/cpp/messagebus/left_right_buffer.h b/library/cpp/messagebus/left_right_buffer.h new file mode 100644 index 0000000000..f937cefad0 --- /dev/null +++ b/library/cpp/messagebus/left_right_buffer.h @@ -0,0 +1,78 @@ +#pragma once + +#include <util/generic/buffer.h> +#include <util/generic/noncopyable.h> +#include <util/system/yassert.h> + +namespace NBus { + namespace NPrivate { + class TLeftRightBuffer : TNonCopyable { + private: + TBuffer Buffer; + size_t Left; + + void CheckInvariant() { + Y_ASSERT(Left <= Buffer.Size()); + } + + public: + TLeftRightBuffer() + : Left(0) + { + } + + TBuffer& GetBuffer() { + return Buffer; + } + + size_t Capacity() { + return Buffer.Capacity(); + } + + void Clear() { + Buffer.Clear(); + Left = 0; + } + + void Reset() { + Buffer.Reset(); + Left = 0; + } + + void Compact() { + Buffer.ChopHead(Left); + Left = 0; + } + + char* LeftPos() { + return Buffer.Data() + Left; + } + + size_t LeftSize() { + return Left; + } + + void LeftProceed(size_t count) { + Y_ASSERT(count <= Size()); + Left += count; + } + + size_t Size() { + return Buffer.Size() - Left; + } + + bool Empty() { + return Size() == 0; + } + + char* RightPos() { + return Buffer.Data() + Buffer.Size(); + } + + size_t Avail() { + return Buffer.Avail(); + } + }; + + } +} diff --git a/library/cpp/messagebus/lfqueue_batch.h b/library/cpp/messagebus/lfqueue_batch.h new file mode 100644 index 0000000000..8128d3154d --- /dev/null +++ b/library/cpp/messagebus/lfqueue_batch.h @@ -0,0 +1,36 @@ +#pragma once + +#include <library/cpp/messagebus/actor/temp_tls_vector.h> + +#include <util/generic/vector.h> +#include <util/thread/lfstack.h> + +template <typename T, template <typename, class> class TVectorType = TVector> +class TLockFreeQueueBatch { +private: + TLockFreeStack<TVectorType<T, std::allocator<T>>*> Stack; + +public: + bool IsEmpty() { + return Stack.IsEmpty(); + } + + void EnqueueAll(TAutoPtr<TVectorType<T, std::allocator<T>>> vec) { + Stack.Enqueue(vec.Release()); + } + + void DequeueAllSingleConsumer(TVectorType<T, std::allocator<T>>* r) { + TTempTlsVector<TVectorType<T, std::allocator<T>>*> vs; + Stack.DequeueAllSingleConsumer(vs.GetVector()); + + for (typename TVector<TVectorType<T, std::allocator<T>>*>::reverse_iterator i = vs.GetVector()->rbegin(); + i != vs.GetVector()->rend(); ++i) { + if (i == vs.GetVector()->rend()) { + r->swap(**i); + } else { + r->insert(r->end(), (*i)->begin(), (*i)->end()); + } + delete *i; + } + } +}; diff --git a/library/cpp/messagebus/lfqueue_batch_ut.cpp b/library/cpp/messagebus/lfqueue_batch_ut.cpp new file mode 100644 index 0000000000..f80434c0d4 --- /dev/null +++ b/library/cpp/messagebus/lfqueue_batch_ut.cpp @@ -0,0 +1,56 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "lfqueue_batch.h" + +Y_UNIT_TEST_SUITE(TLockFreeQueueBatch) { + Y_UNIT_TEST(Order1) { + TLockFreeQueueBatch<unsigned> q; + { + TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>); + v->push_back(0); + v->push_back(1); + q.EnqueueAll(v); + } + + TVector<unsigned> r; + q.DequeueAllSingleConsumer(&r); + + UNIT_ASSERT_VALUES_EQUAL(2u, r.size()); + for (unsigned i = 0; i < 2; ++i) { + UNIT_ASSERT_VALUES_EQUAL(i, r[i]); + } + + r.clear(); + q.DequeueAllSingleConsumer(&r); + UNIT_ASSERT_VALUES_EQUAL(0u, r.size()); + } + + Y_UNIT_TEST(Order2) { + TLockFreeQueueBatch<unsigned> q; + { + TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>); + v->push_back(0); + v->push_back(1); + q.EnqueueAll(v); + } + { + TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>); + v->push_back(2); + v->push_back(3); + v->push_back(4); + q.EnqueueAll(v); + } + + TVector<unsigned> r; + q.DequeueAllSingleConsumer(&r); + + UNIT_ASSERT_VALUES_EQUAL(5u, r.size()); + for (unsigned i = 0; i < 5; ++i) { + UNIT_ASSERT_VALUES_EQUAL(i, r[i]); + } + + r.clear(); + q.DequeueAllSingleConsumer(&r); + UNIT_ASSERT_VALUES_EQUAL(0u, r.size()); + } +} diff --git a/library/cpp/messagebus/local_flags.cpp b/library/cpp/messagebus/local_flags.cpp new file mode 100644 index 0000000000..877e533f76 --- /dev/null +++ b/library/cpp/messagebus/local_flags.cpp @@ -0,0 +1,32 @@ +#include "local_flags.h" + +#include <util/stream/str.h> +#include <util/string/printf.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TString NBus::NPrivate::LocalFlagSetToString(ui32 flags0) { + if (flags0 == 0) { + return "0"; + } + + ui32 flags = flags0; + + TStringStream ss; +#define P(name, value, ...) \ + do \ + if (flags & value) { \ + if (!ss.Str().empty()) { \ + ss << "|"; \ + } \ + ss << #name; \ + flags &= ~name; \ + } \ + while (false); + MESSAGE_LOCAL_FLAGS_MAP(P) + if (flags != 0) { + return Sprintf("0x%x", unsigned(flags0)); + } + return ss.Str(); +} diff --git a/library/cpp/messagebus/local_flags.h b/library/cpp/messagebus/local_flags.h new file mode 100644 index 0000000000..f589283188 --- /dev/null +++ b/library/cpp/messagebus/local_flags.h @@ -0,0 +1,26 @@ +#pragma once + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +#include <util/generic/string.h> +#include <util/stream/output.h> + +namespace NBus { + namespace NPrivate { +#define MESSAGE_LOCAL_FLAGS_MAP(XX) \ + XX(MESSAGE_REPLY_INTERNAL, 0x0001) \ + XX(MESSAGE_IN_WORK, 0x0002) \ + XX(MESSAGE_IN_FLIGHT_ON_CLIENT, 0x0004) \ + XX(MESSAGE_REPLY_IS_BEGING_SENT, 0x0008) \ + XX(MESSAGE_ONE_WAY_INTERNAL, 0x0010) \ + /**/ + + enum EMessageLocalFlags { + MESSAGE_LOCAL_FLAGS_MAP(ENUM_VALUE_GEN) + }; + + ENUM_TO_STRING(EMessageLocalFlags, MESSAGE_LOCAL_FLAGS_MAP) + + TString LocalFlagSetToString(ui32); + } +} diff --git a/library/cpp/messagebus/local_flags_ut.cpp b/library/cpp/messagebus/local_flags_ut.cpp new file mode 100644 index 0000000000..189d73eb0f --- /dev/null +++ b/library/cpp/messagebus/local_flags_ut.cpp @@ -0,0 +1,18 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "local_flags.h" + +using namespace NBus; +using namespace NBus::NPrivate; + +Y_UNIT_TEST_SUITE(EMessageLocalFlags) { + Y_UNIT_TEST(TestLocalFlagSetToString) { + UNIT_ASSERT_VALUES_EQUAL("0", LocalFlagSetToString(0)); + UNIT_ASSERT_VALUES_EQUAL("MESSAGE_REPLY_INTERNAL", + LocalFlagSetToString(MESSAGE_REPLY_INTERNAL)); + UNIT_ASSERT_VALUES_EQUAL("MESSAGE_IN_WORK|MESSAGE_IN_FLIGHT_ON_CLIENT", + LocalFlagSetToString(MESSAGE_IN_WORK | MESSAGE_IN_FLIGHT_ON_CLIENT)); + UNIT_ASSERT_VALUES_EQUAL("0xff3456", + LocalFlagSetToString(0xff3456)); + } +} diff --git a/library/cpp/messagebus/local_tasks.h b/library/cpp/messagebus/local_tasks.h new file mode 100644 index 0000000000..d8e801a457 --- /dev/null +++ b/library/cpp/messagebus/local_tasks.h @@ -0,0 +1,23 @@ +#pragma once + +#include <util/system/atomic.h> + +class TLocalTasks { +private: + TAtomic GotTasks; + +public: + TLocalTasks() + : GotTasks(0) + { + } + + void AddTask() { + AtomicSet(GotTasks, 1); + } + + bool FetchTask() { + bool gotTasks = AtomicCas(&GotTasks, 0, 1); + return gotTasks; + } +}; diff --git a/library/cpp/messagebus/locator.cpp b/library/cpp/messagebus/locator.cpp new file mode 100644 index 0000000000..e38a35c426 --- /dev/null +++ b/library/cpp/messagebus/locator.cpp @@ -0,0 +1,427 @@ +//////////////////////////////////////////////////////////////////////////// +/// \file +/// \brief Implementation of locator service + +#include "locator.h" + +#include "ybus.h" + +#include <util/generic/hash_set.h> +#include <util/system/hostname.h> + +namespace NBus { + using namespace NAddr; + + static TIpPort GetAddrPort(const IRemoteAddr& addr) { + switch (addr.Addr()->sa_family) { + case AF_INET: { + return ntohs(((const sockaddr_in*)addr.Addr())->sin_port); + } + + case AF_INET6: { + return ntohs(((const sockaddr_in6*)addr.Addr())->sin6_port); + } + + default: { + ythrow yexception() << "not implemented"; + break; + } + } + } + + static inline bool GetIp6AddressFromVector(const TVector<TNetAddr>& addrs, TNetAddr* addr) { + for (size_t i = 1; i < addrs.size(); ++i) { + if (addrs[i - 1].Addr()->sa_family == addrs[i].Addr()->sa_family) { + return false; + } + + if (GetAddrPort(addrs[i - 1]) != GetAddrPort(addrs[i])) { + return false; + } + } + + for (size_t i = 0; i < addrs.size(); ++i) { + if (addrs[i].Addr()->sa_family == AF_INET6) { + *addr = addrs[i]; + return true; + } + } + + return false; + } + + EMessageStatus TBusProtocol::GetDestination(const TBusClientSession*, TBusMessage* mess, TBusLocator* locator, TNetAddr* addr) { + TBusService service = GetService(); + TBusKey key = GetKey(mess); + TVector<TNetAddr> addrs; + + /// check for special local key + if (key == YBUS_KEYLOCAL) { + locator->GetLocalAddresses(service, addrs); + } else { + /// lookup address/port in the locator table + locator->LocateAll(service, key, addrs); + } + + if (addrs.size() == 0) { + return MESSAGE_SERVICE_UNKNOWN; + } else if (addrs.size() == 1) { + *addr = addrs[0]; + } else { + if (!GetIp6AddressFromVector(addrs, addr)) { + /// default policy can't make choice for you here, overide GetDestination() function + /// to implement custom routing strategy for your service. + return MESSAGE_SERVICE_TOOMANY; + } + } + + return MESSAGE_OK; + } + + static const sockaddr_in* SockAddrIpV4(const IRemoteAddr& a) { + return (const sockaddr_in*)a.Addr(); + } + + static const sockaddr_in6* SockAddrIpV6(const IRemoteAddr& a) { + return (const sockaddr_in6*)a.Addr(); + } + + static bool IsAddressEqual(const IRemoteAddr& a1, const IRemoteAddr& a2) { + if (a1.Addr()->sa_family == a2.Addr()->sa_family) { + if (a1.Addr()->sa_family == AF_INET) { + return memcmp(&SockAddrIpV4(a1)->sin_addr, &SockAddrIpV4(a2)->sin_addr, sizeof(in_addr)) == 0; + } else { + return memcmp(&SockAddrIpV6(a1)->sin6_addr, &SockAddrIpV6(a2)->sin6_addr, sizeof(in6_addr)) == 0; + } + } + return false; + } + + TBusLocator::TBusLocator() + : MyInterfaces(GetNetworkInterfaces()) + { + } + + bool TBusLocator::TItem::operator<(const TItem& y) const { + const TItem& x = *this; + + if (x.ServiceId == y.ServiceId) { + return (x.End < y.End) || ((x.End == y.End) && CompareByHost(x.Addr, y.Addr) < 0); + } + return x.ServiceId < y.ServiceId; + } + + bool TBusLocator::TItem::operator==(const TItem& y) const { + return ServiceId == y.ServiceId && Start == y.Start && End == y.End && Addr == y.Addr; + } + + TBusLocator::TItem::TItem(TServiceId serviceId, TBusKey start, TBusKey end, const TNetAddr& addr) + : ServiceId(serviceId) + , Start(start) + , End(end) + , Addr(addr) + { + } + + bool TBusLocator::IsLocal(const TNetAddr& addr) { + for (const auto& myInterface : MyInterfaces) { + if (IsAddressEqual(addr, *myInterface.Address)) { + return true; + } + } + + return false; + } + + TBusLocator::TServiceId TBusLocator::GetServiceId(const char* name) { + const char* c = ServiceIdSet.insert(name).first->c_str(); + return (ui64)c; + } + + int TBusLocator::RegisterBreak(TBusService service, const TVector<TBusKey>& starts, const TNetAddr& addr) { + TGuard<TMutex> G(Lock); + + TServiceId serviceId = GetServiceId(service); + for (size_t i = 0; i < starts.size(); ++i) { + RegisterBreak(serviceId, starts[i], addr); + } + return 0; + } + + int TBusLocator::RegisterBreak(TServiceId serviceId, const TBusKey start, const TNetAddr& addr) { + TItems::const_iterator it = Items.lower_bound(TItem(serviceId, 0, start, addr)); + TItems::const_iterator service_it = + Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr())); + + THolder<TItem> left; + THolder<TItem> right; + if ((it != Items.end() || Items.begin() != Items.end()) && service_it != Items.end() && service_it->ServiceId == serviceId) { + if (it == Items.end()) { + --it; + } + const TItem& item = *it; + left.Reset(new TItem(serviceId, item.Start, + Max<TBusKey>(item.Start, start - 1), item.Addr)); + right.Reset(new TItem(serviceId, start, item.End, addr)); + Items.erase(*it); + } else { + left.Reset(new TItem(serviceId, YBUS_KEYMIN, start, addr)); + if (start < YBUS_KEYMAX) { + right.Reset(new TItem(serviceId, start + 1, YBUS_KEYMAX, addr)); + } + } + Items.insert(*left); + Items.insert(*right); + NormalizeBreaks(serviceId); + return 0; + } + + int TBusLocator::UnregisterBreak(TBusService service, const TNetAddr& addr) { + TGuard<TMutex> G(Lock); + + TServiceId serviceId = GetServiceId(service); + return UnregisterBreak(serviceId, addr); + } + + int TBusLocator::UnregisterBreak(TServiceId serviceId, const TNetAddr& addr) { + int deleted = 0; + TItems::iterator it = Items.begin(); + while (it != Items.end()) { + const TItem& item = *it; + if (item.ServiceId != serviceId) { + ++it; + continue; + } + TItems::iterator itErase = it++; + if (item.ServiceId == serviceId && item.Addr == addr) { + Items.erase(itErase); + deleted += 1; + } + } + + if (Items.begin() == Items.end()) { + return deleted; + } + TBusKey keyItem = YBUS_KEYMAX; + it = Items.end(); + TItems::iterator first = it; + do { + --it; + // item.Start is not used in set comparison function + // so you can't violate set sort order by changing it + // hence const_cast() + TItem& item = const_cast<TItem&>(*it); + if (item.ServiceId != serviceId) { + continue; + } + first = it; + if (item.End < keyItem) { + item.End = keyItem; + } + keyItem = item.Start - 1; + } while (it != Items.begin()); + + if (first != Items.end() && first->Start != 0) { + TItem item(serviceId, YBUS_KEYMIN, first->Start - 1, first->Addr); + Items.insert(item); + } + + NormalizeBreaks(serviceId); + return deleted; + } + + void TBusLocator::NormalizeBreaks(TServiceId serviceId) { + TItems::const_iterator first = Items.lower_bound(TItem(serviceId, YBUS_KEYMIN, YBUS_KEYMIN, TNetAddr())); + TItems::const_iterator last = Items.end(); + + if ((Items.end() != first) && (first->ServiceId == serviceId)) { + if (serviceId != Max<TServiceId>()) { + last = Items.lower_bound(TItem(serviceId + 1, YBUS_KEYMIN, YBUS_KEYMIN, TNetAddr())); + } + + --last; + Y_ASSERT(Items.end() != last); + Y_ASSERT(last->ServiceId == serviceId); + + TItem& beg = const_cast<TItem&>(*first); + beg.Addr = last->Addr; + } + } + + int TBusLocator::LocateAll(TBusService service, TBusKey key, TVector<TNetAddr>& addrs) { + TGuard<TMutex> G(Lock); + Y_VERIFY(addrs.empty(), "Non emtpy addresses"); + + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + + for (it = Items.lower_bound(TItem(serviceId, 0, key, TNetAddr())); + it != Items.end() && it->ServiceId == serviceId && it->Start <= key && key <= it->End; + ++it) { + const TItem& item = *it; + addrs.push_back(item.Addr); + } + + if (addrs.size() == 0) { + return -1; + } + return (int)addrs.size(); + } + + int TBusLocator::Locate(TBusService service, TBusKey key, TNetAddr* addr) { + TGuard<TMutex> G(Lock); + + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + + it = Items.lower_bound(TItem(serviceId, 0, key, TNetAddr())); + + if (it != Items.end()) { + const TItem& item = *it; + if (item.ServiceId == serviceId && item.Start <= key && key < item.End) { + *addr = item.Addr; + + return 0; + } + } + + return -1; + } + + int TBusLocator::GetLocalPort(TBusService service) { + TGuard<TMutex> G(Lock); + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + int port = 0; + + for (it = Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr())); it != Items.end(); ++it) { + const TItem& item = *it; + if (item.ServiceId != serviceId) { + break; + } + + if (IsLocal(item.Addr)) { + if (port != 0 && port != GetAddrPort(item.Addr)) { + Y_ASSERT(0 && "Can't decide which port to use."); + return 0; + } + port = GetAddrPort(item.Addr); + } + } + + return port; + } + + int TBusLocator::GetLocalAddresses(TBusService service, TVector<TNetAddr>& addrs) { + TGuard<TMutex> G(Lock); + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + + for (it = Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr())); it != Items.end(); ++it) { + const TItem& item = *it; + if (item.ServiceId != serviceId) { + break; + } + + if (IsLocal(item.Addr)) { + addrs.push_back(item.Addr); + } + } + + if (addrs.size() == 0) { + return -1; + } + + return (int)addrs.size(); + } + + int TBusLocator::LocateHost(TBusService service, TBusKey key, TString* host, int* port, bool* isLocal) { + int ret; + TNetAddr addr; + ret = Locate(service, key, &addr); + if (ret != 0) { + return ret; + } + + { + TGuard<TMutex> G(Lock); + THostAddrMap::const_iterator it = HostAddrMap.find(addr); + if (it == HostAddrMap.end()) { + return -1; + } + *host = it->second; + } + + *port = GetAddrPort(addr); + if (isLocal != nullptr) { + *isLocal = IsLocal(addr); + } + return 0; + } + + int TBusLocator::LocateKeys(TBusService service, TBusKeyVec& keys, bool onlyLocal) { + TGuard<TMutex> G(Lock); + Y_VERIFY(keys.empty(), "Non empty keys"); + + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + for (it = Items.begin(); it != Items.end(); ++it) { + const TItem& item = *it; + if (item.ServiceId != serviceId) { + continue; + } + if (onlyLocal && !IsLocal(item.Addr)) { + continue; + } + keys.push_back(std::make_pair(item.Start, item.End)); + } + return (int)keys.size(); + } + + int TBusLocator::Register(TBusService service, const char* hostName, int port, TBusKey start /*= YBUS_KEYMIN*/, TBusKey end /*= YBUS_KEYMAX*/, EIpVersion requireVersion /*= EIP_VERSION_4*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) { + TNetAddr addr(hostName, port, requireVersion, preferVersion); // throws + { + TGuard<TMutex> G(Lock); + HostAddrMap[addr] = hostName; + } + Register(service, start, end, addr); + return 0; + } + + int TBusLocator::Register(TBusService service, TBusKey start, TBusKey end, const TNetworkAddress& na, EIpVersion requireVersion /*= EIP_VERSION_4*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) { + TNetAddr addr(na, requireVersion, preferVersion); // throws + Register(service, start, end, addr); + return 0; + } + + int TBusLocator::Register(TBusService service, TBusKey start, TBusKey end, const TNetAddr& addr) { + TGuard<TMutex> G(Lock); + + TServiceId serviceId = GetServiceId(service); + TItems::const_iterator it; + + TItem itemToReg(serviceId, start, end, addr); + for (it = Items.lower_bound(TItem(serviceId, 0, start, TNetAddr())); + it != Items.end() && it->ServiceId == serviceId; + ++it) { + const TItem& item = *it; + if (item == itemToReg) { + return 0; + } + if ((item.Start < start && start < item.End) || (item.Start < end && end < item.End)) { + Y_FAIL("Overlap in registered keys with non-identical range"); + } + } + + Items.insert(itemToReg); + return 0; + } + + int TBusLocator::Unregister(TBusService service, TBusKey start, TBusKey end) { + TGuard<TMutex> G(Lock); + TServiceId serviceId = GetServiceId(service); + Items.erase(TItem(serviceId, start, end, TNetAddr())); + return 0; + } + +} diff --git a/library/cpp/messagebus/locator.h b/library/cpp/messagebus/locator.h new file mode 100644 index 0000000000..f8556a3fce --- /dev/null +++ b/library/cpp/messagebus/locator.h @@ -0,0 +1,93 @@ +#pragma once + +#include "defs.h" + +#include <util/generic/hash.h> +#include <util/generic/map.h> +#include <util/generic/set.h> +#include <util/generic/string.h> +#include <util/network/interface.h> +#include <util/system/mutex.h> + +namespace NBus { + /////////////////////////////////////////////// + /// \brief Client interface to locator service + + /// This interface abstracts clustering/location service that + /// allows clients find servers (address, port) using "name" and "key". + /// The instance lives in TBusMessageQueue-object, but can be shared by different queues. + class TBusLocator: public TAtomicRefCount<TBusLocator>, public TNonCopyable { + private: + typedef ui64 TServiceId; + typedef TSet<TString> TServiceIdSet; + TServiceIdSet ServiceIdSet; + TServiceId GetServiceId(const char* name); + + typedef TMap<TNetAddr, TString> THostAddrMap; + THostAddrMap HostAddrMap; + + TNetworkInterfaceList MyInterfaces; + + struct TItem { + TServiceId ServiceId; + TBusKey Start; + TBusKey End; + TNetAddr Addr; + + bool operator<(const TItem& y) const; + + bool operator==(const TItem& y) const; + + TItem(TServiceId serviceId, TBusKey start, TBusKey end, const TNetAddr& addr); + }; + + typedef TMultiSet<TItem> TItems; + TItems Items; + TMutex Lock; + + int RegisterBreak(TServiceId serviceId, const TBusKey start, const TNetAddr& addr); + int UnregisterBreak(TServiceId serviceId, const TNetAddr& addr); + + void NormalizeBreaks(TServiceId serviceId); + + private: + int Register(TBusService service, TBusKey start, TBusKey end, const TNetAddr& addr); + + public: + /// creates instance that obtains location table from locator server (not implemented) + TBusLocator(); + + /// returns true if this address is on the same node for YBUS_KEYLOCAL + bool IsLocal(const TNetAddr& addr); + + /// returns first address for service and key + int Locate(TBusService service, TBusKey key, TNetAddr* addr); + + /// returns all addresses mathing service and key + int LocateAll(TBusService service, TBusKey key, TVector<TNetAddr>& addrs); + + /// returns actual host name for service and key + int LocateHost(TBusService service, TBusKey key, TString* host, int* port, bool* isLocal = nullptr); + + /// returns all key ranges for the given service + int LocateKeys(TBusService service, TBusKeyVec& keys, bool onlyLocal = false); + + /// returns port on the local node for the service + int GetLocalPort(TBusService service); + + /// returns addresses of the local node for the service + int GetLocalAddresses(TBusService service, TVector<TNetAddr>& addrs); + + /// register service instance + int Register(TBusService service, TBusKey start, TBusKey end, const TNetworkAddress& addr, EIpVersion requireVersion = EIP_VERSION_4, EIpVersion preferVersion = EIP_VERSION_ANY); + /// @throws yexception + int Register(TBusService service, const char* host, int port, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion requireVersion = EIP_VERSION_4, EIpVersion preferVersion = EIP_VERSION_ANY); + + /// unregister service instance + int Unregister(TBusService service, TBusKey start, TBusKey end); + + int RegisterBreak(TBusService service, const TVector<TBusKey>& starts, const TNetAddr& addr); + int UnregisterBreak(TBusService service, const TNetAddr& addr); + }; + +} diff --git a/library/cpp/messagebus/mb_lwtrace.cpp b/library/cpp/messagebus/mb_lwtrace.cpp new file mode 100644 index 0000000000..c54cd5ab71 --- /dev/null +++ b/library/cpp/messagebus/mb_lwtrace.cpp @@ -0,0 +1,12 @@ +#include "mb_lwtrace.h" + +#include <library/cpp/lwtrace/all.h> + +#include <util/generic/singleton.h> + +LWTRACE_DEFINE_PROVIDER(LWTRACE_MESSAGEBUS_PROVIDER) + +void NBus::InitBusLwtrace() { + // Function is nop, and needed only to make sure TBusLwtraceInit loaded. + // It won't be necessary when pg@ implements GLOBAL in arc. +} diff --git a/library/cpp/messagebus/mb_lwtrace.h b/library/cpp/messagebus/mb_lwtrace.h new file mode 100644 index 0000000000..e62728b265 --- /dev/null +++ b/library/cpp/messagebus/mb_lwtrace.h @@ -0,0 +1,19 @@ +#pragma once + +#include <library/cpp/lwtrace/all.h> + +#include <util/generic/string.h> + +#define LWTRACE_MESSAGEBUS_PROVIDER(PROBE, EVENT, GROUPS, TYPES, NAMES) \ + PROBE(Error, GROUPS("MessagebusRare"), TYPES(TString, TString, TString), NAMES("status", "address", "misc")) \ + PROBE(ServerUnknownVersion, GROUPS("MessagebusRare"), TYPES(TString, ui32), NAMES("address", "version")) \ + PROBE(Accepted, GROUPS("MessagebusRare"), TYPES(TString), NAMES("address")) \ + PROBE(Disconnected, GROUPS("MessagebusRare"), TYPES(TString), NAMES("address")) \ + PROBE(Read, GROUPS(), TYPES(ui32), NAMES("size")) \ + /**/ + +LWTRACE_DECLARE_PROVIDER(LWTRACE_MESSAGEBUS_PROVIDER) + +namespace NBus { + void InitBusLwtrace(); +} diff --git a/library/cpp/messagebus/memory.h b/library/cpp/messagebus/memory.h new file mode 100644 index 0000000000..b2c0544491 --- /dev/null +++ b/library/cpp/messagebus/memory.h @@ -0,0 +1,42 @@ +#pragma once + +#ifndef CACHE_LINE_SIZE +#define CACHE_LINE_SIZE 64 +#endif + +#define CONCAT(a, b) a##b +#define LABEL(a) CONCAT(UniqueName_, a) +#define UNIQUE_NAME LABEL(__LINE__) + +#define CACHE_LINE_PADDING char UNIQUE_NAME[CACHE_LINE_SIZE]; + +static inline void* MallocAligned(size_t size, size_t alignment) { + void** ptr = (void**)malloc(size + alignment + sizeof(size_t*)); + if (!ptr) { + return nullptr; + } + + size_t mask = ~(alignment - 1); + intptr_t roundedDown = intptr_t(ptr) & mask; + void** alignedPtr = (void**)(roundedDown + alignment); + alignedPtr[-1] = ptr; + return alignedPtr; +} + +static inline void FreeAligned(void* ptr) { + if (!ptr) { + return; + } + + void** typedPtr = (void**)ptr; + void* originalPtr = typedPtr[-1]; + free(originalPtr); +} + +static inline void* MallocCacheAligned(size_t size) { + return MallocAligned(size, CACHE_LINE_SIZE); +} + +static inline void FreeCacheAligned(void* ptr) { + return FreeAligned(ptr); +} diff --git a/library/cpp/messagebus/memory_ut.cpp b/library/cpp/messagebus/memory_ut.cpp new file mode 100644 index 0000000000..00654f28a1 --- /dev/null +++ b/library/cpp/messagebus/memory_ut.cpp @@ -0,0 +1,13 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "memory.h" + +Y_UNIT_TEST_SUITE(MallocAligned) { + Y_UNIT_TEST(Test) { + for (size_t size = 0; size < 1000; ++size) { + void* ptr = MallocAligned(size, 128); + UNIT_ASSERT(uintptr_t(ptr) % 128 == 0); + FreeAligned(ptr); + } + } +} diff --git a/library/cpp/messagebus/message.cpp b/library/cpp/messagebus/message.cpp new file mode 100644 index 0000000000..bfa7ed8e9b --- /dev/null +++ b/library/cpp/messagebus/message.cpp @@ -0,0 +1,198 @@ +#include "remote_server_connection.h" +#include "ybus.h" + +#include <util/random/random.h> +#include <util/string/printf.h> +#include <util/system/atomic.h> + +#include <string.h> + +using namespace NBus; + +namespace NBus { + using namespace NBus::NPrivate; + + TBusIdentity::TBusIdentity() + : MessageId(0) + , Size(0) + , Flags(0) + , LocalFlags(0) + { + } + + TBusIdentity::~TBusIdentity() { + // TODO: print local flags +#ifndef NDEBUG + Y_VERIFY(LocalFlags == 0, "local flags must be zero at this point; message type is %s", + MessageType.value_or("unknown").c_str()); +#else + Y_VERIFY(LocalFlags == 0, "local flags must be zero at this point"); +#endif + } + + TNetAddr TBusIdentity::GetNetAddr() const { + if (!!Connection) { + return Connection->GetAddr(); + } else { + Y_FAIL(); + } + } + + void TBusIdentity::Pack(char* dest) { + memcpy(dest, this, sizeof(TBusIdentity)); + LocalFlags = 0; + + // prevent decref + new (&Connection) TIntrusivePtr<TRemoteServerConnection>; + } + + void TBusIdentity::Unpack(const char* src) { + Y_VERIFY(LocalFlags == 0); + Y_VERIFY(!Connection); + + memcpy(this, src, sizeof(TBusIdentity)); + } + + void TBusHeader::GenerateId() { + for (;;) { + Id = RandomNumber<TBusKey>(); + // Skip reserved ids + if (IsBusKeyValid(Id)) + return; + } + } + + TBusMessage::TBusMessage(ui16 type, int approxsize) + //: TCtr("BusMessage") + : TRefCounted<TBusMessage, TAtomicCounter, TDelete>(1) + , LocalFlags(0) + , RequestSize(0) + , Data(nullptr) + { + Y_UNUSED(approxsize); + GetHeader()->Type = type; + DoReset(); + } + + TBusMessage::TBusMessage(ECreateUninitialized) + //: TCtr("BusMessage") + : TRefCounted<TBusMessage, TAtomicCounter, TDelete>(1) + , LocalFlags(0) + , Data(nullptr) + { + } + + TString TBusMessage::Describe() const { + return Sprintf("object type: %s, message type: %d", TypeName(*this).data(), int(GetHeader()->Type)); + } + + TBusMessage::~TBusMessage() { +#ifndef NDEBUG + Y_VERIFY(GetHeader()->Id != YBUS_KEYINVALID, "must not be invalid key, message type: %d, ", int(Type)); + GetHeader()->Id = YBUS_KEYINVALID; + Data = (void*)17; + CheckClean(); +#endif + } + + void TBusMessage::DoReset() { + GetHeader()->SendTime = 0; + GetHeader()->Size = 0; + GetHeader()->FlagsInternal = 0; + GetHeader()->GenerateId(); + GetHeader()->SetVersionInternal(); + } + + void TBusMessage::Reset() { + CheckClean(); + DoReset(); + } + + void TBusMessage::CheckClean() const { + if (Y_UNLIKELY(LocalFlags != 0)) { + TString describe = Describe(); + TString localFlags = LocalFlagSetToString(LocalFlags); + Y_FAIL("message local flags must be zero, got: %s, message: %s", localFlags.data(), describe.data()); + } + } + + /////////////////////////////////////////////////////// + /// \brief Unpacks header from network order + + /// \todo ntoh instead of memcpy + int TBusHeader::ReadHeader(TArrayRef<const char> data) { + Y_ASSERT(data.size() >= sizeof(TBusHeader)); + memcpy(this, data.data(), sizeof(TBusHeader)); + return sizeof(TBusHeader); + } + + /////////////////////////////////////////////////////// + /// \brief Packs header to network order + + ////////////////////////////////////////////////////////// + /// \brief serialize message identity to be used to construct reply message + + /// function stores messageid, flags and connection reply address into the buffer + /// that can later be used to construct a reply to the message + void TBusMessage::GetIdentity(TBusIdentity& data) const { + data.MessageId = GetHeader()->Id; + data.Size = GetHeader()->Size; + data.Flags = GetHeader()->FlagsInternal; + //data.LocalFlags = LocalFlags; + } + + //////////////////////////////////////////////////////////// + /// \brief set message identity from serialized form + + /// function restores messageid, flags and connection reply address from the buffer + /// into the reply message + void TBusMessage::SetIdentity(const TBusIdentity& data) { + // TODO: wrong assertion: YBUS_KEYMIN is 0 + Y_ASSERT(data.MessageId != 0); + bool compressed = IsCompressed(); + GetHeader()->Id = data.MessageId; + GetHeader()->FlagsInternal = data.Flags; + LocalFlags = data.LocalFlags & ~MESSAGE_IN_WORK; + ReplyTo = data.Connection->PeerAddrSocketAddr; + SetCompressed(compressed || IsCompressedResponse()); + } + + void TBusMessage::SetCompressed(bool v) { + if (v) { + GetHeader()->FlagsInternal |= MESSAGE_COMPRESS_INTERNAL; + } else { + GetHeader()->FlagsInternal &= ~(MESSAGE_COMPRESS_INTERNAL); + } + } + + void TBusMessage::SetCompressedResponse(bool v) { + if (v) { + GetHeader()->FlagsInternal |= MESSAGE_COMPRESS_RESPONSE; + } else { + GetHeader()->FlagsInternal &= ~(MESSAGE_COMPRESS_RESPONSE); + } + } + + TString TBusIdentity::ToString() const { + TStringStream ss; + ss << "msg-id=" << MessageId + << " size=" << Size; + if (!!Connection) { + ss << " conn=" << Connection->GetAddr(); + } + ss + << " flags=" << Flags + << " local-flags=" << LocalFlags +#ifndef NDEBUG + << " msg-type= " << MessageType.value_or("unknown").c_str() +#endif + ; + return ss.Str(); + } + +} + +template <> +void Out<TBusIdentity>(IOutputStream& os, TTypeTraits<TBusIdentity>::TFuncParam ident) { + os << ident.ToString(); +} diff --git a/library/cpp/messagebus/message.h b/library/cpp/messagebus/message.h new file mode 100644 index 0000000000..005ca10c65 --- /dev/null +++ b/library/cpp/messagebus/message.h @@ -0,0 +1,272 @@ +#pragma once + +#include "base.h" +#include "local_flags.h" +#include "message_status.h" +#include "netaddr.h" +#include "socket_addr.h" + +#include <util/generic/array_ref.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/system/defaults.h> +#include <util/system/type_name.h> +#include <util/system/yassert.h> + +#include <optional> +#include <typeinfo> + +namespace NBus { + /////////////////////////////////////////////////////////////////// + /// \brief Structure to preserve identity from message to reply + struct TBusIdentity : TNonCopyable { + friend class TBusMessage; + friend class NPrivate::TRemoteServerSession; + friend struct NPrivate::TClientRequestImpl; + friend class TOnMessageContext; + + // TODO: make private + TBusKey MessageId; + + private: + ui32 Size; + TIntrusivePtr<NPrivate::TRemoteServerConnection> Connection; + ui16 Flags; + ui32 LocalFlags; + TInstant RecvTime; + +#ifndef NDEBUG + std::optional<TString> MessageType; +#endif + + private: + // TODO: drop + TNetAddr GetNetAddr() const; + + public: + void Pack(char* dest); + void Unpack(const char* src); + + bool IsInWork() const { + return LocalFlags & NPrivate::MESSAGE_IN_WORK; + } + + // for internal use only + void BeginWork() { + SetInWork(true); + } + + // for internal use only + void EndWork() { + SetInWork(false); + } + + TBusIdentity(); + ~TBusIdentity(); + + void Swap(TBusIdentity& that) { + DoSwap(MessageId, that.MessageId); + DoSwap(Size, that.Size); + DoSwap(Connection, that.Connection); + DoSwap(Flags, that.Flags); + DoSwap(LocalFlags, that.LocalFlags); + DoSwap(RecvTime, that.RecvTime); +#ifndef NDEBUG + DoSwap(MessageType, that.MessageType); +#endif + } + + TString ToString() const; + + private: + void SetInWork(bool inWork) { + if (LocalFlags == 0 && inWork) { + LocalFlags = NPrivate::MESSAGE_IN_WORK; + } else if (LocalFlags == NPrivate::MESSAGE_IN_WORK && !inWork) { + LocalFlags = 0; + } else { + Y_FAIL("impossible combination of flag and parameter: %s %d", + inWork ? "true" : "false", unsigned(LocalFlags)); + } + } + + void SetMessageType(const std::type_info& messageTypeInfo) { +#ifndef NDEBUG + Y_VERIFY(!MessageType, "state check"); + MessageType = TypeName(messageTypeInfo); +#else + Y_UNUSED(messageTypeInfo); +#endif + } + }; + + static const size_t BUS_IDENTITY_PACKED_SIZE = sizeof(TBusIdentity); + + /////////////////////////////////////////////////////////////// + /// \brief Message flags in TBusHeader.Flags + enum EMessageFlags { + MESSAGE_COMPRESS_INTERNAL = 0x8000, ///< message is compressed + MESSAGE_COMPRESS_RESPONSE = 0x4000, ///< message prefers compressed response + MESSAGE_VERSION_INTERNAL = 0x00F0, ///< these bits are used as version + }; + +////////////////////////////////////////////////////////// +/// \brief Message header present in all message send and received + +/// This header is send into the wire. +/// \todo fix for low/high end, 32/64bit some day +#pragma pack(1) + struct TBusHeader { + friend class TBusMessage; + + TBusKey Id = 0; ///< unique message ID + ui32 Size = 0; ///< total size of the message + TBusInstant SendTime = 0; ///< time the message was sent + ui16 FlagsInternal = 0; ///< TRACE is one of the flags + ui16 Type = 0; ///< to be used by TBusProtocol + + int GetVersionInternal() { + return (FlagsInternal & MESSAGE_VERSION_INTERNAL) >> 4; + } + void SetVersionInternal(unsigned ver = YBUS_VERSION) { + FlagsInternal |= (ver << 4); + } + + public: + TBusHeader() { + } + TBusHeader(TArrayRef<const char> data) { + ReadHeader(data); + } + + private: + /// function for serialization/deserialization of the header + /// returns number of bytes written/read + int ReadHeader(TArrayRef<const char> data); + + void GenerateId(); + }; +#pragma pack() + +#define TBUSMAX_MESSAGE 26 * 1024 * 1024 + sizeof(NBus::TBusHeader) ///< is't it enough? +#define TBUSMIN_MESSAGE sizeof(NBus::TBusHeader) ///< can't be less then header + + inline bool IsVersionNegotiation(const NBus::TBusHeader& header) { + return header.Id == 0 && header.Size == sizeof(TBusHeader); + } + + ////////////////////////////////////////////////////////// + /// \brief Base class for all messages passed in the system + + enum ECreateUninitialized { + MESSAGE_CREATE_UNINITIALIZED, + }; + + class TBusMessage + : protected TBusHeader, + public TRefCounted<TBusMessage, TAtomicCounter, TDelete>, + private TNonCopyable { + friend class TLocalSession; + friend struct ::NBus::NPrivate::TBusSessionImpl; + friend class ::NBus::NPrivate::TRemoteServerSession; + friend class ::NBus::NPrivate::TRemoteClientSession; + friend class ::NBus::NPrivate::TRemoteConnection; + friend class ::NBus::NPrivate::TRemoteClientConnection; + friend class ::NBus::NPrivate::TRemoteServerConnection; + friend struct ::NBus::NPrivate::TBusMessagePtrAndHeader; + + private: + ui32 LocalFlags; + + /// connection identity for reply set by PushMessage() + NPrivate::TBusSocketAddr ReplyTo; + // server-side response only, hack + ui32 RequestSize; + + TInstant RecvTime; + + public: + /// constructor to create messages on sending end + TBusMessage(ui16 type, int approxsize = sizeof(TBusHeader)); + + /// constructor with serialzed data to examine the header + TBusMessage(ECreateUninitialized); + + // slow, for diagnostics only + virtual TString Describe() const; + + // must be called if this message object needs to be reused + void Reset(); + + void CheckClean() const; + + void SetCompressed(bool); + void SetCompressedResponse(bool); + + private: + bool IsCompressed() const { + return FlagsInternal & MESSAGE_COMPRESS_INTERNAL; + } + bool IsCompressedResponse() const { + return FlagsInternal & MESSAGE_COMPRESS_RESPONSE; + } + + public: + /// can have private data to destroy + virtual ~TBusMessage(); + + /// returns header of the message + TBusHeader* GetHeader() { + return this; + } + const TBusHeader* GetHeader() const { + return this; + } + + /// helper to return type for protocol object to unpack object + static ui16 GetType(TArrayRef<const char> data) { + return TBusHeader(data).Type; + } + + /// returns payload data + static TArrayRef<const char> GetPayload(TArrayRef<const char> data) { + return data.Slice(sizeof(TBusHeader)); + } + + private: + void DoReset(); + + /// serialize message identity to be used to construct reply message + void GetIdentity(TBusIdentity& ident) const; + + /// set message identity from serialized form + void SetIdentity(const TBusIdentity& ident); + + public: + TNetAddr GetReplyTo() const { + return ReplyTo.ToNetAddr(); + } + + /// store of application specific data, never serialized into wire + void* Data; + }; + + class TBusMessageAutoPtr: public TAutoPtr<TBusMessage> { + public: + TBusMessageAutoPtr() { + } + + TBusMessageAutoPtr(TBusMessage* message) + : TAutoPtr<TBusMessage>(message) + { + } + + template <typename T1> + TBusMessageAutoPtr(const TAutoPtr<T1>& that) + : TAutoPtr<TBusMessage>(that.Release()) + { + } + }; + +} diff --git a/library/cpp/messagebus/message_counter.cpp b/library/cpp/messagebus/message_counter.cpp new file mode 100644 index 0000000000..04d9343f6a --- /dev/null +++ b/library/cpp/messagebus/message_counter.cpp @@ -0,0 +1,46 @@ +#include "message_counter.h" + +#include <util/stream/str.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TMessageCounter::TMessageCounter() + : BytesData(0) + , BytesNetwork(0) + , Count(0) + , CountCompressed(0) + , CountCompressionRequests(0) +{ +} + +TMessageCounter& TMessageCounter::operator+=(const TMessageCounter& that) { + BytesData += that.BytesData; + BytesNetwork += that.BytesNetwork; + Count += that.Count; + CountCompressed += that.CountCompressed; + CountCompressionRequests += that.CountCompressionRequests; + return *this; +} + +TString TMessageCounter::ToString(bool reader) const { + if (reader) { + Y_ASSERT(CountCompressionRequests == 0); + } + + TStringStream readValue; + readValue << Count; + if (CountCompressionRequests != 0 || CountCompressed != 0) { + readValue << " (" << CountCompressed << " compr"; + if (!reader) { + readValue << ", " << CountCompressionRequests << " compr reqs"; + } + readValue << ")"; + } + readValue << ", "; + readValue << BytesData << "b"; + if (BytesNetwork != BytesData) { + readValue << " (" << BytesNetwork << "b network)"; + } + return readValue.Str(); +} diff --git a/library/cpp/messagebus/message_counter.h b/library/cpp/messagebus/message_counter.h new file mode 100644 index 0000000000..e4be1180b0 --- /dev/null +++ b/library/cpp/messagebus/message_counter.h @@ -0,0 +1,36 @@ +#pragma once + +#include <util/generic/string.h> + +#include <cstddef> + +namespace NBus { + namespace NPrivate { + struct TMessageCounter { + size_t BytesData; + size_t BytesNetwork; + size_t Count; + size_t CountCompressed; + size_t CountCompressionRequests; // reader only + + void AddMessage(size_t bytesData, size_t bytesCompressed, bool Compressed, bool compressionRequested) { + BytesData += bytesData; + BytesNetwork += bytesCompressed; + Count += 1; + if (Compressed) { + CountCompressed += 1; + } + if (compressionRequested) { + CountCompressionRequests += 1; + } + } + + TMessageCounter& operator+=(const TMessageCounter& that); + + TString ToString(bool reader) const; + + TMessageCounter(); + }; + + } +} diff --git a/library/cpp/messagebus/message_ptr_and_header.h b/library/cpp/messagebus/message_ptr_and_header.h new file mode 100644 index 0000000000..9b4e2fd270 --- /dev/null +++ b/library/cpp/messagebus/message_ptr_and_header.h @@ -0,0 +1,36 @@ +#pragma once + +#include "message.h" +#include "nondestroying_holder.h" + +#include <util/generic/noncopyable.h> +#include <util/generic/utility.h> + +namespace NBus { + namespace NPrivate { + struct TBusMessagePtrAndHeader : TNonCopyable { + TNonDestroyingHolder<TBusMessage> MessagePtr; + TBusHeader Header; + ui32 LocalFlags; + + TBusMessagePtrAndHeader() + : LocalFlags() + { + } + + explicit TBusMessagePtrAndHeader(TBusMessage* messagePtr) + : MessagePtr(messagePtr) + , Header(*MessagePtr->GetHeader()) + , LocalFlags(MessagePtr->LocalFlags) + { + } + + void Swap(TBusMessagePtrAndHeader& that) { + DoSwap(MessagePtr, that.MessagePtr); + DoSwap(Header, that.Header); + DoSwap(LocalFlags, that.LocalFlags); + } + }; + + } +} diff --git a/library/cpp/messagebus/message_status.cpp b/library/cpp/messagebus/message_status.cpp new file mode 100644 index 0000000000..41ad62b73f --- /dev/null +++ b/library/cpp/messagebus/message_status.cpp @@ -0,0 +1,13 @@ +#include "message_status.h" + +using namespace NBus; + +const char* NBus::MessageStatusDescription(EMessageStatus messageStatus) { +#define MESSAGE_STATUS_DESCRIPTION_GEN(name, description, ...) \ + if (messageStatus == name) \ + return description; + + MESSAGE_STATUS_MAP(MESSAGE_STATUS_DESCRIPTION_GEN) + + return "Unknown"; +} diff --git a/library/cpp/messagebus/message_status.h b/library/cpp/messagebus/message_status.h new file mode 100644 index 0000000000..e1878960b3 --- /dev/null +++ b/library/cpp/messagebus/message_status.h @@ -0,0 +1,57 @@ +#pragma once + +#include "codegen.h" +#include "defs.h" + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +namespace NBus { +//////////////////////////////////////////////////////////////// +/// \brief Status of message communication + +#define MESSAGE_STATUS_MAP(XX) \ + XX(MESSAGE_OK, "OK") \ + XX(MESSAGE_CONNECT_FAILED, "Connect failed") \ + XX(MESSAGE_TIMEOUT, "Message timed out") \ + XX(MESSAGE_SERVICE_UNKNOWN, "Locator hasn't found address for key") \ + XX(MESSAGE_BUSY, "Too many messages in flight") \ + XX(MESSAGE_UNKNOWN, "Request not found by id, usually it means that message is timed out") \ + XX(MESSAGE_DESERIALIZE_ERROR, "Deserialize by TBusProtocol failed") \ + XX(MESSAGE_HEADER_CORRUPTED, "Header corrupted") \ + XX(MESSAGE_DECOMPRESS_ERROR, "Failed to decompress") \ + XX(MESSAGE_MESSAGE_TOO_LARGE, "Message too large") \ + XX(MESSAGE_REPLY_FAILED, "Unused by messagebus, used by other code") \ + XX(MESSAGE_DELIVERY_FAILED, "Message delivery failed because connection is closed") \ + XX(MESSAGE_INVALID_VERSION, "Protocol error: invalid version") \ + XX(MESSAGE_SERVICE_TOOMANY, "Locator failed to resolve address") \ + XX(MESSAGE_SHUTDOWN, "Failure because of either session or connection shutdown") \ + XX(MESSAGE_DONT_ASK, "Internal error code used by modules") + + enum EMessageStatus { + MESSAGE_STATUS_MAP(ENUM_VALUE_GEN_NO_VALUE) + MESSAGE_STATUS_COUNT + }; + + ENUM_TO_STRING(EMessageStatus, MESSAGE_STATUS_MAP) + + const char* MessageStatusDescription(EMessageStatus); + + static inline const char* GetMessageStatus(EMessageStatus status) { + return ToCString(status); + } + + // For lwtrace + struct TMessageStatusField { + typedef int TStoreType; + typedef int TFuncParam; + + static void ToString(int value, TString* out) { + *out = GetMessageStatus((NBus::EMessageStatus)value); + } + + static int ToStoreType(int value) { + return value; + } + }; + +} // ns diff --git a/library/cpp/messagebus/message_status_counter.cpp b/library/cpp/messagebus/message_status_counter.cpp new file mode 100644 index 0000000000..891c8f5bb2 --- /dev/null +++ b/library/cpp/messagebus/message_status_counter.cpp @@ -0,0 +1,71 @@ +#include "message_status_counter.h" + +#include "key_value_printer.h" +#include "text_utils.h" + +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +#include <util/stream/str.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TMessageStatusCounter::TMessageStatusCounter() { + Zero(Counts); +} + +TMessageStatusCounter& TMessageStatusCounter::operator+=(const TMessageStatusCounter& that) { + for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) { + Counts[i] += that.Counts[i]; + } + return *this; +} + +TString TMessageStatusCounter::PrintToString() const { + TStringStream ss; + TKeyValuePrinter p; + bool hasNonZeros = false; + bool hasZeros = false; + for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) { + if (i == MESSAGE_OK) { + Y_VERIFY(Counts[i] == 0); + continue; + } + if (Counts[i] != 0) { + p.AddRow(EMessageStatus(i), Counts[i]); + const char* description = MessageStatusDescription(EMessageStatus(i)); + // TODO: add third column + Y_UNUSED(description); + + hasNonZeros = true; + } else { + hasZeros = true; + } + } + if (!hasNonZeros) { + ss << "message status counts are zeros\n"; + } else { + if (hasZeros) { + ss << "message status counts are zeros, except:\n"; + } else { + ss << "message status counts:\n"; + } + ss << IndentText(p.PrintToString()); + } + return ss.Str(); +} + +void TMessageStatusCounter::FillErrorsProtobuf(TConnectionStatusMonRecord* status) const { + status->clear_errorcountbystatus(); + for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) { + if (i == MESSAGE_OK) { + Y_VERIFY(Counts[i] == 0); + continue; + } + if (Counts[i] != 0) { + TMessageStatusRecord* description = status->add_errorcountbystatus(); + description->SetStatus(TMessageStatusCounter::MessageStatusToProtobuf((EMessageStatus)i)); + description->SetCount(Counts[i]); + } + } +} diff --git a/library/cpp/messagebus/message_status_counter.h b/library/cpp/messagebus/message_status_counter.h new file mode 100644 index 0000000000..e8ba2fdd31 --- /dev/null +++ b/library/cpp/messagebus/message_status_counter.h @@ -0,0 +1,36 @@ +#pragma once + +#include "message_status.h" + +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +#include <util/generic/string.h> + +#include <array> + +namespace NBus { + namespace NPrivate { + struct TMessageStatusCounter { + static TMessageStatusRecord::EMessageStatus MessageStatusToProtobuf(EMessageStatus status) { + return (TMessageStatusRecord::EMessageStatus)status; + } + + std::array<unsigned, MESSAGE_STATUS_COUNT> Counts; + + unsigned& operator[](EMessageStatus index) { + return Counts[index]; + } + const unsigned& operator[](EMessageStatus index) const { + return Counts[index]; + } + + TMessageStatusCounter(); + + TMessageStatusCounter& operator+=(const TMessageStatusCounter&); + + TString PrintToString() const; + void FillErrorsProtobuf(TConnectionStatusMonRecord*) const; + }; + + } +} diff --git a/library/cpp/messagebus/message_status_counter_ut.cpp b/library/cpp/messagebus/message_status_counter_ut.cpp new file mode 100644 index 0000000000..9598651329 --- /dev/null +++ b/library/cpp/messagebus/message_status_counter_ut.cpp @@ -0,0 +1,23 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "message_status_counter.h" + +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +Y_UNIT_TEST_SUITE(MessageStatusCounter) { + Y_UNIT_TEST(MessageStatusConversion) { + const ::google::protobuf::EnumDescriptor* descriptor = + TMessageStatusRecord_EMessageStatus_descriptor(); + + for (int i = 0; i < MESSAGE_STATUS_COUNT; i++) { + const ::google::protobuf::EnumValueDescriptor* valueDescriptor = + descriptor->FindValueByName(ToString((EMessageStatus)i)); + UNIT_ASSERT_UNEQUAL(valueDescriptor, nullptr); + UNIT_ASSERT_EQUAL(valueDescriptor->number(), i); + } + UNIT_ASSERT_EQUAL(MESSAGE_STATUS_COUNT, descriptor->value_count()); + } +} diff --git a/library/cpp/messagebus/messqueue.cpp b/library/cpp/messagebus/messqueue.cpp new file mode 100644 index 0000000000..3474d62705 --- /dev/null +++ b/library/cpp/messagebus/messqueue.cpp @@ -0,0 +1,198 @@ +#include "key_value_printer.h" +#include "mb_lwtrace.h" +#include "remote_client_session.h" +#include "remote_server_session.h" +#include "ybus.h" + +#include <util/generic/singleton.h> + +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NActor; + +TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, TExecutorPtr executor, TBusLocator* locator, const char* name) { + return new TBusMessageQueue(config, executor, locator, name); +} + +TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, TBusLocator* locator, const char* name) { + TExecutor::TConfig executorConfig; + executorConfig.WorkerCount = config.NumWorkers; + executorConfig.Name = name; + TExecutorPtr executor = new TExecutor(executorConfig); + return CreateMessageQueue(config, executor, locator, name); +} + +TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, const char* name) { + return CreateMessageQueue(config, new TBusLocator, name); +} + +TBusMessageQueuePtr NBus::CreateMessageQueue(TExecutorPtr executor, const char* name) { + return CreateMessageQueue(TBusQueueConfig(), executor, new TBusLocator, name); +} + +TBusMessageQueuePtr NBus::CreateMessageQueue(const char* name) { + TBusQueueConfig config; + return CreateMessageQueue(config, name); +} + +namespace { + TBusQueueConfig QueueConfigFillDefaults(const TBusQueueConfig& orig, const TString& name) { + TBusQueueConfig patched = orig; + if (!patched.Name) { + patched.Name = name; + } + return patched; + } +} + +TBusMessageQueue::TBusMessageQueue(const TBusQueueConfig& config, TExecutorPtr executor, TBusLocator* locator, const char* name) + : Config(QueueConfigFillDefaults(config, name)) + , Locator(locator) + , WorkQueue(executor) + , Running(1) +{ + InitBusLwtrace(); + InitNetworkSubSystem(); +} + +TBusMessageQueue::~TBusMessageQueue() { + Stop(); +} + +void TBusMessageQueue::Stop() { + if (!AtomicCas(&Running, 0, 1)) { + ShutdownComplete.WaitI(); + return; + } + + Scheduler.Stop(); + + DestroyAllSessions(); + + WorkQueue->Stop(); + + ShutdownComplete.Signal(); +} + +bool TBusMessageQueue::IsRunning() { + return AtomicGet(Running); +} + +TBusMessageQueueStatus TBusMessageQueue::GetStatusRecordInternal() const { + TBusMessageQueueStatus r; + r.ExecutorStatus = WorkQueue->GetStatusRecordInternal(); + r.Config = Config; + return r; +} + +TString TBusMessageQueue::GetStatusSelf() const { + return GetStatusRecordInternal().PrintToString(); +} + +TString TBusMessageQueue::GetStatusSingleLine() const { + return WorkQueue->GetStatusSingleLine(); +} + +TString TBusMessageQueue::GetStatus(ui16 flags) const { + TStringStream ss; + + ss << GetStatusSelf(); + + TList<TIntrusivePtr<TBusSessionImpl>> sessions; + { + TGuard<TMutex> scope(Lock); + sessions = Sessions; + } + + for (TList<TIntrusivePtr<TBusSessionImpl>>::const_iterator session = sessions.begin(); + session != sessions.end(); ++session) { + ss << Endl; + ss << (*session)->GetStatus(flags); + } + + ss << Endl; + ss << "object counts (not necessarily owned by this message queue):" << Endl; + TKeyValuePrinter p; + p.AddRow("TRemoteClientConnection", TObjectCounter<TRemoteClientConnection>::ObjectCount(), false); + p.AddRow("TRemoteServerConnection", TObjectCounter<TRemoteServerConnection>::ObjectCount(), false); + p.AddRow("TRemoteClientSession", TObjectCounter<TRemoteClientSession>::ObjectCount(), false); + p.AddRow("TRemoteServerSession", TObjectCounter<TRemoteServerSession>::ObjectCount(), false); + p.AddRow("NEventLoop::TEventLoop", TObjectCounter<NEventLoop::TEventLoop>::ObjectCount(), false); + p.AddRow("NEventLoop::TChannel", TObjectCounter<NEventLoop::TChannel>::ObjectCount(), false); + ss << p.PrintToString(); + + return ss.Str(); +} + +TBusClientSessionPtr TBusMessageQueue::CreateSource(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, const TString& name) { + TRemoteClientSessionPtr session(new TRemoteClientSession(this, proto, handler, config, name)); + Add(session.Get()); + return session.Get(); +} + +TBusServerSessionPtr TBusMessageQueue::CreateDestination(TBusProtocol* proto, IBusServerHandler* handler, const TBusClientSessionConfig& config, const TString& name) { + TRemoteServerSessionPtr session(new TRemoteServerSession(this, proto, handler, config, name)); + try { + int port = config.ListenPort; + if (port == 0) { + port = Locator->GetLocalPort(proto->GetService()); + } + if (port == 0) { + port = proto->GetPort(); + } + + session->Listen(port, this); + + Add(session.Get()); + return session.Release(); + } catch (...) { + Y_FAIL("create destination failure: %s", CurrentExceptionMessage().c_str()); + } +} + +TBusServerSessionPtr TBusMessageQueue::CreateDestination(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, const TVector<TBindResult>& bindTo, const TString& name) { + TRemoteServerSessionPtr session(new TRemoteServerSession(this, proto, handler, config, name)); + try { + session->Listen(bindTo, this); + Add(session.Get()); + return session.Release(); + } catch (...) { + Y_FAIL("create destination failure: %s", CurrentExceptionMessage().c_str()); + } +} + +void TBusMessageQueue::Add(TIntrusivePtr<TBusSessionImpl> session) { + TGuard<TMutex> scope(Lock); + Sessions.push_back(session); +} + +void TBusMessageQueue::Remove(TBusSession* session) { + TGuard<TMutex> scope(Lock); + TList<TIntrusivePtr<TBusSessionImpl>>::iterator it = std::find(Sessions.begin(), Sessions.end(), session); + Y_VERIFY(it != Sessions.end(), "do not destroy session twice"); + Sessions.erase(it); +} + +void TBusMessageQueue::Destroy(TBusSession* session) { + session->Shutdown(); +} + +void TBusMessageQueue::DestroyAllSessions() { + TList<TIntrusivePtr<TBusSessionImpl>> sessions; + { + TGuard<TMutex> scope(Lock); + sessions = Sessions; + } + + for (auto& session : sessions) { + Y_VERIFY(session->IsDown(), "Session must be shut down prior to queue shutdown"); + } +} + +void TBusMessageQueue::Schedule(IScheduleItemAutoPtr i) { + Scheduler.Schedule(i); +} + +TString TBusMessageQueue::GetNameInternal() const { + return Config.Name; +} diff --git a/library/cpp/messagebus/misc/atomic_box.h b/library/cpp/messagebus/misc/atomic_box.h new file mode 100644 index 0000000000..401621f933 --- /dev/null +++ b/library/cpp/messagebus/misc/atomic_box.h @@ -0,0 +1,34 @@ +#pragma once + +#include <util/system/atomic.h> + +// TAtomic with human interface +template <typename T> +class TAtomicBox { +private: + union { + TAtomic Value; + // when T is enum, it is convenient to inspect its content in gdb + T ValueForDebugger; + }; + + static_assert(sizeof(T) <= sizeof(TAtomic), "expect sizeof(T) <= sizeof(TAtomic)"); + +public: + TAtomicBox(T value = T()) + : Value(value) + { + } + + void Set(T value) { + AtomicSet(Value, (TAtomic)value); + } + + T Get() const { + return (T)AtomicGet(Value); + } + + bool CompareAndSet(T expected, T set) { + return AtomicCas(&Value, (TAtomicBase)set, (TAtomicBase)expected); + } +}; diff --git a/library/cpp/messagebus/misc/granup.h b/library/cpp/messagebus/misc/granup.h new file mode 100644 index 0000000000..36ecfebc93 --- /dev/null +++ b/library/cpp/messagebus/misc/granup.h @@ -0,0 +1,50 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/system/guard.h> +#include <util/system/mutex.h> +#include <util/system/spinlock.h> + +namespace NBus { + template <typename TItem, typename TLocker = TSpinLock> + class TGranUp { + public: + TGranUp(TDuration gran) + : Gran(gran) + , Next(TInstant::MicroSeconds(0)) + { + } + + template <typename TFunctor> + void Update(TFunctor functor, TInstant now, bool force = false) { + if (force || now > Next) + Set(functor(), now); + } + + void Update(const TItem& item, TInstant now, bool force = false) { + if (force || now > Next) + Set(item, now); + } + + TItem Get() const noexcept { + TGuard<TLocker> guard(Lock); + + return Item; + } + + protected: + void Set(const TItem& item, TInstant now) { + TGuard<TLocker> guard(Lock); + + Item = item; + + Next = now + Gran; + } + + private: + const TDuration Gran; + TLocker Lock; + TItem Item; + TInstant Next; + }; +} diff --git a/library/cpp/messagebus/misc/test_sync.h b/library/cpp/messagebus/misc/test_sync.h new file mode 100644 index 0000000000..be3f4f20b8 --- /dev/null +++ b/library/cpp/messagebus/misc/test_sync.h @@ -0,0 +1,75 @@ +#pragma once + +#include <util/system/condvar.h> +#include <util/system/mutex.h> + +class TTestSync { +private: + unsigned Current; + + TMutex Mutex; + TCondVar CondVar; + +public: + TTestSync() + : Current(0) + { + } + + void Inc() { + TGuard<TMutex> guard(Mutex); + + DoInc(); + CondVar.BroadCast(); + } + + unsigned Get() { + TGuard<TMutex> guard(Mutex); + + return Current; + } + + void WaitFor(unsigned n) { + TGuard<TMutex> guard(Mutex); + + Y_VERIFY(Current <= n, "too late, waiting for %d, already %d", n, Current); + + while (n > Current) { + CondVar.WaitI(Mutex); + } + } + + void WaitForAndIncrement(unsigned n) { + TGuard<TMutex> guard(Mutex); + + Y_VERIFY(Current <= n, "too late, waiting for %d, already %d", n, Current); + + while (n > Current) { + CondVar.WaitI(Mutex); + } + + DoInc(); + CondVar.BroadCast(); + } + + void CheckAndIncrement(unsigned n) { + TGuard<TMutex> guard(Mutex); + + Y_VERIFY(Current == n, "must be %d, currently %d", n, Current); + + DoInc(); + CondVar.BroadCast(); + } + + void Check(unsigned n) { + TGuard<TMutex> guard(Mutex); + + Y_VERIFY(Current == n, "must be %d, currently %d", n, Current); + } + +private: + void DoInc() { + unsigned r = ++Current; + Y_UNUSED(r); + } +}; diff --git a/library/cpp/messagebus/misc/tokenquota.h b/library/cpp/messagebus/misc/tokenquota.h new file mode 100644 index 0000000000..190547fa54 --- /dev/null +++ b/library/cpp/messagebus/misc/tokenquota.h @@ -0,0 +1,83 @@ +#pragma once + +#include <util/system/atomic.h> + +namespace NBus { + /* Consumer and feeder quota model impl. + + Consumer thread only calls: + Acquire(), fetches tokens for usage from bucket; + Consume(), eats given amount of tokens, must not + be greater than Value() items; + + Other threads (feeders) calls: + Return(), put used tokens back to bucket; + */ + + class TTokenQuota { + public: + TTokenQuota(bool enabled, size_t tokens, size_t wake) + : Enabled(tokens > 0 ? enabled : false) + , Acquired(0) + , WakeLev(wake < 1 ? Max<size_t>(1, tokens / 2) : 0) + , Tokens_(tokens) + { + Y_UNUSED(padd_); + } + + bool Acquire(TAtomic level = 1, bool force = false) { + level = Max(TAtomicBase(level), TAtomicBase(1)); + + if (Enabled && (Acquired < level || force)) { + Acquired += AtomicSwap(&Tokens_, 0); + } + + return !Enabled || Acquired >= level; + } + + void Consume(size_t items) { + if (Enabled) { + Y_ASSERT(Acquired >= TAtomicBase(items)); + + Acquired -= items; + } + } + + bool Return(size_t items_) noexcept { + if (!Enabled || items_ == 0) + return false; + + const TAtomic items = items_; + const TAtomic value = AtomicAdd(Tokens_, items); + + return (value - items < WakeLev && value >= WakeLev); + } + + bool IsEnabled() const noexcept { + return Enabled; + } + + bool IsAboveWake() const noexcept { + return !Enabled || (WakeLev <= AtomicGet(Tokens_)); + } + + size_t Tokens() const noexcept { + return Acquired + AtomicGet(Tokens_); + } + + size_t Check(const TAtomic level) const noexcept { + return !Enabled || level <= Acquired; + } + + private: + bool Enabled; + TAtomicBase Acquired; + const TAtomicBase WakeLev; + TAtomic Tokens_; + + /* This padd requires for align Tokens_ member on its own + CPU cacheline. */ + + ui64 padd_; + }; +} diff --git a/library/cpp/messagebus/misc/weak_ptr.h b/library/cpp/messagebus/misc/weak_ptr.h new file mode 100644 index 0000000000..70fdeb0e2a --- /dev/null +++ b/library/cpp/messagebus/misc/weak_ptr.h @@ -0,0 +1,99 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/system/mutex.h> + +template <typename T> +struct TWeakPtr; + +template <typename TSelf> +struct TWeakRefCounted { + template <typename> + friend struct TWeakPtr; + +private: + struct TRef: public TAtomicRefCount<TRef> { + TMutex Mutex; + TSelf* Outer; + + TRef(TSelf* outer) + : Outer(outer) + { + } + + void Release() { + TGuard<TMutex> g(Mutex); + Y_ASSERT(!!Outer); + Outer = nullptr; + } + + TIntrusivePtr<TSelf> Get() { + TGuard<TMutex> g(Mutex); + Y_ASSERT(!Outer || Outer->RefCount() > 0); + return Outer; + } + }; + + TAtomicCounter Counter; + TIntrusivePtr<TRef> RefPtr; + +public: + TWeakRefCounted() + : RefPtr(new TRef(static_cast<TSelf*>(this))) + { + } + + void Ref() { + Counter.Inc(); + } + + void UnRef() { + if (Counter.Dec() == 0) { + RefPtr->Release(); + + // drop is to prevent dtor from reading it + RefPtr.Drop(); + + delete static_cast<TSelf*>(this); + } + } + + void DecRef() { + Counter.Dec(); + } + + unsigned RefCount() const { + return Counter.Val(); + } +}; + +template <typename T> +struct TWeakPtr { +private: + typedef TIntrusivePtr<typename T::TRef> TRefPtr; + TRefPtr RefPtr; + +public: + TWeakPtr() { + } + + TWeakPtr(T* t) { + if (!!t) { + RefPtr = t->RefPtr; + } + } + + TWeakPtr(TIntrusivePtr<T> t) { + if (!!t) { + RefPtr = t->RefPtr; + } + } + + TIntrusivePtr<T> Get() { + if (!RefPtr) { + return nullptr; + } else { + return RefPtr->Get(); + } + } +}; diff --git a/library/cpp/messagebus/misc/weak_ptr_ut.cpp b/library/cpp/messagebus/misc/weak_ptr_ut.cpp new file mode 100644 index 0000000000..5a325278db --- /dev/null +++ b/library/cpp/messagebus/misc/weak_ptr_ut.cpp @@ -0,0 +1,46 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "weak_ptr.h" + +Y_UNIT_TEST_SUITE(TWeakPtrTest) { + struct TWeakPtrTester: public TWeakRefCounted<TWeakPtrTester> { + int* const CounterPtr; + + TWeakPtrTester(int* counterPtr) + : CounterPtr(counterPtr) + { + } + ~TWeakPtrTester() { + ++*CounterPtr; + } + }; + + Y_UNIT_TEST(Simple) { + int destroyCount = 0; + + TIntrusivePtr<TWeakPtrTester> p(new TWeakPtrTester(&destroyCount)); + + UNIT_ASSERT(!!p); + UNIT_ASSERT_VALUES_EQUAL(1u, p->RefCount()); + + TWeakPtr<TWeakPtrTester> p2(p); + + UNIT_ASSERT_VALUES_EQUAL(1u, p->RefCount()); + + { + TIntrusivePtr<TWeakPtrTester> p3 = p2.Get(); + UNIT_ASSERT(!!p3); + UNIT_ASSERT_VALUES_EQUAL(2u, p->RefCount()); + } + + p.Drop(); + UNIT_ASSERT_VALUES_EQUAL(1, destroyCount); + + { + TIntrusivePtr<TWeakPtrTester> p3 = p2.Get(); + UNIT_ASSERT(!p3); + } + + UNIT_ASSERT_VALUES_EQUAL(1, destroyCount); + } +} diff --git a/library/cpp/messagebus/monitoring/mon_proto.proto b/library/cpp/messagebus/monitoring/mon_proto.proto new file mode 100644 index 0000000000..73b6614481 --- /dev/null +++ b/library/cpp/messagebus/monitoring/mon_proto.proto @@ -0,0 +1,55 @@ +import "library/cpp/monlib/encode/legacy_protobuf/protos/metric_meta.proto"; + +package NBus; + +option java_package = "ru.yandex.messagebus.monitoring.proto"; + +message TMessageStatusRecord { + enum EMessageStatus { + MESSAGE_OK = 0; + MESSAGE_CONNECT_FAILED = 1; + MESSAGE_TIMEOUT = 2; + MESSAGE_SERVICE_UNKNOWN = 3; + MESSAGE_BUSY = 4; + MESSAGE_UNKNOWN = 5; + MESSAGE_DESERIALIZE_ERROR = 6; + MESSAGE_HEADER_CORRUPTED = 7; + MESSAGE_DECOMPRESS_ERROR = 8; + MESSAGE_MESSAGE_TOO_LARGE = 9; + MESSAGE_REPLY_FAILED = 10; + MESSAGE_DELIVERY_FAILED = 11; + MESSAGE_INVALID_VERSION = 12; + MESSAGE_SERVICE_TOOMANY = 13; + MESSAGE_SHUTDOWN = 14; + MESSAGE_DONT_ASK = 15; + } + + optional EMessageStatus Status = 1; + optional uint32 Count = 2; +} + +message TConnectionStatusMonRecord { + optional uint32 SendQueueSize = 1 [ (NMonProto.Metric).Type = GAUGE ]; + // client only + optional uint32 AckMessagesSize = 2 [ (NMonProto.Metric).Type = GAUGE ]; + optional uint32 ErrorCount = 3 [ (NMonProto.Metric).Type = RATE ]; + + optional uint64 WriteBytes = 10 [ (NMonProto.Metric).Type = RATE ]; + optional uint64 WriteBytesCompressed = 11; + optional uint64 WriteMessages = 12 [ (NMonProto.Metric).Type = RATE ]; + optional uint64 WriteSyscalls = 13; + optional uint64 WriteActs = 14; + optional uint64 ReadBytes = 20 [ (NMonProto.Metric).Type = RATE ]; + optional uint64 ReadBytesCompressed = 21; + optional uint64 ReadMessages = 22 [ (NMonProto.Metric).Type = RATE ]; + optional uint64 ReadSyscalls = 23; + optional uint64 ReadActs = 24; + + repeated TMessageStatusRecord ErrorCountByStatus = 25; +} + +message TSessionStatusMonRecord { + optional uint32 InFlight = 1 [ (NMonProto.Metric).Type = GAUGE ]; + optional uint32 ConnectionCount = 2 [ (NMonProto.Metric).Type = GAUGE ]; + optional uint32 ConnectCount = 3 [ (NMonProto.Metric).Type = RATE ]; +} diff --git a/library/cpp/messagebus/monitoring/ya.make b/library/cpp/messagebus/monitoring/ya.make new file mode 100644 index 0000000000..25782492b1 --- /dev/null +++ b/library/cpp/messagebus/monitoring/ya.make @@ -0,0 +1,15 @@ +PROTO_LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/monlib/encode/legacy_protobuf/protos +) + +SRCS( + mon_proto.proto +) + +EXCLUDE_TAGS(GO_PROTO) + +END() diff --git a/library/cpp/messagebus/moved.h b/library/cpp/messagebus/moved.h new file mode 100644 index 0000000000..ede8dcd244 --- /dev/null +++ b/library/cpp/messagebus/moved.h @@ -0,0 +1,39 @@ +#pragma once + +#include <util/generic/utility.h> + +template <typename T> +class TMoved { +private: + mutable T Value; + +public: + TMoved() { + } + TMoved(const TMoved<T>& that) { + DoSwap(Value, that.Value); + } + TMoved(const T& that) { + DoSwap(Value, const_cast<T&>(that)); + } + + void swap(TMoved& that) { + DoSwap(Value, that.Value); + } + + T& operator*() { + return Value; + } + + const T& operator*() const { + return Value; + } + + T* operator->() { + return &Value; + } + + const T* operator->() const { + return &Value; + } +}; diff --git a/library/cpp/messagebus/moved_ut.cpp b/library/cpp/messagebus/moved_ut.cpp new file mode 100644 index 0000000000..c1a07cce7e --- /dev/null +++ b/library/cpp/messagebus/moved_ut.cpp @@ -0,0 +1,22 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "moved.h" + +Y_UNIT_TEST_SUITE(TMovedTest) { + Y_UNIT_TEST(Simple) { + TMoved<THolder<int>> h1(MakeHolder<int>(10)); + TMoved<THolder<int>> h2 = h1; + UNIT_ASSERT(!*h1); + UNIT_ASSERT(!!*h2); + UNIT_ASSERT_VALUES_EQUAL(10, **h2); + } + + void Foo(TMoved<THolder<int>> h) { + UNIT_ASSERT_VALUES_EQUAL(11, **h); + } + + Y_UNIT_TEST(PassToFunction) { + THolder<int> h(new int(11)); + Foo(h); + } +} diff --git a/library/cpp/messagebus/netaddr.h b/library/cpp/messagebus/netaddr.h new file mode 100644 index 0000000000..f915c8c574 --- /dev/null +++ b/library/cpp/messagebus/netaddr.h @@ -0,0 +1,4 @@ +#pragma once + +#include <library/cpp/messagebus/config/netaddr.h> + diff --git a/library/cpp/messagebus/netaddr_ut.cpp b/library/cpp/messagebus/netaddr_ut.cpp new file mode 100644 index 0000000000..e5c68bf402 --- /dev/null +++ b/library/cpp/messagebus/netaddr_ut.cpp @@ -0,0 +1,21 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "netaddr.h" +#include "test_utils.h" + +using namespace NBus; + +Y_UNIT_TEST_SUITE(TNetAddr) { + Y_UNIT_TEST(ResolveIpv4) { + ASSUME_IP_V4_ENABLED; + UNIT_ASSERT(TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_4).IsIpv4()); + } + + Y_UNIT_TEST(ResolveIpv6) { + UNIT_ASSERT(TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_6).IsIpv6()); + } + + Y_UNIT_TEST(ResolveAny) { + TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_ANY); + } +} diff --git a/library/cpp/messagebus/network.cpp b/library/cpp/messagebus/network.cpp new file mode 100644 index 0000000000..304bedae5a --- /dev/null +++ b/library/cpp/messagebus/network.cpp @@ -0,0 +1,156 @@ +#include "network.h" + +#include <util/generic/maybe.h> +#include <util/generic/ptr.h> +#include <util/network/init.h> +#include <util/network/socket.h> +#include <util/system/platform.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +namespace { + TBindResult BindOnPortProto(int port, int af, bool reusePort) { + Y_VERIFY(af == AF_INET || af == AF_INET6, "wrong af"); + + SOCKET fd = ::socket(af, SOCK_STREAM, 0); + if (fd == INVALID_SOCKET) { + ythrow TSystemError() << "failed to create a socket"; + } + + int one = 1; + int r1 = SetSockOpt(fd, SOL_SOCKET, SO_REUSEADDR, one); + if (r1 < 0) { + ythrow TSystemError() << "failed to setsockopt SO_REUSEADDR"; + } + +#ifdef SO_REUSEPORT + if (reusePort) { + int r = SetSockOpt(fd, SOL_SOCKET, SO_REUSEPORT, one); + if (r < 0) { + ythrow TSystemError() << "failed to setsockopt SO_REUSEPORT"; + } + } +#else + Y_UNUSED(reusePort); +#endif + + THolder<TOpaqueAddr> addr(new TOpaqueAddr); + sockaddr* sa = addr->MutableAddr(); + sa->sa_family = af; + socklen_t len; + if (af == AF_INET) { + len = sizeof(sockaddr_in); + ((sockaddr_in*)sa)->sin_port = HostToInet((ui16)port); + ((sockaddr_in*)sa)->sin_addr.s_addr = INADDR_ANY; + } else { + len = sizeof(sockaddr_in6); + ((sockaddr_in6*)sa)->sin6_port = HostToInet((ui16)port); + } + + if (af == AF_INET6) { + FixIPv6ListenSocket(fd); + } + + int r2 = ::bind(fd, sa, len); + if (r2 < 0) { + ythrow TSystemError() << "failed to bind on port " << port; + } + + int rsn = ::getsockname(fd, addr->MutableAddr(), addr->LenPtr()); + if (rsn < 0) { + ythrow TSystemError() << "failed to getsockname"; + } + + int r3 = ::listen(fd, 50); + if (r3 < 0) { + ythrow TSystemError() << "listen failed"; + } + + TBindResult r; + r.Socket.Reset(new TSocketHolder(fd)); + r.Addr = TNetAddr(addr.Release()); + return r; + } + + TMaybe<TBindResult> TryBindOnPortProto(int port, int af, bool reusePort) { + try { + return {BindOnPortProto(port, af, reusePort)}; + } catch (const TSystemError&) { + return {}; + } + } + + std::pair<unsigned, TVector<TBindResult>> AggregateBindResults(TBindResult&& r1, TBindResult&& r2) { + Y_VERIFY(r1.Addr.GetPort() == r2.Addr.GetPort(), "internal"); + std::pair<unsigned, TVector<TBindResult>> r; + r.second.reserve(2); + + r.first = r1.Addr.GetPort(); + r.second.emplace_back(std::move(r1)); + r.second.emplace_back(std::move(r2)); + return r; + } +} + +std::pair<unsigned, TVector<TBindResult>> NBus::BindOnPort(int port, bool reusePort) { + std::pair<unsigned, TVector<TBindResult>> r; + r.second.reserve(2); + + if (port != 0) { + return AggregateBindResults(BindOnPortProto(port, AF_INET, reusePort), + BindOnPortProto(port, AF_INET6, reusePort)); + } + + // use nothrow versions in cycle + for (int i = 0; i < 1000; ++i) { + TMaybe<TBindResult> in4 = TryBindOnPortProto(0, AF_INET, reusePort); + if (!in4) { + continue; + } + + TMaybe<TBindResult> in6 = TryBindOnPortProto(in4->Addr.GetPort(), AF_INET6, reusePort); + if (!in6) { + continue; + } + + return AggregateBindResults(std::move(*in4), std::move(*in6)); + } + + TBindResult in4 = BindOnPortProto(0, AF_INET, reusePort); + TBindResult in6 = BindOnPortProto(in4.Addr.GetPort(), AF_INET6, reusePort); + return AggregateBindResults(std::move(in4), std::move(in6)); +} + +void NBus::NPrivate::SetSockOptTcpCork(SOCKET s, bool value) { +#ifdef _linux_ + CheckedSetSockOpt(s, IPPROTO_TCP, TCP_CORK, (int)value, "TCP_CORK"); +#else + Y_UNUSED(s); + Y_UNUSED(value); +#endif +} + +ssize_t NBus::NPrivate::SocketSend(SOCKET s, TArrayRef<const char> data) { + int flags = 0; +#if defined(_linux_) || defined(_freebsd_) + flags |= MSG_NOSIGNAL; +#endif + ssize_t r = ::send(s, data.data(), data.size(), flags); + if (r < 0) { + Y_VERIFY(LastSystemError() != EBADF, "bad fd"); + } + return r; +} + +ssize_t NBus::NPrivate::SocketRecv(SOCKET s, TArrayRef<char> buffer) { + int flags = 0; +#if defined(_linux_) || defined(_freebsd_) + flags |= MSG_NOSIGNAL; +#endif + ssize_t r = ::recv(s, buffer.data(), buffer.size(), flags); + if (r < 0) { + Y_VERIFY(LastSystemError() != EBADF, "bad fd"); + } + return r; +} diff --git a/library/cpp/messagebus/network.h b/library/cpp/messagebus/network.h new file mode 100644 index 0000000000..cc4bd76ea3 --- /dev/null +++ b/library/cpp/messagebus/network.h @@ -0,0 +1,28 @@ +#pragma once + +#include "netaddr.h" + +#include <util/generic/array_ref.h> +#include <util/generic/ptr.h> +#include <util/network/socket.h> + +#include <utility> + +namespace NBus { + namespace NPrivate { + void SetSockOptTcpCork(SOCKET s, bool value); + + [[nodiscard]] ssize_t SocketSend(SOCKET s, TArrayRef<const char> data); + + [[nodiscard]] ssize_t SocketRecv(SOCKET s, TArrayRef<char> buffer); + + } + + struct TBindResult { + TSimpleSharedPtr<TSocketHolder> Socket; + TNetAddr Addr; + }; + + std::pair<unsigned, TVector<TBindResult>> BindOnPort(int port, bool reusePort); + +} diff --git a/library/cpp/messagebus/network_ut.cpp b/library/cpp/messagebus/network_ut.cpp new file mode 100644 index 0000000000..f1798419db --- /dev/null +++ b/library/cpp/messagebus/network_ut.cpp @@ -0,0 +1,65 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "network.h" + +#include <library/cpp/messagebus/test/helper/fixed_port.h> + +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NBus::NTest; + +namespace { + int GetSockPort(SOCKET socket) { + sockaddr_storage addr; + Zero(addr); + + socklen_t len = sizeof(addr); + + int r = ::getsockname(socket, (sockaddr*)&addr, &len); + UNIT_ASSERT(r >= 0); + + if (addr.ss_family == AF_INET) { + sockaddr_in* addr_in = (sockaddr_in*)&addr; + return InetToHost(addr_in->sin_port); + } else if (addr.ss_family == AF_INET6) { + sockaddr_in6* addr_in6 = (sockaddr_in6*)&addr; + return InetToHost(addr_in6->sin6_port); + } else { + UNIT_FAIL("unknown AF"); + throw 1; + } + } +} + +Y_UNIT_TEST_SUITE(Network) { + Y_UNIT_TEST(BindOnPortConcrete) { + if (!IsFixedPortTestAllowed()) { + return; + } + + TVector<TBindResult> r = BindOnPort(FixedPort, false).second; + UNIT_ASSERT_VALUES_EQUAL(size_t(2), r.size()); + + for (TVector<TBindResult>::iterator i = r.begin(); i != r.end(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(i->Addr.GetPort(), GetSockPort(i->Socket->operator SOCKET())); + } + } + + Y_UNIT_TEST(BindOnPortRandom) { + TVector<TBindResult> r = BindOnPort(0, false).second; + UNIT_ASSERT_VALUES_EQUAL(size_t(2), r.size()); + + for (TVector<TBindResult>::iterator i = r.begin(); i != r.end(); ++i) { + UNIT_ASSERT_VALUES_EQUAL(i->Addr.GetPort(), GetSockPort(i->Socket->operator SOCKET())); + UNIT_ASSERT(i->Addr.GetPort() > 0); + } + + UNIT_ASSERT_VALUES_EQUAL(r.at(0).Addr.GetPort(), r.at(1).Addr.GetPort()); + } + + Y_UNIT_TEST(BindOnBusyPort) { + auto r = BindOnPort(0, false); + + UNIT_ASSERT_EXCEPTION_CONTAINS(BindOnPort(r.first, false), TSystemError, "failed to bind on port " + ToString(r.first)); + } +} diff --git a/library/cpp/messagebus/nondestroying_holder.h b/library/cpp/messagebus/nondestroying_holder.h new file mode 100644 index 0000000000..f4725d696f --- /dev/null +++ b/library/cpp/messagebus/nondestroying_holder.h @@ -0,0 +1,39 @@ +#pragma once + +#include <util/generic/ptr.h> + +template <typename T> +class TNonDestroyingHolder: public THolder<T> { +public: + TNonDestroyingHolder(T* t = nullptr) noexcept + : THolder<T>(t) + { + } + + TNonDestroyingHolder(TAutoPtr<T> t) noexcept + : THolder<T>(t) + { + } + + ~TNonDestroyingHolder() { + Y_VERIFY(!*this, "stored object must be explicitly released"); + } +}; + +template <class T> +class TNonDestroyingAutoPtr: public TAutoPtr<T> { +public: + inline TNonDestroyingAutoPtr(T* t = 0) noexcept + : TAutoPtr<T>(t) + { + } + + inline TNonDestroyingAutoPtr(const TAutoPtr<T>& t) noexcept + : TAutoPtr<T>(t.Release()) + { + } + + inline ~TNonDestroyingAutoPtr() { + Y_VERIFY(!*this, "stored object must be explicitly released"); + } +}; diff --git a/library/cpp/messagebus/nondestroying_holder_ut.cpp b/library/cpp/messagebus/nondestroying_holder_ut.cpp new file mode 100644 index 0000000000..208042a2ba --- /dev/null +++ b/library/cpp/messagebus/nondestroying_holder_ut.cpp @@ -0,0 +1,12 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "nondestroying_holder.h" + +Y_UNIT_TEST_SUITE(TNonDestroyingHolder) { + Y_UNIT_TEST(ToAutoPtr) { + TNonDestroyingHolder<int> h(new int(11)); + TAutoPtr<int> i(h); + UNIT_ASSERT_VALUES_EQUAL(11, *i); + UNIT_ASSERT(!h); + } +} diff --git a/library/cpp/messagebus/oldmodule/module.cpp b/library/cpp/messagebus/oldmodule/module.cpp new file mode 100644 index 0000000000..24bd778799 --- /dev/null +++ b/library/cpp/messagebus/oldmodule/module.cpp @@ -0,0 +1,881 @@ +#include "module.h" + +#include <library/cpp/messagebus/scheduler_actor.h> +#include <library/cpp/messagebus/thread_extra.h> +#include <library/cpp/messagebus/actor/actor.h> +#include <library/cpp/messagebus/actor/queue_in_actor.h> +#include <library/cpp/messagebus/actor/what_thread_does.h> +#include <library/cpp/messagebus/actor/what_thread_does_guard.h> + +#include <util/generic/singleton.h> +#include <util/string/printf.h> +#include <util/system/event.h> + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; + +namespace { + Y_POD_STATIC_THREAD(TBusJob*) + ThreadCurrentJob; + + struct TThreadCurrentJobGuard { + TBusJob* Prev; + + TThreadCurrentJobGuard(TBusJob* job) + : Prev(ThreadCurrentJob) + { + Y_ASSERT(!ThreadCurrentJob || ThreadCurrentJob == job); + ThreadCurrentJob = job; + } + ~TThreadCurrentJobGuard() { + ThreadCurrentJob = Prev; + } + }; + + void ClearState(NBus::TJobState* state) { + /// skip sendbacks handlers + if (state->Message != state->Reply) { + if (state->Message) { + delete state->Message; + state->Message = nullptr; + } + + if (state->Reply) { + delete state->Reply; + state->Reply = nullptr; + } + } + } + + void ClearJobStateVector(NBus::TJobStateVec* vec) { + Y_ASSERT(vec); + + for (auto& call : *vec) { + ClearState(&call); + } + + vec->clear(); + } + +} + +namespace NBus { + namespace NPrivate { + class TJobStorage { + }; + + struct TModuleClientHandler + : public IBusClientHandler { + TModuleClientHandler(TBusModuleImpl* module) + : Module(module) + { + } + + void OnReply(TAutoPtr<TBusMessage> req, TAutoPtr<TBusMessage> reply) override; + void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override; + void OnError(TAutoPtr<TBusMessage> msg, EMessageStatus status) override; + void OnClientConnectionEvent(const TClientConnectionEvent& event) override; + + TBusModuleImpl* const Module; + }; + + struct TModuleServerHandler + : public IBusServerHandler { + TModuleServerHandler(TBusModuleImpl* module) + : Module(module) + { + } + + void OnMessage(TOnMessageContext& msg) override; + + TBusModuleImpl* const Module; + }; + + struct TBusModuleImpl: public TBusModuleInternal { + TBusModule* const Module; + + TBusMessageQueue* Queue; + + TScheduler Scheduler; + + const char* const Name; + + typedef TList<TJobRunner*> TBusJobList; + /// jobs currently in-flight on this module + TBusJobList Jobs; + /// module level mutex + TMutex Lock; + TCondVar ShutdownCondVar; + TAtomic JobCount; + + enum EState { + CREATED, + RUNNING, + STOPPED, + }; + + TAtomic State; + TBusModuleConfig ModuleConfig; + TBusServerSessionPtr ExternalSession; + /// protocol for local proxy session + THolder<IBusClientHandler> ModuleClientHandler; + THolder<IBusServerHandler> ModuleServerHandler; + TVector<TSimpleSharedPtr<TBusStarter>> Starters; + + // Sessions must be destroyed before + // ModuleClientHandler / ModuleServerHandler + TVector<TBusClientSessionPtr> ClientSessions; + TVector<TBusServerSessionPtr> ServerSessions; + + TBusModuleImpl(TBusModule* module, const char* name) + : Module(module) + , Queue() + , Name(name) + , JobCount(0) + , State(CREATED) + , ExternalSession(nullptr) + , ModuleClientHandler(new TModuleClientHandler(this)) + , ModuleServerHandler(new TModuleServerHandler(this)) + { + } + + ~TBusModuleImpl() override { + // Shutdown cannot be called from destructor, + // because module has virtual methods. + Y_VERIFY(State != RUNNING, "if running, must explicitly call Shutdown() before destructor"); + + Scheduler.Stop(); + + while (!Jobs.empty()) { + DestroyJob(Jobs.front()); + } + Y_VERIFY(JobCount == 0, "state check"); + } + + void OnMessageReceived(TAutoPtr<TBusMessage> msg, TOnMessageContext&); + + void AddJob(TJobRunner* jobRunner); + + void DestroyJob(TJobRunner* job); + + /// terminate job on this message + void CancelJob(TBusJob* job, EMessageStatus status); + /// prints statuses of jobs + TString GetStatus(unsigned flags); + + size_t Size() const { + return AtomicGet(JobCount); + } + + void Shutdown(); + + TVector<TBusClientSessionPtr> GetClientSessionsInternal() override { + return ClientSessions; + } + + TVector<TBusServerSessionPtr> GetServerSessionsInternal() override { + return ServerSessions; + } + + TBusMessageQueue* GetQueue() override { + return Queue; + } + + TString GetNameInternal() override { + return Name; + } + + TString GetStatusSingleLine() override { + TStringStream ss; + ss << "jobs: " << Size(); + return ss.Str(); + } + + void OnClientConnectionEvent(const TClientConnectionEvent& event) { + Module->OnClientConnectionEvent(event); + } + }; + + struct TJobResponseMessage { + TBusMessage* Request; + TBusMessage* Response; + EMessageStatus Status; + + TJobResponseMessage(TBusMessage* request, TBusMessage* response, EMessageStatus status) + : Request(request) + , Response(response) + , Status(status) + { + } + }; + + struct TJobRunner: public TAtomicRefCount<TJobRunner>, + public NActor::TActor<TJobRunner>, + public NActor::TQueueInActor<TJobRunner, TJobResponseMessage>, + public TScheduleActor<TJobRunner> { + THolder<TBusJob> Job; + + TList<TJobRunner*>::iterator JobStorageIterator; + + TJobRunner(TAutoPtr<TBusJob> job) + : NActor::TActor<TJobRunner>(job->ModuleImpl->Queue->GetExecutor()) + , TScheduleActor<TJobRunner>(&job->ModuleImpl->Scheduler) + , Job(job.Release()) + , JobStorageIterator() + { + Job->Runner = this; + } + + ~TJobRunner() override { + Y_ASSERT(JobStorageIterator == TList<TJobRunner*>::iterator()); + } + + void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, const TJobResponseMessage& message) { + Job->CallReplyHandler(message.Status, message.Request, message.Response); + } + + void Destroy() { + if (!!Job->OnMessageContext) { + if (!Job->ReplySent) { + Job->OnMessageContext.ForgetRequest(); + } + } + Job->ModuleImpl->DestroyJob(this); + } + + void Act(NActor::TDefaultTag) { + if (JobStorageIterator == TList<TJobRunner*>::iterator()) { + return; + } + + if (Job->SleepUntil != 0) { + if (AtomicGet(Job->ModuleImpl->State) == TBusModuleImpl::STOPPED) { + Destroy(); + return; + } + } + + TThreadCurrentJobGuard g(Job.Get()); + + NActor::TQueueInActor<TJobRunner, TJobResponseMessage>::DequeueAll(); + + if (Alarm.FetchTask()) { + if (Job->AnyPendingToSend()) { + Y_ASSERT(Job->SleepUntil == 0); + Job->SendPending(); + if (Job->AnyPendingToSend()) { + } + } else { + // regular alarm + Y_ASSERT(Job->Pending.empty()); + Y_ASSERT(Job->SleepUntil != 0); + Job->SleepUntil = 0; + } + } + + for (;;) { + if (Job->Pending.empty() && !!Job->Handler && Job->Status == MESSAGE_OK) { + TWhatThreadDoesPushPop pp("do call job handler (do not confuse with reply handler)"); + + Job->Handler = Job->Handler(Job->Module, Job.Get(), Job->Message); + } + + if (Job->SleepUntil != 0) { + ScheduleAt(TInstant::MilliSeconds(Job->SleepUntil)); + return; + } + + Job->SendPending(); + + if (Job->AnyPendingToSend()) { + ScheduleAt(TInstant::Now() + TDuration::Seconds(1)); + return; + } + + if (!Job->Pending.empty()) { + // waiting replies + return; + } + + if (Job->IsDone()) { + Destroy(); + return; + } + } + } + }; + + } + + static inline TJobRunner* GetJob(TBusMessage* message) { + return (TJobRunner*)message->Data; + } + + static inline void SetJob(TBusMessage* message, TJobRunner* job) { + message->Data = job; + } + + TBusJob::TBusJob(TBusModule* module, TBusMessage* message) + : Status(MESSAGE_OK) + , Runner() + , Message(message) + , ReplySent(false) + , Module(module) + , ModuleImpl(module->Impl.Get()) + , SleepUntil(0) + { + Handler = TJobHandler(&TBusModule::Start); + } + + TBusJob::~TBusJob() { + Y_ASSERT(Pending.size() == 0); + //Y_ASSERT(SleepUntil == 0); + + ClearAllMessageStates(); + } + + TNetAddr TBusJob::GetPeerAddrNetAddr() const { + Y_VERIFY(!!OnMessageContext); + return OnMessageContext.GetPeerAddrNetAddr(); + } + + void TBusJob::CheckThreadCurrentJob() { + Y_ASSERT(ThreadCurrentJob == this); + } + + ///////////////////////////////////////////////////////// + /// \brief Send messages in pending list + + /// If at least one message is gone return true + /// If message has not been send, move it to Finished with appropriate error code + bool TBusJob::SendPending() { + // Iterator type must be size_t, not vector::iterator, + // because `DoCallReplyHandler` may call `Send` that modifies `Pending` vector, + // that in turn invalidates iterator. + // Implementation assumes that `DoCallReplyHandler` only pushes back to `Pending` + // (not erases, and not inserts) so iteration by index is valid. + size_t it = 0; + while (it != Pending.size()) { + TJobState& call = Pending[it]; + + if (call.Status == MESSAGE_DONT_ASK) { + EMessageStatus getAddressStatus = MESSAGE_OK; + TNetAddr addr; + if (call.UseAddr) { + addr = call.Addr; + } else { + getAddressStatus = const_cast<TBusProtocol*>(call.Session->GetProto())->GetDestination(call.Session, call.Message, call.Session->GetQueue()->GetLocator(), &addr); + } + + if (getAddressStatus == MESSAGE_OK) { + // hold extra reference for each request in flight + Runner->Ref(); + + if (call.OneWay) { + call.Status = call.Session->SendMessageOneWay(call.Message, &addr); + } else { + call.Status = call.Session->SendMessage(call.Message, &addr); + } + + if (call.Status != MESSAGE_OK) { + Runner->UnRef(); + } + + } else { + call.Status = getAddressStatus; + } + } + + if (call.Status == MESSAGE_OK) { + ++it; // keep pending list until we get reply + } else if (call.Status == MESSAGE_BUSY) { + Y_FAIL("MESSAGE_BUSY is prohibited in modules. Please increase MaxInFlight"); + } else if (call.Status == MESSAGE_CONNECT_FAILED && call.NumRetries < call.MaxRetries) { + ++it; // try up to call.MaxRetries times to send message + call.NumRetries++; + DoCallReplyHandler(call); + call.Status = MESSAGE_DONT_ASK; + call.Message->Reset(); // generate new Id + } else { + Finished.push_back(call); + DoCallReplyHandler(call); + Pending.erase(Pending.begin() + it); + } + } + return Pending.size() > 0; + } + + bool TBusJob::AnyPendingToSend() { + for (unsigned i = 0; i < Pending.size(); ++i) { + if (Pending[i].Status == MESSAGE_DONT_ASK) { + return true; + } + } + + return false; + } + + bool TBusJob::IsDone() { + bool r = (SleepUntil == 0 && Pending.size() == 0 && (Handler == nullptr || Status != MESSAGE_OK)); + return r; + } + + void TBusJob::CallJobHandlerOnly() { + TThreadCurrentJobGuard threadCurrentJobGuard(this); + TWhatThreadDoesPushPop pp("do call job handler (do not confuse with reply handler)"); + + Handler = Handler(ModuleImpl->Module, this, Message); + } + + bool TBusJob::CallJobHandler() { + /// go on as far as we can go without waiting + while (!IsDone()) { + /// call the handler + CallJobHandlerOnly(); + + /// quit if job is canceled + if (Status != MESSAGE_OK) { + break; + } + + /// there are messages to send and wait for reply + SendPending(); + + if (!Pending.empty()) { + break; + } + + /// asked to sleep + if (SleepUntil) { + break; + } + } + + Y_VERIFY(!(Pending.size() == 0 && Handler == nullptr && Status == MESSAGE_OK && !ReplySent), + "Handler returned NULL without Cancel() or SendReply() for message=%016" PRIx64 " type=%d", + Message->GetHeader()->Id, Message->GetHeader()->Type); + + return IsDone(); + } + + void TBusJob::DoCallReplyHandler(TJobState& call) { + if (call.Handler) { + TWhatThreadDoesPushPop pp("do call reply handler (do not confuse with job handler)"); + + TThreadCurrentJobGuard threadCurrentJobGuard(this); + (Module->*(call.Handler))(this, call.Status, call.Message, call.Reply); + } + } + + int TBusJob::CallReplyHandler(EMessageStatus status, TBusMessage* mess, TBusMessage* reply) { + /// find handler for given message and update it's status + size_t i = 0; + for (; i < Pending.size(); ++i) { + TJobState& call = Pending[i]; + if (call.Message == mess) { + break; + } + } + + /// if not found, report error + if (i == Pending.size()) { + Y_FAIL("must not happen"); + } + + /// fill in response into job state + TJobState& call = Pending[i]; + call.Status = status; + Y_ASSERT(call.Message == mess); + call.Reply = reply; + + if ((status == MESSAGE_TIMEOUT || status == MESSAGE_DELIVERY_FAILED) && call.NumRetries < call.MaxRetries) { + call.NumRetries++; + call.Status = MESSAGE_DONT_ASK; + call.Message->Reset(); // generate new Id + DoCallReplyHandler(call); + return 0; + } + + /// call the handler if provided + DoCallReplyHandler(call); + + /// move job state into the finished stack + Finished.push_back(Pending[i]); + Pending.erase(Pending.begin() + i); + + return 0; + } + + /////////////////////////////////////////////////////////////// + /// send message to any other session or application + void TBusJob::Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries) { + CheckThreadCurrentJob(); + + SetJob(mess.Get(), Runner); + Pending.push_back(TJobState(rhandler, MESSAGE_DONT_ASK, mess.Release(), session, nullptr, maxRetries, nullptr, false)); + } + + void TBusJob::Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries, const TNetAddr& addr) { + CheckThreadCurrentJob(); + + SetJob(mess.Get(), Runner); + Pending.push_back(TJobState(rhandler, MESSAGE_DONT_ASK, mess.Release(), session, nullptr, maxRetries, &addr, false)); + } + + void TBusJob::SendOneWayTo(TBusMessageAutoPtr req, TBusClientSession* session, const TNetAddr& addr) { + CheckThreadCurrentJob(); + + SetJob(req.Get(), Runner); + Pending.push_back(TJobState(nullptr, MESSAGE_DONT_ASK, req.Release(), session, nullptr, 0, &addr, true)); + } + + void TBusJob::SendOneWayWithLocator(TBusMessageAutoPtr req, TBusClientSession* session) { + CheckThreadCurrentJob(); + + SetJob(req.Get(), Runner); + Pending.push_back(TJobState(nullptr, MESSAGE_DONT_ASK, req.Release(), session, nullptr, 0, nullptr, true)); + } + + /////////////////////////////////////////////////////////////// + /// send reply to the starter message + void TBusJob::SendReply(TBusMessageAutoPtr reply) { + CheckThreadCurrentJob(); + + Y_VERIFY(!ReplySent, "cannot call SendReply twice"); + ReplySent = true; + if (!OnMessageContext) + return; + + EMessageStatus ok = OnMessageContext.SendReplyMove(reply); + if (ok != MESSAGE_OK) { + // TODO: count errors + } + } + + /// set the flag to terminate job at the earliest convenience + void TBusJob::Cancel(EMessageStatus status) { + CheckThreadCurrentJob(); + + Status = status; + } + + void TBusJob::ClearState(TJobState& call) { + TJobStateVec::iterator it; + for (it = Finished.begin(); it != Finished.end(); ++it) { + TJobState& state = *it; + if (&call == &state) { + ::ClearState(&call); + Finished.erase(it); + return; + } + } + Y_ASSERT(0); + } + + void TBusJob::ClearAllMessageStates() { + ClearJobStateVector(&Finished); + ClearJobStateVector(&Pending); + } + + void TBusJob::Sleep(int milliSeconds) { + CheckThreadCurrentJob(); + + Y_VERIFY(Pending.empty(), "sleep is not allowed when there are pending job"); + Y_VERIFY(SleepUntil == 0, "must not override sleep"); + + SleepUntil = Now() + milliSeconds; + } + + TString TBusJob::GetStatus(unsigned flags) { + TString strReturn; + strReturn += Sprintf(" job=%016" PRIx64 " type=%d sent=%d pending=%d (%d) %s\n", + Message->GetHeader()->Id, + (int)Message->GetHeader()->Type, + (int)(Now() - Message->GetHeader()->SendTime) / 1000, + (int)Pending.size(), + (int)Finished.size(), + Status != MESSAGE_OK ? ToString(Status).data() : ""); + + TJobStateVec::iterator it; + for (it = Pending.begin(); it != Pending.end(); ++it) { + TJobState& call = *it; + strReturn += call.GetStatus(flags); + } + return strReturn; + } + + TString TJobState::GetStatus(unsigned flags) { + Y_UNUSED(flags); + TString strReturn; + strReturn += Sprintf(" pending=%016" PRIx64 " type=%d (%s) sent=%d %s\n", + Message->GetHeader()->Id, + (int)Message->GetHeader()->Type, + Session->GetProto()->GetService(), + (int)(Now() - Message->GetHeader()->SendTime) / 1000, + ToString(Status).data()); + return strReturn; + } + + ////////////////////////////////////////////////////////////////////// + + void TBusModuleImpl::CancelJob(TBusJob* job, EMessageStatus status) { + TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for CancelJob"); + if (job) { + job->Cancel(status); + } + } + + TString TBusModuleImpl::GetStatus(unsigned flags) { + Y_UNUSED(flags); + TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for GetStatus"); + TString strReturn = Sprintf("JobsInFlight=%d\n", (int)Jobs.size()); + for (auto job : Jobs) { + //strReturn += job->Job->GetStatus(flags); + Y_UNUSED(job); + strReturn += "TODO\n"; + } + return strReturn; + } + + TBusModuleConfig::TBusModuleConfig() + : StarterMaxInFlight(1000) + { + } + + TBusModuleConfig::TSecret::TSecret() + : SchedulePeriod(TDuration::Seconds(1)) + { + } + + TBusModule::TBusModule(const char* name) + : Impl(new TBusModuleImpl(this, name)) + { + } + + TBusModule::~TBusModule() { + } + + const char* TBusModule::GetName() const { + return Impl->Name; + } + + void TBusModule::SetConfig(const TBusModuleConfig& config) { + Impl->ModuleConfig = config; + } + + bool TBusModule::StartInput() { + Y_VERIFY(Impl->State == TBusModuleImpl::CREATED, "state check"); + Y_VERIFY(!!Impl->Queue, "state check"); + Impl->State = TBusModuleImpl::RUNNING; + + Y_ASSERT(!Impl->ExternalSession); + TBusServerSessionPtr extSession = CreateExtSession(*Impl->Queue); + if (extSession != nullptr) { + Impl->ExternalSession = extSession; + } + + return true; + } + + bool TBusModule::Shutdown() { + Impl->Shutdown(); + + return true; + } + + TBusJob* TBusModule::CreateJobInstance(TBusMessage* message) { + TBusJob* job = new TBusJob(this, message); + return job; + } + + /** +Example for external session creation: + +TBusSession* TMyModule::CreateExtSession(TBusMessageQueue& queue) { + TBusSession* session = CreateDefaultDestination(queue, &ExternalProto, ExternalConfig); + session->RegisterService(hostname, begin, end); + return session; +*/ + + bool TBusModule::CreatePrivateSessions(TBusMessageQueue* queue) { + Impl->Queue = queue; + return true; + } + + int TBusModule::GetModuleSessionInFlight() const { + return Impl->Size(); + } + + TIntrusivePtr<TBusModuleInternal> TBusModule::GetInternal() { + return Impl.Get(); + } + + TBusServerSessionPtr TBusModule::CreateDefaultDestination( + TBusMessageQueue& queue, TBusProtocol* proto, const TBusServerSessionConfig& config, const TString& name) { + TBusServerSessionConfig patchedConfig = config; + patchedConfig.ExecuteOnMessageInWorkerPool = false; + if (!patchedConfig.Name) { + patchedConfig.Name = name; + } + if (!patchedConfig.Name) { + patchedConfig.Name = Impl->Name; + } + TBusServerSessionPtr session = + TBusServerSession::Create(proto, Impl->ModuleServerHandler.Get(), patchedConfig, &queue); + Impl->ServerSessions.push_back(session); + return session; + } + + TBusClientSessionPtr TBusModule::CreateDefaultSource( + TBusMessageQueue& queue, TBusProtocol* proto, const TBusClientSessionConfig& config, const TString& name) { + TBusClientSessionConfig patchedConfig = config; + patchedConfig.ExecuteOnReplyInWorkerPool = false; + if (!patchedConfig.Name) { + patchedConfig.Name = name; + } + if (!patchedConfig.Name) { + patchedConfig.Name = Impl->Name; + } + TBusClientSessionPtr session = + TBusClientSession::Create(proto, Impl->ModuleClientHandler.Get(), patchedConfig, &queue); + Impl->ClientSessions.push_back(session); + return session; + } + + TBusStarter* TBusModule::CreateDefaultStarter(TBusMessageQueue&, const TBusSessionConfig& config) { + TBusStarter* session = new TBusStarter(this, config); + Impl->Starters.push_back(session); + return session; + } + + void TBusModule::OnClientConnectionEvent(const TClientConnectionEvent& event) { + Y_UNUSED(event); + } + + TString TBusModule::GetStatus(unsigned flags) { + TString strReturn = Sprintf("%s\n", Impl->Name); + strReturn += Impl->GetStatus(flags); + return strReturn; + } + +} + +void TBusModuleImpl::AddJob(TJobRunner* jobRunner) { + TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for AddJob"); + Jobs.push_back(jobRunner); + jobRunner->JobStorageIterator = Jobs.end(); + --jobRunner->JobStorageIterator; +} + +void TBusModuleImpl::DestroyJob(TJobRunner* job) { + Y_ASSERT(job->JobStorageIterator != TList<TJobRunner*>::iterator()); + + { + TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for DestroyJob"); + int jobCount = AtomicDecrement(JobCount); + Y_VERIFY(jobCount >= 0, "decremented too much"); + Jobs.erase(job->JobStorageIterator); + + if (AtomicGet(State) == STOPPED) { + if (jobCount == 0) { + ShutdownCondVar.BroadCast(); + } + } + } + + job->JobStorageIterator = TList<TJobRunner*>::iterator(); +} + +void TBusModuleImpl::OnMessageReceived(TAutoPtr<TBusMessage> msg0, TOnMessageContext& context) { + TBusMessage* msg = !!msg0 ? msg0.Get() : context.GetMessage(); + Y_VERIFY(!!msg); + + THolder<TJobRunner> jobRunner(new TJobRunner(Module->CreateJobInstance(msg))); + jobRunner->Job->MessageHolder.Reset(msg0.Release()); + jobRunner->Job->OnMessageContext.Swap(context); + SetJob(jobRunner->Job->Message, jobRunner.Get()); + + AtomicIncrement(JobCount); + + AddJob(jobRunner.Get()); + + jobRunner.Release()->Schedule(); +} + +void TBusModuleImpl::Shutdown() { + if (AtomicGet(State) != TBusModuleImpl::RUNNING) { + AtomicSet(State, TBusModuleImpl::STOPPED); + return; + } + AtomicSet(State, TBusModuleImpl::STOPPED); + + for (auto& clientSession : ClientSessions) { + clientSession->Shutdown(); + } + for (auto& serverSession : ServerSessions) { + serverSession->Shutdown(); + } + + for (size_t starter = 0; starter < Starters.size(); ++starter) { + Starters[starter]->Shutdown(); + } + + { + TWhatThreadDoesAcquireGuard<TMutex> guard(Lock, "modules: acquiring lock for Shutdown"); + for (auto& Job : Jobs) { + Job->Schedule(); + } + + while (!Jobs.empty()) { + ShutdownCondVar.WaitI(Lock); + } + } +} + +EMessageStatus TBusModule::StartJob(TAutoPtr<TBusMessage> message) { + Y_VERIFY(Impl->State == TBusModuleImpl::RUNNING); + Y_VERIFY(!!Impl->Queue); + + if ((unsigned)AtomicGet(Impl->JobCount) >= Impl->ModuleConfig.StarterMaxInFlight) { + return MESSAGE_BUSY; + } + + TOnMessageContext dummy; + Impl->OnMessageReceived(message.Release(), dummy); + + return MESSAGE_OK; +} + +void TModuleServerHandler::OnMessage(TOnMessageContext& msg) { + Module->OnMessageReceived(nullptr, msg); +} + +void TModuleClientHandler::OnReply(TAutoPtr<TBusMessage> req, TAutoPtr<TBusMessage> resp) { + TJobRunner* job = GetJob(req.Get()); + Y_ASSERT(job); + Y_ASSERT(job->Job->Message != req.Get()); + job->EnqueueAndSchedule(TJobResponseMessage(req.Release(), resp.Release(), MESSAGE_OK)); + job->UnRef(); +} + +void TModuleClientHandler::OnMessageSentOneWay(TAutoPtr<TBusMessage> req) { + TJobRunner* job = GetJob(req.Get()); + Y_ASSERT(job); + Y_ASSERT(job->Job->Message != req.Get()); + job->EnqueueAndSchedule(TJobResponseMessage(req.Release(), nullptr, MESSAGE_OK)); + job->UnRef(); +} + +void TModuleClientHandler::OnError(TAutoPtr<TBusMessage> msg, EMessageStatus status) { + TJobRunner* job = GetJob(msg.Get()); + if (job) { + Y_ASSERT(job->Job->Message != msg.Get()); + job->EnqueueAndSchedule(TJobResponseMessage(msg.Release(), nullptr, status)); + job->UnRef(); + } +} + +void TModuleClientHandler::OnClientConnectionEvent(const TClientConnectionEvent& event) { + Module->OnClientConnectionEvent(event); +} diff --git a/library/cpp/messagebus/oldmodule/module.h b/library/cpp/messagebus/oldmodule/module.h new file mode 100644 index 0000000000..8d1c4a5d52 --- /dev/null +++ b/library/cpp/messagebus/oldmodule/module.h @@ -0,0 +1,410 @@ +#pragma once + +/////////////////////////////////////////////////////////////////////////// +/// \file +/// \brief Application interface for modules + +/// NBus::TBusModule provides foundation for implementation of asynchnous +/// modules that communicate with multiple external or local sessions +/// NBus::TBusSession. + +/// To implement the module some virtual functions needs to be overridden: + +/// NBus::TBusModule::CreateExtSession() creates and registers an +/// external session that receives incoming messages as input for module +/// processing. + +/// When new incoming message arrives the new NBus::TBusJob is created. +/// NBus::TBusJob is somewhat similar to a thread, it maintains all the state +/// during processing of one incoming message. Default implementation of +/// NBus::TBusJob will maintain all send and received messages during +/// lifetime of this job. Each message, status and reply can be found +/// within NBus::TJobState using NBus::TBusJob::GetState(). If your module +/// needs to maintain an additional information during lifetime of the job +/// you can derive your own class from NBus::TBusJob and override job +/// factory method NBus::IJobFactory::CreateJobInstance() to create your instances. + +/// Processing of a given message starts with a call to NBus::TBusModule::Start() +/// handler that should be overridden in the module implementation. Within +/// the callback handler module can perform any computation and access any +/// datastore tables that it needs. The handler can also access any module +/// variables. However, same handler can be called from multiple threads so, +/// it is recommended that handler only access read-only module level variables. + +/// Handler should use NBus::TBusJob::Send() to send messages to other client +/// sessions and it can use NBus::TBusJob::Reply() to send reply to the main +/// job message. When handler is done, it returns the pointer to the next handler to call +/// when all pending messages have cleared. If handler +/// returns pointer to itself the module will reschedule execution of this handler +/// for a later time. This should be done in case NBus::TBusJob::Send() returns +/// error (not MESSAGE_OK) + +#include "startsession.h" + +#include <library/cpp/messagebus/ybus.h> + +#include <util/generic/noncopyable.h> +#include <util/generic/object_counter.h> + +namespace NBus { + class TBusJob; + class TBusModule; + + namespace NPrivate { + struct TCallJobHandlerWorkItem; + struct TBusModuleImpl; + struct TModuleServerHandler; + struct TModuleClientHandler; + struct TJobRunner; + } + + class TJobHandler { + protected: + typedef TJobHandler (TBusModule::*TBusHandlerPtr)(TBusJob* job, TBusMessage* mess); + TBusHandlerPtr MyPtr; + + public: + template <class B> + TJobHandler(TJobHandler (B::*fptr)(TBusJob* job, TBusMessage* mess)) { + MyPtr = static_cast<TBusHandlerPtr>(fptr); + } + TJobHandler(TBusHandlerPtr fptr = nullptr) { + MyPtr = fptr; + } + TJobHandler(const TJobHandler&) = default; + TJobHandler& operator =(const TJobHandler&) = default; + bool operator==(TJobHandler h) const { + return MyPtr == h.MyPtr; + } + bool operator!=(TJobHandler h) const { + return MyPtr != h.MyPtr; + } + bool operator!() const { + return !MyPtr; + } + TJobHandler operator()(TBusModule* b, TBusJob* job, TBusMessage* mess) { + return (b->*MyPtr)(job, mess); + } + }; + + typedef void (TBusModule::*TReplyHandler)(TBusJob* job, EMessageStatus status, TBusMessage* mess, TBusMessage* reply); + + //////////////////////////////////////////////////// + /// \brief Pending message state + + struct TJobState { + friend class TBusJob; + friend class ::TCrawlerModule; + + TReplyHandler Handler; + EMessageStatus Status; + TBusMessage* Message; + TBusMessage* Reply; + TBusClientSession* Session; + size_t NumRetries; + size_t MaxRetries; + // If != NULL then use it as destination. + TNetAddr Addr; + bool UseAddr; + bool OneWay; + + private: + TJobState(TReplyHandler handler, + EMessageStatus status, + TBusMessage* mess, TBusClientSession* session, TBusMessage* reply, size_t maxRetries = 0, + const TNetAddr* addr = nullptr, bool oneWay = false) + : Handler(handler) + , Status(status) + , Message(mess) + , Reply(reply) + , Session(session) + , NumRetries(0) + , MaxRetries(maxRetries) + , OneWay(oneWay) + { + if (!!addr) { + Addr = *addr; + } + UseAddr = !!addr; + } + + public: + TString GetStatus(unsigned flags); + }; + + using TJobStateVec = TVector<TJobState>; + + ///////////////////////////////////////////////////////// + /// \brief Execution item = thread + + /// Maintains internal state of document in computation + class TBusJob { + TObjectCounter<TBusJob> ObjectCounter; + + private: + void CheckThreadCurrentJob(); + + public: + /// given a module and starter message + TBusJob(TBusModule* module, TBusMessage* message); + + /// destructor will free all the message that were send and received + virtual ~TBusJob(); + + TBusMessage* GetMessage() const { + return Message; + } + + TNetAddr GetPeerAddrNetAddr() const; + + /// send message to any other session or application + /// If addr is set then use it as destination. + void Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries, const TNetAddr& addr); + void Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler = nullptr, size_t maxRetries = 0); + + void SendOneWayTo(TBusMessageAutoPtr req, TBusClientSession* session, const TNetAddr& addr); + void SendOneWayWithLocator(TBusMessageAutoPtr req, TBusClientSession* session); + + /// send reply to the starter message + virtual void SendReply(TBusMessageAutoPtr reply); + + /// set the flag to terminate job at the earliest convenience + void Cancel(EMessageStatus status); + + /// helper to put item on finished list of states + /// It should not be a part of public API, + /// so prohibit it for all except current users. + private: + friend class ::TCrawlerModule; + void PutState(const TJobState& state) { + Finished.push_back(state); + } + + public: + /// retrieve all pending messages + void GetPending(TJobStateVec* stateVec) { + Y_ASSERT(stateVec); + *stateVec = Pending; + } + + /// helper function to find state of previously sent messages + template <class MessageType> + TJobState* GetState(int* startFrom = nullptr) { + for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) { + TJobState* call = &Finished[i]; + if (call->Reply != nullptr && dynamic_cast<MessageType*>(call->Reply)) { + if (startFrom) { + *startFrom = i; + } + return call; + } + if (call->Message != nullptr && dynamic_cast<MessageType*>(call->Message)) { + if (startFrom) { + *startFrom = i; + } + return call; + } + } + return nullptr; + } + + /// helper function to find response for previously sent messages + template <class MessageType> + MessageType* Get(int* startFrom = nullptr) { + for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) { + TJobState& call = Finished[i]; + if (call.Reply != nullptr && dynamic_cast<MessageType*>(call.Reply)) { + if (startFrom) { + *startFrom = i; + } + return static_cast<MessageType*>(call.Reply); + } + if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) { + if (startFrom) { + *startFrom = i; + } + return static_cast<MessageType*>(call.Message); + } + } + return nullptr; + } + + /// helper function to find status for previously sent message + template <class MessageType> + EMessageStatus GetStatus(int* startFrom = nullptr) { + for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) { + TJobState& call = Finished[i]; + if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) { + if (startFrom) { + *startFrom = i; + } + return call.Status; + } + } + return MESSAGE_UNKNOWN; + } + + /// helper function to clear state of previosly sent messages + template <class MessageType> + void Clear() { + for (size_t i = 0; i < Finished.size();) { + // `Finished.size() - i` decreases with each iteration + // we either increment i, or remove element from Finished. + TJobState& call = Finished[i]; + if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) { + ClearState(call); + } else { + ++i; + } + } + } + + /// helper function to clear state in order to try again + void ClearState(TJobState& state); + + /// clears all message states + void ClearAllMessageStates(); + + /// returns true if job is done + bool IsDone(); + + /// return human reabable status of this job + virtual TString GetStatus(unsigned flags); + + /// set sleep time for job + void Sleep(int milliSeconds); + + void CallJobHandlerOnly(); + + private: + bool CallJobHandler(); + void DoCallReplyHandler(TJobState&); + /// send out all Pending jobs, failed sends will be migrated to Finished + bool SendPending(); + bool AnyPendingToSend(); + + public: + /// helper to call from OnReply() and OnError() + int CallReplyHandler(EMessageStatus status, TBusMessage* mess, TBusMessage* reply); + + public: + TJobHandler Handler; ///< job handler to be executed within next CallJobHandler() + EMessageStatus Status; ///< set != MESSAGE_OK if job should terminate asap + private: + NPrivate::TJobRunner* Runner; + TBusMessage* Message; + THolder<TBusMessage> MessageHolder; + TOnMessageContext OnMessageContext; // starter + public: + bool ReplySent; + + private: + friend class TBusModule; + friend struct NPrivate::TBusModuleImpl; + friend struct NPrivate::TCallJobHandlerWorkItem; + friend struct NPrivate::TModuleServerHandler; + friend struct NPrivate::TModuleClientHandler; + friend struct NPrivate::TJobRunner; + + TJobStateVec Pending; ///< messages currently outstanding via Send() + TJobStateVec Finished; ///< messages that were replied to + TBusModule* Module; + NPrivate::TBusModuleImpl* ModuleImpl; ///< module which created the job + TBusInstant SleepUntil; ///< time to wakeup, 0 if no sleep + }; + + //////////////////////////////////////////////////////////////////// + /// \brief Classes to implement basic module functionality + + class IJobFactory { + protected: + virtual ~IJobFactory() { + } + + public: + /// job factory method, override to create custom jobs + virtual TBusJob* CreateJobInstance(TBusMessage* message) = 0; + }; + + struct TBusModuleConfig { + unsigned StarterMaxInFlight; + + struct TSecret { + TDuration SchedulePeriod; + + TSecret(); + }; + TSecret Secret; + + TBusModuleConfig(); + }; + + namespace NPrivate { + struct TBusModuleInternal: public TAtomicRefCount<TBusModuleInternal> { + virtual TVector<TBusClientSessionPtr> GetClientSessionsInternal() = 0; + virtual TVector<TBusServerSessionPtr> GetServerSessionsInternal() = 0; + virtual TBusMessageQueue* GetQueue() = 0; + + virtual TString GetNameInternal() = 0; + + virtual TString GetStatusSingleLine() = 0; + + virtual ~TBusModuleInternal() { + } + }; + } + + class TBusModule: public IJobFactory, TNonCopyable { + friend class TBusJob; + + TObjectCounter<TBusModule> ObjectCounter; + + TIntrusivePtr<NPrivate::TBusModuleImpl> Impl; + + public: + /// Each module should have a name which is used as protocol service + TBusModule(const char* name); + ~TBusModule() override; + + const char* GetName() const; + + void SetConfig(const TBusModuleConfig& config); + + /// get status of all jobs in flight + TString GetStatus(unsigned flags = 0); + + /// called when application is about to start + virtual bool StartInput(); + /// called when application is about to exit + virtual bool Shutdown(); + + // this default implementation just creates TBusJob object + TBusJob* CreateJobInstance(TBusMessage* message) override; + + EMessageStatus StartJob(TAutoPtr<TBusMessage> message); + + /// creates private sessions, calls CreateExtSession(), should be called before StartInput() + bool CreatePrivateSessions(TBusMessageQueue* queue); + + virtual void OnClientConnectionEvent(const TClientConnectionEvent& event); + + public: + /// entry point into module, first function to call + virtual TJobHandler Start(TBusJob* job, TBusMessage* mess) = 0; + + protected: + /// override this function to create destination session + virtual TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) = 0; + + public: + int GetModuleSessionInFlight() const; + + TIntrusivePtr<NPrivate::TBusModuleInternal> GetInternal(); + + protected: + TBusServerSessionPtr CreateDefaultDestination(TBusMessageQueue& queue, TBusProtocol* proto, const TBusServerSessionConfig& config, const TString& name = TString()); + TBusClientSessionPtr CreateDefaultSource(TBusMessageQueue& queue, TBusProtocol* proto, const TBusClientSessionConfig& config, const TString& name = TString()); + TBusStarter* CreateDefaultStarter(TBusMessageQueue& unused, const TBusSessionConfig& config); + }; + +} diff --git a/library/cpp/messagebus/oldmodule/startsession.cpp b/library/cpp/messagebus/oldmodule/startsession.cpp new file mode 100644 index 0000000000..7c38801d62 --- /dev/null +++ b/library/cpp/messagebus/oldmodule/startsession.cpp @@ -0,0 +1,65 @@ +/////////////////////////////////////////////////////////// +/// \file +/// \brief Starter session implementation + +/// Starter session will generate emtpy message to insert +/// into local session that are registered under same protocol + +/// Starter (will one day) automatically adjust number +/// of message inflight to make sure that at least one of source +/// sessions within message queue is at the limit (bottle neck) + +/// Maximum number of messages that starter will instert into +/// the pipeline is configured by NBus::TBusSessionConfig::MaxInFlight + +#include "startsession.h" + +#include "module.h" + +#include <library/cpp/messagebus/ybus.h> + +namespace NBus { + void* TBusStarter::_starter(void* data) { + TBusStarter* pThis = static_cast<TBusStarter*>(data); + pThis->Starter(); + return nullptr; + } + + TBusStarter::TBusStarter(TBusModule* module, const TBusSessionConfig& config) + : Module(module) + , Config(config) + , StartThread(_starter, this) + , Exiting(false) + { + StartThread.Start(); + } + + TBusStarter::~TBusStarter() { + Shutdown(); + } + + void TBusStarter::Shutdown() { + { + TGuard<TMutex> g(ExitLock); + Exiting = true; + ExitSignal.Signal(); + } + StartThread.Join(); + } + + void TBusStarter::Starter() { + TGuard<TMutex> g(ExitLock); + while (!Exiting) { + TAutoPtr<TBusMessage> empty(new TBusMessage(0)); + + EMessageStatus status = Module->StartJob(empty); + + if (Config.SendTimeout > 0) { + ExitSignal.WaitT(ExitLock, TDuration::MilliSeconds(Config.SendTimeout)); + } else { + ExitSignal.WaitT(ExitLock, (status == MESSAGE_BUSY) ? TDuration::MilliSeconds(1) : TDuration::Zero()); + } + } + } + +} diff --git a/library/cpp/messagebus/oldmodule/startsession.h b/library/cpp/messagebus/oldmodule/startsession.h new file mode 100644 index 0000000000..5e26e7e1e5 --- /dev/null +++ b/library/cpp/messagebus/oldmodule/startsession.h @@ -0,0 +1,34 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +#include <util/system/thread.h> + +namespace NBus { + class TBusModule; + + class TBusStarter { + private: + TBusModule* Module; + TBusSessionConfig Config; + TThread StartThread; + bool Exiting; + TCondVar ExitSignal; + TMutex ExitLock; + + static void* _starter(void* data); + + void Starter(); + + TString GetStatus(ui16 /*flags=YBUS_STATUS_CONNS*/) { + return ""; + } + + public: + TBusStarter(TBusModule* module, const TBusSessionConfig& config); + ~TBusStarter(); + + void Shutdown(); + }; + +} diff --git a/library/cpp/messagebus/oldmodule/ya.make b/library/cpp/messagebus/oldmodule/ya.make new file mode 100644 index 0000000000..ca5eae74f0 --- /dev/null +++ b/library/cpp/messagebus/oldmodule/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus + library/cpp/messagebus/actor +) + +SRCS( + module.cpp + startsession.cpp +) + +END() diff --git a/library/cpp/messagebus/protobuf/ya.make b/library/cpp/messagebus/protobuf/ya.make new file mode 100644 index 0000000000..64ff240b51 --- /dev/null +++ b/library/cpp/messagebus/protobuf/ya.make @@ -0,0 +1,15 @@ +LIBRARY(messagebus_protobuf) + +OWNER(g:messagebus) + +SRCS( + ybusbuf.cpp +) + +PEERDIR( + contrib/libs/protobuf + library/cpp/messagebus + library/cpp/messagebus/actor +) + +END() diff --git a/library/cpp/messagebus/protobuf/ybusbuf.cpp b/library/cpp/messagebus/protobuf/ybusbuf.cpp new file mode 100644 index 0000000000..63415b3737 --- /dev/null +++ b/library/cpp/messagebus/protobuf/ybusbuf.cpp @@ -0,0 +1,88 @@ +#include "ybusbuf.h" + +#include <library/cpp/messagebus/actor/what_thread_does.h> + +#include <google/protobuf/io/coded_stream.h> + +using namespace NBus; + +TBusBufferProtocol::TBusBufferProtocol(TBusService name, int port) + : TBusProtocol(name, port) +{ +} + +TBusBufferProtocol::~TBusBufferProtocol() { + for (auto& type : Types) { + delete type; + } +} + +TBusBufferBase* TBusBufferProtocol::FindType(int type) { + for (unsigned i = 0; i < Types.size(); i++) { + if (Types[i]->GetHeader()->Type == type) { + return Types[i]; + } + } + return nullptr; +} + +bool TBusBufferProtocol::IsRegisteredType(unsigned type) { + return TypeMask[type >> 5] & (1 << (type & ((1 << 5) - 1))); +} + +void TBusBufferProtocol::RegisterType(TAutoPtr<TBusBufferBase> mess) { + ui32 type = mess->GetHeader()->Type; + TypeMask[type >> 5] |= 1 << (type & ((1 << 5) - 1)); + + Types.push_back(mess.Release()); +} + +TArrayRef<TBusBufferBase* const> TBusBufferProtocol::GetTypes() const { + return Types; +} + +void TBusBufferProtocol::Serialize(const TBusMessage* mess, TBuffer& data) { + TWhatThreadDoesPushPop pp("serialize protobuf message"); + + const TBusHeader* header = mess->GetHeader(); + + if (!IsRegisteredType(header->Type)) { + Y_FAIL("unknown message type: %d", int(header->Type)); + return; + } + + // cast the base from real message + const TBusBufferBase* bmess = CheckedCast<const TBusBufferBase*>(mess); + + unsigned size = bmess->GetRecord()->ByteSize(); + data.Reserve(data.Size() + size); + + char* after = (char*)bmess->GetRecord()->SerializeWithCachedSizesToArray((ui8*)data.Pos()); + Y_VERIFY(after - data.Pos() == size); + + data.Advance(size); +} + +TAutoPtr<TBusMessage> TBusBufferProtocol::Deserialize(ui16 messageType, TArrayRef<const char> payload) { + TWhatThreadDoesPushPop pp("deserialize protobuf message"); + + TBusBufferBase* messageTemplate = FindType(messageType); + if (messageTemplate == nullptr) { + return nullptr; + //Y_FAIL("unknown message type: %d", unsigned(messageType)); + } + + // clone the base + TAutoPtr<TBusBufferBase> bmess = messageTemplate->New(); + + // Need to override protobuf message size limit + // NOTE: the payload size has already been checked against session MaxMessageSize + google::protobuf::io::CodedInputStream input(reinterpret_cast<const ui8*>(payload.data()), payload.size()); + input.SetTotalBytesLimit(payload.size()); + + bool ok = bmess->GetRecord()->ParseFromCodedStream(&input) && input.ConsumedEntireMessage(); + if (!ok) { + return nullptr; + } + return bmess.Release(); +} diff --git a/library/cpp/messagebus/protobuf/ybusbuf.h b/library/cpp/messagebus/protobuf/ybusbuf.h new file mode 100644 index 0000000000..57b4267ea5 --- /dev/null +++ b/library/cpp/messagebus/protobuf/ybusbuf.h @@ -0,0 +1,233 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +#include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> + +#include <util/generic/cast.h> +#include <util/generic/vector.h> +#include <util/stream/mem.h> + +#include <array> + +namespace NBus { + using TBusBufferRecord = ::google::protobuf::Message; + + template <class TBufferMessage> + class TBusBufferMessagePtr; + template <class TBufferMessage> + class TBusBufferMessageAutoPtr; + + class TBusBufferBase: public TBusMessage { + public: + TBusBufferBase(int type) + : TBusMessage((ui16)type) + { + } + TBusBufferBase(ECreateUninitialized) + : TBusMessage(MESSAGE_CREATE_UNINITIALIZED) + { + } + + ui16 GetType() const { + return GetHeader()->Type; + } + + virtual TBusBufferRecord* GetRecord() const = 0; + virtual TBusBufferBase* New() = 0; + }; + + /////////////////////////////////////////////////////////////////// + /// \brief Template for all messages that have protobuf description + + /// @param TBufferRecord is record described in .proto file with namespace + /// @param MessageFile is offset for .proto file message ids + + /// \attention If you want one protocol NBus::TBusBufferProtocol to handle + /// messageges described in different .proto files, make sure that they have + /// unique values for MessageFile + + template <class TBufferRecord, int MType> + class TBusBufferMessage: public TBusBufferBase { + public: + static const int MessageType = MType; + + typedef TBusBufferMessagePtr<TBusBufferMessage<TBufferRecord, MType>> TPtr; + typedef TBusBufferMessageAutoPtr<TBusBufferMessage<TBufferRecord, MType>> TAutoPtr; + + public: + typedef TBufferRecord RecordType; + TBufferRecord Record; + + public: + TBusBufferMessage() + : TBusBufferBase(MessageType) + { + } + TBusBufferMessage(ECreateUninitialized) + : TBusBufferBase(MESSAGE_CREATE_UNINITIALIZED) + { + } + explicit TBusBufferMessage(const TBufferRecord& record) + : TBusBufferBase(MessageType) + , Record(record) + { + } + explicit TBusBufferMessage(TBufferRecord&& record) + : TBusBufferBase(MessageType) + , Record(std::move(record)) + { + } + + public: + TBusBufferRecord* GetRecord() const override { + return (TBusBufferRecord*)&Record; + } + TBusBufferBase* New() override { + return new TBusBufferMessage<TBufferRecord, MessageType>(); + } + }; + + template <class TSelf, class TBufferMessage> + class TBusBufferMessagePtrBase { + public: + typedef typename TBufferMessage::RecordType RecordType; + + private: + TSelf* GetSelf() { + return static_cast<TSelf*>(this); + } + const TSelf* GetSelf() const { + return static_cast<const TSelf*>(this); + } + + public: + RecordType* operator->() { + Y_ASSERT(GetSelf()->Get()); + return &(GetSelf()->Get()->Record); + } + const RecordType* operator->() const { + Y_ASSERT(GetSelf()->Get()); + return &(GetSelf()->Get()->Record); + } + RecordType& operator*() { + Y_ASSERT(GetSelf()->Get()); + return GetSelf()->Get()->Record; + } + const RecordType& operator*() const { + Y_ASSERT(GetSelf()->Get()); + return GetSelf()->Get()->Record; + } + + TBusHeader* GetHeader() { + return GetSelf()->Get()->GetHeader(); + } + const TBusHeader* GetHeader() const { + return GetSelf()->Get()->GetHeader(); + } + }; + + template <class TBufferMessage> + class TBusBufferMessagePtr: public TBusBufferMessagePtrBase<TBusBufferMessagePtr<TBufferMessage>, TBufferMessage> { + protected: + TBufferMessage* Holder; + + public: + TBusBufferMessagePtr(TBufferMessage* mess) + : Holder(mess) + { + } + static TBusBufferMessagePtr<TBufferMessage> DynamicCast(TBusMessage* message) { + return dynamic_cast<TBufferMessage*>(message); + } + TBufferMessage* Get() { + return Holder; + } + const TBufferMessage* Get() const { + return Holder; + } + + operator TBufferMessage*() { + return Holder; + } + operator const TBufferMessage*() const { + return Holder; + } + + operator TAutoPtr<TBusMessage>() { + TAutoPtr<TBusMessage> r(Holder); + Holder = 0; + return r; + } + operator TBusMessageAutoPtr() { + TBusMessageAutoPtr r(Holder); + Holder = nullptr; + return r; + } + }; + + template <class TBufferMessage> + class TBusBufferMessageAutoPtr: public TBusBufferMessagePtrBase<TBusBufferMessageAutoPtr<TBufferMessage>, TBufferMessage> { + public: + TAutoPtr<TBufferMessage> AutoPtr; + + public: + TBusBufferMessageAutoPtr() { + } + TBusBufferMessageAutoPtr(TBufferMessage* message) + : AutoPtr(message) + { + } + + TBufferMessage* Get() { + return AutoPtr.Get(); + } + const TBufferMessage* Get() const { + return AutoPtr.Get(); + } + + TBufferMessage* Release() const { + return AutoPtr.Release(); + } + + operator TAutoPtr<TBusMessage>() { + return AutoPtr.Release(); + } + operator TBusMessageAutoPtr() { + return AutoPtr.Release(); + } + }; + + ///////////////////////////////////////////// + /// \brief Generic protocol object for messages descibed with protobuf + + /// \attention If you mix messages in the same protocol from more than + /// .proto file make sure that they have different MessageFile parameter + /// in the NBus::TBusBufferMessage template + + class TBusBufferProtocol: public TBusProtocol { + private: + TVector<TBusBufferBase*> Types; + std::array<ui32, ((1 << 16) >> 5)> TypeMask; + + TBusBufferBase* FindType(int type); + bool IsRegisteredType(unsigned type); + + public: + TBusBufferProtocol(TBusService name, int port); + + ~TBusBufferProtocol() override; + + /// register all the message that this protocol should handle + void RegisterType(TAutoPtr<TBusBufferBase> mess); + + TArrayRef<TBusBufferBase* const> GetTypes() const; + + /// serialized protocol specific data into TBusData + void Serialize(const TBusMessage* mess, TBuffer& data) override; + + TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override; + }; + +} diff --git a/library/cpp/messagebus/queue_config.cpp b/library/cpp/messagebus/queue_config.cpp new file mode 100644 index 0000000000..78fb52ee49 --- /dev/null +++ b/library/cpp/messagebus/queue_config.cpp @@ -0,0 +1,22 @@ +#include "queue_config.h" + +using namespace NBus; + +TBusQueueConfig::TBusQueueConfig() { + // workers and listeners configuratioin + NumWorkers = 1; +} + +void TBusQueueConfig::ConfigureLastGetopt( + NLastGetopt::TOpts& opts, const TString& prefix) { + opts.AddLongOption(prefix + "worker-count") + .RequiredArgument("COUNT") + .DefaultValue(ToString(NumWorkers)) + .StoreResult(&NumWorkers); +} + +TString TBusQueueConfig::PrintToString() const { + TStringStream ss; + ss << "NumWorkers=" << NumWorkers << "\n"; + return ss.Str(); +} diff --git a/library/cpp/messagebus/queue_config.h b/library/cpp/messagebus/queue_config.h new file mode 100644 index 0000000000..a9955f0c70 --- /dev/null +++ b/library/cpp/messagebus/queue_config.h @@ -0,0 +1,19 @@ +#pragma once + +#include <library/cpp/getopt/last_getopt.h> + +namespace NBus { + ////////////////////////////////////////////////////////////////// + /// \brief Configuration for message queue + struct TBusQueueConfig { + TString Name; + int NumWorkers; ///< number of threads calling OnMessage(), OnReply() handlers + + TBusQueueConfig(); ///< initializes with default settings + + void ConfigureLastGetopt(NLastGetopt::TOpts&, const TString& prefix = "mb-"); + + TString PrintToString() const; + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/coro.cpp b/library/cpp/messagebus/rain_check/core/coro.cpp new file mode 100644 index 0000000000..500841dd5b --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/coro.cpp @@ -0,0 +1,60 @@ +#include "coro.h" + +#include "coro_stack.h" + +#include <util/system/tls.h> +#include <util/system/yassert.h> + +using namespace NRainCheck; + +TContClosure TCoroTaskRunner::ContClosure(TCoroTaskRunner* runner, TArrayRef<char> memRegion) { + TContClosure contClosure; + contClosure.TrampoLine = runner; + contClosure.Stack = memRegion; + return contClosure; +} + +TCoroTaskRunner::TCoroTaskRunner(IEnv* env, ISubtaskListener* parent, TAutoPtr<ICoroTask> impl) + : TTaskRunnerBase(env, parent, impl.Release()) + , Stack(GetImpl()->StackSize) + , ContMachineContext(ContClosure(this, Stack.MemRegion())) + , CoroDone(false) +{ +} + +TCoroTaskRunner::~TCoroTaskRunner() { + Y_ASSERT(CoroDone); +} + +Y_POD_STATIC_THREAD(TContMachineContext*) +CallerContext; +Y_POD_STATIC_THREAD(TCoroTaskRunner*) +Task; + +bool TCoroTaskRunner::ReplyReceived() { + Y_ASSERT(!CoroDone); + + TContMachineContext me; + + CallerContext = &me; + Task = this; + + me.SwitchTo(&ContMachineContext); + + Stack.VerifyNoStackOverflow(); + + Y_ASSERT(CallerContext == &me); + Y_ASSERT(Task == this); + + return !CoroDone; +} + +void NRainCheck::TCoroTaskRunner::DoRun() { + GetImpl()->Run(); + CoroDone = true; + ContMachineContext.SwitchTo(CallerContext); +} + +void NRainCheck::ICoroTask::WaitForSubtasks() { + Task->ContMachineContext.SwitchTo(CallerContext); +} diff --git a/library/cpp/messagebus/rain_check/core/coro.h b/library/cpp/messagebus/rain_check/core/coro.h new file mode 100644 index 0000000000..95e2a30f9b --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/coro.h @@ -0,0 +1,58 @@ +#pragma once + +#include "coro_stack.h" +#include "task.h" + +#include <util/generic/ptr.h> +#include <util/memory/alloc.h> +#include <util/system/align.h> +#include <util/system/context.h> +#include <util/system/valgrind.h> + +namespace NRainCheck { + class ICoroTask; + + class TCoroTaskRunner: public TTaskRunnerBase, private ITrampoLine { + friend class ICoroTask; + + private: + NPrivate::TCoroStack Stack; + TContMachineContext ContMachineContext; + bool CoroDone; + + public: + TCoroTaskRunner(IEnv* env, ISubtaskListener* parent, TAutoPtr<ICoroTask> impl); + ~TCoroTaskRunner() override; + + private: + static TContClosure ContClosure(TCoroTaskRunner* runner, TArrayRef<char> memRegion); + + bool ReplyReceived() override /* override */; + + void DoRun() override /* override */; + + ICoroTask* GetImpl() { + return (ICoroTask*)GetImplBase(); + } + }; + + class ICoroTask: public ITaskBase { + friend class TCoroTaskRunner; + + private: + size_t StackSize; + + public: + typedef TCoroTaskRunner TTaskRunner; + typedef ICoroTask ITask; + + ICoroTask(size_t stackSize = 0x2000) + : StackSize(stackSize) + { + } + + virtual void Run() = 0; + static void WaitForSubtasks(); + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/coro_stack.cpp b/library/cpp/messagebus/rain_check/core/coro_stack.cpp new file mode 100644 index 0000000000..83b984ca6e --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/coro_stack.cpp @@ -0,0 +1,41 @@ +#include "coro_stack.h" + +#include <util/generic/singleton.h> +#include <util/system/valgrind.h> + +#include <cstdlib> +#include <stdio.h> + +using namespace NRainCheck; +using namespace NRainCheck::NPrivate; + +TCoroStack::TCoroStack(size_t size) + : SizeValue(size) +{ + Y_VERIFY(size % sizeof(ui32) == 0); + Y_VERIFY(size >= 0x1000); + + DataHolder.Reset(malloc(size)); + + // register in valgrind + + *MagicNumberLocation() = MAGIC_NUMBER; + +#if defined(WITH_VALGRIND) + ValgrindStackId = VALGRIND_STACK_REGISTER(Data(), (char*)Data() + Size()); +#endif +} + +TCoroStack::~TCoroStack() { +#if defined(WITH_VALGRIND) + VALGRIND_STACK_DEREGISTER(ValgrindStackId); +#endif + + VerifyNoStackOverflow(); +} + +void TCoroStack::FailStackOverflow() { + static const char message[] = "stack overflow\n"; + fputs(message, stderr); + abort(); +} diff --git a/library/cpp/messagebus/rain_check/core/coro_stack.h b/library/cpp/messagebus/rain_check/core/coro_stack.h new file mode 100644 index 0000000000..2f3520e6e4 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/coro_stack.h @@ -0,0 +1,54 @@ +#pragma once + +#include <util/generic/array_ref.h> +#include <util/generic/ptr.h> +#include <util/system/valgrind.h> + +namespace NRainCheck { + namespace NPrivate { + struct TCoroStack { + THolder<void, TFree> DataHolder; + size_t SizeValue; + +#if defined(WITH_VALGRIND) + size_t ValgrindStackId; +#endif + + TCoroStack(size_t size); + ~TCoroStack(); + + void* Data() { + return DataHolder.Get(); + } + + size_t Size() { + return SizeValue; + } + + TArrayRef<char> MemRegion() { + return TArrayRef((char*)Data(), Size()); + } + + ui32* MagicNumberLocation() { +#if STACK_GROW_DOWN == 1 + return (ui32*)Data(); +#elif STACK_GROW_DOWN == 0 + return ((ui32*)(((char*)Data()) + Size())) - 1; +#else +#error "unknown" +#endif + } + + static void FailStackOverflow(); + + inline void VerifyNoStackOverflow() noexcept { + if (Y_UNLIKELY(*MagicNumberLocation() != MAGIC_NUMBER)) { + FailStackOverflow(); + } + } + + static const ui32 MAGIC_NUMBER = 0xAB4D15FE; + }; + + } +} diff --git a/library/cpp/messagebus/rain_check/core/coro_ut.cpp b/library/cpp/messagebus/rain_check/core/coro_ut.cpp new file mode 100644 index 0000000000..61a33584a5 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/coro_ut.cpp @@ -0,0 +1,106 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "coro.h" +#include "spawn.h" + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +using namespace NRainCheck; + +Y_UNIT_TEST_SUITE(RainCheckCoro) { + struct TSimpleCoroTask : ICoroTask { + TTestSync* const TestSync; + + TSimpleCoroTask(TTestEnv*, TTestSync* testSync) + : TestSync(testSync) + { + } + + void Run() override { + TestSync->WaitForAndIncrement(0); + } + }; + + Y_UNIT_TEST(Simple) { + TTestSync testSync; + + TTestEnv env; + + TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSimpleCoroTask>(&testSync); + testSync.WaitForAndIncrement(1); + } + + struct TSleepCoroTask : ICoroTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TSleepCoroTask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TSubtaskCompletion SleepCompletion; + + void Run() override { + Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1)); + WaitForSubtasks(); + TestSync->WaitForAndIncrement(0); + } + }; + + Y_UNIT_TEST(Sleep) { + TTestSync testSync; + + TTestEnv env; + + TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSleepCoroTask>(&testSync); + + testSync.WaitForAndIncrement(1); + } + + struct TSubtask : ICoroTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TSubtask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + void Run() override { + TestSync->CheckAndIncrement(1); + } + }; + + struct TSpawnCoroTask : ICoroTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TSpawnCoroTask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TSubtaskCompletion SubtaskCompletion; + + void Run() override { + TestSync->CheckAndIncrement(0); + SpawnSubtask<TSubtask>(Env, &SubtaskCompletion, TestSync); + WaitForSubtasks(); + TestSync->CheckAndIncrement(2); + } + }; + + Y_UNIT_TEST(Spawn) { + TTestSync testSync; + + TTestEnv env; + + TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSpawnCoroTask>(&testSync); + + testSync.WaitForAndIncrement(3); + } +} diff --git a/library/cpp/messagebus/rain_check/core/env.cpp b/library/cpp/messagebus/rain_check/core/env.cpp new file mode 100644 index 0000000000..fdc0000dbd --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/env.cpp @@ -0,0 +1,3 @@ +#include "env.h" + +using namespace NRainCheck; diff --git a/library/cpp/messagebus/rain_check/core/env.h b/library/cpp/messagebus/rain_check/core/env.h new file mode 100644 index 0000000000..f6dd7fceb6 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/env.h @@ -0,0 +1,47 @@ +#pragma once + +#include "sleep.h" +#include "spawn.h" + +#include <library/cpp/messagebus/actor/executor.h> + +#include <util/generic/ptr.h> + +namespace NRainCheck { + struct IEnv { + virtual ::NActor::TExecutor* GetExecutor() = 0; + virtual ~IEnv() { + } + }; + + template <typename TSelf> + struct TEnvTemplate: public IEnv { + template <typename TTask, typename TParam> + TIntrusivePtr<typename TTask::TTaskRunner> SpawnTask(TParam param) { + return ::NRainCheck::SpawnTask<TTask, TSelf>((TSelf*)this, param); + } + }; + + template <typename TSelf> + struct TSimpleEnvTemplate: public TEnvTemplate<TSelf> { + ::NActor::TExecutorPtr Executor; + TSleepService SleepService; + + TSimpleEnvTemplate(unsigned threadCount = 0) + : Executor(new ::NActor::TExecutor(threadCount != 0 ? threadCount : 4)) + { + } + + ::NActor::TExecutor* GetExecutor() override { + return Executor.Get(); + } + }; + + struct TSimpleEnv: public TSimpleEnvTemplate<TSimpleEnv> { + TSimpleEnv(unsigned threadCount = 0) + : TSimpleEnvTemplate<TSimpleEnv>(threadCount) + { + } + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/fwd.h b/library/cpp/messagebus/rain_check/core/fwd.h new file mode 100644 index 0000000000..b43ff8c17c --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/fwd.h @@ -0,0 +1,18 @@ +#pragma once + +namespace NRainCheck { + namespace NPrivate { + } + + class ITaskBase; + class ISimpleTask; + class ICoroTask; + + struct ISubtaskListener; + + class TTaskRunnerBase; + + class TSubtaskCompletion; + struct IEnv; + +} diff --git a/library/cpp/messagebus/rain_check/core/rain_check.cpp b/library/cpp/messagebus/rain_check/core/rain_check.cpp new file mode 100644 index 0000000000..2ea1f9e21b --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/rain_check.cpp @@ -0,0 +1 @@ +#include "rain_check.h" diff --git a/library/cpp/messagebus/rain_check/core/rain_check.h b/library/cpp/messagebus/rain_check/core/rain_check.h new file mode 100644 index 0000000000..0f289717a2 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/rain_check.h @@ -0,0 +1,8 @@ +#pragma once + +#include "coro.h" +#include "env.h" +#include "simple.h" +#include "sleep.h" +#include "spawn.h" +#include "task.h" diff --git a/library/cpp/messagebus/rain_check/core/simple.cpp b/library/cpp/messagebus/rain_check/core/simple.cpp new file mode 100644 index 0000000000..70182b2f93 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/simple.cpp @@ -0,0 +1,18 @@ +#include "simple.h" + +using namespace NRainCheck; + +TSimpleTaskRunner::TSimpleTaskRunner(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ISimpleTask> impl) + : TTaskRunnerBase(env, parentTask, impl.Release()) + , ContinueFunc(&ISimpleTask::Start) +{ +} + +TSimpleTaskRunner::~TSimpleTaskRunner() { + Y_ASSERT(!ContinueFunc); +} + +bool TSimpleTaskRunner::ReplyReceived() { + ContinueFunc = (GetImpl()->*(ContinueFunc.Func))(); + return !!ContinueFunc; +} diff --git a/library/cpp/messagebus/rain_check/core/simple.h b/library/cpp/messagebus/rain_check/core/simple.h new file mode 100644 index 0000000000..20e1bf19f5 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/simple.h @@ -0,0 +1,62 @@ +#pragma once + +#include "task.h" + +namespace NRainCheck { + class ISimpleTask; + + // Function called on continue + class TContinueFunc { + friend class TSimpleTaskRunner; + + typedef TContinueFunc (ISimpleTask::*TFunc)(); + TFunc Func; + + public: + TContinueFunc() + : Func(nullptr) + { + } + + TContinueFunc(void*) + : Func(nullptr) + { + } + + template <typename TTask> + TContinueFunc(TContinueFunc (TTask::*func)()) + : Func((TFunc)func) + { + static_assert((std::is_base_of<ISimpleTask, TTask>::value), "expect (std::is_base_of<ISimpleTask, TTask>::value)"); + } + + bool operator!() const { + return !Func; + } + }; + + class TSimpleTaskRunner: public TTaskRunnerBase { + public: + TSimpleTaskRunner(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ISimpleTask>); + ~TSimpleTaskRunner() override; + + private: + // Function to be called on completion of all pending tasks. + TContinueFunc ContinueFunc; + + bool ReplyReceived() override /* override */; + + ISimpleTask* GetImpl() { + return (ISimpleTask*)GetImplBase(); + } + }; + + class ISimpleTask: public ITaskBase { + public: + typedef TSimpleTaskRunner TTaskRunner; + typedef ISimpleTask ITask; + + virtual TContinueFunc Start() = 0; + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/simple_ut.cpp b/library/cpp/messagebus/rain_check/core/simple_ut.cpp new file mode 100644 index 0000000000..d4545e05aa --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/simple_ut.cpp @@ -0,0 +1,59 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <library/cpp/messagebus/latch.h> + +#include <util/system/event.h> + +using namespace NRainCheck; + +Y_UNIT_TEST_SUITE(RainCheckSimple) { + struct TTaskWithCompletionCallback: public ISimpleTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TTaskWithCompletionCallback(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TSubtaskCompletion SleepCompletion; + + TContinueFunc Start() override { + TestSync->CheckAndIncrement(0); + + Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1)); + SleepCompletion.SetCompletionCallback(&TTaskWithCompletionCallback::SleepCompletionCallback); + + return &TTaskWithCompletionCallback::Last; + } + + void SleepCompletionCallback(TSubtaskCompletion* completion) { + Y_VERIFY(completion == &SleepCompletion); + TestSync->CheckAndIncrement(1); + + Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1)); + SleepCompletion.SetCompletionCallback(&TTaskWithCompletionCallback::NextSleepCompletionCallback); + } + + void NextSleepCompletionCallback(TSubtaskCompletion*) { + TestSync->CheckAndIncrement(2); + } + + TContinueFunc Last() { + TestSync->CheckAndIncrement(3); + return nullptr; + } + }; + + Y_UNIT_TEST(CompletionCallback) { + TTestEnv env; + TTestSync testSync; + + env.SpawnTask<TTaskWithCompletionCallback>(&testSync); + + testSync.WaitForAndIncrement(4); + } +} diff --git a/library/cpp/messagebus/rain_check/core/sleep.cpp b/library/cpp/messagebus/rain_check/core/sleep.cpp new file mode 100644 index 0000000000..f5d0b4cac9 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/sleep.cpp @@ -0,0 +1,47 @@ +#include "rain_check.h" + +#include <util/system/yassert.h> + +using namespace NRainCheck; +using namespace NRainCheck::NPrivate; +using namespace NBus; +using namespace NBus::NPrivate; + +TSleepService::TSleepService(::NBus::NPrivate::TScheduler* scheduler) + : Scheduler(scheduler) +{ +} + +NRainCheck::TSleepService::TSleepService() + : SchedulerHolder(new TScheduler) + , Scheduler(SchedulerHolder.Get()) +{ +} + +NRainCheck::TSleepService::~TSleepService() { + if (!!SchedulerHolder) { + Scheduler->Stop(); + } +} + +namespace { + struct TSleepServiceScheduleItem: public IScheduleItem { + ISubtaskListener* const Parent; + + TSleepServiceScheduleItem(ISubtaskListener* parent, TInstant time) + : IScheduleItem(time) + , Parent(parent) + { + } + + void Do() override { + Parent->SetDone(); + } + }; +} + +void TSleepService::Sleep(TSubtaskCompletion* r, TDuration duration) { + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + r->SetRunning(current); + Scheduler->Schedule(new TSleepServiceScheduleItem(r, TInstant::Now() + duration)); +} diff --git a/library/cpp/messagebus/rain_check/core/sleep.h b/library/cpp/messagebus/rain_check/core/sleep.h new file mode 100644 index 0000000000..1a7a1f8674 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/sleep.h @@ -0,0 +1,24 @@ +#pragma once + +#include "fwd.h" + +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <util/datetime/base.h> + +namespace NRainCheck { + class TSleepService { + private: + THolder< ::NBus::NPrivate::TScheduler> SchedulerHolder; + ::NBus::NPrivate::TScheduler* const Scheduler; + + public: + TSleepService(::NBus::NPrivate::TScheduler*); + TSleepService(); + ~TSleepService(); + + // Wake up a task after given duration. + void Sleep(TSubtaskCompletion* r, TDuration); + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/sleep_ut.cpp b/library/cpp/messagebus/rain_check/core/sleep_ut.cpp new file mode 100644 index 0000000000..2ae85a87b1 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/sleep_ut.cpp @@ -0,0 +1,46 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <util/system/event.h> + +using namespace NRainCheck; +using namespace NActor; + +Y_UNIT_TEST_SUITE(Sleep) { + struct TTestTask: public ISimpleTask { + TSimpleEnv* const Env; + TTestSync* const TestSync; + + TTestTask(TSimpleEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TSubtaskCompletion Sleep; + + TContinueFunc Start() override { + Env->SleepService.Sleep(&Sleep, TDuration::MilliSeconds(1)); + + TestSync->CheckAndIncrement(0); + + return &TTestTask::Continue; + } + + TContinueFunc Continue() { + TestSync->CheckAndIncrement(1); + return nullptr; + } + }; + + Y_UNIT_TEST(Test) { + TTestSync testSync; + + TSimpleEnv env; + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TTestTask>(&testSync); + + testSync.WaitForAndIncrement(2); + } +} diff --git a/library/cpp/messagebus/rain_check/core/spawn.cpp b/library/cpp/messagebus/rain_check/core/spawn.cpp new file mode 100644 index 0000000000..c570355fbe --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/spawn.cpp @@ -0,0 +1,5 @@ +#include "spawn.h" + +void NRainCheck::NPrivate::SpawnTaskImpl(TTaskRunnerBase* task) { + task->Schedule(); +} diff --git a/library/cpp/messagebus/rain_check/core/spawn.h b/library/cpp/messagebus/rain_check/core/spawn.h new file mode 100644 index 0000000000..f2b146bf29 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/spawn.h @@ -0,0 +1,50 @@ +#pragma once + +#include "coro.h" +#include "simple.h" +#include "task.h" + +namespace NRainCheck { + namespace NPrivate { + void SpawnTaskImpl(TTaskRunnerBase* task); + + template <typename TTask, typename ITask, typename TRunner, typename TEnv, typename TParam> + TIntrusivePtr<TRunner> SpawnTaskWithRunner(TEnv* env, TParam param1, ISubtaskListener* subtaskListener) { + static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)"); + TIntrusivePtr<TRunner> task(new TRunner(env, subtaskListener, new TTask(env, param1))); + NPrivate::SpawnTaskImpl(task.Get()); + return task; + } + + template <typename TTask, typename ITask, typename TRunner, typename TEnv> + void SpawnSubtaskWithRunner(TEnv* env, TSubtaskCompletion* completion) { + static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)"); + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + completion->SetRunning(current); + NPrivate::SpawnTaskImpl(new TRunner(env, completion, new TTask(env))); + } + + template <typename TTask, typename ITask, typename TRunner, typename TEnv, typename TParam> + void SpawnSubtaskWithRunner(TEnv* env, TSubtaskCompletion* completion, TParam param) { + static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)"); + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + completion->SetRunning(current); + NPrivate::SpawnTaskImpl(new TRunner(env, completion, new TTask(env, param))); + } + + } + + // Instantiate and start a task with given parameter. + template <typename TTask, typename TEnv, typename TParam> + TIntrusivePtr<typename TTask::TTaskRunner> SpawnTask(TEnv* env, TParam param1, ISubtaskListener* subtaskListener = &TNopSubtaskListener::Instance) { + return NPrivate::SpawnTaskWithRunner< + TTask, typename TTask::ITask, typename TTask::TTaskRunner, TEnv, TParam>(env, param1, subtaskListener); + } + + // Instantiate and start subtask of given task. + template <typename TTask, typename TEnv, typename TParam> + void SpawnSubtask(TEnv* env, TSubtaskCompletion* completion, TParam param) { + return NPrivate::SpawnSubtaskWithRunner<TTask, typename TTask::ITask, typename TTask::TTaskRunner>(env, completion, param); + } + +} diff --git a/library/cpp/messagebus/rain_check/core/spawn_ut.cpp b/library/cpp/messagebus/rain_check/core/spawn_ut.cpp new file mode 100644 index 0000000000..ba5a5e41cf --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/spawn_ut.cpp @@ -0,0 +1,145 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/rain_check/test/helper/misc.h> +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <library/cpp/messagebus/latch.h> + +#include <util/system/event.h> + +#include <array> + +using namespace NRainCheck; +using namespace NActor; + +Y_UNIT_TEST_SUITE(Spawn) { + struct TTestTask: public ISimpleTask { + TTestSync* const TestSync; + + TTestTask(TSimpleEnv*, TTestSync* testSync) + : TestSync(testSync) + , I(0) + { + } + + TSystemEvent Started; + + unsigned I; + + TContinueFunc Start() override { + if (I < 4) { + I += 1; + return &TTestTask::Start; + } + TestSync->CheckAndIncrement(0); + return &TTestTask::Continue; + } + + TContinueFunc Continue() { + TestSync->CheckAndIncrement(1); + + Started.Signal(); + return nullptr; + } + }; + + Y_UNIT_TEST(Continuation) { + TTestSync testSync; + + TSimpleEnv env; + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TTestTask>(&testSync); + + testSync.WaitForAndIncrement(2); + } + + struct TSubtask: public ISimpleTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TSubtask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TContinueFunc Start() override { + Sleep(TDuration::MilliSeconds(1)); + TestSync->CheckAndIncrement(1); + return nullptr; + } + }; + + struct TSpawnTask: public ISimpleTask { + TTestEnv* const Env; + TTestSync* const TestSync; + + TSpawnTask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + { + } + + TSubtaskCompletion SubtaskCompletion; + + TContinueFunc Start() override { + TestSync->CheckAndIncrement(0); + SpawnSubtask<TSubtask>(Env, &SubtaskCompletion, TestSync); + return &TSpawnTask::Continue; + } + + TContinueFunc Continue() { + TestSync->CheckAndIncrement(2); + return nullptr; + } + }; + + Y_UNIT_TEST(Subtask) { + TTestSync testSync; + + TTestEnv env; + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSpawnTask>(&testSync); + + testSync.WaitForAndIncrement(3); + } + + struct TSpawnLongTask: public ISimpleTask { + TTestEnv* const Env; + TTestSync* const TestSync; + unsigned I; + + TSpawnLongTask(TTestEnv* env, TTestSync* testSync) + : Env(env) + , TestSync(testSync) + , I(0) + { + } + + std::array<TSubtaskCompletion, 3> Subtasks; + + TContinueFunc Start() override { + if (I == 1000) { + TestSync->CheckAndIncrement(0); + return nullptr; + } + + for (auto& subtask : Subtasks) { + SpawnSubtask<TNopSimpleTask>(Env, &subtask, ""); + } + + ++I; + return &TSpawnLongTask::Start; + } + }; + + Y_UNIT_TEST(SubtaskLong) { + TTestSync testSync; + + TTestEnv env; + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSpawnLongTask>(&testSync); + + testSync.WaitForAndIncrement(1); + } +} diff --git a/library/cpp/messagebus/rain_check/core/task.cpp b/library/cpp/messagebus/rain_check/core/task.cpp new file mode 100644 index 0000000000..a098437d53 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/task.cpp @@ -0,0 +1,216 @@ +#include "rain_check.h" + +#include <library/cpp/messagebus/actor/temp_tls_vector.h> + +#include <util/system/type_name.h> +#include <util/system/tls.h> + +using namespace NRainCheck; +using namespace NRainCheck::NPrivate; + +using namespace NActor; + +namespace { + Y_POD_STATIC_THREAD(TTaskRunnerBase*) + ThreadCurrentTask; +} + +void TNopSubtaskListener::SetDone() { +} + +TNopSubtaskListener TNopSubtaskListener::Instance; + +TTaskRunnerBase::TTaskRunnerBase(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ITaskBase> impl) + : TActor<TTaskRunnerBase>(env->GetExecutor()) + , Impl(impl) + , ParentTask(parentTask) + //, HoldsSelfReference(false) + , Done(false) + , SetDoneCalled(false) +{ +} + +TTaskRunnerBase::~TTaskRunnerBase() { + Y_ASSERT(Done); +} + +namespace { + struct TRunningInThisThreadGuard { + TTaskRunnerBase* const Task; + TRunningInThisThreadGuard(TTaskRunnerBase* task) + : Task(task) + { + Y_ASSERT(!ThreadCurrentTask); + ThreadCurrentTask = task; + } + + ~TRunningInThisThreadGuard() { + Y_ASSERT(ThreadCurrentTask == Task); + ThreadCurrentTask = nullptr; + } + }; +} + +void NRainCheck::TTaskRunnerBase::Act(NActor::TDefaultTag) { + Y_ASSERT(RefCount() > 0); + + TRunningInThisThreadGuard g(this); + + //RetainRef(); + + for (;;) { + TTempTlsVector<TSubtaskCompletion*> temp; + + temp.GetVector()->swap(Pending); + + for (auto& pending : *temp.GetVector()) { + if (pending->IsComplete()) { + pending->FireCompletionCallback(GetImplBase()); + } else { + Pending.push_back(pending); + } + } + + if (!Pending.empty()) { + return; + } + + if (!Done) { + Done = !ReplyReceived(); + } else { + if (Pending.empty()) { + if (!SetDoneCalled) { + ParentTask->SetDone(); + SetDoneCalled = true; + } + //ReleaseRef(); + return; + } + } + } +} + +bool TTaskRunnerBase::IsRunningInThisThread() const { + return ThreadCurrentTask == this; +} + +TSubtaskCompletion::~TSubtaskCompletion() { + ESubtaskState state = State.Get(); + Y_ASSERT(state == CREATED || state == DONE || state == CANCELED); +} + +void TSubtaskCompletion::FireCompletionCallback(ITaskBase* task) { + Y_ASSERT(IsComplete()); + + if (!!CompletionFunc) { + TSubtaskCompletionFunc temp = CompletionFunc; + // completion func must be reset before calling it, + // because function may set it back + CompletionFunc = TSubtaskCompletionFunc(); + (task->*(temp.Func))(this); + } +} + +void NRainCheck::TSubtaskCompletion::Cancel() { + for (;;) { + ESubtaskState state = State.Get(); + if (state == CREATED && State.CompareAndSet(CREATED, CANCELED)) { + return; + } + if (state == RUNNING && State.CompareAndSet(RUNNING, CANCEL_REQUESTED)) { + return; + } + if (state == DONE && State.CompareAndSet(DONE, CANCELED)) { + return; + } + if (state == CANCEL_REQUESTED || state == CANCELED) { + return; + } + } +} + +void TSubtaskCompletion::SetRunning(TTaskRunnerBase* parent) { + Y_ASSERT(!TaskRunner); + Y_ASSERT(!!parent); + + TaskRunner = parent; + + parent->Pending.push_back(this); + + parent->RefV(); + + for (;;) { + ESubtaskState current = State.Get(); + if (current != CREATED && current != DONE) { + Y_FAIL("current state should be CREATED or DONE: %s", ToCString(current)); + } + if (State.CompareAndSet(current, RUNNING)) { + return; + } + } +} + +void TSubtaskCompletion::SetDone() { + Y_ASSERT(!!TaskRunner); + TTaskRunnerBase* temp = TaskRunner; + TaskRunner = nullptr; + + for (;;) { + ESubtaskState state = State.Get(); + if (state == RUNNING) { + if (State.CompareAndSet(RUNNING, DONE)) { + break; + } + } else if (state == CANCEL_REQUESTED) { + if (State.CompareAndSet(CANCEL_REQUESTED, CANCELED)) { + break; + } + } else { + Y_FAIL("cannot SetDone: unknown state: %s", ToCString(state)); + } + } + + temp->ScheduleV(); + temp->UnRefV(); +} + +#if 0 +void NRainCheck::TTaskRunnerBase::RetainRef() +{ + if (HoldsSelfReference) { + return; + } + HoldsSelfReference = true; + Ref(); +} + +void NRainCheck::TTaskRunnerBase::ReleaseRef() +{ + if (!HoldsSelfReference) { + return; + } + HoldsSelfReference = false; + DecRef(); +} +#endif + +void TTaskRunnerBase::AssertInThisThread() const { + Y_ASSERT(IsRunningInThisThread()); +} + +TTaskRunnerBase* TTaskRunnerBase::CurrentTask() { + Y_VERIFY(!!ThreadCurrentTask); + return ThreadCurrentTask; +} + +ITaskBase* TTaskRunnerBase::CurrentTaskImpl() { + return CurrentTask()->GetImplBase(); +} + +TString TTaskRunnerBase::GetStatusSingleLine() { + return TypeName(*Impl); +} + +bool NRainCheck::AreWeInsideTask() { + return ThreadCurrentTask != nullptr; +} diff --git a/library/cpp/messagebus/rain_check/core/task.h b/library/cpp/messagebus/rain_check/core/task.h new file mode 100644 index 0000000000..7d8778bcda --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/task.h @@ -0,0 +1,184 @@ +#pragma once + +#include "fwd.h" + +#include <library/cpp/messagebus/actor/actor.h> +#include <library/cpp/messagebus/misc/atomic_box.h> + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/thread/lfstack.h> + +namespace NRainCheck { + struct ISubtaskListener { + virtual void SetDone() = 0; + virtual ~ISubtaskListener() { + } + }; + + struct TNopSubtaskListener: public ISubtaskListener { + void SetDone() override; + + static TNopSubtaskListener Instance; + }; + + class TSubtaskCompletionFunc { + friend class TSubtaskCompletion; + + typedef void (ITaskBase::*TFunc)(TSubtaskCompletion*); + TFunc Func; + + public: + TSubtaskCompletionFunc() + : Func(nullptr) + { + } + + TSubtaskCompletionFunc(void*) + : Func(nullptr) + { + } + + template <typename TTask> + TSubtaskCompletionFunc(void (TTask::*func)(TSubtaskCompletion*)) + : Func((TFunc)func) + { + static_assert((std::is_base_of<ITaskBase, TTask>::value), "expect (std::is_base_of<ITaskBase, TTask>::value)"); + } + + bool operator!() const { + return !Func; + } + }; + + template <typename T> + class TTaskFuture; + +#define SUBTASK_STATE_MAP(XX) \ + XX(CREATED, "Initial") \ + XX(RUNNING, "Running") \ + XX(DONE, "Completed") \ + XX(CANCEL_REQUESTED, "Cancel requested, but still executing") \ + XX(CANCELED, "Canceled") \ + /**/ + + enum ESubtaskState { + SUBTASK_STATE_MAP(ENUM_VALUE_GEN_NO_VALUE) + }; + + ENUM_TO_STRING(ESubtaskState, SUBTASK_STATE_MAP) + + class TSubtaskCompletion : TNonCopyable, public ISubtaskListener { + friend struct TTaskAccessor; + + private: + TAtomicBox<ESubtaskState> State; + TTaskRunnerBase* volatile TaskRunner; + TSubtaskCompletionFunc CompletionFunc; + + public: + TSubtaskCompletion() + : State(CREATED) + , TaskRunner() + { + } + ~TSubtaskCompletion() override; + + // Either done or cancel requested or cancelled + bool IsComplete() const { + ESubtaskState state = State.Get(); + switch (state) { + case RUNNING: + return false; + case DONE: + return true; + case CANCEL_REQUESTED: + return false; + case CANCELED: + return true; + case CREATED: + Y_FAIL("not started"); + default: + Y_FAIL("unknown value: %u", (unsigned)state); + } + } + + void FireCompletionCallback(ITaskBase*); + + void SetCompletionCallback(TSubtaskCompletionFunc func) { + CompletionFunc = func; + } + + // Completed, but not cancelled + bool IsDone() const { + return State.Get() == DONE; + } + + // Request cancel by actor + // Does nothing but marks task cancelled, + // and allows proceeding to next callback + void Cancel(); + + // called by service provider implementations + // must not be called by actor + void SetRunning(TTaskRunnerBase* parent); + void SetDone() override; + }; + + // See ISimpleTask, ICoroTask + class TTaskRunnerBase: public TAtomicRefCount<TTaskRunnerBase>, public NActor::TActor<TTaskRunnerBase> { + friend class NActor::TActor<TTaskRunnerBase>; + friend class TContinueFunc; + friend struct TTaskAccessor; + friend class TSubtaskCompletion; + + private: + THolder<ITaskBase> Impl; + + ISubtaskListener* const ParentTask; + // While task is running, it holds extra reference to self. + //bool HoldsSelfReference; + bool Done; + bool SetDoneCalled; + + // Subtasks currently executed. + TVector<TSubtaskCompletion*> Pending; + + void Act(NActor::TDefaultTag); + + public: + // Construct task. Task is not automatically started. + TTaskRunnerBase(IEnv*, ISubtaskListener* parent, TAutoPtr<ITaskBase> impl); + ~TTaskRunnerBase() override; + + bool IsRunningInThisThread() const; + void AssertInThisThread() const; + static TTaskRunnerBase* CurrentTask(); + static ITaskBase* CurrentTaskImpl(); + + TString GetStatusSingleLine(); + + protected: + //void RetainRef(); + //void ReleaseRef(); + ITaskBase* GetImplBase() { + return Impl.Get(); + } + + private: + // true if need to call again + virtual bool ReplyReceived() = 0; + }; + + class ITaskBase { + public: + virtual ~ITaskBase() { + } + }; + + // Check that current method executed inside some task. + bool AreWeInsideTask(); + +} diff --git a/library/cpp/messagebus/rain_check/core/track.cpp b/library/cpp/messagebus/rain_check/core/track.cpp new file mode 100644 index 0000000000..092a51a214 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/track.cpp @@ -0,0 +1,66 @@ +#include "track.h" + +using namespace NRainCheck; +using namespace NRainCheck::NPrivate; + +void TTaskTrackerReceipt::SetDone() { + TaskTracker->GetQueue<TTaskTrackerReceipt*>()->EnqueueAndSchedule(this); +} + +TString TTaskTrackerReceipt::GetStatusSingleLine() { + return Task->GetStatusSingleLine(); +} + +TTaskTracker::TTaskTracker(NActor::TExecutor* executor) + : NActor::TActor<TTaskTracker>(executor) +{ +} + +TTaskTracker::~TTaskTracker() { + Y_ASSERT(Tasks.Empty()); +} + +void TTaskTracker::Shutdown() { + ShutdownFlag.Set(true); + Schedule(); + ShutdownEvent.WaitI(); +} + +void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, ITaskFactory* taskFactory) { + THolder<ITaskFactory> holder(taskFactory); + + THolder<TTaskTrackerReceipt> receipt(new TTaskTrackerReceipt(this)); + receipt->Task = taskFactory->NewTask(receipt.Get()); + + Tasks.PushBack(receipt.Release()); +} + +void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TTaskTrackerReceipt* receipt) { + Y_ASSERT(!receipt->Empty()); + receipt->Unlink(); + delete receipt; +} + +void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TAsyncResult<TTaskTrackerStatus>* status) { + TTaskTrackerStatus s; + s.Size = Tasks.Size(); + status->SetResult(s); +} + +void TTaskTracker::Act(NActor::TDefaultTag) { + GetQueue<TAsyncResult<TTaskTrackerStatus>*>()->DequeueAll(); + GetQueue<ITaskFactory*>()->DequeueAll(); + GetQueue<TTaskTrackerReceipt*>()->DequeueAll(); + + if (ShutdownFlag.Get()) { + if (Tasks.Empty()) { + ShutdownEvent.Signal(); + } + } +} + +ui32 TTaskTracker::Size() { + TAsyncResult<TTaskTrackerStatus> r; + GetQueue<TAsyncResult<TTaskTrackerStatus>*>()->EnqueueAndSchedule(&r); + return r.GetResult().Size; +} diff --git a/library/cpp/messagebus/rain_check/core/track.h b/library/cpp/messagebus/rain_check/core/track.h new file mode 100644 index 0000000000..d387de7574 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/track.h @@ -0,0 +1,97 @@ +#pragma once + +#include "spawn.h" +#include "task.h" + +#include <library/cpp/messagebus/async_result.h> +#include <library/cpp/messagebus/actor/queue_in_actor.h> +#include <library/cpp/messagebus/misc/atomic_box.h> + +#include <util/generic/intrlist.h> +#include <util/system/event.h> + +namespace NRainCheck { + class TTaskTracker; + + namespace NPrivate { + struct ITaskFactory { + virtual TIntrusivePtr<TTaskRunnerBase> NewTask(ISubtaskListener*) = 0; + virtual ~ITaskFactory() { + } + }; + + struct TTaskTrackerReceipt: public ISubtaskListener, public TIntrusiveListItem<TTaskTrackerReceipt> { + TTaskTracker* const TaskTracker; + TIntrusivePtr<TTaskRunnerBase> Task; + + TTaskTrackerReceipt(TTaskTracker* taskTracker) + : TaskTracker(taskTracker) + { + } + + void SetDone() override; + + TString GetStatusSingleLine(); + }; + + struct TTaskTrackerStatus { + ui32 Size; + }; + + } + + class TTaskTracker + : public TAtomicRefCount<TTaskTracker>, + public NActor::TActor<TTaskTracker>, + public NActor::TQueueInActor<TTaskTracker, NPrivate::ITaskFactory*>, + public NActor::TQueueInActor<TTaskTracker, NPrivate::TTaskTrackerReceipt*>, + public NActor::TQueueInActor<TTaskTracker, TAsyncResult<NPrivate::TTaskTrackerStatus>*> { + friend struct NPrivate::TTaskTrackerReceipt; + + private: + TAtomicBox<bool> ShutdownFlag; + TSystemEvent ShutdownEvent; + + TIntrusiveList<NPrivate::TTaskTrackerReceipt> Tasks; + + template <typename TItem> + NActor::TQueueInActor<TTaskTracker, TItem>* GetQueue() { + return this; + } + + public: + TTaskTracker(NActor::TExecutor* executor); + ~TTaskTracker() override; + + void Shutdown(); + + void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, NPrivate::ITaskFactory*); + void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, NPrivate::TTaskTrackerReceipt*); + void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TAsyncResult<NPrivate::TTaskTrackerStatus>*); + + void Act(NActor::TDefaultTag); + + template <typename TTask, typename TEnv, typename TParam> + void Spawn(TEnv* env, TParam param) { + struct TTaskFactory: public NPrivate::ITaskFactory { + TEnv* const Env; + TParam Param; + + TTaskFactory(TEnv* env, TParam param) + : Env(env) + , Param(param) + { + } + + TIntrusivePtr<TTaskRunnerBase> NewTask(ISubtaskListener* subtaskListener) override { + return NRainCheck::SpawnTask<TTask>(Env, Param, subtaskListener).Get(); + } + }; + + GetQueue<NPrivate::ITaskFactory*>()->EnqueueAndSchedule(new TTaskFactory(env, param)); + } + + ui32 Size(); + }; + +} diff --git a/library/cpp/messagebus/rain_check/core/track_ut.cpp b/library/cpp/messagebus/rain_check/core/track_ut.cpp new file mode 100644 index 0000000000..05f7de1319 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/track_ut.cpp @@ -0,0 +1,45 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "track.h" + +#include <library/cpp/messagebus/rain_check/test/helper/misc.h> +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +using namespace NRainCheck; + +Y_UNIT_TEST_SUITE(TaskTracker) { + struct TTaskForTracker: public ISimpleTask { + TTestSync* const TestSync; + + TTaskForTracker(TTestEnv*, TTestSync* testSync) + : TestSync(testSync) + { + } + + TContinueFunc Start() override { + TestSync->WaitForAndIncrement(0); + TestSync->WaitForAndIncrement(2); + return nullptr; + } + }; + + Y_UNIT_TEST(Simple) { + TTestEnv env; + + TIntrusivePtr<TTaskTracker> tracker(new TTaskTracker(env.GetExecutor())); + + TTestSync testSync; + + tracker->Spawn<TTaskForTracker>(&env, &testSync); + + testSync.WaitFor(1); + + UNIT_ASSERT_VALUES_EQUAL(1u, tracker->Size()); + + testSync.CheckAndIncrement(1); + + testSync.WaitForAndIncrement(3); + + tracker->Shutdown(); + } +} diff --git a/library/cpp/messagebus/rain_check/core/ya.make b/library/cpp/messagebus/rain_check/core/ya.make new file mode 100644 index 0000000000..c6fb5640d4 --- /dev/null +++ b/library/cpp/messagebus/rain_check/core/ya.make @@ -0,0 +1,25 @@ +LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/coroutine/engine + library/cpp/deprecated/enum_codegen + library/cpp/messagebus + library/cpp/messagebus/actor + library/cpp/messagebus/scheduler +) + +SRCS( + coro.cpp + coro_stack.cpp + env.cpp + rain_check.cpp + simple.cpp + sleep.cpp + spawn.cpp + task.cpp + track.cpp +) + +END() diff --git a/library/cpp/messagebus/rain_check/http/client.cpp b/library/cpp/messagebus/rain_check/http/client.cpp new file mode 100644 index 0000000000..5ef5ceeece --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/client.cpp @@ -0,0 +1,154 @@ +#include "client.h" + +#include "http_code_extractor.h" + +#include <library/cpp/http/io/stream.h> +#include <library/cpp/neh/factory.h> +#include <library/cpp/neh/http_common.h> +#include <library/cpp/neh/location.h> +#include <library/cpp/neh/neh.h> + +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> +#include <util/network/socket.h> +#include <util/stream/str.h> + +namespace NRainCheck { + class THttpCallback: public NNeh::IOnRecv { + public: + THttpCallback(NRainCheck::THttpFuture* future) + : Future(future) + { + Y_VERIFY(!!future, "future is NULL"); + } + + void OnRecv(NNeh::THandle& handle) override { + THolder<THttpCallback> self(this); + NNeh::TResponseRef response = handle.Get(); + Future->SetDoneAndSchedule(response); + } + + private: + NRainCheck::THttpFuture* const Future; + }; + + THttpFuture::THttpFuture() + : Task(nullptr) + , ErrorCode(THttpFuture::NoError) + { + } + + THttpFuture::~THttpFuture() { + } + + bool THttpFuture::HasError() const { + return (ErrorCode != THttpFuture::NoError); + } + + THttpFuture::EError THttpFuture::GetErrorCode() const { + return ErrorCode; + } + + TString THttpFuture::GetErrorDescription() const { + return ErrorDescription; + } + + THttpClientService::THttpClientService() + : GetProtocol(NNeh::ProtocolFactory()->Protocol("http")) + , FullProtocol(NNeh::ProtocolFactory()->Protocol("full")) + { + Y_VERIFY(!!GetProtocol, "GET protocol is NULL."); + Y_VERIFY(!!FullProtocol, "POST protocol is NULL."); + } + + THttpClientService::~THttpClientService() { + } + + void THttpClientService::SendPost(TString addr, const TString& data, const THttpHeaders& headers, THttpFuture* future) { + Y_VERIFY(!!future, "future is NULL."); + + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + future->SetRunning(current); + future->Task = current; + + THolder<THttpCallback> callback(new THttpCallback(future)); + NNeh::TServiceStatRef stat; + try { + NNeh::TMessage msg(addr.replace(0, NNeh::TParsedLocation(addr).Scheme.size(), "post"), data); + TStringStream headersText; + headers.OutTo(&headersText); + NNeh::NHttp::MakeFullRequest(msg, headersText.Str(), TString()); + FullProtocol->ScheduleRequest(msg, callback.Get(), stat); + Y_UNUSED(callback.Release()); + } catch (const TNetworkResolutionError& err) { + future->SetFail(THttpFuture::CantResolveNameError, err.AsStrBuf()); + } catch (const yexception& err) { + future->SetFail(THttpFuture::OtherError, err.AsStrBuf()); + } + } + + void THttpClientService::Send(const TString& request, THttpFuture* future) { + Y_VERIFY(!!future, "future is NULL."); + + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + future->SetRunning(current); + future->Task = current; + + THolder<THttpCallback> callback(new THttpCallback(future)); + NNeh::TServiceStatRef stat; + try { + GetProtocol->ScheduleRequest(NNeh::TMessage::FromString(request), + callback.Get(), + stat); + Y_UNUSED(callback.Release()); + } catch (const TNetworkResolutionError& err) { + future->SetFail(THttpFuture::CantResolveNameError, err.AsStrBuf()); + } catch (const yexception& err) { + future->SetFail(THttpFuture::OtherError, err.AsStrBuf()); + } + } + + bool THttpFuture::HasHttpCode() const { + return !!HttpCode; + } + + bool THttpFuture::HasResponseBody() const { + return !!Response; + } + + ui32 THttpFuture::GetHttpCode() const { + Y_ASSERT(IsDone()); + Y_ASSERT(HasHttpCode()); + + return static_cast<ui32>(*HttpCode); + } + + TString THttpFuture::GetResponseBody() const { + Y_ASSERT(IsDone()); + Y_ASSERT(HasResponseBody()); + + return Response->Data; + } + + void THttpFuture::SetDoneAndSchedule(TAutoPtr<NNeh::TResponse> response) { + if (!response->IsError()) { + ErrorCode = THttpFuture::NoError; + HttpCode = HttpCodes::HTTP_OK; + } else { + ErrorCode = THttpFuture::BadHttpCodeError; + ErrorDescription = response->GetErrorText(); + + HttpCode = TryGetHttpCodeFromErrorDescription(ErrorDescription); + } + Response.Reset(response); + SetDone(); + } + + void THttpFuture::SetFail(THttpFuture::EError errorCode, const TStringBuf& errorDescription) { + ErrorCode = errorCode; + ErrorDescription = errorDescription; + Response.Destroy(); + SetDone(); + } + +} diff --git a/library/cpp/messagebus/rain_check/http/client.h b/library/cpp/messagebus/rain_check/http/client.h new file mode 100644 index 0000000000..d4199c4c98 --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/client.h @@ -0,0 +1,78 @@ +#pragma once + +#include <library/cpp/messagebus/rain_check/core/task.h> + +#include <library/cpp/http/misc/httpcodes.h> + +#include <util/generic/maybe.h> +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/system/defaults.h> +#include <util/system/yassert.h> + +class THttpHeaders; + +namespace NNeh { + class IProtocol; + struct TResponse; +} + +namespace NRainCheck { + class THttpCallback; + class THttpClientService; + + class THttpFuture: public TSubtaskCompletion { + public: + enum EError { + NoError = 0, + + CantResolveNameError = 1, + BadHttpCodeError = 2, + + OtherError = 100 + }; + + private: + friend class THttpCallback; + friend class THttpClientService; + + public: + THttpFuture(); + ~THttpFuture() override; + + bool HasHttpCode() const; + bool HasResponseBody() const; + + ui32 GetHttpCode() const; + TString GetResponseBody() const; + + bool HasError() const; + EError GetErrorCode() const; + TString GetErrorDescription() const; + + private: + void SetDoneAndSchedule(TAutoPtr<NNeh::TResponse> response); + void SetFail(EError errorCode, const TStringBuf& errorDescription); + + private: + TTaskRunnerBase* Task; + TMaybe<HttpCodes> HttpCode; + THolder<NNeh::TResponse> Response; + EError ErrorCode; + TString ErrorDescription; + }; + + class THttpClientService { + public: + THttpClientService(); + virtual ~THttpClientService(); + + void Send(const TString& request, THttpFuture* future); + void SendPost(TString addr, const TString& data, const THttpHeaders& headers, THttpFuture* future); + + private: + NNeh::IProtocol* const GetProtocol; + NNeh::IProtocol* const FullProtocol; + }; + +} diff --git a/library/cpp/messagebus/rain_check/http/client_ut.cpp b/library/cpp/messagebus/rain_check/http/client_ut.cpp new file mode 100644 index 0000000000..1628114391 --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/client_ut.cpp @@ -0,0 +1,205 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "client.h" +#include "http_code_extractor.h" + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <library/cpp/messagebus/test/helper/fixed_port.h> + +#include <library/cpp/http/io/stream.h> +#include <library/cpp/neh/rpc.h> + +#include <util/generic/cast.h> +#include <util/generic/ptr.h> +#include <util/generic/strbuf.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/network/ip.h> +#include <util/stream/str.h> +#include <util/string/printf.h> +#include <util/system/defaults.h> +#include <util/system/yassert.h> + +#include <cstdlib> +#include <utility> + +using namespace NRainCheck; +using namespace NBus::NTest; + +namespace { + class THttpClientEnv: public TTestEnvTemplate<THttpClientEnv> { + public: + THttpClientService HttpClientService; + }; + + const TString TEST_SERVICE = "test-service"; + const TString TEST_GET_PARAMS = "p=GET"; + const TString TEST_POST_PARAMS = "p=POST"; + const TString TEST_POST_HEADERS = "Content-Type: application/json\r\n"; + const TString TEST_GET_RECV = "GET was ok."; + const TString TEST_POST_RECV = "POST was ok."; + + TString BuildServiceLocation(ui32 port) { + return Sprintf("http://*:%" PRIu32 "/%s", port, TEST_SERVICE.data()); + } + + TString BuildPostServiceLocation(ui32 port) { + return Sprintf("post://*:%" PRIu32 "/%s", port + 1, TEST_SERVICE.data()); + } + + TString BuildGetTestRequest(ui32 port) { + return BuildServiceLocation(port) + "?" + TEST_GET_PARAMS; + } + + class TSimpleServer { + public: + inline void ServeRequest(const NNeh::IRequestRef& req) { + NNeh::TData response; + if (req->Data() == TEST_GET_PARAMS) { + response.assign(TEST_GET_RECV.begin(), TEST_GET_RECV.end()); + } else { + response.assign(TEST_POST_RECV.begin(), TEST_POST_RECV.end()); + } + req->SendReply(response); + } + }; + + NNeh::IServicesRef RunServer(ui32 port, TSimpleServer& server) { + NNeh::IServicesRef runner = NNeh::CreateLoop(); + runner->Add(BuildServiceLocation(port), server); + runner->Add(BuildPostServiceLocation(port), server); + + try { + const int THR_POOL_SIZE = 2; + runner->ForkLoop(THR_POOL_SIZE); + } catch (...) { + Y_FAIL("Can't run server: %s", CurrentExceptionMessage().data()); + } + + return runner; + } + enum ERequestType { + RT_HTTP_GET = 0, + RT_HTTP_POST = 1 + }; + + using TTaskParam = std::pair<TIpPort, ERequestType>; + + class THttpClientTask: public ISimpleTask { + public: + THttpClientTask(THttpClientEnv* env, TTaskParam param) + : Env(env) + , ServerPort(param.first) + , ReqType(param.second) + { + } + + TContinueFunc Start() override { + switch (ReqType) { + case RT_HTTP_GET: { + TString getRequest = BuildGetTestRequest(ServerPort); + for (size_t i = 0; i < 3; ++i) { + Requests.push_back(new THttpFuture()); + Env->HttpClientService.Send(getRequest, Requests[i].Get()); + } + break; + } + case RT_HTTP_POST: { + TString servicePath = BuildPostServiceLocation(ServerPort); + TStringInput headersText(TEST_POST_HEADERS); + THttpHeaders headers(&headersText); + for (size_t i = 0; i < 3; ++i) { + Requests.push_back(new THttpFuture()); + Env->HttpClientService.SendPost(servicePath, TEST_POST_PARAMS, headers, Requests[i].Get()); + } + break; + } + } + + return &THttpClientTask::GotReplies; + } + + TContinueFunc GotReplies() { + const TString& TEST_OK_RECV = (ReqType == RT_HTTP_GET) ? TEST_GET_RECV : TEST_POST_RECV; + for (size_t i = 0; i < Requests.size(); ++i) { + UNIT_ASSERT_EQUAL(Requests[i]->GetHttpCode(), 200); + UNIT_ASSERT_EQUAL(Requests[i]->GetResponseBody(), TEST_OK_RECV); + } + + Env->TestSync.CheckAndIncrement(0); + + return nullptr; + } + + THttpClientEnv* const Env; + const TIpPort ServerPort; + const ERequestType ReqType; + + TVector<TSimpleSharedPtr<THttpFuture>> Requests; + }; + +} // anonymous namespace + +Y_UNIT_TEST_SUITE(RainCheckHttpClient) { + static const TIpPort SERVER_PORT = 4000; + + Y_UNIT_TEST(Simple) { + // TODO: randomize port + if (!IsFixedPortTestAllowed()) { + return; + } + + TSimpleServer server; + NNeh::IServicesRef runner = RunServer(SERVER_PORT, server); + + THttpClientEnv env; + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<THttpClientTask>(TTaskParam(SERVER_PORT, RT_HTTP_GET)); + + env.TestSync.WaitForAndIncrement(1); + } + + Y_UNIT_TEST(SimplePost) { + // TODO: randomize port + if (!IsFixedPortTestAllowed()) { + return; + } + + TSimpleServer server; + NNeh::IServicesRef runner = RunServer(SERVER_PORT, server); + + THttpClientEnv env; + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<THttpClientTask>(TTaskParam(SERVER_PORT, RT_HTTP_POST)); + + env.TestSync.WaitForAndIncrement(1); + } + + Y_UNIT_TEST(HttpCodeExtraction) { + // Find "request failed(" string, then copy len("HTTP/1.X NNN") chars and try to convert NNN to HTTP code. + +#define CHECK_VALID_LINE(line, code) \ + UNIT_ASSERT_NO_EXCEPTION(TryGetHttpCodeFromErrorDescription(line)); \ + UNIT_ASSERT(!!TryGetHttpCodeFromErrorDescription(line)); \ + UNIT_ASSERT_EQUAL(*TryGetHttpCodeFromErrorDescription(line), code) + + CHECK_VALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.0 200 Some random message"), 200); + CHECK_VALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.0 404 Some random message"), 404); + CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.0 100 Some random message"), 100); + CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.0 105)"), 105); + CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.1 2004 Some random message"), 200); +#undef CHECK_VALID_LINE + +#define CHECK_INVALID_LINE(line) \ + UNIT_ASSERT_NO_EXCEPTION(TryGetHttpCodeFromErrorDescription(line)); \ + UNIT_ASSERT(!TryGetHttpCodeFromErrorDescription(line)) + + CHECK_INVALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.1 1 Some random message")); + CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.0 asdf Some random message")); + CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200 Some random message")); + CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.0 2x00 Some random message")); + CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200 Some random message")); + CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200")); + CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.1 3334 Some random message")); +#undef CHECK_INVALID_LINE + } +} diff --git a/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp b/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp new file mode 100644 index 0000000000..51d75762f6 --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp @@ -0,0 +1,39 @@ +#include "http_code_extractor.h" + +#include <library/cpp/http/io/stream.h> +#include <library/cpp/http/misc/httpcodes.h> + +#include <util/generic/maybe.h> +#include <util/generic/strbuf.h> +#include <util/string/cast.h> + +namespace NRainCheck { + TMaybe<HttpCodes> TryGetHttpCodeFromErrorDescription(const TStringBuf& errorMessage) { + // Try to get HttpCode from library/cpp/neh response. + // If response has HttpCode and it is not 200 OK, library/cpp/neh will send a message + // "library/cpp/neh/http.cpp:<LINE>: request failed(<FIRST-HTTP-RESPONSE-LINE>)" + // (see library/cpp/neh/http.cpp:625). So, we will try to parse this message and + // find out HttpCode in it. It is bad temporary solution, but we have no choice. + const TStringBuf SUBSTR = "request failed("; + const size_t SUBSTR_LEN = SUBSTR.size(); + const size_t FIRST_LINE_LEN = TStringBuf("HTTP/1.X NNN").size(); + + TMaybe<HttpCodes> httpCode; + + const size_t substrPos = errorMessage.find(SUBSTR); + if (substrPos != TStringBuf::npos) { + const TStringBuf firstLineStart = errorMessage.SubStr(substrPos + SUBSTR_LEN, FIRST_LINE_LEN); + try { + httpCode = static_cast<HttpCodes>(ParseHttpRetCode(firstLineStart)); + if (*httpCode < HTTP_CONTINUE || *httpCode >= HTTP_CODE_MAX) { + httpCode = Nothing(); + } + } catch (const TFromStringException& ex) { + // Can't parse HttpCode: it is OK, because ErrorDescription can be random string. + } + } + + return httpCode; + } + +} diff --git a/library/cpp/messagebus/rain_check/http/http_code_extractor.h b/library/cpp/messagebus/rain_check/http/http_code_extractor.h new file mode 100644 index 0000000000..33b565fa1c --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/http_code_extractor.h @@ -0,0 +1,16 @@ +#pragma once + +#include <library/cpp/http/misc/httpcodes.h> + +#include <util/generic/maybe.h> +#include <util/generic/strbuf.h> + +namespace NRainCheck { + // Try to get HttpCode from library/cpp/neh response. + // If response has HttpCode and it is not 200 OK, library/cpp/neh will send a message + // "library/cpp/neh/http.cpp:<LINE>: request failed(<FIRST-HTTP-RESPONSE-LINE>)" + // (see library/cpp/neh/http.cpp:625). So, we will try to parse this message and + // find out HttpCode in it. It is bad temporary solution, but we have no choice. + TMaybe<HttpCodes> TryGetHttpCodeFromErrorDescription(const TStringBuf& errorMessage); + +} diff --git a/library/cpp/messagebus/rain_check/http/ya.make b/library/cpp/messagebus/rain_check/http/ya.make new file mode 100644 index 0000000000..ef13329df3 --- /dev/null +++ b/library/cpp/messagebus/rain_check/http/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +OWNER(g:messagebus) + +SRCS( + client.cpp + http_code_extractor.cpp +) + +PEERDIR( + library/cpp/messagebus/rain_check/core + library/cpp/neh + library/cpp/http/misc + library/cpp/http/io +) + +END() diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp new file mode 100644 index 0000000000..daac8d9a99 --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp @@ -0,0 +1,98 @@ +#include "messagebus_client.h" + +using namespace NRainCheck; +using namespace NBus; + +TBusClientService::TBusClientService( + const NBus::TBusSessionConfig& config, + NBus::TBusProtocol* proto, + NBus::TBusMessageQueue* queue) { + Session = queue->CreateSource(proto, this, config); +} + +TBusClientService::~TBusClientService() { + Session->Shutdown(); +} + +void TBusClientService::SendCommon(NBus::TBusMessage* message, const NBus::TNetAddr&, TBusFuture* future) { + TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask(); + + future->SetRunning(current); + + future->Task = current; + + // after this statement message is owned by both messagebus and future + future->Request.Reset(message); + + // TODO: allow cookie in messagebus + message->Data = future; +} + +void TBusClientService::ProcessResultCommon(NBus::TBusMessageAutoPtr message, + const NBus::TNetAddr&, TBusFuture* future, + NBus::EMessageStatus status) { + Y_UNUSED(message.Release()); + + if (status == NBus::MESSAGE_OK) { + return; + } + + future->SetDoneAndSchedule(status, nullptr); +} + +void TBusClientService::SendOneWay( + NBus::TBusMessageAutoPtr message, const NBus::TNetAddr& addr, + TBusFuture* future) { + SendCommon(message.Get(), addr, future); + + EMessageStatus ok = Session->SendMessageOneWay(message.Get(), &addr, false); + ProcessResultCommon(message, addr, future, ok); +} + +NBus::TBusClientSessionPtr TBusClientService::GetSessionForMonitoring() const { + return Session; +} + +void TBusClientService::Send( + TBusMessageAutoPtr message, const TNetAddr& addr, + TBusFuture* future) { + SendCommon(message.Get(), addr, future); + + EMessageStatus ok = Session->SendMessage(message.Get(), &addr, false); + ProcessResultCommon(message, addr, future, ok); +} + +void TBusClientService::OnReply( + TAutoPtr<TBusMessage> request, + TAutoPtr<TBusMessage> response) { + TBusFuture* future = (TBusFuture*)request->Data; + Y_ASSERT(future->Request.Get() == request.Get()); + Y_UNUSED(request.Release()); + future->SetDoneAndSchedule(MESSAGE_OK, response); +} + +void NRainCheck::TBusClientService::OnMessageSentOneWay( + TAutoPtr<NBus::TBusMessage> request) { + TBusFuture* future = (TBusFuture*)request->Data; + Y_ASSERT(future->Request.Get() == request.Get()); + Y_UNUSED(request.Release()); + future->SetDoneAndSchedule(MESSAGE_OK, nullptr); +} + +void TBusClientService::OnError( + TAutoPtr<TBusMessage> message, NBus::EMessageStatus status) { + if (message->Data == nullptr) { + return; + } + + TBusFuture* future = (TBusFuture*)message->Data; + Y_ASSERT(future->Request.Get() == message.Get()); + Y_UNUSED(message.Release()); + future->SetDoneAndSchedule(status, nullptr); +} + +void TBusFuture::SetDoneAndSchedule(EMessageStatus status, TAutoPtr<TBusMessage> response) { + Status = status; + Response.Reset(response.Release()); + SetDone(); +} diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h new file mode 100644 index 0000000000..0a291cdea6 --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h @@ -0,0 +1,67 @@ +#pragma once + +#include <library/cpp/messagebus/rain_check/core/task.h> + +#include <library/cpp/messagebus/ybus.h> + +namespace NRainCheck { + class TBusFuture: public TSubtaskCompletion { + friend class TBusClientService; + + private: + THolder<NBus::TBusMessage> Request; + THolder<NBus::TBusMessage> Response; + NBus::EMessageStatus Status; + + private: + TTaskRunnerBase* Task; + + void SetDoneAndSchedule(NBus::EMessageStatus, TAutoPtr<NBus::TBusMessage>); + + public: + // TODO: add MESSAGE_UNDEFINED + TBusFuture() + : Status(NBus::MESSAGE_DONT_ASK) + , Task(nullptr) + { + } + + NBus::TBusMessage* GetRequest() const { + return Request.Get(); + } + + NBus::TBusMessage* GetResponse() const { + Y_ASSERT(IsDone()); + return Response.Get(); + } + + NBus::EMessageStatus GetStatus() const { + Y_ASSERT(IsDone()); + return Status; + } + }; + + class TBusClientService: private NBus::IBusClientHandler { + private: + NBus::TBusClientSessionPtr Session; + + public: + TBusClientService(const NBus::TBusSessionConfig&, NBus::TBusProtocol*, NBus::TBusMessageQueue*); + ~TBusClientService() override; + + void Send(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future); + void SendOneWay(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future); + + // Use it only for monitoring + NBus::TBusClientSessionPtr GetSessionForMonitoring() const; + + private: + void SendCommon(NBus::TBusMessage*, const NBus::TNetAddr&, TBusFuture* future); + void ProcessResultCommon(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future, NBus::EMessageStatus); + + void OnReply(TAutoPtr<NBus::TBusMessage> pMessage, TAutoPtr<NBus::TBusMessage> pReply) override; + void OnError(TAutoPtr<NBus::TBusMessage> pMessage, NBus::EMessageStatus status) override; + void OnMessageSentOneWay(TAutoPtr<NBus::TBusMessage>) override; + }; + +} diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp new file mode 100644 index 0000000000..1b3618558b --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp @@ -0,0 +1,146 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "messagebus_client.h" + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> + +#include <util/generic/cast.h> + +using namespace NBus; +using namespace NBus::NTest; +using namespace NRainCheck; + +struct TMessageBusClientEnv: public TTestEnvTemplate<TMessageBusClientEnv> { + // TODO: use same thread pool + TBusMessageQueuePtr Queue; + TExampleProtocol Proto; + TBusClientService BusClientService; + + static TBusQueueConfig QueueConfig() { + TBusQueueConfig r; + r.NumWorkers = 4; + return r; + } + + TMessageBusClientEnv() + : Queue(CreateMessageQueue(GetExecutor())) + , BusClientService(TBusSessionConfig(), &Proto, Queue.Get()) + { + } +}; + +Y_UNIT_TEST_SUITE(RainCheckMessageBusClient) { + struct TSimpleTask: public ISimpleTask { + TMessageBusClientEnv* const Env; + + const unsigned ServerPort; + + TSimpleTask(TMessageBusClientEnv* env, unsigned serverPort) + : Env(env) + , ServerPort(serverPort) + { + } + + TVector<TSimpleSharedPtr<TBusFuture>> Requests; + + TContinueFunc Start() override { + for (unsigned i = 0; i < 3; ++i) { + Requests.push_back(new TBusFuture); + TNetAddr addr("localhost", ServerPort); + Env->BusClientService.Send(new TExampleRequest(&Env->Proto.RequestCount), addr, Requests[i].Get()); + } + + return TContinueFunc(&TSimpleTask::GotReplies); + } + + TContinueFunc GotReplies() { + for (unsigned i = 0; i < Requests.size(); ++i) { + Y_VERIFY(Requests[i]->GetStatus() == MESSAGE_OK); + VerifyDynamicCast<TExampleResponse*>(Requests[i]->GetResponse()); + } + Env->TestSync.CheckAndIncrement(0); + return nullptr; + } + }; + + Y_UNIT_TEST(Simple) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TMessageBusClientEnv env; + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSimpleTask>(server.GetActualListenPort()); + + env.TestSync.WaitForAndIncrement(1); + } + + struct TOneWayServer: public NBus::IBusServerHandler { + TTestSync* const TestSync; + TExampleProtocol Proto; + NBus::TBusMessageQueuePtr Queue; + NBus::TBusServerSessionPtr Session; + + TOneWayServer(TTestSync* testSync) + : TestSync(testSync) + { + Queue = CreateMessageQueue(); + Session = Queue->CreateDestination(&Proto, this, NBus::TBusSessionConfig()); + } + + void OnMessage(NBus::TOnMessageContext& context) override { + TestSync->CheckAndIncrement(1); + context.ForgetRequest(); + } + }; + + struct TOneWayTask: public ISimpleTask { + TMessageBusClientEnv* const Env; + + const unsigned ServerPort; + + TOneWayTask(TMessageBusClientEnv* env, unsigned serverPort) + : Env(env) + , ServerPort(serverPort) + { + } + + TVector<TSimpleSharedPtr<TBusFuture>> Requests; + + TContinueFunc Start() override { + Env->TestSync.CheckAndIncrement(0); + + for (unsigned i = 0; i < 1; ++i) { + Requests.push_back(new TBusFuture); + TNetAddr addr("localhost", ServerPort); + Env->BusClientService.SendOneWay(new TExampleRequest(&Env->Proto.RequestCount), addr, Requests[i].Get()); + } + + return TContinueFunc(&TOneWayTask::GotReplies); + } + + TContinueFunc GotReplies() { + for (unsigned i = 0; i < Requests.size(); ++i) { + Y_VERIFY(Requests[i]->GetStatus() == MESSAGE_OK); + Y_VERIFY(!Requests[i]->GetResponse()); + } + Env->TestSync.WaitForAndIncrement(2); + return nullptr; + } + }; + + Y_UNIT_TEST(OneWay) { + TObjectCountCheck objectCountCheck; + + TMessageBusClientEnv env; + + TOneWayServer server(&env.TestSync); + + TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TOneWayTask>(server.Session->GetActualListenPort()); + + env.TestSync.WaitForAndIncrement(3); + } +} diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp new file mode 100644 index 0000000000..5d4b13d664 --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp @@ -0,0 +1,17 @@ +#include "messagebus_server.h" + +#include <library/cpp/messagebus/rain_check/core/spawn.h> + +using namespace NRainCheck; + +TBusTaskStarter::TBusTaskStarter(TAutoPtr<ITaskFactory> taskFactory) + : TaskFactory(taskFactory) +{ +} + +void TBusTaskStarter::OnMessage(NBus::TOnMessageContext& onMessage) { + TaskFactory->NewTask(onMessage); +} + +TBusTaskStarter::~TBusTaskStarter() { +} diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h new file mode 100644 index 0000000000..1334f05fe4 --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h @@ -0,0 +1,46 @@ +#pragma once + +#include <library/cpp/messagebus/rain_check/core/spawn.h> +#include <library/cpp/messagebus/rain_check/core/task.h> + +#include <library/cpp/messagebus/ybus.h> + +#include <util/system/yassert.h> + +namespace NRainCheck { + class TBusTaskStarter: public NBus::IBusServerHandler { + private: + struct ITaskFactory { + virtual void NewTask(NBus::TOnMessageContext&) = 0; + virtual ~ITaskFactory() { + } + }; + + THolder<ITaskFactory> TaskFactory; + + void OnMessage(NBus::TOnMessageContext&) override; + + public: + TBusTaskStarter(TAutoPtr<ITaskFactory>); + ~TBusTaskStarter() override; + + public: + template <typename TTask, typename TEnv> + static TAutoPtr<TBusTaskStarter> NewStarter(TEnv* env) { + struct TTaskFactory: public ITaskFactory { + TEnv* const Env; + + TTaskFactory(TEnv* env) + : Env(env) + { + } + + void NewTask(NBus::TOnMessageContext& context) override { + SpawnTask<TTask, TEnv, NBus::TOnMessageContext&>(Env, context); + } + }; + + return new TBusTaskStarter(new TTaskFactory(env)); + } + }; +} diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp new file mode 100644 index 0000000000..7c11399f1b --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp @@ -0,0 +1,51 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "messagebus_server.h" + +#include <library/cpp/messagebus/rain_check/test/ut/test.h> + +#include <library/cpp/messagebus/test/helper/example.h> + +using namespace NBus; +using namespace NBus::NTest; +using namespace NRainCheck; + +struct TMessageBusServerEnv: public TTestEnvTemplate<TMessageBusServerEnv> { + TExampleProtocol Proto; +}; + +Y_UNIT_TEST_SUITE(RainCheckMessageBusServer) { + struct TSimpleServerTask: public ISimpleTask { + private: + TMessageBusServerEnv* const Env; + TOnMessageContext MessageContext; + + public: + TSimpleServerTask(TMessageBusServerEnv* env, TOnMessageContext& messageContext) + : Env(env) + { + MessageContext.Swap(messageContext); + } + + TContinueFunc Start() override { + MessageContext.SendReplyMove(new TExampleResponse(&Env->Proto.ResponseCount)); + return nullptr; + } + }; + + Y_UNIT_TEST(Simple) { + TMessageBusServerEnv env; + + THolder<TBusTaskStarter> starter(TBusTaskStarter::NewStarter<TSimpleServerTask>(&env)); + + TBusMessageQueuePtr queue(CreateMessageQueue(env.GetExecutor())); + + TExampleProtocol proto; + + TBusServerSessionPtr session = queue->CreateDestination(&env.Proto, starter.Get(), TBusSessionConfig()); + + TExampleClient client; + + client.SendMessagesWaitReplies(1, TNetAddr("localhost", session->GetActualListenPort())); + } +} diff --git a/library/cpp/messagebus/rain_check/messagebus/ya.make b/library/cpp/messagebus/rain_check/messagebus/ya.make new file mode 100644 index 0000000000..defdac9a61 --- /dev/null +++ b/library/cpp/messagebus/rain_check/messagebus/ya.make @@ -0,0 +1,15 @@ +LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus + library/cpp/messagebus/rain_check/core +) + +SRCS( + messagebus_client.cpp + messagebus_server.cpp +) + +END() diff --git a/library/cpp/messagebus/rain_check/test/TestRainCheck.py b/library/cpp/messagebus/rain_check/test/TestRainCheck.py new file mode 100644 index 0000000000..92ed727b62 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/TestRainCheck.py @@ -0,0 +1,8 @@ +from devtools.fleur.ytest import group, constraint +from devtools.fleur.ytest.integration import UnitTestGroup + +@group +@constraint('library.messagebus') +class TestMessageBus3(UnitTestGroup): + def __init__(self, context): + UnitTestGroup.__init__(self, context, 'MessageBus', 'library-messagebus-rain_check-test-ut') diff --git a/library/cpp/messagebus/rain_check/test/helper/misc.cpp b/library/cpp/messagebus/rain_check/test/helper/misc.cpp new file mode 100644 index 0000000000..c0fcb27252 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/helper/misc.cpp @@ -0,0 +1,27 @@ +#include "misc.h" + +#include <util/system/yassert.h> + +using namespace NRainCheck; + +void TSpawnNopTasksCoroTask::Run() { + Y_VERIFY(Count <= Completion.size()); + for (unsigned i = 0; i < Count; ++i) { + SpawnSubtask<TNopCoroTask>(Env, &Completion[i], ""); + } + + WaitForSubtasks(); +} + +TContinueFunc TSpawnNopTasksSimpleTask::Start() { + Y_VERIFY(Count <= Completion.size()); + for (unsigned i = 0; i < Count; ++i) { + SpawnSubtask<TNopSimpleTask>(Env, &Completion[i], ""); + } + + return &TSpawnNopTasksSimpleTask::Join; +} + +TContinueFunc TSpawnNopTasksSimpleTask::Join() { + return nullptr; +} diff --git a/library/cpp/messagebus/rain_check/test/helper/misc.h b/library/cpp/messagebus/rain_check/test/helper/misc.h new file mode 100644 index 0000000000..9150be4d2f --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/helper/misc.h @@ -0,0 +1,57 @@ +#pragma once + +#include <library/cpp/messagebus/rain_check/core/rain_check.h> + +#include <array> + +namespace NRainCheck { + struct TNopSimpleTask: public ISimpleTask { + TNopSimpleTask(IEnv*, const void*) { + } + + TContinueFunc Start() override { + return nullptr; + } + }; + + struct TNopCoroTask: public ICoroTask { + TNopCoroTask(IEnv*, const void*) { + } + + void Run() override { + } + }; + + struct TSpawnNopTasksCoroTask: public ICoroTask { + IEnv* const Env; + unsigned const Count; + + TSpawnNopTasksCoroTask(IEnv* env, unsigned count) + : Env(env) + , Count(count) + { + } + + std::array<TSubtaskCompletion, 2> Completion; + + void Run() override; + }; + + struct TSpawnNopTasksSimpleTask: public ISimpleTask { + IEnv* const Env; + unsigned const Count; + + TSpawnNopTasksSimpleTask(IEnv* env, unsigned count) + : Env(env) + , Count(count) + { + } + + std::array<TSubtaskCompletion, 2> Completion; + + TContinueFunc Start() override; + + TContinueFunc Join(); + }; + +} diff --git a/library/cpp/messagebus/rain_check/test/helper/ya.make b/library/cpp/messagebus/rain_check/test/helper/ya.make new file mode 100644 index 0000000000..aa9e4e6d81 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/helper/ya.make @@ -0,0 +1,13 @@ +LIBRARY(messagebus-rain_check-test-helper) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus/rain_check/core +) + +SRCS( + misc.cpp +) + +END() diff --git a/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp b/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp new file mode 100644 index 0000000000..22edbd8c6b --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp @@ -0,0 +1,154 @@ +#include <library/cpp/messagebus/rain_check/test/helper/misc.h> + +#include <library/cpp/messagebus/rain_check/core/rain_check.h> + +#include <util/datetime/base.h> + +#include <array> + +using namespace NRainCheck; + +static const unsigned SUBTASKS = 2; + +struct TRainCheckPerftestEnv: public TSimpleEnvTemplate<TRainCheckPerftestEnv> { + unsigned SubtasksPerTask; + + TRainCheckPerftestEnv() + : TSimpleEnvTemplate<TRainCheckPerftestEnv>(4) + , SubtasksPerTask(1000) + { + } +}; + +struct TCoroOuter: public ICoroTask { + TRainCheckPerftestEnv* const Env; + + TCoroOuter(TRainCheckPerftestEnv* env) + : Env(env) + { + } + + void Run() override { + for (;;) { + TInstant start = TInstant::Now(); + + unsigned count = 0; + + unsigned current = 1000; + + do { + for (unsigned i = 0; i < current; ++i) { + std::array<TSubtaskCompletion, SUBTASKS> completion; + + for (unsigned j = 0; j < SUBTASKS; ++j) { + //SpawnSubtask<TNopSimpleTask>(Env, &completion[j]); + //SpawnSubtask<TSpawnNopTasksCoroTask>(Env, &completion[j], SUBTASKS); + SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &completion[j], SUBTASKS); + } + + WaitForSubtasks(); + } + + count += current; + current *= 2; + } while (TInstant::Now() - start < TDuration::Seconds(1)); + + TDuration d = TInstant::Now() - start; + unsigned dns = d.NanoSeconds() / count; + Cerr << dns << "ns per spawn/join\n"; + } + } +}; + +struct TSimpleOuter: public ISimpleTask { + TRainCheckPerftestEnv* const Env; + + TSimpleOuter(TRainCheckPerftestEnv* env, const void*) + : Env(env) + { + } + + TInstant StartInstant; + unsigned Count; + unsigned Current; + unsigned I; + + TContinueFunc Start() override { + StartInstant = TInstant::Now(); + Count = 0; + Current = 1000; + I = 0; + + return &TSimpleOuter::Spawn; + } + + std::array<TSubtaskCompletion, SUBTASKS> Completion; + + TContinueFunc Spawn() { + for (unsigned j = 0; j < SUBTASKS; ++j) { + //SpawnSubtask<TNopSimpleTask>(Env, &Completion[j]); + //SpawnSubtask<TSpawnNopTasksCoroTask>(Env, &Completion[j], SUBTASKS); + SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &Completion[j], SUBTASKS); + } + + return &TSimpleOuter::Join; + } + + TContinueFunc Join() { + I += 1; + if (I != Current) { + return &TSimpleOuter::Spawn; + } + + I = 0; + Count += Current; + Current *= 2; + + TDuration d = TInstant::Now() - StartInstant; + if (d < TDuration::Seconds(1)) { + return &TSimpleOuter::Spawn; + } + + unsigned dns = d.NanoSeconds() / Count; + Cerr << dns << "ns per spawn/join\n"; + + return &TSimpleOuter::Start; + } +}; + +struct TReproduceCrashTask: public ISimpleTask { + TRainCheckPerftestEnv* const Env; + + TReproduceCrashTask(TRainCheckPerftestEnv* env) + : Env(env) + { + } + + std::array<TSubtaskCompletion, SUBTASKS> Completion; + + TContinueFunc Start() override { + for (unsigned j = 0; j < 2; ++j) { + //SpawnSubtask<TNopSimpleTask>(Env, &Completion[j]); + SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &Completion[j], SUBTASKS); + } + + return &TReproduceCrashTask::Start; + } +}; + +int main(int argc, char** argv) { + Y_UNUSED(argc); + Y_UNUSED(argv); + + TRainCheckPerftestEnv env; + + env.SpawnTask<TSimpleOuter>(""); + //env.SpawnTask<TCoroOuter>(); + //env.SpawnTask<TReproduceCrashTask>(); + + for (;;) { + Sleep(TDuration::Hours(1)); + } + + return 0; +} diff --git a/library/cpp/messagebus/rain_check/test/perftest/ya.make b/library/cpp/messagebus/rain_check/test/perftest/ya.make new file mode 100644 index 0000000000..7330a71700 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/perftest/ya.make @@ -0,0 +1,14 @@ +PROGRAM(messagebus_rain_check_perftest) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus/rain_check/core + library/cpp/messagebus/rain_check/test/helper +) + +SRCS( + perftest.cpp +) + +END() diff --git a/library/cpp/messagebus/rain_check/test/ut/test.h b/library/cpp/messagebus/rain_check/test/ut/test.h new file mode 100644 index 0000000000..724f6b7530 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/ut/test.h @@ -0,0 +1,13 @@ +#pragma once + +#include <library/cpp/messagebus/rain_check/core/rain_check.h> + +#include <library/cpp/messagebus/misc/test_sync.h> + +template <typename TSelf> +struct TTestEnvTemplate: public NRainCheck::TSimpleEnvTemplate<TSelf> { + TTestSync TestSync; +}; + +struct TTestEnv: public TTestEnvTemplate<TTestEnv> { +}; diff --git a/library/cpp/messagebus/rain_check/test/ut/ya.make b/library/cpp/messagebus/rain_check/test/ut/ya.make new file mode 100644 index 0000000000..9f7a93417a --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/ut/ya.make @@ -0,0 +1,24 @@ +PROGRAM(library-messagebus-rain_check-test-ut) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/testing/unittest_main + library/cpp/messagebus/rain_check/core + library/cpp/messagebus/rain_check/http + library/cpp/messagebus/rain_check/messagebus + library/cpp/messagebus/test/helper +) + +SRCS( + ../../core/coro_ut.cpp + ../../core/simple_ut.cpp + ../../core/sleep_ut.cpp + ../../core/spawn_ut.cpp + ../../core/track_ut.cpp + ../../http/client_ut.cpp + ../../messagebus/messagebus_client_ut.cpp + ../../messagebus/messagebus_server_ut.cpp +) + +END() diff --git a/library/cpp/messagebus/rain_check/test/ya.make b/library/cpp/messagebus/rain_check/test/ya.make new file mode 100644 index 0000000000..4c1d6f8161 --- /dev/null +++ b/library/cpp/messagebus/rain_check/test/ya.make @@ -0,0 +1,6 @@ +OWNER(g:messagebus) + +RECURSE( + perftest + ut +) diff --git a/library/cpp/messagebus/rain_check/ya.make b/library/cpp/messagebus/rain_check/ya.make new file mode 100644 index 0000000000..966d54c232 --- /dev/null +++ b/library/cpp/messagebus/rain_check/ya.make @@ -0,0 +1,8 @@ +OWNER(g:messagebus) + +RECURSE( + core + http + messagebus + test +) diff --git a/library/cpp/messagebus/ref_counted.h b/library/cpp/messagebus/ref_counted.h new file mode 100644 index 0000000000..29b87764e3 --- /dev/null +++ b/library/cpp/messagebus/ref_counted.h @@ -0,0 +1,6 @@ +#pragma once + +class TAtomicRefCountedObject: public TAtomicRefCount<TAtomicRefCountedObject> { + virtual ~TAtomicRefCountedObject() { + } +}; diff --git a/library/cpp/messagebus/remote_client_connection.cpp b/library/cpp/messagebus/remote_client_connection.cpp new file mode 100644 index 0000000000..8c7a6db3a8 --- /dev/null +++ b/library/cpp/messagebus/remote_client_connection.cpp @@ -0,0 +1,343 @@ +#include "remote_client_connection.h" + +#include "mb_lwtrace.h" +#include "network.h" +#include "remote_client_session.h" + +#include <library/cpp/messagebus/actor/executor.h> +#include <library/cpp/messagebus/actor/temp_tls_vector.h> + +#include <util/generic/cast.h> +#include <util/thread/singleton.h> + +LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER) + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteClientConnection::TRemoteClientConnection(TRemoteClientSessionPtr session, ui64 id, TNetAddr addr) + : TRemoteConnection(session.Get(), id, addr) + , ClientHandler(GetSession()->ClientHandler) +{ + Y_VERIFY(addr.GetPort() > 0, "must connect to non-zero port"); + + ScheduleWrite(); +} + +TRemoteClientSession* TRemoteClientConnection::GetSession() { + return CheckedCast<TRemoteClientSession*>(Session.Get()); +} + +TBusMessage* TRemoteClientConnection::PopAck(TBusKey id) { + return AckMessages.Pop(id); +} + +SOCKET TRemoteClientConnection::CreateSocket(const TNetAddr& addr) { + SOCKET handle = socket(addr.Addr()->sa_family, SOCK_STREAM, 0); + Y_VERIFY(handle != INVALID_SOCKET, "failed to create socket: %s", LastSystemErrorText()); + + TSocketHolder s(handle); + + SetNonBlock(s, true); + SetNoDelay(s, Config.TcpNoDelay); + SetSockOptTcpCork(s, Config.TcpCork); + SetCloseOnExec(s, true); + SetKeepAlive(s, true); + if (Config.SocketRecvBufferSize != 0) { + SetInputBuffer(s, Config.SocketRecvBufferSize); + } + if (Config.SocketSendBufferSize != 0) { + SetOutputBuffer(s, Config.SocketSendBufferSize); + } + if (Config.SocketToS >= 0) { + SetSocketToS(s, &addr, Config.SocketToS); + } + + return s.Release(); +} + +void TRemoteClientConnection::TryConnect() { + if (AtomicGet(WriterData.Down)) { + return; + } + Y_VERIFY(!WriterData.Status.Connected); + + TInstant now = TInstant::Now(); + + if (!WriterData.Channel) { + if ((now - LastConnectAttempt) < TDuration::MilliSeconds(Config.RetryInterval)) { + DropEnqueuedData(MESSAGE_CONNECT_FAILED, MESSAGE_CONNECT_FAILED); + return; + } + LastConnectAttempt = now; + + TSocket connectSocket(CreateSocket(PeerAddr)); + WriterData.SetChannel(Session->WriteEventLoop.Register(connectSocket, this, WriteCookie)); + } + + if (BeforeSendQueue.IsEmpty() && WriterData.SendQueue.Empty() && !Config.ReconnectWhenIdle) { + // TryConnect is called from Writer::Act, which is called in cycle + // from session's ScheduleTimeoutMessages via Cron. This prevent these excessive connects. + return; + } + + ++WriterData.Status.ConnectSyscalls; + + int ret = connect(WriterData.Channel->GetSocket(), PeerAddr.Addr(), PeerAddr.Len()); + int err = ret ? LastSystemError() : 0; + + if (!ret || (ret && err == EISCONN)) { + WriterData.Status.ConnectTime = now; + ++WriterData.SocketVersion; + + WriterData.Channel->DisableWrite(); + WriterData.Status.Connected = true; + AtomicSet(ReturnConnectFailedImmediately, false); + + WriterData.Status.MyAddr = TNetAddr(GetSockAddr(WriterData.Channel->GetSocket())); + + TSocket readSocket = WriterData.Channel->GetSocketPtr(); + + ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(readSocket, WriterData.SocketVersion)); + + FireClientConnectionEvent(TClientConnectionEvent::CONNECTED); + + ScheduleWrite(); + } else { + if (WouldBlock() || err == EALREADY) { + WriterData.Channel->EnableWrite(); + } else { + WriterData.DropChannel(); + WriterData.Status.MyAddr = TNetAddr(); + WriterData.Status.Connected = false; + WriterData.Status.ConnectError = err; + + DropEnqueuedData(MESSAGE_CONNECT_FAILED, MESSAGE_CONNECT_FAILED); + } + } +} + +void TRemoteClientConnection::HandleEvent(SOCKET socket, void* cookie) { + Y_UNUSED(socket); + Y_ASSERT(cookie == WriteCookie || cookie == ReadCookie); + if (cookie == ReadCookie) { + ScheduleRead(); + } else { + ScheduleWrite(); + } +} + +void TRemoteClientConnection::WriterFillStatus() { + TRemoteConnection::WriterFillStatus(); + WriterData.Status.AckMessagesSize = AckMessages.Size(); +} + +void TRemoteClientConnection::BeforeTryWrite() { + ProcessReplyQueue(); + TimeoutMessages(); +} + +namespace NBus { + namespace NPrivate { + class TInvokeOnReply: public IWorkItem { + private: + TRemoteClientSession* RemoteClientSession; + TNonDestroyingHolder<TBusMessage> Request; + TBusMessagePtrAndHeader Response; + + public: + TInvokeOnReply(TRemoteClientSession* session, + TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response) + : RemoteClientSession(session) + , Request(request) + { + Response.Swap(response); + } + + void DoWork() override { + THolder<TInvokeOnReply> holder(this); + RemoteClientSession->ReleaseInFlightAndCallOnReply(Request.Release(), Response); + // TODO: TRemoteClientSessionSemaphore should be enough + RemoteClientSession->JobCount.Decrement(); + } + }; + + } +} + +void TRemoteClientConnection::ProcessReplyQueue() { + if (AtomicGet(WriterData.Down)) { + return; + } + + bool executeInWorkerPool = Session->Config.ExecuteOnReplyInWorkerPool; + + TTempTlsVector<TBusMessagePtrAndHeader, void, TVectorSwaps> replyQueueTemp; + TTempTlsVector< ::NActor::IWorkItem*> workQueueTemp; + + ReplyQueue.DequeueAllSingleConsumer(replyQueueTemp.GetVector()); + if (executeInWorkerPool) { + workQueueTemp.GetVector()->reserve(replyQueueTemp.GetVector()->size()); + } + + for (auto& resp : *replyQueueTemp.GetVector()) { + TBusMessage* req = PopAck(resp.Header.Id); + + if (!req) { + WriterErrorMessage(resp.MessagePtr.Release(), MESSAGE_UNKNOWN); + continue; + } + + if (executeInWorkerPool) { + workQueueTemp.GetVector()->push_back(new TInvokeOnReply(GetSession(), req, resp)); + } else { + GetSession()->ReleaseInFlightAndCallOnReply(req, resp); + } + } + + if (executeInWorkerPool) { + Session->JobCount.Add(workQueueTemp.GetVector()->size()); + Session->Queue->EnqueueWork(*workQueueTemp.GetVector()); + } +} + +void TRemoteClientConnection::TimeoutMessages() { + if (!TimeToTimeoutMessages.FetchTask()) { + return; + } + + TMessagesPtrs timedOutMessages; + + TInstant sendDeadline; + TInstant ackDeadline; + if (IsReturnConnectFailedImmediately()) { + sendDeadline = TInstant::Max(); + ackDeadline = TInstant::Max(); + } else { + TInstant now = TInstant::Now(); + sendDeadline = now - TDuration::MilliSeconds(Session->Config.SendTimeout); + ackDeadline = now - TDuration::MilliSeconds(Session->Config.TotalTimeout); + } + + { + TMessagesPtrs temp; + WriterData.SendQueue.Timeout(sendDeadline, &temp); + timedOutMessages.insert(timedOutMessages.end(), temp.begin(), temp.end()); + } + + // Ignores message that is being written currently (that is stored + // in WriteMessage). It is not a big problem, because after written + // to the network, message will be placed to the AckMessages queue, + // and timed out on the next iteration of this procedure. + + { + TMessagesPtrs temp; + AckMessages.Timeout(ackDeadline, &temp); + timedOutMessages.insert(timedOutMessages.end(), temp.begin(), temp.end()); + } + + ResetOneWayFlag(timedOutMessages); + + GetSession()->ReleaseInFlight(timedOutMessages); + WriterErrorMessages(timedOutMessages, MESSAGE_TIMEOUT); +} + +void TRemoteClientConnection::ScheduleTimeoutMessages() { + TimeToTimeoutMessages.AddTask(); + ScheduleWrite(); +} + +void TRemoteClientConnection::ReaderProcessMessageUnknownVersion(TArrayRef<const char>) { + LWPROBE(Error, ToString(MESSAGE_INVALID_VERSION), ToString(PeerAddr), ""); + ReaderData.Status.Incremental.StatusCounter[MESSAGE_INVALID_VERSION] += 1; + // TODO: close connection + Y_FAIL("unknown message"); +} + +void TRemoteClientConnection::ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) { + Y_ASSERT(result.empty()); + + TRemoteConnection::ClearOutgoingQueue(result, reconnect); + AckMessages.Clear(&result); + + ResetOneWayFlag(result); + GetSession()->ReleaseInFlight(result); +} + +void TRemoteClientConnection::MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) { + for (auto& message : messages) { + bool oneWay = message.LocalFlags & MESSAGE_ONE_WAY_INTERNAL; + + if (oneWay) { + message.MessagePtr->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL; + + TBusMessage* ackMsg = this->PopAck(message.Header.Id); + if (!ackMsg) { + // TODO: expired? + } + + if (ackMsg != message.MessagePtr.Get()) { + // TODO: non-unique id? + } + + GetSession()->ReleaseInFlight({message.MessagePtr.Get()}); + ClientHandler->OnMessageSentOneWay(message.MessagePtr.Release()); + } else { + ClientHandler->OnMessageSent(message.MessagePtr.Get()); + AckMessages.Push(message); + } + } +} + +EMessageStatus TRemoteClientConnection::SendMessage(TBusMessage* req, bool wait) { + return SendMessageImpl(req, wait, false); +} + +EMessageStatus TRemoteClientConnection::SendMessageOneWay(TBusMessage* req, bool wait) { + return SendMessageImpl(req, wait, true); +} + +EMessageStatus TRemoteClientConnection::SendMessageImpl(TBusMessage* msg, bool wait, bool oneWay) { + msg->CheckClean(); + + if (Session->IsDown()) { + return MESSAGE_SHUTDOWN; + } + + if (wait) { + Y_VERIFY(!Session->Queue->GetExecutor()->IsInExecutorThread()); + GetSession()->ClientRemoteInFlight.Wait(); + } else { + if (!GetSession()->ClientRemoteInFlight.TryWait()) { + return MESSAGE_BUSY; + } + } + + GetSession()->AcquireInFlight({msg}); + + EMessageStatus ret = MESSAGE_OK; + + if (oneWay) { + msg->LocalFlags |= MESSAGE_ONE_WAY_INTERNAL; + } + + msg->GetHeader()->SendTime = Now(); + + if (IsReturnConnectFailedImmediately()) { + ret = MESSAGE_CONNECT_FAILED; + goto clean; + } + + Send(msg); + + return MESSAGE_OK; +clean: + msg->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL; + GetSession()->ReleaseInFlight({msg}); + return ret; +} + +void TRemoteClientConnection::OpenConnection() { + // TODO +} diff --git a/library/cpp/messagebus/remote_client_connection.h b/library/cpp/messagebus/remote_client_connection.h new file mode 100644 index 0000000000..fe80b7d2f9 --- /dev/null +++ b/library/cpp/messagebus/remote_client_connection.h @@ -0,0 +1,65 @@ +#pragma once + +#include "connection.h" +#include "local_tasks.h" +#include "remote_client_session.h" +#include "remote_connection.h" + +#include <util/generic/object_counter.h> + +namespace NBus { + namespace NPrivate { + class TRemoteClientConnection: public TRemoteConnection, public TBusClientConnection { + friend class TRemoteConnection; + friend struct TBusSessionImpl; + friend class TRemoteClientSession; + + private: + TObjectCounter<TRemoteClientConnection> ObjectCounter; + + TSyncAckMessages AckMessages; + + TLocalTasks TimeToTimeoutMessages; + + IBusClientHandler* const ClientHandler; + + public: + TRemoteClientConnection(TRemoteClientSessionPtr session, ui64 id, TNetAddr addr); + + inline TRemoteClientSession* GetSession(); + + SOCKET CreateSocket(const TNetAddr& addr); + + void TryConnect() override; + + void HandleEvent(SOCKET socket, void* cookie) override; + + TBusMessage* PopAck(TBusKey id); + + void WriterFillStatus() override; + + void ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) override; + + void BeforeTryWrite() override; + + void ProcessReplyQueue(); + + void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) override; + + void TimeoutMessages(); + + void ScheduleTimeoutMessages(); + + void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) override; + + EMessageStatus SendMessage(TBusMessage* pMes, bool wait) override; + + EMessageStatus SendMessageOneWay(TBusMessage* pMes, bool wait) override; + + EMessageStatus SendMessageImpl(TBusMessage*, bool wait, bool oneWay); + + void OpenConnection() override; + }; + + } +} diff --git a/library/cpp/messagebus/remote_client_session.cpp b/library/cpp/messagebus/remote_client_session.cpp new file mode 100644 index 0000000000..3bc421944f --- /dev/null +++ b/library/cpp/messagebus/remote_client_session.cpp @@ -0,0 +1,127 @@ +#include "remote_client_session.h" + +#include "mb_lwtrace.h" +#include "remote_client_connection.h" + +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <util/generic/cast.h> +#include <util/system/defaults.h> + +LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER) + +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteClientSession::TRemoteClientSession(TBusMessageQueue* queue, + TBusProtocol* proto, IBusClientHandler* handler, + const TBusClientSessionConfig& config, const TString& name) + : TBusSessionImpl(true, queue, proto, handler, config, name) + , ClientRemoteInFlight(config.MaxInFlight, "ClientRemoteInFlight") + , ClientHandler(handler) +{ +} + +TRemoteClientSession::~TRemoteClientSession() { + //Cerr << "~TRemoteClientSession" << Endl; +} + +void TRemoteClientSession::OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) { + TAutoPtr<TVectorSwaps<TBusMessagePtrAndHeader>> temp(new TVectorSwaps<TBusMessagePtrAndHeader>); + temp->swap(newMsg); + c->ReplyQueue.EnqueueAll(temp); + c->ScheduleWrite(); +} + +EMessageStatus TRemoteClientSession::SendMessageImpl(TBusMessage* msg, const TNetAddr* addr, bool wait, bool oneWay) { + if (Y_UNLIKELY(IsDown())) { + return MESSAGE_SHUTDOWN; + } + + TBusSocketAddr resolvedAddr; + EMessageStatus ret = GetMessageDestination(msg, addr, &resolvedAddr); + if (ret != MESSAGE_OK) { + return ret; + } + + msg->ReplyTo = resolvedAddr; + + TRemoteConnectionPtr c = ((TBusSessionImpl*)this)->GetConnection(resolvedAddr, true); + Y_ASSERT(!!c); + + return CheckedCast<TRemoteClientConnection*>(c.Get())->SendMessageImpl(msg, wait, oneWay); +} + +EMessageStatus TRemoteClientSession::SendMessage(TBusMessage* msg, const TNetAddr* addr, bool wait) { + return SendMessageImpl(msg, addr, wait, false); +} + +EMessageStatus TRemoteClientSession::SendMessageOneWay(TBusMessage* pMes, const TNetAddr* addr, bool wait) { + return SendMessageImpl(pMes, addr, wait, true); +} + +int TRemoteClientSession::GetInFlight() const noexcept { + return ClientRemoteInFlight.GetCurrent(); +} + +void TRemoteClientSession::FillStatus() { + TBusSessionImpl::FillStatus(); + + StatusData.Status.InFlightCount = ClientRemoteInFlight.GetCurrent(); + StatusData.Status.InputPaused = false; +} + +void TRemoteClientSession::AcquireInFlight(TArrayRef<TBusMessage* const> messages) { + for (auto message : messages) { + Y_ASSERT(!(message->LocalFlags & MESSAGE_IN_FLIGHT_ON_CLIENT)); + message->LocalFlags |= MESSAGE_IN_FLIGHT_ON_CLIENT; + } + ClientRemoteInFlight.IncrementMultiple(messages.size()); +} + +void TRemoteClientSession::ReleaseInFlight(TArrayRef<TBusMessage* const> messages) { + for (auto message : messages) { + Y_ASSERT(message->LocalFlags & MESSAGE_IN_FLIGHT_ON_CLIENT); + message->LocalFlags &= ~MESSAGE_IN_FLIGHT_ON_CLIENT; + } + ClientRemoteInFlight.ReleaseMultiple(messages.size()); +} + +void TRemoteClientSession::ReleaseInFlightAndCallOnReply(TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response) { + ReleaseInFlight({request.Get()}); + if (Y_UNLIKELY(AtomicGet(Down))) { + InvokeOnError(request, MESSAGE_SHUTDOWN); + InvokeOnError(response.MessagePtr.Release(), MESSAGE_SHUTDOWN); + + TRemoteConnectionReaderIncrementalStatus counter; + LWPROBE(Error, ToString(MESSAGE_SHUTDOWN), "", ""); + counter.StatusCounter[MESSAGE_SHUTDOWN] += 1; + GetDeadConnectionReaderStatusQueue()->EnqueueAndSchedule(counter); + } else { + TWhatThreadDoesPushPop pp("OnReply"); + ClientHandler->OnReply(request, response.MessagePtr.Release()); + } +} + +EMessageStatus TRemoteClientSession::GetMessageDestination(TBusMessage* mess, const TNetAddr* addrp, TBusSocketAddr* dest) { + if (addrp) { + *dest = *addrp; + } else { + TNetAddr tmp; + EMessageStatus ret = const_cast<TBusProtocol*>(GetProto())->GetDestination(this, mess, GetQueue()->GetLocator(), &tmp); + if (ret != MESSAGE_OK) { + return ret; + } + *dest = tmp; + } + return MESSAGE_OK; +} + +void TRemoteClientSession::OpenConnection(const TNetAddr& addr) { + GetConnection(addr)->OpenConnection(); +} + +TBusClientConnectionPtr TRemoteClientSession::GetConnection(const TNetAddr& addr) { + // TODO: GetConnection should not open + return CheckedCast<TRemoteClientConnection*>(((TBusSessionImpl*)this)->GetConnection(addr, true).Get()); +} diff --git a/library/cpp/messagebus/remote_client_session.h b/library/cpp/messagebus/remote_client_session.h new file mode 100644 index 0000000000..7160d0dae9 --- /dev/null +++ b/library/cpp/messagebus/remote_client_session.h @@ -0,0 +1,59 @@ +#pragma once + +#include "remote_client_session_semaphore.h" +#include "session_impl.h" + +#include <util/generic/array_ref.h> +#include <util/generic/object_counter.h> + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance +#endif + +namespace NBus { + namespace NPrivate { + using TRemoteClientSessionPtr = TIntrusivePtr<TRemoteClientSession>; + + class TRemoteClientSession: public TBusClientSession, public TBusSessionImpl { + friend class TRemoteClientConnection; + friend class TInvokeOnReply; + + public: + TObjectCounter<TRemoteClientSession> ObjectCounter; + + TRemoteClientSessionSemaphore ClientRemoteInFlight; + IBusClientHandler* const ClientHandler; + + public: + TRemoteClientSession(TBusMessageQueue* queue, TBusProtocol* proto, + IBusClientHandler* handler, + const TBusSessionConfig& config, const TString& name); + + ~TRemoteClientSession() override; + + void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) override; + + EMessageStatus SendMessageImpl(TBusMessage* msg, const TNetAddr* addr, bool wait, bool oneWay); + EMessageStatus SendMessage(TBusMessage* msg, const TNetAddr* addr = nullptr, bool wait = false) override; + EMessageStatus SendMessageOneWay(TBusMessage* msg, const TNetAddr* addr = nullptr, bool wait = false) override; + + int GetInFlight() const noexcept override; + void FillStatus() override; + void AcquireInFlight(TArrayRef<TBusMessage* const> messages); + void ReleaseInFlight(TArrayRef<TBusMessage* const> messages); + void ReleaseInFlightAndCallOnReply(TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response); + + EMessageStatus GetMessageDestination(TBusMessage* mess, const TNetAddr* addrp, TBusSocketAddr* dest); + + void OpenConnection(const TNetAddr&) override; + + TBusClientConnectionPtr GetConnection(const TNetAddr&) override; + }; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + } +} diff --git a/library/cpp/messagebus/remote_client_session_semaphore.cpp b/library/cpp/messagebus/remote_client_session_semaphore.cpp new file mode 100644 index 0000000000..f877ed4257 --- /dev/null +++ b/library/cpp/messagebus/remote_client_session_semaphore.cpp @@ -0,0 +1,67 @@ +#include "remote_client_session_semaphore.h" + +#include <util/stream/output.h> +#include <util/system/yassert.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteClientSessionSemaphore::TRemoteClientSessionSemaphore(TAtomicBase limit, const char* name) + : Name(name) + , Limit(limit) + , Current(0) + , StopSignal(0) +{ + Y_VERIFY(limit > 0, "limit must be > 0"); + Y_UNUSED(Name); +} + +TRemoteClientSessionSemaphore::~TRemoteClientSessionSemaphore() { + Y_VERIFY(AtomicGet(Current) == 0); +} + +bool TRemoteClientSessionSemaphore::TryAcquire() { + if (!TryWait()) { + return false; + } + + AtomicIncrement(Current); + return true; +} + +bool TRemoteClientSessionSemaphore::TryWait() { + if (AtomicGet(Current) < Limit) + return true; + if (Y_UNLIKELY(AtomicGet(StopSignal))) + return true; + return false; +} + +void TRemoteClientSessionSemaphore::Acquire() { + Wait(); + + Increment(); +} + +void TRemoteClientSessionSemaphore::Increment() { + IncrementMultiple(1); +} + +void TRemoteClientSessionSemaphore::IncrementMultiple(TAtomicBase count) { + AtomicAdd(Current, count); + Updated(); +} + +void TRemoteClientSessionSemaphore::Release() { + ReleaseMultiple(1); +} + +void TRemoteClientSessionSemaphore::ReleaseMultiple(TAtomicBase count) { + AtomicSub(Current, count); + Updated(); +} + +void TRemoteClientSessionSemaphore::Stop() { + AtomicSet(StopSignal, 1); + Updated(); +} diff --git a/library/cpp/messagebus/remote_client_session_semaphore.h b/library/cpp/messagebus/remote_client_session_semaphore.h new file mode 100644 index 0000000000..286ca3c86f --- /dev/null +++ b/library/cpp/messagebus/remote_client_session_semaphore.h @@ -0,0 +1,42 @@ +#pragma once + +#include "cc_semaphore.h" + +#include <util/generic/noncopyable.h> +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> + +namespace NBus { + namespace NPrivate { + class TRemoteClientSessionSemaphore: public TComplexConditionSemaphore<TRemoteClientSessionSemaphore> { + private: + const char* const Name; + + TAtomicBase const Limit; + TAtomic Current; + TAtomic StopSignal; + + public: + TRemoteClientSessionSemaphore(TAtomicBase limit, const char* name = "unnamed"); + ~TRemoteClientSessionSemaphore(); + + TAtomicBase GetCurrent() const { + return AtomicGet(Current); + } + + void Acquire(); + bool TryAcquire(); + void Increment(); + void IncrementMultiple(TAtomicBase count); + bool TryWait(); + void Release(); + void ReleaseMultiple(TAtomicBase count); + void Stop(); + + private: + void CheckNeedToUnlock(); + }; + + } +} diff --git a/library/cpp/messagebus/remote_connection.cpp b/library/cpp/messagebus/remote_connection.cpp new file mode 100644 index 0000000000..22932569db --- /dev/null +++ b/library/cpp/messagebus/remote_connection.cpp @@ -0,0 +1,974 @@ +#include "remote_connection.h" + +#include "key_value_printer.h" +#include "mb_lwtrace.h" +#include "network.h" +#include "remote_client_connection.h" +#include "remote_client_session.h" +#include "remote_server_session.h" +#include "session_impl.h" + +#include <library/cpp/messagebus/actor/what_thread_does.h> + +#include <util/generic/cast.h> +#include <util/network/init.h> +#include <util/system/atomic.h> + +LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER) + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; + +namespace NBus { + namespace NPrivate { + TRemoteConnection::TRemoteConnection(TRemoteSessionPtr session, ui64 connectionId, TNetAddr addr) + : TActor<TRemoteConnection, TWriterTag>(session->Queue->WorkQueue.Get()) + , TActor<TRemoteConnection, TReaderTag>(session->Queue->WorkQueue.Get()) + , TScheduleActor<TRemoteConnection, TWriterTag>(&session->Queue->Scheduler) + , Session(session) + , Proto(session->Proto) + , Config(session->Config) + , RemovedFromSession(false) + , ConnectionId(connectionId) + , PeerAddr(addr) + , PeerAddrSocketAddr(addr) + , CreatedTime(TInstant::Now()) + , ReturnConnectFailedImmediately(false) + , GranStatus(Config.Secret.StatusFlushPeriod) + , QuotaMsg(!Session->IsSource_, Config.PerConnectionMaxInFlight, 0) + , QuotaBytes(!Session->IsSource_, Config.PerConnectionMaxInFlightBySize, 0) + , MaxBufferSize(session->Config.MaxBufferSize) + , ShutdownReason(MESSAGE_OK) + { + WriterData.Status.ConnectionId = connectionId; + WriterData.Status.PeerAddr = PeerAddr; + ReaderData.Status.ConnectionId = connectionId; + + const TInstant now = TInstant::Now(); + + WriterFillStatus(); + + GranStatus.Writer.Update(WriterData.Status, now, true); + GranStatus.Reader.Update(ReaderData.Status, now, true); + } + + TRemoteConnection::~TRemoteConnection() { + Y_VERIFY(ReplyQueue.IsEmpty()); + } + + TRemoteConnection::TWriterData::TWriterData() + : Down(0) + , SocketVersion(0) + , InFlight(0) + , AwakeFlags(0) + , State(WRITER_FILLING) + { + } + + TRemoteConnection::TWriterData::~TWriterData() { + Y_VERIFY(AtomicGet(Down)); + Y_VERIFY(SendQueue.Empty()); + } + + bool TRemoteConnection::TReaderData::HasBytesInBuf(size_t bytes) noexcept { + size_t left = Buffer.Size() - Offset; + + return (MoreBytes = left >= bytes ? 0 : bytes - left) == 0; + } + + void TRemoteConnection::TWriterData::SetChannel(NEventLoop::TChannelPtr channel) { + Y_VERIFY(!Channel, "must not have channel"); + Y_VERIFY(Buffer.GetBuffer().Empty() && Buffer.LeftSize() == 0, "buffer must be empty"); + Y_VERIFY(State == WRITER_FILLING, "state must be initial"); + Channel = channel; + } + + void TRemoteConnection::TReaderData::SetChannel(NEventLoop::TChannelPtr channel) { + Y_VERIFY(!Channel, "must not have channel"); + Y_VERIFY(Buffer.Empty(), "buffer must be empty"); + Channel = channel; + } + + void TRemoteConnection::TWriterData::DropChannel() { + if (!!Channel) { + Channel->Unregister(); + Channel.Drop(); + } + + Buffer.Reset(); + State = WRITER_FILLING; + } + + void TRemoteConnection::TReaderData::DropChannel() { + // TODO: make Drop call Unregister + if (!!Channel) { + Channel->Unregister(); + Channel.Drop(); + } + Buffer.Reset(); + Offset = 0; + } + + TRemoteConnection::TReaderData::TReaderData() + : Down(0) + , SocketVersion(0) + , Offset(0) + , MoreBytes(0) + { + } + + TRemoteConnection::TReaderData::~TReaderData() { + Y_VERIFY(AtomicGet(Down)); + } + + void TRemoteConnection::Send(TNonDestroyingAutoPtr<TBusMessage> msg) { + BeforeSendQueue.Enqueue(msg.Release()); + AtomicIncrement(WriterData.InFlight); + ScheduleWrite(); + } + + void TRemoteConnection::ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) { + if (!reconnect) { + // Do not clear send queue if reconnecting + WriterData.SendQueue.Clear(&result); + } + } + + void TRemoteConnection::Shutdown(EMessageStatus status) { + ScheduleShutdown(status); + + ReaderData.ShutdownComplete.WaitI(); + WriterData.ShutdownComplete.WaitI(); + } + + void TRemoteConnection::TryConnect() { + Y_FAIL("TryConnect is client connection only operation"); + } + + void TRemoteConnection::ScheduleRead() { + GetReaderActor()->Schedule(); + } + + void TRemoteConnection::ScheduleWrite() { + GetWriterActor()->Schedule(); + } + + void TRemoteConnection::WriterRotateCounters() { + if (!WriterData.TimeToRotateCounters.FetchTask()) { + return; + } + + WriterData.Status.DurationCounterPrev = WriterData.Status.DurationCounter; + Reset(WriterData.Status.DurationCounter); + } + + void TRemoteConnection::WriterSendStatus(TInstant now, bool force) { + GranStatus.Writer.Update(std::bind(&TRemoteConnection::WriterGetStatus, this), now, force); + } + + void TRemoteConnection::ReaderSendStatus(TInstant now, bool force) { + GranStatus.Reader.Update(std::bind(&TRemoteConnection::ReaderFillStatus, this), now, force); + } + + const TRemoteConnectionReaderStatus& TRemoteConnection::ReaderFillStatus() { + ReaderData.Status.BufferSize = ReaderData.Buffer.Capacity(); + ReaderData.Status.QuotaMsg = QuotaMsg.Tokens(); + ReaderData.Status.QuotaBytes = QuotaBytes.Tokens(); + + return ReaderData.Status; + } + + void TRemoteConnection::ProcessItem(TReaderTag, ::NActor::TDefaultTag, TWriterToReaderSocketMessage readSocket) { + if (AtomicGet(ReaderData.Down)) { + ReaderData.Status.Fd = INVALID_SOCKET; + return; + } + + ReaderData.DropChannel(); + + ReaderData.Status.Fd = readSocket.Socket; + ReaderData.SocketVersion = readSocket.SocketVersion; + + if (readSocket.Socket != INVALID_SOCKET) { + ReaderData.SetChannel(Session->ReadEventLoop.Register(readSocket.Socket, this, ReadCookie)); + ReaderData.Channel->EnableRead(); + } + } + + void TRemoteConnection::ProcessItem(TWriterTag, TReconnectTag, ui32 socketVersion) { + Y_VERIFY(socketVersion <= WriterData.SocketVersion, "something weird"); + + if (WriterData.SocketVersion != socketVersion) { + return; + } + Y_VERIFY(WriterData.Status.Connected, "must be connected at this point"); + Y_VERIFY(!!WriterData.Channel, "must have channel at this point"); + + WriterData.Status.Connected = false; + WriterData.DropChannel(); + WriterData.Status.MyAddr = TNetAddr(); + ++WriterData.SocketVersion; + LastConnectAttempt = TInstant(); + + TMessagesPtrs cleared; + ClearOutgoingQueue(cleared, true); + WriterErrorMessages(cleared, MESSAGE_DELIVERY_FAILED); + + FireClientConnectionEvent(TClientConnectionEvent::DISCONNECTED); + + ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(INVALID_SOCKET, WriterData.SocketVersion)); + } + + void TRemoteConnection::ProcessItem(TWriterTag, TWakeReaderTag, ui32 awakeFlags) { + WriterData.AwakeFlags |= awakeFlags; + + ReadQuotaWakeup(); + } + + void TRemoteConnection::Act(TReaderTag) { + TInstant now = TInstant::Now(); + + ReaderData.Status.Acts += 1; + + ReaderGetSocketQueue()->DequeueAllLikelyEmpty(); + + if (AtomicGet(ReaderData.Down)) { + ReaderData.DropChannel(); + + ReaderProcessStatusDown(); + ReaderData.ShutdownComplete.Signal(); + + } else if (!!ReaderData.Channel) { + Y_ASSERT(ReaderData.ReadMessages.empty()); + + for (int i = 0;; ++i) { + if (i == 100) { + // perform other tasks + GetReaderActor()->AddTaskFromActorLoop(); + break; + } + + if (NeedInterruptRead()) { + ReaderData.Channel->EnableRead(); + break; + } + + if (!ReaderFillBuffer()) + break; + + if (!ReaderProcessBuffer()) + break; + } + + ReaderFlushMessages(); + } + + ReaderSendStatus(now); + } + + bool TRemoteConnection::QuotaAcquire(size_t msg, size_t bytes) { + ui32 wakeFlags = 0; + + if (!QuotaMsg.Acquire(msg)) + wakeFlags |= WAKE_QUOTA_MSG; + + else if (!QuotaBytes.Acquire(bytes)) + wakeFlags |= WAKE_QUOTA_BYTES; + + if (wakeFlags) { + ReaderData.Status.QuotaExhausted++; + + WriterGetWakeQueue()->EnqueueAndSchedule(wakeFlags); + } + + return wakeFlags == 0; + } + + void TRemoteConnection::QuotaConsume(size_t msg, size_t bytes) { + QuotaMsg.Consume(msg); + QuotaBytes.Consume(bytes); + } + + void TRemoteConnection::QuotaReturnSelf(size_t items, size_t bytes) { + if (QuotaReturnValues(items, bytes)) + ReadQuotaWakeup(); + } + + void TRemoteConnection::QuotaReturnAside(size_t items, size_t bytes) { + if (QuotaReturnValues(items, bytes) && !AtomicGet(WriterData.Down)) + WriterGetWakeQueue()->EnqueueAndSchedule(0x0); + } + + bool TRemoteConnection::QuotaReturnValues(size_t items, size_t bytes) { + bool rMsg = QuotaMsg.Return(items); + bool rBytes = QuotaBytes.Return(bytes); + + return rMsg || rBytes; + } + + void TRemoteConnection::ReadQuotaWakeup() { + const ui32 mask = WriterData.AwakeFlags & WriteWakeFlags(); + + if (mask && mask == WriterData.AwakeFlags) { + WriterData.Status.ReaderWakeups++; + WriterData.AwakeFlags = 0; + + ScheduleRead(); + } + } + + ui32 TRemoteConnection::WriteWakeFlags() const { + ui32 awakeFlags = 0; + + if (QuotaMsg.IsAboveWake()) + awakeFlags |= WAKE_QUOTA_MSG; + + if (QuotaBytes.IsAboveWake()) + awakeFlags |= WAKE_QUOTA_BYTES; + + return awakeFlags; + } + + bool TRemoteConnection::ReaderProcessBuffer() { + TInstant now = TInstant::Now(); + + for (;;) { + if (!ReaderData.HasBytesInBuf(sizeof(TBusHeader))) { + break; + } + + TBusHeader header(MakeArrayRef(ReaderData.Buffer.Data() + ReaderData.Offset, ReaderData.Buffer.Size() - ReaderData.Offset)); + + if (header.Size < sizeof(TBusHeader)) { + LWPROBE(Error, ToString(MESSAGE_HEADER_CORRUPTED), ToString(PeerAddr), ToString(header.Size)); + ReaderData.Status.Incremental.StatusCounter[MESSAGE_HEADER_CORRUPTED] += 1; + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_HEADER_CORRUPTED, false); + return false; + } + + if (!IsVersionNegotiation(header) && !IsBusKeyValid(header.Id)) { + LWPROBE(Error, ToString(MESSAGE_HEADER_CORRUPTED), ToString(PeerAddr), ToString(header.Size)); + ReaderData.Status.Incremental.StatusCounter[MESSAGE_HEADER_CORRUPTED] += 1; + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_HEADER_CORRUPTED, false); + return false; + } + + if (header.Size > Config.MaxMessageSize) { + LWPROBE(Error, ToString(MESSAGE_MESSAGE_TOO_LARGE), ToString(PeerAddr), ToString(header.Size)); + ReaderData.Status.Incremental.StatusCounter[MESSAGE_MESSAGE_TOO_LARGE] += 1; + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_MESSAGE_TOO_LARGE, false); + return false; + } + + if (!ReaderData.HasBytesInBuf(header.Size)) { + if (ReaderData.Offset == 0) { + ReaderData.Buffer.Reserve(header.Size); + } + break; + } + + if (!QuotaAcquire(1, header.Size)) + return false; + + if (!MessageRead(MakeArrayRef(ReaderData.Buffer.Data() + ReaderData.Offset, header.Size), now)) { + return false; + } + + ReaderData.Offset += header.Size; + } + + ReaderData.Buffer.ChopHead(ReaderData.Offset); + ReaderData.Offset = 0; + + if (ReaderData.Buffer.Capacity() > MaxBufferSize && ReaderData.Buffer.Size() <= MaxBufferSize) { + ReaderData.Status.Incremental.BufferDrops += 1; + + TBuffer temp; + // probably should use another constant + temp.Reserve(Config.DefaultBufferSize); + temp.Append(ReaderData.Buffer.Data(), ReaderData.Buffer.Size()); + + ReaderData.Buffer.Swap(temp); + } + + return true; + } + + bool TRemoteConnection::ReaderFillBuffer() { + if (!ReaderData.BufferMore()) + return true; + + if (ReaderData.Buffer.Avail() == 0) { + if (ReaderData.Buffer.Size() == 0) { + ReaderData.Buffer.Reserve(Config.DefaultBufferSize); + } else { + ReaderData.Buffer.Reserve(ReaderData.Buffer.Size() * 2); + } + } + + Y_ASSERT(ReaderData.Buffer.Avail() > 0); + + ssize_t bytes; + { + TWhatThreadDoesPushPop pp("recv syscall"); + bytes = SocketRecv(ReaderData.Channel->GetSocket(), TArrayRef<char>(ReaderData.Buffer.Pos(), ReaderData.Buffer.Avail())); + } + + if (bytes < 0) { + if (WouldBlock()) { + ReaderData.Channel->EnableRead(); + return false; + } else { + ReaderData.Channel->DisableRead(); + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, false); + return false; + } + } + + if (bytes == 0) { + ReaderData.Channel->DisableRead(); + // TODO: incorrect: it is possible that only input is shutdown, and output is available + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, false); + return false; + } + + ReaderData.Status.Incremental.NetworkOps += 1; + + ReaderData.Buffer.Advance(bytes); + ReaderData.MoreBytes = 0; + return true; + } + + void TRemoteConnection::ClearBeforeSendQueue(EMessageStatus reason) { + BeforeSendQueue.DequeueAll(std::bind(&TRemoteConnection::WriterBeforeWriteErrorMessage, this, std::placeholders::_1, reason)); + } + + void TRemoteConnection::ClearReplyQueue(EMessageStatus reason) { + TVectorSwaps<TBusMessagePtrAndHeader> replyQueueTemp; + Y_ASSERT(replyQueueTemp.empty()); + ReplyQueue.DequeueAllSingleConsumer(&replyQueueTemp); + + TVector<TBusMessage*> messages; + for (TVectorSwaps<TBusMessagePtrAndHeader>::reverse_iterator message = replyQueueTemp.rbegin(); + message != replyQueueTemp.rend(); ++message) { + messages.push_back(message->MessagePtr.Release()); + } + + WriterErrorMessages(messages, reason); + + replyQueueTemp.clear(); + } + + void TRemoteConnection::ProcessBeforeSendQueueMessage(TBusMessage* message, TInstant now) { + // legacy clients expect this field to be set + if (!Session->IsSource_) { + message->SendTime = now.MilliSeconds(); + } + + WriterData.SendQueue.PushBack(message); + } + + void TRemoteConnection::ProcessBeforeSendQueue(TInstant now) { + BeforeSendQueue.DequeueAll(std::bind(&TRemoteConnection::ProcessBeforeSendQueueMessage, this, std::placeholders::_1, now)); + } + + void TRemoteConnection::WriterFillInFlight() { + // this is hack for TLoadBalancedProtocol + WriterFillStatus(); + AtomicSet(WriterData.InFlight, WriterData.Status.GetInFlight()); + } + + const TRemoteConnectionWriterStatus& TRemoteConnection::WriterGetStatus() { + WriterRotateCounters(); + WriterFillStatus(); + + return WriterData.Status; + } + + void TRemoteConnection::WriterFillStatus() { + if (!!WriterData.Channel) { + WriterData.Status.Fd = WriterData.Channel->GetSocket(); + } else { + WriterData.Status.Fd = INVALID_SOCKET; + } + WriterData.Status.BufferSize = WriterData.Buffer.Capacity(); + WriterData.Status.SendQueueSize = WriterData.SendQueue.Size(); + WriterData.Status.State = WriterData.State; + } + + void TRemoteConnection::WriterProcessStatusDown() { + Session->GetDeadConnectionWriterStatusQueue()->EnqueueAndSchedule(WriterData.Status.Incremental); + Reset(WriterData.Status.Incremental); + } + + void TRemoteConnection::ReaderProcessStatusDown() { + Session->GetDeadConnectionReaderStatusQueue()->EnqueueAndSchedule(ReaderData.Status.Incremental); + Reset(ReaderData.Status.Incremental); + } + + void TRemoteConnection::ProcessWriterDown() { + if (!RemovedFromSession) { + Session->GetRemoveConnectionQueue()->EnqueueAndSchedule(this); + + if (Session->IsSource_) { + if (WriterData.Status.Connected) { + FireClientConnectionEvent(TClientConnectionEvent::DISCONNECTED); + } + } + + LWPROBE(Disconnected, ToString(PeerAddr)); + RemovedFromSession = true; + } + + WriterData.DropChannel(); + + DropEnqueuedData(ShutdownReason, MESSAGE_SHUTDOWN); + + WriterProcessStatusDown(); + + WriterData.ShutdownComplete.Signal(); + } + + void TRemoteConnection::DropEnqueuedData(EMessageStatus reason, EMessageStatus reasonForQueues) { + ClearReplyQueue(reasonForQueues); + ClearBeforeSendQueue(reasonForQueues); + WriterGetReconnectQueue()->Clear(); + WriterGetWakeQueue()->Clear(); + + TMessagesPtrs cleared; + ClearOutgoingQueue(cleared, false); + + if (!Session->IsSource_) { + for (auto& i : cleared) { + TBusMessagePtrAndHeader h(i); + CheckedCast<TRemoteServerSession*>(Session.Get())->ReleaseInWorkResponses(MakeArrayRef(&h, 1)); + // assignment back is weird + i = h.MessagePtr.Release(); + // and this part is not batch + } + } + + WriterErrorMessages(cleared, reason); + } + + void TRemoteConnection::BeforeTryWrite() { + } + + void TRemoteConnection::Act(TWriterTag) { + TInstant now = TInstant::Now(); + + WriterData.Status.Acts += 1; + + if (Y_UNLIKELY(AtomicGet(WriterData.Down))) { + // dump status must work even if WriterDown + WriterSendStatus(now, true); + ProcessWriterDown(); + return; + } + + ProcessBeforeSendQueue(now); + + BeforeTryWrite(); + + WriterFillInFlight(); + + WriterGetReconnectQueue()->DequeueAllLikelyEmpty(); + + if (!WriterData.Status.Connected) { + TryConnect(); + } else { + for (int i = 0;; ++i) { + if (i == 100) { + // perform other tasks + GetWriterActor()->AddTaskFromActorLoop(); + break; + } + + if (WriterData.State == WRITER_FILLING) { + WriterFillBuffer(); + + if (WriterData.State == WRITER_FILLING) { + WriterData.Channel->DisableWrite(); + break; + } + + Y_ASSERT(!WriterData.Buffer.Empty()); + } + + if (WriterData.State == WRITER_FLUSHING) { + WriterFlushBuffer(); + + if (WriterData.State == WRITER_FLUSHING) { + break; + } + } + } + } + + WriterGetWakeQueue()->DequeueAllLikelyEmpty(); + + WriterSendStatus(now); + } + + void TRemoteConnection::WriterFlushBuffer() { + Y_ASSERT(WriterData.State == WRITER_FLUSHING); + Y_ASSERT(!WriterData.Buffer.Empty()); + + WriterData.CorkUntil = TInstant::Zero(); + + while (!WriterData.Buffer.Empty()) { + ssize_t bytes; + { + TWhatThreadDoesPushPop pp("send syscall"); + bytes = SocketSend(WriterData.Channel->GetSocket(), TArrayRef<const char>(WriterData.Buffer.LeftPos(), WriterData.Buffer.Size())); + } + + if (bytes < 0) { + if (WouldBlock()) { + WriterData.Channel->EnableWrite(); + return; + } else { + WriterData.Channel->DisableWrite(); + ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, true); + return; + } + } + + WriterData.Status.Incremental.NetworkOps += 1; + + WriterData.Buffer.LeftProceed(bytes); + } + + WriterData.Buffer.Clear(); + if (WriterData.Buffer.Capacity() > MaxBufferSize) { + WriterData.Status.Incremental.BufferDrops += 1; + WriterData.Buffer.Reset(); + } + + WriterData.State = WRITER_FILLING; + } + + void TRemoteConnection::ScheduleShutdownOnServerOrReconnectOnClient(EMessageStatus status, bool writer) { + if (Session->IsSource_) { + WriterGetReconnectQueue()->EnqueueAndSchedule(writer ? WriterData.SocketVersion : ReaderData.SocketVersion); + } else { + ScheduleShutdown(status); + } + } + + void TRemoteConnection::ScheduleShutdown(EMessageStatus status) { + ShutdownReason = status; + + AtomicSet(ReaderData.Down, 1); + ScheduleRead(); + + AtomicSet(WriterData.Down, 1); + ScheduleWrite(); + } + + void TRemoteConnection::CallSerialize(TBusMessage* msg, TBuffer& buffer) const { + size_t posForAssertion = buffer.Size(); + Proto->Serialize(msg, buffer); + Y_VERIFY(buffer.Size() >= posForAssertion, + "incorrect Serialize implementation, pos before serialize: %d, pos after serialize: %d", + int(posForAssertion), int(buffer.Size())); + } + + namespace { + inline void WriteHeader(const TBusHeader& header, TBuffer& data) { + data.Reserve(data.Size() + sizeof(TBusHeader)); + /// \todo hton instead of memcpy + memcpy(data.Data() + data.Size(), &header, sizeof(TBusHeader)); + data.Advance(sizeof(TBusHeader)); + } + + inline void WriteDummyHeader(TBuffer& data) { + data.Resize(data.Size() + sizeof(TBusHeader)); + } + + } + + void TRemoteConnection::SerializeMessage(TBusMessage* msg, TBuffer* data, TMessageCounter* counter) const { + size_t pos = data->Size(); + + size_t dataSize; + + bool compressionRequested = msg->IsCompressed(); + + if (compressionRequested) { + TBuffer compdata; + TBuffer plaindata; + CallSerialize(msg, plaindata); + + dataSize = sizeof(TBusHeader) + plaindata.Size(); + + NCodecs::TCodecPtr c = Proto->GetTransportCodec(); + c->Encode(TStringBuf{plaindata.data(), plaindata.size()}, compdata); + + if (compdata.Size() < plaindata.Size()) { + plaindata.Clear(); + msg->GetHeader()->Size = sizeof(TBusHeader) + compdata.Size(); + WriteHeader(*msg->GetHeader(), *data); + data->Append(compdata.Data(), compdata.Size()); + } else { + compdata.Clear(); + msg->SetCompressed(false); + msg->GetHeader()->Size = sizeof(TBusHeader) + plaindata.Size(); + WriteHeader(*msg->GetHeader(), *data); + data->Append(plaindata.Data(), plaindata.Size()); + } + } else { + WriteDummyHeader(*data); + CallSerialize(msg, *data); + + dataSize = msg->GetHeader()->Size = data->Size() - pos; + + data->Proceed(pos); + WriteHeader(*msg->GetHeader(), *data); + data->Proceed(pos + msg->GetHeader()->Size); + } + + Y_ASSERT(msg->GetHeader()->Size == data->Size() - pos); + counter->AddMessage(dataSize, data->Size() - pos, msg->IsCompressed(), compressionRequested); + } + + TBusMessage* TRemoteConnection::DeserializeMessage(TArrayRef<const char> dataRef, const TBusHeader* header, TMessageCounter* messageCounter, EMessageStatus* status) const { + size_t dataSize; + + TBusMessage* message; + if (header->FlagsInternal & MESSAGE_COMPRESS_INTERNAL) { + TBuffer msg; + { + TBuffer plaindata; + NCodecs::TCodecPtr c = Proto->GetTransportCodec(); + try { + TArrayRef<const char> payload = TBusMessage::GetPayload(dataRef); + c->Decode(TStringBuf{payload.data(), payload.size()}, plaindata); + } catch (...) { + // catch all, because + // http://nga.at.yandex-team.ru/replies.xml?item_no=3884 + *status = MESSAGE_DECOMPRESS_ERROR; + return nullptr; + } + + msg.Append(dataRef.data(), sizeof(TBusHeader)); + msg.Append(plaindata.Data(), plaindata.Size()); + } + TArrayRef<const char> msgRef(msg.Data(), msg.Size()); + dataSize = sizeof(TBusHeader) + msgRef.size(); + // TODO: track error types + message = Proto->Deserialize(header->Type, msgRef.Slice(sizeof(TBusHeader))).Release(); + if (!message) { + *status = MESSAGE_DESERIALIZE_ERROR; + return nullptr; + } + *message->GetHeader() = *header; + message->SetCompressed(true); + } else { + dataSize = dataRef.size(); + message = Proto->Deserialize(header->Type, dataRef.Slice(sizeof(TBusHeader))).Release(); + if (!message) { + *status = MESSAGE_DESERIALIZE_ERROR; + return nullptr; + } + *message->GetHeader() = *header; + } + + messageCounter->AddMessage(dataSize, dataRef.size(), header->FlagsInternal & MESSAGE_COMPRESS_INTERNAL, false); + + return message; + } + + void TRemoteConnection::ResetOneWayFlag(TArrayRef<TBusMessage*> messages) { + for (auto message : messages) { + message->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL; + } + } + + void TRemoteConnection::ReaderFlushMessages() { + if (!ReaderData.ReadMessages.empty()) { + Session->OnMessageReceived(this, ReaderData.ReadMessages); + ReaderData.ReadMessages.clear(); + } + } + + // @return false if actor should break + bool TRemoteConnection::MessageRead(TArrayRef<const char> readDataRef, TInstant now) { + TBusHeader header(readDataRef); + + Y_ASSERT(readDataRef.size() == header.Size); + + if (header.GetVersionInternal() != YBUS_VERSION) { + ReaderProcessMessageUnknownVersion(readDataRef); + return true; + } + + EMessageStatus deserializeFailureStatus = MESSAGE_OK; + TBusMessage* r = DeserializeMessage(readDataRef, &header, &ReaderData.Status.Incremental.MessageCounter, &deserializeFailureStatus); + + if (!r) { + Y_VERIFY(deserializeFailureStatus != MESSAGE_OK, "state check"); + LWPROBE(Error, ToString(deserializeFailureStatus), ToString(PeerAddr), ""); + ReaderData.Status.Incremental.StatusCounter[deserializeFailureStatus] += 1; + ScheduleShutdownOnServerOrReconnectOnClient(deserializeFailureStatus, false); + return false; + } + + LWPROBE(Read, r->GetHeader()->Size); + + r->ReplyTo = PeerAddrSocketAddr; + + TBusMessagePtrAndHeader h(r); + r->RecvTime = now; + + QuotaConsume(1, header.Size); + + ReaderData.ReadMessages.push_back(h); + if (ReaderData.ReadMessages.size() >= 100) { + ReaderFlushMessages(); + } + + return true; + } + + void TRemoteConnection::WriterFillBuffer() { + Y_ASSERT(WriterData.State == WRITER_FILLING); + + Y_ASSERT(WriterData.Buffer.LeftSize() == 0); + + if (Y_UNLIKELY(!WrongVersionRequests.IsEmpty())) { + TVector<TBusHeader> headers; + WrongVersionRequests.DequeueAllSingleConsumer(&headers); + for (TVector<TBusHeader>::reverse_iterator header = headers.rbegin(); + header != headers.rend(); ++header) { + TBusHeader response = *header; + response.SendTime = NBus::Now(); + response.Size = sizeof(TBusHeader); + response.FlagsInternal = 0; + response.SetVersionInternal(YBUS_VERSION); + WriteHeader(response, WriterData.Buffer.GetBuffer()); + } + + Y_ASSERT(!WriterData.Buffer.Empty()); + WriterData.State = WRITER_FLUSHING; + return; + } + + TTempTlsVector<TBusMessagePtrAndHeader, void, TVectorSwaps> writeMessages; + + for (;;) { + THolder<TBusMessage> writeMessage(WriterData.SendQueue.PopFront()); + if (!writeMessage) { + break; + } + + if (Config.Cork != TDuration::Zero()) { + if (WriterData.CorkUntil == TInstant::Zero()) { + WriterData.CorkUntil = TInstant::Now() + Config.Cork; + } + } + + size_t sizeBeforeSerialize = WriterData.Buffer.Size(); + + TMessageCounter messageCounter = WriterData.Status.Incremental.MessageCounter; + + SerializeMessage(writeMessage.Get(), &WriterData.Buffer.GetBuffer(), &messageCounter); + + size_t written = WriterData.Buffer.Size() - sizeBeforeSerialize; + if (written > Config.MaxMessageSize) { + WriterData.Buffer.GetBuffer().EraseBack(written); + WriterBeforeWriteErrorMessage(writeMessage.Release(), MESSAGE_MESSAGE_TOO_LARGE); + continue; + } + + WriterData.Status.Incremental.MessageCounter = messageCounter; + + TBusMessagePtrAndHeader h(writeMessage.Release()); + writeMessages.GetVector()->push_back(h); + + Y_ASSERT(!WriterData.Buffer.Empty()); + if (WriterData.Buffer.Size() >= Config.SendThreshold) { + break; + } + } + + if (!WriterData.Buffer.Empty()) { + if (WriterData.Buffer.Size() >= Config.SendThreshold) { + WriterData.State = WRITER_FLUSHING; + } else if (WriterData.CorkUntil == TInstant::Zero()) { + WriterData.State = WRITER_FLUSHING; + } else if (TInstant::Now() >= WriterData.CorkUntil) { + WriterData.State = WRITER_FLUSHING; + } else { + // keep filling + Y_ASSERT(WriterData.State == WRITER_FILLING); + GetWriterSchedulerActor()->ScheduleAt(WriterData.CorkUntil); + } + } else { + // keep filling + Y_ASSERT(WriterData.State == WRITER_FILLING); + } + + size_t bytes = MessageSize(*writeMessages.GetVector()); + + QuotaReturnSelf(writeMessages.GetVector()->size(), bytes); + + // This is called before `send` syscall inducing latency + MessageSent(*writeMessages.GetVector()); + } + + size_t TRemoteConnection::MessageSize(TArrayRef<TBusMessagePtrAndHeader> messages) { + size_t size = 0; + for (const auto& message : messages) + size += message.MessagePtr->RequestSize; + + return size; + } + + size_t TRemoteConnection::GetInFlight() { + return AtomicGet(WriterData.InFlight); + } + + size_t TRemoteConnection::GetConnectSyscallsNumForTest() { + return WriterData.Status.ConnectSyscalls; + } + + void TRemoteConnection::WriterBeforeWriteErrorMessage(TBusMessage* message, EMessageStatus status) { + if (Session->IsSource_) { + CheckedCast<TRemoteClientSession*>(Session.Get())->ReleaseInFlight({message}); + WriterErrorMessage(message, status); + } else { + TBusMessagePtrAndHeader h(message); + CheckedCast<TRemoteServerSession*>(Session.Get())->ReleaseInWorkResponses(MakeArrayRef(&h, 1)); + WriterErrorMessage(h.MessagePtr.Release(), status); + } + } + + void TRemoteConnection::WriterErrorMessage(TNonDestroyingAutoPtr<TBusMessage> m, EMessageStatus status) { + TBusMessage* released = m.Release(); + WriterErrorMessages(MakeArrayRef(&released, 1), status); + } + + void TRemoteConnection::WriterErrorMessages(const TArrayRef<TBusMessage*> ms, EMessageStatus status) { + ResetOneWayFlag(ms); + + WriterData.Status.Incremental.StatusCounter[status] += ms.size(); + for (auto m : ms) { + Session->InvokeOnError(m, status); + } + } + + void TRemoteConnection::FireClientConnectionEvent(TClientConnectionEvent::EType type) { + Y_VERIFY(Session->IsSource_, "state check"); + TClientConnectionEvent event(type, ConnectionId, PeerAddr); + TRemoteClientSession* session = CheckedCast<TRemoteClientSession*>(Session.Get()); + session->ClientHandler->OnClientConnectionEvent(event); + } + + bool TRemoteConnection::IsAlive() const { + return !AtomicGet(WriterData.Down); + } + + } +} diff --git a/library/cpp/messagebus/remote_connection.h b/library/cpp/messagebus/remote_connection.h new file mode 100644 index 0000000000..4538947368 --- /dev/null +++ b/library/cpp/messagebus/remote_connection.h @@ -0,0 +1,294 @@ +#pragma once + +#include "async_result.h" +#include "defs.h" +#include "event_loop.h" +#include "left_right_buffer.h" +#include "lfqueue_batch.h" +#include "message_ptr_and_header.h" +#include "nondestroying_holder.h" +#include "remote_connection_status.h" +#include "scheduler_actor.h" +#include "socket_addr.h" +#include "storage.h" +#include "vector_swaps.h" +#include "ybus.h" +#include "misc/granup.h" +#include "misc/tokenquota.h" + +#include <library/cpp/messagebus/actor/actor.h> +#include <library/cpp/messagebus/actor/executor.h> +#include <library/cpp/messagebus/actor/queue_for_actor.h> +#include <library/cpp/messagebus/actor/queue_in_actor.h> +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <util/system/atomic.h> +#include <util/system/event.h> +#include <util/thread/lfstack.h> + +namespace NBus { + namespace NPrivate { + class TRemoteConnection; + + typedef TIntrusivePtr<TRemoteConnection> TRemoteConnectionPtr; + typedef TIntrusivePtr<TBusSessionImpl> TRemoteSessionPtr; + + static void* const WriteCookie = (void*)1; + static void* const ReadCookie = (void*)2; + + enum { + WAKE_QUOTA_MSG = 0x01, + WAKE_QUOTA_BYTES = 0x02 + }; + + struct TWriterTag {}; + struct TReaderTag {}; + struct TReconnectTag {}; + struct TWakeReaderTag {}; + + struct TWriterToReaderSocketMessage { + TSocket Socket; + ui32 SocketVersion; + + TWriterToReaderSocketMessage(TSocket socket, ui32 socketVersion) + : Socket(socket) + , SocketVersion(socketVersion) + { + } + }; + + class TRemoteConnection + : public NEventLoop::IEventHandler, + public ::NActor::TActor<TRemoteConnection, TWriterTag>, + public ::NActor::TActor<TRemoteConnection, TReaderTag>, + private ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>, + private ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>, + private ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>, + public TScheduleActor<TRemoteConnection, TWriterTag> { + friend struct TBusSessionImpl; + friend class TRemoteClientSession; + friend class TRemoteServerSession; + friend class ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>; + friend class ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>; + friend class ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>; + + protected: + ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>* ReaderGetSocketQueue() { + return this; + } + + ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>* WriterGetReconnectQueue() { + return this; + } + + ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>* WriterGetWakeQueue() { + return this; + } + + protected: + TRemoteConnection(TRemoteSessionPtr session, ui64 connectionId, TNetAddr addr); + ~TRemoteConnection() override; + + virtual void ClearOutgoingQueue(TMessagesPtrs&, bool reconnect /* or shutdown */); + + public: + void Send(TNonDestroyingAutoPtr<TBusMessage> msg); + void Shutdown(EMessageStatus status); + + inline const TNetAddr& GetAddr() const noexcept; + + private: + friend class TScheduleConnect; + friend class TWorkIO; + + protected: + static size_t MessageSize(TArrayRef<TBusMessagePtrAndHeader>); + bool QuotaAcquire(size_t msg, size_t bytes); + void QuotaConsume(size_t msg, size_t bytes); + void QuotaReturnSelf(size_t items, size_t bytes); + bool QuotaReturnValues(size_t items, size_t bytes); + + bool ReaderProcessBuffer(); + bool ReaderFillBuffer(); + void ReaderFlushMessages(); + + void ReadQuotaWakeup(); + ui32 WriteWakeFlags() const; + + virtual bool NeedInterruptRead() { + return false; + } + + public: + virtual void TryConnect(); + void ProcessItem(TReaderTag, ::NActor::TDefaultTag, TWriterToReaderSocketMessage); + void ProcessItem(TWriterTag, TReconnectTag, ui32 socketVersion); + void ProcessItem(TWriterTag, TWakeReaderTag, ui32 awakeFlags); + void Act(TReaderTag); + inline void WriterBeforeWriteErrorMessage(TBusMessage*, EMessageStatus); + void ClearBeforeSendQueue(EMessageStatus reasonForQueues); + void ClearReplyQueue(EMessageStatus reasonForQueues); + inline void ProcessBeforeSendQueueMessage(TBusMessage*, TInstant now); + void ProcessBeforeSendQueue(TInstant now); + void WriterProcessStatusDown(); + void ReaderProcessStatusDown(); + void ProcessWriterDown(); + void DropEnqueuedData(EMessageStatus reason, EMessageStatus reasonForQueues); + const TRemoteConnectionWriterStatus& WriterGetStatus(); + virtual void WriterFillStatus(); + void WriterFillInFlight(); + virtual void BeforeTryWrite(); + void Act(TWriterTag); + void ScheduleRead(); + void ScheduleWrite(); + void ScheduleShutdownOnServerOrReconnectOnClient(EMessageStatus status, bool writer); + void ScheduleShutdown(EMessageStatus status); + void WriterFlushBuffer(); + void WriterFillBuffer(); + void ReaderSendStatus(TInstant now, bool force = false); + const TRemoteConnectionReaderStatus& ReaderFillStatus(); + void WriterRotateCounters(); + void WriterSendStatus(TInstant now, bool force = false); + void WriterSendStatusIfNecessary(TInstant now); + void QuotaReturnAside(size_t items, size_t bytes); + virtual void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) = 0; + bool MessageRead(TArrayRef<const char> dataRef, TInstant now); + virtual void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) = 0; + + void CallSerialize(TBusMessage* msg, TBuffer& buffer) const; + void SerializeMessage(TBusMessage* msg, TBuffer* data, TMessageCounter* counter) const; + TBusMessage* DeserializeMessage(TArrayRef<const char> dataRef, const TBusHeader* header, TMessageCounter* messageCounter, EMessageStatus* status) const; + + void ResetOneWayFlag(TArrayRef<TBusMessage*>); + + inline ::NActor::TActor<TRemoteConnection, TWriterTag>* GetWriterActor() { + return this; + } + inline ::NActor::TActor<TRemoteConnection, TReaderTag>* GetReaderActor() { + return this; + } + inline TScheduleActor<TRemoteConnection, TWriterTag>* GetWriterSchedulerActor() { + return this; + } + + void WriterErrorMessage(TNonDestroyingAutoPtr<TBusMessage> m, EMessageStatus status); + // takes ownership of ms + void WriterErrorMessages(const TArrayRef<TBusMessage*> ms, EMessageStatus status); + + void FireClientConnectionEvent(TClientConnectionEvent::EType); + + size_t GetInFlight(); + size_t GetConnectSyscallsNumForTest(); + + bool IsReturnConnectFailedImmediately() { + return (bool)AtomicGet(ReturnConnectFailedImmediately); + } + + bool IsAlive() const; + + TRemoteSessionPtr Session; + TBusProtocol* const Proto; + TBusSessionConfig const Config; + bool RemovedFromSession; + const ui64 ConnectionId; + const TNetAddr PeerAddr; + const TBusSocketAddr PeerAddrSocketAddr; + + const TInstant CreatedTime; + TInstant LastConnectAttempt; + TAtomic ReturnConnectFailedImmediately; + + protected: + ::NActor::TQueueForActor<TBusMessage*> BeforeSendQueue; + TLockFreeStack<TBusHeader> WrongVersionRequests; + + struct TWriterData { + TAtomic Down; + + NEventLoop::TChannelPtr Channel; + ui32 SocketVersion; + + TRemoteConnectionWriterStatus Status; + TInstant StatusLastSendTime; + + TLocalTasks TimeToRotateCounters; + + TAtomic InFlight; + + TTimedMessages SendQueue; + ui32 AwakeFlags; + EWriterState State; + TLeftRightBuffer Buffer; + TInstant CorkUntil; + + TSystemEvent ShutdownComplete; + + void SetChannel(NEventLoop::TChannelPtr channel); + void DropChannel(); + + TWriterData(); + ~TWriterData(); + }; + + struct TReaderData { + TAtomic Down; + + NEventLoop::TChannelPtr Channel; + ui32 SocketVersion; + + TRemoteConnectionReaderStatus Status; + TInstant StatusLastSendTime; + + TBuffer Buffer; + size_t Offset; /* offset in read buffer */ + size_t MoreBytes; /* more bytes required from socket */ + TVectorSwaps<TBusMessagePtrAndHeader> ReadMessages; + + TSystemEvent ShutdownComplete; + + bool BufferMore() const noexcept { + return MoreBytes > 0; + } + + bool HasBytesInBuf(size_t bytes) noexcept; + void SetChannel(NEventLoop::TChannelPtr channel); + void DropChannel(); + + TReaderData(); + ~TReaderData(); + }; + + // owned by session status actor + struct TGranStatus { + TGranStatus(TDuration gran) + : Writer(gran) + , Reader(gran) + { + } + + TGranUp<TRemoteConnectionWriterStatus> Writer; + TGranUp<TRemoteConnectionReaderStatus> Reader; + }; + + TWriterData WriterData; + TReaderData ReaderData; + TGranStatus GranStatus; + TTokenQuota QuotaMsg; + TTokenQuota QuotaBytes; + + size_t MaxBufferSize; + + // client connection only + TLockFreeQueueBatch<TBusMessagePtrAndHeader, TVectorSwaps> ReplyQueue; + + EMessageStatus ShutdownReason; + }; + + inline const TNetAddr& TRemoteConnection::GetAddr() const noexcept { + return PeerAddr; + } + + typedef TIntrusivePtr<TRemoteConnection> TRemoteConnectionPtr; + + } +} diff --git a/library/cpp/messagebus/remote_connection_status.cpp b/library/cpp/messagebus/remote_connection_status.cpp new file mode 100644 index 0000000000..2c48b2a287 --- /dev/null +++ b/library/cpp/messagebus/remote_connection_status.cpp @@ -0,0 +1,265 @@ +#include "remote_connection_status.h" + +#include "key_value_printer.h" + +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +#include <util/stream/format.h> +#include <util/stream/output.h> +#include <util/system/yassert.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +template <typename T> +static void Add(T& thiz, const T& that) { + thiz += that; +} + +template <typename T> +static void Max(T& thiz, const T& that) { + if (that > thiz) { + thiz = that; + } +} + +template <typename T> +static void AssertZero(T& thiz, const T& that) { + Y_ASSERT(thiz == T()); + Y_UNUSED(that); +} + +TDurationCounter::TDurationCounter() + : DURATION_COUNTER_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TDuration TDurationCounter::AvgDuration() const { + if (Count == 0) { + return TDuration::Zero(); + } else { + return SumDuration / Count; + } +} + +TDurationCounter& TDurationCounter::operator+=(const TDurationCounter& that) { + DURATION_COUNTER_MAP(STRUCT_FIELD_ADD, ) + return *this; +} + +TString TDurationCounter::ToString() const { + if (Count == 0) { + return "0"; + } else { + TStringStream ss; + ss << "avg: " << AvgDuration() << ", max: " << MaxDuration << ", count: " << Count; + return ss.Str(); + } +} + +TRemoteConnectionStatusBase::TRemoteConnectionStatusBase() + : REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionStatusBase& TRemoteConnectionStatusBase ::operator+=(const TRemoteConnectionStatusBase& that) { + REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_ADD, ) + return *this; +} + +TRemoteConnectionIncrementalStatusBase::TRemoteConnectionIncrementalStatusBase() + : REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionIncrementalStatusBase& TRemoteConnectionIncrementalStatusBase::operator+=( + const TRemoteConnectionIncrementalStatusBase& that) { + REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_ADD, ) + return *this; +} + +TRemoteConnectionReaderIncrementalStatus::TRemoteConnectionReaderIncrementalStatus() + : REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionReaderIncrementalStatus& TRemoteConnectionReaderIncrementalStatus::operator+=( + const TRemoteConnectionReaderIncrementalStatus& that) { + TRemoteConnectionIncrementalStatusBase::operator+=(that); + REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_ADD, ) + return *this; +} + +TRemoteConnectionReaderStatus::TRemoteConnectionReaderStatus() + : REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionReaderStatus& TRemoteConnectionReaderStatus::operator+=(const TRemoteConnectionReaderStatus& that) { + TRemoteConnectionStatusBase::operator+=(that); + REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_ADD, ) + return *this; +} + +TRemoteConnectionWriterIncrementalStatus::TRemoteConnectionWriterIncrementalStatus() + : REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionWriterIncrementalStatus& TRemoteConnectionWriterIncrementalStatus::operator+=( + const TRemoteConnectionWriterIncrementalStatus& that) { + TRemoteConnectionIncrementalStatusBase::operator+=(that); + REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_ADD, ) + return *this; +} + +TRemoteConnectionWriterStatus::TRemoteConnectionWriterStatus() + : REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TRemoteConnectionWriterStatus& TRemoteConnectionWriterStatus::operator+=(const TRemoteConnectionWriterStatus& that) { + TRemoteConnectionStatusBase::operator+=(that); + REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_ADD, ) + return *this; +} + +size_t TRemoteConnectionWriterStatus::GetInFlight() const { + return SendQueueSize + AckMessagesSize; +} + +TConnectionStatusMonRecord TRemoteConnectionStatus::GetStatusProtobuf() const { + TConnectionStatusMonRecord status; + + // TODO: fill unfilled fields + status.SetSendQueueSize(WriterStatus.SendQueueSize); + status.SetAckMessagesSize(WriterStatus.AckMessagesSize); + // status.SetErrorCount(); + // status.SetWriteBytes(); + // status.SetWriteBytesCompressed(); + // status.SetWriteMessages(); + status.SetWriteSyscalls(WriterStatus.Incremental.NetworkOps); + status.SetWriteActs(WriterStatus.Acts); + // status.SetReadBytes(); + // status.SetReadBytesCompressed(); + // status.SetReadMessages(); + status.SetReadSyscalls(ReaderStatus.Incremental.NetworkOps); + status.SetReadActs(ReaderStatus.Acts); + + TMessageStatusCounter sumStatusCounter; + sumStatusCounter += WriterStatus.Incremental.StatusCounter; + sumStatusCounter += ReaderStatus.Incremental.StatusCounter; + sumStatusCounter.FillErrorsProtobuf(&status); + + return status; +} + +TString TRemoteConnectionStatus::PrintToString() const { + TStringStream ss; + + TKeyValuePrinter p; + + if (!Summary) { + // TODO: print MyAddr too, but only if it is set + ss << WriterStatus.PeerAddr << " (" << WriterStatus.ConnectionId << ")" + << ", writefd=" << WriterStatus.Fd + << ", readfd=" << ReaderStatus.Fd + << Endl; + if (WriterStatus.Connected) { + p.AddRow("connect time", WriterStatus.ConnectTime.ToString()); + p.AddRow("writer state", ToCString(WriterStatus.State)); + } else { + ss << "not connected"; + if (WriterStatus.ConnectError != 0) { + ss << ", last connect error: " << LastSystemErrorText(WriterStatus.ConnectError); + } + ss << Endl; + } + } + if (!Server) { + p.AddRow("connect syscalls", WriterStatus.ConnectSyscalls); + } + + p.AddRow("send queue", LeftPad(WriterStatus.SendQueueSize, 6)); + + if (Server) { + p.AddRow("quota msg", LeftPad(ReaderStatus.QuotaMsg, 6)); + p.AddRow("quota bytes", LeftPad(ReaderStatus.QuotaBytes, 6)); + p.AddRow("quota exhausted", LeftPad(ReaderStatus.QuotaExhausted, 6)); + p.AddRow("reader wakeups", LeftPad(WriterStatus.ReaderWakeups, 6)); + } else { + p.AddRow("ack messages", LeftPad(WriterStatus.AckMessagesSize, 6)); + } + + p.AddRow("written", WriterStatus.Incremental.MessageCounter.ToString(false)); + p.AddRow("read", ReaderStatus.Incremental.MessageCounter.ToString(true)); + + p.AddRow("write syscalls", LeftPad(WriterStatus.Incremental.NetworkOps, 12)); + p.AddRow("read syscalls", LeftPad(ReaderStatus.Incremental.NetworkOps, 12)); + + p.AddRow("write acts", LeftPad(WriterStatus.Acts, 12)); + p.AddRow("read acts", LeftPad(ReaderStatus.Acts, 12)); + + p.AddRow("write buffer cap", LeftPad(WriterStatus.BufferSize, 12)); + p.AddRow("read buffer cap", LeftPad(ReaderStatus.BufferSize, 12)); + + p.AddRow("write buffer drops", LeftPad(WriterStatus.Incremental.BufferDrops, 10)); + p.AddRow("read buffer drops", LeftPad(ReaderStatus.Incremental.BufferDrops, 10)); + + if (Server) { + p.AddRow("process dur", WriterStatus.DurationCounterPrev.ToString()); + } + + ss << p.PrintToString(); + + if (false && Server) { + ss << "time histogram:\n"; + ss << WriterStatus.Incremental.ProcessDurationHistogram.PrintToString(); + } + + TMessageStatusCounter sumStatusCounter; + sumStatusCounter += WriterStatus.Incremental.StatusCounter; + sumStatusCounter += ReaderStatus.Incremental.StatusCounter; + + ss << sumStatusCounter.PrintToString(); + + return ss.Str(); +} + +TRemoteConnectionStatus::TRemoteConnectionStatus() + : REMOTE_CONNECTION_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA) +{ +} + +TString TSessionDumpStatus::PrintToString() const { + if (Shutdown) { + return "shutdown"; + } + + TStringStream ss; + ss << Head; + if (ConnectionStatusSummary.Server) { + ss << "\n"; + ss << Acceptors; + } + ss << "\n"; + ss << "connections summary:" << Endl; + ss << ConnectionsSummary; + if (!!Connections) { + ss << "\n"; + ss << Connections; + } + ss << "\n"; + ss << Config.PrintToString(); + return ss.Str(); +} + +TString TBusMessageQueueStatus::PrintToString() const { + TStringStream ss; + ss << "work queue:\n"; + ss << ExecutorStatus.Status; + ss << "\n"; + ss << "queue config:\n"; + ss << Config.PrintToString(); + return ss.Str(); +} diff --git a/library/cpp/messagebus/remote_connection_status.h b/library/cpp/messagebus/remote_connection_status.h new file mode 100644 index 0000000000..5db10e51ea --- /dev/null +++ b/library/cpp/messagebus/remote_connection_status.h @@ -0,0 +1,214 @@ +#pragma once + +#include "codegen.h" +#include "duration_histogram.h" +#include "message_counter.h" +#include "message_status_counter.h" +#include "queue_config.h" +#include "session_config.h" + +#include <library/cpp/messagebus/actor/executor.h> + +#include <library/cpp/deprecated/enum_codegen/enum_codegen.h> + +namespace NBus { + class TConnectionStatusMonRecord; +} + +namespace NBus { + namespace NPrivate { +#define WRITER_STATE_MAP(XX) \ + XX(WRITER_UNKNOWN) \ + XX(WRITER_FILLING) \ + XX(WRITER_FLUSHING) \ + /**/ + + // TODO: move elsewhere + enum EWriterState { + WRITER_STATE_MAP(ENUM_VALUE_GEN_NO_VALUE) + }; + + ENUM_TO_STRING(EWriterState, WRITER_STATE_MAP) + +#define STRUCT_FIELD_ADD(name, type, func) func(name, that.name); + + template <typename T> + void Reset(T& t) { + t.~T(); + new (&t) T(); + } + +#define DURATION_COUNTER_MAP(XX, comma) \ + XX(Count, unsigned, Add) \ + comma \ + XX(SumDuration, TDuration, Add) comma \ + XX(MaxDuration, TDuration, Max) /**/ + + struct TDurationCounter { + DURATION_COUNTER_MAP(STRUCT_FIELD_GEN, ) + + TDuration AvgDuration() const; + + TDurationCounter(); + + void AddDuration(TDuration d) { + Count += 1; + SumDuration += d; + if (d > MaxDuration) { + MaxDuration = d; + } + } + + TDurationCounter& operator+=(const TDurationCounter&); + + TString ToString() const; + }; + +#define REMOTE_CONNECTION_STATUS_BASE_MAP(XX, comma) \ + XX(ConnectionId, ui64, AssertZero) \ + comma \ + XX(Fd, SOCKET, AssertZero) comma \ + XX(Acts, ui64, Add) comma \ + XX(BufferSize, ui64, Add) /**/ + + struct TRemoteConnectionStatusBase { + REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_GEN, ) + + TRemoteConnectionStatusBase& operator+=(const TRemoteConnectionStatusBase&); + + TRemoteConnectionStatusBase(); + }; + +#define REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(XX, comma) \ + XX(BufferDrops, unsigned, Add) \ + comma \ + XX(NetworkOps, unsigned, Add) /**/ + + struct TRemoteConnectionIncrementalStatusBase { + REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_GEN, ) + + TRemoteConnectionIncrementalStatusBase& operator+=(const TRemoteConnectionIncrementalStatusBase&); + + TRemoteConnectionIncrementalStatusBase(); + }; + +#define REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(XX, comma) \ + XX(MessageCounter, TMessageCounter, Add) \ + comma \ + XX(StatusCounter, TMessageStatusCounter, Add) /**/ + + struct TRemoteConnectionReaderIncrementalStatus: public TRemoteConnectionIncrementalStatusBase { + REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_GEN, ) + + TRemoteConnectionReaderIncrementalStatus& operator+=(const TRemoteConnectionReaderIncrementalStatus&); + + TRemoteConnectionReaderIncrementalStatus(); + }; + +#define REMOTE_CONNECTION_READER_STATUS_MAP(XX, comma) \ + XX(QuotaMsg, size_t, Add) \ + comma \ + XX(QuotaBytes, size_t, Add) comma \ + XX(QuotaExhausted, size_t, Add) comma \ + XX(Incremental, TRemoteConnectionReaderIncrementalStatus, Add) /**/ + + struct TRemoteConnectionReaderStatus: public TRemoteConnectionStatusBase { + REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_GEN, ) + + TRemoteConnectionReaderStatus& operator+=(const TRemoteConnectionReaderStatus&); + + TRemoteConnectionReaderStatus(); + }; + +#define REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(XX, comma) \ + XX(MessageCounter, TMessageCounter, Add) \ + comma \ + XX(StatusCounter, TMessageStatusCounter, Add) comma \ + XX(ProcessDurationHistogram, TDurationHistogram, Add) /**/ + + struct TRemoteConnectionWriterIncrementalStatus: public TRemoteConnectionIncrementalStatusBase { + REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_GEN, ) + + TRemoteConnectionWriterIncrementalStatus& operator+=(const TRemoteConnectionWriterIncrementalStatus&); + + TRemoteConnectionWriterIncrementalStatus(); + }; + +#define REMOTE_CONNECTION_WRITER_STATUS(XX, comma) \ + XX(Connected, bool, AssertZero) \ + comma \ + XX(ConnectTime, TInstant, AssertZero) comma /* either connect time on client or accept time on server */ \ + XX(ConnectError, int, AssertZero) comma \ + XX(ConnectSyscalls, unsigned, Add) comma \ + XX(PeerAddr, TNetAddr, AssertZero) comma \ + XX(MyAddr, TNetAddr, AssertZero) comma \ + XX(State, EWriterState, AssertZero) comma \ + XX(SendQueueSize, size_t, Add) comma \ + XX(AckMessagesSize, size_t, Add) comma /* client only */ \ + XX(DurationCounter, TDurationCounter, Add) comma /* server only */ \ + XX(DurationCounterPrev, TDurationCounter, Add) comma /* server only */ \ + XX(Incremental, TRemoteConnectionWriterIncrementalStatus, Add) comma \ + XX(ReaderWakeups, size_t, Add) /**/ + + struct TRemoteConnectionWriterStatus: public TRemoteConnectionStatusBase { + REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_GEN, ) + + TRemoteConnectionWriterStatus(); + + TRemoteConnectionWriterStatus& operator+=(const TRemoteConnectionWriterStatus&); + + size_t GetInFlight() const; + }; + +#define REMOTE_CONNECTION_STATUS_MAP(XX, comma) \ + XX(Summary, bool) \ + comma \ + XX(Server, bool) /**/ + + struct TRemoteConnectionStatus { + REMOTE_CONNECTION_STATUS_MAP(STRUCT_FIELD_GEN, ) + + TRemoteConnectionReaderStatus ReaderStatus; + TRemoteConnectionWriterStatus WriterStatus; + + TRemoteConnectionStatus(); + + TString PrintToString() const; + TConnectionStatusMonRecord GetStatusProtobuf() const; + }; + + struct TBusSessionStatus { + size_t InFlightCount; + size_t InFlightSize; + bool InputPaused; + + TBusSessionStatus(); + }; + + struct TSessionDumpStatus { + bool Shutdown; + TString Head; + TString Acceptors; + TString ConnectionsSummary; + TString Connections; + TBusSessionStatus Status; + TRemoteConnectionStatus ConnectionStatusSummary; + TBusSessionConfig Config; + + TSessionDumpStatus() + : Shutdown(false) + { + } + + TString PrintToString() const; + }; + + // without sessions + struct TBusMessageQueueStatus { + NActor::NPrivate::TExecutorStatus ExecutorStatus; + TBusQueueConfig Config; + + TString PrintToString() const; + }; + } +} diff --git a/library/cpp/messagebus/remote_server_connection.cpp b/library/cpp/messagebus/remote_server_connection.cpp new file mode 100644 index 0000000000..74be34ded9 --- /dev/null +++ b/library/cpp/messagebus/remote_server_connection.cpp @@ -0,0 +1,73 @@ +#include "remote_server_connection.h" + +#include "mb_lwtrace.h" +#include "remote_server_session.h" + +#include <util/generic/cast.h> + +LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER) + +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteServerConnection::TRemoteServerConnection(TRemoteServerSessionPtr session, ui64 id, TNetAddr addr) + : TRemoteConnection(session.Get(), id, addr) +{ +} + +void TRemoteServerConnection::Init(SOCKET socket, TInstant now) { + WriterData.Status.ConnectTime = now; + WriterData.Status.Connected = true; + + Y_VERIFY(socket != INVALID_SOCKET, "must be a valid socket"); + + TSocket readSocket(socket); + TSocket writeSocket = readSocket; + + // this must not be done in constructor, because if event loop is stopped, + // this is deleted + WriterData.SetChannel(Session->WriteEventLoop.Register(writeSocket, this, WriteCookie)); + WriterData.SocketVersion = 1; + + ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(readSocket, WriterData.SocketVersion)); +} + +TRemoteServerSession* TRemoteServerConnection::GetSession() { + return CheckedCast<TRemoteServerSession*>(Session.Get()); +} + +void TRemoteServerConnection::HandleEvent(SOCKET socket, void* cookie) { + Y_UNUSED(socket); + Y_ASSERT(cookie == ReadCookie || cookie == WriteCookie); + if (cookie == ReadCookie) { + GetSession()->ServerOwnedMessages.Wait(); + ScheduleRead(); + } else { + ScheduleWrite(); + } +} + +bool TRemoteServerConnection::NeedInterruptRead() { + return !GetSession()->ServerOwnedMessages.TryWait(); +} + +void TRemoteServerConnection::MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) { + TInstant now = TInstant::Now(); + + GetSession()->ReleaseInWorkResponses(messages); + for (auto& message : messages) { + TInstant recvTime = message.MessagePtr->RecvTime; + GetSession()->ServerHandler->OnSent(message.MessagePtr.Release()); + TDuration d = now - recvTime; + WriterData.Status.DurationCounter.AddDuration(d); + WriterData.Status.Incremental.ProcessDurationHistogram.AddTime(d); + } +} + +void TRemoteServerConnection::ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) { + TBusHeader header(dataRef); + // TODO: full version hex + LWPROBE(ServerUnknownVersion, ToString(PeerAddr), header.GetVersionInternal()); + WrongVersionRequests.Enqueue(header); + GetWriterActor()->Schedule(); +} diff --git a/library/cpp/messagebus/remote_server_connection.h b/library/cpp/messagebus/remote_server_connection.h new file mode 100644 index 0000000000..63d7f20646 --- /dev/null +++ b/library/cpp/messagebus/remote_server_connection.h @@ -0,0 +1,32 @@ +#pragma once + +#include "session_impl.h" + +#include <util/generic/object_counter.h> + +namespace NBus { + namespace NPrivate { + class TRemoteServerConnection: public TRemoteConnection { + friend struct TBusSessionImpl; + friend class TRemoteServerSession; + + TObjectCounter<TRemoteServerConnection> ObjectCounter; + + public: + TRemoteServerConnection(TRemoteServerSessionPtr session, ui64 id, TNetAddr addr); + + void Init(SOCKET socket, TInstant now); + + inline TRemoteServerSession* GetSession(); + + void HandleEvent(SOCKET socket, void* cookie) override; + + bool NeedInterruptRead() override; + + void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) override; + + void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) override; + }; + + } +} diff --git a/library/cpp/messagebus/remote_server_session.cpp b/library/cpp/messagebus/remote_server_session.cpp new file mode 100644 index 0000000000..6abbf88a60 --- /dev/null +++ b/library/cpp/messagebus/remote_server_session.cpp @@ -0,0 +1,206 @@ +#include "remote_server_session.h" + +#include "remote_connection.h" +#include "remote_server_connection.h" + +#include <library/cpp/messagebus/actor/temp_tls_vector.h> + +#include <util/generic/cast.h> +#include <util/stream/output.h> +#include <util/system/yassert.h> + +#include <typeinfo> + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteServerSession::TRemoteServerSession(TBusMessageQueue* queue, + TBusProtocol* proto, IBusServerHandler* handler, + const TBusServerSessionConfig& config, const TString& name) + : TBusSessionImpl(false, queue, proto, handler, config, name) + , ServerOwnedMessages(config.MaxInFlight, config.MaxInFlightBySize, "ServerOwnedMessages") + , ServerHandler(handler) +{ + if (config.PerConnectionMaxInFlightBySize > 0) { + if (config.PerConnectionMaxInFlightBySize < config.MaxMessageSize) + ythrow yexception() + << "too low PerConnectionMaxInFlightBySize value"; + } +} + +namespace NBus { + namespace NPrivate { + class TInvokeOnMessage: public IWorkItem { + private: + TRemoteServerSession* RemoteServerSession; + TBusMessagePtrAndHeader Request; + TIntrusivePtr<TRemoteServerConnection> Connection; + + public: + TInvokeOnMessage(TRemoteServerSession* session, TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& connection) + : RemoteServerSession(session) + { + Y_ASSERT(!!connection); + Connection.Swap(connection); + + Request.Swap(request); + } + + void DoWork() override { + THolder<TInvokeOnMessage> holder(this); + RemoteServerSession->InvokeOnMessage(Request, Connection); + // TODO: TRemoteServerSessionSemaphore should be enough + RemoteServerSession->JobCount.Decrement(); + } + }; + + } +} + +void TRemoteServerSession::OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& messages) { + AcquireInWorkRequests(messages); + + bool executeInPool = Config.ExecuteOnMessageInWorkerPool; + + TTempTlsVector< ::IWorkItem*> workQueueTemp; + + if (executeInPool) { + workQueueTemp.GetVector()->reserve(messages.size()); + } + + for (auto& message : messages) { + // TODO: incref once + TIntrusivePtr<TRemoteServerConnection> connection(CheckedCast<TRemoteServerConnection*>(c)); + if (executeInPool) { + workQueueTemp.GetVector()->push_back(new TInvokeOnMessage(this, message, connection)); + } else { + InvokeOnMessage(message, connection); + } + } + + if (executeInPool) { + JobCount.Add(workQueueTemp.GetVector()->size()); + Queue->EnqueueWork(*workQueueTemp.GetVector()); + } +} + +void TRemoteServerSession::InvokeOnMessage(TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& conn) { + if (Y_UNLIKELY(AtomicGet(Down))) { + ReleaseInWorkRequests(*conn.Get(), request.MessagePtr.Get()); + InvokeOnError(request.MessagePtr.Release(), MESSAGE_SHUTDOWN); + } else { + TWhatThreadDoesPushPop pp("OnMessage"); + + TBusIdentity ident; + + ident.Connection.Swap(conn); + request.MessagePtr->GetIdentity(ident); + + Y_ASSERT(request.MessagePtr->LocalFlags & MESSAGE_IN_WORK); + DoSwap(request.MessagePtr->LocalFlags, ident.LocalFlags); + + ident.RecvTime = request.MessagePtr->RecvTime; + +#ifndef NDEBUG + auto& message = *request.MessagePtr; + ident.SetMessageType(typeid(message)); +#endif + + TOnMessageContext context(request.MessagePtr.Release(), ident, this); + ServerHandler->OnMessage(context); + } +} + +EMessageStatus TRemoteServerSession::ForgetRequest(const TBusIdentity& ident) { + ReleaseInWork(const_cast<TBusIdentity&>(ident)); + + return MESSAGE_OK; +} + +EMessageStatus TRemoteServerSession::SendReply(const TBusIdentity& ident, TBusMessage* reply) { + reply->CheckClean(); + + ConvertInWork(const_cast<TBusIdentity&>(ident), reply); + + reply->RecvTime = ident.RecvTime; + + ident.Connection->Send(reply); + + return MESSAGE_OK; +} + +int TRemoteServerSession::GetInFlight() const noexcept { + return ServerOwnedMessages.GetCurrentCount(); +} + +void TRemoteServerSession::FillStatus() { + TBusSessionImpl::FillStatus(); + + // TODO: weird + StatusData.Status.InFlightCount = ServerOwnedMessages.GetCurrentCount(); + StatusData.Status.InFlightSize = ServerOwnedMessages.GetCurrentSize(); + StatusData.Status.InputPaused = ServerOwnedMessages.IsLocked(); +} + +void TRemoteServerSession::AcquireInWorkRequests(TArrayRef<const TBusMessagePtrAndHeader> messages) { + TAtomicBase size = 0; + for (auto message = messages.begin(); message != messages.end(); ++message) { + Y_ASSERT(!(message->MessagePtr->LocalFlags & MESSAGE_IN_WORK)); + message->MessagePtr->LocalFlags |= MESSAGE_IN_WORK; + size += message->MessagePtr->GetHeader()->Size; + } + + ServerOwnedMessages.IncrementMultiple(messages.size(), size); +} + +void TRemoteServerSession::ReleaseInWorkResponses(TArrayRef<const TBusMessagePtrAndHeader> responses) { + TAtomicBase size = 0; + for (auto response = responses.begin(); response != responses.end(); ++response) { + Y_ASSERT((response->MessagePtr->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT)); + response->MessagePtr->LocalFlags &= ~MESSAGE_REPLY_IS_BEGING_SENT; + size += response->MessagePtr->RequestSize; + } + + ServerOwnedMessages.ReleaseMultiple(responses.size(), size); +} + +void TRemoteServerSession::ReleaseInWorkRequests(TRemoteConnection& con, TBusMessage* request) { + Y_ASSERT((request->LocalFlags & MESSAGE_IN_WORK)); + request->LocalFlags &= ~MESSAGE_IN_WORK; + + const size_t size = request->GetHeader()->Size; + + con.QuotaReturnAside(1, size); + ServerOwnedMessages.ReleaseMultiple(1, size); +} + +void TRemoteServerSession::ReleaseInWork(TBusIdentity& ident) { + ident.SetInWork(false); + ident.Connection->QuotaReturnAside(1, ident.Size); + + ServerOwnedMessages.ReleaseMultiple(1, ident.Size); +} + +void TRemoteServerSession::ConvertInWork(TBusIdentity& req, TBusMessage* reply) { + reply->SetIdentity(req); + + req.SetInWork(false); + Y_ASSERT(!(reply->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT)); + reply->LocalFlags |= MESSAGE_REPLY_IS_BEGING_SENT; + reply->RequestSize = req.Size; +} + +void TRemoteServerSession::Shutdown() { + ServerOwnedMessages.Stop(); + TBusSessionImpl::Shutdown(); +} + +void TRemoteServerSession::PauseInput(bool pause) { + ServerOwnedMessages.PauseByUsed(pause); +} + +unsigned TRemoteServerSession::GetActualListenPort() { + Y_VERIFY(Config.ListenPort > 0, "state check"); + return Config.ListenPort; +} diff --git a/library/cpp/messagebus/remote_server_session.h b/library/cpp/messagebus/remote_server_session.h new file mode 100644 index 0000000000..f5c266a7f7 --- /dev/null +++ b/library/cpp/messagebus/remote_server_session.h @@ -0,0 +1,54 @@ +#pragma once + +#include "remote_server_session_semaphore.h" +#include "session_impl.h" + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance +#endif + +namespace NBus { + namespace NPrivate { + class TRemoteServerSession: public TBusServerSession, public TBusSessionImpl { + friend class TRemoteServerConnection; + + private: + TObjectCounter<TRemoteServerSession> ObjectCounter; + + TRemoteServerSessionSemaphore ServerOwnedMessages; + IBusServerHandler* const ServerHandler; + + public: + TRemoteServerSession(TBusMessageQueue* queue, TBusProtocol* proto, + IBusServerHandler* handler, + const TBusSessionConfig& config, const TString& name); + + void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) override; + void InvokeOnMessage(TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& conn); + + EMessageStatus SendReply(const TBusIdentity& ident, TBusMessage* pRep) override; + + EMessageStatus ForgetRequest(const TBusIdentity& ident) override; + + int GetInFlight() const noexcept override; + void FillStatus() override; + + void Shutdown() override; + + void PauseInput(bool pause) override; + unsigned GetActualListenPort() override; + + void AcquireInWorkRequests(TArrayRef<const TBusMessagePtrAndHeader> requests); + void ReleaseInWorkResponses(TArrayRef<const TBusMessagePtrAndHeader> responses); + void ReleaseInWorkRequests(TRemoteConnection&, TBusMessage*); + void ReleaseInWork(TBusIdentity&); + void ConvertInWork(TBusIdentity& req, TBusMessage* reply); + }; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + + } +} diff --git a/library/cpp/messagebus/remote_server_session_semaphore.cpp b/library/cpp/messagebus/remote_server_session_semaphore.cpp new file mode 100644 index 0000000000..6094a3586e --- /dev/null +++ b/library/cpp/messagebus/remote_server_session_semaphore.cpp @@ -0,0 +1,59 @@ +#include "remote_server_session_semaphore.h" + +#include <util/stream/output.h> +#include <util/system/yassert.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TRemoteServerSessionSemaphore::TRemoteServerSessionSemaphore( + TAtomicBase limitCount, TAtomicBase limitSize, const char* name) + : Name(name) + , LimitCount(limitCount) + , LimitSize(limitSize) + , CurrentCount(0) + , CurrentSize(0) + , PausedByUser(0) + , StopSignal(0) +{ + Y_VERIFY(limitCount > 0, "limit must be > 0"); + Y_UNUSED(Name); +} + +TRemoteServerSessionSemaphore::~TRemoteServerSessionSemaphore() { + Y_VERIFY(AtomicGet(CurrentCount) == 0); + // TODO: fix spider and enable + //Y_VERIFY(AtomicGet(CurrentSize) == 0); +} + +bool TRemoteServerSessionSemaphore::TryWait() { + if (Y_UNLIKELY(AtomicGet(StopSignal))) + return true; + if (AtomicGet(PausedByUser)) + return false; + if (AtomicGet(CurrentCount) < LimitCount && (LimitSize < 0 || AtomicGet(CurrentSize) < LimitSize)) + return true; + return false; +} + +void TRemoteServerSessionSemaphore::IncrementMultiple(TAtomicBase count, TAtomicBase size) { + AtomicAdd(CurrentCount, count); + AtomicAdd(CurrentSize, size); + Updated(); +} + +void TRemoteServerSessionSemaphore::ReleaseMultiple(TAtomicBase count, TAtomicBase size) { + AtomicSub(CurrentCount, count); + AtomicSub(CurrentSize, size); + Updated(); +} + +void TRemoteServerSessionSemaphore::Stop() { + AtomicSet(StopSignal, 1); + Updated(); +} + +void TRemoteServerSessionSemaphore::PauseByUsed(bool pause) { + AtomicSet(PausedByUser, pause); + Updated(); +} diff --git a/library/cpp/messagebus/remote_server_session_semaphore.h b/library/cpp/messagebus/remote_server_session_semaphore.h new file mode 100644 index 0000000000..de714fd342 --- /dev/null +++ b/library/cpp/messagebus/remote_server_session_semaphore.h @@ -0,0 +1,42 @@ +#pragma once + +#include "cc_semaphore.h" + +#include <util/generic/noncopyable.h> + +namespace NBus { + namespace NPrivate { + class TRemoteServerSessionSemaphore: public TComplexConditionSemaphore<TRemoteServerSessionSemaphore> { + private: + const char* const Name; + + TAtomicBase const LimitCount; + TAtomicBase const LimitSize; + TAtomic CurrentCount; + TAtomic CurrentSize; + TAtomic PausedByUser; + TAtomic StopSignal; + + public: + TRemoteServerSessionSemaphore(TAtomicBase limitCount, TAtomicBase limitSize, const char* name = "unnamed"); + ~TRemoteServerSessionSemaphore(); + + TAtomicBase GetCurrentCount() const { + return AtomicGet(CurrentCount); + } + TAtomicBase GetCurrentSize() const { + return AtomicGet(CurrentSize); + } + + void IncrementMultiple(TAtomicBase count, TAtomicBase size); + bool TryWait(); + void ReleaseMultiple(TAtomicBase count, TAtomicBase size); + void Stop(); + void PauseByUsed(bool pause); + + private: + void CheckNeedToUnlock(); + }; + + } +} diff --git a/library/cpp/messagebus/scheduler/scheduler.cpp b/library/cpp/messagebus/scheduler/scheduler.cpp new file mode 100644 index 0000000000..5a5fe52894 --- /dev/null +++ b/library/cpp/messagebus/scheduler/scheduler.cpp @@ -0,0 +1,119 @@ +#include "scheduler.h" + +#include <util/datetime/base.h> +#include <util/generic/algorithm.h> +#include <util/generic/yexception.h> + +//#include "dummy_debugger.h" + +using namespace NBus; +using namespace NBus::NPrivate; + +class TScheduleDeadlineCompare { +public: + bool operator()(const IScheduleItemAutoPtr& i1, const IScheduleItemAutoPtr& i2) const noexcept { + return i1->GetScheduleTime() > i2->GetScheduleTime(); + } +}; + +TScheduler::TScheduler() + : StopThread(false) + , Thread([&] { this->SchedulerThread(); }) +{ +} + +TScheduler::~TScheduler() { + Y_VERIFY(StopThread, "state check"); +} + +size_t TScheduler::Size() const { + TGuard<TLock> guard(Lock); + return Items.size() + (!!NextItem ? 1 : 0); +} + +void TScheduler::Stop() { + { + TGuard<TLock> guard(Lock); + Y_VERIFY(!StopThread, "Scheduler already stopped"); + StopThread = true; + CondVar.Signal(); + } + Thread.Get(); + + if (!!NextItem) { + NextItem.Destroy(); + } + + for (auto& item : Items) { + item.Destroy(); + } +} + +void TScheduler::Schedule(TAutoPtr<IScheduleItem> i) { + TGuard<TLock> lock(Lock); + if (StopThread) + return; + + if (!!NextItem) { + if (i->GetScheduleTime() < NextItem->GetScheduleTime()) { + DoSwap(i, NextItem); + } + } + + Items.push_back(i); + PushHeap(Items.begin(), Items.end(), TScheduleDeadlineCompare()); + + FillNextItem(); + + CondVar.Signal(); +} + +void TScheduler::FillNextItem() { + if (!NextItem && !Items.empty()) { + PopHeap(Items.begin(), Items.end(), TScheduleDeadlineCompare()); + NextItem = Items.back(); + Items.erase(Items.end() - 1); + } +} + +void TScheduler::SchedulerThread() { + for (;;) { + IScheduleItemAutoPtr current; + + { + TGuard<TLock> guard(Lock); + + if (StopThread) { + break; + } + + if (!!NextItem) { + CondVar.WaitD(Lock, NextItem->GetScheduleTime()); + } else { + CondVar.WaitI(Lock); + } + + if (StopThread) { + break; + } + + // signal comes if either scheduler is to be stopped of there's work to do + Y_VERIFY(!!NextItem, "state check"); + + if (TInstant::Now() < NextItem->GetScheduleTime()) { + // NextItem is updated since WaitD + continue; + } + + current = NextItem.Release(); + } + + current->Do(); + current.Destroy(); + + { + TGuard<TLock> guard(Lock); + FillNextItem(); + } + } +} diff --git a/library/cpp/messagebus/scheduler/scheduler.h b/library/cpp/messagebus/scheduler/scheduler.h new file mode 100644 index 0000000000..afcc0de55d --- /dev/null +++ b/library/cpp/messagebus/scheduler/scheduler.h @@ -0,0 +1,68 @@ +#pragma once + +#include <library/cpp/threading/future/legacy_future.h> + +#include <util/datetime/base.h> +#include <util/generic/object_counter.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> +#include <util/system/thread.h> + +namespace NBus { + namespace NPrivate { + class IScheduleItem { + public: + inline IScheduleItem(TInstant scheduleTime) noexcept; + virtual ~IScheduleItem() { + } + + virtual void Do() = 0; + inline TInstant GetScheduleTime() const noexcept; + + private: + TInstant ScheduleTime; + }; + + using IScheduleItemAutoPtr = TAutoPtr<IScheduleItem>; + + class TScheduler { + public: + TScheduler(); + ~TScheduler(); + void Stop(); + void Schedule(TAutoPtr<IScheduleItem> i); + + size_t Size() const; + + private: + void SchedulerThread(); + + void FillNextItem(); + + private: + TVector<IScheduleItemAutoPtr> Items; + IScheduleItemAutoPtr NextItem; + typedef TMutex TLock; + TLock Lock; + TCondVar CondVar; + + TObjectCounter<TScheduler> ObjectCounter; + + bool StopThread; + NThreading::TLegacyFuture<> Thread; + }; + + inline IScheduleItem::IScheduleItem(TInstant scheduleTime) noexcept + : ScheduleTime(scheduleTime) + { + } + + inline TInstant IScheduleItem::GetScheduleTime() const noexcept { + return ScheduleTime; + } + + } +} diff --git a/library/cpp/messagebus/scheduler/scheduler_ut.cpp b/library/cpp/messagebus/scheduler/scheduler_ut.cpp new file mode 100644 index 0000000000..a5ea641c10 --- /dev/null +++ b/library/cpp/messagebus/scheduler/scheduler_ut.cpp @@ -0,0 +1,36 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "scheduler.h" + +#include <library/cpp/messagebus/misc/test_sync.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +Y_UNIT_TEST_SUITE(TSchedulerTests) { + struct TSimpleScheduleItem: public IScheduleItem { + TTestSync* const TestSync; + + TSimpleScheduleItem(TTestSync* testSync) + : IScheduleItem((TInstant::Now() + TDuration::MilliSeconds(1))) + , TestSync(testSync) + { + } + + void Do() override { + TestSync->WaitForAndIncrement(0); + } + }; + + Y_UNIT_TEST(Simple) { + TTestSync testSync; + + TScheduler scheduler; + + scheduler.Schedule(new TSimpleScheduleItem(&testSync)); + + testSync.WaitForAndIncrement(1); + + scheduler.Stop(); + } +} diff --git a/library/cpp/messagebus/scheduler/ya.make b/library/cpp/messagebus/scheduler/ya.make new file mode 100644 index 0000000000..dcb7408a20 --- /dev/null +++ b/library/cpp/messagebus/scheduler/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/threading/future +) + +SRCS( + scheduler.cpp +) + +END() diff --git a/library/cpp/messagebus/scheduler_actor.h b/library/cpp/messagebus/scheduler_actor.h new file mode 100644 index 0000000000..d0c23c94c4 --- /dev/null +++ b/library/cpp/messagebus/scheduler_actor.h @@ -0,0 +1,85 @@ +#pragma once + +#include "local_tasks.h" + +#include <library/cpp/messagebus/actor/actor.h> +#include <library/cpp/messagebus/actor/what_thread_does_guard.h> +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <util/system/mutex.h> + +namespace NBus { + namespace NPrivate { + template <typename TThis, typename TTag = NActor::TDefaultTag> + class TScheduleActor { + typedef NActor::TActor<TThis, TTag> TActorForMe; + + private: + TScheduler* const Scheduler; + + TMutex Mutex; + + TInstant ScheduleTime; + + public: + TLocalTasks Alarm; + + private: + struct TScheduleItemImpl: public IScheduleItem { + TIntrusivePtr<TThis> Thiz; + + TScheduleItemImpl(TIntrusivePtr<TThis> thiz, TInstant when) + : IScheduleItem(when) + , Thiz(thiz) + { + } + + void Do() override { + { + TWhatThreadDoesAcquireGuard<TMutex> guard(Thiz->Mutex, "scheduler actor: acquiring lock for Do"); + + if (Thiz->ScheduleTime == TInstant::Max()) { + // was already fired + return; + } + + Thiz->ScheduleTime = TInstant::Max(); + } + + Thiz->Alarm.AddTask(); + Thiz->GetActorForMe()->Schedule(); + } + }; + + public: + TScheduleActor(TScheduler* scheduler) + : Scheduler(scheduler) + , ScheduleTime(TInstant::Max()) + { + } + + /// call Act(TTag) at specified time, unless it is already scheduled at earlier time. + void ScheduleAt(TInstant when) { + TWhatThreadDoesAcquireGuard<TMutex> guard(Mutex, "scheduler: acquiring lock for ScheduleAt"); + + if (when > ScheduleTime) { + // already scheduled + return; + } + + ScheduleTime = when; + Scheduler->Schedule(new TScheduleItemImpl(GetThis(), when)); + } + + private: + TThis* GetThis() { + return static_cast<TThis*>(this); + } + + TActorForMe* GetActorForMe() { + return static_cast<TActorForMe*>(GetThis()); + } + }; + + } +} diff --git a/library/cpp/messagebus/scheduler_actor_ut.cpp b/library/cpp/messagebus/scheduler_actor_ut.cpp new file mode 100644 index 0000000000..e81ffd3186 --- /dev/null +++ b/library/cpp/messagebus/scheduler_actor_ut.cpp @@ -0,0 +1,48 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "scheduler_actor.h" +#include "misc/test_sync.h" + +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NActor; + +Y_UNIT_TEST_SUITE(TSchedulerActorTests) { + struct TMyActor: public TAtomicRefCount<TMyActor>, public TActor<TMyActor>, public TScheduleActor<TMyActor> { + TTestSync TestSync; + + TMyActor(TExecutor* executor, TScheduler* scheduler) + : TActor<TMyActor>(executor) + , TScheduleActor<TMyActor>(scheduler) + , Iteration(0) + { + } + + unsigned Iteration; + + void Act(TDefaultTag) { + if (!Alarm.FetchTask()) { + Y_FAIL("must not have no spurious wakeups in test"); + } + + TestSync.WaitForAndIncrement(Iteration++); + if (Iteration <= 5) { + ScheduleAt(TInstant::Now() + TDuration::MilliSeconds(Iteration)); + } + } + }; + + Y_UNIT_TEST(Simple) { + TExecutor executor(1); + TScheduler scheduler; + + TIntrusivePtr<TMyActor> actor(new TMyActor(&executor, &scheduler)); + + actor->ScheduleAt(TInstant::Now() + TDuration::MilliSeconds(1)); + + actor->TestSync.WaitForAndIncrement(6); + + // TODO: stop in destructor + scheduler.Stop(); + } +} diff --git a/library/cpp/messagebus/session.cpp b/library/cpp/messagebus/session.cpp new file mode 100644 index 0000000000..46a7ece6a8 --- /dev/null +++ b/library/cpp/messagebus/session.cpp @@ -0,0 +1,130 @@ +#include "ybus.h" + +#include <util/generic/cast.h> + +using namespace NBus; + +namespace NBus { + TBusSession::TBusSession() { + } + + //////////////////////////////////////////////////////////////////// + /// \brief Adds peer of connection into connection list + + int CompareByHost(const IRemoteAddr& l, const IRemoteAddr& r) noexcept { + if (l.Addr()->sa_family != r.Addr()->sa_family) { + return l.Addr()->sa_family < r.Addr()->sa_family ? -1 : +1; + } + + switch (l.Addr()->sa_family) { + case AF_INET: { + return memcmp(&(((const sockaddr_in*)l.Addr())->sin_addr), &(((const sockaddr_in*)r.Addr())->sin_addr), sizeof(in_addr)); + } + + case AF_INET6: { + return memcmp(&(((const sockaddr_in6*)l.Addr())->sin6_addr), &(((const sockaddr_in6*)r.Addr())->sin6_addr), sizeof(in6_addr)); + } + } + + return memcmp(l.Addr(), r.Addr(), Min<size_t>(l.Len(), r.Len())); + } + + bool operator<(const TNetAddr& a1, const TNetAddr& a2) { + return CompareByHost(a1, a2) < 0; + } + + size_t TBusSession::GetInFlight(const TNetAddr& addr) const { + size_t r; + GetInFlightBulk({addr}, MakeArrayRef(&r, 1)); + return r; + } + + size_t TBusSession::GetConnectSyscallsNumForTest(const TNetAddr& addr) const { + size_t r; + GetConnectSyscallsNumBulkForTest({addr}, MakeArrayRef(&r, 1)); + return r; + } + + // Split 'host' into name and port taking into account that host can be specified + // as ipv6 address ('[<ipv6 address]:port' notion). + bool SplitHost(const TString& host, TString* hostName, TString* portNum) { + hostName->clear(); + portNum->clear(); + + // Simple check that we have to deal with ipv6 address specification or + // just host name or ipv4 address. + if (!host.empty() && (host[0] == '[')) { + size_t pos = host.find(']'); + if (pos < 2 || pos == TString::npos) { + // '[]' and '[<address>' are errors. + return false; + } + + *hostName = host.substr(1, pos - 1); + + pos++; + if (pos != host.length()) { + if (host[pos] != ':') { + // Do not allow '[...]a' but '[...]:' is ok (as for ipv4 before + return false; + } + + *portNum = host.substr(pos + 1); + } + } else { + size_t pos = host.find(':'); + if (pos != TString::npos) { + if (pos == 0) { + // Treat ':<port>' as errors but allow or '<host>:' for compatibility. + return false; + } + + *portNum = host.substr(pos + 1); + } + + *hostName = host.substr(0, pos); + } + + return true; + } + + /// registers external session on host:port with locator service + int TBusSession::RegisterService(const char* host, TBusKey start /*= YBUS_KEYMIN*/, TBusKey end /*= YBUS_KEYMAX*/, EIpVersion ipVersion) { + TString hostName; + TString port; + int portNum; + + if (!SplitHost(host, &hostName, &port)) { + hostName = host; + } + + if (port.empty()) { + portNum = GetProto()->GetPort(); + } else { + try { + portNum = FromString<int>(port); + } catch (const TFromStringException&) { + return -1; + } + } + + TBusService service = GetProto()->GetService(); + return GetQueue()->GetLocator()->Register(service, hostName.data(), portNum, start, end, ipVersion); + } + + TBusSession::~TBusSession() { + } + +} + +TBusClientSessionPtr TBusClientSession::Create(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, TBusMessageQueuePtr queue) { + return queue->CreateSource(proto, handler, config); +} + +TBusServerSessionPtr TBusServerSession::Create(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, TBusMessageQueuePtr queue) { + return queue->CreateDestination(proto, handler, config); +} + +TBusServerSessionPtr TBusServerSession::Create(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, TBusMessageQueuePtr queue, const TVector<TBindResult>& bindTo) { + return queue->CreateDestination(proto, handler, config, bindTo); +} diff --git a/library/cpp/messagebus/session.h b/library/cpp/messagebus/session.h new file mode 100644 index 0000000000..fb12ab7c22 --- /dev/null +++ b/library/cpp/messagebus/session.h @@ -0,0 +1,225 @@ +#pragma once + +#include "connection.h" +#include "defs.h" +#include "handler.h" +#include "message.h" +#include "netaddr.h" +#include "network.h" +#include "session_config.h" +#include "misc/weak_ptr.h" + +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +#include <util/generic/array_ref.h> +#include <util/generic/ptr.h> + +namespace NBus { + template <typename TBusSessionSubclass> + class TBusSessionPtr; + using TBusClientSessionPtr = TBusSessionPtr<TBusClientSession>; + using TBusServerSessionPtr = TBusSessionPtr<TBusServerSession>; + + /////////////////////////////////////////////////////////////////// + /// \brief Interface of session object. + + /// Each client and server + /// should instantiate session object to be able to communicate via bus + /// client: sess = queue->CreateSource(protocol, handler); + /// server: sess = queue->CreateDestination(protocol, handler); + + class TBusSession: public TWeakRefCounted<TBusSession> { + public: + size_t GetInFlight(const TNetAddr& addr) const; + size_t GetConnectSyscallsNumForTest(const TNetAddr& addr) const; + + virtual void GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const = 0; + virtual void GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const = 0; + + virtual int GetInFlight() const noexcept = 0; + /// monitoring status of current session and it's connections + virtual TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) = 0; + virtual TConnectionStatusMonRecord GetStatusProtobuf() = 0; + virtual NPrivate::TSessionDumpStatus GetStatusRecordInternal() = 0; + virtual TString GetStatusSingleLine() = 0; + /// return session config + virtual const TBusSessionConfig* GetConfig() const noexcept = 0; + /// return session protocol + virtual const TBusProtocol* GetProto() const noexcept = 0; + virtual TBusMessageQueue* GetQueue() const noexcept = 0; + + /// registers external session on host:port with locator service + int RegisterService(const char* hostname, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion ipVersion = EIP_VERSION_4); + + protected: + TBusSession(); + + public: + virtual TString GetNameInternal() = 0; + + virtual void Shutdown() = 0; + + virtual ~TBusSession(); + }; + + struct TBusClientSession: public virtual TBusSession { + typedef ::NBus::NPrivate::TRemoteClientSession TImpl; + + static TBusClientSessionPtr Create( + TBusProtocol* proto, + IBusClientHandler* handler, + const TBusClientSessionConfig& config, + TBusMessageQueuePtr queue); + + virtual TBusClientConnectionPtr GetConnection(const TNetAddr&) = 0; + + /// if you want to open connection early + virtual void OpenConnection(const TNetAddr&) = 0; + + /// Send message to the destination + /// If addr is set then use it as destination. + /// Takes ownership of addr (see ClearState method). + virtual EMessageStatus SendMessage(TBusMessage* pMes, const TNetAddr* addr = nullptr, bool wait = false) = 0; + + virtual EMessageStatus SendMessageOneWay(TBusMessage* pMes, const TNetAddr* addr = nullptr, bool wait = false) = 0; + + /// Like SendMessage but cares about message + template <typename T /* <: TBusMessage */> + EMessageStatus SendMessageAutoPtr(const TAutoPtr<T>& mes, const TNetAddr* addr = nullptr, bool wait = false) { + EMessageStatus status = SendMessage(mes.Get(), addr, wait); + if (status == MESSAGE_OK) + Y_UNUSED(mes.Release()); + return status; + } + + /// Like SendMessageOneWay but cares about message + template <typename T /* <: TBusMessage */> + EMessageStatus SendMessageOneWayAutoPtr(const TAutoPtr<T>& mes, const TNetAddr* addr = nullptr, bool wait = false) { + EMessageStatus status = SendMessageOneWay(mes.Get(), addr, wait); + if (status == MESSAGE_OK) + Y_UNUSED(mes.Release()); + return status; + } + + EMessageStatus SendMessageMove(TBusMessageAutoPtr message, const TNetAddr* addr = nullptr, bool wait = false) { + return SendMessageAutoPtr(message, addr, wait); + } + + EMessageStatus SendMessageOneWayMove(TBusMessageAutoPtr message, const TNetAddr* addr = nullptr, bool wait = false) { + return SendMessageOneWayAutoPtr(message, addr, wait); + } + + // TODO: implement similar one-way methods + }; + + struct TBusServerSession: public virtual TBusSession { + typedef ::NBus::NPrivate::TRemoteServerSession TImpl; + + static TBusServerSessionPtr Create( + TBusProtocol* proto, + IBusServerHandler* handler, + const TBusServerSessionConfig& config, + TBusMessageQueuePtr queue); + + static TBusServerSessionPtr Create( + TBusProtocol* proto, + IBusServerHandler* handler, + const TBusServerSessionConfig& config, + TBusMessageQueuePtr queue, + const TVector<TBindResult>& bindTo); + + // TODO: make parameter non-const + virtual EMessageStatus SendReply(const TBusIdentity& ident, TBusMessage* pRep) = 0; + + // TODO: make parameter non-const + virtual EMessageStatus ForgetRequest(const TBusIdentity& ident) = 0; + + template <typename U /* <: TBusMessage */> + EMessageStatus SendReplyAutoPtr(TBusIdentity& ident, TAutoPtr<U>& resp) { + EMessageStatus status = SendReply(const_cast<const TBusIdentity&>(ident), resp.Get()); + if (status == MESSAGE_OK) { + Y_UNUSED(resp.Release()); + } + return status; + } + + EMessageStatus SendReplyMove(TBusIdentity& ident, TBusMessageAutoPtr resp) { + return SendReplyAutoPtr(ident, resp); + } + + /// Pause input from the network. + /// It is valid to call this method in parallel. + /// TODO: pull this method up to TBusSession. + virtual void PauseInput(bool pause) = 0; + virtual unsigned GetActualListenPort() = 0; + }; + + namespace NPrivate { + template <typename TBusSessionSubclass> + class TBusOwnerSessionPtr: public TAtomicRefCount<TBusOwnerSessionPtr<TBusSessionSubclass>> { + private: + TIntrusivePtr<TBusSessionSubclass> Ptr; + + public: + TBusOwnerSessionPtr(TBusSessionSubclass* session) + : Ptr(session) + { + Y_ASSERT(!!Ptr); + } + + ~TBusOwnerSessionPtr() { + Ptr->Shutdown(); + } + + TBusSessionSubclass* Get() const { + return reinterpret_cast<TBusSessionSubclass*>(Ptr.Get()); + } + }; + + } + + template <typename TBusSessionSubclass> + class TBusSessionPtr { + private: + TIntrusivePtr<NPrivate::TBusOwnerSessionPtr<TBusSessionSubclass>> SmartPtr; + TBusSessionSubclass* Ptr; + + public: + TBusSessionPtr() + : Ptr() + { + } + TBusSessionPtr(TBusSessionSubclass* session) + : SmartPtr(!!session ? new NPrivate::TBusOwnerSessionPtr<TBusSessionSubclass>(session) : nullptr) + , Ptr(session) + { + } + + TBusSessionSubclass* Get() const { + return Ptr; + } + operator TBusSessionSubclass*() { + return Get(); + } + TBusSessionSubclass& operator*() const { + return *Get(); + } + TBusSessionSubclass* operator->() const { + return Get(); + } + + bool operator!() const { + return !Ptr; + } + + void Swap(TBusSessionPtr& t) noexcept { + DoSwap(SmartPtr, t.SmartPtr); + DoSwap(Ptr, t.Ptr); + } + + void Drop() { + TBusSessionPtr().Swap(*this); + } + }; + +} diff --git a/library/cpp/messagebus/session_config.h b/library/cpp/messagebus/session_config.h new file mode 100644 index 0000000000..37df97e986 --- /dev/null +++ b/library/cpp/messagebus/session_config.h @@ -0,0 +1,4 @@ +#pragma once + +#include <library/cpp/messagebus/config/session_config.h> + diff --git a/library/cpp/messagebus/session_impl.cpp b/library/cpp/messagebus/session_impl.cpp new file mode 100644 index 0000000000..ddf9f360c4 --- /dev/null +++ b/library/cpp/messagebus/session_impl.cpp @@ -0,0 +1,650 @@ +#include "session_impl.h" + +#include "acceptor.h" +#include "network.h" +#include "remote_client_connection.h" +#include "remote_client_session.h" +#include "remote_server_connection.h" +#include "remote_server_session.h" +#include "misc/weak_ptr.h" + +#include <util/generic/cast.h> + +using namespace NActor; +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NEventLoop; + +namespace { + class TScheduleSession: public IScheduleItem { + public: + TScheduleSession(TBusSessionImpl* session, TInstant deadline) + : IScheduleItem(deadline) + , Session(session) + , SessionImpl(session) + { + } + + void Do() override { + TIntrusivePtr<TBusSession> session = Session.Get(); + if (!!session) { + SessionImpl->Cron(); + } + } + + private: + TWeakPtr<TBusSession> Session; + // Work around TWeakPtr limitation + TBusSessionImpl* SessionImpl; + }; +} + +TConnectionsAcceptorsSnapshot::TConnectionsAcceptorsSnapshot() + : LastConnectionId(0) + , LastAcceptorId(0) +{ +} + +struct TBusSessionImpl::TImpl { + TRemoteConnectionWriterIncrementalStatus DeadConnectionWriterStatusSummary; + TRemoteConnectionReaderIncrementalStatus DeadConnectionReaderStatusSummary; + TAcceptorStatus DeadAcceptorStatusSummary; +}; + +namespace { + TBusSessionConfig SessionConfigFillDefaults(const TBusSessionConfig& config, const TString& name) { + TBusSessionConfig copy = config; + if (copy.TotalTimeout == 0 && copy.SendTimeout == 0) { + copy.TotalTimeout = TDuration::Seconds(60).MilliSeconds(); + copy.SendTimeout = TDuration::Seconds(15).MilliSeconds(); + } else if (copy.TotalTimeout == 0) { + Y_ASSERT(copy.SendTimeout != 0); + copy.TotalTimeout = config.SendTimeout + TDuration::MilliSeconds(10).MilliSeconds(); + } else if (copy.SendTimeout == 0) { + Y_ASSERT(copy.TotalTimeout != 0); + if ((ui64)copy.TotalTimeout > (ui64)TDuration::MilliSeconds(10).MilliSeconds()) { + copy.SendTimeout = copy.TotalTimeout - TDuration::MilliSeconds(10).MilliSeconds(); + } else { + copy.SendTimeout = copy.TotalTimeout; + } + } else { + Y_ASSERT(copy.TotalTimeout != 0); + Y_ASSERT(copy.SendTimeout != 0); + } + + if (copy.ConnectTimeout == 0) { + copy.ConnectTimeout = copy.SendTimeout; + } + + Y_VERIFY(copy.SendTimeout > 0, "SendTimeout must be > 0"); + Y_VERIFY(copy.TotalTimeout > 0, "TotalTimeout must be > 0"); + Y_VERIFY(copy.ConnectTimeout > 0, "ConnectTimeout must be > 0"); + Y_VERIFY(copy.TotalTimeout >= copy.SendTimeout, "TotalTimeout must be >= SendTimeout"); + + if (!copy.Name) { + copy.Name = name; + } + + return copy; + } +} + +TBusSessionImpl::TBusSessionImpl(bool isSource, TBusMessageQueue* queue, TBusProtocol* proto, + IBusErrorHandler* handler, + const TBusSessionConfig& config, const TString& name) + : TActor<TBusSessionImpl, TStatusTag>(queue->WorkQueue.Get()) + , TActor<TBusSessionImpl, TConnectionTag>(queue->WorkQueue.Get()) + , Impl(new TImpl) + , IsSource_(isSource) + , Queue(queue) + , Proto(proto) + , ProtoName(Proto->GetService()) + , ErrorHandler(handler) + , HandlerUseCountHolder(&handler->UseCountChecker) + , Config(SessionConfigFillDefaults(config, name)) + , WriteEventLoop("wr-el") + , ReadEventLoop("rd-el") + , LastAcceptorId(0) + , LastConnectionId(0) + , Down(0) +{ + Impl->DeadAcceptorStatusSummary.Summary = true; + + ReadEventLoopThread.Reset(new NThreading::TLegacyFuture<void, false>(std::bind(&TEventLoop::Run, std::ref(ReadEventLoop)))); + WriteEventLoopThread.Reset(new NThreading::TLegacyFuture<void, false>(std::bind(&TEventLoop::Run, std::ref(WriteEventLoop)))); + + Queue->Schedule(IScheduleItemAutoPtr(new TScheduleSession(this, TInstant::Now() + Config.Secret.TimeoutPeriod))); +} + +TBusSessionImpl::~TBusSessionImpl() { + Y_VERIFY(Down); + Y_VERIFY(ShutdownCompleteEvent.WaitT(TDuration::Zero())); + Y_VERIFY(!WriteEventLoop.IsRunning()); + Y_VERIFY(!ReadEventLoop.IsRunning()); +} + +TBusSessionStatus::TBusSessionStatus() + : InFlightCount(0) + , InFlightSize(0) + , InputPaused(false) +{ +} + +void TBusSessionImpl::Shutdown() { + if (!AtomicCas(&Down, 1, 0)) { + ShutdownCompleteEvent.WaitI(); + return; + } + + Y_VERIFY(Queue->IsRunning(), "Session must be shut down prior to queue shutdown"); + + TUseAfterFreeCheckerGuard handlerAliveCheckedGuard(ErrorHandler->UseAfterFreeChecker); + + // For legacy clients that don't use smart pointers + TIntrusivePtr<TBusSessionImpl> thiz(this); + + Queue->Remove(this); + + // shutdown event loops first, so they won't send more events + // to acceptors and connections + ReadEventLoop.Stop(); + WriteEventLoop.Stop(); + ReadEventLoopThread->Get(); + WriteEventLoopThread->Get(); + + // shutdown acceptors before connections + // so they won't create more connections + TVector<TAcceptorPtr> acceptors; + GetAcceptors(&acceptors); + { + TGuard<TMutex> guard(ConnectionsLock); + Acceptors.clear(); + } + + for (auto& acceptor : acceptors) { + acceptor->Shutdown(); + } + + // shutdown connections + TVector<TRemoteConnectionPtr> cs; + GetConnections(&cs); + + for (auto& c : cs) { + c->Shutdown(MESSAGE_SHUTDOWN); + } + + // shutdown connections actor + // must shutdown after connections destroyed + ConnectionsData.ShutdownState.ShutdownCommand(); + GetConnectionsActor()->Schedule(); + ConnectionsData.ShutdownState.ShutdownComplete.WaitI(); + + // finally shutdown status actor + StatusData.ShutdownState.ShutdownCommand(); + GetStatusActor()->Schedule(); + StatusData.ShutdownState.ShutdownComplete.WaitI(); + + // Make sure no one references IMessageHandler after Shutdown() + JobCount.WaitForZero(); + HandlerUseCountHolder.Reset(); + + ShutdownCompleteEvent.Signal(); +} + +bool TBusSessionImpl::IsDown() { + return static_cast<bool>(AtomicGet(Down)); +} + +size_t TBusSessionImpl::GetInFlightImpl(const TNetAddr& addr) const { + TRemoteConnectionPtr conn = const_cast<TBusSessionImpl*>(this)->GetConnection(addr, false); + if (!!conn) { + return conn->GetInFlight(); + } else { + return 0; + } +} + +void TBusSessionImpl::GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const { + Y_VERIFY(addrs.size() == results.size(), "input.size != output.size"); + for (size_t i = 0; i < addrs.size(); ++i) { + results[i] = GetInFlightImpl(addrs[i]); + } +} + +size_t TBusSessionImpl::GetConnectSyscallsNumForTestImpl(const TNetAddr& addr) const { + TRemoteConnectionPtr conn = const_cast<TBusSessionImpl*>(this)->GetConnection(addr, false); + if (!!conn) { + return conn->GetConnectSyscallsNumForTest(); + } else { + return 0; + } +} + +void TBusSessionImpl::GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const { + Y_VERIFY(addrs.size() == results.size(), "input.size != output.size"); + for (size_t i = 0; i < addrs.size(); ++i) { + results[i] = GetConnectSyscallsNumForTestImpl(addrs[i]); + } +} + +void TBusSessionImpl::FillStatus() { +} + +TSessionDumpStatus TBusSessionImpl::GetStatusRecordInternal() { + // Probably useless, because it returns cached info now + Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(), + "GetStatus must not be called from executor thread"); + + TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex); + // TODO: returns zeros for a second after start + // (until first cron) + return StatusData.StatusDumpCached; +} + +TString TBusSessionImpl::GetStatus(ui16 flags) { + Y_UNUSED(flags); + + return GetStatusRecordInternal().PrintToString(); +} + +TConnectionStatusMonRecord TBusSessionImpl::GetStatusProtobuf() { + Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(), + "GetStatus must not be called from executor thread"); + + TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex); + + return StatusData.StatusDumpCached.ConnectionStatusSummary.GetStatusProtobuf(); +} + +TString TBusSessionImpl::GetStatusSingleLine() { + TSessionDumpStatus status = GetStatusRecordInternal(); + + TStringStream ss; + ss << "in-flight: " << status.Status.InFlightCount; + if (IsSource_) { + ss << " ack: " << status.ConnectionStatusSummary.WriterStatus.AckMessagesSize; + } + ss << " send-q: " << status.ConnectionStatusSummary.WriterStatus.SendQueueSize; + return ss.Str(); +} + +void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionWriterIncrementalStatus& connectionStatus) { + Impl->DeadConnectionWriterStatusSummary += connectionStatus; +} + +void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionReaderIncrementalStatus& connectionStatus) { + Impl->DeadConnectionReaderStatusSummary += connectionStatus; +} + +void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TAcceptorStatus& acceptorStatus) { + Impl->DeadAcceptorStatusSummary += acceptorStatus; +} + +void TBusSessionImpl::ProcessItem(TConnectionTag, ::NActor::TDefaultTag, const TOnAccept& onAccept) { + TSocketHolder socket(onAccept.s); + + if (AtomicGet(Down)) { + // do not create connections after shutdown initiated + return; + } + + //if (Connections.find(addr) != Connections.end()) { + // TODO: it is possible + // won't be a problem after socket address replaced with id + //} + + TRemoteConnectionPtr c(new TRemoteServerConnection(VerifyDynamicCast<TRemoteServerSession*>(this), ++LastConnectionId, onAccept.addr)); + + VerifyDynamicCast<TRemoteServerConnection*>(c.Get())->Init(socket.Release(), onAccept.now); + + InsertConnectionLockAcquired(c.Get()); +} + +void TBusSessionImpl::ProcessItem(TConnectionTag, TRemoveTag, TRemoteConnectionPtr c) { + TAddrRemoteConnections::iterator it1 = Connections.find(c->PeerAddrSocketAddr); + if (it1 != Connections.end()) { + if (it1->second.Get() == c.Get()) { + Connections.erase(it1); + } + } + + THashMap<ui64, TRemoteConnectionPtr>::iterator it2 = ConnectionsById.find(c->ConnectionId); + if (it2 != ConnectionsById.end()) { + ConnectionsById.erase(it2); + } + + SendSnapshotToStatusActor(); +} + +void TBusSessionImpl::ProcessConnectionsAcceptorsShapshotQueueItem(TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> snapshot) { + for (TVector<TRemoteConnectionPtr>::const_iterator connection = snapshot->Connections.begin(); + connection != snapshot->Connections.end(); ++connection) { + Y_ASSERT((*connection)->ConnectionId <= snapshot->LastConnectionId); + } + + for (TVector<TAcceptorPtr>::const_iterator acceptor = snapshot->Acceptors.begin(); + acceptor != snapshot->Acceptors.end(); ++acceptor) { + Y_ASSERT((*acceptor)->AcceptorId <= snapshot->LastAcceptorId); + } + + StatusData.ConnectionsAcceptorsSnapshot = snapshot; +} + +void TBusSessionImpl::StatusUpdateCachedDumpIfNecessary(TInstant now) { + if (now - StatusData.StatusDumpCachedLastUpdate > Config.Secret.StatusFlushPeriod) { + StatusUpdateCachedDump(); + StatusData.StatusDumpCachedLastUpdate = now; + } +} + +void TBusSessionImpl::StatusUpdateCachedDump() { + TSessionDumpStatus r; + + if (AtomicGet(Down)) { + r.Shutdown = true; + TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex); + StatusData.StatusDumpCached = r; + return; + } + + // TODO: make thread-safe + FillStatus(); + + r.Status = StatusData.Status; + + { + TStringStream ss; + + TString name = Config.Name; + if (!name) { + name = "unnamed"; + } + + ss << (IsSource_ ? "client" : "server") << " session " << name << ", proto " << Proto->GetService() << Endl; + ss << "in flight: " << r.Status.InFlightCount; + if (!IsSource_) { + ss << ", " << r.Status.InFlightSize << "b"; + } + if (r.Status.InputPaused) { + ss << " (input paused)"; + } + ss << "\n"; + + r.Head = ss.Str(); + } + + TVector<TRemoteConnectionPtr>& connections = StatusData.ConnectionsAcceptorsSnapshot->Connections; + TVector<TAcceptorPtr>& acceptors = StatusData.ConnectionsAcceptorsSnapshot->Acceptors; + + r.ConnectionStatusSummary = TRemoteConnectionStatus(); + r.ConnectionStatusSummary.Summary = true; + r.ConnectionStatusSummary.Server = !IsSource_; + r.ConnectionStatusSummary.WriterStatus.Incremental = Impl->DeadConnectionWriterStatusSummary; + r.ConnectionStatusSummary.ReaderStatus.Incremental = Impl->DeadConnectionReaderStatusSummary; + + TAcceptorStatus acceptorStatusSummary = Impl->DeadAcceptorStatusSummary; + + { + TStringStream ss; + + for (TVector<TAcceptorPtr>::const_iterator acceptor = acceptors.begin(); + acceptor != acceptors.end(); ++acceptor) { + const TAcceptorStatus status = (*acceptor)->GranStatus.Listen.Get(); + + acceptorStatusSummary += status; + + if (acceptor != acceptors.begin()) { + ss << "\n"; + } + ss << status.PrintToString(); + } + + r.Acceptors = ss.Str(); + } + + { + TStringStream ss; + + for (TVector<TRemoteConnectionPtr>::const_iterator connection = connections.begin(); + connection != connections.end(); ++connection) { + if (connection != connections.begin()) { + ss << "\n"; + } + + TRemoteConnectionStatus status; + status.Server = !IsSource_; + status.ReaderStatus = (*connection)->GranStatus.Reader.Get(); + status.WriterStatus = (*connection)->GranStatus.Writer.Get(); + + ss << status.PrintToString(); + + r.ConnectionStatusSummary.ReaderStatus += status.ReaderStatus; + r.ConnectionStatusSummary.WriterStatus += status.WriterStatus; + } + + r.ConnectionsSummary = r.ConnectionStatusSummary.PrintToString(); + r.Connections = ss.Str(); + } + + r.Config = Config; + + TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex); + StatusData.StatusDumpCached = r; +} + +TBusSessionImpl::TStatusData::TStatusData() + : ConnectionsAcceptorsSnapshot(new TConnectionsAcceptorsSnapshot) +{ +} + +void TBusSessionImpl::Act(TStatusTag) { + TInstant now = TInstant::Now(); + + EShutdownState shutdownState = StatusData.ShutdownState.State.Get(); + + StatusData.ConnectionsAcceptorsSnapshotsQueue.DequeueAllLikelyEmpty(std::bind(&TBusSessionImpl::ProcessConnectionsAcceptorsShapshotQueueItem, this, std::placeholders::_1)); + + GetDeadConnectionWriterStatusQueue()->DequeueAllLikelyEmpty(); + GetDeadConnectionReaderStatusQueue()->DequeueAllLikelyEmpty(); + GetDeadAcceptorStatusQueue()->DequeueAllLikelyEmpty(); + + // TODO: check queues are empty if already stopped + + if (shutdownState != SS_RUNNING) { + // important to beak cyclic link session -> connection -> session + StatusData.ConnectionsAcceptorsSnapshot->Connections.clear(); + StatusData.ConnectionsAcceptorsSnapshot->Acceptors.clear(); + } + + if (shutdownState == SS_SHUTDOWN_COMMAND) { + StatusData.ShutdownState.CompleteShutdown(); + } + + StatusUpdateCachedDumpIfNecessary(now); +} + +TBusSessionImpl::TConnectionsData::TConnectionsData() { +} + +void TBusSessionImpl::Act(TConnectionTag) { + TConnectionsGuard guard(ConnectionsLock); + + EShutdownState shutdownState = ConnectionsData.ShutdownState.State.Get(); + if (shutdownState == SS_SHUTDOWN_COMPLETE) { + Y_VERIFY(GetRemoveConnectionQueue()->IsEmpty()); + Y_VERIFY(GetOnAcceptQueue()->IsEmpty()); + } + + GetRemoveConnectionQueue()->DequeueAllLikelyEmpty(); + GetOnAcceptQueue()->DequeueAllLikelyEmpty(); + + if (shutdownState == SS_SHUTDOWN_COMMAND) { + ConnectionsData.ShutdownState.CompleteShutdown(); + } +} + +void TBusSessionImpl::Listen(int port, TBusMessageQueue* q) { + Listen(BindOnPort(port, Config.ReusePort).second, q); +} + +void TBusSessionImpl::Listen(const TVector<TBindResult>& bindTo, TBusMessageQueue* q) { + Y_ASSERT(q == Queue); + int actualPort = -1; + + for (const TBindResult& br : bindTo) { + if (actualPort == -1) { + actualPort = br.Addr.GetPort(); + } else { + Y_VERIFY(actualPort == br.Addr.GetPort(), "state check"); + } + if (Config.SocketToS >= 0) { + SetSocketToS(*br.Socket, &(br.Addr), Config.SocketToS); + } + + TAcceptorPtr acceptor(new TAcceptor(this, ++LastAcceptorId, br.Socket->Release(), br.Addr)); + + TConnectionsGuard guard(ConnectionsLock); + InsertAcceptorLockAcquired(acceptor.Get()); + } + + Config.ListenPort = actualPort; +} + +void TBusSessionImpl::SendSnapshotToStatusActor() { + //Y_ASSERT(ConnectionsLock.IsLocked()); + + TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> snapshot(new TConnectionsAcceptorsSnapshot); + GetAcceptorsLockAquired(&snapshot->Acceptors); + GetConnectionsLockAquired(&snapshot->Connections); + snapshot->LastAcceptorId = LastAcceptorId; + snapshot->LastConnectionId = LastConnectionId; + StatusData.ConnectionsAcceptorsSnapshotsQueue.Enqueue(snapshot); + GetStatusActor()->Schedule(); +} + +void TBusSessionImpl::InsertConnectionLockAcquired(TRemoteConnection* connection) { + //Y_ASSERT(ConnectionsLock.IsLocked()); + + Connections.insert(std::make_pair(connection->PeerAddrSocketAddr, connection)); + // connection for given adds may already exist at this point + // (so we overwrite old connection) + // after reconnect, if previous connections wasn't shutdown yet + + bool inserted2 = ConnectionsById.insert(std::make_pair(connection->ConnectionId, connection)).second; + Y_VERIFY(inserted2, "state check: must be inserted (2)"); + + SendSnapshotToStatusActor(); +} + +void TBusSessionImpl::InsertAcceptorLockAcquired(TAcceptor* acceptor) { + //Y_ASSERT(ConnectionsLock.IsLocked()); + + Acceptors.push_back(acceptor); + + SendSnapshotToStatusActor(); +} + +void TBusSessionImpl::GetConnections(TVector<TRemoteConnectionPtr>* r) { + TConnectionsGuard guard(ConnectionsLock); + GetConnectionsLockAquired(r); +} + +void TBusSessionImpl::GetAcceptors(TVector<TAcceptorPtr>* r) { + TConnectionsGuard guard(ConnectionsLock); + GetAcceptorsLockAquired(r); +} + +void TBusSessionImpl::GetConnectionsLockAquired(TVector<TRemoteConnectionPtr>* r) { + //Y_ASSERT(ConnectionsLock.IsLocked()); + + r->reserve(Connections.size()); + + for (auto& connection : Connections) { + r->push_back(connection.second); + } +} + +void TBusSessionImpl::GetAcceptorsLockAquired(TVector<TAcceptorPtr>* r) { + //Y_ASSERT(ConnectionsLock.IsLocked()); + + r->reserve(Acceptors.size()); + + for (auto& acceptor : Acceptors) { + r->push_back(acceptor); + } +} + +TRemoteConnectionPtr TBusSessionImpl::GetConnectionById(ui64 id) { + TConnectionsGuard guard(ConnectionsLock); + + THashMap<ui64, TRemoteConnectionPtr>::const_iterator it = ConnectionsById.find(id); + if (it == ConnectionsById.end()) { + return nullptr; + } else { + return it->second; + } +} + +TAcceptorPtr TBusSessionImpl::GetAcceptorById(ui64 id) { + TGuard<TMutex> guard(ConnectionsLock); + + for (const auto& Acceptor : Acceptors) { + if (Acceptor->AcceptorId == id) { + return Acceptor; + } + } + + return nullptr; +} + +void TBusSessionImpl::InvokeOnError(TNonDestroyingAutoPtr<TBusMessage> message, EMessageStatus status) { + message->CheckClean(); + ErrorHandler->OnError(message, status); +} + +TRemoteConnectionPtr TBusSessionImpl::GetConnection(const TBusSocketAddr& addr, bool create) { + TConnectionsGuard guard(ConnectionsLock); + + TAddrRemoteConnections::const_iterator it = Connections.find(addr); + if (it != Connections.end()) { + return it->second; + } + + if (!create) { + return TRemoteConnectionPtr(); + } + + Y_VERIFY(IsSource_, "must be source"); + + TRemoteConnectionPtr c(new TRemoteClientConnection(VerifyDynamicCast<TRemoteClientSession*>(this), ++LastConnectionId, addr.ToNetAddr())); + InsertConnectionLockAcquired(c.Get()); + + return c; +} + +void TBusSessionImpl::Cron() { + TVector<TRemoteConnectionPtr> connections; + GetConnections(&connections); + + for (const auto& it : connections) { + TRemoteConnection* connection = it.Get(); + if (IsSource_) { + VerifyDynamicCast<TRemoteClientConnection*>(connection)->ScheduleTimeoutMessages(); + } else { + VerifyDynamicCast<TRemoteServerConnection*>(connection)->WriterData.TimeToRotateCounters.AddTask(); + // no schedule: do not rotate if there's no traffic + } + } + + // status updates are sent without scheduling + GetStatusActor()->Schedule(); + + Queue->Schedule(IScheduleItemAutoPtr(new TScheduleSession(this, TInstant::Now() + Config.Secret.TimeoutPeriod))); +} + +TString TBusSessionImpl::GetNameInternal() { + if (!!Config.Name) { + return Config.Name; + } + return ProtoName; +} diff --git a/library/cpp/messagebus/session_impl.h b/library/cpp/messagebus/session_impl.h new file mode 100644 index 0000000000..90ef246ff8 --- /dev/null +++ b/library/cpp/messagebus/session_impl.h @@ -0,0 +1,259 @@ +#pragma once + +#include "acceptor_status.h" +#include "async_result.h" +#include "event_loop.h" +#include "netaddr.h" +#include "remote_connection.h" +#include "remote_connection_status.h" +#include "session_job_count.h" +#include "shutdown_state.h" +#include "ybus.h" + +#include <library/cpp/messagebus/actor/actor.h> +#include <library/cpp/messagebus/actor/queue_in_actor.h> +#include <library/cpp/messagebus/monitoring/mon_proto.pb.h> + +#include <library/cpp/threading/future/legacy_future.h> + +#include <util/generic/array_ref.h> +#include <util/generic/string.h> + +namespace NBus { + namespace NPrivate { + typedef TIntrusivePtr<TRemoteClientConnection> TRemoteClientConnectionPtr; + typedef TIntrusivePtr<TRemoteServerConnection> TRemoteServerConnectionPtr; + + typedef TIntrusivePtr<TRemoteServerSession> TRemoteServerSessionPtr; + + typedef TIntrusivePtr<TAcceptor> TAcceptorPtr; + typedef TVector<TAcceptorPtr> TAcceptorsPtrs; + + struct TConnectionsAcceptorsSnapshot { + TVector<TRemoteConnectionPtr> Connections; + TVector<TAcceptorPtr> Acceptors; + ui64 LastConnectionId; + ui64 LastAcceptorId; + + TConnectionsAcceptorsSnapshot(); + }; + + typedef TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> TConnectionsAcceptorsSnapshotPtr; + + struct TOnAccept { + SOCKET s; + TNetAddr addr; + TInstant now; + }; + + struct TStatusTag {}; + struct TConnectionTag {}; + + struct TDeadConnectionTag {}; + struct TRemoveTag {}; + + struct TBusSessionImpl + : public virtual TBusSession, + private ::NActor::TActor<TBusSessionImpl, TStatusTag>, + private ::NActor::TActor<TBusSessionImpl, TConnectionTag> + + , + private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus, TStatusTag, TDeadConnectionTag>, + private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus, TStatusTag, TDeadConnectionTag>, + private ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus, TStatusTag, TDeadConnectionTag> + + , + private ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>, + private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag> { + friend class TAcceptor; + friend class TRemoteConnection; + friend class TRemoteServerConnection; + friend class ::NActor::TActor<TBusSessionImpl, TStatusTag>; + friend class ::NActor::TActor<TBusSessionImpl, TConnectionTag>; + friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus, TStatusTag, TDeadConnectionTag>; + friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus, TStatusTag, TDeadConnectionTag>; + friend class ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus, TStatusTag, TDeadConnectionTag>; + friend class ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>; + friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag>; + + public: + ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>* GetOnAcceptQueue() { + return this; + } + + ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag>* GetRemoveConnectionQueue() { + return this; + } + + ::NActor::TActor<TBusSessionImpl, TConnectionTag>* GetConnectionActor() { + return this; + } + + typedef TGuard<TMutex> TConnectionsGuard; + + TBusSessionImpl(bool isSource, TBusMessageQueue* queue, TBusProtocol* proto, + IBusErrorHandler* handler, + const TBusSessionConfig& config, const TString& name); + + ~TBusSessionImpl() override; + + void Shutdown() override; + bool IsDown(); + + size_t GetInFlightImpl(const TNetAddr& addr) const; + size_t GetConnectSyscallsNumForTestImpl(const TNetAddr& addr) const; + + void GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const override; + void GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const override; + + virtual void FillStatus(); + TSessionDumpStatus GetStatusRecordInternal() override; + TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) override; + TConnectionStatusMonRecord GetStatusProtobuf() override; + TString GetStatusSingleLine() override; + + void ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionWriterIncrementalStatus&); + void ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionReaderIncrementalStatus&); + void ProcessItem(TStatusTag, TDeadConnectionTag, const TAcceptorStatus&); + void ProcessItem(TStatusTag, ::NActor::TDefaultTag, const TAcceptorStatus&); + void ProcessItem(TConnectionTag, ::NActor::TDefaultTag, const TOnAccept&); + void ProcessItem(TConnectionTag, TRemoveTag, TRemoteConnectionPtr); + void ProcessConnectionsAcceptorsShapshotQueueItem(TAtomicSharedPtr<TConnectionsAcceptorsSnapshot>); + void StatusUpdateCachedDump(); + void StatusUpdateCachedDumpIfNecessary(TInstant now); + void Act(TStatusTag); + void Act(TConnectionTag); + + TBusProtocol* GetProto() const noexcept override; + const TBusSessionConfig* GetConfig() const noexcept override; + TBusMessageQueue* GetQueue() const noexcept override; + TString GetNameInternal() override; + + virtual void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) = 0; + + void Listen(int port, TBusMessageQueue* q); + void Listen(const TVector<TBindResult>& bindTo, TBusMessageQueue* q); + TBusConnection* Accept(SOCKET listen); + + inline ::NActor::TActor<TBusSessionImpl, TStatusTag>* GetStatusActor() { + return this; + } + inline ::NActor::TActor<TBusSessionImpl, TConnectionTag>* GetConnectionsActor() { + return this; + } + + typedef THashMap<TBusSocketAddr, TRemoteConnectionPtr> TAddrRemoteConnections; + + void SendSnapshotToStatusActor(); + + void InsertConnectionLockAcquired(TRemoteConnection* connection); + void InsertAcceptorLockAcquired(TAcceptor* acceptor); + + void GetConnections(TVector<TRemoteConnectionPtr>*); + void GetAcceptors(TVector<TAcceptorPtr>*); + void GetConnectionsLockAquired(TVector<TRemoteConnectionPtr>*); + void GetAcceptorsLockAquired(TVector<TAcceptorPtr>*); + + TRemoteConnectionPtr GetConnection(const TBusSocketAddr& addr, bool create); + TRemoteConnectionPtr GetConnectionById(ui64 id); + TAcceptorPtr GetAcceptorById(ui64 id); + + void InvokeOnError(TNonDestroyingAutoPtr<TBusMessage>, EMessageStatus); + + void Cron(); + + TBusSessionJobCount JobCount; + + // TODO: replace with actor + TMutex ConnectionsLock; + + struct TImpl; + THolder<TImpl> Impl; + + const bool IsSource_; + + TBusMessageQueue* const Queue; + TBusProtocol* const Proto; + // copied to be available after Proto dies + const TString ProtoName; + + IBusErrorHandler* const ErrorHandler; + TUseCountHolder HandlerUseCountHolder; + TBusSessionConfig Config; // TODO: make const + + NEventLoop::TEventLoop WriteEventLoop; + NEventLoop::TEventLoop ReadEventLoop; + THolder<NThreading::TLegacyFuture<void, false>> ReadEventLoopThread; + THolder<NThreading::TLegacyFuture<void, false>> WriteEventLoopThread; + + THashMap<ui64, TRemoteConnectionPtr> ConnectionsById; + TAddrRemoteConnections Connections; + TAcceptorsPtrs Acceptors; + + struct TStatusData { + TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> ConnectionsAcceptorsSnapshot; + ::NActor::TQueueForActor<TAtomicSharedPtr<TConnectionsAcceptorsSnapshot>> ConnectionsAcceptorsSnapshotsQueue; + + TAtomicShutdownState ShutdownState; + + TBusSessionStatus Status; + + TSessionDumpStatus StatusDumpCached; + TMutex StatusDumpCachedMutex; + TInstant StatusDumpCachedLastUpdate; + + TStatusData(); + }; + TStatusData StatusData; + + struct TConnectionsData { + TAtomicShutdownState ShutdownState; + + TConnectionsData(); + }; + TConnectionsData ConnectionsData; + + ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus, + TStatusTag, TDeadConnectionTag>* + GetDeadConnectionWriterStatusQueue() { + return this; + } + + ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus, + TStatusTag, TDeadConnectionTag>* + GetDeadConnectionReaderStatusQueue() { + return this; + } + + ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus, + TStatusTag, TDeadConnectionTag>* + GetDeadAcceptorStatusQueue() { + return this; + } + + template <typename TItem> + ::NActor::IQueueInActor<TItem>* GetQueue() { + return this; + } + + ui64 LastAcceptorId; + ui64 LastConnectionId; + + TAtomic Down; + TSystemEvent ShutdownCompleteEvent; + }; + + inline TBusProtocol* TBusSessionImpl::GetProto() const noexcept { + return Proto; + } + + inline const TBusSessionConfig* TBusSessionImpl::GetConfig() const noexcept { + return &Config; + } + + inline TBusMessageQueue* TBusSessionImpl::GetQueue() const noexcept { + return Queue; + } + + } +} diff --git a/library/cpp/messagebus/session_job_count.cpp b/library/cpp/messagebus/session_job_count.cpp new file mode 100644 index 0000000000..33322b1910 --- /dev/null +++ b/library/cpp/messagebus/session_job_count.cpp @@ -0,0 +1,22 @@ +#include "session_job_count.h" + +#include <util/system/yassert.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +TBusSessionJobCount::TBusSessionJobCount() + : JobCount(0) +{ +} + +TBusSessionJobCount::~TBusSessionJobCount() { + Y_VERIFY(JobCount == 0, "must be 0 job count to destroy job"); +} + +void TBusSessionJobCount::WaitForZero() { + TGuard<TMutex> guard(Mutex); + while (AtomicGet(JobCount) > 0) { + CondVar.WaitI(Mutex); + } +} diff --git a/library/cpp/messagebus/session_job_count.h b/library/cpp/messagebus/session_job_count.h new file mode 100644 index 0000000000..23aca618b1 --- /dev/null +++ b/library/cpp/messagebus/session_job_count.h @@ -0,0 +1,39 @@ +#pragma once + +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/mutex.h> + +namespace NBus { + namespace NPrivate { + class TBusSessionJobCount { + private: + TAtomic JobCount; + + TMutex Mutex; + TCondVar CondVar; + + public: + TBusSessionJobCount(); + ~TBusSessionJobCount(); + + void Add(unsigned delta) { + AtomicAdd(JobCount, delta); + } + + void Increment() { + Add(1); + } + + void Decrement() { + if (AtomicDecrement(JobCount) == 0) { + TGuard<TMutex> guard(Mutex); + CondVar.BroadCast(); + } + } + + void WaitForZero(); + }; + + } +} diff --git a/library/cpp/messagebus/shutdown_state.cpp b/library/cpp/messagebus/shutdown_state.cpp new file mode 100644 index 0000000000..a4e2bfa8b2 --- /dev/null +++ b/library/cpp/messagebus/shutdown_state.cpp @@ -0,0 +1,20 @@ +#include "shutdown_state.h" + +#include <util/system/yassert.h> + +void TAtomicShutdownState::ShutdownCommand() { + Y_VERIFY(State.CompareAndSet(SS_RUNNING, SS_SHUTDOWN_COMMAND)); +} + +void TAtomicShutdownState::CompleteShutdown() { + Y_VERIFY(State.CompareAndSet(SS_SHUTDOWN_COMMAND, SS_SHUTDOWN_COMPLETE)); + ShutdownComplete.Signal(); +} + +bool TAtomicShutdownState::IsRunning() { + return State.Get() == SS_RUNNING; +} + +TAtomicShutdownState::~TAtomicShutdownState() { + Y_VERIFY(SS_SHUTDOWN_COMPLETE == State.Get()); +} diff --git a/library/cpp/messagebus/shutdown_state.h b/library/cpp/messagebus/shutdown_state.h new file mode 100644 index 0000000000..86bd7110ae --- /dev/null +++ b/library/cpp/messagebus/shutdown_state.h @@ -0,0 +1,22 @@ +#pragma once + +#include "misc/atomic_box.h" + +#include <util/system/event.h> + +enum EShutdownState { + SS_RUNNING, + SS_SHUTDOWN_COMMAND, + SS_SHUTDOWN_COMPLETE, +}; + +struct TAtomicShutdownState { + TAtomicBox<EShutdownState> State; + TSystemEvent ShutdownComplete; + + void ShutdownCommand(); + void CompleteShutdown(); + bool IsRunning(); + + ~TAtomicShutdownState(); +}; diff --git a/library/cpp/messagebus/socket_addr.cpp b/library/cpp/messagebus/socket_addr.cpp new file mode 100644 index 0000000000..c1b3a28fbe --- /dev/null +++ b/library/cpp/messagebus/socket_addr.cpp @@ -0,0 +1,79 @@ +#include "socket_addr.h" + +#include "netaddr.h" + +#include <util/network/address.h> +#include <util/network/init.h> +#include <util/system/yassert.h> + +using namespace NAddr; + +using namespace NBus; +using namespace NBus::NPrivate; + +static_assert(ADDR_UNSPEC == 0, "expect ADDR_UNSPEC == 0"); + +NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(const NAddr::IRemoteAddr* addr) + : IPv6ScopeID(0) +{ + const sockaddr* sa = addr->Addr(); + + switch ((EAddrFamily)sa->sa_family) { + case AF_UNSPEC: { + IpAddr.Clear(); + Port = 0; + break; + } + case AF_INET: { + IpAddr.SetInAddr(((const sockaddr_in*)sa)->sin_addr); + Port = InetToHost(((const sockaddr_in*)sa)->sin_port); + break; + } + case AF_INET6: { + IpAddr.SetIn6Addr(((const sockaddr_in6*)sa)->sin6_addr); + Port = InetToHost(((const sockaddr_in*)sa)->sin_port); + IPv6ScopeID = InetToHost(((const sockaddr_in6*)sa)->sin6_scope_id); + break; + } + default: + Y_FAIL("unknown address family"); + } +} + +NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(TStringBuf host, unsigned port) { + *this = TNetAddr(host, port); +} + +NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(const TNetAddr& addr) { + *this = TBusSocketAddr(&addr); +} + +TNetAddr NBus::NPrivate::TBusSocketAddr::ToNetAddr() const { + sockaddr_storage storage; + Zero(storage); + + storage.ss_family = (ui16)IpAddr.GetAddrFamily(); + + switch (IpAddr.GetAddrFamily()) { + case ADDR_UNSPEC: + return TNetAddr(); + case ADDR_IPV4: { + ((sockaddr_in*)&storage)->sin_addr = IpAddr.GetInAddr(); + ((sockaddr_in*)&storage)->sin_port = HostToInet(Port); + break; + } + case ADDR_IPV6: { + ((sockaddr_in6*)&storage)->sin6_addr = IpAddr.GetIn6Addr(); + ((sockaddr_in6*)&storage)->sin6_port = HostToInet(Port); + ((sockaddr_in6*)&storage)->sin6_scope_id = HostToInet(IPv6ScopeID); + break; + } + } + + return TNetAddr(new TOpaqueAddr((sockaddr*)&storage)); +} + +template <> +void Out<TBusSocketAddr>(IOutputStream& out, const TBusSocketAddr& addr) { + out << addr.ToNetAddr(); +} diff --git a/library/cpp/messagebus/socket_addr.h b/library/cpp/messagebus/socket_addr.h new file mode 100644 index 0000000000..959eafe689 --- /dev/null +++ b/library/cpp/messagebus/socket_addr.h @@ -0,0 +1,113 @@ +#pragma once + +#include "hash.h" + +#include <util/generic/hash.h> +#include <util/generic/utility.h> +#include <util/network/address.h> +#include <util/network/init.h> + +#include <string.h> + +namespace NBus { + class TNetAddr; +} + +namespace NBus { + namespace NPrivate { + enum EAddrFamily { + ADDR_UNSPEC = AF_UNSPEC, + ADDR_IPV4 = AF_INET, + ADDR_IPV6 = AF_INET6, + }; + + class TBusIpAddr { + private: + EAddrFamily Af; + + union { + in_addr In4; + in6_addr In6; + }; + + public: + TBusIpAddr() { + Clear(); + } + + EAddrFamily GetAddrFamily() const { + return Af; + } + + void Clear() { + Zero(*this); + } + + in_addr GetInAddr() const { + Y_ASSERT(Af == ADDR_IPV4); + return In4; + } + + void SetInAddr(const in_addr& in4) { + Clear(); + Af = ADDR_IPV4; + In4 = in4; + } + + in6_addr GetIn6Addr() const { + Y_ASSERT(Af == ADDR_IPV6); + return In6; + } + + void SetIn6Addr(const in6_addr& in6) { + Clear(); + Af = ADDR_IPV6; + In6 = in6; + } + + bool operator==(const TBusIpAddr& that) const { + return memcmp(this, &that, sizeof(that)) == 0; + } + }; + + class TBusSocketAddr { + public: + TBusIpAddr IpAddr; + ui16 Port; + + //Only makes sense for IPv6 link-local addresses + ui32 IPv6ScopeID; + + TBusSocketAddr() + : Port(0) + , IPv6ScopeID(0) + { + } + + TBusSocketAddr(const NAddr::IRemoteAddr*); + TBusSocketAddr(const TNetAddr&); + TBusSocketAddr(TStringBuf host, unsigned port); + + TNetAddr ToNetAddr() const; + + bool operator==(const TBusSocketAddr& that) const { + return IpAddr == that.IpAddr && Port == that.Port; + } + }; + + } +} + +template <> +struct THash<NBus::NPrivate::TBusIpAddr> { + inline size_t operator()(const NBus::NPrivate::TBusIpAddr& a) const { + return ComputeHash(TStringBuf((const char*)&a, sizeof(a))); + } +}; + +template <> +struct THash<NBus::NPrivate::TBusSocketAddr> { + inline size_t operator()(const NBus::NPrivate::TBusSocketAddr& a) const { + return HashValues(a.IpAddr, a.Port); + } +}; diff --git a/library/cpp/messagebus/socket_addr_ut.cpp b/library/cpp/messagebus/socket_addr_ut.cpp new file mode 100644 index 0000000000..783bb62a86 --- /dev/null +++ b/library/cpp/messagebus/socket_addr_ut.cpp @@ -0,0 +1,15 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "netaddr.h" +#include "socket_addr.h" + +#include <util/string/cast.h> + +using namespace NBus; +using namespace NBus::NPrivate; + +Y_UNIT_TEST_SUITE(TBusSocketAddr) { + Y_UNIT_TEST(Simple) { + UNIT_ASSERT_VALUES_EQUAL(TString("127.0.0.1:80"), ToString(TBusSocketAddr("127.0.0.1", 80))); + } +} diff --git a/library/cpp/messagebus/storage.cpp b/library/cpp/messagebus/storage.cpp new file mode 100644 index 0000000000..efefc87340 --- /dev/null +++ b/library/cpp/messagebus/storage.cpp @@ -0,0 +1,161 @@ +#include "storage.h" + +#include <typeinfo> + +namespace NBus { + namespace NPrivate { + TTimedMessages::TTimedMessages() { + } + + TTimedMessages::~TTimedMessages() { + Y_VERIFY(Items.empty()); + } + + void TTimedMessages::PushBack(TNonDestroyingAutoPtr<TBusMessage> m) { + TItem i; + i.Message.Reset(m.Release()); + Items.push_back(i); + } + + TNonDestroyingAutoPtr<TBusMessage> TTimedMessages::PopFront() { + TBusMessage* r = nullptr; + if (!Items.empty()) { + r = Items.front()->Message.Release(); + Items.pop_front(); + } + return r; + } + + bool TTimedMessages::Empty() const { + return Items.empty(); + } + + size_t TTimedMessages::Size() const { + return Items.size(); + } + + void TTimedMessages::Timeout(TInstant before, TMessagesPtrs* r) { + // shortcut + if (before == TInstant::Max()) { + Clear(r); + return; + } + + while (!Items.empty()) { + TItem& i = *Items.front(); + if (TInstant::MilliSeconds(i.Message->GetHeader()->SendTime) > before) { + break; + } + r->push_back(i.Message.Release()); + Items.pop_front(); + } + } + + void TTimedMessages::Clear(TMessagesPtrs* r) { + while (!Items.empty()) { + r->push_back(Items.front()->Message.Release()); + Items.pop_front(); + } + } + + TSyncAckMessages::TSyncAckMessages() { + KeyToMessage.set_empty_key(0); + KeyToMessage.set_deleted_key(1); + } + + TSyncAckMessages::~TSyncAckMessages() { + Y_VERIFY(KeyToMessage.empty()); + Y_VERIFY(TimedItems.empty()); + } + + void TSyncAckMessages::Push(TBusMessagePtrAndHeader& m) { + // Perform garbage collection if `TimedMessages` contain too many junk data + if (TimedItems.size() > 1000 && TimedItems.size() > KeyToMessage.size() * 4) { + Gc(); + } + + TValue value = {m.MessagePtr.Release()}; + + std::pair<TKeyToMessage::iterator, bool> p = KeyToMessage.insert(TKeyToMessage::value_type(m.Header.Id, value)); + Y_VERIFY(p.second, "non-unique id; %s", value.Message->Describe().data()); + + TTimedItem item = {m.Header.Id, m.Header.SendTime}; + TimedItems.push_back(item); + } + + TBusMessage* TSyncAckMessages::Pop(TBusKey id) { + TKeyToMessage::iterator it = KeyToMessage.find(id); + if (it == KeyToMessage.end()) { + return nullptr; + } + TValue v = it->second; + KeyToMessage.erase(it); + + // `TimedMessages` still contain record about this message + + return v.Message; + } + + void TSyncAckMessages::Timeout(TInstant before, TMessagesPtrs* r) { + // shortcut + if (before == TInstant::Max()) { + Clear(r); + return; + } + + Y_ASSERT(r->empty()); + + while (!TimedItems.empty()) { + TTimedItem i = TimedItems.front(); + if (TInstant::MilliSeconds(i.SendTime) > before) { + break; + } + + TKeyToMessage::iterator itMessage = KeyToMessage.find(i.Key); + + if (itMessage != KeyToMessage.end()) { + r->push_back(itMessage->second.Message); + KeyToMessage.erase(itMessage); + } + + TimedItems.pop_front(); + } + } + + void TSyncAckMessages::Clear(TMessagesPtrs* r) { + for (TKeyToMessage::const_iterator i = KeyToMessage.begin(); i != KeyToMessage.end(); ++i) { + r->push_back(i->second.Message); + } + + KeyToMessage.clear(); + TimedItems.clear(); + } + + void TSyncAckMessages::Gc() { + TDeque<TTimedItem> tmp; + + for (auto& timedItem : TimedItems) { + if (KeyToMessage.find(timedItem.Key) == KeyToMessage.end()) { + continue; + } + tmp.push_back(timedItem); + } + + TimedItems.swap(tmp); + } + + void TSyncAckMessages::RemoveAll(const TMessagesPtrs& messages) { + for (auto message : messages) { + TKeyToMessage::iterator it = KeyToMessage.find(message->GetHeader()->Id); + Y_VERIFY(it != KeyToMessage.end(), "delete non-existent message"); + KeyToMessage.erase(it); + } + } + + void TSyncAckMessages::DumpState() { + Cerr << TimedItems.size() << Endl; + Cerr << KeyToMessage.size() << Endl; + } + + } +} diff --git a/library/cpp/messagebus/storage.h b/library/cpp/messagebus/storage.h new file mode 100644 index 0000000000..7d168844ed --- /dev/null +++ b/library/cpp/messagebus/storage.h @@ -0,0 +1,94 @@ +#pragma once + +#include "message_ptr_and_header.h" +#include "moved.h" +#include "ybus.h" + +#include <contrib/libs/sparsehash/src/sparsehash/dense_hash_map> + +#include <util/generic/deque.h> +#include <util/generic/noncopyable.h> +#include <util/generic/utility.h> + +namespace NBus { + namespace NPrivate { + typedef TVector<TBusMessage*> TMessagesPtrs; + + class TTimedMessages { + public: + TTimedMessages(); + ~TTimedMessages(); + + struct TItem { + THolder<TBusMessage> Message; + + void Swap(TItem& that) { + DoSwap(Message, that.Message); + } + }; + + typedef TDeque<TMoved<TItem>> TItems; + + void PushBack(TNonDestroyingAutoPtr<TBusMessage> m); + TNonDestroyingAutoPtr<TBusMessage> PopFront(); + bool Empty() const; + size_t Size() const; + + void Timeout(TInstant before, TMessagesPtrs* r); + void Clear(TMessagesPtrs* r); + + private: + TItems Items; + }; + + class TSyncAckMessages : TNonCopyable { + public: + TSyncAckMessages(); + ~TSyncAckMessages(); + + void Push(TBusMessagePtrAndHeader& m); + TBusMessage* Pop(TBusKey id); + + void Timeout(TInstant before, TMessagesPtrs* r); + + void Clear(TMessagesPtrs* r); + + size_t Size() const { + return KeyToMessage.size(); + } + + void RemoveAll(const TMessagesPtrs&); + + void Gc(); + + void DumpState(); + + private: + struct TTimedItem { + TBusKey Key; + TBusInstant SendTime; + }; + + typedef TDeque<TTimedItem> TTimedItems; + typedef TDeque<TTimedItem>::iterator TTimedIterator; + + TTimedItems TimedItems; + + struct TValue { + TBusMessage* Message; + }; + + // keys are already random, no need to hash them further + struct TIdHash { + size_t operator()(TBusKey value) const { + return value; + } + }; + + typedef google::dense_hash_map<TBusKey, TValue, TIdHash> TKeyToMessage; + + TKeyToMessage KeyToMessage; + }; + + } +} diff --git a/library/cpp/messagebus/synchandler.cpp b/library/cpp/messagebus/synchandler.cpp new file mode 100644 index 0000000000..8e891d66b3 --- /dev/null +++ b/library/cpp/messagebus/synchandler.cpp @@ -0,0 +1,198 @@ +#include "remote_client_session.h" +#include "remote_connection.h" +#include "ybus.h" + +using namespace NBus; +using namespace NBus::NPrivate; + +///////////////////////////////////////////////////////////////// +/// Object that encapsulates all messgae data required for sending +/// a message synchronously and receiving a reply. It includes: +/// 1. ConditionVariable to wait on message reply +/// 2. Lock used by condition variable +/// 3. Message reply +/// 4. Reply status +struct TBusSyncMessageData { + TCondVar ReplyEvent; + TMutex ReplyLock; + TBusMessage* Reply; + EMessageStatus ReplyStatus; + + TBusSyncMessageData() + : Reply(nullptr) + , ReplyStatus(MESSAGE_DONT_ASK) + { + } +}; + +class TSyncHandler: public IBusClientHandler { +public: + TSyncHandler(bool expectReply = true) + : ExpectReply(expectReply) + , Session(nullptr) + { + } + ~TSyncHandler() override { + } + + void OnReply(TAutoPtr<TBusMessage> pMessage0, TAutoPtr<TBusMessage> pReply0) override { + TBusMessage* pMessage = pMessage0.Release(); + TBusMessage* pReply = pReply0.Release(); + + if (!ExpectReply) { // Maybe need VERIFY, but it will be better to support backward compatibility here. + return; + } + + TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage->Data); + SignalResult(data, pReply, MESSAGE_OK); + } + + void OnError(TAutoPtr<TBusMessage> pMessage0, EMessageStatus status) override { + TBusMessage* pMessage = pMessage0.Release(); + TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage->Data); + if (!data) { + return; + } + + SignalResult(data, /*pReply=*/nullptr, status); + } + + void OnMessageSent(TBusMessage* pMessage) override { + Y_UNUSED(pMessage); + Y_ASSERT(ExpectReply); + } + + void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override { + Y_ASSERT(!ExpectReply); + TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage.Release()->Data); + SignalResult(data, /*pReply=*/nullptr, MESSAGE_OK); + } + + void SetSession(TRemoteClientSession* session) { + if (!ExpectReply) { + Session = session; + } + } + +private: + void SignalResult(TBusSyncMessageData* data, TBusMessage* pReply, EMessageStatus status) const { + Y_VERIFY(data, "Message data is set to NULL."); + TGuard<TMutex> G(data->ReplyLock); + data->Reply = pReply; + data->ReplyStatus = status; + data->ReplyEvent.Signal(); + } + +private: + // This is weird, because in regular client one-way-ness is selected per call, not per session. + bool ExpectReply; + TRemoteClientSession* Session; +}; + +namespace NBus { + namespace NPrivate { +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance +#endif + + /////////////////////////////////////////////////////////////////////////// + class TBusSyncSourceSessionImpl + : private TSyncHandler + // TODO: do not extend TRemoteClientSession + , + public TRemoteClientSession { + private: + bool NeedReply; + + public: + TBusSyncSourceSessionImpl(TBusMessageQueue* queue, TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply, const TString& name) + : TSyncHandler(needReply) + , TRemoteClientSession(queue, proto, this, config, name) + , NeedReply(needReply) + { + SetSession(this); + } + + TBusMessage* SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr = nullptr) { + Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(), + "SendSyncMessage must not be called from executor thread"); + + TBusMessage* reply = nullptr; + THolder<TBusSyncMessageData> data(new TBusSyncMessageData()); + + pMessage->Data = data.Get(); + + { + TGuard<TMutex> G(data->ReplyLock); + if (NeedReply) { + status = SendMessage(pMessage, addr, false); // probably should be true + } else { + status = SendMessageOneWay(pMessage, addr); + } + + if (status == MESSAGE_OK) { + data->ReplyEvent.Wait(data->ReplyLock); + TBusSyncMessageData* rdata = static_cast<TBusSyncMessageData*>(pMessage->Data); + Y_VERIFY(rdata == data.Get(), "Message data pointer should not be modified."); + reply = rdata->Reply; + status = rdata->ReplyStatus; + } + } + + // deletion of message and reply is a job of application. + pMessage->Data = nullptr; + + return reply; + } + }; + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + } +} + +TBusSyncSourceSession::TBusSyncSourceSession(TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> session) + : Session(session) +{ +} + +TBusSyncSourceSession::~TBusSyncSourceSession() { + Shutdown(); +} + +void TBusSyncSourceSession::Shutdown() { + Session->Shutdown(); +} + +TBusMessage* TBusSyncSourceSession::SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr) { + return Session->SendSyncMessage(pMessage, status, addr); +} + +int TBusSyncSourceSession::RegisterService(const char* hostname, TBusKey start, TBusKey end, EIpVersion ipVersion) { + return Session->RegisterService(hostname, start, end, ipVersion); +} + +int TBusSyncSourceSession::GetInFlight() { + return Session->GetInFlight(); +} + +const TBusProtocol* TBusSyncSourceSession::GetProto() const { + return Session->GetProto(); +} + +const TBusClientSession* TBusSyncSourceSession::GetBusClientSessionWorkaroundDoNotUse() const { + return Session.Get(); +} + +TBusSyncClientSessionPtr TBusMessageQueue::CreateSyncSource(TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply, const TString& name) { + TIntrusivePtr<TBusSyncSourceSessionImpl> session = new TBusSyncSourceSessionImpl(this, proto, config, needReply, name); + Add(session.Get()); + return new TBusSyncSourceSession(session); +} + +void TBusMessageQueue::Destroy(TBusSyncClientSessionPtr session) { + Destroy(session->Session.Get()); + Y_UNUSED(session->Session.Release()); +} diff --git a/library/cpp/messagebus/test/TestMessageBus.py b/library/cpp/messagebus/test/TestMessageBus.py new file mode 100644 index 0000000000..0bbaa0a313 --- /dev/null +++ b/library/cpp/messagebus/test/TestMessageBus.py @@ -0,0 +1,8 @@ +from devtools.fleur.ytest import group, constraint +from devtools.fleur.ytest.integration import UnitTestGroup + +@group +@constraint('library.messagebus') +class TestMessageBus(UnitTestGroup): + def __init__(self, context): + UnitTestGroup.__init__(self, context, 'MessageBus', 'library-messagebus-test-ut') diff --git a/library/cpp/messagebus/test/example/client/client.cpp b/library/cpp/messagebus/test/example/client/client.cpp new file mode 100644 index 0000000000..89b5f2c9be --- /dev/null +++ b/library/cpp/messagebus/test/example/client/client.cpp @@ -0,0 +1,81 @@ +#include <library/cpp/messagebus/test/example/common/proto.h> + +#include <util/random/random.h> + +using namespace NBus; +using namespace NCalculator; + +namespace NCalculator { + struct TCalculatorClient: public IBusClientHandler { + TCalculatorProtocol Proto; + TBusMessageQueuePtr MessageQueue; + TBusClientSessionPtr ClientSession; + + TCalculatorClient() { + MessageQueue = CreateMessageQueue(); + TBusClientSessionConfig config; + config.TotalTimeout = 2 * 1000; + ClientSession = TBusClientSession::Create(&Proto, this, config, MessageQueue); + } + + ~TCalculatorClient() override { + MessageQueue->Stop(); + } + + void OnReply(TAutoPtr<TBusMessage> request, TAutoPtr<TBusMessage> response0) override { + Y_VERIFY(response0->GetHeader()->Type == TResponse::MessageType, "wrong response"); + TResponse* response = VerifyDynamicCast<TResponse*>(response0.Get()); + if (request->GetHeader()->Type == TRequestSum::MessageType) { + TRequestSum* requestSum = VerifyDynamicCast<TRequestSum*>(request.Get()); + int a = requestSum->Record.GetA(); + int b = requestSum->Record.GetB(); + Cerr << a << " + " << b << " = " << response->Record.GetResult() << "\n"; + } else if (request->GetHeader()->Type == TRequestMul::MessageType) { + TRequestMul* requestMul = VerifyDynamicCast<TRequestMul*>(request.Get()); + int a = requestMul->Record.GetA(); + int b = requestMul->Record.GetB(); + Cerr << a << " * " << b << " = " << response->Record.GetResult() << "\n"; + } else { + Y_FAIL("unknown request"); + } + } + + void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override { + Cerr << "got error " << status << "\n"; + } + }; + +} + +int main(int, char**) { + TCalculatorClient client; + + for (;;) { + TNetAddr addr(TNetAddr("127.0.0.1", TCalculatorProtocol().GetPort())); + + int a = RandomNumber<unsigned>(10); + int b = RandomNumber<unsigned>(10); + EMessageStatus ok; + if (RandomNumber<bool>()) { + TAutoPtr<TRequestSum> request(new TRequestSum); + request->Record.SetA(a); + request->Record.SetB(b); + Cerr << "sending " << a << " + " << b << "\n"; + ok = client.ClientSession->SendMessageAutoPtr(request, &addr); + } else { + TAutoPtr<TRequestMul> request(new TRequestMul); + request->Record.SetA(a); + request->Record.SetB(b); + Cerr << "sending " << a << " * " << b << "\n"; + ok = client.ClientSession->SendMessageAutoPtr(request, &addr); + } + + if (ok != MESSAGE_OK) { + Cerr << "failed to send message " << ok << "\n"; + } + + Sleep(TDuration::Seconds(1)); + } + + return 0; +} diff --git a/library/cpp/messagebus/test/example/client/ya.make b/library/cpp/messagebus/test/example/client/ya.make new file mode 100644 index 0000000000..a660a01698 --- /dev/null +++ b/library/cpp/messagebus/test/example/client/ya.make @@ -0,0 +1,13 @@ +PROGRAM(messagebus_example_client) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus/test/example/common +) + +SRCS( + client.cpp +) + +END() diff --git a/library/cpp/messagebus/test/example/common/messages.proto b/library/cpp/messagebus/test/example/common/messages.proto new file mode 100644 index 0000000000..16b858fc77 --- /dev/null +++ b/library/cpp/messagebus/test/example/common/messages.proto @@ -0,0 +1,15 @@ +package NCalculator; + +message TRequestSumRecord { + required int32 A = 1; + required int32 B = 2; +} + +message TRequestMulRecord { + required int32 A = 1; + required int32 B = 2; +} + +message TResponseRecord { + required int32 Result = 1; +} diff --git a/library/cpp/messagebus/test/example/common/proto.cpp b/library/cpp/messagebus/test/example/common/proto.cpp new file mode 100644 index 0000000000..1d18aa77ea --- /dev/null +++ b/library/cpp/messagebus/test/example/common/proto.cpp @@ -0,0 +1,12 @@ +#include "proto.h" + +using namespace NCalculator; +using namespace NBus; + +TCalculatorProtocol::TCalculatorProtocol() + : TBusBufferProtocol("Calculator", 34567) +{ + RegisterType(new TRequestSum); + RegisterType(new TRequestMul); + RegisterType(new TResponse); +} diff --git a/library/cpp/messagebus/test/example/common/proto.h b/library/cpp/messagebus/test/example/common/proto.h new file mode 100644 index 0000000000..a151aac468 --- /dev/null +++ b/library/cpp/messagebus/test/example/common/proto.h @@ -0,0 +1,17 @@ +#pragma once + +#include <library/cpp/messagebus/test/example/common/messages.pb.h> + +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/protobuf/ybusbuf.h> + +namespace NCalculator { + typedef ::NBus::TBusBufferMessage<TRequestSumRecord, 1> TRequestSum; + typedef ::NBus::TBusBufferMessage<TRequestMulRecord, 2> TRequestMul; + typedef ::NBus::TBusBufferMessage<TResponseRecord, 3> TResponse; + + struct TCalculatorProtocol: public ::NBus::TBusBufferProtocol { + TCalculatorProtocol(); + }; + +} diff --git a/library/cpp/messagebus/test/example/common/ya.make b/library/cpp/messagebus/test/example/common/ya.make new file mode 100644 index 0000000000..4da16608fc --- /dev/null +++ b/library/cpp/messagebus/test/example/common/ya.make @@ -0,0 +1,15 @@ +LIBRARY(messagebus_test_example_common) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus + library/cpp/messagebus/protobuf +) + +SRCS( + proto.cpp + messages.proto +) + +END() diff --git a/library/cpp/messagebus/test/example/server/server.cpp b/library/cpp/messagebus/test/example/server/server.cpp new file mode 100644 index 0000000000..13e52d75f5 --- /dev/null +++ b/library/cpp/messagebus/test/example/server/server.cpp @@ -0,0 +1,58 @@ +#include <library/cpp/messagebus/test/example/common/proto.h> + +using namespace NBus; +using namespace NCalculator; + +namespace NCalculator { + struct TCalculatorServer: public IBusServerHandler { + TCalculatorProtocol Proto; + TBusMessageQueuePtr MessageQueue; + TBusServerSessionPtr ServerSession; + + TCalculatorServer() { + MessageQueue = CreateMessageQueue(); + TBusServerSessionConfig config; + ServerSession = TBusServerSession::Create(&Proto, this, config, MessageQueue); + } + + ~TCalculatorServer() override { + MessageQueue->Stop(); + } + + void OnMessage(TOnMessageContext& request) override { + if (request.GetMessage()->GetHeader()->Type == TRequestSum::MessageType) { + TRequestSum* requestSum = VerifyDynamicCast<TRequestSum*>(request.GetMessage()); + int a = requestSum->Record.GetA(); + int b = requestSum->Record.GetB(); + int result = a + b; + Cerr << "requested " << a << " + " << b << ", sending " << result << "\n"; + TAutoPtr<TResponse> response(new TResponse); + response->Record.SetResult(result); + request.SendReplyMove(response); + } else if (request.GetMessage()->GetHeader()->Type == TRequestMul::MessageType) { + TRequestMul* requestMul = VerifyDynamicCast<TRequestMul*>(request.GetMessage()); + int a = requestMul->Record.GetA(); + int b = requestMul->Record.GetB(); + int result = a * b; + Cerr << "requested " << a << " * " << b << ", sending " << result << "\n"; + TAutoPtr<TResponse> response(new TResponse); + response->Record.SetResult(result); + request.SendReplyMove(response); + } else { + Y_FAIL("unknown request"); + } + } + }; +} + +int main(int, char**) { + TCalculatorServer server; + + Cerr << "listening on port " << server.ServerSession->GetActualListenPort() << "\n"; + + for (;;) { + Sleep(TDuration::Seconds(1)); + } + + return 0; +} diff --git a/library/cpp/messagebus/test/example/server/ya.make b/library/cpp/messagebus/test/example/server/ya.make new file mode 100644 index 0000000000..8cdd97cb12 --- /dev/null +++ b/library/cpp/messagebus/test/example/server/ya.make @@ -0,0 +1,13 @@ +PROGRAM(messagebus_example_server) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/messagebus/test/example/common +) + +SRCS( + server.cpp +) + +END() diff --git a/library/cpp/messagebus/test/example/ya.make b/library/cpp/messagebus/test/example/ya.make new file mode 100644 index 0000000000..f275351c29 --- /dev/null +++ b/library/cpp/messagebus/test/example/ya.make @@ -0,0 +1,7 @@ +OWNER(g:messagebus) + +RECURSE( + client + common + server +) diff --git a/library/cpp/messagebus/test/helper/alloc_counter.h b/library/cpp/messagebus/test/helper/alloc_counter.h new file mode 100644 index 0000000000..ec9041cb15 --- /dev/null +++ b/library/cpp/messagebus/test/helper/alloc_counter.h @@ -0,0 +1,21 @@ +#pragma once + +#include <util/generic/noncopyable.h> +#include <util/system/atomic.h> +#include <util/system/yassert.h> + +class TAllocCounter : TNonCopyable { +private: + TAtomic* CountPtr; + +public: + TAllocCounter(TAtomic* countPtr) + : CountPtr(countPtr) + { + AtomicIncrement(*CountPtr); + } + + ~TAllocCounter() { + Y_VERIFY(AtomicDecrement(*CountPtr) >= 0, "released too many"); + } +}; diff --git a/library/cpp/messagebus/test/helper/example.cpp b/library/cpp/messagebus/test/helper/example.cpp new file mode 100644 index 0000000000..7c6d704042 --- /dev/null +++ b/library/cpp/messagebus/test/helper/example.cpp @@ -0,0 +1,281 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "example.h" + +#include <util/generic/cast.h> + +using namespace NBus; +using namespace NBus::NTest; + +static void FillWithJunk(TArrayRef<char> data) { + TStringBuf junk = + "01234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01234567890123456789012345678901234567890123456789012345678901234567890123456789" + "01234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + for (size_t i = 0; i < data.size(); i += junk.size()) { + memcpy(data.data() + i, junk.data(), Min(junk.size(), data.size() - i)); + } +} + +static TString JunkString(size_t len) { + TTempBuf temp(len); + TArrayRef<char> tempArrayRef(temp.Data(), len); + FillWithJunk(tempArrayRef); + + return TString(tempArrayRef.data(), tempArrayRef.size()); +} + +TExampleRequest::TExampleRequest(TAtomic* counterPtr, size_t payloadSize) + : TBusMessage(77) + , AllocCounter(counterPtr) + , Data(JunkString(payloadSize)) +{ +} + +TExampleRequest::TExampleRequest(ECreateUninitialized, TAtomic* counterPtr) + : TBusMessage(MESSAGE_CREATE_UNINITIALIZED) + , AllocCounter(counterPtr) +{ +} + +TExampleResponse::TExampleResponse(TAtomic* counterPtr, size_t payloadSize) + : TBusMessage(79) + , AllocCounter(counterPtr) + , Data(JunkString(payloadSize)) +{ +} + +TExampleResponse::TExampleResponse(ECreateUninitialized, TAtomic* counterPtr) + : TBusMessage(MESSAGE_CREATE_UNINITIALIZED) + , AllocCounter(counterPtr) +{ +} + +TExampleProtocol::TExampleProtocol(int port) + : TBusProtocol("Example", port) + , RequestCount(0) + , ResponseCount(0) + , RequestCountDeserialized(0) + , ResponseCountDeserialized(0) + , StartCount(0) +{ +} + +TExampleProtocol::~TExampleProtocol() { + if (UncaughtException()) { + // so it could be reported in test + return; + } + Y_VERIFY(0 == AtomicGet(RequestCount), "protocol %s: must be 0 requests allocated, actually %d", GetService(), int(RequestCount)); + Y_VERIFY(0 == AtomicGet(ResponseCount), "protocol %s: must be 0 responses allocated, actually %d", GetService(), int(ResponseCount)); + Y_VERIFY(0 == AtomicGet(RequestCountDeserialized), "protocol %s: must be 0 requests deserialized allocated, actually %d", GetService(), int(RequestCountDeserialized)); + Y_VERIFY(0 == AtomicGet(ResponseCountDeserialized), "protocol %s: must be 0 responses deserialized allocated, actually %d", GetService(), int(ResponseCountDeserialized)); + Y_VERIFY(0 == AtomicGet(StartCount), "protocol %s: must be 0 start objects allocated, actually %d", GetService(), int(StartCount)); +} + +void TExampleProtocol::Serialize(const TBusMessage* message, TBuffer& buffer) { + // Messages have no data, we recreate them from scratch + // instead of sending, so we don't need to serialize them. + if (const TExampleRequest* exampleMessage = dynamic_cast<const TExampleRequest*>(message)) { + buffer.Append(exampleMessage->Data.data(), exampleMessage->Data.size()); + } else if (const TExampleResponse* exampleReply = dynamic_cast<const TExampleResponse*>(message)) { + buffer.Append(exampleReply->Data.data(), exampleReply->Data.size()); + } else { + Y_FAIL("unknown message type"); + } +} + +TAutoPtr<TBusMessage> TExampleProtocol::Deserialize(ui16 messageType, TArrayRef<const char> payload) { + // TODO: check data + Y_UNUSED(payload); + + if (messageType == 77) { + TExampleRequest* exampleMessage = new TExampleRequest(MESSAGE_CREATE_UNINITIALIZED, &RequestCountDeserialized); + exampleMessage->Data.append(payload.data(), payload.size()); + return exampleMessage; + } else if (messageType == 79) { + TExampleResponse* exampleReply = new TExampleResponse(MESSAGE_CREATE_UNINITIALIZED, &ResponseCountDeserialized); + exampleReply->Data.append(payload.data(), payload.size()); + return exampleReply; + } else { + return nullptr; + } +} + +TExampleClient::TExampleClient(const TBusClientSessionConfig sessionConfig, int port) + : Proto(port) + , UseCompression(false) + , CrashOnError(false) + , DataSize(320) + , MessageCount(0) + , RepliesCount(0) + , Errors(0) + , LastError(MESSAGE_OK) +{ + Bus = CreateMessageQueue("TExampleClient"); + + Session = TBusClientSession::Create(&Proto, this, sessionConfig, Bus); + + Session->RegisterService("localhost"); +} + +TExampleClient::~TExampleClient() { +} + +EMessageStatus TExampleClient::SendMessage(const TNetAddr* addr) { + TAutoPtr<TExampleRequest> message(new TExampleRequest(&Proto.RequestCount, DataSize)); + message->SetCompressed(UseCompression); + return Session->SendMessageAutoPtr(message, addr); +} + +void TExampleClient::SendMessages(size_t count, const TNetAddr* addr) { + UNIT_ASSERT(MessageCount == 0); + UNIT_ASSERT(RepliesCount == 0); + UNIT_ASSERT(Errors == 0); + + WorkDone.Reset(); + MessageCount = count; + for (ssize_t i = 0; i < MessageCount; ++i) { + EMessageStatus s = SendMessage(addr); + UNIT_ASSERT_EQUAL_C(s, MESSAGE_OK, "expecting OK, got " << s); + } +} + +void TExampleClient::SendMessages(size_t count, const TNetAddr& addr) { + SendMessages(count, &addr); +} + +void TExampleClient::ResetCounters() { + MessageCount = 0; + RepliesCount = 0; + Errors = 0; + LastError = MESSAGE_OK; + + WorkDone.Reset(); +} + +void TExampleClient::WaitReplies() { + WorkDone.WaitT(TDuration::Seconds(60)); + + UNIT_ASSERT_VALUES_EQUAL(AtomicGet(RepliesCount), MessageCount); + UNIT_ASSERT_VALUES_EQUAL(AtomicGet(Errors), 0); + UNIT_ASSERT_VALUES_EQUAL(Session->GetInFlight(), 0); + + ResetCounters(); +} + +EMessageStatus TExampleClient::WaitForError() { + WorkDone.WaitT(TDuration::Seconds(60)); + + UNIT_ASSERT_VALUES_EQUAL(1, MessageCount); + UNIT_ASSERT_VALUES_EQUAL(0, AtomicGet(RepliesCount)); + UNIT_ASSERT_VALUES_EQUAL(0, Session->GetInFlight()); + UNIT_ASSERT_VALUES_EQUAL(1, Errors); + EMessageStatus result = LastError; + + ResetCounters(); + return result; +} + +void TExampleClient::WaitForError(EMessageStatus status) { + EMessageStatus error = WaitForError(); + UNIT_ASSERT_VALUES_EQUAL(status, error); +} + +void TExampleClient::SendMessagesWaitReplies(size_t count, const TNetAddr* addr) { + SendMessages(count, addr); + WaitReplies(); +} + +void TExampleClient::SendMessagesWaitReplies(size_t count, const TNetAddr& addr) { + SendMessagesWaitReplies(count, &addr); +} + +void TExampleClient::OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) { + Y_UNUSED(mess); + Y_UNUSED(reply); + + if (AtomicIncrement(RepliesCount) == MessageCount) { + WorkDone.Signal(); + } +} + +void TExampleClient::OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) { + if (CrashOnError) { + Y_FAIL("client failed: %s", ToCString(status)); + } + + Y_UNUSED(mess); + + AtomicIncrement(Errors); + LastError = status; + WorkDone.Signal(); +} + +TExampleServer::TExampleServer( + const char* name, + const TBusServerSessionConfig& sessionConfig) + : UseCompression(false) + , AckMessageBeforeSendReply(false) + , ForgetRequest(false) +{ + Bus = CreateMessageQueue(name); + Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus); +} + +TExampleServer::TExampleServer(unsigned port, const char* name) + : UseCompression(false) + , AckMessageBeforeSendReply(false) + , ForgetRequest(false) +{ + Bus = CreateMessageQueue(name); + TBusServerSessionConfig sessionConfig; + sessionConfig.ListenPort = port; + Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus); +} + +TExampleServer::~TExampleServer() { +} + +size_t TExampleServer::GetInFlight() const { + return Session->GetInFlight(); +} + +unsigned TExampleServer::GetActualListenPort() const { + return Session->GetActualListenPort(); +} + +TNetAddr TExampleServer::GetActualListenAddr() const { + return TNetAddr("127.0.0.1", GetActualListenPort()); +} + +void TExampleServer::WaitForOnMessageCount(unsigned n) { + TestSync.WaitFor(n); +} + +void TExampleServer::OnMessage(TOnMessageContext& mess) { + TestSync.Inc(); + + TExampleRequest* request = VerifyDynamicCast<TExampleRequest*>(mess.GetMessage()); + + if (ForgetRequest) { + mess.ForgetRequest(); + return; + } + + TAutoPtr<TBusMessage> reply(new TExampleResponse(&Proto.ResponseCount, DataSize.GetOrElse(request->Data.size()))); + reply->SetCompressed(UseCompression); + + EMessageStatus status; + if (AckMessageBeforeSendReply) { + TBusIdentity ident; + mess.AckMessage(ident); + status = Session->SendReply(ident, reply.Release()); // TODO: leaks on error + } else { + status = mess.SendReplyMove(reply); + } + + Y_VERIFY(status == MESSAGE_OK, "failed to send reply: %s", ToString(status).data()); +} diff --git a/library/cpp/messagebus/test/helper/example.h b/library/cpp/messagebus/test/helper/example.h new file mode 100644 index 0000000000..26b7475308 --- /dev/null +++ b/library/cpp/messagebus/test/helper/example.h @@ -0,0 +1,132 @@ +#pragma once + +#include <library/cpp/testing/unittest/registar.h> + +#include "alloc_counter.h" +#include "message_handler_error.h" + +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/misc/test_sync.h> + +#include <util/system/event.h> + +namespace NBus { + namespace NTest { + class TExampleRequest: public TBusMessage { + friend class TExampleProtocol; + + private: + TAllocCounter AllocCounter; + + public: + TString Data; + + public: + TExampleRequest(TAtomic* counterPtr, size_t payloadSize = 320); + TExampleRequest(ECreateUninitialized, TAtomic* counterPtr); + }; + + class TExampleResponse: public TBusMessage { + friend class TExampleProtocol; + + private: + TAllocCounter AllocCounter; + + public: + TString Data; + TExampleResponse(TAtomic* counterPtr, size_t payloadSize = 320); + TExampleResponse(ECreateUninitialized, TAtomic* counterPtr); + }; + + class TExampleProtocol: public TBusProtocol { + public: + TAtomic RequestCount; + TAtomic ResponseCount; + TAtomic RequestCountDeserialized; + TAtomic ResponseCountDeserialized; + TAtomic StartCount; + + TExampleProtocol(int port = 0); + + ~TExampleProtocol() override; + + void Serialize(const TBusMessage* message, TBuffer& buffer) override; + + TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override; + }; + + class TExampleClient: private TBusClientHandlerError { + public: + TExampleProtocol Proto; + bool UseCompression; + bool CrashOnError; + size_t DataSize; + + ssize_t MessageCount; + TAtomic RepliesCount; + TAtomic Errors; + EMessageStatus LastError; + + TSystemEvent WorkDone; + + TBusMessageQueuePtr Bus; + TBusClientSessionPtr Session; + + public: + TExampleClient(const TBusClientSessionConfig sessionConfig = TBusClientSessionConfig(), int port = 0); + ~TExampleClient() override; + + EMessageStatus SendMessage(const TNetAddr* addr = nullptr); + + void SendMessages(size_t count, const TNetAddr* addr = nullptr); + void SendMessages(size_t count, const TNetAddr& addr); + + void ResetCounters(); + void WaitReplies(); + EMessageStatus WaitForError(); + void WaitForError(EMessageStatus status); + + void SendMessagesWaitReplies(size_t count, const TNetAddr* addr = nullptr); + void SendMessagesWaitReplies(size_t count, const TNetAddr& addr); + + void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override; + + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus) override; + }; + + class TExampleServer: private TBusServerHandlerError { + public: + TExampleProtocol Proto; + bool UseCompression; + bool AckMessageBeforeSendReply; + TMaybe<size_t> DataSize; // Nothing means use request size + bool ForgetRequest; + + TTestSync TestSync; + + TBusMessageQueuePtr Bus; + TBusServerSessionPtr Session; + + public: + TExampleServer( + const char* name = "TExampleServer", + const TBusServerSessionConfig& sessionConfig = TBusServerSessionConfig()); + + TExampleServer(unsigned port, const char* name = "TExampleServer"); + + ~TExampleServer() override; + + public: + size_t GetInFlight() const; + unsigned GetActualListenPort() const; + // any of + TNetAddr GetActualListenAddr() const; + + void WaitForOnMessageCount(unsigned n); + + protected: + void OnMessage(TOnMessageContext& mess) override; + }; + + } +} diff --git a/library/cpp/messagebus/test/helper/example_module.cpp b/library/cpp/messagebus/test/helper/example_module.cpp new file mode 100644 index 0000000000..65ecfcf73f --- /dev/null +++ b/library/cpp/messagebus/test/helper/example_module.cpp @@ -0,0 +1,43 @@ +#include "example_module.h" + +using namespace NBus; +using namespace NBus::NTest; + +TExampleModule::TExampleModule() + : TBusModule("TExampleModule") +{ + TBusQueueConfig queueConfig; + queueConfig.NumWorkers = 5; + Queue = CreateMessageQueue(queueConfig); +} + +void TExampleModule::StartModule() { + CreatePrivateSessions(Queue.Get()); + StartInput(); +} + +bool TExampleModule::Shutdown() { + TBusModule::Shutdown(); + return true; +} + +TBusServerSessionPtr TExampleModule::CreateExtSession(TBusMessageQueue&) { + return nullptr; +} + +TBusServerSessionPtr TExampleServerModule::CreateExtSession(TBusMessageQueue& queue) { + TBusServerSessionPtr r = CreateDefaultDestination(queue, &Proto, TBusServerSessionConfig()); + ServerAddr = TNetAddr("localhost", r->GetActualListenPort()); + return r; +} + +TExampleClientModule::TExampleClientModule() + : Source() +{ +} + +TBusServerSessionPtr TExampleClientModule::CreateExtSession(TBusMessageQueue& queue) { + Source = CreateDefaultSource(queue, &Proto, TBusServerSessionConfig()); + Source->RegisterService("localhost"); + return nullptr; +} diff --git a/library/cpp/messagebus/test/helper/example_module.h b/library/cpp/messagebus/test/helper/example_module.h new file mode 100644 index 0000000000..a0b295f613 --- /dev/null +++ b/library/cpp/messagebus/test/helper/example_module.h @@ -0,0 +1,37 @@ +#pragma once + +#include "example.h" + +#include <library/cpp/messagebus/oldmodule/module.h> + +namespace NBus { + namespace NTest { + struct TExampleModule: public TBusModule { + TExampleProtocol Proto; + TBusMessageQueuePtr Queue; + + TExampleModule(); + + void StartModule(); + + bool Shutdown() override; + + // nop by default + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override; + }; + + struct TExampleServerModule: public TExampleModule { + TNetAddr ServerAddr; + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override; + }; + + struct TExampleClientModule: public TExampleModule { + TBusClientSessionPtr Source; + + TExampleClientModule(); + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override; + }; + + } +} diff --git a/library/cpp/messagebus/test/helper/fixed_port.cpp b/library/cpp/messagebus/test/helper/fixed_port.cpp new file mode 100644 index 0000000000..258da0d1a5 --- /dev/null +++ b/library/cpp/messagebus/test/helper/fixed_port.cpp @@ -0,0 +1,10 @@ +#include "fixed_port.h" + +#include <util/system/env.h> + +#include <stdlib.h> + +bool NBus::NTest::IsFixedPortTestAllowed() { + // TODO: report skipped tests to test + return !GetEnv("MB_TESTS_SKIP_FIXED_PORT"); +} diff --git a/library/cpp/messagebus/test/helper/fixed_port.h b/library/cpp/messagebus/test/helper/fixed_port.h new file mode 100644 index 0000000000..a9c61ebc63 --- /dev/null +++ b/library/cpp/messagebus/test/helper/fixed_port.h @@ -0,0 +1,11 @@ +#pragma once + +namespace NBus { + namespace NTest { + bool IsFixedPortTestAllowed(); + + // Must not be in range OS uses for bind on random port. + const unsigned FixedPort = 4927; + + } +} diff --git a/library/cpp/messagebus/test/helper/hanging_server.cpp b/library/cpp/messagebus/test/helper/hanging_server.cpp new file mode 100644 index 0000000000..a35514b00d --- /dev/null +++ b/library/cpp/messagebus/test/helper/hanging_server.cpp @@ -0,0 +1,13 @@ +#include "hanging_server.h" + +#include <util/system/yassert.h> + +using namespace NBus; + +THangingServer::THangingServer(int port) { + BindResult = BindOnPort(port, false); +} + +int THangingServer::GetPort() const { + return BindResult.first; +} diff --git a/library/cpp/messagebus/test/helper/hanging_server.h b/library/cpp/messagebus/test/helper/hanging_server.h new file mode 100644 index 0000000000..cc9fb274d8 --- /dev/null +++ b/library/cpp/messagebus/test/helper/hanging_server.h @@ -0,0 +1,16 @@ +#pragma once + +#include <library/cpp/messagebus/network.h> + +#include <util/network/sock.h> + +class THangingServer { +private: + std::pair<unsigned, TVector<NBus::TBindResult>> BindResult; + +public: + // listen on given port, and nothing else + THangingServer(int port = 0); + // actual port + int GetPort() const; +}; diff --git a/library/cpp/messagebus/test/helper/message_handler_error.cpp b/library/cpp/messagebus/test/helper/message_handler_error.cpp new file mode 100644 index 0000000000..c09811ec67 --- /dev/null +++ b/library/cpp/messagebus/test/helper/message_handler_error.cpp @@ -0,0 +1,26 @@ +#include "message_handler_error.h" + +#include <util/system/yassert.h> + +using namespace NBus; +using namespace NBus::NTest; + +void TBusClientHandlerError::OnError(TAutoPtr<TBusMessage>, EMessageStatus status) { + Y_FAIL("must not be called, status: %s", ToString(status).data()); +} + +void TBusClientHandlerError::OnReply(TAutoPtr<TBusMessage>, TAutoPtr<TBusMessage>) { + Y_FAIL("must not be called"); +} + +void TBusClientHandlerError::OnMessageSentOneWay(TAutoPtr<TBusMessage>) { + Y_FAIL("must not be called"); +} + +void TBusServerHandlerError::OnError(TAutoPtr<TBusMessage>, EMessageStatus status) { + Y_FAIL("must not be called, status: %s", ToString(status).data()); +} + +void TBusServerHandlerError::OnMessage(TOnMessageContext&) { + Y_FAIL("must not be called"); +} diff --git a/library/cpp/messagebus/test/helper/message_handler_error.h b/library/cpp/messagebus/test/helper/message_handler_error.h new file mode 100644 index 0000000000..a314b10761 --- /dev/null +++ b/library/cpp/messagebus/test/helper/message_handler_error.h @@ -0,0 +1,19 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +namespace NBus { + namespace NTest { + struct TBusClientHandlerError: public IBusClientHandler { + void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) override; + void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override; + void OnReply(TAutoPtr<TBusMessage> pMessage, TAutoPtr<TBusMessage> pReply) override; + }; + + struct TBusServerHandlerError: public IBusServerHandler { + void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) override; + void OnMessage(TOnMessageContext& pMessage) override; + }; + + } +} diff --git a/library/cpp/messagebus/test/helper/object_count_check.h b/library/cpp/messagebus/test/helper/object_count_check.h new file mode 100644 index 0000000000..1c4756e58c --- /dev/null +++ b/library/cpp/messagebus/test/helper/object_count_check.h @@ -0,0 +1,74 @@ +#pragma once + +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/remote_client_connection.h> +#include <library/cpp/messagebus/remote_client_session.h> +#include <library/cpp/messagebus/remote_server_connection.h> +#include <library/cpp/messagebus/remote_server_session.h> +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/oldmodule/module.h> +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <util/generic/object_counter.h> +#include <util/system/type_name.h> +#include <util/stream/output.h> + +#include <typeinfo> + +struct TObjectCountCheck { + bool Enabled; + + template <typename T> + struct TReset { + TObjectCountCheck* const Thiz; + + TReset(TObjectCountCheck* thiz) + : Thiz(thiz) + { + } + + void operator()() { + long oldValue = TObjectCounter<T>::ResetObjectCount(); + if (oldValue != 0) { + Cerr << "warning: previous counter: " << oldValue << " for " << TypeName<T>() << Endl; + Cerr << "won't check in this test" << Endl; + Thiz->Enabled = false; + } + } + }; + + TObjectCountCheck() { + Enabled = true; + DoForAllCounters<TReset>(); + } + + template <typename T> + struct TCheckZero { + TCheckZero(TObjectCountCheck*) { + } + + void operator()() { + UNIT_ASSERT_VALUES_EQUAL_C(0L, TObjectCounter<T>::ObjectCount(), TypeName<T>()); + } + }; + + ~TObjectCountCheck() { + if (Enabled) { + DoForAllCounters<TCheckZero>(); + } + } + + template <template <typename> class TOp> + void DoForAllCounters() { + TOp< ::NBus::NPrivate::TRemoteClientConnection>(this)(); + TOp< ::NBus::NPrivate::TRemoteServerConnection>(this)(); + TOp< ::NBus::NPrivate::TRemoteClientSession>(this)(); + TOp< ::NBus::NPrivate::TRemoteServerSession>(this)(); + TOp< ::NBus::NPrivate::TScheduler>(this)(); + TOp< ::NEventLoop::TEventLoop>(this)(); + TOp< ::NEventLoop::TChannel>(this)(); + TOp< ::NBus::TBusModule>(this)(); + TOp< ::NBus::TBusJob>(this)(); + } +}; diff --git a/library/cpp/messagebus/test/helper/wait_for.h b/library/cpp/messagebus/test/helper/wait_for.h new file mode 100644 index 0000000000..f09958d4c0 --- /dev/null +++ b/library/cpp/messagebus/test/helper/wait_for.h @@ -0,0 +1,14 @@ +#pragma once + +#include <util/datetime/base.h> +#include <util/system/yassert.h> + +#define UNIT_WAIT_FOR(condition) \ + do { \ + TInstant start(TInstant::Now()); \ + while (!(condition) && (TInstant::Now() - start < TDuration::Seconds(10))) { \ + Sleep(TDuration::MilliSeconds(1)); \ + } \ + /* TODO: use UNIT_ASSERT if in unittest thread */ \ + Y_VERIFY(condition, "condition failed after 10 seconds wait"); \ + } while (0) diff --git a/library/cpp/messagebus/test/helper/ya.make b/library/cpp/messagebus/test/helper/ya.make new file mode 100644 index 0000000000..97bd45f573 --- /dev/null +++ b/library/cpp/messagebus/test/helper/ya.make @@ -0,0 +1,17 @@ +LIBRARY(messagebus_test_helper) + +OWNER(g:messagebus) + +SRCS( + example.cpp + example_module.cpp + fixed_port.cpp + message_handler_error.cpp + hanging_server.cpp +) + +PEERDIR( + library/cpp/messagebus/oldmodule +) + +END() diff --git a/library/cpp/messagebus/test/perftest/messages.proto b/library/cpp/messagebus/test/perftest/messages.proto new file mode 100644 index 0000000000..8919034e7a --- /dev/null +++ b/library/cpp/messagebus/test/perftest/messages.proto @@ -0,0 +1,7 @@ +message TPerftestRequestRecord { + required string Data = 1; +} + +message TPerftestResponseRecord { + required string Data = 1; +} diff --git a/library/cpp/messagebus/test/perftest/perftest.cpp b/library/cpp/messagebus/test/perftest/perftest.cpp new file mode 100644 index 0000000000..8489319278 --- /dev/null +++ b/library/cpp/messagebus/test/perftest/perftest.cpp @@ -0,0 +1,713 @@ +#include "simple_proto.h" + +#include <library/cpp/messagebus/test/perftest/messages.pb.h> + +#include <library/cpp/messagebus/text_utils.h> +#include <library/cpp/messagebus/thread_extra.h> +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/oldmodule/module.h> +#include <library/cpp/messagebus/protobuf/ybusbuf.h> +#include <library/cpp/messagebus/www/www.h> + +#include <library/cpp/deprecated/threadable/threadable.h> +#include <library/cpp/execprofile/profile.h> +#include <library/cpp/getopt/opt.h> +#include <library/cpp/lwtrace/start.h> +#include <library/cpp/sighandler/async_signals_handler.h> +#include <library/cpp/threading/future/legacy_future.h> + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/random/random.h> +#include <util/stream/file.h> +#include <util/stream/output.h> +#include <util/stream/str.h> +#include <util/string/split.h> +#include <util/system/event.h> +#include <util/system/sysstat.h> +#include <util/system/thread.h> +#include <util/thread/lfqueue.h> + +#include <signal.h> +#include <stdlib.h> + +using namespace NBus; + +/////////////////////////////////////////////////////// +/// \brief Configuration parameters of the test + +const int DEFAULT_PORT = 55666; + +struct TPerftestConfig { + TString Nodes; ///< node1:port1,node2:port2 + int ClientCount; + int MessageSize; ///< size of message to send + int Delay; ///< server delay (milliseconds) + float Failure; ///< simulated failure rate + int ServerPort; + int Run; + bool ServerUseModules; + bool ExecuteOnMessageInWorkerPool; + bool ExecuteOnReplyInWorkerPool; + bool UseCompression; + bool Profile; + unsigned WwwPort; + + TPerftestConfig(); + + void Print() { + fprintf(stderr, "ClientCount=%d\n", ClientCount); + fprintf(stderr, "ServerPort=%d\n", ServerPort); + fprintf(stderr, "Delay=%d usecs\n", Delay); + fprintf(stderr, "MessageSize=%d bytes\n", MessageSize); + fprintf(stderr, "Failure=%.3f%%\n", Failure * 100.0); + fprintf(stderr, "Runtime=%d seconds\n", Run); + fprintf(stderr, "ServerUseModules=%s\n", ServerUseModules ? "true" : "false"); + fprintf(stderr, "ExecuteOnMessageInWorkerPool=%s\n", ExecuteOnMessageInWorkerPool ? "true" : "false"); + fprintf(stderr, "ExecuteOnReplyInWorkerPool=%s\n", ExecuteOnReplyInWorkerPool ? "true" : "false"); + fprintf(stderr, "UseCompression=%s\n", UseCompression ? "true" : "false"); + fprintf(stderr, "Profile=%s\n", Profile ? "true" : "false"); + fprintf(stderr, "WwwPort=%u\n", WwwPort); + } +}; + +extern TPerftestConfig* TheConfig; +extern bool TheExit; + +TVector<TNetAddr> ServerAddresses; + +struct TConfig { + TBusQueueConfig ServerQueueConfig; + TBusQueueConfig ClientQueueConfig; + TBusServerSessionConfig ServerSessionConfig; + TBusClientSessionConfig ClientSessionConfig; + bool SimpleProtocol; + +private: + void ConfigureDefaults(TBusQueueConfig& config) { + config.NumWorkers = 4; + } + + void ConfigureDefaults(TBusSessionConfig& config) { + config.MaxInFlight = 10000; + config.SendTimeout = TDuration::Seconds(20).MilliSeconds(); + config.TotalTimeout = TDuration::Seconds(60).MilliSeconds(); + } + +public: + TConfig() + : SimpleProtocol(false) + { + ConfigureDefaults(ServerQueueConfig); + ConfigureDefaults(ClientQueueConfig); + ConfigureDefaults(ServerSessionConfig); + ConfigureDefaults(ClientSessionConfig); + } + + void Print() { + // TODO: do not print server if only client and vice verse + Cerr << "server queue config:\n"; + Cerr << IndentText(ServerQueueConfig.PrintToString()); + Cerr << "server session config:" << Endl; + Cerr << IndentText(ServerSessionConfig.PrintToString()); + Cerr << "client queue config:\n"; + Cerr << IndentText(ClientQueueConfig.PrintToString()); + Cerr << "client session config:" << Endl; + Cerr << IndentText(ClientSessionConfig.PrintToString()); + Cerr << "simple protocol: " << SimpleProtocol << "\n"; + } +}; + +TConfig Config; + +//////////////////////////////////////////////////////////////// +/// \brief Fast message + +using TPerftestRequest = TBusBufferMessage<TPerftestRequestRecord, 77>; +using TPerftestResponse = TBusBufferMessage<TPerftestResponseRecord, 79>; + +static size_t RequestSize() { + return RandomNumber<size_t>(TheConfig->MessageSize * 2 + 1); +} + +TAutoPtr<TBusMessage> NewRequest() { + if (Config.SimpleProtocol) { + TAutoPtr<TSimpleMessage> r(new TSimpleMessage); + r->SetCompressed(TheConfig->UseCompression); + r->Payload = 10; + return r.Release(); + } else { + TAutoPtr<TPerftestRequest> r(new TPerftestRequest); + r->SetCompressed(TheConfig->UseCompression); + // TODO: use random content for better compression test + r->Record.SetData(TString(RequestSize(), '?')); + return r.Release(); + } +} + +void CheckRequest(TPerftestRequest* request) { + const TString& data = request->Record.GetData(); + for (size_t i = 0; i != data.size(); ++i) { + Y_VERIFY(data.at(i) == '?', "must be question mark"); + } +} + +TAutoPtr<TPerftestResponse> NewResponse(TPerftestRequest* request) { + TAutoPtr<TPerftestResponse> r(new TPerftestResponse); + r->SetCompressed(TheConfig->UseCompression); + r->Record.SetData(TString(request->Record.GetData().size(), '.')); + return r; +} + +void CheckResponse(TPerftestResponse* response) { + const TString& data = response->Record.GetData(); + for (size_t i = 0; i != data.size(); ++i) { + Y_VERIFY(data.at(i) == '.', "must be dot"); + } +} + +//////////////////////////////////////////////////////////////////// +/// \brief Fast protocol that common between client and server +class TPerftestProtocol: public TBusBufferProtocol { +public: + TPerftestProtocol() + : TBusBufferProtocol("TPerftestProtocol", TheConfig->ServerPort) + { + RegisterType(new TPerftestRequest); + RegisterType(new TPerftestResponse); + } +}; + +class TPerftestServer; +class TPerftestUsingModule; +class TPerftestClient; + +struct TTestStats { + TInstant Start; + + TAtomic Messages; + TAtomic Errors; + TAtomic Replies; + + void IncMessage() { + AtomicIncrement(Messages); + } + void IncReplies() { + AtomicDecrement(Messages); + AtomicIncrement(Replies); + } + int NumMessage() { + return AtomicGet(Messages); + } + void IncErrors() { + AtomicDecrement(Messages); + AtomicIncrement(Errors); + } + int NumErrors() { + return AtomicGet(Errors); + } + int NumReplies() { + return AtomicGet(Replies); + } + + double GetThroughput() { + return NumReplies() * 1000000.0 / (TInstant::Now() - Start).MicroSeconds(); + } + +public: + TTestStats() + : Start(TInstant::Now()) + , Messages(0) + , Errors(0) + , Replies(0) + { + } + + void PeriodicallyPrint(); +}; + +TTestStats Stats; + +//////////////////////////////////////////////////////////////////// +/// \brief Fast of the client session +class TPerftestClient : IBusClientHandler { +public: + TBusClientSessionPtr Session; + THolder<TBusProtocol> Proto; + TBusMessageQueuePtr Bus; + TVector<TBusClientConnectionPtr> Connections; + +public: + /// constructor creates instances of protocol and session + TPerftestClient() { + /// create or get instance of message queue, need one per application + Bus = CreateMessageQueue(Config.ClientQueueConfig, "client"); + + if (Config.SimpleProtocol) { + Proto.Reset(new TSimpleProtocol); + } else { + Proto.Reset(new TPerftestProtocol); + } + + Session = TBusClientSession::Create(Proto.Get(), this, Config.ClientSessionConfig, Bus); + + for (unsigned i = 0; i < ServerAddresses.size(); ++i) { + Connections.push_back(Session->GetConnection(ServerAddresses[i])); + } + } + + /// dispatch of requests is done here + void Work() { + SetCurrentThreadName("FastClient::Work"); + + while (!TheExit) { + TBusClientConnection* connection; + if (Connections.size() == 1) { + connection = Connections.front().Get(); + } else { + connection = Connections.at(RandomNumber<size_t>()).Get(); + } + + TBusMessage* message = NewRequest().Release(); + int ret = connection->SendMessage(message, true); + + if (ret == MESSAGE_OK) { + Stats.IncMessage(); + } else if (ret == MESSAGE_BUSY) { + //delete message; + //Sleep(TDuration::MilliSeconds(1)); + //continue; + Y_FAIL("unreachable"); + } else if (ret == MESSAGE_SHUTDOWN) { + delete message; + } else { + delete message; + Stats.IncErrors(); + } + } + } + + void Stop() { + Session->Shutdown(); + } + + /// actual work is being done here + void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override { + Y_UNUSED(mess); + + if (Config.SimpleProtocol) { + VerifyDynamicCast<TSimpleMessage*>(reply.Get()); + } else { + TPerftestResponse* typed = VerifyDynamicCast<TPerftestResponse*>(reply.Get()); + + CheckResponse(typed); + } + + Stats.IncReplies(); + } + + /// message that could not be delivered + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + Y_UNUSED(mess); + Y_UNUSED(status); + + if (TheExit) { + return; + } + + Stats.IncErrors(); + + // Y_ASSERT(TheConfig->Failure > 0.0); + } +}; + +class TPerftestServerCommon { +public: + THolder<TBusProtocol> Proto; + + TBusMessageQueuePtr Bus; + + TBusServerSessionPtr Session; + +protected: + TPerftestServerCommon(const char* name) + : Session() + { + if (Config.SimpleProtocol) { + Proto.Reset(new TSimpleProtocol); + } else { + Proto.Reset(new TPerftestProtocol); + } + + /// create or get instance of single message queue, need one for application + Bus = CreateMessageQueue(Config.ServerQueueConfig, name); + } + +public: + void Stop() { + Session->Shutdown(); + } +}; + +struct TAsyncRequest { + TBusMessage* Request; + TInstant ReceivedTime; +}; + +///////////////////////////////////////////////////////////////////// +/// \brief Fast of the server session +class TPerftestServer: public TPerftestServerCommon, public IBusServerHandler { +public: + TLockFreeQueue<TAsyncRequest> AsyncRequests; + +public: + TPerftestServer() + : TPerftestServerCommon("server") + { + /// register destination session + Session = TBusServerSession::Create(Proto.Get(), this, Config.ServerSessionConfig, Bus); + Y_ASSERT(Session && "probably somebody is listening on the same port"); + } + + /// when message comes, send reply + void OnMessage(TOnMessageContext& mess) override { + if (Config.SimpleProtocol) { + TSimpleMessage* typed = VerifyDynamicCast<TSimpleMessage*>(mess.GetMessage()); + TAutoPtr<TSimpleMessage> response(new TSimpleMessage); + response->Payload = typed->Payload; + mess.SendReplyMove(response); + return; + } + + TPerftestRequest* typed = VerifyDynamicCast<TPerftestRequest*>(mess.GetMessage()); + + CheckRequest(typed); + + /// forget replies for few messages, see what happends + if (TheConfig->Failure > RandomNumber<double>()) { + return; + } + + /// sleep requested time + if (TheConfig->Delay) { + TAsyncRequest request; + request.Request = mess.ReleaseMessage(); + request.ReceivedTime = TInstant::Now(); + AsyncRequests.Enqueue(request); + return; + } + + TAutoPtr<TPerftestResponse> reply(NewResponse(typed)); + /// sent empty reply for each message + mess.SendReplyMove(reply); + // TODO: count results + } + + void Stop() { + TPerftestServerCommon::Stop(); + } +}; + +class TPerftestUsingModule: public TPerftestServerCommon, public TBusModule { +public: + TPerftestUsingModule() + : TPerftestServerCommon("server") + , TBusModule("fast") + { + Y_VERIFY(CreatePrivateSessions(Bus.Get()), "failed to initialize dupdetect module"); + Y_VERIFY(StartInput(), "failed to start input"); + } + + ~TPerftestUsingModule() override { + Shutdown(); + } + +private: + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + TPerftestRequest* typed = VerifyDynamicCast<TPerftestRequest*>(mess); + CheckRequest(typed); + + /// sleep requested time + if (TheConfig->Delay) { + usleep(TheConfig->Delay); + } + + /// forget replies for few messages, see what happends + if (TheConfig->Failure > RandomNumber<double>()) { + return nullptr; + } + + job->SendReply(NewResponse(typed).Release()); + return nullptr; + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + return Session = CreateDefaultDestination(queue, Proto.Get(), Config.ServerSessionConfig); + } +}; + +// ./perftest/perftest -s 11456 -c localhost:11456 -r 60 -n 4 -i 5000 + +using namespace std; +using namespace NBus; + +static TNetworkAddress ParseNetworkAddress(const char* string) { + TString Name; + int Port; + + const char* port = strchr(string, ':'); + + if (port != nullptr) { + Name.append(string, port - string); + Port = atoi(port + 1); + } else { + Name.append(string); + Port = TheConfig->ServerPort != 0 ? TheConfig->ServerPort : DEFAULT_PORT; + } + + return TNetworkAddress(Name, Port); +} + +TVector<TNetAddr> ParseNodes(const TString nodes) { + TVector<TNetAddr> r; + + TVector<TString> hosts; + + size_t numh = Split(nodes.data(), ",", hosts); + + for (int i = 0; i < int(numh); i++) { + const TNetworkAddress& networkAddress = ParseNetworkAddress(hosts[i].data()); + Y_VERIFY(networkAddress.Begin() != networkAddress.End(), "no addresses"); + r.push_back(TNetAddr(networkAddress, &*networkAddress.Begin())); + } + + return r; +} + +TPerftestConfig::TPerftestConfig() { + TBusSessionConfig defaultConfig; + + ServerPort = DEFAULT_PORT; + Delay = 0; // artificial delay inside server OnMessage() + MessageSize = 200; + Failure = 0.00; + Run = 60; // in seconds + Nodes = "localhost"; + ServerUseModules = false; + ExecuteOnMessageInWorkerPool = defaultConfig.ExecuteOnMessageInWorkerPool; + ExecuteOnReplyInWorkerPool = defaultConfig.ExecuteOnReplyInWorkerPool; + UseCompression = false; + Profile = false; + WwwPort = 0; +} + +TPerftestConfig* TheConfig = new TPerftestConfig(); +bool TheExit = false; + +TSystemEvent StopEvent; + +TSimpleSharedPtr<TPerftestServer> Server; +TSimpleSharedPtr<TPerftestUsingModule> ServerUsingModule; + +TVector<TSimpleSharedPtr<TPerftestClient>> Clients; +TMutex ClientsLock; + +void stopsignal(int /*sig*/) { + fprintf(stderr, "\n-------------------- exiting ------------------\n"); + TheExit = true; + StopEvent.Signal(); +} + +// -s <num> - start server on port <num> +// -c <node:port,node:port> - start client + +void TTestStats::PeriodicallyPrint() { + SetCurrentThreadName("print-stats"); + + for (;;) { + StopEvent.WaitT(TDuration::Seconds(1)); + if (TheExit) + break; + + TVector<TSimpleSharedPtr<TPerftestClient>> clients; + { + TGuard<TMutex> guard(ClientsLock); + clients = Clients; + } + + fprintf(stderr, "replies=%d errors=%d throughput=%.3f mess/sec\n", + NumReplies(), NumErrors(), GetThroughput()); + if (!!Server) { + fprintf(stderr, "server: q: %u %s\n", + (unsigned)Server->Bus->GetExecutor()->GetWorkQueueSize(), + Server->Session->GetStatusSingleLine().data()); + } + if (!!ServerUsingModule) { + fprintf(stderr, "server: q: %u %s\n", + (unsigned)ServerUsingModule->Bus->GetExecutor()->GetWorkQueueSize(), + ServerUsingModule->Session->GetStatusSingleLine().data()); + } + for (const auto& client : clients) { + fprintf(stderr, "client: q: %u %s\n", + (unsigned)client->Bus->GetExecutor()->GetWorkQueueSize(), + client->Session->GetStatusSingleLine().data()); + } + + TStringStream stats; + + bool first = true; + if (!!Server) { + if (!first) { + stats << "\n"; + } + first = false; + stats << "server:\n"; + stats << IndentText(Server->Bus->GetStatus()); + } + if (!!ServerUsingModule) { + if (!first) { + stats << "\n"; + } + first = false; + stats << "server using modules:\n"; + stats << IndentText(ServerUsingModule->Bus->GetStatus()); + } + for (const auto& client : clients) { + if (!first) { + stats << "\n"; + } + first = false; + stats << "client:\n"; + stats << IndentText(client->Bus->GetStatus()); + } + + TUnbufferedFileOutput("stats").Write(stats.Str()); + } +} + +int main(int argc, char* argv[]) { + NLWTrace::StartLwtraceFromEnv(); + + /* unix foo */ + setvbuf(stdout, nullptr, _IONBF, 0); + setvbuf(stderr, nullptr, _IONBF, 0); + Umask(0); + SetAsyncSignalHandler(SIGINT, stopsignal); + SetAsyncSignalHandler(SIGTERM, stopsignal); +#ifndef _win_ + SetAsyncSignalHandler(SIGUSR1, stopsignal); +#endif + signal(SIGPIPE, SIG_IGN); + + NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default(); + opts.AddLongOption('s', "server-port", "server port").RequiredArgument("port").StoreResult(&TheConfig->ServerPort); + opts.AddCharOption('m', "average message size").RequiredArgument("size").StoreResult(&TheConfig->MessageSize); + opts.AddLongOption('c', "server-host", "server hosts").RequiredArgument("host[,host]...").StoreResult(&TheConfig->Nodes); + opts.AddCharOption('f', "failure rate (rational number between 0 and 1)").RequiredArgument("rate").StoreResult(&TheConfig->Failure); + opts.AddCharOption('w', "delay before reply").RequiredArgument("microseconds").StoreResult(&TheConfig->Delay); + opts.AddCharOption('r', "run duration").RequiredArgument("seconds").StoreResult(&TheConfig->Run); + opts.AddLongOption("client-count", "amount of clients").RequiredArgument("count").StoreResult(&TheConfig->ClientCount).DefaultValue("1"); + opts.AddLongOption("server-use-modules").StoreResult(&TheConfig->ServerUseModules, true); + opts.AddLongOption("on-message-in-pool", "execute OnMessage callback in worker pool") + .RequiredArgument("BOOL") + .StoreResult(&TheConfig->ExecuteOnMessageInWorkerPool); + opts.AddLongOption("on-reply-in-pool", "execute OnReply callback in worker pool") + .RequiredArgument("BOOL") + .StoreResult(&TheConfig->ExecuteOnReplyInWorkerPool); + opts.AddLongOption("compression", "use compression").RequiredArgument("BOOL").StoreResult(&TheConfig->UseCompression); + opts.AddLongOption("simple-proto").SetFlag(&Config.SimpleProtocol); + opts.AddLongOption("profile").SetFlag(&TheConfig->Profile); + opts.AddLongOption("www-port").RequiredArgument("PORT").StoreResult(&TheConfig->WwwPort); + opts.AddHelpOption(); + + Config.ServerQueueConfig.ConfigureLastGetopt(opts, "server-"); + Config.ServerSessionConfig.ConfigureLastGetopt(opts, "server-"); + Config.ClientQueueConfig.ConfigureLastGetopt(opts, "client-"); + Config.ClientSessionConfig.ConfigureLastGetopt(opts, "client-"); + + opts.SetFreeArgsMax(0); + + NLastGetopt::TOptsParseResult parseResult(&opts, argc, argv); + + TheConfig->Print(); + Config.Print(); + + if (TheConfig->Profile) { + BeginProfiling(); + } + + TIntrusivePtr<TBusWww> www(new TBusWww); + + ServerAddresses = ParseNodes(TheConfig->Nodes); + + if (TheConfig->ServerPort) { + if (TheConfig->ServerUseModules) { + ServerUsingModule = new TPerftestUsingModule(); + www->RegisterModule(ServerUsingModule.Get()); + } else { + Server = new TPerftestServer(); + www->RegisterServerSession(Server->Session); + } + } + + TVector<TSimpleSharedPtr<NThreading::TLegacyFuture<void, false>>> futures; + + if (ServerAddresses.size() > 0 && TheConfig->ClientCount > 0) { + for (int i = 0; i < TheConfig->ClientCount; ++i) { + TGuard<TMutex> guard(ClientsLock); + Clients.push_back(new TPerftestClient); + futures.push_back(new NThreading::TLegacyFuture<void, false>(std::bind(&TPerftestClient::Work, Clients.back()))); + www->RegisterClientSession(Clients.back()->Session); + } + } + + futures.push_back(new NThreading::TLegacyFuture<void, false>(std::bind(&TTestStats::PeriodicallyPrint, std::ref(Stats)))); + + THolder<TBusWwwHttpServer> wwwServer; + if (TheConfig->WwwPort != 0) { + wwwServer.Reset(new TBusWwwHttpServer(www, TheConfig->WwwPort)); + } + + /* sit here until signal terminate our process */ + StopEvent.WaitT(TDuration::Seconds(TheConfig->Run)); + TheExit = true; + StopEvent.Signal(); + + if (!!Server) { + Cerr << "Stopping server\n"; + Server->Stop(); + } + if (!!ServerUsingModule) { + Cerr << "Stopping server (using modules)\n"; + ServerUsingModule->Stop(); + } + + TVector<TSimpleSharedPtr<TPerftestClient>> clients; + { + TGuard<TMutex> guard(ClientsLock); + clients = Clients; + } + + if (!clients.empty()) { + Cerr << "Stopping clients\n"; + + for (auto& client : clients) { + client->Stop(); + } + } + + wwwServer.Destroy(); + + for (const auto& future : futures) { + future->Get(); + } + + if (TheConfig->Profile) { + EndProfiling(); + } + + Cerr << "***SUCCESS***\n"; + return 0; +} diff --git a/library/cpp/messagebus/test/perftest/simple_proto.cpp b/library/cpp/messagebus/test/perftest/simple_proto.cpp new file mode 100644 index 0000000000..19d6c15b9d --- /dev/null +++ b/library/cpp/messagebus/test/perftest/simple_proto.cpp @@ -0,0 +1,22 @@ +#include "simple_proto.h" + +#include <util/generic/cast.h> + +#include <typeinfo> + +using namespace NBus; + +void TSimpleProtocol::Serialize(const TBusMessage* mess, TBuffer& data) { + Y_VERIFY(typeid(TSimpleMessage) == typeid(*mess)); + const TSimpleMessage* typed = static_cast<const TSimpleMessage*>(mess); + data.Append((const char*)&typed->Payload, 4); +} + +TAutoPtr<TBusMessage> TSimpleProtocol::Deserialize(ui16, TArrayRef<const char> payload) { + if (payload.size() != 4) { + return nullptr; + } + TAutoPtr<TSimpleMessage> r(new TSimpleMessage); + memcpy(&r->Payload, payload.data(), 4); + return r.Release(); +} diff --git a/library/cpp/messagebus/test/perftest/simple_proto.h b/library/cpp/messagebus/test/perftest/simple_proto.h new file mode 100644 index 0000000000..4a0cc08db3 --- /dev/null +++ b/library/cpp/messagebus/test/perftest/simple_proto.h @@ -0,0 +1,29 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> + +struct TSimpleMessage: public NBus::TBusMessage { + ui32 Payload; + + TSimpleMessage() + : TBusMessage(1) + , Payload(0) + { + } + + TSimpleMessage(NBus::ECreateUninitialized) + : TBusMessage(NBus::ECreateUninitialized()) + { + } +}; + +struct TSimpleProtocol: public NBus::TBusProtocol { + TSimpleProtocol() + : NBus::TBusProtocol("simple", 55666) + { + } + + void Serialize(const NBus::TBusMessage* mess, TBuffer& data) override; + + TAutoPtr<NBus::TBusMessage> Deserialize(ui16 ty, TArrayRef<const char> payload) override; +}; diff --git a/library/cpp/messagebus/test/perftest/stackcollect.diff b/library/cpp/messagebus/test/perftest/stackcollect.diff new file mode 100644 index 0000000000..658f0141b3 --- /dev/null +++ b/library/cpp/messagebus/test/perftest/stackcollect.diff @@ -0,0 +1,13 @@ +Index: test/perftest/CMakeLists.txt +=================================================================== +--- test/perftest/CMakeLists.txt (revision 1088840) ++++ test/perftest/CMakeLists.txt (working copy) +@@ -3,7 +3,7 @@ PROGRAM(messagebus_perftest) + OWNER(nga) + + PEERDIR( +- library/cpp/execprofile ++ junk/davenger/stackcollect + library/cpp/messagebus + library/cpp/messagebus/protobuf + library/cpp/sighandler diff --git a/library/cpp/messagebus/test/perftest/ya.make b/library/cpp/messagebus/test/perftest/ya.make new file mode 100644 index 0000000000..24c2848ed5 --- /dev/null +++ b/library/cpp/messagebus/test/perftest/ya.make @@ -0,0 +1,24 @@ +PROGRAM(messagebus_perftest) + +OWNER(g:messagebus) + +PEERDIR( + library/cpp/deprecated/threadable + library/cpp/execprofile + library/cpp/getopt + library/cpp/lwtrace + library/cpp/messagebus + library/cpp/messagebus/oldmodule + library/cpp/messagebus/protobuf + library/cpp/messagebus/www + library/cpp/sighandler + library/cpp/threading/future +) + +SRCS( + messages.proto + perftest.cpp + simple_proto.cpp +) + +END() diff --git a/library/cpp/messagebus/test/ut/count_down_latch.h b/library/cpp/messagebus/test/ut/count_down_latch.h new file mode 100644 index 0000000000..5117db5731 --- /dev/null +++ b/library/cpp/messagebus/test/ut/count_down_latch.h @@ -0,0 +1,30 @@ +#pragma once + +#include <util/system/atomic.h> +#include <util/system/event.h> + +class TCountDownLatch { +private: + TAtomic Current; + TSystemEvent EventObject; + +public: + TCountDownLatch(unsigned initial) + : Current(initial) + { + } + + void CountDown() { + if (AtomicDecrement(Current) == 0) { + EventObject.Signal(); + } + } + + void Await() { + EventObject.Wait(); + } + + bool Await(TDuration timeout) { + return EventObject.WaitT(timeout); + } +}; diff --git a/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp b/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp new file mode 100644 index 0000000000..3fdd175d73 --- /dev/null +++ b/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp @@ -0,0 +1,40 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/test_utils.h> +#include <library/cpp/messagebus/ybus.h> + +class TLocatorRegisterUniqTest: public TTestBase { + UNIT_TEST_SUITE(TLocatorRegisterUniqTest); + UNIT_TEST(TestRegister); + UNIT_TEST_SUITE_END(); + +protected: + void TestRegister(); +}; + +UNIT_TEST_SUITE_REGISTRATION(TLocatorRegisterUniqTest); + +void TLocatorRegisterUniqTest::TestRegister() { + ASSUME_IP_V4_ENABLED; + + NBus::TBusLocator locator; + const char* serviceName = "TestService"; + const char* hostName = "192.168.0.42"; + int port = 31337; + + NBus::TBusKeyVec keys; + locator.LocateKeys(serviceName, keys); + UNIT_ASSERT(keys.size() == 0); + + locator.Register(serviceName, hostName, port); + locator.LocateKeys(serviceName, keys); + /// YBUS_KEYMIN YBUS_KEYMAX range + UNIT_ASSERT(keys.size() == 1); + + TVector<NBus::TNetAddr> hosts; + UNIT_ASSERT(locator.LocateAll(serviceName, NBus::YBUS_KEYMIN, hosts) == 1); + + locator.Register(serviceName, hostName, port); + hosts.clear(); + UNIT_ASSERT(locator.LocateAll(serviceName, NBus::YBUS_KEYMIN, hosts) == 1); +} diff --git a/library/cpp/messagebus/test/ut/messagebus_ut.cpp b/library/cpp/messagebus/test/ut/messagebus_ut.cpp new file mode 100644 index 0000000000..040f9b7702 --- /dev/null +++ b/library/cpp/messagebus/test/ut/messagebus_ut.cpp @@ -0,0 +1,1151 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/fixed_port.h> +#include <library/cpp/messagebus/test/helper/hanging_server.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> +#include <library/cpp/messagebus/test/helper/wait_for.h> + +#include <library/cpp/messagebus/misc/test_sync.h> + +#include <util/network/sock.h> + +#include <utility> + +using namespace NBus; +using namespace NBus::NTest; + +namespace { + struct TExampleClientSlowOnMessageSent: public TExampleClient { + TAtomic SentCompleted; + + TSystemEvent ReplyReceived; + + TExampleClientSlowOnMessageSent() + : SentCompleted(0) + { + } + + ~TExampleClientSlowOnMessageSent() override { + Session->Shutdown(); + } + + void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override { + Y_VERIFY(AtomicGet(SentCompleted), "must be completed"); + + TExampleClient::OnReply(mess, reply); + + ReplyReceived.Signal(); + } + + void OnMessageSent(TBusMessage*) override { + Sleep(TDuration::MilliSeconds(100)); + AtomicSet(SentCompleted, 1); + } + }; + +} + +Y_UNIT_TEST_SUITE(TMessageBusTests) { + void TestDestinationTemplate(bool useCompression, bool ackMessageBeforeReply, + const TBusServerSessionConfig& sessionConfig) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TExampleClient client(sessionConfig); + client.CrashOnError = true; + + server.UseCompression = useCompression; + client.UseCompression = useCompression; + + server.AckMessageBeforeSendReply = ackMessageBeforeReply; + + client.SendMessagesWaitReplies(100, server.GetActualListenAddr()); + UNIT_ASSERT_EQUAL(server.Session->GetInFlight(), 0); + UNIT_ASSERT_EQUAL(client.Session->GetInFlight(), 0); + } + + Y_UNIT_TEST(TestDestination) { + TestDestinationTemplate(false, false, TBusServerSessionConfig()); + } + + Y_UNIT_TEST(TestDestinationUsingAck) { + TestDestinationTemplate(false, true, TBusServerSessionConfig()); + } + + Y_UNIT_TEST(TestDestinationWithCompression) { + TestDestinationTemplate(true, false, TBusServerSessionConfig()); + } + + Y_UNIT_TEST(TestCork) { + TBusServerSessionConfig config; + config.SendThreshold = 1000000000000; + config.Cork = TDuration::MilliSeconds(10); + TestDestinationTemplate(false, false, config); + // TODO: test for cork hanging + } + + Y_UNIT_TEST(TestReconnect) { + if (!IsFixedPortTestAllowed()) { + return; + } + + TObjectCountCheck objectCountCheck; + + unsigned port = FixedPort; + TNetAddr serverAddr("localhost", port); + THolder<TExampleServer> server; + + TBusClientSessionConfig clientConfig; + clientConfig.RetryInterval = 0; + TExampleClient client(clientConfig); + + server.Reset(new TExampleServer(port, "TExampleServer 1")); + + client.SendMessagesWaitReplies(17, serverAddr); + + server.Destroy(); + + // Making the client to detect disconnection. + client.SendMessages(1, serverAddr); + EMessageStatus error = client.WaitForError(); + if (error == MESSAGE_DELIVERY_FAILED) { + client.SendMessages(1, serverAddr); + error = client.WaitForError(); + } + UNIT_ASSERT_VALUES_EQUAL(MESSAGE_CONNECT_FAILED, error); + + server.Reset(new TExampleServer(port, "TExampleServer 2")); + + client.SendMessagesWaitReplies(19, serverAddr); + } + + struct TestNoServerImplClient: public TExampleClient { + TTestSync TestSync; + int failures = 0; + + template <typename... Args> + TestNoServerImplClient(Args&&... args) + : TExampleClient(std::forward<Args>(args)...) + { + } + + ~TestNoServerImplClient() override { + Session->Shutdown(); + } + + void OnError(TAutoPtr<TBusMessage> message, EMessageStatus status) override { + Y_UNUSED(message); + + Y_VERIFY(status == MESSAGE_CONNECT_FAILED, "must be MESSAGE_CONNECT_FAILED, got %s", ToString(status).data()); + + TestSync.CheckAndIncrement((failures++) * 2); + } + }; + + void TestNoServerImpl(unsigned port, bool oneWay) { + TNetAddr noServerAddr("localhost", port); + + TestNoServerImplClient client; + + int count = 0; + for (; count < 200; ++count) { + EMessageStatus status; + if (oneWay) { + status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &noServerAddr); + } else { + TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount)); + status = client.Session->SendMessageAutoPtr(message, &noServerAddr); + } + + Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data()); + + if (count == 0) { + // lame way to wait until it is connected + Sleep(TDuration::MilliSeconds(10)); + } + client.TestSync.WaitForAndIncrement(count * 2 + 1); + } + + client.TestSync.WaitForAndIncrement(count * 2); + } + + void HangingServerImpl(unsigned port) { + TNetAddr noServerAddr("localhost", port); + + TExampleClient client; + + int count = 0; + for (;; ++count) { + TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount)); + EMessageStatus status = client.Session->SendMessageAutoPtr(message, &noServerAddr); + if (status == MESSAGE_BUSY) { + break; + } + UNIT_ASSERT_VALUES_EQUAL(int(MESSAGE_OK), int(status)); + + if (count == 0) { + // lame way to wait until it is connected + Sleep(TDuration::MilliSeconds(10)); + } + } + + UNIT_ASSERT_VALUES_EQUAL(client.Session->GetConfig()->MaxInFlight, count); + } + + Y_UNIT_TEST(TestHangindServer) { + TObjectCountCheck objectCountCheck; + + THangingServer server(0); + + HangingServerImpl(server.GetPort()); + } + + Y_UNIT_TEST(TestNoServer) { + TObjectCountCheck objectCountCheck; + + TestNoServerImpl(17, false); + } + + Y_UNIT_TEST(PauseInput) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + server.Session->PauseInput(true); + + TBusClientSessionConfig clientConfig; + clientConfig.MaxInFlight = 1000; + TExampleClient client(clientConfig); + + client.SendMessages(100, server.GetActualListenAddr()); + + server.TestSync.Check(0); + + server.Session->PauseInput(false); + + server.TestSync.WaitFor(100); + + client.WaitReplies(); + + server.Session->PauseInput(true); + + client.SendMessages(200, server.GetActualListenAddr()); + + server.TestSync.Check(100); + + server.Session->PauseInput(false); + + server.TestSync.WaitFor(300); + + client.WaitReplies(); + } + + struct TSendTimeoutCheckerExampleClient: public TExampleClient { + static TBusClientSessionConfig SessionConfig(bool periodLessThanConnectTimeout) { + TBusClientSessionConfig sessionConfig; + if (periodLessThanConnectTimeout) { + sessionConfig.SendTimeout = 1; + sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(50); + } else { + sessionConfig.SendTimeout = 50; + sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(1); + } + return sessionConfig; + } + + TSendTimeoutCheckerExampleClient(bool periodLessThanConnectTimeout) + : TExampleClient(SessionConfig(periodLessThanConnectTimeout)) + { + } + + ~TSendTimeoutCheckerExampleClient() override { + Session->Shutdown(); + } + + TSystemEvent ErrorHappened; + + void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override { + Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "got status: %s", ToString(status).data()); + ErrorHappened.Signal(); + } + }; + + void NoServer_SendTimeout_Callback_Impl(bool periodLessThanConnectTimeout) { + TObjectCountCheck objectCountCheck; + + TNetAddr serverAddr("localhost", 17); + + TSendTimeoutCheckerExampleClient client(periodLessThanConnectTimeout); + + client.SendMessages(1, serverAddr); + + client.ErrorHappened.WaitI(); + } + + Y_UNIT_TEST(NoServer_SendTimeout_Callback_PeriodLess) { + NoServer_SendTimeout_Callback_Impl(true); + } + + Y_UNIT_TEST(NoServer_SendTimeout_Callback_TimeoutLess) { + NoServer_SendTimeout_Callback_Impl(false); + } + + Y_UNIT_TEST(TestOnReplyCalledAfterOnMessageSent) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + TNetAddr serverAddr = server.GetActualListenAddr(); + TExampleClientSlowOnMessageSent client; + + TAutoPtr<TExampleRequest> message(new TExampleRequest(&client.Proto.RequestCount)); + EMessageStatus s = client.Session->SendMessageAutoPtr(message, &serverAddr); + UNIT_ASSERT_EQUAL(s, MESSAGE_OK); + + UNIT_ASSERT(client.ReplyReceived.WaitT(TDuration::Seconds(5))); + } + + struct TDelayReplyServer: public TBusServerHandlerError { + TBusMessageQueuePtr Bus; + TExampleProtocol Proto; + TSystemEvent MessageReceivedEvent; // 1 wait for 1 message + TBusServerSessionPtr Session; + TMutex Lock_; + TDeque<TAutoPtr<TOnMessageContext>> DelayedMessages; + + TDelayReplyServer() + : MessageReceivedEvent(TEventResetType::rAuto) + { + Bus = CreateMessageQueue("TDelayReplyServer"); + TBusServerSessionConfig sessionConfig; + sessionConfig.SendTimeout = 1000; + sessionConfig.TotalTimeout = 2001; + Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus); + if (!Session) { + ythrow yexception() << "Failed to create destination session"; + } + } + + void OnMessage(TOnMessageContext& mess) override { + Y_VERIFY(mess.IsConnectionAlive(), "connection should be alive here"); + TAutoPtr<TOnMessageContext> delayedMsg(new TOnMessageContext); + delayedMsg->Swap(mess); + auto g(Guard(Lock_)); + DelayedMessages.push_back(delayedMsg); + MessageReceivedEvent.Signal(); + } + + bool CheckClientIsAlive() { + auto g(Guard(Lock_)); + for (auto& delayedMessage : DelayedMessages) { + if (!delayedMessage->IsConnectionAlive()) { + return false; + } + } + return true; + } + + bool CheckClientIsDead() const { + auto g(Guard(Lock_)); + for (const auto& delayedMessage : DelayedMessages) { + if (delayedMessage->IsConnectionAlive()) { + return false; + } + } + return true; + } + + void ReplyToDelayedMessages() { + while (true) { + TOnMessageContext msg; + { + auto g(Guard(Lock_)); + if (DelayedMessages.empty()) { + break; + } + DelayedMessages.front()->Swap(msg); + DelayedMessages.pop_front(); + } + TAutoPtr<TBusMessage> reply(new TExampleResponse(&Proto.ResponseCount)); + msg.SendReplyMove(reply); + } + } + + size_t GetDelayedMessageCount() const { + auto g(Guard(Lock_)); + return DelayedMessages.size(); + } + + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + Y_UNUSED(mess); + Y_VERIFY(status == MESSAGE_SHUTDOWN, "only shutdown allowed, got %s", ToString(status).data()); + } + }; + + Y_UNIT_TEST(TestReplyCalledAfterClientDisconnected) { + TObjectCountCheck objectCountCheck; + + TDelayReplyServer server; + + THolder<TExampleClient> client(new TExampleClient); + + client->SendMessages(1, TNetAddr("localhost", server.Session->GetActualListenPort())); + + UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5))); + + UNIT_ASSERT_VALUES_EQUAL(1, server.Session->GetInFlight()); + + client.Destroy(); + + UNIT_WAIT_FOR(server.CheckClientIsDead()); + + server.ReplyToDelayedMessages(); + + // wait until all server message are delivered + UNIT_WAIT_FOR(0 == server.Session->GetInFlight()); + } + + struct TPackUnpackServer: public TBusServerHandlerError { + TBusMessageQueuePtr Bus; + TExampleProtocol Proto; + TSystemEvent MessageReceivedEvent; + TSystemEvent ClientDiedEvent; + TBusServerSessionPtr Session; + + TPackUnpackServer() { + Bus = CreateMessageQueue("TPackUnpackServer"); + TBusServerSessionConfig sessionConfig; + Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus); + } + + void OnMessage(TOnMessageContext& mess) override { + TBusIdentity ident; + mess.AckMessage(ident); + + char packed[BUS_IDENTITY_PACKED_SIZE]; + ident.Pack(packed); + TBusIdentity resurrected; + resurrected.Unpack(packed); + + mess.GetSession()->SendReply(resurrected, new TExampleResponse(&Proto.ResponseCount)); + } + + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + Y_UNUSED(mess); + Y_VERIFY(status == MESSAGE_SHUTDOWN, "only shutdown allowed"); + } + }; + + Y_UNIT_TEST(PackUnpack) { + TObjectCountCheck objectCountCheck; + + TPackUnpackServer server; + + THolder<TExampleClient> client(new TExampleClient); + + client->SendMessagesWaitReplies(1, TNetAddr("localhost", server.Session->GetActualListenPort())); + } + + Y_UNIT_TEST(ClientRequestTooLarge) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TBusClientSessionConfig clientConfig; + clientConfig.MaxMessageSize = 100; + TExampleClient client(clientConfig); + + client.DataSize = 10; + client.SendMessagesWaitReplies(1, server.GetActualListenAddr()); + + client.DataSize = 1000; + client.SendMessages(1, server.GetActualListenAddr()); + client.WaitForError(MESSAGE_MESSAGE_TOO_LARGE); + + client.DataSize = 20; + client.SendMessagesWaitReplies(10, server.GetActualListenAddr()); + + client.DataSize = 10000; + client.SendMessages(1, server.GetActualListenAddr()); + client.WaitForError(MESSAGE_MESSAGE_TOO_LARGE); + } + + struct TServerForResponseTooLarge: public TExampleServer { + TTestSync TestSync; + + static TBusServerSessionConfig Config() { + TBusServerSessionConfig config; + config.MaxMessageSize = 100; + return config; + } + + TServerForResponseTooLarge() + : TExampleServer("TServerForResponseTooLarge", Config()) + { + } + + ~TServerForResponseTooLarge() override { + Session->Shutdown(); + } + + void OnMessage(TOnMessageContext& mess) override { + TAutoPtr<TBusMessage> response; + + if (TestSync.Get() == 0) { + TestSync.CheckAndIncrement(0); + response.Reset(new TExampleResponse(&Proto.ResponseCount, 1000)); + } else { + TestSync.WaitForAndIncrement(3); + response.Reset(new TExampleResponse(&Proto.ResponseCount, 10)); + } + + mess.SendReplyMove(response); + } + + void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override { + TestSync.WaitForAndIncrement(1); + + Y_VERIFY(status == MESSAGE_MESSAGE_TOO_LARGE, "status"); + } + }; + + Y_UNIT_TEST(ServerResponseTooLarge) { + TObjectCountCheck objectCountCheck; + + TServerForResponseTooLarge server; + + TExampleClient client; + client.DataSize = 10; + + client.SendMessages(1, server.GetActualListenAddr()); + server.TestSync.WaitForAndIncrement(2); + client.ResetCounters(); + + client.SendMessages(1, server.GetActualListenAddr()); + + client.WorkDone.WaitI(); + + server.TestSync.CheckAndIncrement(4); + + UNIT_ASSERT_VALUES_EQUAL(1, client.Session->GetInFlight()); + } + + struct TServerForRequestTooLarge: public TExampleServer { + TTestSync TestSync; + + static TBusServerSessionConfig Config() { + TBusServerSessionConfig config; + config.MaxMessageSize = 100; + return config; + } + + TServerForRequestTooLarge() + : TExampleServer("TServerForRequestTooLarge", Config()) + { + } + + ~TServerForRequestTooLarge() override { + Session->Shutdown(); + } + + void OnMessage(TOnMessageContext& req) override { + unsigned n = TestSync.Get(); + if (n < 2) { + TestSync.CheckAndIncrement(n); + TAutoPtr<TExampleResponse> resp(new TExampleResponse(&Proto.ResponseCount, 10)); + req.SendReplyMove(resp); + } else { + Y_FAIL("wrong"); + } + } + }; + + Y_UNIT_TEST(ServerRequestTooLarge) { + TObjectCountCheck objectCountCheck; + + TServerForRequestTooLarge server; + + TExampleClient client; + client.DataSize = 10; + + client.SendMessagesWaitReplies(2, server.GetActualListenAddr()); + + server.TestSync.CheckAndIncrement(2); + + client.DataSize = 200; + client.SendMessages(1, server.GetActualListenAddr()); + // server closes connection, so MESSAGE_DELIVERY_FAILED is returned to client + client.WaitForError(MESSAGE_DELIVERY_FAILED); + } + + Y_UNIT_TEST(ClientResponseTooLarge) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + server.DataSize = 10; + + TBusClientSessionConfig clientSessionConfig; + clientSessionConfig.MaxMessageSize = 100; + TExampleClient client(clientSessionConfig); + client.DataSize = 10; + + client.SendMessagesWaitReplies(3, server.GetActualListenAddr()); + + server.DataSize = 1000; + + client.SendMessages(1, server.GetActualListenAddr()); + client.WaitForError(MESSAGE_DELIVERY_FAILED); + } + + Y_UNIT_TEST(ServerUnknownMessage) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + TNetAddr serverAddr = server.GetActualListenAddr(); + + TExampleClient client; + + client.SendMessagesWaitReplies(2, serverAddr); + + TAutoPtr<TBusMessage> req(new TExampleRequest(&client.Proto.RequestCount)); + req->GetHeader()->Type = 11; + client.Session->SendMessageAutoPtr(req, &serverAddr); + client.MessageCount = 1; + + client.WaitForError(MESSAGE_DELIVERY_FAILED); + } + + Y_UNIT_TEST(ServerMessageReservedIds) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + TNetAddr serverAddr = server.GetActualListenAddr(); + + TExampleClient client; + + client.SendMessagesWaitReplies(2, serverAddr); + + // This test doens't check 0, 1, YBUS_KEYINVALID because there are asserts() on sending side + + TAutoPtr<TBusMessage> req(new TExampleRequest(&client.Proto.RequestCount)); + req->GetHeader()->Id = 2; + client.Session->SendMessageAutoPtr(req, &serverAddr); + client.MessageCount = 1; + client.WaitForError(MESSAGE_DELIVERY_FAILED); + + req.Reset(new TExampleRequest(&client.Proto.RequestCount)); + req->GetHeader()->Id = YBUS_KEYLOCAL; + client.Session->SendMessageAutoPtr(req, &serverAddr); + client.MessageCount = 1; + client.WaitForError(MESSAGE_DELIVERY_FAILED); + } + + Y_UNIT_TEST(TestGetInFlightForDestination) { + TObjectCountCheck objectCountCheck; + + TDelayReplyServer server; + + TExampleClient client; + + TNetAddr addr("localhost", server.Session->GetActualListenPort()); + + UNIT_ASSERT_VALUES_EQUAL(size_t(0), client.Session->GetInFlight(addr)); + + client.SendMessages(2, &addr); + + for (size_t i = 0; i < 5; ++i) { + // One MessageReceivedEvent indicates one message, we need to wait for two + UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5))); + if (server.GetDelayedMessageCount() == 2) { + break; + } + } + UNIT_ASSERT_VALUES_EQUAL(server.GetDelayedMessageCount(), 2); + + size_t inFlight = client.Session->GetInFlight(addr); + // 4 is for messagebus1 that adds inFlight counter twice for some reason + UNIT_ASSERT(inFlight == 2 || inFlight == 4); + + UNIT_ASSERT(server.CheckClientIsAlive()); + + server.ReplyToDelayedMessages(); + + client.WaitReplies(); + } + + struct TResetAfterSendOneWayErrorInCallbackClient: public TExampleClient { + TTestSync TestSync; + + static TBusClientSessionConfig SessionConfig() { + TBusClientSessionConfig config; + // 1 ms is not enough when test is running under valgrind + config.ConnectTimeout = 10; + config.SendTimeout = 10; + config.Secret.TimeoutPeriod = TDuration::MilliSeconds(1); + return config; + } + + TResetAfterSendOneWayErrorInCallbackClient() + : TExampleClient(SessionConfig()) + { + } + + ~TResetAfterSendOneWayErrorInCallbackClient() override { + Session->Shutdown(); + } + + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + TestSync.WaitForAndIncrement(0); + Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "must be connection failed, got %s", ToString(status).data()); + mess.Destroy(); + TestSync.CheckAndIncrement(1); + } + }; + + Y_UNIT_TEST(ResetAfterSendOneWayErrorInCallback) { + TObjectCountCheck objectCountCheck; + + TNetAddr noServerAddr("localhost", 17); + + TResetAfterSendOneWayErrorInCallbackClient client; + + EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &noServerAddr); + UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok); + + client.TestSync.WaitForAndIncrement(2); + } + + struct TResetAfterSendMessageOneWayDuringShutdown: public TExampleClient { + TTestSync TestSync; + + ~TResetAfterSendMessageOneWayDuringShutdown() override { + Session->Shutdown(); + } + + void OnError(TAutoPtr<TBusMessage> message, EMessageStatus status) override { + TestSync.CheckAndIncrement(0); + + Y_VERIFY(status == MESSAGE_CONNECT_FAILED, "must be MESSAGE_CONNECT_FAILED, got %s", ToString(status).data()); + + // check reset is possible here + message->Reset(); + + // intentionally don't destroy the message + // we will try to resend it + Y_UNUSED(message.Release()); + + TestSync.CheckAndIncrement(1); + } + }; + + Y_UNIT_TEST(ResetAfterSendMessageOneWayDuringShutdown) { + TObjectCountCheck objectCountCheck; + + TNetAddr noServerAddr("localhost", 17); + + TResetAfterSendMessageOneWayDuringShutdown client; + + TExampleRequest* message = new TExampleRequest(&client.Proto.RequestCount); + EMessageStatus ok = client.Session->SendMessageOneWay(message, &noServerAddr); + UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok); + + client.TestSync.WaitForAndIncrement(2); + + client.Session->Shutdown(); + + ok = client.Session->SendMessageOneWay(message); + Y_VERIFY(ok == MESSAGE_SHUTDOWN, "must be shutdown when sending during shutdown, got %s", ToString(ok).data()); + + // check reset is possible here + message->Reset(); + client.TestSync.CheckAndIncrement(3); + + delete message; + } + + Y_UNIT_TEST(ResetAfterSendOneWayErrorInReturn) { + TObjectCountCheck objectCountCheck; + + TestNoServerImpl(17, true); + } + + struct TResetAfterSendOneWaySuccessClient: public TExampleClient { + TTestSync TestSync; + + ~TResetAfterSendOneWaySuccessClient() override { + Session->Shutdown(); + } + + void OnMessageSentOneWay(TAutoPtr<TBusMessage> sent) override { + TestSync.WaitForAndIncrement(0); + sent->Reset(); + TestSync.CheckAndIncrement(1); + } + }; + + Y_UNIT_TEST(ResetAfterSendOneWaySuccess) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + TNetAddr serverAddr = server.GetActualListenAddr(); + + TResetAfterSendOneWaySuccessClient client; + + EMessageStatus ok = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &serverAddr); + UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok); + // otherwize message might go to OnError(MESSAGE_SHUTDOWN) + server.WaitForOnMessageCount(1); + + client.TestSync.WaitForAndIncrement(2); + } + + Y_UNIT_TEST(GetStatus) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TExampleClient client; + // make sure connected + client.SendMessagesWaitReplies(3, server.GetActualListenAddr()); + + server.Bus->GetStatus(); + server.Bus->GetStatus(); + server.Bus->GetStatus(); + + client.Bus->GetStatus(); + client.Bus->GetStatus(); + client.Bus->GetStatus(); + } + + Y_UNIT_TEST(BindOnRandomPort) { + TObjectCountCheck objectCountCheck; + + TBusServerSessionConfig serverConfig; + TExampleServer server; + + TExampleClient client; + TNetAddr addr(TNetAddr("127.0.0.1", server.Session->GetActualListenPort())); + client.SendMessagesWaitReplies(3, &addr); + } + + Y_UNIT_TEST(UnbindOnShutdown) { + TBusMessageQueuePtr queue(CreateMessageQueue()); + + TExampleProtocol proto; + TBusServerHandlerError handler; + TBusServerSessionPtr session = TBusServerSession::Create( + &proto, &handler, TBusServerSessionConfig(), queue); + + unsigned port = session->GetActualListenPort(); + UNIT_ASSERT(port > 0); + + session->Shutdown(); + + // fails is Shutdown() didn't unbind + THangingServer hangingServer(port); + } + + Y_UNIT_TEST(VersionNegotiation) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TSockAddrInet addr(IpFromString("127.0.0.1"), server.Session->GetActualListenPort()); + + TInetStreamSocket socket; + int r1 = socket.Connect(&addr); + UNIT_ASSERT(r1 >= 0); + + TStreamSocketOutput output(&socket); + + TBusHeader request; + Zero(request); + request.Size = sizeof(request); + request.SetVersionInternal(0xF); // max + output.Write(&request, sizeof(request)); + + UNIT_ASSERT_VALUES_EQUAL(IsVersionNegotiation(request), true); + + TStreamSocketInput input(&socket); + + TBusHeader response; + size_t pos = 0; + + while (pos < sizeof(response)) { + size_t count = input.Read(((char*)&response) + pos, sizeof(response) - pos); + pos += count; + } + + UNIT_ASSERT_VALUES_EQUAL(sizeof(response), pos); + + UNIT_ASSERT_VALUES_EQUAL(YBUS_VERSION, response.GetVersionInternal()); + } + + struct TOnConnectionEventClient: public TExampleClient { + TTestSync Sync; + + ~TOnConnectionEventClient() override { + Session->Shutdown(); + } + + void OnClientConnectionEvent(const TClientConnectionEvent& event) override { + if (Sync.Get() > 2) { + // Test OnClientConnectionEvent_Disconnect is broken. + // Sometimes reconnect happens during server shutdown + // when acceptor connections is still alive, and + // server connection is already closed + return; + } + + if (event.GetType() == TClientConnectionEvent::CONNECTED) { + Sync.WaitForAndIncrement(0); + } else if (event.GetType() == TClientConnectionEvent::DISCONNECTED) { + Sync.WaitForAndIncrement(2); + } + } + + void OnError(TAutoPtr<TBusMessage>, EMessageStatus) override { + // We do not check for message errors in this test. + } + + void OnMessageSentOneWay(TAutoPtr<TBusMessage>) override { + } + }; + + struct TOnConnectionEventServer: public TExampleServer { + TOnConnectionEventServer() + : TExampleServer("TOnConnectionEventServer") + { + } + + ~TOnConnectionEventServer() override { + Session->Shutdown(); + } + + void OnError(TAutoPtr<TBusMessage>, EMessageStatus) override { + // We do not check for server message errors in this test. + } + }; + + Y_UNIT_TEST(OnClientConnectionEvent_Shutdown) { + TObjectCountCheck objectCountCheck; + + TOnConnectionEventServer server; + + TOnConnectionEventClient client; + + TNetAddr addr("127.0.0.1", server.Session->GetActualListenPort()); + + client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &addr); + + client.Sync.WaitForAndIncrement(1); + + client.Session->Shutdown(); + + client.Sync.WaitForAndIncrement(3); + } + + Y_UNIT_TEST(OnClientConnectionEvent_Disconnect) { + TObjectCountCheck objectCountCheck; + + THolder<TOnConnectionEventServer> server(new TOnConnectionEventServer); + + TOnConnectionEventClient client; + TNetAddr addr("127.0.0.1", server->Session->GetActualListenPort()); + + client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &addr); + + client.Sync.WaitForAndIncrement(1); + + server.Destroy(); + + client.Sync.WaitForAndIncrement(3); + } + + struct TServerForQuotaWake: public TExampleServer { + TSystemEvent GoOn; + TMutex OneLock; + + TOnMessageContext OneMessage; + + static TBusServerSessionConfig Config() { + TBusServerSessionConfig config; + + config.PerConnectionMaxInFlight = 1; + config.PerConnectionMaxInFlightBySize = 1500; + config.MaxMessageSize = 1024; + + return config; + } + + TServerForQuotaWake() + : TExampleServer("TServerForQuotaWake", Config()) + { + } + + ~TServerForQuotaWake() override { + Session->Shutdown(); + } + + void OnMessage(TOnMessageContext& req) override { + if (!GoOn.Wait(0)) { + TGuard<TMutex> guard(OneLock); + + UNIT_ASSERT(!OneMessage); + + OneMessage.Swap(req); + } else + TExampleServer::OnMessage(req); + } + + void WakeOne() { + TGuard<TMutex> guard(OneLock); + + UNIT_ASSERT(!!OneMessage); + + TExampleServer::OnMessage(OneMessage); + + TOnMessageContext().Swap(OneMessage); + } + }; + + Y_UNIT_TEST(WakeReaderOnQuota) { + const size_t test_msg_count = 64; + + TBusClientSessionConfig clientConfig; + + clientConfig.MaxInFlight = test_msg_count; + + TExampleClient client(clientConfig); + TServerForQuotaWake server; + TInstant start; + + client.MessageCount = test_msg_count; + + const NBus::TNetAddr addr = server.GetActualListenAddr(); + + for (unsigned count = 0;;) { + UNIT_ASSERT(count <= test_msg_count); + + TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount)); + EMessageStatus status = client.Session->SendMessageAutoPtr(message, &addr); + + if (status == MESSAGE_OK) { + count++; + + } else if (status == MESSAGE_BUSY) { + if (count == test_msg_count) { + TInstant now = TInstant::Now(); + + if (start.GetValue() == 0) { + start = now; + + // TODO: properly check that server is blocked + } else if (start + TDuration::MilliSeconds(100) < now) { + break; + } + } + + Sleep(TDuration::MilliSeconds(10)); + + } else + UNIT_ASSERT(false); + } + + server.GoOn.Signal(); + server.WakeOne(); + + client.WaitReplies(); + + server.WaitForOnMessageCount(test_msg_count); + }; + + Y_UNIT_TEST(TestConnectionAttempts) { + TObjectCountCheck objectCountCheck; + + TNetAddr noServerAddr("localhost", 17); + TBusClientSessionConfig clientConfig; + clientConfig.RetryInterval = 100; + TestNoServerImplClient client(clientConfig); + + int count = 0; + for (; count < 10; ++count) { + EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), + &noServerAddr); + + Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data()); + client.TestSync.WaitForAndIncrement(count * 2 + 1); + + // First connection attempt is for connect call; second one is to get connect result. + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + } + Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval)); + for (; count < 10; ++count) { + EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), + &noServerAddr); + + Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data()); + client.TestSync.WaitForAndIncrement(count * 2 + 1); + + // First connection attempt is for connect call; second one is to get connect result. + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 4); + } + }; + + Y_UNIT_TEST(TestConnectionAttemptsOnNoMessagesAndNotReconnectWhenIdle) { + TObjectCountCheck objectCountCheck; + + TNetAddr noServerAddr("localhost", 17); + TBusClientSessionConfig clientConfig; + clientConfig.RetryInterval = 100; + clientConfig.ReconnectWhenIdle = false; + TestNoServerImplClient client(clientConfig); + + int count = 0; + for (; count < 10; ++count) { + EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), + &noServerAddr); + + Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data()); + client.TestSync.WaitForAndIncrement(count * 2 + 1); + + // First connection attempt is for connect call; second one is to get connect result. + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + } + + Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval / 2)); + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + Sleep(TDuration::MilliSeconds(10 * clientConfig.RetryInterval)); + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + }; + + Y_UNIT_TEST(TestConnectionAttemptsOnNoMessagesAndReconnectWhenIdle) { + TObjectCountCheck objectCountCheck; + + TNetAddr noServerAddr("localhost", 17); + TBusClientSessionConfig clientConfig; + clientConfig.ReconnectWhenIdle = true; + clientConfig.RetryInterval = 100; + TestNoServerImplClient client(clientConfig); + + int count = 0; + for (; count < 10; ++count) { + EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), + &noServerAddr); + + Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data()); + client.TestSync.WaitForAndIncrement(count * 2 + 1); + + // First connection attempt is for connect call; second one is to get connect result. + UNIT_ASSERT_VALUES_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + } + + Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval / 2)); + UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2); + Sleep(TDuration::MilliSeconds(10 * clientConfig.RetryInterval)); + // it is undeterministic how many reconnects will be during that amount of time + // but it should occur at least once + UNIT_ASSERT(client.Session->GetConnectSyscallsNumForTest(noServerAddr) > 2); + }; +}; diff --git a/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp b/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp new file mode 100644 index 0000000000..4083cf3b7b --- /dev/null +++ b/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp @@ -0,0 +1,143 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/message_handler_error.h> + +#include <library/cpp/messagebus/misc/test_sync.h> +#include <library/cpp/messagebus/oldmodule/module.h> + +using namespace NBus; +using namespace NBus::NTest; + +Y_UNIT_TEST_SUITE(ModuleClientOneWay) { + struct TTestServer: public TBusServerHandlerError { + TExampleProtocol Proto; + + TTestSync* const TestSync; + + TBusMessageQueuePtr Queue; + TBusServerSessionPtr ServerSession; + + TTestServer(TTestSync* testSync) + : TestSync(testSync) + { + Queue = CreateMessageQueue(); + ServerSession = TBusServerSession::Create(&Proto, this, TBusServerSessionConfig(), Queue); + } + + void OnMessage(TOnMessageContext& context) override { + TestSync->WaitForAndIncrement(1); + context.ForgetRequest(); + } + }; + + struct TClientModule: public TBusModule { + TExampleProtocol Proto; + + TTestSync* const TestSync; + unsigned const Port; + + TBusClientSessionPtr ClientSession; + + TClientModule(TTestSync* testSync, unsigned port) + : TBusModule("m") + , TestSync(testSync) + , Port(port) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage*) override { + TestSync->WaitForAndIncrement(0); + + job->SendOneWayTo(new TExampleRequest(&Proto.RequestCount), ClientSession.Get(), TNetAddr("localhost", Port)); + + return &TClientModule::Sent; + } + + TJobHandler Sent(TBusJob* job, TBusMessage*) { + TestSync->WaitForAndIncrement(2); + job->Cancel(MESSAGE_DONT_ASK); + return nullptr; + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + ClientSession = CreateDefaultSource(queue, &Proto, TBusServerSessionConfig()); + return nullptr; + } + }; + + Y_UNIT_TEST(Simple) { + TTestSync testSync; + + TTestServer server(&testSync); + + TBusMessageQueuePtr queue = CreateMessageQueue(); + TClientModule clientModule(&testSync, server.ServerSession->GetActualListenPort()); + + clientModule.CreatePrivateSessions(queue.Get()); + clientModule.StartInput(); + + clientModule.StartJob(new TExampleRequest(&clientModule.Proto.StartCount)); + + testSync.WaitForAndIncrement(3); + + clientModule.Shutdown(); + } + + struct TSendErrorModule: public TBusModule { + TExampleProtocol Proto; + + TTestSync* const TestSync; + + TBusClientSessionPtr ClientSession; + + TSendErrorModule(TTestSync* testSync) + : TBusModule("m") + , TestSync(testSync) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage*) override { + TestSync->WaitForAndIncrement(0); + + job->SendOneWayTo(new TExampleRequest(&Proto.RequestCount), ClientSession.Get(), TNetAddr("localhost", 1)); + + return &TSendErrorModule::Sent; + } + + TJobHandler Sent(TBusJob* job, TBusMessage*) { + TestSync->WaitForAndIncrement(1); + job->Cancel(MESSAGE_DONT_ASK); + return nullptr; + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + TBusServerSessionConfig sessionConfig; + sessionConfig.ConnectTimeout = 1; + sessionConfig.SendTimeout = 1; + sessionConfig.TotalTimeout = 1; + sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(1); + ClientSession = CreateDefaultSource(queue, &Proto, sessionConfig); + return nullptr; + } + }; + + Y_UNIT_TEST(SendError) { + TTestSync testSync; + + TBusQueueConfig queueConfig; + queueConfig.NumWorkers = 5; + + TBusMessageQueuePtr queue = CreateMessageQueue(queueConfig); + TSendErrorModule clientModule(&testSync); + + clientModule.CreatePrivateSessions(queue.Get()); + clientModule.StartInput(); + + clientModule.StartJob(new TExampleRequest(&clientModule.Proto.StartCount)); + + testSync.WaitForAndIncrement(2); + + clientModule.Shutdown(); + } +} diff --git a/library/cpp/messagebus/test/ut/module_client_ut.cpp b/library/cpp/messagebus/test/ut/module_client_ut.cpp new file mode 100644 index 0000000000..ebfe185cc6 --- /dev/null +++ b/library/cpp/messagebus/test/ut/module_client_ut.cpp @@ -0,0 +1,368 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "count_down_latch.h" +#include "moduletest.h" + +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/example_module.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> +#include <library/cpp/messagebus/test/helper/wait_for.h> + +#include <library/cpp/messagebus/misc/test_sync.h> +#include <library/cpp/messagebus/oldmodule/module.h> + +#include <util/generic/cast.h> +#include <util/system/event.h> + +using namespace NBus; +using namespace NBus::NTest; + +// helper class that cleans TBusJob instance, so job's destructor can +// be completed without assertion fail. +struct TJobGuard { +public: + TJobGuard(NBus::TBusJob* job) + : Job(job) + { + } + + ~TJobGuard() { + Job->ClearAllMessageStates(); + } + +private: + NBus::TBusJob* Job; +}; + +class TMessageOk: public NBus::TBusMessage { +public: + TMessageOk() + : NBus::TBusMessage(1) + { + } +}; + +class TMessageError: public NBus::TBusMessage { +public: + TMessageError() + : NBus::TBusMessage(2) + { + } +}; + +Y_UNIT_TEST_SUITE(BusJobTest) { +#if 0 + Y_UNIT_TEST(TestPending) { + TObjectCountCheck objectCountCheck; + + TDupDetectModule module; + TBusJob job(&module, new TBusMessage(0)); + // Guard will clear the job if unit-assertion fails. + TJobGuard g(&job); + + NBus::TBusMessage* msg = new NBus::TBusMessage(1); + job.Send(msg, NULL); + NBus::TJobStateVec pending; + job.GetPending(&pending); + + UNIT_ASSERT_VALUES_EQUAL(pending.size(), 1u); + UNIT_ASSERT_EQUAL(msg, pending[0].Message); + } + + Y_UNIT_TEST(TestCallReplyHandler) { + TObjectCountCheck objectCountCheck; + + TDupDetectModule module; + NBus::TBusJob job(&module, new NBus::TBusMessage(0)); + // Guard will clear the job if unit-assertion fails. + TJobGuard g(&job); + + NBus::TBusMessage* msgOk = new TMessageOk; + NBus::TBusMessage* msgError = new TMessageError; + job.Send(msgOk, NULL); + job.Send(msgError, NULL); + + UNIT_ASSERT_EQUAL(job.GetState<TMessageOk>(), NULL); + UNIT_ASSERT_EQUAL(job.GetState<TMessageError>(), NULL); + + NBus::TBusMessage* reply = new NBus::TBusMessage(0); + job.CallReplyHandler(NBus::MESSAGE_OK, msgOk, reply); + job.CallReplyHandler(NBus::MESSAGE_TIMEOUT, msgError, NULL); + + UNIT_ASSERT_UNEQUAL(job.GetState<TMessageOk>(), NULL); + UNIT_ASSERT_UNEQUAL(job.GetState<TMessageError>(), NULL); + + UNIT_ASSERT_VALUES_EQUAL(job.GetStatus<TMessageError>(), NBus::MESSAGE_TIMEOUT); + UNIT_ASSERT_EQUAL(job.GetState<TMessageError>()->Status, NBus::MESSAGE_TIMEOUT); + + UNIT_ASSERT_VALUES_EQUAL(job.GetStatus<TMessageOk>(), NBus::MESSAGE_OK); + UNIT_ASSERT_EQUAL(job.GetState<TMessageOk>()->Reply, reply); + } +#endif + + struct TParallelOnReplyModule : TExampleClientModule { + TNetAddr ServerAddr; + + TCountDownLatch RepliesLatch; + + TParallelOnReplyModule(const TNetAddr& serverAddr) + : ServerAddr(serverAddr) + , RepliesLatch(2) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + job->Send(new TExampleRequest(&Proto.RequestCount), Source, TReplyHandler(&TParallelOnReplyModule::ReplyHandler), 0, ServerAddr); + return &TParallelOnReplyModule::HandleReplies; + } + + void ReplyHandler(TBusJob*, EMessageStatus status, TBusMessage* mess, TBusMessage* reply) { + Y_UNUSED(mess); + Y_UNUSED(reply); + Y_VERIFY(status == MESSAGE_OK, "failed to get reply: %s", ToCString(status)); + } + + TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) { + Y_UNUSED(mess); + RepliesLatch.CountDown(); + Y_VERIFY(RepliesLatch.Await(TDuration::Seconds(10)), "failed to get answers"); + job->Cancel(MESSAGE_UNKNOWN); + return nullptr; + } + }; + + Y_UNIT_TEST(TestReplyHandlerCalledInParallel) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + + TExampleProtocol proto; + + TBusQueueConfig config; + config.NumWorkers = 5; + + TParallelOnReplyModule module(server.GetActualListenAddr()); + module.StartModule(); + + module.StartJob(new TExampleRequest(&proto.StartCount)); + module.StartJob(new TExampleRequest(&proto.StartCount)); + + UNIT_ASSERT(module.RepliesLatch.Await(TDuration::Seconds(10))); + + module.Shutdown(); + } + + struct TErrorHandlerCheckerModule : TExampleModule { + TNetAddr ServerAddr; + + TBusClientSessionPtr Source; + + TCountDownLatch GotReplyLatch; + + TBusMessage* SentMessage; + + TErrorHandlerCheckerModule() + : ServerAddr("localhost", 17) + , GotReplyLatch(2) + , SentMessage() + { + } + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + TExampleRequest* message = new TExampleRequest(&Proto.RequestCount); + job->Send(message, Source, TReplyHandler(&TErrorHandlerCheckerModule::ReplyHandler), 0, ServerAddr); + SentMessage = message; + return &TErrorHandlerCheckerModule::HandleReplies; + } + + void ReplyHandler(TBusJob*, EMessageStatus status, TBusMessage* req, TBusMessage* resp) { + Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "got wrong status: %s", ToString(status).data()); + Y_VERIFY(req == SentMessage, "checking request"); + Y_VERIFY(resp == nullptr, "checking response"); + GotReplyLatch.CountDown(); + } + + TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) { + Y_UNUSED(mess); + job->Cancel(MESSAGE_UNKNOWN); + GotReplyLatch.CountDown(); + return nullptr; + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + TBusClientSessionConfig sessionConfig; + sessionConfig.SendTimeout = 1; // TODO: allow 0 + sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(10); + Source = CreateDefaultSource(queue, &Proto, sessionConfig); + Source->RegisterService("localhost"); + return nullptr; + } + }; + + Y_UNIT_TEST(ErrorHandler) { + TExampleProtocol proto; + + TBusQueueConfig config; + config.NumWorkers = 5; + + TErrorHandlerCheckerModule module; + + TBusModuleConfig moduleConfig; + moduleConfig.Secret.SchedulePeriod = TDuration::MilliSeconds(10); + module.SetConfig(moduleConfig); + + module.StartModule(); + + module.StartJob(new TExampleRequest(&proto.StartCount)); + + module.GotReplyLatch.Await(); + + module.Shutdown(); + } + + struct TSlowReplyServer: public TBusServerHandlerError { + TTestSync* const TestSync; + TBusMessageQueuePtr Bus; + TBusServerSessionPtr ServerSession; + TExampleProtocol Proto; + + TAtomic OnMessageCount; + + TSlowReplyServer(TTestSync* testSync) + : TestSync(testSync) + , OnMessageCount(0) + { + Bus = CreateMessageQueue("TSlowReplyServer"); + TBusServerSessionConfig sessionConfig; + ServerSession = TBusServerSession::Create(&Proto, this, sessionConfig, Bus); + } + + void OnMessage(TOnMessageContext& req) override { + if (AtomicIncrement(OnMessageCount) == 1) { + TestSync->WaitForAndIncrement(0); + } + TAutoPtr<TBusMessage> response(new TExampleResponse(&Proto.ResponseCount)); + req.SendReplyMove(response); + } + }; + + struct TModuleThatSendsReplyEarly: public TExampleClientModule { + TTestSync* const TestSync; + const unsigned ServerPort; + + TBusServerSessionPtr ServerSession; + TAtomic ReplyCount; + + TModuleThatSendsReplyEarly(TTestSync* testSync, unsigned serverPort) + : TestSync(testSync) + , ServerPort(serverPort) + , ServerSession(nullptr) + , ReplyCount(0) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + for (unsigned i = 0; i < 2; ++i) { + job->Send( + new TExampleRequest(&Proto.RequestCount), + Source, + TReplyHandler(&TModuleThatSendsReplyEarly::ReplyHandler), + 0, + TNetAddr("127.0.0.1", ServerPort)); + } + return &TModuleThatSendsReplyEarly::HandleReplies; + } + + void ReplyHandler(TBusJob* job, EMessageStatus status, TBusMessage* mess, TBusMessage* reply) { + Y_UNUSED(mess); + Y_UNUSED(reply); + Y_VERIFY(status == MESSAGE_OK, "failed to get reply"); + if (AtomicIncrement(ReplyCount) == 1) { + TestSync->WaitForAndIncrement(1); + job->SendReply(new TExampleResponse(&Proto.ResponseCount)); + } else { + TestSync->WaitForAndIncrement(3); + } + } + + TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) { + Y_UNUSED(mess); + job->Cancel(MESSAGE_UNKNOWN); + return nullptr; + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + TExampleClientModule::CreateExtSession(queue); + TBusServerSessionConfig sessionConfig; + return ServerSession = CreateDefaultDestination(queue, &Proto, sessionConfig); + } + }; + + Y_UNIT_TEST(SendReplyCalledBeforeAllRepliesReceived) { + TTestSync testSync; + + TSlowReplyServer slowReplyServer(&testSync); + + TModuleThatSendsReplyEarly module(&testSync, slowReplyServer.ServerSession->GetActualListenPort()); + module.StartModule(); + + TExampleClient client; + TNetAddr addr("127.0.0.1", module.ServerSession->GetActualListenPort()); + client.SendMessagesWaitReplies(1, &addr); + + testSync.WaitForAndIncrement(2); + + module.Shutdown(); + } + + struct TShutdownCalledBeforeReplyReceivedModule: public TExampleClientModule { + unsigned ServerPort; + + TTestSync TestSync; + + TShutdownCalledBeforeReplyReceivedModule(unsigned serverPort) + : ServerPort(serverPort) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage*) override { + TestSync.CheckAndIncrement(0); + + job->Send(new TExampleRequest(&Proto.RequestCount), Source, + TReplyHandler(&TShutdownCalledBeforeReplyReceivedModule::HandleReply), + 0, TNetAddr("localhost", ServerPort)); + return &TShutdownCalledBeforeReplyReceivedModule::End; + } + + void HandleReply(TBusJob*, EMessageStatus status, TBusMessage*, TBusMessage*) { + Y_VERIFY(status == MESSAGE_SHUTDOWN, "got %s", ToCString(status)); + TestSync.CheckAndIncrement(1); + } + + TJobHandler End(TBusJob* job, TBusMessage*) { + TestSync.CheckAndIncrement(2); + job->Cancel(MESSAGE_SHUTDOWN); + return nullptr; + } + }; + + Y_UNIT_TEST(ShutdownCalledBeforeReplyReceived) { + TExampleServer server; + server.ForgetRequest = true; + + TShutdownCalledBeforeReplyReceivedModule module(server.GetActualListenPort()); + + module.StartModule(); + + module.StartJob(new TExampleRequest(&module.Proto.RequestCount)); + + server.TestSync.WaitFor(1); + + module.Shutdown(); + + module.TestSync.CheckAndIncrement(3); + } +} diff --git a/library/cpp/messagebus/test/ut/module_server_ut.cpp b/library/cpp/messagebus/test/ut/module_server_ut.cpp new file mode 100644 index 0000000000..88fe1dd9b6 --- /dev/null +++ b/library/cpp/messagebus/test/ut/module_server_ut.cpp @@ -0,0 +1,119 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "count_down_latch.h" +#include "moduletest.h" + +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/example_module.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> +#include <library/cpp/messagebus/test/helper/wait_for.h> + +#include <library/cpp/messagebus/oldmodule/module.h> + +#include <util/generic/cast.h> + +using namespace NBus; +using namespace NBus::NTest; + +Y_UNIT_TEST_SUITE(ModuleServerTests) { + Y_UNIT_TEST(TestModule) { + TObjectCountCheck objectCountCheck; + + /// create or get instance of message queue, need one per application + TBusMessageQueuePtr bus(CreateMessageQueue()); + THostInfoHandler hostHandler(bus.Get()); + TDupDetectModule module(hostHandler.GetActualListenAddr()); + bool success; + success = module.Init(bus.Get()); + UNIT_ASSERT_C(success, "failed to initialize dupdetect module"); + + success = module.StartInput(); + UNIT_ASSERT_C(success, "failed to start dupdetect module"); + + TDupDetectHandler dupHandler(module.ListenAddr, bus.Get()); + dupHandler.Work(); + + UNIT_WAIT_FOR(dupHandler.NumMessages == dupHandler.NumReplies); + + module.Shutdown(); + dupHandler.DupDetect->Shutdown(); + } + + struct TParallelOnMessageModule: public TExampleServerModule { + TCountDownLatch WaitTwoRequestsLatch; + + TParallelOnMessageModule() + : WaitTwoRequestsLatch(2) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + WaitTwoRequestsLatch.CountDown(); + Y_VERIFY(WaitTwoRequestsLatch.Await(TDuration::Seconds(5)), "oops"); + + VerifyDynamicCast<TExampleRequest*>(mess); + + job->SendReply(new TExampleResponse(&Proto.ResponseCount)); + return nullptr; + } + }; + + Y_UNIT_TEST(TestOnMessageHandlerCalledInParallel) { + TObjectCountCheck objectCountCheck; + + TBusQueueConfig config; + config.NumWorkers = 5; + + TParallelOnMessageModule module; + module.StartModule(); + + TExampleClient client; + + client.SendMessagesWaitReplies(2, module.ServerAddr); + + module.Shutdown(); + } + + struct TDelayReplyServer: public TExampleServerModule { + TSystemEvent MessageReceivedEvent; + TSystemEvent ClientDiedEvent; + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + + MessageReceivedEvent.Signal(); + + Y_VERIFY(ClientDiedEvent.WaitT(TDuration::Seconds(5)), "oops"); + + job->SendReply(new TExampleResponse(&Proto.ResponseCount)); + return nullptr; + } + }; + + Y_UNIT_TEST(TestReplyCalledAfterClientDisconnected) { + TObjectCountCheck objectCountCheck; + + TBusQueueConfig config; + config.NumWorkers = 5; + + TDelayReplyServer server; + server.StartModule(); + + THolder<TExampleClient> client(new TExampleClient); + + client->SendMessages(1, server.ServerAddr); + + UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5))); + + UNIT_ASSERT_VALUES_EQUAL(1, server.GetModuleSessionInFlight()); + + client.Destroy(); + + server.ClientDiedEvent.Signal(); + + // wait until all server message are delivered + UNIT_WAIT_FOR(0 == server.GetModuleSessionInFlight()); + + server.Shutdown(); + } +} diff --git a/library/cpp/messagebus/test/ut/moduletest.h b/library/cpp/messagebus/test/ut/moduletest.h new file mode 100644 index 0000000000..d5da72c0cb --- /dev/null +++ b/library/cpp/messagebus/test/ut/moduletest.h @@ -0,0 +1,221 @@ +#pragma once + +/////////////////////////////////////////////////////////////////// +/// \file +/// \brief Example of using local session for communication. + +#include <library/cpp/messagebus/test/helper/alloc_counter.h> +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/message_handler_error.h> + +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/oldmodule/module.h> + +namespace NBus { + namespace NTest { + using namespace std; + +#define TYPE_HOSTINFOREQUEST 100 +#define TYPE_HOSTINFORESPONSE 101 + + //////////////////////////////////////////////////////////////////// + /// \brief DupDetect protocol that common between client and server + //////////////////////////////////////////////////////////////////// + /// \brief HostInfo request class + class THostInfoMessage: public TBusMessage { + public: + THostInfoMessage() + : TBusMessage(TYPE_HOSTINFOREQUEST) + { + } + THostInfoMessage(ECreateUninitialized) + : TBusMessage(MESSAGE_CREATE_UNINITIALIZED) + { + } + + ~THostInfoMessage() override { + } + }; + + //////////////////////////////////////////////////////////////////// + /// \brief HostInfo reply class + class THostInfoReply: public TBusMessage { + public: + THostInfoReply() + : TBusMessage(TYPE_HOSTINFORESPONSE) + { + } + THostInfoReply(ECreateUninitialized) + : TBusMessage(MESSAGE_CREATE_UNINITIALIZED) + { + } + + ~THostInfoReply() override { + } + }; + + //////////////////////////////////////////////////////////////////// + /// \brief HostInfo protocol that common between client and server + class THostInfoProtocol: public TBusProtocol { + public: + THostInfoProtocol() + : TBusProtocol("HOSTINFO", 0) + { + } + /// serialized protocol specific data into TBusData + void Serialize(const TBusMessage* mess, TBuffer& data) override { + Y_UNUSED(data); + Y_UNUSED(mess); + } + + /// deserialized TBusData into new instance of the message + TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override { + Y_UNUSED(payload); + + if (messageType == TYPE_HOSTINFOREQUEST) { + return new THostInfoMessage(MESSAGE_CREATE_UNINITIALIZED); + } else if (messageType == TYPE_HOSTINFORESPONSE) { + return new THostInfoReply(MESSAGE_CREATE_UNINITIALIZED); + } else { + Y_FAIL("unknown"); + } + } + }; + + ////////////////////////////////////////////////////////////// + /// \brief HostInfo handler (should convert it to module too) + struct THostInfoHandler: public TBusServerHandlerError { + TBusServerSessionPtr Session; + TBusServerSessionConfig HostInfoConfig; + THostInfoProtocol HostInfoProto; + + THostInfoHandler(TBusMessageQueue* queue) { + Session = TBusServerSession::Create(&HostInfoProto, this, HostInfoConfig, queue); + } + + void OnMessage(TOnMessageContext& mess) override { + usleep(10 * 1000); /// pretend we are doing something + + TAutoPtr<THostInfoReply> reply(new THostInfoReply()); + + mess.SendReplyMove(reply); + } + + TNetAddr GetActualListenAddr() { + return TNetAddr("localhost", Session->GetActualListenPort()); + } + }; + + ////////////////////////////////////////////////////////////// + /// \brief DupDetect handler (should convert it to module too) + struct TDupDetectHandler: public TBusClientHandlerError { + TNetAddr ServerAddr; + + TBusClientSessionPtr DupDetect; + TBusClientSessionConfig DupDetectConfig; + TExampleProtocol DupDetectProto; + + int NumMessages; + int NumReplies; + + TDupDetectHandler(const TNetAddr& serverAddr, TBusMessageQueuePtr queue) + : ServerAddr(serverAddr) + { + DupDetect = TBusClientSession::Create(&DupDetectProto, this, DupDetectConfig, queue); + DupDetect->RegisterService("localhost"); + } + + void Work() { + NumMessages = 10; + NumReplies = 0; + + for (int i = 0; i < NumMessages; i++) { + TExampleRequest* mess = new TExampleRequest(&DupDetectProto.RequestCount); + DupDetect->SendMessage(mess, &ServerAddr); + } + } + + void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override { + Y_UNUSED(mess); + Y_UNUSED(reply); + NumReplies++; + } + }; + + ///////////////////////////////////////////////////////////////// + /// \brief DupDetect module + + struct TDupDetectModule: public TBusModule { + TNetAddr HostInfoAddr; + + TBusClientSessionPtr HostInfoClientSession; + TBusClientSessionConfig HostInfoConfig; + THostInfoProtocol HostInfoProto; + + TExampleProtocol DupDetectProto; + TBusServerSessionConfig DupDetectConfig; + + TNetAddr ListenAddr; + + TDupDetectModule(const TNetAddr& hostInfoAddr) + : TBusModule("DUPDETECTMODULE") + , HostInfoAddr(hostInfoAddr) + { + } + + bool Init(TBusMessageQueue* queue) { + HostInfoClientSession = CreateDefaultSource(*queue, &HostInfoProto, HostInfoConfig); + HostInfoClientSession->RegisterService("localhost"); + + return TBusModule::CreatePrivateSessions(queue); + } + + TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override { + TBusServerSessionPtr session = CreateDefaultDestination(queue, &DupDetectProto, DupDetectConfig); + + ListenAddr = TNetAddr("localhost", session->GetActualListenPort()); + + return session; + } + + /// entry point into module, first function to call + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + TExampleRequest* dmess = dynamic_cast<TExampleRequest*>(mess); + Y_UNUSED(dmess); + + THostInfoMessage* hmess = new THostInfoMessage(); + + /// send message to imaginary hostinfo server + job->Send(hmess, HostInfoClientSession, TReplyHandler(), 0, HostInfoAddr); + + return TJobHandler(&TDupDetectModule::ProcessHostInfo); + } + + /// next handler is executed when all outstanding requests from previous handler is completed + TJobHandler ProcessHostInfo(TBusJob* job, TBusMessage* mess) { + TExampleRequest* dmess = dynamic_cast<TExampleRequest*>(mess); + Y_UNUSED(dmess); + + THostInfoMessage* hmess = job->Get<THostInfoMessage>(); + THostInfoReply* hreply = job->Get<THostInfoReply>(); + EMessageStatus hstatus = job->GetStatus<THostInfoMessage>(); + Y_ASSERT(hmess != nullptr); + Y_ASSERT(hreply != nullptr); + Y_ASSERT(hstatus == MESSAGE_OK); + + return TJobHandler(&TDupDetectModule::Finish); + } + + /// last handler sends reply and returns NULL + TJobHandler Finish(TBusJob* job, TBusMessage* mess) { + Y_UNUSED(mess); + + TExampleResponse* reply = new TExampleResponse(&DupDetectProto.ResponseCount); + job->SendReply(reply); + + return nullptr; + } + }; + + } +} diff --git a/library/cpp/messagebus/test/ut/one_way_ut.cpp b/library/cpp/messagebus/test/ut/one_way_ut.cpp new file mode 100644 index 0000000000..9c21227e2b --- /dev/null +++ b/library/cpp/messagebus/test/ut/one_way_ut.cpp @@ -0,0 +1,255 @@ +/////////////////////////////////////////////////////////////////// +/// \file +/// \brief Example of reply-less communication + +/// This example demostrates how asynchronous message passing library +/// can be used to send message and do not wait for reply back. +/// The usage of reply-less communication should be restricted to +/// low-throughput clients and high-throughput server to provide reasonable +/// utility. Removing replies from the communication removes any restriction +/// on how many message can be send to server and rougue clients may overwelm +/// server without thoughtput control. + +/// 1) To implement reply-less client \n + +/// Call NBus::TBusSession::AckMessage() +/// from within NBus::IMessageHandler::OnSent() handler when message has +/// gone into wire on client end. See example in NBus::NullClient::OnMessageSent(). +/// Discard identity for reply message. + +/// 2) To implement reply-less server \n + +/// Call NBus::TBusSession::AckMessage() from within NBus::IMessageHandler::OnMessage() +/// handler when message has been received on server end. +/// See example in NBus::NullServer::OnMessage(). +/// Discard identity for reply message. + +#include <library/cpp/messagebus/test/helper/alloc_counter.h> +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/hanging_server.h> +#include <library/cpp/messagebus/test/helper/message_handler_error.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> +#include <library/cpp/messagebus/test/helper/wait_for.h> + +#include <library/cpp/messagebus/ybus.h> + +using namespace std; +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NBus::NTest; + +//////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////// +/// \brief Reply-less client and handler +struct NullClient : TBusClientHandlerError { + TNetAddr ServerAddr; + + TBusMessageQueuePtr Queue; + TBusClientSessionPtr Session; + TExampleProtocol Proto; + + /// constructor creates instances of protocol and session + NullClient(const TNetAddr& serverAddr, const TBusClientSessionConfig& sessionConfig = TBusClientSessionConfig()) + : ServerAddr(serverAddr) + { + UNIT_ASSERT(serverAddr.GetPort() > 0); + + /// create or get instance of message queue, need one per application + Queue = CreateMessageQueue(); + + /// register source/client session + Session = TBusClientSession::Create(&Proto, this, sessionConfig, Queue); + + /// register service, announce to clients via LocatorService + Session->RegisterService("localhost"); + } + + ~NullClient() override { + Session->Shutdown(); + } + + /// dispatch of requests is done here + void Work() { + int batch = 10; + + for (int i = 0; i < batch; i++) { + TExampleRequest* mess = new TExampleRequest(&Proto.RequestCount); + mess->Data = "TADA"; + Session->SendMessageOneWay(mess, &ServerAddr); + } + } + + void OnMessageSentOneWay(TAutoPtr<TBusMessage>) override { + } +}; + +///////////////////////////////////////////////////////////////////// +/// \brief Reply-less server and handler +class NullServer: public TBusServerHandlerError { +public: + /// session object to maintian + TBusMessageQueuePtr Queue; + TBusServerSessionPtr Session; + TExampleProtocol Proto; + +public: + TAtomic NumMessages; + + NullServer() { + NumMessages = 0; + + /// create or get instance of single message queue, need one for application + Queue = CreateMessageQueue(); + + /// register destination session + TBusServerSessionConfig sessionConfig; + Session = TBusServerSession::Create(&Proto, this, sessionConfig, Queue); + } + + ~NullServer() override { + Session->Shutdown(); + } + + /// when message comes do not send reply, just acknowledge + void OnMessage(TOnMessageContext& mess) override { + TExampleRequest* fmess = static_cast<TExampleRequest*>(mess.GetMessage()); + + Y_ASSERT(fmess->Data == "TADA"); + + /// tell session to forget this message and never expect any reply + mess.ForgetRequest(); + + AtomicIncrement(NumMessages); + } + + /// this handler should not be called because this server does not send replies + void OnSent(TAutoPtr<TBusMessage> mess) override { + Y_UNUSED(mess); + Y_FAIL("This server does not sent replies"); + } +}; + +Y_UNIT_TEST_SUITE(TMessageBusTests_OneWay) { + Y_UNIT_TEST(Simple) { + TObjectCountCheck objectCountCheck; + + NullServer server; + NullClient client(TNetAddr("localhost", server.Session->GetActualListenPort())); + + client.Work(); + + // wait until all client message are delivered + UNIT_WAIT_FOR(AtomicGet(server.NumMessages) == 10); + + // assert correct number of messages + UNIT_ASSERT_VALUES_EQUAL(AtomicGet(server.NumMessages), 10); + UNIT_ASSERT_VALUES_EQUAL(server.Session->GetInFlight(), 0); + UNIT_ASSERT_VALUES_EQUAL(client.Session->GetInFlight(), 0); + } + + struct TMessageTooLargeClient: public NullClient { + TSystemEvent GotTooLarge; + + TBusClientSessionConfig Config() { + TBusClientSessionConfig r; + r.MaxMessageSize = 1; + return r; + } + + TMessageTooLargeClient(unsigned port) + : NullClient(TNetAddr("localhost", port), Config()) + { + } + + ~TMessageTooLargeClient() override { + Session->Shutdown(); + } + + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + Y_UNUSED(mess); + + Y_VERIFY(status == MESSAGE_MESSAGE_TOO_LARGE, "wrong status: %s", ToCString(status)); + + GotTooLarge.Signal(); + } + }; + + Y_UNIT_TEST(MessageTooLargeOnClient) { + TObjectCountCheck objectCountCheck; + + NullServer server; + + TMessageTooLargeClient client(server.Session->GetActualListenPort()); + + EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr); + UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok); + + client.GotTooLarge.WaitI(); + } + + struct TCheckTimeoutClient: public NullClient { + ~TCheckTimeoutClient() override { + Session->Shutdown(); + } + + static TBusClientSessionConfig SessionConfig() { + TBusClientSessionConfig sessionConfig; + sessionConfig.SendTimeout = 1; + sessionConfig.ConnectTimeout = 1; + sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(10); + return sessionConfig; + } + + TCheckTimeoutClient(const TNetAddr& serverAddr) + : NullClient(serverAddr, SessionConfig()) + { + } + + TSystemEvent GotError; + + /// message that could not be delivered + void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override { + Y_UNUSED(mess); + Y_UNUSED(status); // TODO: check status + + GotError.Signal(); + } + }; + + Y_UNIT_TEST(SendTimeout_Callback_NoServer) { + TObjectCountCheck objectCountCheck; + + TCheckTimeoutClient client(TNetAddr("localhost", 17)); + + EMessageStatus ok = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr); + UNIT_ASSERT_EQUAL(ok, MESSAGE_OK); + + client.GotError.WaitI(); + } + + Y_UNIT_TEST(SendTimeout_Callback_HangingServer) { + THangingServer server; + + TObjectCountCheck objectCountCheck; + + TCheckTimeoutClient client(TNetAddr("localhost", server.GetPort())); + + bool first = true; + for (;;) { + EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr); + if (ok == MESSAGE_BUSY) { + UNIT_ASSERT(!first); + break; + } + UNIT_ASSERT_VALUES_EQUAL(ok, MESSAGE_OK); + first = false; + } + + // BUGBUG: The test is buggy: the client might not get any error when sending one-way messages. + // All the messages that the client has sent before he gets first MESSAGE_BUSY error might get + // serailized and written to the socket buffer, so the write queue gets drained and there are + // no messages to timeout when periodic timeout check happens. + + client.GotError.WaitI(); + } +} diff --git a/library/cpp/messagebus/test/ut/starter_ut.cpp b/library/cpp/messagebus/test/ut/starter_ut.cpp new file mode 100644 index 0000000000..dd4d3aaa5e --- /dev/null +++ b/library/cpp/messagebus/test/ut/starter_ut.cpp @@ -0,0 +1,140 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include <library/cpp/messagebus/test/helper/example_module.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> +#include <library/cpp/messagebus/test/helper/wait_for.h> + +using namespace NBus; +using namespace NBus::NTest; + +Y_UNIT_TEST_SUITE(TBusStarterTest) { + struct TStartJobTestModule: public TExampleModule { + using TBusModule::CreateDefaultStarter; + + TAtomic StartCount; + + TStartJobTestModule() + : StartCount(0) + { + } + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + AtomicIncrement(StartCount); + job->Sleep(10); + return &TStartJobTestModule::End; + } + + TJobHandler End(TBusJob* job, TBusMessage* mess) { + Y_UNUSED(mess); + AtomicIncrement(StartCount); + job->Cancel(MESSAGE_UNKNOWN); + return nullptr; + } + }; + + Y_UNIT_TEST(Test) { + TObjectCountCheck objectCountCheck; + + TBusMessageQueuePtr bus(CreateMessageQueue()); + + TStartJobTestModule module; + + //module.StartModule(); + module.CreatePrivateSessions(bus.Get()); + module.StartInput(); + + TBusSessionConfig config; + config.SendTimeout = 10; + + module.CreateDefaultStarter(*bus, config); + + UNIT_WAIT_FOR(AtomicGet(module.StartCount) >= 3); + + module.Shutdown(); + bus->Stop(); + } + + Y_UNIT_TEST(TestModuleStartJob) { + TObjectCountCheck objectCountCheck; + + TExampleProtocol proto; + + TStartJobTestModule module; + + TBusModuleConfig moduleConfig; + moduleConfig.Secret.SchedulePeriod = TDuration::MilliSeconds(10); + module.SetConfig(moduleConfig); + + module.StartModule(); + + module.StartJob(new TExampleRequest(&proto.RequestCount)); + + UNIT_WAIT_FOR(AtomicGet(module.StartCount) != 2); + + module.Shutdown(); + } + + struct TSleepModule: public TExampleServerModule { + TSystemEvent MessageReceivedEvent; + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + + MessageReceivedEvent.Signal(); + + job->Sleep(1000000000); + + return TJobHandler(&TSleepModule::Never); + } + + TJobHandler Never(TBusJob*, TBusMessage*) { + Y_FAIL("happens"); + throw 1; + } + }; + + Y_UNIT_TEST(StartJobDestroyDuringSleep) { + TObjectCountCheck objectCountCheck; + + TExampleProtocol proto; + + TSleepModule module; + + module.StartModule(); + + module.StartJob(new TExampleRequest(&proto.StartCount)); + + module.MessageReceivedEvent.WaitI(); + + module.Shutdown(); + } + + struct TSendReplyModule: public TExampleServerModule { + TSystemEvent MessageReceivedEvent; + + TJobHandler Start(TBusJob* job, TBusMessage* mess) override { + Y_UNUSED(mess); + + job->SendReply(new TExampleResponse(&Proto.ResponseCount)); + + MessageReceivedEvent.Signal(); + + return nullptr; + } + }; + + Y_UNIT_TEST(AllowSendReplyInStarted) { + TObjectCountCheck objectCountCheck; + + TExampleProtocol proto; + + TSendReplyModule module; + module.StartModule(); + module.StartJob(new TExampleRequest(&proto.StartCount)); + + module.MessageReceivedEvent.WaitI(); + + module.Shutdown(); + } +} diff --git a/library/cpp/messagebus/test/ut/sync_client_ut.cpp b/library/cpp/messagebus/test/ut/sync_client_ut.cpp new file mode 100644 index 0000000000..400128193f --- /dev/null +++ b/library/cpp/messagebus/test/ut/sync_client_ut.cpp @@ -0,0 +1,69 @@ +#include <library/cpp/messagebus/test/helper/example.h> +#include <library/cpp/messagebus/test/helper/object_count_check.h> + +namespace NBus { + namespace NTest { + using namespace std; + + //////////////////////////////////////////////////////////////////// + /// \brief Client for sending synchronous message to local server + struct TSyncClient { + TNetAddr ServerAddr; + + TExampleProtocol Proto; + TBusMessageQueuePtr Bus; + TBusSyncClientSessionPtr Session; + + int NumReplies; + int NumMessages; + + /// constructor creates instances of queue, protocol and session + TSyncClient(const TNetAddr& serverAddr) + : ServerAddr(serverAddr) + { + /// create or get instance of message queue, need one per application + Bus = CreateMessageQueue(); + + NumReplies = 0; + NumMessages = 10; + + /// register source/client session + TBusClientSessionConfig sessionConfig; + Session = Bus->CreateSyncSource(&Proto, sessionConfig); + Session->RegisterService("localhost"); + } + + ~TSyncClient() { + Session->Shutdown(); + } + + /// dispatch of requests is done here + void Work() { + for (int i = 0; i < NumMessages; i++) { + THolder<TExampleRequest> mess(new TExampleRequest(&Proto.RequestCount)); + EMessageStatus status; + THolder<TBusMessage> reply(Session->SendSyncMessage(mess.Get(), status, &ServerAddr)); + if (!!reply) { + NumReplies++; + } + } + } + }; + + Y_UNIT_TEST_SUITE(SyncClientTest) { + Y_UNIT_TEST(TestSync) { + TObjectCountCheck objectCountCheck; + + TExampleServer server; + TSyncClient client(server.GetActualListenAddr()); + client.Work(); + // assert correct number of replies + UNIT_ASSERT_EQUAL(client.NumReplies, client.NumMessages); + // assert that there is no message left in flight + UNIT_ASSERT_EQUAL(server.Session->GetInFlight(), 0); + UNIT_ASSERT_EQUAL(client.Session->GetInFlight(), 0); + } + } + + } +} diff --git a/library/cpp/messagebus/test/ut/ya.make b/library/cpp/messagebus/test/ut/ya.make new file mode 100644 index 0000000000..fe1b4961d6 --- /dev/null +++ b/library/cpp/messagebus/test/ut/ya.make @@ -0,0 +1,56 @@ +OWNER(g:messagebus) + +UNITTEST_FOR(library/cpp/messagebus) + +TIMEOUT(1200) + +SIZE(LARGE) + +TAG( + ya:not_autocheck + ya:fat +) + +FORK_SUBTESTS() + +PEERDIR( + library/cpp/testing/unittest_main + library/cpp/messagebus + library/cpp/messagebus/test/helper + library/cpp/messagebus/www +) + +SRCS( + messagebus_ut.cpp + module_client_ut.cpp + module_client_one_way_ut.cpp + module_server_ut.cpp + one_way_ut.cpp + starter_ut.cpp + sync_client_ut.cpp + locator_uniq_ut.cpp + ../../actor/actor_ut.cpp + ../../actor/ring_buffer_ut.cpp + ../../actor/tasks_ut.cpp + ../../actor/what_thread_does_guard_ut.cpp + ../../async_result_ut.cpp + ../../cc_semaphore_ut.cpp + ../../coreconn_ut.cpp + ../../duration_histogram_ut.cpp + ../../message_status_counter_ut.cpp + ../../misc/weak_ptr_ut.cpp + ../../latch_ut.cpp + ../../lfqueue_batch_ut.cpp + ../../local_flags_ut.cpp + ../../memory_ut.cpp + ../../moved_ut.cpp + ../../netaddr_ut.cpp + ../../network_ut.cpp + ../../nondestroying_holder_ut.cpp + ../../scheduler_actor_ut.cpp + ../../scheduler/scheduler_ut.cpp + ../../socket_addr_ut.cpp + ../../vector_swaps_ut.cpp +) + +END() diff --git a/library/cpp/messagebus/test/ya.make b/library/cpp/messagebus/test/ya.make new file mode 100644 index 0000000000..0dc4bd4720 --- /dev/null +++ b/library/cpp/messagebus/test/ya.make @@ -0,0 +1,7 @@ +OWNER(g:messagebus) + +RECURSE( + example + perftest + ut +) diff --git a/library/cpp/messagebus/test_utils.h b/library/cpp/messagebus/test_utils.h new file mode 100644 index 0000000000..2abdf504b1 --- /dev/null +++ b/library/cpp/messagebus/test_utils.h @@ -0,0 +1,12 @@ +#pragma once + +// Do nothing if there is no support for IPv4 +#define ASSUME_IP_V4_ENABLED \ + do { \ + try { \ + TNetworkAddress("192.168.0.42", 80); \ + } catch (const TNetworkResolutionError& ex) { \ + Y_UNUSED(ex); \ + return; \ + } \ + } while (0) diff --git a/library/cpp/messagebus/text_utils.h b/library/cpp/messagebus/text_utils.h new file mode 100644 index 0000000000..c2dcad834c --- /dev/null +++ b/library/cpp/messagebus/text_utils.h @@ -0,0 +1,3 @@ +#pragma once + +#include <library/cpp/string_utils/indent_text/indent_text.h> diff --git a/library/cpp/messagebus/thread_extra.h b/library/cpp/messagebus/thread_extra.h new file mode 100644 index 0000000000..2c79741e88 --- /dev/null +++ b/library/cpp/messagebus/thread_extra.h @@ -0,0 +1,3 @@ +#pragma once + +#include <library/cpp/messagebus/actor/thread_extra.h> diff --git a/library/cpp/messagebus/use_after_free_checker.cpp b/library/cpp/messagebus/use_after_free_checker.cpp new file mode 100644 index 0000000000..4904e7c614 --- /dev/null +++ b/library/cpp/messagebus/use_after_free_checker.cpp @@ -0,0 +1,22 @@ +#include "use_after_free_checker.h" + +#include <util/system/yassert.h> + +namespace { + const ui64 VALID = (ui64)0xAABBCCDDEEFF0011LL; + const ui64 INVALID = (ui64)0x1122334455667788LL; +} + +TUseAfterFreeChecker::TUseAfterFreeChecker() + : Magic(VALID) +{ +} + +TUseAfterFreeChecker::~TUseAfterFreeChecker() { + Y_VERIFY(Magic == VALID, "Corrupted"); + Magic = INVALID; +} + +void TUseAfterFreeChecker::CheckNotFreed() const { + Y_VERIFY(Magic == VALID, "Freed or corrupted"); +} diff --git a/library/cpp/messagebus/use_after_free_checker.h b/library/cpp/messagebus/use_after_free_checker.h new file mode 100644 index 0000000000..590b076156 --- /dev/null +++ b/library/cpp/messagebus/use_after_free_checker.h @@ -0,0 +1,31 @@ +#pragma once + +#include <util/system/platform.h> +#include <util/system/types.h> + +class TUseAfterFreeChecker { +private: + ui64 Magic; + +public: + TUseAfterFreeChecker(); + ~TUseAfterFreeChecker(); + void CheckNotFreed() const; +}; + +// check twice: in constructor and in destructor +class TUseAfterFreeCheckerGuard { +private: + const TUseAfterFreeChecker& Check; + +public: + TUseAfterFreeCheckerGuard(const TUseAfterFreeChecker& check) + : Check(check) + { + Check.CheckNotFreed(); + } + + ~TUseAfterFreeCheckerGuard() { + Check.CheckNotFreed(); + } +}; diff --git a/library/cpp/messagebus/use_count_checker.cpp b/library/cpp/messagebus/use_count_checker.cpp new file mode 100644 index 0000000000..c6243ea21f --- /dev/null +++ b/library/cpp/messagebus/use_count_checker.cpp @@ -0,0 +1,53 @@ +#include "use_count_checker.h" + +#include <util/generic/utility.h> +#include <util/system/yassert.h> + +TUseCountChecker::TUseCountChecker() { +} + +TUseCountChecker::~TUseCountChecker() { + TAtomicBase count = Counter.Val(); + Y_VERIFY(count == 0, "must not release when count is not zero: %ld", (long)count); +} + +void TUseCountChecker::Inc() { + Counter.Inc(); +} + +void TUseCountChecker::Dec() { + Counter.Dec(); +} + +TUseCountHolder::TUseCountHolder() + : CurrentChecker(nullptr) +{ +} + +TUseCountHolder::TUseCountHolder(TUseCountChecker* currentChecker) + : CurrentChecker(currentChecker) +{ + if (!!CurrentChecker) { + CurrentChecker->Inc(); + } +} + +TUseCountHolder::~TUseCountHolder() { + if (!!CurrentChecker) { + CurrentChecker->Dec(); + } +} + +TUseCountHolder& TUseCountHolder::operator=(TUseCountHolder that) { + Swap(that); + return *this; +} + +void TUseCountHolder::Swap(TUseCountHolder& that) { + DoSwap(CurrentChecker, that.CurrentChecker); +} + +void TUseCountHolder::Reset() { + TUseCountHolder tmp; + Swap(tmp); +} diff --git a/library/cpp/messagebus/use_count_checker.h b/library/cpp/messagebus/use_count_checker.h new file mode 100644 index 0000000000..70bef6fa8a --- /dev/null +++ b/library/cpp/messagebus/use_count_checker.h @@ -0,0 +1,27 @@ +#pragma once + +#include <util/generic/refcount.h> + +class TUseCountChecker { +private: + TAtomicCounter Counter; + +public: + TUseCountChecker(); + ~TUseCountChecker(); + void Inc(); + void Dec(); +}; + +class TUseCountHolder { +private: + TUseCountChecker* CurrentChecker; + +public: + TUseCountHolder(); + explicit TUseCountHolder(TUseCountChecker* currentChecker); + TUseCountHolder& operator=(TUseCountHolder that); + ~TUseCountHolder(); + void Swap(TUseCountHolder&); + void Reset(); +}; diff --git a/library/cpp/messagebus/vector_swaps.h b/library/cpp/messagebus/vector_swaps.h new file mode 100644 index 0000000000..b920bcf03e --- /dev/null +++ b/library/cpp/messagebus/vector_swaps.h @@ -0,0 +1,171 @@ +#pragma once + +#include <util/generic/array_ref.h> +#include <util/generic/noncopyable.h> +#include <util/generic/utility.h> +#include <util/system/yassert.h> + +#include <stdlib.h> + +template <typename T, class A = std::allocator<T>> +class TVectorSwaps : TNonCopyable { +private: + T* Start; + T* Finish; + T* EndOfStorage; + + void StateCheck() { + Y_ASSERT(Start <= Finish); + Y_ASSERT(Finish <= EndOfStorage); + } + +public: + typedef T* iterator; + typedef const T* const_iterator; + + typedef std::reverse_iterator<iterator> reverse_iterator; + typedef std::reverse_iterator<const_iterator> const_reverse_iterator; + + TVectorSwaps() + : Start() + , Finish() + , EndOfStorage() + { + } + + ~TVectorSwaps() { + for (size_t i = 0; i < size(); ++i) { + Start[i].~T(); + } + free(Start); + } + + operator TArrayRef<const T>() const { + return MakeArrayRef(data(), size()); + } + + operator TArrayRef<T>() { + return MakeArrayRef(data(), size()); + } + + size_t capacity() const { + return EndOfStorage - Start; + } + + size_t size() const { + return Finish - Start; + } + + bool empty() const { + return size() == 0; + } + + T* data() { + return Start; + } + + const T* data() const { + return Start; + } + + T& operator[](size_t index) { + Y_ASSERT(index < size()); + return Start[index]; + } + + const T& operator[](size_t index) const { + Y_ASSERT(index < size()); + return Start[index]; + } + + iterator begin() { + return Start; + } + + iterator end() { + return Finish; + } + + const_iterator begin() const { + return Start; + } + + const_iterator end() const { + return Finish; + } + + reverse_iterator rbegin() { + return reverse_iterator(end()); + } + reverse_iterator rend() { + return reverse_iterator(begin()); + } + + const_reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + const_reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + void swap(TVectorSwaps<T>& that) { + DoSwap(Start, that.Start); + DoSwap(Finish, that.Finish); + DoSwap(EndOfStorage, that.EndOfStorage); + } + + void reserve(size_t n) { + if (n <= capacity()) { + return; + } + + size_t newCapacity = FastClp2(n); + TVectorSwaps<T> tmp; + tmp.Start = (T*)malloc(sizeof(T) * newCapacity); + Y_VERIFY(!!tmp.Start); + + tmp.EndOfStorage = tmp.Start + newCapacity; + + for (size_t i = 0; i < size(); ++i) { + // TODO: catch exceptions + new (tmp.Start + i) T(); + DoSwap(Start[i], tmp.Start[i]); + } + + tmp.Finish = tmp.Start + size(); + + swap(tmp); + + StateCheck(); + } + + void clear() { + TVectorSwaps<T> tmp; + swap(tmp); + } + + template <class TIterator> + void insert(iterator pos, TIterator b, TIterator e) { + Y_VERIFY(pos == end(), "TODO: only insert at the end is implemented"); + + size_t count = e - b; + + reserve(size() + count); + + TIterator next = b; + + for (size_t i = 0; i < count; ++i) { + new (Start + size() + i) T(); + DoSwap(Start[size() + i], *next); + ++next; + } + + Finish += count; + + StateCheck(); + } + + void push_back(T& elem) { + insert(end(), &elem, &elem + 1); + } +}; diff --git a/library/cpp/messagebus/vector_swaps_ut.cpp b/library/cpp/messagebus/vector_swaps_ut.cpp new file mode 100644 index 0000000000..693cc6857b --- /dev/null +++ b/library/cpp/messagebus/vector_swaps_ut.cpp @@ -0,0 +1,17 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "vector_swaps.h" + +Y_UNIT_TEST_SUITE(TVectorSwapsTest) { + Y_UNIT_TEST(Simple) { + TVectorSwaps<THolder<unsigned>> v; + for (unsigned i = 0; i < 100; ++i) { + THolder<unsigned> tmp(new unsigned(i)); + v.push_back(tmp); + } + + for (unsigned i = 0; i < 100; ++i) { + UNIT_ASSERT_VALUES_EQUAL(i, *v[i]); + } + } +} diff --git a/library/cpp/messagebus/www/bus-ico.png b/library/cpp/messagebus/www/bus-ico.png Binary files differnew file mode 100644 index 0000000000..c69a461892 --- /dev/null +++ b/library/cpp/messagebus/www/bus-ico.png diff --git a/library/cpp/messagebus/www/concat_strings.h b/library/cpp/messagebus/www/concat_strings.h new file mode 100644 index 0000000000..7b730564eb --- /dev/null +++ b/library/cpp/messagebus/www/concat_strings.h @@ -0,0 +1,22 @@ +#pragma once + +#include <util/generic/string.h> +#include <util/stream/str.h> + +// ATTN: not equivalent to TString::Join - cat concat anything "outputable" to stream, not only TString convertable types. + +inline void DoConcatStrings(TStringStream&) { +} + +template <class T, class... R> +inline void DoConcatStrings(TStringStream& ss, const T& t, const R&... r) { + ss << t; + DoConcatStrings(ss, r...); +} + +template <class... R> +inline TString ConcatStrings(const R&... r) { + TStringStream ss; + DoConcatStrings(ss, r...); + return ss.Str(); +} diff --git a/library/cpp/messagebus/www/html_output.cpp b/library/cpp/messagebus/www/html_output.cpp new file mode 100644 index 0000000000..10ea2e163b --- /dev/null +++ b/library/cpp/messagebus/www/html_output.cpp @@ -0,0 +1,4 @@ +#include "html_output.h" + +Y_POD_THREAD(IOutputStream*) +HtmlOutputStreamPtr; diff --git a/library/cpp/messagebus/www/html_output.h b/library/cpp/messagebus/www/html_output.h new file mode 100644 index 0000000000..27e77adefa --- /dev/null +++ b/library/cpp/messagebus/www/html_output.h @@ -0,0 +1,324 @@ +#pragma once + +#include "concat_strings.h" + +#include <util/generic/string.h> +#include <util/stream/output.h> +#include <library/cpp/html/pcdata/pcdata.h> +#include <util/system/tls.h> + +extern Y_POD_THREAD(IOutputStream*) HtmlOutputStreamPtr; + +static IOutputStream& HtmlOutputStream() { + Y_VERIFY(!!HtmlOutputStreamPtr); + return *HtmlOutputStreamPtr; +} + +struct THtmlOutputStreamPushPop { + IOutputStream* const Prev; + + THtmlOutputStreamPushPop(IOutputStream* outputStream) + : Prev(HtmlOutputStreamPtr) + { + HtmlOutputStreamPtr = outputStream; + } + + ~THtmlOutputStreamPushPop() { + HtmlOutputStreamPtr = Prev; + } +}; + +struct TChars { + TString Text; + bool NeedEscape; + + TChars(TStringBuf text) + : Text(text) + , NeedEscape(true) + { + } + TChars(TStringBuf text, bool escape) + : Text(text) + , NeedEscape(escape) + { + } + TChars(const char* text) + : Text(text) + , NeedEscape(true) + { + } + TChars(const char* text, bool escape) + : Text(text) + , NeedEscape(escape) + { + } + + TString Escape() { + if (NeedEscape) { + return EncodeHtmlPcdata(Text); + } else { + return Text; + } + } +}; + +struct TAttr { + TString Name; + TString Value; + + TAttr(TStringBuf name, TStringBuf value) + : Name(name) + , Value(value) + { + } + + TAttr() { + } + + bool operator!() const { + return !Name; + } +}; + +static inline void Doctype() { + HtmlOutputStream() << "<!doctype html>\n"; +} + +static inline void Nl() { + HtmlOutputStream() << "\n"; +} + +static inline void Sp() { + HtmlOutputStream() << " "; +} + +static inline void Text(TStringBuf text) { + HtmlOutputStream() << EncodeHtmlPcdata(text); +} + +static inline void Line(TStringBuf text) { + Text(text); + Nl(); +} + +static inline void WriteAttr(TAttr a) { + if (!!a) { + HtmlOutputStream() << " " << a.Name << "='" << EncodeHtmlPcdata(a.Value) << "'"; + } +} + +static inline void Open(TStringBuf tag, TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr(), TAttr a4 = TAttr()) { + HtmlOutputStream() << "<" << tag; + WriteAttr(a1); + WriteAttr(a2); + WriteAttr(a3); + WriteAttr(a4); + HtmlOutputStream() << ">"; +} + +static inline void Open(TStringBuf tag, TStringBuf cssClass, TStringBuf id = "") { + Open(tag, TAttr("class", cssClass), !!id ? TAttr("id", id) : TAttr()); +} + +static inline void OpenBlock(TStringBuf tag, TStringBuf cssClass = "") { + Open(tag, cssClass); + Nl(); +} + +static inline void Close(TStringBuf tag) { + HtmlOutputStream() << "</" << tag << ">\n"; +} + +static inline void CloseBlock(TStringBuf tag) { + Close(tag); + Nl(); +} + +static inline void TagWithContent(TStringBuf tag, TChars content) { + HtmlOutputStream() << "<" << tag << ">" << content.Escape() << "</" << tag << ">"; +} + +static inline void BlockTagWithContent(TStringBuf tag, TStringBuf content) { + TagWithContent(tag, content); + Nl(); +} + +static inline void TagWithClass(TStringBuf tag, TStringBuf cssClass) { + Open(tag, cssClass); + Close(tag); +} + +static inline void Hn(unsigned n, TStringBuf title) { + BlockTagWithContent(ConcatStrings("h", n), title); +} + +static inline void Small(TStringBuf text) { + TagWithContent("small", text); +} + +static inline void HnWithSmall(unsigned n, TStringBuf title, TStringBuf small) { + TString tagName = ConcatStrings("h", n); + Open(tagName); + HtmlOutputStream() << title; + Sp(); + Small(small); + Close(tagName); +} + +static inline void H1(TStringBuf title) { + Hn(1, title); +} + +static inline void H2(TStringBuf title) { + Hn(2, title); +} + +static inline void H3(TStringBuf title) { + Hn(3, title); +} + +static inline void H4(TStringBuf title) { + Hn(4, title); +} + +static inline void H5(TStringBuf title) { + Hn(5, title); +} + +static inline void H6(TStringBuf title) { + Hn(6, title); +} + +static inline void Pre(TStringBuf content) { + HtmlOutputStream() << "<pre>" << EncodeHtmlPcdata(content) << "</pre>\n"; +} + +static inline void Li(TStringBuf content) { + BlockTagWithContent("li", content); +} + +static inline void LiWithClass(TStringBuf cssClass, TStringBuf content) { + Open("li", cssClass); + Text(content); + Close("li"); +} + +static inline void OpenA(TStringBuf href) { + Open("a", TAttr("href", href)); +} + +static inline void A(TStringBuf href, TStringBuf text) { + OpenA(href); + Text(text); + Close("a"); +} + +static inline void Td(TStringBuf content) { + TagWithContent("td", content); +} + +static inline void Th(TStringBuf content, TStringBuf cssClass = "") { + OpenBlock("th", cssClass); + Text(content); + CloseBlock("th"); +} + +static inline void DivWithClassAndContent(TStringBuf cssClass, TStringBuf content) { + Open("div", cssClass); + Text(content); + Close("div"); +} + +static inline void BootstrapError(TStringBuf text) { + DivWithClassAndContent("alert alert-danger", text); +} + +static inline void BootstrapInfo(TStringBuf text) { + DivWithClassAndContent("alert alert-info", text); +} + +static inline void ScriptHref(TStringBuf href) { + Open("script", + TAttr("language", "javascript"), + TAttr("type", "text/javascript"), + TAttr("src", href)); + Close("script"); + Nl(); +} + +static inline void LinkStylesheet(TStringBuf href) { + Open("link", TAttr("rel", "stylesheet"), TAttr("href", href)); + Close("link"); + Nl(); +} + +static inline void LinkFavicon(TStringBuf href) { + Open("link", TAttr("rel", "shortcut icon"), TAttr("href", href)); + Close("link"); + Nl(); +} + +static inline void Title(TChars title) { + TagWithContent("title", title); + Nl(); +} + +static inline void Code(TStringBuf content) { + TagWithContent("code", content); +} + +struct TTagGuard { + const TString TagName; + + TTagGuard(TStringBuf tagName, TStringBuf cssClass, TStringBuf id = "") + : TagName(tagName) + { + Open(TagName, cssClass, id); + } + + TTagGuard(TStringBuf tagName, TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr(), TAttr a4 = TAttr()) + : TagName(tagName) + { + Open(tagName, a1, a2, a3, a4); + } + + ~TTagGuard() { + Close(TagName); + } +}; + +struct TDivGuard: public TTagGuard { + TDivGuard(TStringBuf cssClass, TStringBuf id = "") + : TTagGuard("div", cssClass, id) + { + } + + TDivGuard(TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr()) + : TTagGuard("div", a1, a2, a3) + { + } +}; + +struct TAGuard { + TAGuard(TStringBuf href) { + OpenA(href); + } + + ~TAGuard() { + Close("a"); + } +}; + +struct TScriptFunctionGuard { + TTagGuard Script; + + TScriptFunctionGuard() + : Script("script") + { + Line("$(function() {"); + } + + ~TScriptFunctionGuard() { + Line("});"); + } +}; diff --git a/library/cpp/messagebus/www/messagebus.js b/library/cpp/messagebus/www/messagebus.js new file mode 100644 index 0000000000..e30508b879 --- /dev/null +++ b/library/cpp/messagebus/www/messagebus.js @@ -0,0 +1,48 @@ +function logTransform(v) { + return Math.log(v + 1); +} + +function plotHist(where, hist) { + var max = hist.map(function(x) {return x[1]}).reduce(function(x, y) {return Math.max(x, y)}); + + var ticks = []; + for (var t = 1; ; t *= 10) { + if (t > max) { + break; + } + ticks.push(t); + } + + $.plot(where, [hist], + { + data: hist, + series: { + bars: { + show: true, + barWidth: 0.9 + } + }, + xaxis: { + mode: 'categories', + tickLength: 0 + }, + yaxis: { + ticks: ticks, + transform: logTransform + } + } + ); +} + +function plotQueueSize(where, data, ticks) { + $.plot(where, [data], + { + xaxis: { + ticks: ticks, + }, + yaxis: { + //transform: logTransform + } + } + ); +} diff --git a/library/cpp/messagebus/www/www.cpp b/library/cpp/messagebus/www/www.cpp new file mode 100644 index 0000000000..62ec241d85 --- /dev/null +++ b/library/cpp/messagebus/www/www.cpp @@ -0,0 +1,930 @@ +#include "www.h" + +#include "concat_strings.h" +#include "html_output.h" + +#include <library/cpp/messagebus/remote_connection_status.h> +#include <library/cpp/monlib/deprecated/json/writer.h> + +#include <library/cpp/archive/yarchive.h> +#include <library/cpp/http/fetch/httpfsm.h> +#include <library/cpp/http/fetch/httpheader.h> +#include <library/cpp/http/server/http.h> +#include <library/cpp/json/writer/json.h> +#include <library/cpp/uri/http_url.h> + +#include <util/string/cast.h> +#include <util/string/printf.h> +#include <util/system/mutex.h> + +#include <utility> + +using namespace NBus; +using namespace NBus::NPrivate; +using namespace NActor; +using namespace NActor::NPrivate; + +static const char HTTP_OK_JS[] = "HTTP/1.1 200 Ok\r\nContent-Type: text/javascript\r\nConnection: Close\r\n\r\n"; +static const char HTTP_OK_JSON[] = "HTTP/1.1 200 Ok\r\nContent-Type: application/json; charset=utf-8\r\nConnection: Close\r\n\r\n"; +static const char HTTP_OK_PNG[] = "HTTP/1.1 200 Ok\r\nContent-Type: image/png\r\nConnection: Close\r\n\r\n"; +static const char HTTP_OK_BIN[] = "HTTP/1.1 200 Ok\r\nContent-Type: application/octet-stream\r\nConnection: Close\r\n\r\n"; +static const char HTTP_OK_HTML[] = "HTTP/1.1 200 Ok\r\nContent-Type: text/html; charset=utf-8\r\nConnection: Close\r\n\r\n"; + +namespace { + typedef TIntrusivePtr<TBusModuleInternal> TBusModuleInternalPtr; + + template <typename TValuePtr> + struct TNamedValues { + TVector<std::pair<TString, TValuePtr>> Entries; + + TValuePtr FindByName(TStringBuf name) { + Y_VERIFY(!!name); + + for (unsigned i = 0; i < Entries.size(); ++i) { + if (Entries[i].first == name) { + return Entries[i].second; + } + } + return TValuePtr(); + } + + TString FindNameByPtr(TValuePtr value) { + Y_VERIFY(!!value); + + for (unsigned i = 0; i < Entries.size(); ++i) { + if (Entries[i].second.Get() == value.Get()) { + return Entries[i].first; + } + } + + Y_FAIL("unregistered"); + } + + void Add(TValuePtr p) { + Y_VERIFY(!!p); + + // Do not add twice + for (unsigned i = 0; i < Entries.size(); ++i) { + if (Entries[i].second.Get() == p.Get()) { + return; + } + } + + if (!!p->GetNameInternal()) { + TValuePtr current = FindByName(p->GetNameInternal()); + + if (!current) { + Entries.emplace_back(p->GetNameInternal(), p); + return; + } + } + + for (unsigned i = 1;; ++i) { + TString prefix = p->GetNameInternal(); + if (!prefix) { + prefix = "unnamed"; + } + TString name = ConcatStrings(prefix, "-", i); + + TValuePtr current = FindByName(name); + + if (!current) { + Entries.emplace_back(name, p); + return; + } + } + } + + size_t size() const { + return Entries.size(); + } + + bool operator!() const { + return size() == 0; + } + }; + + template <typename TSessionPtr> + struct TSessionValues: public TNamedValues<TSessionPtr> { + typedef TNamedValues<TSessionPtr> TBase; + + TVector<TString> GetNamesForQueue(TBusMessageQueue* queue) { + TVector<TString> r; + for (unsigned i = 0; i < TBase::size(); ++i) { + if (TBase::Entries[i].second->GetQueue() == queue) { + r.push_back(TBase::Entries[i].first); + } + } + return r; + } + }; +} + +namespace { + TString RootHref() { + return ConcatStrings("?"); + } + + TString QueueHref(TStringBuf name) { + return ConcatStrings("?q=", name); + } + + TString ServerSessionHref(TStringBuf name) { + return ConcatStrings("?ss=", name); + } + + TString ClientSessionHref(TStringBuf name) { + return ConcatStrings("?cs=", name); + } + + TString OldModuleHref(TStringBuf name) { + return ConcatStrings("?om=", name); + } + + /* + static void RootLink() { + A(RootHref(), "root"); + } + */ + + void QueueLink(TStringBuf name) { + A(QueueHref(name), name); + } + + void ServerSessionLink(TStringBuf name) { + A(ServerSessionHref(name), name); + } + + void ClientSessionLink(TStringBuf name) { + A(ClientSessionHref(name), name); + } + + void OldModuleLink(TStringBuf name) { + A(OldModuleHref(name), name); + } + +} + +const unsigned char WWW_STATIC_DATA[] = { +#include "www_static.inc" +}; + +class TWwwStaticLoader: public TArchiveReader { +public: + TWwwStaticLoader() + : TArchiveReader(TBlob::NoCopy(WWW_STATIC_DATA, sizeof(WWW_STATIC_DATA))) + { + } +}; + +struct TBusWww::TImpl { + // TODO: use weak pointers + TNamedValues<TBusMessageQueuePtr> Queues; + TSessionValues<TIntrusivePtr<TBusClientSession>> ClientSessions; + TSessionValues<TIntrusivePtr<TBusServerSession>> ServerSessions; + TSessionValues<TBusModuleInternalPtr> Modules; + + TMutex Mutex; + + void RegisterClientSession(TBusClientSessionPtr s) { + Y_VERIFY(!!s); + TGuard<TMutex> g(Mutex); + ClientSessions.Add(s.Get()); + Queues.Add(s->GetQueue()); + } + + void RegisterServerSession(TBusServerSessionPtr s) { + Y_VERIFY(!!s); + TGuard<TMutex> g(Mutex); + ServerSessions.Add(s.Get()); + Queues.Add(s->GetQueue()); + } + + void RegisterQueue(TBusMessageQueuePtr q) { + Y_VERIFY(!!q); + TGuard<TMutex> g(Mutex); + Queues.Add(q); + } + + void RegisterModule(TBusModule* module) { + Y_VERIFY(!!module); + TGuard<TMutex> g(Mutex); + + { + TVector<TBusClientSessionPtr> clientSessions = module->GetInternal()->GetClientSessionsInternal(); + for (unsigned i = 0; i < clientSessions.size(); ++i) { + RegisterClientSession(clientSessions[i]); + } + } + + { + TVector<TBusServerSessionPtr> serverSessions = module->GetInternal()->GetServerSessionsInternal(); + for (unsigned i = 0; i < serverSessions.size(); ++i) { + RegisterServerSession(serverSessions[i]); + } + } + + Queues.Add(module->GetInternal()->GetQueue()); + Modules.Add(module->GetInternal()); + } + + TString FindQueueNameBySessionName(TStringBuf sessionName, bool client) { + TIntrusivePtr<TBusClientSession> clientSession; + TIntrusivePtr<TBusServerSession> serverSession; + TBusSession* session; + if (client) { + clientSession = ClientSessions.FindByName(sessionName); + session = clientSession.Get(); + } else { + serverSession = ServerSessions.FindByName(sessionName); + session = serverSession.Get(); + } + Y_VERIFY(!!session); + return Queues.FindNameByPtr(session->GetQueue()); + } + + struct TRequest { + TImpl* const Outer; + IOutputStream& Os; + const TCgiParameters& CgiParams; + const TOptionalParams& Params; + + TRequest(TImpl* outer, IOutputStream& os, const TCgiParameters& cgiParams, const TOptionalParams& params) + : Outer(outer) + , Os(os) + , CgiParams(cgiParams) + , Params(params) + { + } + + void CrumbsParentLinks() { + for (unsigned i = 0; i < Params.ParentLinks.size(); ++i) { + const TLink& link = Params.ParentLinks[i]; + TTagGuard li("li"); + A(link.Href, link.Title); + } + } + + void Crumb(TStringBuf name, TStringBuf href = "") { + if (!!href) { + TTagGuard li("li"); + A(href, name); + } else { + LiWithClass("active", name); + } + } + + void BreadcrumbRoot() { + TTagGuard ol("ol", "breadcrumb"); + CrumbsParentLinks(); + Crumb("MessageBus"); + } + + void BreadcrumbQueue(TStringBuf queueName) { + TTagGuard ol("ol", "breadcrumb"); + CrumbsParentLinks(); + Crumb("MessageBus", RootHref()); + Crumb(ConcatStrings("queue ", queueName)); + } + + void BreadcrumbSession(TStringBuf sessionName, bool client) { + TString queueName = Outer->FindQueueNameBySessionName(sessionName, client); + TStringBuf whatSession = client ? "client session" : "server session"; + + TTagGuard ol("ol", "breadcrumb"); + CrumbsParentLinks(); + Crumb("MessageBus", RootHref()); + Crumb(ConcatStrings("queue ", queueName), QueueHref(queueName)); + Crumb(ConcatStrings(whatSession, " ", sessionName)); + } + + void ServeSessionsOfQueue(TBusMessageQueuePtr queue, bool includeQueue) { + TVector<TString> clientNames = Outer->ClientSessions.GetNamesForQueue(queue.Get()); + TVector<TString> serverNames = Outer->ServerSessions.GetNamesForQueue(queue.Get()); + TVector<TString> moduleNames = Outer->Modules.GetNamesForQueue(queue.Get()); + + TTagGuard table("table", "table table-condensed table-bordered"); + + { + TTagGuard colgroup("colgroup"); + TagWithClass("col", "col-md-2"); + TagWithClass("col", "col-md-2"); + TagWithClass("col", "col-md-8"); + } + + { + TTagGuard tr("tr"); + Th("What", "span2"); + Th("Name", "span2"); + Th("Status", "span6"); + } + + if (includeQueue) { + TTagGuard tr1("tr"); + Td("queue"); + + { + TTagGuard td("td"); + QueueLink(Outer->Queues.FindNameByPtr(queue)); + } + + { + TTagGuard tr2("td"); + Pre(queue->GetStatusSingleLine()); + } + } + + for (unsigned j = 0; j < clientNames.size(); ++j) { + TTagGuard tr("tr"); + Td("client session"); + + { + TTagGuard td("td"); + ClientSessionLink(clientNames[j]); + } + + { + TTagGuard td("td"); + Pre(Outer->ClientSessions.FindByName(clientNames[j])->GetStatusSingleLine()); + } + } + + for (unsigned j = 0; j < serverNames.size(); ++j) { + TTagGuard tr("tr"); + Td("server session"); + + { + TTagGuard td("td"); + ServerSessionLink(serverNames[j]); + } + + { + TTagGuard td("td"); + Pre(Outer->ServerSessions.FindByName(serverNames[j])->GetStatusSingleLine()); + } + } + + for (unsigned j = 0; j < moduleNames.size(); ++j) { + TTagGuard tr("tr"); + Td("module"); + + { + TTagGuard td("td"); + if (false) { + OldModuleLink(moduleNames[j]); + } else { + // TODO + Text(moduleNames[j]); + } + } + + { + TTagGuard td("td"); + Pre(Outer->Modules.FindByName(moduleNames[j])->GetStatusSingleLine()); + } + } + } + + void ServeQueue(const TString& name) { + TBusMessageQueuePtr queue = Outer->Queues.FindByName(name); + + if (!queue) { + BootstrapError(ConcatStrings("queue not found by name: ", name)); + return; + } + + BreadcrumbQueue(name); + + TDivGuard container("container"); + + H1(ConcatStrings("MessageBus queue ", '"', name, '"')); + + TBusMessageQueueStatus status = queue->GetStatusRecordInternal(); + + Pre(status.PrintToString()); + + ServeSessionsOfQueue(queue, false); + + HnWithSmall(3, "Peak queue size", "(stored for an hour)"); + + { + TDivGuard div; + TDivGuard div2(TAttr("id", "queue-size-graph"), TAttr("style", "height: 300px")); + } + + { + TScriptFunctionGuard script; + + NJsonWriter::TBuf data(NJsonWriter::HEM_ESCAPE_HTML); + NJsonWriter::TBuf ticks(NJsonWriter::HEM_ESCAPE_HTML); + + const TExecutorHistory& history = status.ExecutorStatus.History; + + data.BeginList(); + ticks.BeginList(); + for (unsigned i = 0; i < history.HistoryRecords.size(); ++i) { + ui64 secondOfMinute = (history.FirstHistoryRecordSecond() + i) % 60; + ui64 minuteOfHour = (history.FirstHistoryRecordSecond() + i) / 60 % 60; + + unsigned printEach; + + if (history.HistoryRecords.size() <= 500) { + printEach = 1; + } else if (history.HistoryRecords.size() <= 1000) { + printEach = 2; + } else if (history.HistoryRecords.size() <= 3000) { + printEach = 6; + } else { + printEach = 12; + } + + if (secondOfMinute % printEach != 0) { + continue; + } + + ui32 max = 0; + for (unsigned j = 0; j < printEach; ++j) { + if (i < j) { + continue; + } + max = Max<ui32>(max, history.HistoryRecords[i - j].MaxQueueSize); + } + + data.BeginList(); + data.WriteString(ToString(i)); + data.WriteInt(max); + data.EndList(); + + // TODO: can be done with flot time plugin + if (history.HistoryRecords.size() <= 20) { + ticks.BeginList(); + ticks.WriteInt(i); + ticks.WriteString(ToString(secondOfMinute)); + ticks.EndList(); + } else if (history.HistoryRecords.size() <= 60) { + if (secondOfMinute % 5 == 0) { + ticks.BeginList(); + ticks.WriteInt(i); + ticks.WriteString(ToString(secondOfMinute)); + ticks.EndList(); + } + } else { + bool needTick; + if (history.HistoryRecords.size() <= 3 * 60) { + needTick = secondOfMinute % 15 == 0; + } else if (history.HistoryRecords.size() <= 7 * 60) { + needTick = secondOfMinute % 30 == 0; + } else if (history.HistoryRecords.size() <= 20 * 60) { + needTick = secondOfMinute == 0; + } else { + needTick = secondOfMinute == 0 && minuteOfHour % 5 == 0; + } + if (needTick) { + ticks.BeginList(); + ticks.WriteInt(i); + ticks.WriteString(Sprintf(":%02u:%02u", (unsigned)minuteOfHour, (unsigned)secondOfMinute)); + ticks.EndList(); + } + } + } + ticks.EndList(); + data.EndList(); + + HtmlOutputStream() << " var data = " << data.Str() << ";\n"; + HtmlOutputStream() << " var ticks = " << ticks.Str() << ";\n"; + HtmlOutputStream() << " plotQueueSize('#queue-size-graph', data, ticks);\n"; + } + } + + void ServeSession(TStringBuf name, bool client) { + TIntrusivePtr<TBusClientSession> clientSession; + TIntrusivePtr<TBusServerSession> serverSession; + TBusSession* session; + TStringBuf whatSession; + if (client) { + whatSession = "client session"; + clientSession = Outer->ClientSessions.FindByName(name); + session = clientSession.Get(); + } else { + whatSession = "server session"; + serverSession = Outer->ServerSessions.FindByName(name); + session = serverSession.Get(); + } + if (!session) { + BootstrapError(ConcatStrings(whatSession, " not found by name: ", name)); + return; + } + + TSessionDumpStatus dumpStatus = session->GetStatusRecordInternal(); + + TBusMessageQueuePtr queue = session->GetQueue(); + TString queueName = Outer->Queues.FindNameByPtr(session->GetQueue()); + + BreadcrumbSession(name, client); + + TDivGuard container("container"); + + H1(ConcatStrings("MessageBus ", whatSession, " ", '"', name, '"')); + + TBusMessageQueueStatus queueStatus = queue->GetStatusRecordInternal(); + + { + H3(ConcatStrings("queue ", queueName)); + Pre(queueStatus.PrintToString()); + } + + TSessionDumpStatus status = session->GetStatusRecordInternal(); + + if (status.Shutdown) { + BootstrapError("Session shut down"); + return; + } + + H3("Basic"); + Pre(status.Head); + + if (status.ConnectionStatusSummary.Server) { + H3("Acceptors"); + Pre(status.Acceptors); + } + + H3("Connections"); + Pre(status.ConnectionsSummary); + + { + TDivGuard div; + TTagGuard button("button", + TAttr("type", "button"), + TAttr("class", "btn"), + TAttr("data-toggle", "collapse"), + TAttr("data-target", "#connections")); + Text("Show connection details"); + } + { + TDivGuard div(TAttr("id", "connections"), TAttr("class", "collapse")); + Pre(status.Connections); + } + + H3("TBusSessionConfig"); + Pre(status.Config.PrintToString()); + + if (!client) { + H3("Message process time histogram"); + + const TDurationHistogram& h = + dumpStatus.ConnectionStatusSummary.WriterStatus.Incremental.ProcessDurationHistogram; + + { + TDivGuard div; + TDivGuard div2(TAttr("id", "h"), TAttr("style", "height: 300px")); + } + + { + TScriptFunctionGuard script; + + NJsonWriter::TBuf buf(NJsonWriter::HEM_ESCAPE_HTML); + buf.BeginList(); + for (unsigned i = 0; i < h.Times.size(); ++i) { + TString label = TDurationHistogram::LabelBefore(i); + buf.BeginList(); + buf.WriteString(label); + buf.WriteLongLong(h.Times[i]); + buf.EndList(); + } + buf.EndList(); + + HtmlOutputStream() << " var hist = " << buf.Str() << ";\n"; + HtmlOutputStream() << " plotHist('#h', hist);\n"; + } + } + } + + void ServeDefault() { + if (!Outer->Queues) { + BootstrapError("no queues"); + return; + } + + BreadcrumbRoot(); + + TDivGuard container("container"); + + H1("MessageBus queues"); + + for (unsigned i = 0; i < Outer->Queues.size(); ++i) { + TString queueName = Outer->Queues.Entries[i].first; + TBusMessageQueuePtr queue = Outer->Queues.Entries[i].second; + + HnWithSmall(3, queueName, "(queue)"); + + ServeSessionsOfQueue(queue, true); + } + } + + void WriteQueueSensors(NMonitoring::TDeprecatedJsonWriter& sj, TStringBuf queueName, TBusMessageQueue* queue) { + auto status = queue->GetStatusRecordInternal(); + sj.OpenMetric(); + sj.WriteLabels("mb_queue", queueName, "sensor", "WorkQueueSize"); + sj.WriteValue(status.ExecutorStatus.WorkQueueSize); + sj.CloseMetric(); + } + + void WriteMessageCounterSensors(NMonitoring::TDeprecatedJsonWriter& sj, + TStringBuf labelName, TStringBuf sessionName, bool read, const TMessageCounter& counter) { + TStringBuf readOrWrite = read ? "read" : "write"; + + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "mb_dir", readOrWrite, "sensor", "MessageBytes"); + sj.WriteValue(counter.BytesData); + sj.WriteModeDeriv(); + sj.CloseMetric(); + + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "mb_dir", readOrWrite, "sensor", "MessageCount"); + sj.WriteValue(counter.Count); + sj.WriteModeDeriv(); + sj.CloseMetric(); + } + + void WriteSessionStatus(NMonitoring::TDeprecatedJsonWriter& sj, TStringBuf sessionName, bool client, + TBusSession* session) { + TStringBuf labelName = client ? "mb_client_session" : "mb_server_session"; + + auto status = session->GetStatusRecordInternal(); + + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "sensor", "InFlightCount"); + sj.WriteValue(status.Status.InFlightCount); + sj.CloseMetric(); + + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "sensor", "InFlightSize"); + sj.WriteValue(status.Status.InFlightSize); + sj.CloseMetric(); + + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "sensor", "SendQueueSize"); + sj.WriteValue(status.ConnectionStatusSummary.WriterStatus.SendQueueSize); + sj.CloseMetric(); + + if (client) { + sj.OpenMetric(); + sj.WriteLabels(labelName, sessionName, "sensor", "AckMessagesSize"); + sj.WriteValue(status.ConnectionStatusSummary.WriterStatus.AckMessagesSize); + sj.CloseMetric(); + } + + WriteMessageCounterSensors(sj, labelName, sessionName, false, + status.ConnectionStatusSummary.WriterStatus.Incremental.MessageCounter); + WriteMessageCounterSensors(sj, labelName, sessionName, true, + status.ConnectionStatusSummary.ReaderStatus.Incremental.MessageCounter); + } + + void ServeSolomonJson(const TString& q, const TString& cs, const TString& ss) { + Y_UNUSED(q); + Y_UNUSED(cs); + Y_UNUSED(ss); + bool all = q == "" && cs == "" && ss == ""; + + NMonitoring::TDeprecatedJsonWriter sj(&Os); + + sj.OpenDocument(); + sj.OpenMetrics(); + + for (unsigned i = 0; i < Outer->Queues.size(); ++i) { + TString queueName = Outer->Queues.Entries[i].first; + TBusMessageQueuePtr queue = Outer->Queues.Entries[i].second; + if (all || q == queueName) { + WriteQueueSensors(sj, queueName, &*queue); + } + + TVector<TString> clientNames = Outer->ClientSessions.GetNamesForQueue(queue.Get()); + TVector<TString> serverNames = Outer->ServerSessions.GetNamesForQueue(queue.Get()); + TVector<TString> moduleNames = Outer->Modules.GetNamesForQueue(queue.Get()); + for (auto& sessionName : clientNames) { + if (all || cs == sessionName) { + auto session = Outer->ClientSessions.FindByName(sessionName); + WriteSessionStatus(sj, sessionName, true, &*session); + } + } + + for (auto& sessionName : serverNames) { + if (all || ss == sessionName) { + auto session = Outer->ServerSessions.FindByName(sessionName); + WriteSessionStatus(sj, sessionName, false, &*session); + } + } + } + + sj.CloseMetrics(); + sj.CloseDocument(); + } + + void ServeStatic(IOutputStream& os, TStringBuf path) { + if (path.EndsWith(".js")) { + os << HTTP_OK_JS; + } else if (path.EndsWith(".png")) { + os << HTTP_OK_PNG; + } else { + os << HTTP_OK_BIN; + } + TBlob blob = Singleton<TWwwStaticLoader>()->ObjectBlobByKey(TString("/") + TString(path)); + os.Write(blob.Data(), blob.Size()); + } + + void HeaderJsCss() { + LinkStylesheet("//yandex.st/bootstrap/3.0.2/css/bootstrap.css"); + LinkFavicon("?file=bus-ico.png"); + ScriptHref("//yandex.st/jquery/2.0.3/jquery.js"); + ScriptHref("//yandex.st/bootstrap/3.0.2/js/bootstrap.js"); + ScriptHref("//cdnjs.cloudflare.com/ajax/libs/flot/0.8.1/jquery.flot.min.js"); + ScriptHref("//cdnjs.cloudflare.com/ajax/libs/flot/0.8.1/jquery.flot.categories.min.js"); + ScriptHref("?file=messagebus.js"); + } + + void Serve() { + THtmlOutputStreamPushPop pp(&Os); + + TCgiParameters::const_iterator file = CgiParams.Find("file"); + if (file != CgiParams.end()) { + ServeStatic(Os, file->second); + return; + } + + bool solomonJson = false; + TCgiParameters::const_iterator fmt = CgiParams.Find("fmt"); + if (fmt != CgiParams.end()) { + if (fmt->second == "solomon-json") { + solomonJson = true; + } + } + + TCgiParameters::const_iterator cs = CgiParams.Find("cs"); + TCgiParameters::const_iterator ss = CgiParams.Find("ss"); + TCgiParameters::const_iterator q = CgiParams.Find("q"); + + if (solomonJson) { + Os << HTTP_OK_JSON; + + TString qp = q != CgiParams.end() ? q->first : ""; + TString csp = cs != CgiParams.end() ? cs->first : ""; + TString ssp = ss != CgiParams.end() ? ss->first : ""; + ServeSolomonJson(qp, csp, ssp); + } else { + Os << HTTP_OK_HTML; + + Doctype(); + + TTagGuard html("html"); + { + TTagGuard head("head"); + + HeaderJsCss(); + // ✉ 🚌 + Title(TChars("MessageBus", false)); + } + + TTagGuard body("body"); + + if (cs != CgiParams.end()) { + ServeSession(cs->second, true); + } else if (ss != CgiParams.end()) { + ServeSession(ss->second, false); + } else if (q != CgiParams.end()) { + ServeQueue(q->second); + } else { + ServeDefault(); + } + } + } + }; + + void ServeHttp(IOutputStream& os, const TCgiParameters& queryArgs, const TBusWww::TOptionalParams& params) { + TGuard<TMutex> g(Mutex); + + TRequest request(this, os, queryArgs, params); + + request.Serve(); + } +}; + +NBus::TBusWww::TBusWww() + : Impl(new TImpl) +{ +} + +NBus::TBusWww::~TBusWww() { +} + +void NBus::TBusWww::RegisterClientSession(TBusClientSessionPtr s) { + Impl->RegisterClientSession(s); +} + +void TBusWww::RegisterServerSession(TBusServerSessionPtr s) { + Impl->RegisterServerSession(s); +} + +void TBusWww::RegisterQueue(TBusMessageQueuePtr q) { + Impl->RegisterQueue(q); +} + +void TBusWww::RegisterModule(TBusModule* module) { + Impl->RegisterModule(module); +} + +void TBusWww::ServeHttp(IOutputStream& httpOutputStream, + const TCgiParameters& queryArgs, + const TBusWww::TOptionalParams& params) { + Impl->ServeHttp(httpOutputStream, queryArgs, params); +} + +struct TBusWwwHttpServer::TImpl: public THttpServer::ICallBack { + TIntrusivePtr<TBusWww> Www; + THttpServer HttpServer; + + static THttpServer::TOptions MakeHttpServerOptions(unsigned port) { + Y_VERIFY(port > 0); + THttpServer::TOptions r; + r.Port = port; + return r; + } + + TImpl(TIntrusivePtr<TBusWww> www, unsigned port) + : Www(www) + , HttpServer(this, MakeHttpServerOptions(port)) + { + HttpServer.Start(); + } + + struct TClientRequestImpl: public TClientRequest { + TBusWwwHttpServer::TImpl* const Outer; + + TClientRequestImpl(TBusWwwHttpServer::TImpl* outer) + : Outer(outer) + { + } + + bool Reply(void*) override { + Outer->ServeRequest(Input(), Output()); + return true; + } + }; + + TString MakeSimpleResponse(unsigned code, TString text, TString content = "") { + if (!content) { + TStringStream contentSs; + contentSs << code << " " << text; + content = contentSs.Str(); + } + TStringStream ss; + ss << "HTTP/1.1 " + << code << " " << text << "\r\nConnection: Close\r\n\r\n" + << content; + return ss.Str(); + } + + void ServeRequest(THttpInput& input, THttpOutput& output) { + TCgiParameters cgiParams; + try { + THttpRequestHeader header; + THttpHeaderParser parser; + parser.Init(&header); + if (parser.Execute(input.FirstLine()) < 0) { + HtmlOutputStream() << MakeSimpleResponse(400, "Bad request"); + return; + } + THttpURL url; + if (url.Parse(header.GetUrl()) != THttpURL::ParsedOK) { + HtmlOutputStream() << MakeSimpleResponse(400, "Invalid url"); + return; + } + cgiParams.Scan(url.Get(THttpURL::FieldQuery)); + + TBusWww::TOptionalParams params; + //params.ParentLinks.emplace_back(); + //params.ParentLinks.back().Title = "temp"; + //params.ParentLinks.back().Href = "http://wiki.yandex-team.ru/"; + + Www->ServeHttp(output, cgiParams, params); + } catch (...) { + output << MakeSimpleResponse(500, "Exception", + TString() + "Exception: " + CurrentExceptionMessage()); + } + } + + TClientRequest* CreateClient() override { + return new TClientRequestImpl(this); + } + + ~TImpl() override { + HttpServer.Stop(); + } +}; + +NBus::TBusWwwHttpServer::TBusWwwHttpServer(TIntrusivePtr<TBusWww> www, unsigned port) + : Impl(new TImpl(www, port)) +{ +} + +NBus::TBusWwwHttpServer::~TBusWwwHttpServer() { +} diff --git a/library/cpp/messagebus/www/www.h b/library/cpp/messagebus/www/www.h new file mode 100644 index 0000000000..6cd652b477 --- /dev/null +++ b/library/cpp/messagebus/www/www.h @@ -0,0 +1,45 @@ +#pragma once + +#include <library/cpp/messagebus/ybus.h> +#include <library/cpp/messagebus/oldmodule/module.h> + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <library/cpp/cgiparam/cgiparam.h> + +namespace NBus { + class TBusWww: public TAtomicRefCount<TBusWww> { + public: + struct TLink { + TString Title; + TString Href; + }; + + struct TOptionalParams { + TVector<TLink> ParentLinks; + }; + + TBusWww(); + ~TBusWww(); + + void RegisterClientSession(TBusClientSessionPtr); + void RegisterServerSession(TBusServerSessionPtr); + void RegisterQueue(TBusMessageQueuePtr); + void RegisterModule(TBusModule*); + + void ServeHttp(IOutputStream& httpOutputStream, const TCgiParameters& queryArgs, const TOptionalParams& params = TOptionalParams()); + + struct TImpl; + THolder<TImpl> Impl; + }; + + class TBusWwwHttpServer { + public: + TBusWwwHttpServer(TIntrusivePtr<TBusWww> www, unsigned port); + ~TBusWwwHttpServer(); + + struct TImpl; + THolder<TImpl> Impl; + }; + +} diff --git a/library/cpp/messagebus/www/ya.make b/library/cpp/messagebus/www/ya.make new file mode 100644 index 0000000000..972390cea3 --- /dev/null +++ b/library/cpp/messagebus/www/ya.make @@ -0,0 +1,29 @@ +LIBRARY() + +OWNER(g:messagebus) + +SRCS( + html_output.cpp + www.cpp +) + +ARCHIVE( + NAME www_static.inc + messagebus.js + bus-ico.png +) + +PEERDIR( + library/cpp/archive + library/cpp/cgiparam + library/cpp/html/pcdata + library/cpp/http/fetch + library/cpp/http/server + library/cpp/json/writer + library/cpp/messagebus + library/cpp/messagebus/oldmodule + library/cpp/monlib/deprecated/json + library/cpp/uri +) + +END() diff --git a/library/cpp/messagebus/ya.make b/library/cpp/messagebus/ya.make new file mode 100644 index 0000000000..e13cf06dea --- /dev/null +++ b/library/cpp/messagebus/ya.make @@ -0,0 +1,68 @@ +LIBRARY() + +OWNER(g:messagebus) + +IF (SANITIZER_TYPE == "undefined") + NO_SANITIZE() +ENDIF() + +SRCS( + acceptor.cpp + acceptor_status.cpp + connection.cpp + coreconn.cpp + duration_histogram.cpp + event_loop.cpp + futex_like.cpp + handler.cpp + key_value_printer.cpp + local_flags.cpp + locator.cpp + mb_lwtrace.cpp + message.cpp + message_counter.cpp + message_status.cpp + message_status_counter.cpp + messqueue.cpp + misc/atomic_box.h + misc/granup.h + misc/test_sync.h + misc/tokenquota.h + misc/weak_ptr.h + network.cpp + queue_config.cpp + remote_client_connection.cpp + remote_client_session.cpp + remote_client_session_semaphore.cpp + remote_connection.cpp + remote_connection_status.cpp + remote_server_connection.cpp + remote_server_session.cpp + remote_server_session_semaphore.cpp + session.cpp + session_impl.cpp + session_job_count.cpp + shutdown_state.cpp + socket_addr.cpp + storage.cpp + synchandler.cpp + use_after_free_checker.cpp + use_count_checker.cpp + ybus.h +) + +PEERDIR( + contrib/libs/sparsehash + library/cpp/codecs + library/cpp/deprecated/enum_codegen + library/cpp/getopt/small + library/cpp/lwtrace + library/cpp/messagebus/actor + library/cpp/messagebus/config + library/cpp/messagebus/monitoring + library/cpp/messagebus/scheduler + library/cpp/string_utils/indent_text + library/cpp/threading/future +) + +END() diff --git a/library/cpp/messagebus/ybus.h b/library/cpp/messagebus/ybus.h new file mode 100644 index 0000000000..de21ad8521 --- /dev/null +++ b/library/cpp/messagebus/ybus.h @@ -0,0 +1,205 @@ +#pragma once + +/// Asynchronous Messaging Library implements framework for sending and +/// receiving messages between loosely connected processes. + +#include "coreconn.h" +#include "defs.h" +#include "handler.h" +#include "handler_impl.h" +#include "local_flags.h" +#include "locator.h" +#include "message.h" +#include "message_status.h" +#include "network.h" +#include "queue_config.h" +#include "remote_connection_status.h" +#include "session.h" +#include "session_config.h" +#include "socket_addr.h" + +#include <library/cpp/messagebus/actor/executor.h> +#include <library/cpp/messagebus/scheduler/scheduler.h> + +#include <library/cpp/codecs/codecs.h> + +#include <util/generic/array_ref.h> +#include <util/generic/buffer.h> +#include <util/generic/noncopyable.h> +#include <util/generic/ptr.h> +#include <util/stream/input.h> +#include <util/system/atomic.h> +#include <util/system/condvar.h> +#include <util/system/type_name.h> +#include <util/system/event.h> +#include <util/system/mutex.h> + +namespace NBus { + //////////////////////////////////////////////////////// + /// \brief Common structure to store address information + + int CompareByHost(const IRemoteAddr& l, const IRemoteAddr& r) noexcept; + bool operator<(const TNetAddr& a1, const TNetAddr& a2); // compare by addresses + + ///////////////////////////////////////////////////////////////////////// + /// \brief Handles routing and data encoding to/from wire + + /// Protocol is stateless threadsafe singleton object that + /// encapsulates relationship between a message (TBusMessage) object + /// and destination server. Protocol object is reponsible for serializing in-memory + /// message and reply into the wire, retuning name of the service and resource + /// distribution key for given protocol. + + /// Protocol object should transparently handle messages and replies. + /// This is interface only class, actuall instances of the protocols + /// should be created using templates inhereted from this base class. + class TBusProtocol { + private: + TString ServiceName; + int ServicePort; + + public: + TBusProtocol(TBusService name = "UNKNOWN", int port = 0) + : ServiceName(name) + , ServicePort(port) + { + } + + /// returns service type for this protocol and message + TBusService GetService() const { + return ServiceName.data(); + } + + /// returns port number for destination session to open socket + int GetPort() const { + return ServicePort; + } + + virtual ~TBusProtocol() { + } + + /// \brief serialized protocol specific data into TBusData + /// \note buffer passed to the function (data) is not empty, use append functions + virtual void Serialize(const TBusMessage* mess, TBuffer& data) = 0; + + /// deserialized TBusData into new instance of the message + virtual TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) = 0; + + /// returns key for messages of this protocol + virtual TBusKey GetKey(const TBusMessage*) { + return YBUS_KEYMIN; + } + + /// default implementation of routing policy to allow overrides + virtual EMessageStatus GetDestination(const TBusClientSession* session, TBusMessage* mess, TBusLocator* locator, TNetAddr* addr); + + /// codec for transport level compression + virtual NCodecs::TCodecPtr GetTransportCodec(void) const { + return NCodecs::ICodec::GetInstance("snappy"); + } + }; + + class TBusSyncSourceSession: public TAtomicRefCount<TBusSyncSourceSession> { + friend class TBusMessageQueue; + + public: + TBusSyncSourceSession(TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> session); + ~TBusSyncSourceSession(); + + void Shutdown(); + + TBusMessage* SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr = nullptr); + + int RegisterService(const char* hostname, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion ipVersion = EIP_VERSION_4); + + int GetInFlight(); + + const TBusProtocol* GetProto() const; + + const TBusClientSession* GetBusClientSessionWorkaroundDoNotUse() const; // It's for TLoadBalancedProtocol::GetDestination() function that really needs TBusClientSession* unlike all other protocols. Look at review 32425 (http://rb.yandex-team.ru/arc/r/32425/) for more information. + private: + TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> Session; + }; + + using TBusSyncClientSessionPtr = TIntrusivePtr<TBusSyncSourceSession>; + + /////////////////////////////////////////////////////////////////// + /// \brief Main message queue object, need one per application + class TBusMessageQueue: public TAtomicRefCount<TBusMessageQueue> { + /// allow mesage queue to be created only via factory + friend TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name); + friend class ::NBus::NPrivate::TRemoteConnection; + friend struct ::NBus::NPrivate::TBusSessionImpl; + friend class ::NBus::NPrivate::TAcceptor; + friend struct ::NBus::TBusServerSession; + + private: + const TBusQueueConfig Config; + TMutex Lock; + TList<TIntrusivePtr< ::NBus::NPrivate::TBusSessionImpl>> Sessions; + TSimpleIntrusivePtr<TBusLocator> Locator; + NPrivate::TScheduler Scheduler; + + ::NActor::TExecutorPtr WorkQueue; + + TAtomic Running; + TSystemEvent ShutdownComplete; + + private: + /// constructor is protected, used NBus::CreateMessageQueue() to create a instance + TBusMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name); + + public: + TString GetNameInternal() const; + + ~TBusMessageQueue(); + + void Stop(); + bool IsRunning(); + + public: + void EnqueueWork(TArrayRef< ::NActor::IWorkItem* const> w) { + WorkQueue->EnqueueWork(w); + } + + ::NActor::TExecutor* GetExecutor() { + return WorkQueue.Get(); + } + + TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) const; + // without sessions + NPrivate::TBusMessageQueueStatus GetStatusRecordInternal() const; + TString GetStatusSelf() const; + TString GetStatusSingleLine() const; + + TBusLocator* GetLocator() const { + return Locator.Get(); + } + + TBusClientSessionPtr CreateSource(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, const TString& name = ""); + TBusSyncClientSessionPtr CreateSyncSource(TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply = true, const TString& name = ""); + TBusServerSessionPtr CreateDestination(TBusProtocol* proto, IBusServerHandler* hander, const TBusServerSessionConfig& config, const TString& name = ""); + TBusServerSessionPtr CreateDestination(TBusProtocol* proto, IBusServerHandler* hander, const TBusServerSessionConfig& config, const TVector<TBindResult>& bindTo, const TString& name = ""); + + private: + void Destroy(TBusSession* session); + void Destroy(TBusSyncClientSessionPtr session); + + public: + void Schedule(NPrivate::IScheduleItemAutoPtr i); + + private: + void DestroyAllSessions(); + void Add(TIntrusivePtr< ::NBus::NPrivate::TBusSessionImpl> session); + void Remove(TBusSession* session); + }; + + ///////////////////////////////////////////////////////////////// + /// Factory methods to construct message queue + TBusMessageQueuePtr CreateMessageQueue(const char* name = ""); + TBusMessageQueuePtr CreateMessageQueue(NActor::TExecutorPtr executor, const char* name = ""); + TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, const char* name = ""); + TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, TBusLocator* locator, const char* name = ""); + TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name = ""); + +} |