aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/messagebus
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/messagebus
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/messagebus')
-rw-r--r--library/cpp/messagebus/acceptor.cpp127
-rw-r--r--library/cpp/messagebus/acceptor.h60
-rw-r--r--library/cpp/messagebus/acceptor_status.cpp68
-rw-r--r--library/cpp/messagebus/acceptor_status.h35
-rw-r--r--library/cpp/messagebus/actor/actor.h144
-rw-r--r--library/cpp/messagebus/actor/actor_ut.cpp157
-rw-r--r--library/cpp/messagebus/actor/executor.cpp338
-rw-r--r--library/cpp/messagebus/actor/executor.h105
-rw-r--r--library/cpp/messagebus/actor/queue_for_actor.h74
-rw-r--r--library/cpp/messagebus/actor/queue_in_actor.h80
-rw-r--r--library/cpp/messagebus/actor/ring_buffer.h135
-rw-r--r--library/cpp/messagebus/actor/ring_buffer_ut.cpp60
-rw-r--r--library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h91
-rw-r--r--library/cpp/messagebus/actor/tasks.h48
-rw-r--r--library/cpp/messagebus/actor/tasks_ut.cpp37
-rw-r--r--library/cpp/messagebus/actor/temp_tls_vector.h40
-rw-r--r--library/cpp/messagebus/actor/thread_extra.cpp30
-rw-r--r--library/cpp/messagebus/actor/thread_extra.h41
-rw-r--r--library/cpp/messagebus/actor/what_thread_does.cpp22
-rw-r--r--library/cpp/messagebus/actor/what_thread_does.h28
-rw-r--r--library/cpp/messagebus/actor/what_thread_does_guard.h40
-rw-r--r--library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp13
-rw-r--r--library/cpp/messagebus/actor/ya.make11
-rw-r--r--library/cpp/messagebus/all.lwt8
-rw-r--r--library/cpp/messagebus/all/ya.make10
-rw-r--r--library/cpp/messagebus/async_result.h54
-rw-r--r--library/cpp/messagebus/async_result_ut.cpp37
-rw-r--r--library/cpp/messagebus/base.h11
-rw-r--r--library/cpp/messagebus/cc_semaphore.h36
-rw-r--r--library/cpp/messagebus/cc_semaphore_ut.cpp45
-rw-r--r--library/cpp/messagebus/codegen.h4
-rw-r--r--library/cpp/messagebus/config/codegen.h10
-rw-r--r--library/cpp/messagebus/config/defs.h82
-rw-r--r--library/cpp/messagebus/config/netaddr.cpp183
-rw-r--r--library/cpp/messagebus/config/netaddr.h86
-rw-r--r--library/cpp/messagebus/config/session_config.cpp157
-rw-r--r--library/cpp/messagebus/config/session_config.h65
-rw-r--r--library/cpp/messagebus/config/ya.make15
-rw-r--r--library/cpp/messagebus/connection.cpp16
-rw-r--r--library/cpp/messagebus/connection.h61
-rw-r--r--library/cpp/messagebus/coreconn.cpp30
-rw-r--r--library/cpp/messagebus/coreconn.h67
-rw-r--r--library/cpp/messagebus/coreconn_ut.cpp25
-rw-r--r--library/cpp/messagebus/debug_receiver/debug_receiver.cpp42
-rw-r--r--library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp20
-rw-r--r--library/cpp/messagebus/debug_receiver/debug_receiver_handler.h10
-rw-r--r--library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp20
-rw-r--r--library/cpp/messagebus/debug_receiver/debug_receiver_proto.h27
-rw-r--r--library/cpp/messagebus/debug_receiver/ya.make17
-rw-r--r--library/cpp/messagebus/defs.h4
-rw-r--r--library/cpp/messagebus/dummy_debugger.h9
-rw-r--r--library/cpp/messagebus/duration_histogram.cpp74
-rw-r--r--library/cpp/messagebus/duration_histogram.h45
-rw-r--r--library/cpp/messagebus/duration_histogram_ut.cpp38
-rw-r--r--library/cpp/messagebus/event_loop.cpp370
-rw-r--r--library/cpp/messagebus/event_loop.h72
-rw-r--r--library/cpp/messagebus/extra_ref.h36
-rw-r--r--library/cpp/messagebus/futex_like.cpp55
-rw-r--r--library/cpp/messagebus/futex_like.h86
-rw-r--r--library/cpp/messagebus/handler.cpp36
-rw-r--r--library/cpp/messagebus/handler.h135
-rw-r--r--library/cpp/messagebus/handler_impl.h23
-rw-r--r--library/cpp/messagebus/hash.h19
-rw-r--r--library/cpp/messagebus/key_value_printer.cpp46
-rw-r--r--library/cpp/messagebus/key_value_printer.h28
-rw-r--r--library/cpp/messagebus/latch.h53
-rw-r--r--library/cpp/messagebus/latch_ut.cpp20
-rw-r--r--library/cpp/messagebus/left_right_buffer.h78
-rw-r--r--library/cpp/messagebus/lfqueue_batch.h36
-rw-r--r--library/cpp/messagebus/lfqueue_batch_ut.cpp56
-rw-r--r--library/cpp/messagebus/local_flags.cpp32
-rw-r--r--library/cpp/messagebus/local_flags.h26
-rw-r--r--library/cpp/messagebus/local_flags_ut.cpp18
-rw-r--r--library/cpp/messagebus/local_tasks.h23
-rw-r--r--library/cpp/messagebus/locator.cpp427
-rw-r--r--library/cpp/messagebus/locator.h93
-rw-r--r--library/cpp/messagebus/mb_lwtrace.cpp12
-rw-r--r--library/cpp/messagebus/mb_lwtrace.h19
-rw-r--r--library/cpp/messagebus/memory.h42
-rw-r--r--library/cpp/messagebus/memory_ut.cpp13
-rw-r--r--library/cpp/messagebus/message.cpp198
-rw-r--r--library/cpp/messagebus/message.h272
-rw-r--r--library/cpp/messagebus/message_counter.cpp46
-rw-r--r--library/cpp/messagebus/message_counter.h36
-rw-r--r--library/cpp/messagebus/message_ptr_and_header.h36
-rw-r--r--library/cpp/messagebus/message_status.cpp13
-rw-r--r--library/cpp/messagebus/message_status.h57
-rw-r--r--library/cpp/messagebus/message_status_counter.cpp71
-rw-r--r--library/cpp/messagebus/message_status_counter.h36
-rw-r--r--library/cpp/messagebus/message_status_counter_ut.cpp23
-rw-r--r--library/cpp/messagebus/messqueue.cpp198
-rw-r--r--library/cpp/messagebus/misc/atomic_box.h34
-rw-r--r--library/cpp/messagebus/misc/granup.h50
-rw-r--r--library/cpp/messagebus/misc/test_sync.h75
-rw-r--r--library/cpp/messagebus/misc/tokenquota.h83
-rw-r--r--library/cpp/messagebus/misc/weak_ptr.h99
-rw-r--r--library/cpp/messagebus/misc/weak_ptr_ut.cpp46
-rw-r--r--library/cpp/messagebus/monitoring/mon_proto.proto55
-rw-r--r--library/cpp/messagebus/monitoring/ya.make15
-rw-r--r--library/cpp/messagebus/moved.h39
-rw-r--r--library/cpp/messagebus/moved_ut.cpp22
-rw-r--r--library/cpp/messagebus/netaddr.h4
-rw-r--r--library/cpp/messagebus/netaddr_ut.cpp21
-rw-r--r--library/cpp/messagebus/network.cpp156
-rw-r--r--library/cpp/messagebus/network.h28
-rw-r--r--library/cpp/messagebus/network_ut.cpp65
-rw-r--r--library/cpp/messagebus/nondestroying_holder.h39
-rw-r--r--library/cpp/messagebus/nondestroying_holder_ut.cpp12
-rw-r--r--library/cpp/messagebus/oldmodule/module.cpp881
-rw-r--r--library/cpp/messagebus/oldmodule/module.h410
-rw-r--r--library/cpp/messagebus/oldmodule/startsession.cpp65
-rw-r--r--library/cpp/messagebus/oldmodule/startsession.h34
-rw-r--r--library/cpp/messagebus/oldmodule/ya.make15
-rw-r--r--library/cpp/messagebus/protobuf/ya.make15
-rw-r--r--library/cpp/messagebus/protobuf/ybusbuf.cpp88
-rw-r--r--library/cpp/messagebus/protobuf/ybusbuf.h233
-rw-r--r--library/cpp/messagebus/queue_config.cpp22
-rw-r--r--library/cpp/messagebus/queue_config.h19
-rw-r--r--library/cpp/messagebus/rain_check/core/coro.cpp60
-rw-r--r--library/cpp/messagebus/rain_check/core/coro.h58
-rw-r--r--library/cpp/messagebus/rain_check/core/coro_stack.cpp41
-rw-r--r--library/cpp/messagebus/rain_check/core/coro_stack.h54
-rw-r--r--library/cpp/messagebus/rain_check/core/coro_ut.cpp106
-rw-r--r--library/cpp/messagebus/rain_check/core/env.cpp3
-rw-r--r--library/cpp/messagebus/rain_check/core/env.h47
-rw-r--r--library/cpp/messagebus/rain_check/core/fwd.h18
-rw-r--r--library/cpp/messagebus/rain_check/core/rain_check.cpp1
-rw-r--r--library/cpp/messagebus/rain_check/core/rain_check.h8
-rw-r--r--library/cpp/messagebus/rain_check/core/simple.cpp18
-rw-r--r--library/cpp/messagebus/rain_check/core/simple.h62
-rw-r--r--library/cpp/messagebus/rain_check/core/simple_ut.cpp59
-rw-r--r--library/cpp/messagebus/rain_check/core/sleep.cpp47
-rw-r--r--library/cpp/messagebus/rain_check/core/sleep.h24
-rw-r--r--library/cpp/messagebus/rain_check/core/sleep_ut.cpp46
-rw-r--r--library/cpp/messagebus/rain_check/core/spawn.cpp5
-rw-r--r--library/cpp/messagebus/rain_check/core/spawn.h50
-rw-r--r--library/cpp/messagebus/rain_check/core/spawn_ut.cpp145
-rw-r--r--library/cpp/messagebus/rain_check/core/task.cpp216
-rw-r--r--library/cpp/messagebus/rain_check/core/task.h184
-rw-r--r--library/cpp/messagebus/rain_check/core/track.cpp66
-rw-r--r--library/cpp/messagebus/rain_check/core/track.h97
-rw-r--r--library/cpp/messagebus/rain_check/core/track_ut.cpp45
-rw-r--r--library/cpp/messagebus/rain_check/core/ya.make25
-rw-r--r--library/cpp/messagebus/rain_check/http/client.cpp154
-rw-r--r--library/cpp/messagebus/rain_check/http/client.h78
-rw-r--r--library/cpp/messagebus/rain_check/http/client_ut.cpp205
-rw-r--r--library/cpp/messagebus/rain_check/http/http_code_extractor.cpp39
-rw-r--r--library/cpp/messagebus/rain_check/http/http_code_extractor.h16
-rw-r--r--library/cpp/messagebus/rain_check/http/ya.make17
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp98
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_client.h67
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp146
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp17
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_server.h46
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp51
-rw-r--r--library/cpp/messagebus/rain_check/messagebus/ya.make15
-rw-r--r--library/cpp/messagebus/rain_check/test/TestRainCheck.py8
-rw-r--r--library/cpp/messagebus/rain_check/test/helper/misc.cpp27
-rw-r--r--library/cpp/messagebus/rain_check/test/helper/misc.h57
-rw-r--r--library/cpp/messagebus/rain_check/test/helper/ya.make13
-rw-r--r--library/cpp/messagebus/rain_check/test/perftest/perftest.cpp154
-rw-r--r--library/cpp/messagebus/rain_check/test/perftest/ya.make14
-rw-r--r--library/cpp/messagebus/rain_check/test/ut/test.h13
-rw-r--r--library/cpp/messagebus/rain_check/test/ut/ya.make24
-rw-r--r--library/cpp/messagebus/rain_check/test/ya.make6
-rw-r--r--library/cpp/messagebus/rain_check/ya.make8
-rw-r--r--library/cpp/messagebus/ref_counted.h6
-rw-r--r--library/cpp/messagebus/remote_client_connection.cpp343
-rw-r--r--library/cpp/messagebus/remote_client_connection.h65
-rw-r--r--library/cpp/messagebus/remote_client_session.cpp127
-rw-r--r--library/cpp/messagebus/remote_client_session.h59
-rw-r--r--library/cpp/messagebus/remote_client_session_semaphore.cpp67
-rw-r--r--library/cpp/messagebus/remote_client_session_semaphore.h42
-rw-r--r--library/cpp/messagebus/remote_connection.cpp974
-rw-r--r--library/cpp/messagebus/remote_connection.h294
-rw-r--r--library/cpp/messagebus/remote_connection_status.cpp265
-rw-r--r--library/cpp/messagebus/remote_connection_status.h214
-rw-r--r--library/cpp/messagebus/remote_server_connection.cpp73
-rw-r--r--library/cpp/messagebus/remote_server_connection.h32
-rw-r--r--library/cpp/messagebus/remote_server_session.cpp206
-rw-r--r--library/cpp/messagebus/remote_server_session.h54
-rw-r--r--library/cpp/messagebus/remote_server_session_semaphore.cpp59
-rw-r--r--library/cpp/messagebus/remote_server_session_semaphore.h42
-rw-r--r--library/cpp/messagebus/scheduler/scheduler.cpp119
-rw-r--r--library/cpp/messagebus/scheduler/scheduler.h68
-rw-r--r--library/cpp/messagebus/scheduler/scheduler_ut.cpp36
-rw-r--r--library/cpp/messagebus/scheduler/ya.make13
-rw-r--r--library/cpp/messagebus/scheduler_actor.h85
-rw-r--r--library/cpp/messagebus/scheduler_actor_ut.cpp48
-rw-r--r--library/cpp/messagebus/session.cpp130
-rw-r--r--library/cpp/messagebus/session.h225
-rw-r--r--library/cpp/messagebus/session_config.h4
-rw-r--r--library/cpp/messagebus/session_impl.cpp650
-rw-r--r--library/cpp/messagebus/session_impl.h259
-rw-r--r--library/cpp/messagebus/session_job_count.cpp22
-rw-r--r--library/cpp/messagebus/session_job_count.h39
-rw-r--r--library/cpp/messagebus/shutdown_state.cpp20
-rw-r--r--library/cpp/messagebus/shutdown_state.h22
-rw-r--r--library/cpp/messagebus/socket_addr.cpp79
-rw-r--r--library/cpp/messagebus/socket_addr.h113
-rw-r--r--library/cpp/messagebus/socket_addr_ut.cpp15
-rw-r--r--library/cpp/messagebus/storage.cpp161
-rw-r--r--library/cpp/messagebus/storage.h94
-rw-r--r--library/cpp/messagebus/synchandler.cpp198
-rw-r--r--library/cpp/messagebus/test/TestMessageBus.py8
-rw-r--r--library/cpp/messagebus/test/example/client/client.cpp81
-rw-r--r--library/cpp/messagebus/test/example/client/ya.make13
-rw-r--r--library/cpp/messagebus/test/example/common/messages.proto15
-rw-r--r--library/cpp/messagebus/test/example/common/proto.cpp12
-rw-r--r--library/cpp/messagebus/test/example/common/proto.h17
-rw-r--r--library/cpp/messagebus/test/example/common/ya.make15
-rw-r--r--library/cpp/messagebus/test/example/server/server.cpp58
-rw-r--r--library/cpp/messagebus/test/example/server/ya.make13
-rw-r--r--library/cpp/messagebus/test/example/ya.make7
-rw-r--r--library/cpp/messagebus/test/helper/alloc_counter.h21
-rw-r--r--library/cpp/messagebus/test/helper/example.cpp281
-rw-r--r--library/cpp/messagebus/test/helper/example.h132
-rw-r--r--library/cpp/messagebus/test/helper/example_module.cpp43
-rw-r--r--library/cpp/messagebus/test/helper/example_module.h37
-rw-r--r--library/cpp/messagebus/test/helper/fixed_port.cpp10
-rw-r--r--library/cpp/messagebus/test/helper/fixed_port.h11
-rw-r--r--library/cpp/messagebus/test/helper/hanging_server.cpp13
-rw-r--r--library/cpp/messagebus/test/helper/hanging_server.h16
-rw-r--r--library/cpp/messagebus/test/helper/message_handler_error.cpp26
-rw-r--r--library/cpp/messagebus/test/helper/message_handler_error.h19
-rw-r--r--library/cpp/messagebus/test/helper/object_count_check.h74
-rw-r--r--library/cpp/messagebus/test/helper/wait_for.h14
-rw-r--r--library/cpp/messagebus/test/helper/ya.make17
-rw-r--r--library/cpp/messagebus/test/perftest/messages.proto7
-rw-r--r--library/cpp/messagebus/test/perftest/perftest.cpp713
-rw-r--r--library/cpp/messagebus/test/perftest/simple_proto.cpp22
-rw-r--r--library/cpp/messagebus/test/perftest/simple_proto.h29
-rw-r--r--library/cpp/messagebus/test/perftest/stackcollect.diff13
-rw-r--r--library/cpp/messagebus/test/perftest/ya.make24
-rw-r--r--library/cpp/messagebus/test/ut/count_down_latch.h30
-rw-r--r--library/cpp/messagebus/test/ut/locator_uniq_ut.cpp40
-rw-r--r--library/cpp/messagebus/test/ut/messagebus_ut.cpp1151
-rw-r--r--library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp143
-rw-r--r--library/cpp/messagebus/test/ut/module_client_ut.cpp368
-rw-r--r--library/cpp/messagebus/test/ut/module_server_ut.cpp119
-rw-r--r--library/cpp/messagebus/test/ut/moduletest.h221
-rw-r--r--library/cpp/messagebus/test/ut/one_way_ut.cpp255
-rw-r--r--library/cpp/messagebus/test/ut/starter_ut.cpp140
-rw-r--r--library/cpp/messagebus/test/ut/sync_client_ut.cpp69
-rw-r--r--library/cpp/messagebus/test/ut/ya.make56
-rw-r--r--library/cpp/messagebus/test/ya.make7
-rw-r--r--library/cpp/messagebus/test_utils.h12
-rw-r--r--library/cpp/messagebus/text_utils.h3
-rw-r--r--library/cpp/messagebus/thread_extra.h3
-rw-r--r--library/cpp/messagebus/use_after_free_checker.cpp22
-rw-r--r--library/cpp/messagebus/use_after_free_checker.h31
-rw-r--r--library/cpp/messagebus/use_count_checker.cpp53
-rw-r--r--library/cpp/messagebus/use_count_checker.h27
-rw-r--r--library/cpp/messagebus/vector_swaps.h171
-rw-r--r--library/cpp/messagebus/vector_swaps_ut.cpp17
-rw-r--r--library/cpp/messagebus/www/bus-ico.pngbin0 -> 2208 bytes
-rw-r--r--library/cpp/messagebus/www/concat_strings.h22
-rw-r--r--library/cpp/messagebus/www/html_output.cpp4
-rw-r--r--library/cpp/messagebus/www/html_output.h324
-rw-r--r--library/cpp/messagebus/www/messagebus.js48
-rw-r--r--library/cpp/messagebus/www/www.cpp930
-rw-r--r--library/cpp/messagebus/www/www.h45
-rw-r--r--library/cpp/messagebus/www/ya.make29
-rw-r--r--library/cpp/messagebus/ya.make68
-rw-r--r--library/cpp/messagebus/ybus.h205
265 files changed, 23011 insertions, 0 deletions
diff --git a/library/cpp/messagebus/acceptor.cpp b/library/cpp/messagebus/acceptor.cpp
new file mode 100644
index 0000000000..64a38619c2
--- /dev/null
+++ b/library/cpp/messagebus/acceptor.cpp
@@ -0,0 +1,127 @@
+#include "acceptor.h"
+
+#include "key_value_printer.h"
+#include "mb_lwtrace.h"
+#include "network.h"
+
+#include <util/network/init.h>
+#include <util/system/defaults.h>
+#include <util/system/error.h>
+#include <util/system/yassert.h>
+
+LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER)
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TAcceptor::TAcceptor(TBusSessionImpl* session, ui64 acceptorId, SOCKET socket, const TNetAddr& addr)
+ : TActor<TAcceptor>(session->Queue->WorkQueue.Get())
+ , AcceptorId(acceptorId)
+ , Session(session)
+ , GranStatus(session->Config.Secret.StatusFlushPeriod)
+{
+ SetNonBlock(socket, true);
+
+ Channel = Session->ReadEventLoop.Register(socket, this);
+ Channel->EnableRead();
+
+ Stats.AcceptorId = acceptorId;
+ Stats.Fd = socket;
+ Stats.ListenAddr = addr;
+
+ SendStatus(TInstant::Now());
+}
+
+void TAcceptor::Act(TDefaultTag) {
+ EShutdownState state = ShutdownState.State.Get();
+
+ if (state == SS_SHUTDOWN_COMPLETE) {
+ return;
+ }
+
+ TInstant now = TInstant::Now();
+
+ if (state == SS_SHUTDOWN_COMMAND) {
+ if (!!Channel) {
+ Channel->Unregister();
+ Channel.Drop();
+ Stats.Fd = INVALID_SOCKET;
+ }
+
+ SendStatus(now);
+
+ Session->GetDeadAcceptorStatusQueue()->EnqueueAndSchedule(Stats);
+ Stats.ResetIncremental();
+
+ ShutdownState.CompleteShutdown();
+ return;
+ }
+
+ THolder<TOpaqueAddr> addr(new TOpaqueAddr());
+ SOCKET acceptedSocket = accept(Channel->GetSocket(), addr->MutableAddr(), addr->LenPtr());
+
+ int acceptErrno = LastSystemError();
+
+ if (acceptedSocket == INVALID_SOCKET) {
+ if (LastSystemError() != EWOULDBLOCK) {
+ Stats.LastAcceptErrorErrno = acceptErrno;
+ Stats.LastAcceptErrorInstant = now;
+ ++Stats.AcceptErrorCount;
+ }
+ } else {
+ TSocketHolder s(acceptedSocket);
+ try {
+ SetKeepAlive(s, true);
+ SetNoDelay(s, Session->Config.TcpNoDelay);
+ SetSockOptTcpCork(s, Session->Config.TcpCork);
+ SetCloseOnExec(s, true);
+ SetNonBlock(s, true);
+ if (Session->Config.SocketToS >= 0) {
+ SetSocketToS(s, addr.Get(), Session->Config.SocketToS);
+ }
+ } catch (...) {
+ // It means that connection was reset just now
+ // TODO: do something better
+ goto skipAccept;
+ }
+
+ {
+ TOnAccept onAccept;
+ onAccept.s = s.Release();
+ onAccept.addr = TNetAddr(addr.Release());
+ onAccept.now = now;
+
+ LWPROBE(Accepted, ToString(onAccept.addr));
+
+ Session->GetOnAcceptQueue()->EnqueueAndSchedule(onAccept);
+
+ Stats.LastAcceptSuccessInstant = now;
+ ++Stats.AcceptSuccessCount;
+ }
+
+ skipAccept:;
+ }
+
+ Channel->EnableRead();
+
+ SendStatus(now);
+}
+
+void TAcceptor::SendStatus(TInstant now) {
+ GranStatus.Listen.Update(Stats, now);
+}
+
+void TAcceptor::HandleEvent(SOCKET socket, void* cookie) {
+ Y_UNUSED(socket);
+ Y_UNUSED(cookie);
+
+ GetActor()->Schedule();
+}
+
+void TAcceptor::Shutdown() {
+ ShutdownState.ShutdownCommand();
+ GetActor()->Schedule();
+
+ ShutdownState.ShutdownComplete.WaitI();
+}
diff --git a/library/cpp/messagebus/acceptor.h b/library/cpp/messagebus/acceptor.h
new file mode 100644
index 0000000000..57cb010bf2
--- /dev/null
+++ b/library/cpp/messagebus/acceptor.h
@@ -0,0 +1,60 @@
+#pragma once
+
+#include "acceptor_status.h"
+#include "defs.h"
+#include "event_loop.h"
+#include "netaddr.h"
+#include "session_impl.h"
+#include "shutdown_state.h"
+
+#include <library/cpp/messagebus/actor/actor.h>
+
+#include <util/system/event.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TAcceptor
+ : public NEventLoop::IEventHandler,
+ private ::NActor::TActor<TAcceptor> {
+ friend struct TBusSessionImpl;
+ friend class ::NActor::TActor<TAcceptor>;
+
+ public:
+ TAcceptor(TBusSessionImpl* session, ui64 acceptorId, SOCKET socket, const TNetAddr& addr);
+
+ void HandleEvent(SOCKET socket, void* cookie) override;
+
+ void Shutdown();
+
+ inline ::NActor::TActor<TAcceptor>* GetActor() {
+ return this;
+ }
+
+ private:
+ void SendStatus(TInstant now);
+ void Act(::NActor::TDefaultTag);
+
+ private:
+ const ui64 AcceptorId;
+
+ TBusSessionImpl* const Session;
+ NEventLoop::TChannelPtr Channel;
+
+ TAcceptorStatus Stats;
+
+ TAtomicShutdownState ShutdownState;
+
+ struct TGranStatus {
+ TGranStatus(TDuration gran)
+ : Listen(gran)
+ {
+ }
+
+ TGranUp<TAcceptorStatus> Listen;
+ };
+
+ TGranStatus GranStatus;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/acceptor_status.cpp b/library/cpp/messagebus/acceptor_status.cpp
new file mode 100644
index 0000000000..5006ff68ae
--- /dev/null
+++ b/library/cpp/messagebus/acceptor_status.cpp
@@ -0,0 +1,68 @@
+#include "acceptor_status.h"
+
+#include "key_value_printer.h"
+
+#include <util/stream/format.h>
+#include <util/stream/output.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TAcceptorStatus::TAcceptorStatus()
+ : Summary(false)
+ , AcceptorId(0)
+ , Fd(INVALID_SOCKET)
+{
+ ResetIncremental();
+}
+
+void TAcceptorStatus::ResetIncremental() {
+ AcceptSuccessCount = 0;
+ AcceptErrorCount = 0;
+ LastAcceptErrorErrno = 0;
+ LastAcceptErrorInstant = TInstant();
+ LastAcceptSuccessInstant = TInstant();
+}
+
+TAcceptorStatus& TAcceptorStatus::operator+=(const TAcceptorStatus& that) {
+ Y_ASSERT(Summary);
+ Y_ASSERT(AcceptorId == 0);
+
+ AcceptSuccessCount += that.AcceptSuccessCount;
+ LastAcceptSuccessInstant = Max(LastAcceptSuccessInstant, that.LastAcceptSuccessInstant);
+
+ AcceptErrorCount += that.AcceptErrorCount;
+ if (that.LastAcceptErrorInstant > LastAcceptErrorInstant) {
+ LastAcceptErrorInstant = that.LastAcceptErrorInstant;
+ LastAcceptErrorErrno = that.LastAcceptErrorErrno;
+ }
+
+ return *this;
+}
+
+TString TAcceptorStatus::PrintToString() const {
+ TStringStream ss;
+
+ if (!Summary) {
+ ss << "acceptor (" << AcceptorId << "), fd=" << Fd << ", addr=" << ListenAddr << Endl;
+ }
+
+ TKeyValuePrinter p;
+
+ p.AddRow("accept error count", LeftPad(AcceptErrorCount, 4));
+
+ if (AcceptErrorCount > 0) {
+ p.AddRow("last accept error",
+ TString() + LastSystemErrorText(LastAcceptErrorErrno) + " at " + LastAcceptErrorInstant.ToString());
+ }
+
+ p.AddRow("accept success count", LeftPad(AcceptSuccessCount, 4));
+ if (AcceptSuccessCount > 0) {
+ p.AddRow("last accept success",
+ TString() + "at " + LastAcceptSuccessInstant.ToString());
+ }
+
+ ss << p.PrintToString();
+
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/acceptor_status.h b/library/cpp/messagebus/acceptor_status.h
new file mode 100644
index 0000000000..6aa1404f4d
--- /dev/null
+++ b/library/cpp/messagebus/acceptor_status.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "netaddr.h"
+
+#include <util/network/init.h>
+
+namespace NBus {
+ namespace NPrivate {
+ struct TAcceptorStatus {
+ bool Summary;
+
+ ui64 AcceptorId;
+
+ SOCKET Fd;
+
+ TNetAddr ListenAddr;
+
+ unsigned AcceptSuccessCount;
+ TInstant LastAcceptSuccessInstant;
+
+ unsigned AcceptErrorCount;
+ TInstant LastAcceptErrorInstant;
+ int LastAcceptErrorErrno;
+
+ void ResetIncremental();
+
+ TAcceptorStatus();
+
+ TAcceptorStatus& operator+=(const TAcceptorStatus& that);
+
+ TString PrintToString() const;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/actor/actor.h b/library/cpp/messagebus/actor/actor.h
new file mode 100644
index 0000000000..9b8f20298a
--- /dev/null
+++ b/library/cpp/messagebus/actor/actor.h
@@ -0,0 +1,144 @@
+#pragma once
+
+#include "executor.h"
+#include "tasks.h"
+#include "what_thread_does.h"
+
+#include <util/system/yassert.h>
+
+namespace NActor {
+ class IActor: protected IWorkItem {
+ public:
+ // TODO: make private
+ TTasks Tasks;
+
+ public:
+ virtual void ScheduleHereV() = 0;
+ virtual void ScheduleV() = 0;
+ virtual void ScheduleHereAtMostOnceV() = 0;
+
+ // TODO: make private
+ virtual void RefV() = 0;
+ virtual void UnRefV() = 0;
+
+ // mute warnings
+ ~IActor() override {
+ }
+ };
+
+ struct TDefaultTag {};
+
+ template <typename TThis, typename TTag = TDefaultTag>
+ class TActor: public IActor {
+ private:
+ TExecutor* const Executor;
+
+ public:
+ TActor(TExecutor* executor)
+ : Executor(executor)
+ {
+ }
+
+ void AddTaskFromActorLoop() {
+ bool schedule = Tasks.AddTask();
+ // TODO: check thread id
+ Y_ASSERT(!schedule);
+ }
+
+ /**
+ * Schedule actor.
+ *
+ * If actor is sleeping, then actor will be executed right now.
+ * If actor is executing right now, it will be executed one more time.
+ * If this method is called multiple time, actor will be re-executed no more than one more time.
+ */
+ void Schedule() {
+ if (Tasks.AddTask()) {
+ EnqueueWork();
+ }
+ }
+
+ /**
+ * Schedule actor, execute it in current thread.
+ *
+ * If actor is running, continue executing where it is executing.
+ * If actor is sleeping, execute it in current thread.
+ *
+ * Operation is useful for tasks that are likely to complete quickly.
+ */
+ void ScheduleHere() {
+ if (Tasks.AddTask()) {
+ Loop();
+ }
+ }
+
+ /**
+ * Schedule actor, execute in current thread no more than once.
+ *
+ * If actor is running, continue executing where it is executing.
+ * If actor is sleeping, execute one iteration here, and if actor got new tasks,
+ * reschedule it in worker pool.
+ */
+ void ScheduleHereAtMostOnce() {
+ if (Tasks.AddTask()) {
+ bool fetched = Tasks.FetchTask();
+ Y_VERIFY(fetched, "happens");
+
+ DoAct();
+
+ // if someone added more tasks, schedule them
+ if (Tasks.FetchTask()) {
+ bool added = Tasks.AddTask();
+ Y_VERIFY(!added, "happens");
+ EnqueueWork();
+ }
+ }
+ }
+
+ void ScheduleHereV() override {
+ ScheduleHere();
+ }
+ void ScheduleV() override {
+ Schedule();
+ }
+ void ScheduleHereAtMostOnceV() override {
+ ScheduleHereAtMostOnce();
+ }
+ void RefV() override {
+ GetThis()->Ref();
+ }
+ void UnRefV() override {
+ GetThis()->UnRef();
+ }
+
+ private:
+ TThis* GetThis() {
+ return static_cast<TThis*>(this);
+ }
+
+ void EnqueueWork() {
+ GetThis()->Ref();
+ Executor->EnqueueWork({this});
+ }
+
+ void DoAct() {
+ WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC();
+
+ GetThis()->Act(TTag());
+ }
+
+ void Loop() {
+ // TODO: limit number of iterations
+ while (Tasks.FetchTask()) {
+ DoAct();
+ }
+ }
+
+ void DoWork() override {
+ Y_ASSERT(GetThis()->RefCount() >= 1);
+ Loop();
+ GetThis()->UnRef();
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/actor/actor_ut.cpp b/library/cpp/messagebus/actor/actor_ut.cpp
new file mode 100644
index 0000000000..b76ab55bfa
--- /dev/null
+++ b/library/cpp/messagebus/actor/actor_ut.cpp
@@ -0,0 +1,157 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "actor.h"
+#include "queue_in_actor.h"
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+
+#include <util/generic/object_counter.h>
+#include <util/system/event.h>
+
+using namespace NActor;
+
+template <typename TThis>
+struct TTestActorBase: public TAtomicRefCount<TThis>, public TActor<TThis> {
+ TTestSync Started;
+ TTestSync Acted;
+
+ TTestActorBase(TExecutor* executor)
+ : TActor<TThis>(executor)
+ {
+ }
+
+ void Act(TDefaultTag) {
+ Started.Inc();
+ static_cast<TThis*>(this)->Act2();
+ Acted.Inc();
+ }
+};
+
+struct TNopActor: public TTestActorBase<TNopActor> {
+ TObjectCounter<TNopActor> AllocCounter;
+
+ TNopActor(TExecutor* executor)
+ : TTestActorBase<TNopActor>(executor)
+ {
+ }
+
+ void Act2() {
+ }
+};
+
+struct TWaitForSignalActor: public TTestActorBase<TWaitForSignalActor> {
+ TWaitForSignalActor(TExecutor* executor)
+ : TTestActorBase<TWaitForSignalActor>(executor)
+ {
+ }
+
+ TSystemEvent WaitFor;
+
+ void Act2() {
+ WaitFor.Wait();
+ }
+};
+
+struct TDecrementAndSendActor: public TTestActorBase<TDecrementAndSendActor>, public TQueueInActor<TDecrementAndSendActor, int> {
+ TSystemEvent Done;
+
+ TDecrementAndSendActor* Next;
+
+ TDecrementAndSendActor(TExecutor* executor)
+ : TTestActorBase<TDecrementAndSendActor>(executor)
+ , Next(nullptr)
+ {
+ }
+
+ void ProcessItem(TDefaultTag, TDefaultTag, int n) {
+ if (n == 0) {
+ Done.Signal();
+ } else {
+ Next->EnqueueAndSchedule(n - 1);
+ }
+ }
+
+ void Act(TDefaultTag) {
+ DequeueAll();
+ }
+};
+
+struct TObjectCountChecker {
+ TObjectCountChecker() {
+ CheckCounts();
+ }
+
+ ~TObjectCountChecker() {
+ CheckCounts();
+ }
+
+ void CheckCounts() {
+ UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TNopActor>::ObjectCount());
+ UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TWaitForSignalActor>::ObjectCount());
+ UNIT_ASSERT_VALUES_EQUAL(TAtomicBase(0), TObjectCounter<TDecrementAndSendActor>::ObjectCount());
+ }
+};
+
+Y_UNIT_TEST_SUITE(TActor) {
+ Y_UNIT_TEST(Simple) {
+ TObjectCountChecker objectCountChecker;
+
+ TExecutor executor(4);
+
+ TIntrusivePtr<TNopActor> actor(new TNopActor(&executor));
+
+ actor->Schedule();
+
+ actor->Acted.WaitFor(1u);
+ }
+
+ Y_UNIT_TEST(ScheduleAfterStart) {
+ TObjectCountChecker objectCountChecker;
+
+ TExecutor executor(4);
+
+ TIntrusivePtr<TWaitForSignalActor> actor(new TWaitForSignalActor(&executor));
+
+ actor->Schedule();
+
+ actor->Started.WaitFor(1);
+
+ actor->Schedule();
+
+ actor->WaitFor.Signal();
+
+ // make sure Act is called second time
+ actor->Acted.WaitFor(2u);
+ }
+
+ void ComplexImpl(int queueSize, int actorCount) {
+ TObjectCountChecker objectCountChecker;
+
+ TExecutor executor(queueSize);
+
+ TVector<TIntrusivePtr<TDecrementAndSendActor>> actors;
+ for (int i = 0; i < actorCount; ++i) {
+ actors.push_back(new TDecrementAndSendActor(&executor));
+ }
+
+ for (int i = 0; i < actorCount; ++i) {
+ actors.at(i)->Next = &*actors.at((i + 1) % actorCount);
+ }
+
+ for (int i = 0; i < actorCount; ++i) {
+ actors.at(i)->EnqueueAndSchedule(10000);
+ }
+
+ for (int i = 0; i < actorCount; ++i) {
+ actors.at(i)->Done.WaitI();
+ }
+ }
+
+ Y_UNIT_TEST(ComplexContention) {
+ ComplexImpl(4, 6);
+ }
+
+ Y_UNIT_TEST(ComplexNoContention) {
+ ComplexImpl(6, 4);
+ }
+}
diff --git a/library/cpp/messagebus/actor/executor.cpp b/library/cpp/messagebus/actor/executor.cpp
new file mode 100644
index 0000000000..7a2227a458
--- /dev/null
+++ b/library/cpp/messagebus/actor/executor.cpp
@@ -0,0 +1,338 @@
+#include "executor.h"
+
+#include "thread_extra.h"
+#include "what_thread_does.h"
+#include "what_thread_does_guard.h"
+
+#include <util/generic/utility.h>
+#include <util/random/random.h>
+#include <util/stream/str.h>
+#include <util/system/tls.h>
+#include <util/system/yassert.h>
+
+#include <array>
+
+using namespace NActor;
+using namespace NActor::NPrivate;
+
+namespace {
+ struct THistoryInternal {
+ struct TRecord {
+ TAtomic MaxQueueSize;
+
+ TRecord()
+ : MaxQueueSize()
+ {
+ }
+
+ TExecutorHistory::THistoryRecord Capture() {
+ TExecutorHistory::THistoryRecord r;
+ r.MaxQueueSize = AtomicGet(MaxQueueSize);
+ return r;
+ }
+ };
+
+ ui64 Start;
+ ui64 LastTime;
+
+ std::array<TRecord, 3600> Records;
+
+ THistoryInternal() {
+ Start = TInstant::Now().Seconds();
+ LastTime = Start - 1;
+ }
+
+ TRecord& GetRecordForTime(ui64 time) {
+ return Records[time % Records.size()];
+ }
+
+ TRecord& GetNowRecord(ui64 now) {
+ for (ui64 t = LastTime + 1; t <= now; ++t) {
+ GetRecordForTime(t) = TRecord();
+ }
+ LastTime = now;
+ return GetRecordForTime(now);
+ }
+
+ TExecutorHistory Capture() {
+ TExecutorHistory history;
+ ui64 now = TInstant::Now().Seconds();
+ ui64 lastHistoryRecord = now - 1;
+ ui32 historySize = Min<ui32>(lastHistoryRecord - Start, Records.size() - 1);
+ history.HistoryRecords.resize(historySize);
+ for (ui32 i = 0; i < historySize; ++i) {
+ history.HistoryRecords[i] = GetRecordForTime(lastHistoryRecord - historySize + i).Capture();
+ }
+ history.LastHistoryRecordSecond = lastHistoryRecord;
+ return history;
+ }
+ };
+
+}
+
+Y_POD_STATIC_THREAD(TExecutor*)
+ThreadCurrentExecutor;
+
+static const char* NoLocation = "nowhere";
+
+struct TExecutorWorkerThreadLocalData {
+ ui32 MaxQueueSize;
+};
+
+static TExecutorWorkerThreadLocalData WorkerNoThreadLocalData;
+Y_POD_STATIC_THREAD(TExecutorWorkerThreadLocalData)
+WorkerThreadLocalData;
+
+namespace NActor {
+ struct TExecutorWorker {
+ TExecutor* const Executor;
+ TThread Thread;
+ const char** WhatThreadDoesLocation;
+ TExecutorWorkerThreadLocalData* ThreadLocalData;
+
+ TExecutorWorker(TExecutor* executor)
+ : Executor(executor)
+ , Thread(RunThreadProc, this)
+ , WhatThreadDoesLocation(&NoLocation)
+ , ThreadLocalData(&::WorkerNoThreadLocalData)
+ {
+ Thread.Start();
+ }
+
+ void Run() {
+ WhatThreadDoesLocation = ::WhatThreadDoesLocation();
+ AtomicSet(ThreadLocalData, &::WorkerThreadLocalData);
+ WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC();
+ Executor->RunWorker();
+ }
+
+ static void* RunThreadProc(void* thiz0) {
+ TExecutorWorker* thiz = (TExecutorWorker*)thiz0;
+ thiz->Run();
+ return nullptr;
+ }
+ };
+
+ struct TExecutor::TImpl {
+ TExecutor* const Executor;
+ THistoryInternal History;
+
+ TSystemEvent HelperStopSignal;
+ TThread HelperThread;
+
+ TImpl(TExecutor* executor)
+ : Executor(executor)
+ , HelperThread(HelperThreadProc, this)
+ {
+ }
+
+ void RunHelper() {
+ ui64 nowSeconds = TInstant::Now().Seconds();
+ for (;;) {
+ TInstant nextStop = TInstant::Seconds(nowSeconds + 1) + TDuration::MilliSeconds(RandomNumber<ui32>(1000));
+
+ if (HelperStopSignal.WaitD(nextStop)) {
+ return;
+ }
+
+ nowSeconds = nextStop.Seconds();
+
+ THistoryInternal::TRecord& record = History.GetNowRecord(nowSeconds);
+
+ ui32 maxQueueSize = Executor->GetMaxQueueSizeAndClear();
+ if (maxQueueSize > record.MaxQueueSize) {
+ AtomicSet(record.MaxQueueSize, maxQueueSize);
+ }
+ }
+ }
+
+ static void* HelperThreadProc(void* impl0) {
+ TImpl* impl = (TImpl*)impl0;
+ impl->RunHelper();
+ return nullptr;
+ }
+ };
+
+}
+
+static TExecutor::TConfig MakeConfig(unsigned workerCount) {
+ TExecutor::TConfig config;
+ config.WorkerCount = workerCount;
+ return config;
+}
+
+TExecutor::TExecutor(size_t workerCount)
+ : Config(MakeConfig(workerCount))
+{
+ Init();
+}
+
+TExecutor::TExecutor(const TExecutor::TConfig& config)
+ : Config(config)
+{
+ Init();
+}
+
+void TExecutor::Init() {
+ Impl.Reset(new TImpl(this));
+
+ AtomicSet(ExitWorkers, 0);
+
+ Y_VERIFY(Config.WorkerCount > 0);
+
+ for (size_t i = 0; i < Config.WorkerCount; i++) {
+ WorkerThreads.push_back(new TExecutorWorker(this));
+ }
+
+ Impl->HelperThread.Start();
+}
+
+TExecutor::~TExecutor() {
+ Stop();
+}
+
+void TExecutor::Stop() {
+ AtomicSet(ExitWorkers, 1);
+
+ Impl->HelperStopSignal.Signal();
+ Impl->HelperThread.Join();
+
+ {
+ TWhatThreadDoesAcquireGuard<TMutex> guard(WorkMutex, "executor: acquiring lock for Stop");
+ WorkAvailable.BroadCast();
+ }
+
+ for (size_t i = 0; i < WorkerThreads.size(); i++) {
+ WorkerThreads[i]->Thread.Join();
+ }
+
+ // TODO: make queue empty at this point
+ ProcessWorkQueueHere();
+}
+
+void TExecutor::EnqueueWork(TArrayRef<IWorkItem* const> wis) {
+ if (wis.empty())
+ return;
+
+ if (Y_UNLIKELY(AtomicGet(ExitWorkers) != 0)) {
+ Y_VERIFY(WorkItems.Empty(), "executor %s: cannot add tasks after queue shutdown", Config.Name);
+ }
+
+ TWhatThreadDoesPushPop pp("executor: EnqueueWork");
+
+ WorkItems.PushAll(wis);
+
+ {
+ if (wis.size() == 1) {
+ TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for EnqueueWork");
+ WorkAvailable.Signal();
+ } else {
+ TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for EnqueueWork");
+ WorkAvailable.BroadCast();
+ }
+ }
+}
+
+size_t TExecutor::GetWorkQueueSize() const {
+ return WorkItems.Size();
+}
+
+using namespace NTSAN;
+
+ui32 TExecutor::GetMaxQueueSizeAndClear() const {
+ ui32 max = 0;
+ for (unsigned i = 0; i < WorkerThreads.size(); ++i) {
+ TExecutorWorkerThreadLocalData* wtls = RelaxedLoad(&WorkerThreads[i]->ThreadLocalData);
+ max = Max<ui32>(max, RelaxedLoad(&wtls->MaxQueueSize));
+ RelaxedStore<ui32>(&wtls->MaxQueueSize, 0);
+ }
+ return max;
+}
+
+TString TExecutor::GetStatus() const {
+ return GetStatusRecordInternal().Status;
+}
+
+TString TExecutor::GetStatusSingleLine() const {
+ TStringStream ss;
+ ss << "work items: " << GetWorkQueueSize();
+ return ss.Str();
+}
+
+TExecutorStatus TExecutor::GetStatusRecordInternal() const {
+ TExecutorStatus r;
+
+ r.WorkQueueSize = GetWorkQueueSize();
+
+ {
+ TStringStream ss;
+ ss << "work items: " << GetWorkQueueSize() << "\n";
+ ss << "workers:\n";
+ for (unsigned i = 0; i < WorkerThreads.size(); ++i) {
+ ss << "-- " << AtomicGet(*AtomicGet(WorkerThreads[i]->WhatThreadDoesLocation)) << "\n";
+ }
+ r.Status = ss.Str();
+ }
+
+ r.History = Impl->History.Capture();
+
+ return r;
+}
+
+bool TExecutor::IsInExecutorThread() const {
+ return ThreadCurrentExecutor == this;
+}
+
+TAutoPtr<IWorkItem> TExecutor::DequeueWork() {
+ IWorkItem* wi = reinterpret_cast<IWorkItem*>(1);
+ size_t queueSize = Max<size_t>();
+ if (!WorkItems.TryPop(&wi, &queueSize)) {
+ TWhatThreadDoesAcquireGuard<TMutex> g(WorkMutex, "executor: acquiring lock for DequeueWork");
+ while (!WorkItems.TryPop(&wi, &queueSize)) {
+ if (AtomicGet(ExitWorkers) != 0)
+ return nullptr;
+
+ TWhatThreadDoesPushPop pp("waiting for work on condvar");
+ WorkAvailable.Wait(WorkMutex);
+ }
+ }
+
+ auto& wtls = TlsRef(WorkerThreadLocalData);
+
+ if (queueSize > RelaxedLoad(&wtls.MaxQueueSize)) {
+ RelaxedStore<ui32>(&wtls.MaxQueueSize, queueSize);
+ }
+
+ return wi;
+}
+
+void TExecutor::RunWorkItem(TAutoPtr<IWorkItem> wi) {
+ WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC();
+ wi.Release()->DoWork();
+}
+
+void TExecutor::ProcessWorkQueueHere() {
+ IWorkItem* wi;
+ while (WorkItems.TryPop(&wi)) {
+ RunWorkItem(wi);
+ }
+}
+
+void TExecutor::RunWorker() {
+ Y_VERIFY(!ThreadCurrentExecutor, "state check");
+ ThreadCurrentExecutor = this;
+
+ SetCurrentThreadName("wrkr");
+
+ for (;;) {
+ TAutoPtr<IWorkItem> wi = DequeueWork();
+ if (!wi) {
+ break;
+ }
+ // Note for messagebus users: make sure program crashes
+ // on uncaught exception in thread, otherewise messagebus may just hang on error.
+ RunWorkItem(wi);
+ }
+
+ ThreadCurrentExecutor = (TExecutor*)nullptr;
+}
diff --git a/library/cpp/messagebus/actor/executor.h b/library/cpp/messagebus/actor/executor.h
new file mode 100644
index 0000000000..7292d8be53
--- /dev/null
+++ b/library/cpp/messagebus/actor/executor.h
@@ -0,0 +1,105 @@
+#pragma once
+
+#include "ring_buffer_with_spin_lock.h"
+
+#include <util/generic/array_ref.h>
+#include <util/generic/vector.h>
+#include <util/system/condvar.h>
+#include <util/system/event.h>
+#include <util/system/mutex.h>
+#include <util/system/thread.h>
+#include <util/thread/lfqueue.h>
+
+namespace NActor {
+ namespace NPrivate {
+ struct TExecutorHistory {
+ struct THistoryRecord {
+ ui32 MaxQueueSize;
+ };
+ TVector<THistoryRecord> HistoryRecords;
+ ui64 LastHistoryRecordSecond;
+
+ ui64 FirstHistoryRecordSecond() const {
+ return LastHistoryRecordSecond - HistoryRecords.size() + 1;
+ }
+ };
+
+ struct TExecutorStatus {
+ size_t WorkQueueSize = 0;
+ TExecutorHistory History;
+ TString Status;
+ };
+ }
+
+ class IWorkItem {
+ public:
+ virtual ~IWorkItem() {
+ }
+ virtual void DoWork(/* must release this */) = 0;
+ };
+
+ struct TExecutorWorker;
+
+ class TExecutor: public TAtomicRefCount<TExecutor> {
+ friend struct TExecutorWorker;
+
+ public:
+ struct TConfig {
+ size_t WorkerCount;
+ const char* Name;
+
+ TConfig()
+ : WorkerCount(1)
+ , Name()
+ {
+ }
+ };
+
+ private:
+ struct TImpl;
+ THolder<TImpl> Impl;
+
+ const TConfig Config;
+
+ TAtomic ExitWorkers;
+
+ TVector<TAutoPtr<TExecutorWorker>> WorkerThreads;
+
+ TRingBufferWithSpinLock<IWorkItem*> WorkItems;
+
+ TMutex WorkMutex;
+ TCondVar WorkAvailable;
+
+ public:
+ explicit TExecutor(size_t workerCount);
+ TExecutor(const TConfig& config);
+ ~TExecutor();
+
+ void Stop();
+
+ void EnqueueWork(TArrayRef<IWorkItem* const> w);
+
+ size_t GetWorkQueueSize() const;
+ TString GetStatus() const;
+ TString GetStatusSingleLine() const;
+ NPrivate::TExecutorStatus GetStatusRecordInternal() const;
+
+ bool IsInExecutorThread() const;
+
+ private:
+ void Init();
+
+ TAutoPtr<IWorkItem> DequeueWork();
+
+ void ProcessWorkQueueHere();
+
+ inline void RunWorkItem(TAutoPtr<IWorkItem>);
+
+ void RunWorker();
+
+ ui32 GetMaxQueueSizeAndClear() const;
+ };
+
+ using TExecutorPtr = TIntrusivePtr<TExecutor>;
+
+}
diff --git a/library/cpp/messagebus/actor/queue_for_actor.h b/library/cpp/messagebus/actor/queue_for_actor.h
new file mode 100644
index 0000000000..40fa536b82
--- /dev/null
+++ b/library/cpp/messagebus/actor/queue_for_actor.h
@@ -0,0 +1,74 @@
+#pragma once
+
+#include <util/generic/vector.h>
+#include <util/system/yassert.h>
+#include <util/thread/lfstack.h>
+#include <util/thread/singleton.h>
+
+// TODO: include from correct directory
+#include "temp_tls_vector.h"
+
+namespace NActor {
+ namespace NPrivate {
+ struct TTagForTl {};
+
+ }
+
+ template <typename T>
+ class TQueueForActor {
+ private:
+ TLockFreeStack<T> Queue;
+
+ public:
+ ~TQueueForActor() {
+ Y_VERIFY(Queue.IsEmpty());
+ }
+
+ bool IsEmpty() {
+ return Queue.IsEmpty();
+ }
+
+ void Enqueue(const T& value) {
+ Queue.Enqueue(value);
+ }
+
+ template <typename TCollection>
+ void EnqueueAll(const TCollection& all) {
+ Queue.EnqueueAll(all);
+ }
+
+ void Clear() {
+ TVector<T> tmp;
+ Queue.DequeueAll(&tmp);
+ }
+
+ template <typename TFunc>
+ void DequeueAll(const TFunc& func
+ // TODO: , std::enable_if_t<TFunctionParamCount<TFunc>::Value == 1>* = 0
+ ) {
+ TTempTlsVector<T> temp;
+
+ Queue.DequeueAllSingleConsumer(temp.GetVector());
+
+ for (typename TVector<T>::reverse_iterator i = temp.GetVector()->rbegin(); i != temp.GetVector()->rend(); ++i) {
+ func(*i);
+ }
+
+ temp.Clear();
+
+ if (temp.Capacity() * sizeof(T) > 64 * 1024) {
+ temp.Shrink();
+ }
+ }
+
+ template <typename TFunc>
+ void DequeueAllLikelyEmpty(const TFunc& func) {
+ if (Y_LIKELY(IsEmpty())) {
+ return;
+ }
+
+ DequeueAll(func);
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/actor/queue_in_actor.h b/library/cpp/messagebus/actor/queue_in_actor.h
new file mode 100644
index 0000000000..9865996532
--- /dev/null
+++ b/library/cpp/messagebus/actor/queue_in_actor.h
@@ -0,0 +1,80 @@
+#pragma once
+
+#include "actor.h"
+#include "queue_for_actor.h"
+
+#include <functional>
+
+namespace NActor {
+ template <typename TItem>
+ class IQueueInActor {
+ public:
+ virtual void EnqueueAndScheduleV(const TItem& item) = 0;
+ virtual void DequeueAllV() = 0;
+ virtual void DequeueAllLikelyEmptyV() = 0;
+
+ virtual ~IQueueInActor() {
+ }
+ };
+
+ template <typename TThis, typename TItem, typename TActorTag = TDefaultTag, typename TQueueTag = TDefaultTag>
+ class TQueueInActor: public IQueueInActor<TItem> {
+ typedef TQueueInActor<TThis, TItem, TActorTag, TQueueTag> TSelf;
+
+ public:
+ // TODO: make protected
+ TQueueForActor<TItem> QueueInActor;
+
+ private:
+ TActor<TThis, TActorTag>* GetActor() {
+ return GetThis();
+ }
+
+ TThis* GetThis() {
+ return static_cast<TThis*>(this);
+ }
+
+ void ProcessItem(const TItem& item) {
+ GetThis()->ProcessItem(TActorTag(), TQueueTag(), item);
+ }
+
+ public:
+ void EnqueueAndNoSchedule(const TItem& item) {
+ QueueInActor.Enqueue(item);
+ }
+
+ void EnqueueAndSchedule(const TItem& item) {
+ EnqueueAndNoSchedule(item);
+ GetActor()->Schedule();
+ }
+
+ void EnqueueAndScheduleV(const TItem& item) override {
+ EnqueueAndSchedule(item);
+ }
+
+ void Clear() {
+ QueueInActor.Clear();
+ }
+
+ void DequeueAll() {
+ QueueInActor.DequeueAll(std::bind(&TSelf::ProcessItem, this, std::placeholders::_1));
+ }
+
+ void DequeueAllV() override {
+ return DequeueAll();
+ }
+
+ void DequeueAllLikelyEmpty() {
+ QueueInActor.DequeueAllLikelyEmpty(std::bind(&TSelf::ProcessItem, this, std::placeholders::_1));
+ }
+
+ void DequeueAllLikelyEmptyV() override {
+ return DequeueAllLikelyEmpty();
+ }
+
+ bool IsEmpty() {
+ return QueueInActor.IsEmpty();
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/actor/ring_buffer.h b/library/cpp/messagebus/actor/ring_buffer.h
new file mode 100644
index 0000000000..ec5706f7c7
--- /dev/null
+++ b/library/cpp/messagebus/actor/ring_buffer.h
@@ -0,0 +1,135 @@
+#pragma once
+
+#include <util/generic/array_ref.h>
+#include <util/generic/maybe.h>
+#include <util/generic/utility.h>
+#include <util/generic/vector.h>
+#include <util/system/yassert.h>
+
+template <typename T>
+struct TRingBuffer {
+private:
+ ui32 CapacityPow;
+ ui32 CapacityMask;
+ ui32 Capacity;
+ ui32 WritePos;
+ ui32 ReadPos;
+ TVector<T> Data;
+
+ void StateCheck() const {
+ Y_ASSERT(Capacity == Data.size());
+ Y_ASSERT(Capacity == (1u << CapacityPow));
+ Y_ASSERT((Capacity & CapacityMask) == 0u);
+ Y_ASSERT(Capacity - CapacityMask == 1u);
+ Y_ASSERT(WritePos < Capacity);
+ Y_ASSERT(ReadPos < Capacity);
+ }
+
+ size_t Writable() const {
+ return (Capacity + ReadPos - WritePos - 1) & CapacityMask;
+ }
+
+ void ReserveWritable(ui32 sz) {
+ if (sz <= Writable())
+ return;
+
+ ui32 newCapacityPow = CapacityPow;
+ while ((1u << newCapacityPow) < sz + ui32(Size()) + 1u) {
+ ++newCapacityPow;
+ }
+ ui32 newCapacity = 1u << newCapacityPow;
+ ui32 newCapacityMask = newCapacity - 1u;
+ TVector<T> newData(newCapacity);
+ ui32 oldSize = Size();
+ // Copy old elements
+ for (size_t i = 0; i < oldSize; ++i) {
+ newData[i] = Get(i);
+ }
+
+ CapacityPow = newCapacityPow;
+ Capacity = newCapacity;
+ CapacityMask = newCapacityMask;
+ Data.swap(newData);
+ ReadPos = 0;
+ WritePos = oldSize;
+
+ StateCheck();
+ }
+
+ const T& Get(ui32 i) const {
+ return Data[(ReadPos + i) & CapacityMask];
+ }
+
+public:
+ TRingBuffer()
+ : CapacityPow(0)
+ , CapacityMask(0)
+ , Capacity(1 << CapacityPow)
+ , WritePos(0)
+ , ReadPos(0)
+ , Data(Capacity)
+ {
+ StateCheck();
+ }
+
+ size_t Size() const {
+ return (Capacity + WritePos - ReadPos) & CapacityMask;
+ }
+
+ bool Empty() const {
+ return WritePos == ReadPos;
+ }
+
+ void PushAll(TArrayRef<const T> value) {
+ ReserveWritable(value.size());
+
+ ui32 secondSize;
+ ui32 firstSize;
+
+ if (WritePos + value.size() <= Capacity) {
+ firstSize = value.size();
+ secondSize = 0;
+ } else {
+ firstSize = Capacity - WritePos;
+ secondSize = value.size() - firstSize;
+ }
+
+ for (size_t i = 0; i < firstSize; ++i) {
+ Data[WritePos + i] = value[i];
+ }
+
+ for (size_t i = 0; i < secondSize; ++i) {
+ Data[i] = value[firstSize + i];
+ }
+
+ WritePos = (WritePos + value.size()) & CapacityMask;
+ StateCheck();
+ }
+
+ void Push(const T& t) {
+ PushAll(MakeArrayRef(&t, 1));
+ }
+
+ bool TryPop(T* r) {
+ StateCheck();
+ if (Empty()) {
+ return false;
+ }
+ *r = Data[ReadPos];
+ ReadPos = (ReadPos + 1) & CapacityMask;
+ return true;
+ }
+
+ TMaybe<T> TryPop() {
+ T tmp;
+ if (TryPop(&tmp)) {
+ return tmp;
+ } else {
+ return TMaybe<T>();
+ }
+ }
+
+ T Pop() {
+ return *TryPop();
+ }
+};
diff --git a/library/cpp/messagebus/actor/ring_buffer_ut.cpp b/library/cpp/messagebus/actor/ring_buffer_ut.cpp
new file mode 100644
index 0000000000..bdb379b3a9
--- /dev/null
+++ b/library/cpp/messagebus/actor/ring_buffer_ut.cpp
@@ -0,0 +1,60 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "ring_buffer.h"
+
+#include <util/random/random.h>
+
+Y_UNIT_TEST_SUITE(RingBuffer) {
+ struct TRingBufferTester {
+ TRingBuffer<unsigned> RingBuffer;
+
+ unsigned NextPush;
+ unsigned NextPop;
+
+ TRingBufferTester()
+ : NextPush()
+ , NextPop()
+ {
+ }
+
+ void Push() {
+ //Cerr << "push " << NextPush << "\n";
+ RingBuffer.Push(NextPush);
+ NextPush += 1;
+ }
+
+ void Pop() {
+ //Cerr << "pop " << NextPop << "\n";
+ unsigned popped = RingBuffer.Pop();
+ UNIT_ASSERT_VALUES_EQUAL(NextPop, popped);
+ NextPop += 1;
+ }
+
+ bool Empty() const {
+ UNIT_ASSERT_VALUES_EQUAL(RingBuffer.Size(), NextPush - NextPop);
+ UNIT_ASSERT_VALUES_EQUAL(RingBuffer.Empty(), RingBuffer.Size() == 0);
+ return RingBuffer.Empty();
+ }
+ };
+
+ void Iter() {
+ TRingBufferTester rb;
+
+ while (rb.NextPush < 1000) {
+ rb.Push();
+ while (!rb.Empty() && RandomNumber<bool>()) {
+ rb.Pop();
+ }
+ }
+
+ while (!rb.Empty()) {
+ rb.Pop();
+ }
+ }
+
+ Y_UNIT_TEST(Random) {
+ for (unsigned i = 0; i < 100; ++i) {
+ Iter();
+ }
+ }
+}
diff --git a/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h b/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h
new file mode 100644
index 0000000000..f0b7cd90e4
--- /dev/null
+++ b/library/cpp/messagebus/actor/ring_buffer_with_spin_lock.h
@@ -0,0 +1,91 @@
+#pragma once
+
+#include "ring_buffer.h"
+
+#include <util/system/spinlock.h>
+
+template <typename T>
+class TRingBufferWithSpinLock {
+private:
+ TRingBuffer<T> RingBuffer;
+ TSpinLock SpinLock;
+ TAtomic CachedSize;
+
+public:
+ TRingBufferWithSpinLock()
+ : CachedSize(0)
+ {
+ }
+
+ void Push(const T& t) {
+ PushAll(t);
+ }
+
+ void PushAll(TArrayRef<const T> collection) {
+ if (collection.empty()) {
+ return;
+ }
+
+ TGuard<TSpinLock> Guard(SpinLock);
+ RingBuffer.PushAll(collection);
+ AtomicSet(CachedSize, RingBuffer.Size());
+ }
+
+ bool TryPop(T* r, size_t* sizePtr = nullptr) {
+ if (AtomicGet(CachedSize) == 0) {
+ return false;
+ }
+
+ bool ok;
+ size_t size;
+ {
+ TGuard<TSpinLock> Guard(SpinLock);
+ ok = RingBuffer.TryPop(r);
+ size = RingBuffer.Size();
+ AtomicSet(CachedSize, size);
+ }
+ if (!!sizePtr) {
+ *sizePtr = size;
+ }
+ return ok;
+ }
+
+ TMaybe<T> TryPop() {
+ T tmp;
+ if (TryPop(&tmp)) {
+ return tmp;
+ } else {
+ return TMaybe<T>();
+ }
+ }
+
+ bool PushAllAndTryPop(TArrayRef<const T> collection, T* r) {
+ if (collection.size() == 0) {
+ return TryPop(r);
+ } else {
+ if (AtomicGet(CachedSize) == 0) {
+ *r = collection[0];
+ if (collection.size() > 1) {
+ TGuard<TSpinLock> guard(SpinLock);
+ RingBuffer.PushAll(MakeArrayRef(collection.data() + 1, collection.size() - 1));
+ AtomicSet(CachedSize, RingBuffer.Size());
+ }
+ } else {
+ TGuard<TSpinLock> guard(SpinLock);
+ RingBuffer.PushAll(collection);
+ *r = RingBuffer.Pop();
+ AtomicSet(CachedSize, RingBuffer.Size());
+ }
+ return true;
+ }
+ }
+
+ bool Empty() const {
+ return AtomicGet(CachedSize) == 0;
+ }
+
+ size_t Size() const {
+ TGuard<TSpinLock> Guard(SpinLock);
+ return RingBuffer.Size();
+ }
+};
diff --git a/library/cpp/messagebus/actor/tasks.h b/library/cpp/messagebus/actor/tasks.h
new file mode 100644
index 0000000000..31d35931d2
--- /dev/null
+++ b/library/cpp/messagebus/actor/tasks.h
@@ -0,0 +1,48 @@
+#pragma once
+
+#include <util/system/atomic.h>
+#include <util/system/yassert.h>
+
+namespace NActor {
+ class TTasks {
+ enum {
+ // order of values is important
+ E_WAITING,
+ E_RUNNING_NO_TASKS,
+ E_RUNNING_GOT_TASKS,
+ };
+
+ private:
+ TAtomic State;
+
+ public:
+ TTasks()
+ : State(E_WAITING)
+ {
+ }
+
+ // @return true iff caller have to either schedule task or execute it
+ bool AddTask() {
+ // High contention case optimization: AtomicGet is cheaper than AtomicSwap.
+ if (E_RUNNING_GOT_TASKS == AtomicGet(State)) {
+ return false;
+ }
+
+ TAtomicBase oldState = AtomicSwap(&State, E_RUNNING_GOT_TASKS);
+ return oldState == E_WAITING;
+ }
+
+ // called by executor
+ // @return true iff we have to recheck queues
+ bool FetchTask() {
+ TAtomicBase newState = AtomicDecrement(State);
+ if (newState == E_RUNNING_NO_TASKS) {
+ return true;
+ } else if (newState == E_WAITING) {
+ return false;
+ }
+ Y_FAIL("unknown");
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/actor/tasks_ut.cpp b/library/cpp/messagebus/actor/tasks_ut.cpp
new file mode 100644
index 0000000000..d80e8451a5
--- /dev/null
+++ b/library/cpp/messagebus/actor/tasks_ut.cpp
@@ -0,0 +1,37 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "tasks.h"
+
+using namespace NActor;
+
+Y_UNIT_TEST_SUITE(TTasks) {
+ Y_UNIT_TEST(AddTask_FetchTask_Simple) {
+ TTasks tasks;
+
+ UNIT_ASSERT(tasks.AddTask());
+ UNIT_ASSERT(!tasks.AddTask());
+ UNIT_ASSERT(!tasks.AddTask());
+
+ UNIT_ASSERT(tasks.FetchTask());
+ UNIT_ASSERT(!tasks.FetchTask());
+
+ UNIT_ASSERT(tasks.AddTask());
+ }
+
+ Y_UNIT_TEST(AddTask_FetchTask_AddTask) {
+ TTasks tasks;
+
+ UNIT_ASSERT(tasks.AddTask());
+ UNIT_ASSERT(!tasks.AddTask());
+
+ UNIT_ASSERT(tasks.FetchTask());
+ UNIT_ASSERT(!tasks.AddTask());
+ UNIT_ASSERT(tasks.FetchTask());
+ UNIT_ASSERT(!tasks.AddTask());
+ UNIT_ASSERT(!tasks.AddTask());
+ UNIT_ASSERT(tasks.FetchTask());
+ UNIT_ASSERT(!tasks.FetchTask());
+
+ UNIT_ASSERT(tasks.AddTask());
+ }
+}
diff --git a/library/cpp/messagebus/actor/temp_tls_vector.h b/library/cpp/messagebus/actor/temp_tls_vector.h
new file mode 100644
index 0000000000..675d92f5b0
--- /dev/null
+++ b/library/cpp/messagebus/actor/temp_tls_vector.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include "thread_extra.h"
+
+#include <util/generic/vector.h>
+#include <util/system/yassert.h>
+
+template <typename T, typename TTag = void, template <typename, class> class TVectorType = TVector>
+class TTempTlsVector {
+private:
+ struct TTagForTls {};
+
+ TVectorType<T, std::allocator<T>>* Vector;
+
+public:
+ TVectorType<T, std::allocator<T>>* GetVector() {
+ return Vector;
+ }
+
+ TTempTlsVector() {
+ Vector = FastTlsSingletonWithTag<TVectorType<T, std::allocator<T>>, TTagForTls>();
+ Y_ASSERT(Vector->empty());
+ }
+
+ ~TTempTlsVector() {
+ Clear();
+ }
+
+ void Clear() {
+ Vector->clear();
+ }
+
+ size_t Capacity() const noexcept {
+ return Vector->capacity();
+ }
+
+ void Shrink() {
+ Vector->shrink_to_fit();
+ }
+};
diff --git a/library/cpp/messagebus/actor/thread_extra.cpp b/library/cpp/messagebus/actor/thread_extra.cpp
new file mode 100644
index 0000000000..048480f255
--- /dev/null
+++ b/library/cpp/messagebus/actor/thread_extra.cpp
@@ -0,0 +1,30 @@
+#include "thread_extra.h"
+
+#include <util/stream/str.h>
+#include <util/system/execpath.h>
+#include <util/system/platform.h>
+#include <util/system/thread.h>
+
+namespace {
+#ifdef _linux_
+ TString GetExecName() {
+ TString execPath = GetExecPath();
+ size_t lastSlash = execPath.find_last_of('/');
+ if (lastSlash == TString::npos) {
+ return execPath;
+ } else {
+ return execPath.substr(lastSlash + 1);
+ }
+ }
+#endif
+}
+
+void SetCurrentThreadName(const char* name) {
+#ifdef _linux_
+ TStringStream linuxName;
+ linuxName << GetExecName() << "." << name;
+ TThread::SetCurrentThreadName(linuxName.Str().data());
+#else
+ TThread::SetCurrentThreadName(name);
+#endif
+}
diff --git a/library/cpp/messagebus/actor/thread_extra.h b/library/cpp/messagebus/actor/thread_extra.h
new file mode 100644
index 0000000000..b5aa151618
--- /dev/null
+++ b/library/cpp/messagebus/actor/thread_extra.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#include <util/thread/singleton.h>
+
+namespace NTSAN {
+ template <typename T>
+ inline void RelaxedStore(volatile T* a, T x) {
+ static_assert(std::is_integral<T>::value || std::is_pointer<T>::value, "expect std::is_integral<T>::value || std::is_pointer<T>::value");
+#ifdef _win_
+ *a = x;
+#else
+ __atomic_store_n(a, x, __ATOMIC_RELAXED);
+#endif
+ }
+
+ template <typename T>
+ inline T RelaxedLoad(volatile T* a) {
+#ifdef _win_
+ return *a;
+#else
+ return __atomic_load_n(a, __ATOMIC_RELAXED);
+#endif
+ }
+
+}
+
+void SetCurrentThreadName(const char* name);
+
+namespace NThreadExtra {
+ namespace NPrivate {
+ template <typename TValue, typename TTag>
+ struct TValueHolder {
+ TValue Value;
+ };
+ }
+}
+
+template <typename TValue, typename TTag>
+static inline TValue* FastTlsSingletonWithTag() {
+ return &FastTlsSingleton< ::NThreadExtra::NPrivate::TValueHolder<TValue, TTag>>()->Value;
+}
diff --git a/library/cpp/messagebus/actor/what_thread_does.cpp b/library/cpp/messagebus/actor/what_thread_does.cpp
new file mode 100644
index 0000000000..bebb6a888c
--- /dev/null
+++ b/library/cpp/messagebus/actor/what_thread_does.cpp
@@ -0,0 +1,22 @@
+#include "what_thread_does.h"
+
+#include "thread_extra.h"
+
+#include <util/system/tls.h>
+
+Y_POD_STATIC_THREAD(const char*)
+WhatThreadDoes;
+
+const char* PushWhatThreadDoes(const char* what) {
+ const char* r = NTSAN::RelaxedLoad(&WhatThreadDoes);
+ NTSAN::RelaxedStore(&WhatThreadDoes, what);
+ return r;
+}
+
+void PopWhatThreadDoes(const char* prev) {
+ NTSAN::RelaxedStore(&WhatThreadDoes, prev);
+}
+
+const char** WhatThreadDoesLocation() {
+ return &WhatThreadDoes;
+}
diff --git a/library/cpp/messagebus/actor/what_thread_does.h b/library/cpp/messagebus/actor/what_thread_does.h
new file mode 100644
index 0000000000..235d2c3700
--- /dev/null
+++ b/library/cpp/messagebus/actor/what_thread_does.h
@@ -0,0 +1,28 @@
+#pragma once
+
+const char* PushWhatThreadDoes(const char* what);
+void PopWhatThreadDoes(const char* prev);
+const char** WhatThreadDoesLocation();
+
+struct TWhatThreadDoesPushPop {
+private:
+ const char* Prev;
+
+public:
+ TWhatThreadDoesPushPop(const char* what) {
+ Prev = PushWhatThreadDoes(what);
+ }
+
+ ~TWhatThreadDoesPushPop() {
+ PopWhatThreadDoes(Prev);
+ }
+};
+
+#ifdef __GNUC__
+#define WHAT_THREAD_DOES_FUNCTION __PRETTY_FUNCTION__
+#else
+#define WHAT_THREAD_DOES_FUNCTION __FUNCTION__
+#endif
+
+#define WHAT_THREAD_DOES_PUSH_POP_CURRENT_FUNC() \
+ TWhatThreadDoesPushPop whatThreadDoesPushPopCurrentFunc(WHAT_THREAD_DOES_FUNCTION)
diff --git a/library/cpp/messagebus/actor/what_thread_does_guard.h b/library/cpp/messagebus/actor/what_thread_does_guard.h
new file mode 100644
index 0000000000..f104e9e173
--- /dev/null
+++ b/library/cpp/messagebus/actor/what_thread_does_guard.h
@@ -0,0 +1,40 @@
+#pragma once
+
+#include "what_thread_does.h"
+
+template <class T>
+class TWhatThreadDoesAcquireGuard: public TNonCopyable {
+public:
+ inline TWhatThreadDoesAcquireGuard(const T& t, const char* acquire) noexcept {
+ Init(&t, acquire);
+ }
+
+ inline TWhatThreadDoesAcquireGuard(const T* t, const char* acquire) noexcept {
+ Init(t, acquire);
+ }
+
+ inline ~TWhatThreadDoesAcquireGuard() {
+ Release();
+ }
+
+ inline void Release() noexcept {
+ if (WasAcquired()) {
+ const_cast<T*>(T_)->Release();
+ T_ = nullptr;
+ }
+ }
+
+ inline bool WasAcquired() const noexcept {
+ return T_ != nullptr;
+ }
+
+private:
+ inline void Init(const T* t, const char* acquire) noexcept {
+ T_ = const_cast<T*>(t);
+ TWhatThreadDoesPushPop pp(acquire);
+ T_->Acquire();
+ }
+
+private:
+ T* T_;
+};
diff --git a/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp b/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp
new file mode 100644
index 0000000000..e4b218a7ca
--- /dev/null
+++ b/library/cpp/messagebus/actor/what_thread_does_guard_ut.cpp
@@ -0,0 +1,13 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "what_thread_does_guard.h"
+
+#include <util/system/mutex.h>
+
+Y_UNIT_TEST_SUITE(WhatThreadDoesGuard) {
+ Y_UNIT_TEST(Simple) {
+ TMutex mutex;
+
+ TWhatThreadDoesAcquireGuard<TMutex> guard(mutex, "acquiring my mutex");
+ }
+}
diff --git a/library/cpp/messagebus/actor/ya.make b/library/cpp/messagebus/actor/ya.make
new file mode 100644
index 0000000000..59bd1b0b99
--- /dev/null
+++ b/library/cpp/messagebus/actor/ya.make
@@ -0,0 +1,11 @@
+LIBRARY(messagebus_actor)
+
+OWNER(g:messagebus)
+
+SRCS(
+ executor.cpp
+ thread_extra.cpp
+ what_thread_does.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/all.lwt b/library/cpp/messagebus/all.lwt
new file mode 100644
index 0000000000..0f04be4b2c
--- /dev/null
+++ b/library/cpp/messagebus/all.lwt
@@ -0,0 +1,8 @@
+Blocks {
+ ProbeDesc {
+ Group: "MessagebusRare"
+ }
+ Action {
+ PrintToStderrAction {}
+ }
+}
diff --git a/library/cpp/messagebus/all/ya.make b/library/cpp/messagebus/all/ya.make
new file mode 100644
index 0000000000..ffa2dbfabc
--- /dev/null
+++ b/library/cpp/messagebus/all/ya.make
@@ -0,0 +1,10 @@
+OWNER(g:messagebus)
+
+RECURSE_ROOT_RELATIVE(
+ library/python/messagebus
+ library/cpp/messagebus/debug_receiver
+ library/cpp/messagebus/oldmodule
+ library/cpp/messagebus/rain_check
+ library/cpp/messagebus/test
+ library/cpp/messagebus/www
+)
diff --git a/library/cpp/messagebus/async_result.h b/library/cpp/messagebus/async_result.h
new file mode 100644
index 0000000000..d24dde284a
--- /dev/null
+++ b/library/cpp/messagebus/async_result.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include <util/generic/maybe.h>
+#include <util/generic/noncopyable.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/system/yassert.h>
+
+#include <functional>
+
+// probably this thing should have been called TFuture
+template <typename T>
+class TAsyncResult : TNonCopyable {
+private:
+ TMutex Mutex;
+ TCondVar CondVar;
+
+ TMaybe<T> Result;
+
+ typedef void TOnResult(const T&);
+
+ std::function<TOnResult> OnResult;
+
+public:
+ void SetResult(const T& result) {
+ TGuard<TMutex> guard(Mutex);
+ Y_VERIFY(!Result, "cannot set result twice");
+ Result = result;
+ CondVar.BroadCast();
+
+ if (!!OnResult) {
+ OnResult(result);
+ }
+ }
+
+ const T& GetResult() {
+ TGuard<TMutex> guard(Mutex);
+ while (!Result) {
+ CondVar.Wait(Mutex);
+ }
+ return *Result;
+ }
+
+ template <typename TFunc>
+ void AndThen(const TFunc& onResult) {
+ TGuard<TMutex> guard(Mutex);
+ if (!!Result) {
+ onResult(*Result);
+ } else {
+ Y_ASSERT(!OnResult);
+ OnResult = std::function<TOnResult>(onResult);
+ }
+ }
+};
diff --git a/library/cpp/messagebus/async_result_ut.cpp b/library/cpp/messagebus/async_result_ut.cpp
new file mode 100644
index 0000000000..2e96492afd
--- /dev/null
+++ b/library/cpp/messagebus/async_result_ut.cpp
@@ -0,0 +1,37 @@
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "async_result.h"
+
+namespace {
+ void SetValue(int* location, const int& value) {
+ *location = value;
+ }
+
+}
+
+Y_UNIT_TEST_SUITE(TAsyncResult) {
+ Y_UNIT_TEST(AndThen_Here) {
+ TAsyncResult<int> r;
+
+ int var = 1;
+
+ r.SetResult(17);
+
+ r.AndThen(std::bind(&SetValue, &var, std::placeholders::_1));
+
+ UNIT_ASSERT_VALUES_EQUAL(17, var);
+ }
+
+ Y_UNIT_TEST(AndThen_Later) {
+ TAsyncResult<int> r;
+
+ int var = 1;
+
+ r.AndThen(std::bind(&SetValue, &var, std::placeholders::_1));
+
+ r.SetResult(17);
+
+ UNIT_ASSERT_VALUES_EQUAL(17, var);
+ }
+}
diff --git a/library/cpp/messagebus/base.h b/library/cpp/messagebus/base.h
new file mode 100644
index 0000000000..79fccc312e
--- /dev/null
+++ b/library/cpp/messagebus/base.h
@@ -0,0 +1,11 @@
+#pragma once
+
+#include <util/system/defaults.h>
+
+namespace NBus {
+ /// millis since epoch
+ using TBusInstant = ui64;
+ /// returns time in milliseconds
+ TBusInstant Now();
+
+}
diff --git a/library/cpp/messagebus/cc_semaphore.h b/library/cpp/messagebus/cc_semaphore.h
new file mode 100644
index 0000000000..0df8a3d664
--- /dev/null
+++ b/library/cpp/messagebus/cc_semaphore.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include "latch.h"
+
+template <typename TThis>
+class TComplexConditionSemaphore {
+private:
+ TLatch Latch;
+
+public:
+ void Updated() {
+ if (GetThis()->TryWait()) {
+ Latch.Unlock();
+ }
+ }
+
+ void Wait() {
+ while (!GetThis()->TryWait()) {
+ Latch.Lock();
+ if (GetThis()->TryWait()) {
+ Latch.Unlock();
+ return;
+ }
+ Latch.Wait();
+ }
+ }
+
+ bool IsLocked() {
+ return Latch.IsLocked();
+ }
+
+private:
+ TThis* GetThis() {
+ return static_cast<TThis*>(this);
+ }
+};
diff --git a/library/cpp/messagebus/cc_semaphore_ut.cpp b/library/cpp/messagebus/cc_semaphore_ut.cpp
new file mode 100644
index 0000000000..206bb7c96a
--- /dev/null
+++ b/library/cpp/messagebus/cc_semaphore_ut.cpp
@@ -0,0 +1,45 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "cc_semaphore.h"
+
+#include <util/system/atomic.h>
+
+namespace {
+ struct TTestSemaphore: public TComplexConditionSemaphore<TTestSemaphore> {
+ TAtomic Current;
+
+ TTestSemaphore()
+ : Current(0)
+ {
+ }
+
+ bool TryWait() {
+ return AtomicGet(Current) > 0;
+ }
+
+ void Aquire() {
+ Wait();
+ AtomicDecrement(Current);
+ }
+
+ void Release() {
+ AtomicIncrement(Current);
+ Updated();
+ }
+ };
+}
+
+Y_UNIT_TEST_SUITE(TComplexConditionSemaphore) {
+ Y_UNIT_TEST(Simple) {
+ TTestSemaphore sema;
+ UNIT_ASSERT(!sema.TryWait());
+ sema.Release();
+ UNIT_ASSERT(sema.TryWait());
+ sema.Release();
+ UNIT_ASSERT(sema.TryWait());
+ sema.Aquire();
+ UNIT_ASSERT(sema.TryWait());
+ sema.Aquire();
+ UNIT_ASSERT(!sema.TryWait());
+ }
+}
diff --git a/library/cpp/messagebus/codegen.h b/library/cpp/messagebus/codegen.h
new file mode 100644
index 0000000000..83e969e811
--- /dev/null
+++ b/library/cpp/messagebus/codegen.h
@@ -0,0 +1,4 @@
+#pragma once
+
+#include <library/cpp/messagebus/config/codegen.h>
+
diff --git a/library/cpp/messagebus/config/codegen.h b/library/cpp/messagebus/config/codegen.h
new file mode 100644
index 0000000000..97ddada005
--- /dev/null
+++ b/library/cpp/messagebus/config/codegen.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#define COMMA ,
+
+#define STRUCT_FIELD_GEN(name, type, ...) type name;
+
+#define STRUCT_FIELD_INIT(name, type, defa) name(defa)
+#define STRUCT_FIELD_INIT_DEFAULT(name, type, ...) name()
+
+#define STRUCT_FIELD_PRINT(name, ...) ss << #name << "=" << name << "\n";
diff --git a/library/cpp/messagebus/config/defs.h b/library/cpp/messagebus/config/defs.h
new file mode 100644
index 0000000000..92b1df9969
--- /dev/null
+++ b/library/cpp/messagebus/config/defs.h
@@ -0,0 +1,82 @@
+#pragma once
+
+// unique tag to fix pragma once gcc glueing: ./library/cpp/messagebus/defs.h
+
+#include "codegen.h"
+#include "netaddr.h"
+
+#include <library/cpp/deprecated/enum_codegen/enum_codegen.h>
+
+#include <util/generic/list.h>
+
+#include <utility>
+
+// For historical reasons TCrawlerModule need to access
+// APIs that should be private.
+class TCrawlerModule;
+
+struct TDebugReceiverHandler;
+
+namespace NBus {
+ namespace NPrivate {
+ class TAcceptor;
+ struct TBusSessionImpl;
+ class TRemoteServerSession;
+ class TRemoteClientSession;
+ class TRemoteConnection;
+ class TRemoteServerConnection;
+ class TRemoteClientConnection;
+ class TBusSyncSourceSessionImpl;
+
+ struct TBusMessagePtrAndHeader;
+
+ struct TSessionDumpStatus;
+
+ struct TClientRequestImpl;
+
+ }
+
+ class TBusSession;
+ struct TBusServerSession;
+ struct TBusClientSession;
+ class TBusProtocol;
+ class TBusMessage;
+ class TBusMessageConnection;
+ class TBusMessageQueue;
+ class TBusLocator;
+ struct TBusQueueConfig;
+ struct TBusSessionConfig;
+ struct TBusHeader;
+
+ class IThreadHandler;
+
+ using TBusKey = ui64;
+ using TBusMessageList = TList<TBusMessage*>;
+ using TBusKeyVec = TVector<std::pair<TBusKey, TBusKey>>;
+
+ using TBusMessageQueuePtr = TIntrusivePtr<TBusMessageQueue>;
+
+ class TBusModule;
+
+ using TBusData = TString;
+ using TBusService = const char*;
+
+#define YBUS_KEYMIN TBusKey(0L)
+#define YBUS_KEYMAX TBusKey(-1L)
+#define YBUS_KEYLOCAL TBusKey(7L)
+#define YBUS_KEYINVALID TBusKey(99999999L)
+
+ // Check that generated id is valid for remote message
+ inline bool IsBusKeyValid(TBusKey key) {
+ return key != YBUS_KEYINVALID && key != YBUS_KEYMAX && key > YBUS_KEYLOCAL;
+ }
+
+#define YBUS_VERSION 0
+
+#define YBUS_INFINITE (1u << 30u)
+
+#define YBUS_STATUS_BASIC 0x0000
+#define YBUS_STATUS_CONNS 0x0001
+#define YBUS_STATUS_INFLIGHT 0x0002
+
+}
diff --git a/library/cpp/messagebus/config/netaddr.cpp b/library/cpp/messagebus/config/netaddr.cpp
new file mode 100644
index 0000000000..962ac538e2
--- /dev/null
+++ b/library/cpp/messagebus/config/netaddr.cpp
@@ -0,0 +1,183 @@
+#include "netaddr.h"
+
+#include <util/network/address.h>
+
+#include <cstdlib>
+
+namespace NBus {
+ const char* ToCString(EIpVersion ipVersion) {
+ switch (ipVersion) {
+ case EIP_VERSION_ANY:
+ return "EIP_VERSION_ANY";
+ case EIP_VERSION_4:
+ return "EIP_VERSION_4";
+ case EIP_VERSION_6:
+ return "EIP_VERSION_6";
+ }
+ Y_FAIL();
+ }
+
+ int ToAddrFamily(EIpVersion ipVersion) {
+ switch (ipVersion) {
+ case EIP_VERSION_ANY:
+ return AF_UNSPEC;
+ case EIP_VERSION_4:
+ return AF_INET;
+ case EIP_VERSION_6:
+ return AF_INET6;
+ }
+ Y_FAIL();
+ }
+
+ class TNetworkAddressRef: private TNetworkAddress, public TAddrInfo {
+ public:
+ TNetworkAddressRef(const TNetworkAddress& na, const TAddrInfo& ai)
+ : TNetworkAddress(na)
+ , TAddrInfo(ai)
+ {
+ }
+ };
+
+ static bool Compare(const IRemoteAddr& l, const IRemoteAddr& r) noexcept {
+ if (l.Addr()->sa_family != r.Addr()->sa_family) {
+ return false;
+ }
+
+ switch (l.Addr()->sa_family) {
+ case AF_INET: {
+ return memcmp(&(((const sockaddr_in*)l.Addr())->sin_addr), &(((const sockaddr_in*)r.Addr())->sin_addr), sizeof(in_addr)) == 0 &&
+ ((const sockaddr_in*)l.Addr())->sin_port == ((const sockaddr_in*)r.Addr())->sin_port;
+ }
+
+ case AF_INET6: {
+ return memcmp(&(((const sockaddr_in6*)l.Addr())->sin6_addr), &(((const sockaddr_in6*)r.Addr())->sin6_addr), sizeof(in6_addr)) == 0 &&
+ ((const sockaddr_in6*)l.Addr())->sin6_port == ((const sockaddr_in6*)r.Addr())->sin6_port;
+ }
+ }
+
+ return memcmp(l.Addr(), r.Addr(), Min<size_t>(l.Len(), r.Len())) == 0;
+ }
+
+ TNetAddr::TNetAddr()
+ : Ptr(new TOpaqueAddr)
+ {
+ }
+
+ TNetAddr::TNetAddr(TAutoPtr<IRemoteAddr> addr)
+ : Ptr(addr)
+ {
+ Y_VERIFY(!!Ptr);
+ }
+
+ namespace {
+ using namespace NAddr;
+
+ const char* Describe(EIpVersion version) {
+ switch (version) {
+ case EIP_VERSION_4:
+ return "ipv4 address";
+ case EIP_VERSION_6:
+ return "ipv6 address";
+ case EIP_VERSION_ANY:
+ return "any address";
+ default:
+ Y_FAIL("unreachable");
+ }
+ }
+
+ TAutoPtr<IRemoteAddr> MakeAddress(const TNetworkAddress& na, EIpVersion requireVersion, EIpVersion preferVersion) {
+ TAutoPtr<IRemoteAddr> addr;
+ for (TNetworkAddress::TIterator it = na.Begin(); it != na.End(); ++it) {
+ if (IsFamilyAllowed(it->ai_family, requireVersion)) {
+ if (IsFamilyAllowed(it->ai_family, preferVersion)) {
+ return new TNetworkAddressRef(na, &*it);
+ } else if (!addr) {
+ addr.Reset(new TNetworkAddressRef(na, &*it));
+ }
+ }
+ }
+ return addr;
+ }
+ TAutoPtr<IRemoteAddr> MakeAddress(TStringBuf host, int port, EIpVersion requireVersion, EIpVersion preferVersion) {
+ TString hostString(host);
+ TNetworkAddress na(hostString, port);
+ return MakeAddress(na, requireVersion, preferVersion);
+ }
+ TAutoPtr<IRemoteAddr> MakeAddress(const char* hostPort, EIpVersion requireVersion, EIpVersion preferVersion) {
+ const char* portStr = strchr(hostPort, ':');
+ if (!portStr) {
+ ythrow TNetAddr::TError() << "port not specified in " << hostPort;
+ }
+ int port = atoi(portStr + 1);
+ TNetworkAddress na(TString(hostPort, portStr), port);
+ return MakeAddress(na, requireVersion, preferVersion);
+ }
+ }
+
+ TNetAddr::TNetAddr(const char* hostPort, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/)
+ : Ptr(MakeAddress(hostPort, requireVersion, preferVersion))
+ {
+ if (!Ptr) {
+ ythrow TNetAddr::TError() << "cannot resolve " << hostPort << " into " << Describe(requireVersion);
+ }
+ }
+
+ TNetAddr::TNetAddr(TStringBuf host, int port, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/)
+ : Ptr(MakeAddress(host, port, requireVersion, preferVersion))
+ {
+ if (!Ptr) {
+ ythrow TNetAddr::TError() << "cannot resolve " << host << ":" << port << " into " << Describe(requireVersion);
+ }
+ }
+
+ TNetAddr::TNetAddr(const TNetworkAddress& na, EIpVersion requireVersion /*= EIP_VERSION_ANY*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/)
+ : Ptr(MakeAddress(na, requireVersion, preferVersion))
+ {
+ if (!Ptr) {
+ ythrow TNetAddr::TError() << "cannot resolve into " << Describe(requireVersion);
+ }
+ }
+
+ TNetAddr::TNetAddr(const TNetworkAddress& na, const TAddrInfo& ai)
+ : Ptr(new TNetworkAddressRef(na, ai))
+ {
+ }
+
+ const sockaddr* TNetAddr::Addr() const {
+ return Ptr->Addr();
+ }
+
+ socklen_t TNetAddr::Len() const {
+ return Ptr->Len();
+ }
+
+ int TNetAddr::GetPort() const {
+ switch (Ptr->Addr()->sa_family) {
+ case AF_INET:
+ return InetToHost(((sockaddr_in*)Ptr->Addr())->sin_port);
+ case AF_INET6:
+ return InetToHost(((sockaddr_in6*)Ptr->Addr())->sin6_port);
+ default:
+ Y_FAIL("unknown AF: %d", (int)Ptr->Addr()->sa_family);
+ throw 1;
+ }
+ }
+
+ bool TNetAddr::IsIpv4() const {
+ return Ptr->Addr()->sa_family == AF_INET;
+ }
+
+ bool TNetAddr::IsIpv6() const {
+ return Ptr->Addr()->sa_family == AF_INET6;
+ }
+
+ bool TNetAddr::operator==(const TNetAddr& rhs) const {
+ return Ptr == rhs.Ptr || Compare(*Ptr, *rhs.Ptr);
+ }
+
+}
+
+template <>
+void Out<NBus::TNetAddr>(IOutputStream& out, const NBus::TNetAddr& addr) {
+ Out<NAddr::IRemoteAddr>(out, addr);
+}
diff --git a/library/cpp/messagebus/config/netaddr.h b/library/cpp/messagebus/config/netaddr.h
new file mode 100644
index 0000000000..b79c0cc355
--- /dev/null
+++ b/library/cpp/messagebus/config/netaddr.h
@@ -0,0 +1,86 @@
+#pragma once
+
+#include <util/digest/numeric.h>
+#include <util/generic/hash.h>
+#include <util/generic/ptr.h>
+#include <util/generic/strbuf.h>
+#include <util/generic/vector.h>
+#include <util/network/address.h>
+
+namespace NBus {
+ using namespace NAddr;
+
+ /// IP protocol version.
+ enum EIpVersion {
+ EIP_VERSION_4 = 1,
+ EIP_VERSION_6 = 2,
+ EIP_VERSION_ANY = EIP_VERSION_4 | EIP_VERSION_6,
+ };
+
+ inline bool IsFamilyAllowed(ui16 sa_family, EIpVersion ipVersion) {
+ if (ipVersion == EIP_VERSION_4 && sa_family != AF_INET) {
+ return false;
+ }
+ if (ipVersion == EIP_VERSION_6 && sa_family != AF_INET6) {
+ return false;
+ }
+ return true;
+ }
+
+ const char* ToCString(EIpVersion);
+ int ToAddrFamily(EIpVersion);
+
+ /// Hold referenced pointer to address description structure (ex. sockaddr_storage)
+ /// It's make possible to work with IPv4 / IPv6 addresses simultaneously
+ class TNetAddr: public IRemoteAddr {
+ public:
+ class TError: public yexception {
+ };
+
+ TNetAddr();
+ TNetAddr(TAutoPtr<IRemoteAddr> addr);
+ TNetAddr(const char* hostPort, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY);
+ TNetAddr(TStringBuf host, int port, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY);
+ TNetAddr(const TNetworkAddress& na, EIpVersion requireVersion = EIP_VERSION_ANY, EIpVersion preferVersion = EIP_VERSION_ANY);
+ TNetAddr(const TNetworkAddress& na, const TAddrInfo& ai);
+
+ bool operator==(const TNetAddr&) const;
+ bool operator!=(const TNetAddr& other) const {
+ return !(*this == other);
+ }
+ inline explicit operator bool() const noexcept {
+ return !!Ptr;
+ }
+
+ const sockaddr* Addr() const override;
+ socklen_t Len() const override;
+
+ bool IsIpv4() const;
+ bool IsIpv6() const;
+ int GetPort() const;
+
+ private:
+ TAtomicSharedPtr<IRemoteAddr> Ptr;
+ };
+
+ using TSockAddrInVector = TVector<TNetAddr>;
+
+ struct TNetAddrHostPortHash {
+ inline size_t operator()(const TNetAddr& a) const {
+ const sockaddr* s = a.Addr();
+ const sockaddr_in* const sa = reinterpret_cast<const sockaddr_in*>(s);
+ const sockaddr_in6* const sa6 = reinterpret_cast<const sockaddr_in6*>(s);
+
+ switch (s->sa_family) {
+ case AF_INET:
+ return CombineHashes<size_t>(ComputeHash(TStringBuf(reinterpret_cast<const char*>(&sa->sin_addr), sizeof(sa->sin_addr))), IntHashImpl(sa->sin_port));
+
+ case AF_INET6:
+ return CombineHashes<size_t>(ComputeHash(TStringBuf(reinterpret_cast<const char*>(&sa6->sin6_addr), sizeof(sa6->sin6_addr))), IntHashImpl(sa6->sin6_port));
+ }
+
+ return ComputeHash(TStringBuf(reinterpret_cast<const char*>(s), a.Len()));
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/config/session_config.cpp b/library/cpp/messagebus/config/session_config.cpp
new file mode 100644
index 0000000000..fbbbb106c9
--- /dev/null
+++ b/library/cpp/messagebus/config/session_config.cpp
@@ -0,0 +1,157 @@
+#include "session_config.h"
+
+#include <util/generic/strbuf.h>
+#include <util/string/hex.h>
+
+using namespace NBus;
+
+TBusSessionConfig::TSecret::TSecret()
+ : TimeoutPeriod(TDuration::Seconds(1))
+ , StatusFlushPeriod(TDuration::MilliSeconds(400))
+{
+}
+
+TBusSessionConfig::TBusSessionConfig()
+ : BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_INIT, COMMA)
+{
+}
+
+TString TBusSessionConfig::PrintToString() const {
+ TStringStream ss;
+ BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_PRINT, )
+ return ss.Str();
+}
+
+static int ParseDurationForMessageBus(const char* option) {
+ return TDuration::Parse(option).MilliSeconds();
+}
+
+static int ParseToSForMessageBus(const char* option) {
+ int tos;
+ TStringBuf str(option);
+ if (str.StartsWith("0x")) {
+ str = str.Tail(2);
+ Y_VERIFY(str.length() == 2, "ToS must be a number between 0x00 and 0xFF");
+ tos = String2Byte(str.data());
+ } else {
+ tos = FromString<int>(option);
+ }
+ Y_VERIFY(tos >= 0 && tos <= 255, "ToS must be between 0x00 and 0xFF");
+ return tos;
+}
+
+template <class T>
+static T ParseWithKmgSuffixT(const char* option) {
+ TStringBuf str(option);
+ T multiplier = 1;
+ if (str.EndsWith('k')) {
+ multiplier = 1024;
+ str = str.Head(str.size() - 1);
+ } else if (str.EndsWith('m')) {
+ multiplier = 1024 * 1024;
+ str = str.Head(str.size() - 1);
+ } else if (str.EndsWith('g')) {
+ multiplier = 1024 * 1024 * 1024;
+ str = str.Head(str.size() - 1);
+ }
+ return FromString<T>(str) * multiplier;
+}
+
+static ui64 ParseWithKmgSuffix(const char* option) {
+ return ParseWithKmgSuffixT<ui64>(option);
+}
+
+static i64 ParseWithKmgSuffixS(const char* option) {
+ return ParseWithKmgSuffixT<i64>(option);
+}
+
+void TBusSessionConfig::ConfigureLastGetopt(NLastGetopt::TOpts& opts,
+ const TString& prefix) {
+ opts.AddLongOption(prefix + "total-timeout")
+ .RequiredArgument("MILLISECONDS")
+ .DefaultValue(ToString(TotalTimeout))
+ .StoreMappedResultT<const char*>(&TotalTimeout,
+ &ParseDurationForMessageBus);
+ opts.AddLongOption(prefix + "connect-timeout")
+ .RequiredArgument("MILLISECONDS")
+ .DefaultValue(ToString(ConnectTimeout))
+ .StoreMappedResultT<const char*>(&ConnectTimeout,
+ &ParseDurationForMessageBus);
+ opts.AddLongOption(prefix + "send-timeout")
+ .RequiredArgument("MILLISECONDS")
+ .DefaultValue(ToString(SendTimeout))
+ .StoreMappedResultT<const char*>(&SendTimeout,
+ &ParseDurationForMessageBus);
+ opts.AddLongOption(prefix + "send-threshold")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(SendThreshold))
+ .StoreMappedResultT<const char*>(&SendThreshold, &ParseWithKmgSuffix);
+
+ opts.AddLongOption(prefix + "max-in-flight")
+ .RequiredArgument("COUNT")
+ .DefaultValue(ToString(MaxInFlight))
+ .StoreMappedResultT<const char*>(&MaxInFlight, &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "max-in-flight-by-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(
+ ToString(MaxInFlightBySize))
+ .StoreMappedResultT<const char*>(&MaxInFlightBySize, &ParseWithKmgSuffixS);
+ opts.AddLongOption(prefix + "per-con-max-in-flight")
+ .RequiredArgument("COUNT")
+ .DefaultValue(ToString(PerConnectionMaxInFlight))
+ .StoreMappedResultT<const char*>(&PerConnectionMaxInFlight,
+ &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "per-con-max-in-flight-by-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(
+ ToString(PerConnectionMaxInFlightBySize))
+ .StoreMappedResultT<const char*>(&PerConnectionMaxInFlightBySize,
+ &ParseWithKmgSuffix);
+
+ opts.AddLongOption(prefix + "default-buffer-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(DefaultBufferSize))
+ .StoreMappedResultT<const char*>(&DefaultBufferSize,
+ &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "max-buffer-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(MaxBufferSize))
+ .StoreMappedResultT<const char*>(&MaxBufferSize, &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "max-message-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(MaxMessageSize))
+ .StoreMappedResultT<const char*>(&MaxMessageSize, &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "socket-recv-buffer-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(SocketRecvBufferSize))
+ .StoreMappedResultT<const char*>(&SocketRecvBufferSize,
+ &ParseWithKmgSuffix);
+ opts.AddLongOption(prefix + "socket-send-buffer-size")
+ .RequiredArgument("BYTES")
+ .DefaultValue(ToString(SocketSendBufferSize))
+ .StoreMappedResultT<const char*>(&SocketSendBufferSize,
+ &ParseWithKmgSuffix);
+
+ opts.AddLongOption(prefix + "socket-tos")
+ .RequiredArgument("[0x00, 0xFF]")
+ .StoreMappedResultT<const char*>(&SocketToS, &ParseToSForMessageBus);
+ ;
+ opts.AddLongOption(prefix + "tcp-cork")
+ .RequiredArgument("BOOL")
+ .DefaultValue(ToString(TcpCork))
+ .StoreResult(&TcpCork);
+ opts.AddLongOption(prefix + "cork")
+ .RequiredArgument("SECONDS")
+ .DefaultValue(
+ ToString(Cork.Seconds()))
+ .StoreMappedResultT<const char*>(&Cork, &TDuration::Parse);
+
+ opts.AddLongOption(prefix + "on-message-in-pool")
+ .RequiredArgument("BOOL")
+ .DefaultValue(ToString(ExecuteOnMessageInWorkerPool))
+ .StoreResult(&ExecuteOnMessageInWorkerPool);
+ opts.AddLongOption(prefix + "on-reply-in-pool")
+ .RequiredArgument("BOOL")
+ .DefaultValue(ToString(ExecuteOnReplyInWorkerPool))
+ .StoreResult(&ExecuteOnReplyInWorkerPool);
+}
diff --git a/library/cpp/messagebus/config/session_config.h b/library/cpp/messagebus/config/session_config.h
new file mode 100644
index 0000000000..84753350a9
--- /dev/null
+++ b/library/cpp/messagebus/config/session_config.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include "codegen.h"
+#include "defs.h"
+
+#include <library/cpp/getopt/last_getopt.h>
+
+#include <util/generic/string.h>
+
+namespace NBus {
+#define BUS_SESSION_CONFIG_MAP(XX, comma) \
+ XX(Name, TString, "") \
+ comma \
+ XX(NumRetries, int, 0) comma \
+ XX(RetryInterval, int, 1000) comma \
+ XX(ReconnectWhenIdle, bool, false) comma \
+ XX(MaxInFlight, i64, 1000) comma \
+ XX(PerConnectionMaxInFlight, unsigned, 0) comma \
+ XX(PerConnectionMaxInFlightBySize, unsigned, 0) comma \
+ XX(MaxInFlightBySize, i64, -1) comma \
+ XX(TotalTimeout, i64, 0) comma \
+ XX(SendTimeout, i64, 0) comma \
+ XX(ConnectTimeout, i64, 0) comma \
+ XX(DefaultBufferSize, size_t, 10 * 1024) comma \
+ XX(MaxBufferSize, size_t, 1024 * 1024) comma \
+ XX(SocketRecvBufferSize, unsigned, 0) comma \
+ XX(SocketSendBufferSize, unsigned, 0) comma \
+ XX(SocketToS, int, -1) comma \
+ XX(SendThreshold, size_t, 10 * 1024) comma \
+ XX(Cork, TDuration, TDuration::Zero()) comma \
+ XX(MaxMessageSize, unsigned, 26 << 20) comma \
+ XX(TcpNoDelay, bool, false) comma \
+ XX(TcpCork, bool, false) comma \
+ XX(ExecuteOnMessageInWorkerPool, bool, true) comma \
+ XX(ExecuteOnReplyInWorkerPool, bool, true) comma \
+ XX(ReusePort, bool, false) comma \
+ XX(ListenPort, unsigned, 0) /* TODO: server only */
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief Configuration for client and server session
+ struct TBusSessionConfig {
+ BUS_SESSION_CONFIG_MAP(STRUCT_FIELD_GEN, )
+
+ struct TSecret {
+ TDuration TimeoutPeriod;
+ TDuration StatusFlushPeriod;
+
+ TSecret();
+ };
+
+ // secret options are available, but you shouldn't probably use them
+ TSecret Secret;
+
+ /// initialized with default settings
+ TBusSessionConfig();
+
+ TString PrintToString() const;
+
+ void ConfigureLastGetopt(NLastGetopt::TOpts&, const TString& prefix = "mb-");
+ };
+
+ using TBusClientSessionConfig = TBusSessionConfig;
+ using TBusServerSessionConfig = TBusSessionConfig;
+
+} // NBus
diff --git a/library/cpp/messagebus/config/ya.make b/library/cpp/messagebus/config/ya.make
new file mode 100644
index 0000000000..20c7dfed19
--- /dev/null
+++ b/library/cpp/messagebus/config/ya.make
@@ -0,0 +1,15 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/getopt
+ library/cpp/deprecated/enum_codegen
+)
+
+SRCS(
+ netaddr.cpp
+ session_config.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/connection.cpp b/library/cpp/messagebus/connection.cpp
new file mode 100644
index 0000000000..07580ce18a
--- /dev/null
+++ b/library/cpp/messagebus/connection.cpp
@@ -0,0 +1,16 @@
+#include "connection.h"
+
+#include "remote_client_connection.h"
+
+#include <util/generic/cast.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+void TBusClientConnectionPtrOps::Ref(TBusClientConnection* c) {
+ return CheckedCast<TRemoteClientConnection*>(c)->Ref();
+}
+
+void TBusClientConnectionPtrOps::UnRef(TBusClientConnection* c) {
+ return CheckedCast<TRemoteClientConnection*>(c)->UnRef();
+}
diff --git a/library/cpp/messagebus/connection.h b/library/cpp/messagebus/connection.h
new file mode 100644
index 0000000000..b1df64ddc1
--- /dev/null
+++ b/library/cpp/messagebus/connection.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include "defs.h"
+#include "message.h"
+
+#include <util/generic/ptr.h>
+
+namespace NBus {
+ struct TBusClientConnection {
+ /// if you want to open connection early
+ virtual void OpenConnection() = 0;
+
+ /// Send message to the destination
+ /// If addr is set then use it as destination.
+ /// Takes ownership of addr (see ClearState method).
+ virtual EMessageStatus SendMessage(TBusMessage* pMes, bool wait = false) = 0;
+
+ virtual EMessageStatus SendMessageOneWay(TBusMessage* pMes, bool wait = false) = 0;
+
+ /// Like SendMessage but cares about message
+ template <typename T /* <: TBusMessage */>
+ EMessageStatus SendMessageAutoPtr(const TAutoPtr<T>& mes, bool wait = false) {
+ EMessageStatus status = SendMessage(mes.Get(), wait);
+ if (status == MESSAGE_OK)
+ Y_UNUSED(mes.Release());
+ return status;
+ }
+
+ /// Like SendMessageOneWay but cares about message
+ template <typename T /* <: TBusMessage */>
+ EMessageStatus SendMessageOneWayAutoPtr(const TAutoPtr<T>& mes, bool wait = false) {
+ EMessageStatus status = SendMessageOneWay(mes.Get(), wait);
+ if (status == MESSAGE_OK)
+ Y_UNUSED(mes.Release());
+ return status;
+ }
+
+ EMessageStatus SendMessageMove(TBusMessageAutoPtr message, bool wait = false) {
+ return SendMessageAutoPtr(message, wait);
+ }
+
+ EMessageStatus SendMessageOneWayMove(TBusMessageAutoPtr message, bool wait = false) {
+ return SendMessageOneWayAutoPtr(message, wait);
+ }
+
+ // TODO: implement similar one-way methods
+
+ virtual ~TBusClientConnection() {
+ }
+ };
+
+ namespace NPrivate {
+ struct TBusClientConnectionPtrOps {
+ static void Ref(TBusClientConnection*);
+ static void UnRef(TBusClientConnection*);
+ };
+ }
+
+ using TBusClientConnectionPtr = TIntrusivePtr<TBusClientConnection, NPrivate::TBusClientConnectionPtrOps>;
+
+}
diff --git a/library/cpp/messagebus/coreconn.cpp b/library/cpp/messagebus/coreconn.cpp
new file mode 100644
index 0000000000..d9411bb5db
--- /dev/null
+++ b/library/cpp/messagebus/coreconn.cpp
@@ -0,0 +1,30 @@
+#include "coreconn.h"
+
+#include "remote_connection.h"
+
+#include <util/datetime/base.h>
+#include <util/generic/yexception.h>
+#include <util/network/socket.h>
+#include <util/string/util.h>
+#include <util/system/thread.h>
+
+namespace NBus {
+ TBusInstant Now() {
+ return millisec();
+ }
+
+ EIpVersion MakeIpVersion(bool allowIpv4, bool allowIpv6) {
+ if (allowIpv4) {
+ if (allowIpv6) {
+ return EIP_VERSION_ANY;
+ } else {
+ return EIP_VERSION_4;
+ }
+ } else if (allowIpv6) {
+ return EIP_VERSION_6;
+ }
+
+ ythrow yexception() << "Neither of IPv4/IPv6 is allowed.";
+ }
+
+}
diff --git a/library/cpp/messagebus/coreconn.h b/library/cpp/messagebus/coreconn.h
new file mode 100644
index 0000000000..fca228d82e
--- /dev/null
+++ b/library/cpp/messagebus/coreconn.h
@@ -0,0 +1,67 @@
+#pragma once
+
+//////////////////////////////////////////////////////////////
+/// \file
+/// \brief Definitions for asynchonous connection queue
+
+#include "base.h"
+#include "event_loop.h"
+#include "netaddr.h"
+
+#include <util/datetime/base.h>
+#include <util/generic/algorithm.h>
+#include <util/generic/list.h>
+#include <util/generic/map.h>
+#include <util/generic/set.h>
+#include <util/generic/string.h>
+#include <util/generic/vector.h>
+#include <util/network/address.h>
+#include <util/network/ip.h>
+#include <util/network/poller.h>
+#include <util/string/util.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/system/thread.h>
+#include <util/thread/lfqueue.h>
+
+#include <deque>
+#include <utility>
+
+#ifdef NO_ERROR
+#undef NO_ERROR
+#endif
+
+#define BUS_WORKER_CONDVAR
+//#define BUS_WORKER_MIXED
+
+namespace NBus {
+ class TBusConnection;
+ class TBusConnectionFactory;
+ class TBusServerFactory;
+
+ using TBusConnectionList = TList<TBusConnection*>;
+
+ /// @throw yexception
+ EIpVersion MakeIpVersion(bool allowIpv4, bool allowIpv6);
+
+ inline bool WouldBlock() {
+ int syserr = LastSystemError();
+ return syserr == EAGAIN || syserr == EINPROGRESS || syserr == EWOULDBLOCK || syserr == EINTR;
+ }
+
+ class TBusSession;
+
+ struct TMaxConnectedException: public yexception {
+ TMaxConnectedException(unsigned maxConnect) {
+ yexception& exc = *this;
+ exc << TStringBuf("Exceeded maximum number of outstanding connections: ");
+ exc << maxConnect;
+ }
+ };
+
+ enum EPollType {
+ POLL_READ,
+ POLL_WRITE
+ };
+
+}
diff --git a/library/cpp/messagebus/coreconn_ut.cpp b/library/cpp/messagebus/coreconn_ut.cpp
new file mode 100644
index 0000000000..beb6850f26
--- /dev/null
+++ b/library/cpp/messagebus/coreconn_ut.cpp
@@ -0,0 +1,25 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "coreconn.h"
+
+#include <util/generic/yexception.h>
+
+Y_UNIT_TEST_SUITE(TMakeIpVersionTest) {
+ using namespace NBus;
+
+ Y_UNIT_TEST(IpV4Allowed) {
+ UNIT_ASSERT_EQUAL(MakeIpVersion(true, false), EIP_VERSION_4);
+ }
+
+ Y_UNIT_TEST(IpV6Allowed) {
+ UNIT_ASSERT_EQUAL(MakeIpVersion(false, true), EIP_VERSION_6);
+ }
+
+ Y_UNIT_TEST(AllAllowed) {
+ UNIT_ASSERT_EQUAL(MakeIpVersion(true, true), EIP_VERSION_ANY);
+ }
+
+ Y_UNIT_TEST(NothingAllowed) {
+ UNIT_ASSERT_EXCEPTION(MakeIpVersion(false, false), yexception);
+ }
+}
diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver.cpp
new file mode 100644
index 0000000000..23b02d1003
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/debug_receiver.cpp
@@ -0,0 +1,42 @@
+#include "debug_receiver_handler.h"
+#include "debug_receiver_proto.h"
+
+#include <library/cpp/messagebus/ybus.h>
+
+#include <library/cpp/getopt/last_getopt.h>
+#include <library/cpp/lwtrace/all.h>
+
+using namespace NBus;
+
+int main(int argc, char** argv) {
+ NLWTrace::StartLwtraceFromEnv();
+
+ TBusQueueConfig queueConfig;
+ TBusServerSessionConfig sessionConfig;
+
+ NLastGetopt::TOpts opts;
+
+ queueConfig.ConfigureLastGetopt(opts);
+ sessionConfig.ConfigureLastGetopt(opts);
+
+ opts.AddLongOption("port").Required().RequiredArgument("PORT").StoreResult(&sessionConfig.ListenPort);
+
+ opts.SetFreeArgsMax(0);
+
+ NLastGetopt::TOptsParseResult r(&opts, argc, argv);
+
+ TBusMessageQueuePtr q(CreateMessageQueue(queueConfig));
+
+ TDebugReceiverProtocol proto;
+ TDebugReceiverHandler handler;
+
+ TBusServerSessionPtr serverSession = TBusServerSession::Create(&proto, &handler, sessionConfig, q);
+ // TODO: race is here
+ handler.ServerSession = serverSession.Get();
+
+ for (;;) {
+ Sleep(TDuration::Hours(17));
+ }
+
+ return 0;
+}
diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp
new file mode 100644
index 0000000000..05f99e94ca
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.cpp
@@ -0,0 +1,20 @@
+#include "debug_receiver_handler.h"
+
+#include "debug_receiver_proto.h"
+
+#include <util/generic/cast.h>
+#include <util/string/printf.h>
+
+void TDebugReceiverHandler::OnError(TAutoPtr<NBus::TBusMessage>, NBus::EMessageStatus status) {
+ Cerr << "error " << status << "\n";
+}
+
+void TDebugReceiverHandler::OnMessage(NBus::TOnMessageContext& message) {
+ TDebugReceiverMessage* typedMessage = VerifyDynamicCast<TDebugReceiverMessage*>(message.GetMessage());
+ Cerr << "type=" << typedMessage->GetHeader()->Type
+ << " size=" << typedMessage->GetHeader()->Size
+ << " flags=" << Sprintf("0x%04x", (int)typedMessage->GetHeader()->FlagsInternal)
+ << "\n";
+
+ message.ForgetRequest();
+}
diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h
new file mode 100644
index 0000000000..0aed6b9984
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/debug_receiver_handler.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+struct TDebugReceiverHandler: public NBus::IBusServerHandler {
+ NBus::TBusServerSession* ServerSession;
+
+ void OnError(TAutoPtr<NBus::TBusMessage> pMessage, NBus::EMessageStatus status) override;
+ void OnMessage(NBus::TOnMessageContext& message) override;
+};
diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp
new file mode 100644
index 0000000000..0c74f9ecc3
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.cpp
@@ -0,0 +1,20 @@
+#include "debug_receiver_proto.h"
+
+using namespace NBus;
+
+TDebugReceiverProtocol::TDebugReceiverProtocol()
+ : TBusProtocol("debug receiver", 0)
+{
+}
+
+void TDebugReceiverProtocol::Serialize(const NBus::TBusMessage*, TBuffer&) {
+ Y_FAIL("it is receiver only");
+}
+
+TAutoPtr<NBus::TBusMessage> TDebugReceiverProtocol::Deserialize(ui16, TArrayRef<const char> payload) {
+ THolder<TDebugReceiverMessage> r(new TDebugReceiverMessage(ECreateUninitialized()));
+
+ r->Payload.Append(payload.data(), payload.size());
+
+ return r.Release();
+}
diff --git a/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h
new file mode 100644
index 0000000000..d34710dcf7
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/debug_receiver_proto.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+struct TDebugReceiverMessage: public NBus::TBusMessage {
+ /// constructor to create messages on sending end
+ TDebugReceiverMessage(ui16 type)
+ : NBus::TBusMessage(type)
+ {
+ }
+
+ /// constructor with serialzed data to examine the header
+ TDebugReceiverMessage(NBus::ECreateUninitialized)
+ : NBus::TBusMessage(NBus::ECreateUninitialized())
+ {
+ }
+
+ TBuffer Payload;
+};
+
+struct TDebugReceiverProtocol: public NBus::TBusProtocol {
+ TDebugReceiverProtocol();
+
+ void Serialize(const NBus::TBusMessage* mess, TBuffer& data) override;
+
+ TAutoPtr<NBus::TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override;
+};
diff --git a/library/cpp/messagebus/debug_receiver/ya.make b/library/cpp/messagebus/debug_receiver/ya.make
new file mode 100644
index 0000000000..f1b14d35bb
--- /dev/null
+++ b/library/cpp/messagebus/debug_receiver/ya.make
@@ -0,0 +1,17 @@
+PROGRAM(messagebus_debug_receiver)
+
+OWNER(g:messagebus)
+
+SRCS(
+ debug_receiver.cpp
+ debug_receiver_proto.cpp
+ debug_receiver_handler.cpp
+)
+
+PEERDIR(
+ library/cpp/getopt
+ library/cpp/lwtrace
+ library/cpp/messagebus
+)
+
+END()
diff --git a/library/cpp/messagebus/defs.h b/library/cpp/messagebus/defs.h
new file mode 100644
index 0000000000..cb553acc45
--- /dev/null
+++ b/library/cpp/messagebus/defs.h
@@ -0,0 +1,4 @@
+#pragma once
+
+#include <library/cpp/messagebus/config/defs.h>
+
diff --git a/library/cpp/messagebus/dummy_debugger.h b/library/cpp/messagebus/dummy_debugger.h
new file mode 100644
index 0000000000..89a4e18716
--- /dev/null
+++ b/library/cpp/messagebus/dummy_debugger.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include <util/datetime/base.h>
+#include <util/stream/output.h>
+
+#define MB_TRACE() \
+ do { \
+ Cerr << TInstant::Now() << " " << __FILE__ << ":" << __LINE__ << " " << __FUNCTION__ << Endl; \
+ } while (false)
diff --git a/library/cpp/messagebus/duration_histogram.cpp b/library/cpp/messagebus/duration_histogram.cpp
new file mode 100644
index 0000000000..32a0001d41
--- /dev/null
+++ b/library/cpp/messagebus/duration_histogram.cpp
@@ -0,0 +1,74 @@
+#include "duration_histogram.h"
+
+#include <util/generic/singleton.h>
+#include <util/stream/str.h>
+
+namespace {
+ ui64 SecondsRound(TDuration d) {
+ if (d.MilliSeconds() % 1000 >= 500) {
+ return d.Seconds() + 1;
+ } else {
+ return d.Seconds();
+ }
+ }
+
+ ui64 MilliSecondsRound(TDuration d) {
+ if (d.MicroSeconds() % 1000 >= 500) {
+ return d.MilliSeconds() + 1;
+ } else {
+ return d.MilliSeconds();
+ }
+ }
+
+ ui64 MinutesRound(TDuration d) {
+ if (d.Seconds() % 60 >= 30) {
+ return d.Minutes() + 1;
+ } else {
+ return d.Minutes();
+ }
+ }
+
+}
+
+namespace {
+ struct TMarks {
+ std::array<TDuration, TDurationHistogram::Buckets> Marks;
+
+ TMarks() {
+ Marks[0] = TDuration::Zero();
+ for (unsigned i = 1; i < TDurationHistogram::Buckets; ++i) {
+ if (i >= TDurationHistogram::SecondBoundary) {
+ Marks[i] = TDuration::Seconds(1) * (1 << (i - TDurationHistogram::SecondBoundary));
+ } else {
+ Marks[i] = TDuration::Seconds(1) / (1 << (TDurationHistogram::SecondBoundary - i));
+ }
+ }
+ }
+ };
+}
+
+TString TDurationHistogram::LabelBefore(unsigned i) {
+ Y_VERIFY(i < Buckets);
+
+ TDuration d = Singleton<TMarks>()->Marks[i];
+
+ TStringStream ss;
+ if (d == TDuration::Zero()) {
+ ss << "0";
+ } else if (d < TDuration::Seconds(1)) {
+ ss << MilliSecondsRound(d) << "ms";
+ } else if (d < TDuration::Minutes(1)) {
+ ss << SecondsRound(d) << "s";
+ } else {
+ ss << MinutesRound(d) << "m";
+ }
+ return ss.Str();
+}
+
+TString TDurationHistogram::PrintToString() const {
+ TStringStream ss;
+ for (auto time : Times) {
+ ss << time << "\n";
+ }
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/duration_histogram.h b/library/cpp/messagebus/duration_histogram.h
new file mode 100644
index 0000000000..ed060b0101
--- /dev/null
+++ b/library/cpp/messagebus/duration_histogram.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include <util/datetime/base.h>
+#include <util/generic/bitops.h>
+#include <util/generic/string.h>
+
+#include <array>
+
+struct TDurationHistogram {
+ static const unsigned Buckets = 20;
+ std::array<ui64, Buckets> Times;
+
+ static const unsigned SecondBoundary = 11;
+
+ TDurationHistogram() {
+ Times.fill(0);
+ }
+
+ static unsigned BucketFor(TDuration d) {
+ ui64 units = d.MicroSeconds() * (1 << SecondBoundary) / 1000000;
+ if (units == 0) {
+ return 0;
+ }
+ unsigned bucket = GetValueBitCount(units) - 1;
+ if (bucket >= Buckets) {
+ bucket = Buckets - 1;
+ }
+ return bucket;
+ }
+
+ void AddTime(TDuration d) {
+ Times[BucketFor(d)] += 1;
+ }
+
+ TDurationHistogram& operator+=(const TDurationHistogram& that) {
+ for (unsigned i = 0; i < Times.size(); ++i) {
+ Times[i] += that.Times[i];
+ }
+ return *this;
+ }
+
+ static TString LabelBefore(unsigned i);
+
+ TString PrintToString() const;
+};
diff --git a/library/cpp/messagebus/duration_histogram_ut.cpp b/library/cpp/messagebus/duration_histogram_ut.cpp
new file mode 100644
index 0000000000..01bcc095e9
--- /dev/null
+++ b/library/cpp/messagebus/duration_histogram_ut.cpp
@@ -0,0 +1,38 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "duration_histogram.h"
+
+Y_UNIT_TEST_SUITE(TDurationHistogramTest) {
+ Y_UNIT_TEST(BucketFor) {
+ UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(0)));
+ UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(1)));
+ UNIT_ASSERT_VALUES_EQUAL(0u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(900)));
+ UNIT_ASSERT_VALUES_EQUAL(1u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(1500)));
+ UNIT_ASSERT_VALUES_EQUAL(2u, TDurationHistogram::BucketFor(TDuration::MicroSeconds(2500)));
+
+ unsigned sb = TDurationHistogram::SecondBoundary;
+
+ UNIT_ASSERT_VALUES_EQUAL(sb - 1, TDurationHistogram::BucketFor(TDuration::MilliSeconds(999)));
+ UNIT_ASSERT_VALUES_EQUAL(sb, TDurationHistogram::BucketFor(TDuration::MilliSeconds(1000)));
+ UNIT_ASSERT_VALUES_EQUAL(sb, TDurationHistogram::BucketFor(TDuration::MilliSeconds(1001)));
+
+ UNIT_ASSERT_VALUES_EQUAL(TDurationHistogram::Buckets - 1, TDurationHistogram::BucketFor(TDuration::Hours(1)));
+ }
+
+ Y_UNIT_TEST(Simple) {
+ TDurationHistogram h1;
+ h1.AddTime(TDuration::MicroSeconds(1));
+ UNIT_ASSERT_VALUES_EQUAL(1u, h1.Times.front());
+
+ TDurationHistogram h2;
+ h1.AddTime(TDuration::Hours(1));
+ UNIT_ASSERT_VALUES_EQUAL(1u, h1.Times.back());
+ }
+
+ Y_UNIT_TEST(LabelFor) {
+ for (unsigned i = 0; i < TDurationHistogram::Buckets; ++i) {
+ TDurationHistogram::LabelBefore(i);
+ //Cerr << TDurationHistogram::LabelBefore(i) << "\n";
+ }
+ }
+}
diff --git a/library/cpp/messagebus/event_loop.cpp b/library/cpp/messagebus/event_loop.cpp
new file mode 100644
index 0000000000..f685135bed
--- /dev/null
+++ b/library/cpp/messagebus/event_loop.cpp
@@ -0,0 +1,370 @@
+#include "event_loop.h"
+
+#include "network.h"
+#include "thread_extra.h"
+
+#include <util/generic/hash.h>
+#include <util/network/pair.h>
+#include <util/network/poller.h>
+#include <util/system/event.h>
+#include <util/system/mutex.h>
+#include <util/system/thread.h>
+#include <util/system/yassert.h>
+#include <util/thread/lfqueue.h>
+
+#include <errno.h>
+
+using namespace NEventLoop;
+
+namespace {
+ enum ERunningState {
+ EVENT_LOOP_CREATED,
+ EVENT_LOOP_RUNNING,
+ EVENT_LOOP_STOPPED,
+ };
+
+ enum EOperation {
+ OP_READ = 1,
+ OP_WRITE = 2,
+ OP_READ_WRITE = OP_READ | OP_WRITE,
+ };
+}
+
+class TChannel::TImpl {
+public:
+ TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr, void* cookie);
+ ~TImpl();
+
+ void EnableRead();
+ void DisableRead();
+ void EnableWrite();
+ void DisableWrite();
+
+ void Unregister();
+
+ SOCKET GetSocket() const;
+ TSocket GetSocketPtr() const;
+
+ void Update(int pollerFlags, bool enable);
+ void CallHandler();
+
+ TEventLoop::TImpl* EventLoop;
+ TSocket Socket;
+ TEventHandlerPtr EventHandler;
+ void* Cookie;
+
+ TMutex Mutex;
+
+ int CurrentFlags;
+ bool Close;
+};
+
+class TEventLoop::TImpl {
+public:
+ TImpl(const char* name);
+
+ void Run();
+ void Wakeup();
+ void Stop();
+
+ TChannelPtr Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie);
+ void Unregister(SOCKET socket);
+
+ typedef THashMap<SOCKET, TChannelPtr> TData;
+
+ void AddToPoller(SOCKET socket, void* cookie, int flags);
+
+ TMutex Mutex;
+
+ const char* Name;
+
+ TAtomic RunningState;
+ TAtomic StopSignal;
+ TSystemEvent StoppedEvent;
+ TData Data;
+
+ TLockFreeQueue<SOCKET> SocketsToRemove;
+
+ TSocketPoller Poller;
+ TSocketHolder WakeupReadSocket;
+ TSocketHolder WakeupWriteSocket;
+};
+
+TChannel::~TChannel() {
+}
+
+void TChannel::EnableRead() {
+ Impl->EnableRead();
+}
+
+void TChannel::DisableRead() {
+ Impl->DisableRead();
+}
+
+void TChannel::EnableWrite() {
+ Impl->EnableWrite();
+}
+
+void TChannel::DisableWrite() {
+ Impl->DisableWrite();
+}
+
+void TChannel::Unregister() {
+ Impl->Unregister();
+}
+
+SOCKET TChannel::GetSocket() const {
+ return Impl->GetSocket();
+}
+
+TSocket TChannel::GetSocketPtr() const {
+ return Impl->GetSocketPtr();
+}
+
+TChannel::TChannel(TImpl* impl)
+ : Impl(impl)
+{
+}
+
+TEventLoop::TEventLoop(const char* name)
+ : Impl(new TImpl(name))
+{
+}
+
+TEventLoop::~TEventLoop() {
+}
+
+void TEventLoop::Run() {
+ Impl->Run();
+}
+
+void TEventLoop::Stop() {
+ Impl->Stop();
+}
+
+bool TEventLoop::IsRunning() {
+ return AtomicGet(Impl->RunningState) == EVENT_LOOP_RUNNING;
+}
+
+TChannelPtr TEventLoop::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) {
+ return Impl->Register(socket, eventHandler, cookie);
+}
+
+TChannel::TImpl::TImpl(TEventLoop::TImpl* eventLoop, TSocket socket, TEventHandlerPtr eventHandler, void* cookie)
+ : EventLoop(eventLoop)
+ , Socket(socket)
+ , EventHandler(eventHandler)
+ , Cookie(cookie)
+ , CurrentFlags(0)
+ , Close(false)
+{
+}
+
+TChannel::TImpl::~TImpl() {
+ Y_ASSERT(Close);
+}
+
+void TChannel::TImpl::EnableRead() {
+ Update(OP_READ, true);
+}
+
+void TChannel::TImpl::DisableRead() {
+ Update(OP_READ, false);
+}
+
+void TChannel::TImpl::EnableWrite() {
+ Update(OP_WRITE, true);
+}
+
+void TChannel::TImpl::DisableWrite() {
+ Update(OP_WRITE, false);
+}
+
+void TChannel::TImpl::Unregister() {
+ TGuard<TMutex> guard(Mutex);
+
+ if (Close) {
+ return;
+ }
+
+ Close = true;
+ if (CurrentFlags != 0) {
+ EventLoop->Poller.Unwait(Socket);
+ CurrentFlags = 0;
+ }
+ EventHandler.Drop();
+
+ EventLoop->SocketsToRemove.Enqueue(Socket);
+ EventLoop->Wakeup();
+}
+
+void TChannel::TImpl::Update(int flags, bool enable) {
+ TGuard<TMutex> guard(Mutex);
+
+ if (Close) {
+ return;
+ }
+
+ int newFlags = enable ? (CurrentFlags | flags) : (CurrentFlags & ~flags);
+
+ if (CurrentFlags == newFlags) {
+ return;
+ }
+
+ if (!newFlags) {
+ EventLoop->Poller.Unwait(Socket);
+ } else {
+ void* cookie = reinterpret_cast<void*>(this);
+ EventLoop->AddToPoller(Socket, cookie, newFlags);
+ }
+
+ CurrentFlags = newFlags;
+}
+
+SOCKET TChannel::TImpl::GetSocket() const {
+ return Socket;
+}
+
+TSocket TChannel::TImpl::GetSocketPtr() const {
+ return Socket;
+}
+
+void TChannel::TImpl::CallHandler() {
+ TEventHandlerPtr handler;
+
+ {
+ TGuard<TMutex> guard(Mutex);
+
+ // other thread may have re-added socket to epoll
+ // so even if CurrentFlags is 0, epoll may fire again
+ // so please use non-blocking operations
+ CurrentFlags = 0;
+
+ if (Close) {
+ return;
+ }
+
+ handler = EventHandler;
+ }
+
+ if (!!handler) {
+ handler->HandleEvent(Socket, Cookie);
+ }
+}
+
+TEventLoop::TImpl::TImpl(const char* name)
+ : Name(name)
+ , RunningState(EVENT_LOOP_CREATED)
+ , StopSignal(0)
+{
+ SOCKET wakeupSockets[2];
+
+ if (SocketPair(wakeupSockets) < 0) {
+ Y_FAIL("failed to create socket pair for wakeup sockets: %s", LastSystemErrorText());
+ }
+
+ TSocketHolder wakeupReadSocket(wakeupSockets[0]);
+ TSocketHolder wakeupWriteSocket(wakeupSockets[1]);
+
+ WakeupReadSocket.Swap(wakeupReadSocket);
+ WakeupWriteSocket.Swap(wakeupWriteSocket);
+
+ SetNonBlock(WakeupWriteSocket, true);
+ SetNonBlock(WakeupReadSocket, true);
+
+ Poller.WaitRead(WakeupReadSocket,
+ reinterpret_cast<void*>(this));
+}
+
+void TEventLoop::TImpl::Run() {
+ bool res = AtomicCas(&RunningState, EVENT_LOOP_RUNNING, EVENT_LOOP_CREATED);
+ Y_VERIFY(res, "Invalid mbus event loop state");
+
+ if (!!Name) {
+ SetCurrentThreadName(Name);
+ }
+
+ while (AtomicGet(StopSignal) == 0) {
+ void* cookies[1024];
+ const size_t count = Poller.WaitI(cookies, Y_ARRAY_SIZE(cookies));
+
+ void** end = cookies + count;
+ for (void** c = cookies; c != end; ++c) {
+ TChannel::TImpl* s = reinterpret_cast<TChannel::TImpl*>(*c);
+
+ if (*c == this) {
+ char buf[0x1000];
+ if (NBus::NPrivate::SocketRecv(WakeupReadSocket, buf) < 0) {
+ Y_FAIL("failed to recv from wakeup socket: %s", LastSystemErrorText());
+ }
+ continue;
+ }
+
+ s->CallHandler();
+ }
+
+ SOCKET socket = -1;
+ while (SocketsToRemove.Dequeue(&socket)) {
+ TGuard<TMutex> guard(Mutex);
+ Y_VERIFY(Data.erase(socket) == 1, "must be removed once");
+ }
+ }
+
+ {
+ TGuard<TMutex> guard(Mutex);
+ for (auto& it : Data) {
+ it.second->Unregister();
+ }
+
+ // release file descriptors
+ Data.clear();
+ }
+
+ res = AtomicCas(&RunningState, EVENT_LOOP_STOPPED, EVENT_LOOP_RUNNING);
+
+ Y_VERIFY(res);
+
+ StoppedEvent.Signal();
+}
+
+void TEventLoop::TImpl::Stop() {
+ AtomicSet(StopSignal, 1);
+
+ if (AtomicGet(RunningState) == EVENT_LOOP_RUNNING) {
+ Wakeup();
+
+ StoppedEvent.WaitI();
+ }
+}
+
+TChannelPtr TEventLoop::TImpl::Register(TSocket socket, TEventHandlerPtr eventHandler, void* cookie) {
+ Y_VERIFY(socket != INVALID_SOCKET, "must be a valid socket");
+
+ TChannelPtr channel = new TChannel(new TChannel::TImpl(this, socket, eventHandler, cookie));
+
+ TGuard<TMutex> guard(Mutex);
+
+ Y_VERIFY(Data.insert(std::make_pair(socket, channel)).second, "must not be already inserted");
+
+ return channel;
+}
+
+void TEventLoop::TImpl::Wakeup() {
+ if (NBus::NPrivate::SocketSend(WakeupWriteSocket, TArrayRef<const char>("", 1)) < 0) {
+ if (LastSystemError() != EAGAIN) {
+ Y_FAIL("failed to send to wakeup socket: %s", LastSystemErrorText());
+ }
+ }
+}
+
+void TEventLoop::TImpl::AddToPoller(SOCKET socket, void* cookie, int flags) {
+ if (flags == OP_READ) {
+ Poller.WaitReadOneShot(socket, cookie);
+ } else if (flags == OP_WRITE) {
+ Poller.WaitWriteOneShot(socket, cookie);
+ } else if (flags == OP_READ_WRITE) {
+ Poller.WaitReadWriteOneShot(socket, cookie);
+ } else {
+ Y_FAIL("Wrong flags: %d", int(flags));
+ }
+}
diff --git a/library/cpp/messagebus/event_loop.h b/library/cpp/messagebus/event_loop.h
new file mode 100644
index 0000000000..d5b0a53b0c
--- /dev/null
+++ b/library/cpp/messagebus/event_loop.h
@@ -0,0 +1,72 @@
+#pragma once
+
+#include <util/generic/object_counter.h>
+#include <util/generic/ptr.h>
+#include <util/network/init.h>
+#include <util/network/socket.h>
+
+namespace NEventLoop {
+ struct IEventHandler
+ : public TAtomicRefCount<IEventHandler> {
+ virtual void HandleEvent(SOCKET socket, void* cookie) = 0;
+ virtual ~IEventHandler() {
+ }
+ };
+
+ typedef TIntrusivePtr<IEventHandler> TEventHandlerPtr;
+
+ class TEventLoop;
+
+ // TODO: make TChannel itself a pointer
+ // to avoid confusion with Drop and Unregister
+ class TChannel
+ : public TAtomicRefCount<TChannel> {
+ public:
+ ~TChannel();
+
+ void EnableRead();
+ void DisableRead();
+ void EnableWrite();
+ void DisableWrite();
+
+ void Unregister();
+
+ SOCKET GetSocket() const;
+ TSocket GetSocketPtr() const;
+
+ private:
+ class TImpl;
+ friend class TEventLoop;
+
+ TObjectCounter<TChannel> ObjectCounter;
+
+ TChannel(TImpl*);
+
+ private:
+ THolder<TImpl> Impl;
+ };
+
+ typedef TIntrusivePtr<TChannel> TChannelPtr;
+
+ class TEventLoop {
+ public:
+ TEventLoop(const char* name = nullptr);
+ ~TEventLoop();
+
+ void Run();
+ void Stop();
+ bool IsRunning();
+
+ TChannelPtr Register(TSocket socket, TEventHandlerPtr, void* cookie = nullptr);
+
+ private:
+ class TImpl;
+ friend class TChannel;
+
+ TObjectCounter<TEventLoop> ObjectCounter;
+
+ private:
+ THolder<TImpl> Impl;
+ };
+
+}
diff --git a/library/cpp/messagebus/extra_ref.h b/library/cpp/messagebus/extra_ref.h
new file mode 100644
index 0000000000..2927123266
--- /dev/null
+++ b/library/cpp/messagebus/extra_ref.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <util/system/yassert.h>
+
+class TExtraRef {
+ TAtomic Holds;
+
+public:
+ TExtraRef()
+ : Holds(false)
+ {
+ }
+ ~TExtraRef() {
+ Y_VERIFY(!Holds);
+ }
+
+ template <typename TThis>
+ void Retain(TThis* thiz) {
+ if (AtomicGet(Holds)) {
+ return;
+ }
+ if (AtomicCas(&Holds, 1, 0)) {
+ thiz->Ref();
+ }
+ }
+
+ template <typename TThis>
+ void Release(TThis* thiz) {
+ if (!AtomicGet(Holds)) {
+ return;
+ }
+ if (AtomicCas(&Holds, 0, 1)) {
+ thiz->UnRef();
+ }
+ }
+};
diff --git a/library/cpp/messagebus/futex_like.cpp b/library/cpp/messagebus/futex_like.cpp
new file mode 100644
index 0000000000..7f965126db
--- /dev/null
+++ b/library/cpp/messagebus/futex_like.cpp
@@ -0,0 +1,55 @@
+#include <util/system/platform.h>
+
+#ifdef _linux_
+#include <sys/syscall.h>
+#include <linux/futex.h>
+
+#if !defined(SYS_futex)
+#define SYS_futex __NR_futex
+#endif
+#endif
+
+#include <errno.h>
+
+#include <util/system/yassert.h>
+
+#include "futex_like.h"
+
+#ifdef _linux_
+namespace {
+ int futex(int* uaddr, int op, int val, const struct timespec* timeout,
+ int* uaddr2, int val3) {
+ return syscall(SYS_futex, uaddr, op, val, timeout, uaddr2, val3);
+ }
+}
+#endif
+
+void TFutexLike::Wake(size_t count) {
+ Y_ASSERT(count > 0);
+#ifdef _linux_
+ if (count > unsigned(Max<int>())) {
+ count = Max<int>();
+ }
+ int r = futex(&Value, FUTEX_WAKE, count, nullptr, nullptr, 0);
+ Y_VERIFY(r >= 0, "futex_wake failed: %s", strerror(errno));
+#else
+ TGuard<TMutex> guard(Mutex);
+ if (count == 1) {
+ CondVar.Signal();
+ } else {
+ CondVar.BroadCast();
+ }
+#endif
+}
+
+void TFutexLike::Wait(int expected) {
+#ifdef _linux_
+ int r = futex(&Value, FUTEX_WAIT, expected, nullptr, nullptr, 0);
+ Y_VERIFY(r >= 0 || errno == EWOULDBLOCK, "futex_wait failed: %s", strerror(errno));
+#else
+ TGuard<TMutex> guard(Mutex);
+ if (expected == Get()) {
+ CondVar.WaitI(Mutex);
+ }
+#endif
+}
diff --git a/library/cpp/messagebus/futex_like.h b/library/cpp/messagebus/futex_like.h
new file mode 100644
index 0000000000..31d60c60f1
--- /dev/null
+++ b/library/cpp/messagebus/futex_like.h
@@ -0,0 +1,86 @@
+#pragma once
+
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/system/platform.h>
+
+class TFutexLike {
+private:
+#ifdef _linux_
+ int Value;
+#else
+ TAtomic Value;
+ TMutex Mutex;
+ TCondVar CondVar;
+#endif
+
+public:
+ TFutexLike()
+ : Value(0)
+ {
+ }
+
+ int AddAndGet(int add) {
+#ifdef _linux_
+ //return __atomic_add_fetch(&Value, add, __ATOMIC_SEQ_CST);
+ return __sync_add_and_fetch(&Value, add);
+#else
+ return AtomicAdd(Value, add);
+#endif
+ }
+
+ int GetAndAdd(int add) {
+ return AddAndGet(add) - add;
+ }
+
+// until we have modern GCC
+#if 0
+ int GetAndSet(int newValue) {
+#ifdef _linux_
+ return __atomic_exchange_n(&Value, newValue, __ATOMIC_SEQ_CST);
+#else
+ return AtomicSwap(&Value, newValue);
+#endif
+ }
+#endif
+
+ int Get() {
+#ifdef _linux_
+ //return __atomic_load_n(&Value, __ATOMIC_SEQ_CST);
+ __sync_synchronize();
+ return Value;
+#else
+ return AtomicGet(Value);
+#endif
+ }
+
+ void Set(int newValue) {
+#ifdef _linux_
+ //__atomic_store_n(&Value, newValue, __ATOMIC_SEQ_CST);
+ Value = newValue;
+ __sync_synchronize();
+#else
+ AtomicSet(Value, newValue);
+#endif
+ }
+
+ int GetAndIncrement() {
+ return AddAndGet(1) - 1;
+ }
+
+ int IncrementAndGet() {
+ return AddAndGet(1);
+ }
+
+ int GetAndDecrement() {
+ return AddAndGet(-1) + 1;
+ }
+
+ int DecrementAndGet() {
+ return AddAndGet(-1);
+ }
+
+ void Wake(size_t count = Max<size_t>());
+
+ void Wait(int expected);
+};
diff --git a/library/cpp/messagebus/handler.cpp b/library/cpp/messagebus/handler.cpp
new file mode 100644
index 0000000000..333bd52934
--- /dev/null
+++ b/library/cpp/messagebus/handler.cpp
@@ -0,0 +1,36 @@
+#include "handler.h"
+
+#include "remote_server_connection.h"
+#include "ybus.h"
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+void IBusErrorHandler::OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) {
+ Y_UNUSED(pMessage);
+ Y_UNUSED(status);
+}
+void IBusServerHandler::OnSent(TAutoPtr<TBusMessage> pMessage) {
+ Y_UNUSED(pMessage);
+}
+void IBusClientHandler::OnMessageSent(TBusMessage* pMessage) {
+ Y_UNUSED(pMessage);
+}
+void IBusClientHandler::OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) {
+ Y_UNUSED(pMessage);
+}
+
+void IBusClientHandler::OnClientConnectionEvent(const TClientConnectionEvent&) {
+}
+
+void TOnMessageContext::ForgetRequest() {
+ Session->ForgetRequest(Ident);
+}
+
+TNetAddr TOnMessageContext::GetPeerAddrNetAddr() const {
+ return Ident.GetNetAddr();
+}
+
+bool TOnMessageContext::IsConnectionAlive() const {
+ return !!Ident.Connection && Ident.Connection->IsAlive();
+}
diff --git a/library/cpp/messagebus/handler.h b/library/cpp/messagebus/handler.h
new file mode 100644
index 0000000000..60002c68a6
--- /dev/null
+++ b/library/cpp/messagebus/handler.h
@@ -0,0 +1,135 @@
+#pragma once
+
+#include "defs.h"
+#include "message.h"
+#include "message_status.h"
+#include "use_after_free_checker.h"
+#include "use_count_checker.h"
+
+#include <util/generic/noncopyable.h>
+
+namespace NBus {
+ /////////////////////////////////////////////////////////////////
+ /// \brief Interface to message bus handler
+
+ struct IBusErrorHandler {
+ friend struct ::NBus::NPrivate::TBusSessionImpl;
+
+ private:
+ TUseAfterFreeChecker UseAfterFreeChecker;
+ TUseCountChecker UseCountChecker;
+
+ public:
+ /// called when message or reply can't be delivered
+ virtual void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status);
+
+ virtual ~IBusErrorHandler() {
+ }
+ };
+
+ class TClientConnectionEvent {
+ public:
+ enum EType {
+ CONNECTED,
+ DISCONNECTED,
+ };
+
+ private:
+ EType Type;
+ ui64 Id;
+ TNetAddr Addr;
+
+ public:
+ TClientConnectionEvent(EType type, ui64 id, TNetAddr addr)
+ : Type(type)
+ , Id(id)
+ , Addr(addr)
+ {
+ }
+
+ EType GetType() const {
+ return Type;
+ }
+ ui64 GetId() const {
+ return Id;
+ }
+ TNetAddr GetAddr() const {
+ return Addr;
+ }
+ };
+
+ class TOnMessageContext : TNonCopyable {
+ private:
+ THolder<TBusMessage> Message;
+ TBusIdentity Ident;
+ // TODO: we don't need to store session, we have connection in ident
+ TBusServerSession* Session;
+
+ public:
+ TOnMessageContext()
+ : Session()
+ {
+ }
+ TOnMessageContext(TAutoPtr<TBusMessage> message, TBusIdentity& ident, TBusServerSession* session)
+ : Message(message)
+ , Session(session)
+ {
+ Ident.Swap(ident);
+ }
+
+ bool IsInWork() const {
+ return Ident.IsInWork();
+ }
+
+ bool operator!() const {
+ return !IsInWork();
+ }
+
+ TBusMessage* GetMessage() {
+ return Message.Get();
+ }
+
+ TBusMessage* ReleaseMessage() {
+ return Message.Release();
+ }
+
+ TBusServerSession* GetSession() {
+ return Session;
+ }
+
+ template <typename U /* <: TBusMessage */>
+ EMessageStatus SendReplyAutoPtr(TAutoPtr<U>& rep);
+
+ EMessageStatus SendReplyMove(TBusMessageAutoPtr response);
+
+ void AckMessage(TBusIdentity& ident);
+
+ void ForgetRequest();
+
+ void Swap(TOnMessageContext& that) {
+ DoSwap(Message, that.Message);
+ Ident.Swap(that.Ident);
+ DoSwap(Session, that.Session);
+ }
+
+ TNetAddr GetPeerAddrNetAddr() const;
+
+ bool IsConnectionAlive() const;
+ };
+
+ struct IBusServerHandler: public IBusErrorHandler {
+ virtual void OnMessage(TOnMessageContext& onMessage) = 0;
+ /// called when reply has been sent from destination
+ virtual void OnSent(TAutoPtr<TBusMessage> pMessage);
+ };
+
+ struct IBusClientHandler: public IBusErrorHandler {
+ /// called on source when reply arrives from destination
+ virtual void OnReply(TAutoPtr<TBusMessage> pMessage, TAutoPtr<TBusMessage> pReply) = 0;
+ /// called when client side message has gone into wire, place to call AckMessage()
+ virtual void OnMessageSent(TBusMessage* pMessage);
+ virtual void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage);
+ virtual void OnClientConnectionEvent(const TClientConnectionEvent&);
+ };
+
+}
diff --git a/library/cpp/messagebus/handler_impl.h b/library/cpp/messagebus/handler_impl.h
new file mode 100644
index 0000000000..6593f04cc3
--- /dev/null
+++ b/library/cpp/messagebus/handler_impl.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include "handler.h"
+#include "local_flags.h"
+#include "session.h"
+
+namespace NBus {
+ template <typename U /* <: TBusMessage */>
+ EMessageStatus TOnMessageContext::SendReplyAutoPtr(TAutoPtr<U>& response) {
+ return Session->SendReplyAutoPtr(Ident, response);
+ }
+
+ inline EMessageStatus TOnMessageContext::SendReplyMove(TBusMessageAutoPtr response) {
+ return SendReplyAutoPtr(response);
+ }
+
+ inline void TOnMessageContext::AckMessage(TBusIdentity& ident) {
+ Y_VERIFY(Ident.LocalFlags == NPrivate::MESSAGE_IN_WORK);
+ Y_VERIFY(ident.LocalFlags == 0);
+ Ident.Swap(ident);
+ }
+
+}
diff --git a/library/cpp/messagebus/hash.h b/library/cpp/messagebus/hash.h
new file mode 100644
index 0000000000..cc1b136a86
--- /dev/null
+++ b/library/cpp/messagebus/hash.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <util/str_stl.h>
+#include <util/digest/numeric.h>
+
+namespace NBus {
+ namespace NPrivate {
+ template <typename T>
+ size_t Hash(const T& val) {
+ return THash<T>()(val);
+ }
+
+ template <typename T, typename U>
+ size_t HashValues(const T& a, const U& b) {
+ return CombineHashes(Hash(a), Hash(b));
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/key_value_printer.cpp b/library/cpp/messagebus/key_value_printer.cpp
new file mode 100644
index 0000000000..c8592145c7
--- /dev/null
+++ b/library/cpp/messagebus/key_value_printer.cpp
@@ -0,0 +1,46 @@
+#include "key_value_printer.h"
+
+#include <util/stream/format.h>
+
+TKeyValuePrinter::TKeyValuePrinter(const TString& sep)
+ : Sep(sep)
+{
+}
+
+TKeyValuePrinter::~TKeyValuePrinter() {
+}
+
+void TKeyValuePrinter::AddRowImpl(const TString& key, const TString& value, bool alignLeft) {
+ Keys.push_back(key);
+ Values.push_back(value);
+ AlignLefts.push_back(alignLeft);
+}
+
+TString TKeyValuePrinter::PrintToString() const {
+ if (Keys.empty()) {
+ return TString();
+ }
+
+ size_t keyWidth = 0;
+ size_t valueWidth = 0;
+
+ for (size_t i = 0; i < Keys.size(); ++i) {
+ keyWidth = Max(keyWidth, Keys.at(i).size());
+ valueWidth = Max(valueWidth, Values.at(i).size());
+ }
+
+ TStringStream ss;
+
+ for (size_t i = 0; i < Keys.size(); ++i) {
+ ss << RightPad(Keys.at(i), keyWidth);
+ ss << Sep;
+ if (AlignLefts.at(i)) {
+ ss << Values.at(i);
+ } else {
+ ss << LeftPad(Values.at(i), valueWidth);
+ }
+ ss << Endl;
+ }
+
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/key_value_printer.h b/library/cpp/messagebus/key_value_printer.h
new file mode 100644
index 0000000000..bca1fde50e
--- /dev/null
+++ b/library/cpp/messagebus/key_value_printer.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include <util/generic/string.h>
+#include <util/generic/typetraits.h>
+#include <util/generic/vector.h>
+#include <util/string/cast.h>
+
+class TKeyValuePrinter {
+private:
+ TString Sep;
+ TVector<TString> Keys;
+ TVector<TString> Values;
+ TVector<bool> AlignLefts;
+
+public:
+ TKeyValuePrinter(const TString& sep = TString(": "));
+ ~TKeyValuePrinter();
+
+ template <typename TKey, typename TValue>
+ void AddRow(const TKey& key, const TValue& value, bool leftAlign = !std::is_integral<TValue>::value) {
+ return AddRowImpl(ToString(key), ToString(value), leftAlign);
+ }
+
+ TString PrintToString() const;
+
+private:
+ void AddRowImpl(const TString& key, const TString& value, bool leftAlign);
+};
diff --git a/library/cpp/messagebus/latch.h b/library/cpp/messagebus/latch.h
new file mode 100644
index 0000000000..373f4c0e13
--- /dev/null
+++ b/library/cpp/messagebus/latch.h
@@ -0,0 +1,53 @@
+#pragma once
+
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+
+class TLatch {
+private:
+ // 0 for unlocked, 1 for locked
+ TAtomic Locked;
+ TMutex Mutex;
+ TCondVar CondVar;
+
+public:
+ TLatch()
+ : Locked(0)
+ {
+ }
+
+ void Wait() {
+ // optimistic path
+ if (AtomicGet(Locked) == 0) {
+ return;
+ }
+
+ TGuard<TMutex> guard(Mutex);
+ while (AtomicGet(Locked) == 1) {
+ CondVar.WaitI(Mutex);
+ }
+ }
+
+ bool TryWait() {
+ return AtomicGet(Locked) == 0;
+ }
+
+ void Unlock() {
+ // optimistic path
+ if (AtomicGet(Locked) == 0) {
+ return;
+ }
+
+ TGuard<TMutex> guard(Mutex);
+ AtomicSet(Locked, 0);
+ CondVar.BroadCast();
+ }
+
+ void Lock() {
+ AtomicSet(Locked, 1);
+ }
+
+ bool IsLocked() {
+ return AtomicGet(Locked);
+ }
+};
diff --git a/library/cpp/messagebus/latch_ut.cpp b/library/cpp/messagebus/latch_ut.cpp
new file mode 100644
index 0000000000..bfab04f527
--- /dev/null
+++ b/library/cpp/messagebus/latch_ut.cpp
@@ -0,0 +1,20 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "latch.h"
+
+Y_UNIT_TEST_SUITE(TLatch) {
+ Y_UNIT_TEST(Simple) {
+ TLatch latch;
+ UNIT_ASSERT(latch.TryWait());
+ latch.Lock();
+ UNIT_ASSERT(!latch.TryWait());
+ latch.Lock();
+ latch.Lock();
+ UNIT_ASSERT(!latch.TryWait());
+ latch.Unlock();
+ UNIT_ASSERT(latch.TryWait());
+ latch.Unlock();
+ latch.Unlock();
+ UNIT_ASSERT(latch.TryWait());
+ }
+}
diff --git a/library/cpp/messagebus/left_right_buffer.h b/library/cpp/messagebus/left_right_buffer.h
new file mode 100644
index 0000000000..f937cefad0
--- /dev/null
+++ b/library/cpp/messagebus/left_right_buffer.h
@@ -0,0 +1,78 @@
+#pragma once
+
+#include <util/generic/buffer.h>
+#include <util/generic/noncopyable.h>
+#include <util/system/yassert.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TLeftRightBuffer : TNonCopyable {
+ private:
+ TBuffer Buffer;
+ size_t Left;
+
+ void CheckInvariant() {
+ Y_ASSERT(Left <= Buffer.Size());
+ }
+
+ public:
+ TLeftRightBuffer()
+ : Left(0)
+ {
+ }
+
+ TBuffer& GetBuffer() {
+ return Buffer;
+ }
+
+ size_t Capacity() {
+ return Buffer.Capacity();
+ }
+
+ void Clear() {
+ Buffer.Clear();
+ Left = 0;
+ }
+
+ void Reset() {
+ Buffer.Reset();
+ Left = 0;
+ }
+
+ void Compact() {
+ Buffer.ChopHead(Left);
+ Left = 0;
+ }
+
+ char* LeftPos() {
+ return Buffer.Data() + Left;
+ }
+
+ size_t LeftSize() {
+ return Left;
+ }
+
+ void LeftProceed(size_t count) {
+ Y_ASSERT(count <= Size());
+ Left += count;
+ }
+
+ size_t Size() {
+ return Buffer.Size() - Left;
+ }
+
+ bool Empty() {
+ return Size() == 0;
+ }
+
+ char* RightPos() {
+ return Buffer.Data() + Buffer.Size();
+ }
+
+ size_t Avail() {
+ return Buffer.Avail();
+ }
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/lfqueue_batch.h b/library/cpp/messagebus/lfqueue_batch.h
new file mode 100644
index 0000000000..8128d3154d
--- /dev/null
+++ b/library/cpp/messagebus/lfqueue_batch.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <library/cpp/messagebus/actor/temp_tls_vector.h>
+
+#include <util/generic/vector.h>
+#include <util/thread/lfstack.h>
+
+template <typename T, template <typename, class> class TVectorType = TVector>
+class TLockFreeQueueBatch {
+private:
+ TLockFreeStack<TVectorType<T, std::allocator<T>>*> Stack;
+
+public:
+ bool IsEmpty() {
+ return Stack.IsEmpty();
+ }
+
+ void EnqueueAll(TAutoPtr<TVectorType<T, std::allocator<T>>> vec) {
+ Stack.Enqueue(vec.Release());
+ }
+
+ void DequeueAllSingleConsumer(TVectorType<T, std::allocator<T>>* r) {
+ TTempTlsVector<TVectorType<T, std::allocator<T>>*> vs;
+ Stack.DequeueAllSingleConsumer(vs.GetVector());
+
+ for (typename TVector<TVectorType<T, std::allocator<T>>*>::reverse_iterator i = vs.GetVector()->rbegin();
+ i != vs.GetVector()->rend(); ++i) {
+ if (i == vs.GetVector()->rend()) {
+ r->swap(**i);
+ } else {
+ r->insert(r->end(), (*i)->begin(), (*i)->end());
+ }
+ delete *i;
+ }
+ }
+};
diff --git a/library/cpp/messagebus/lfqueue_batch_ut.cpp b/library/cpp/messagebus/lfqueue_batch_ut.cpp
new file mode 100644
index 0000000000..f80434c0d4
--- /dev/null
+++ b/library/cpp/messagebus/lfqueue_batch_ut.cpp
@@ -0,0 +1,56 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "lfqueue_batch.h"
+
+Y_UNIT_TEST_SUITE(TLockFreeQueueBatch) {
+ Y_UNIT_TEST(Order1) {
+ TLockFreeQueueBatch<unsigned> q;
+ {
+ TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>);
+ v->push_back(0);
+ v->push_back(1);
+ q.EnqueueAll(v);
+ }
+
+ TVector<unsigned> r;
+ q.DequeueAllSingleConsumer(&r);
+
+ UNIT_ASSERT_VALUES_EQUAL(2u, r.size());
+ for (unsigned i = 0; i < 2; ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(i, r[i]);
+ }
+
+ r.clear();
+ q.DequeueAllSingleConsumer(&r);
+ UNIT_ASSERT_VALUES_EQUAL(0u, r.size());
+ }
+
+ Y_UNIT_TEST(Order2) {
+ TLockFreeQueueBatch<unsigned> q;
+ {
+ TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>);
+ v->push_back(0);
+ v->push_back(1);
+ q.EnqueueAll(v);
+ }
+ {
+ TAutoPtr<TVector<unsigned>> v(new TVector<unsigned>);
+ v->push_back(2);
+ v->push_back(3);
+ v->push_back(4);
+ q.EnqueueAll(v);
+ }
+
+ TVector<unsigned> r;
+ q.DequeueAllSingleConsumer(&r);
+
+ UNIT_ASSERT_VALUES_EQUAL(5u, r.size());
+ for (unsigned i = 0; i < 5; ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(i, r[i]);
+ }
+
+ r.clear();
+ q.DequeueAllSingleConsumer(&r);
+ UNIT_ASSERT_VALUES_EQUAL(0u, r.size());
+ }
+}
diff --git a/library/cpp/messagebus/local_flags.cpp b/library/cpp/messagebus/local_flags.cpp
new file mode 100644
index 0000000000..877e533f76
--- /dev/null
+++ b/library/cpp/messagebus/local_flags.cpp
@@ -0,0 +1,32 @@
+#include "local_flags.h"
+
+#include <util/stream/str.h>
+#include <util/string/printf.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TString NBus::NPrivate::LocalFlagSetToString(ui32 flags0) {
+ if (flags0 == 0) {
+ return "0";
+ }
+
+ ui32 flags = flags0;
+
+ TStringStream ss;
+#define P(name, value, ...) \
+ do \
+ if (flags & value) { \
+ if (!ss.Str().empty()) { \
+ ss << "|"; \
+ } \
+ ss << #name; \
+ flags &= ~name; \
+ } \
+ while (false);
+ MESSAGE_LOCAL_FLAGS_MAP(P)
+ if (flags != 0) {
+ return Sprintf("0x%x", unsigned(flags0));
+ }
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/local_flags.h b/library/cpp/messagebus/local_flags.h
new file mode 100644
index 0000000000..f589283188
--- /dev/null
+++ b/library/cpp/messagebus/local_flags.h
@@ -0,0 +1,26 @@
+#pragma once
+
+#include <library/cpp/deprecated/enum_codegen/enum_codegen.h>
+
+#include <util/generic/string.h>
+#include <util/stream/output.h>
+
+namespace NBus {
+ namespace NPrivate {
+#define MESSAGE_LOCAL_FLAGS_MAP(XX) \
+ XX(MESSAGE_REPLY_INTERNAL, 0x0001) \
+ XX(MESSAGE_IN_WORK, 0x0002) \
+ XX(MESSAGE_IN_FLIGHT_ON_CLIENT, 0x0004) \
+ XX(MESSAGE_REPLY_IS_BEGING_SENT, 0x0008) \
+ XX(MESSAGE_ONE_WAY_INTERNAL, 0x0010) \
+ /**/
+
+ enum EMessageLocalFlags {
+ MESSAGE_LOCAL_FLAGS_MAP(ENUM_VALUE_GEN)
+ };
+
+ ENUM_TO_STRING(EMessageLocalFlags, MESSAGE_LOCAL_FLAGS_MAP)
+
+ TString LocalFlagSetToString(ui32);
+ }
+}
diff --git a/library/cpp/messagebus/local_flags_ut.cpp b/library/cpp/messagebus/local_flags_ut.cpp
new file mode 100644
index 0000000000..189d73eb0f
--- /dev/null
+++ b/library/cpp/messagebus/local_flags_ut.cpp
@@ -0,0 +1,18 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "local_flags.h"
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+Y_UNIT_TEST_SUITE(EMessageLocalFlags) {
+ Y_UNIT_TEST(TestLocalFlagSetToString) {
+ UNIT_ASSERT_VALUES_EQUAL("0", LocalFlagSetToString(0));
+ UNIT_ASSERT_VALUES_EQUAL("MESSAGE_REPLY_INTERNAL",
+ LocalFlagSetToString(MESSAGE_REPLY_INTERNAL));
+ UNIT_ASSERT_VALUES_EQUAL("MESSAGE_IN_WORK|MESSAGE_IN_FLIGHT_ON_CLIENT",
+ LocalFlagSetToString(MESSAGE_IN_WORK | MESSAGE_IN_FLIGHT_ON_CLIENT));
+ UNIT_ASSERT_VALUES_EQUAL("0xff3456",
+ LocalFlagSetToString(0xff3456));
+ }
+}
diff --git a/library/cpp/messagebus/local_tasks.h b/library/cpp/messagebus/local_tasks.h
new file mode 100644
index 0000000000..d8e801a457
--- /dev/null
+++ b/library/cpp/messagebus/local_tasks.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <util/system/atomic.h>
+
+class TLocalTasks {
+private:
+ TAtomic GotTasks;
+
+public:
+ TLocalTasks()
+ : GotTasks(0)
+ {
+ }
+
+ void AddTask() {
+ AtomicSet(GotTasks, 1);
+ }
+
+ bool FetchTask() {
+ bool gotTasks = AtomicCas(&GotTasks, 0, 1);
+ return gotTasks;
+ }
+};
diff --git a/library/cpp/messagebus/locator.cpp b/library/cpp/messagebus/locator.cpp
new file mode 100644
index 0000000000..e38a35c426
--- /dev/null
+++ b/library/cpp/messagebus/locator.cpp
@@ -0,0 +1,427 @@
+////////////////////////////////////////////////////////////////////////////
+/// \file
+/// \brief Implementation of locator service
+
+#include "locator.h"
+
+#include "ybus.h"
+
+#include <util/generic/hash_set.h>
+#include <util/system/hostname.h>
+
+namespace NBus {
+ using namespace NAddr;
+
+ static TIpPort GetAddrPort(const IRemoteAddr& addr) {
+ switch (addr.Addr()->sa_family) {
+ case AF_INET: {
+ return ntohs(((const sockaddr_in*)addr.Addr())->sin_port);
+ }
+
+ case AF_INET6: {
+ return ntohs(((const sockaddr_in6*)addr.Addr())->sin6_port);
+ }
+
+ default: {
+ ythrow yexception() << "not implemented";
+ break;
+ }
+ }
+ }
+
+ static inline bool GetIp6AddressFromVector(const TVector<TNetAddr>& addrs, TNetAddr* addr) {
+ for (size_t i = 1; i < addrs.size(); ++i) {
+ if (addrs[i - 1].Addr()->sa_family == addrs[i].Addr()->sa_family) {
+ return false;
+ }
+
+ if (GetAddrPort(addrs[i - 1]) != GetAddrPort(addrs[i])) {
+ return false;
+ }
+ }
+
+ for (size_t i = 0; i < addrs.size(); ++i) {
+ if (addrs[i].Addr()->sa_family == AF_INET6) {
+ *addr = addrs[i];
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ EMessageStatus TBusProtocol::GetDestination(const TBusClientSession*, TBusMessage* mess, TBusLocator* locator, TNetAddr* addr) {
+ TBusService service = GetService();
+ TBusKey key = GetKey(mess);
+ TVector<TNetAddr> addrs;
+
+ /// check for special local key
+ if (key == YBUS_KEYLOCAL) {
+ locator->GetLocalAddresses(service, addrs);
+ } else {
+ /// lookup address/port in the locator table
+ locator->LocateAll(service, key, addrs);
+ }
+
+ if (addrs.size() == 0) {
+ return MESSAGE_SERVICE_UNKNOWN;
+ } else if (addrs.size() == 1) {
+ *addr = addrs[0];
+ } else {
+ if (!GetIp6AddressFromVector(addrs, addr)) {
+ /// default policy can't make choice for you here, overide GetDestination() function
+ /// to implement custom routing strategy for your service.
+ return MESSAGE_SERVICE_TOOMANY;
+ }
+ }
+
+ return MESSAGE_OK;
+ }
+
+ static const sockaddr_in* SockAddrIpV4(const IRemoteAddr& a) {
+ return (const sockaddr_in*)a.Addr();
+ }
+
+ static const sockaddr_in6* SockAddrIpV6(const IRemoteAddr& a) {
+ return (const sockaddr_in6*)a.Addr();
+ }
+
+ static bool IsAddressEqual(const IRemoteAddr& a1, const IRemoteAddr& a2) {
+ if (a1.Addr()->sa_family == a2.Addr()->sa_family) {
+ if (a1.Addr()->sa_family == AF_INET) {
+ return memcmp(&SockAddrIpV4(a1)->sin_addr, &SockAddrIpV4(a2)->sin_addr, sizeof(in_addr)) == 0;
+ } else {
+ return memcmp(&SockAddrIpV6(a1)->sin6_addr, &SockAddrIpV6(a2)->sin6_addr, sizeof(in6_addr)) == 0;
+ }
+ }
+ return false;
+ }
+
+ TBusLocator::TBusLocator()
+ : MyInterfaces(GetNetworkInterfaces())
+ {
+ }
+
+ bool TBusLocator::TItem::operator<(const TItem& y) const {
+ const TItem& x = *this;
+
+ if (x.ServiceId == y.ServiceId) {
+ return (x.End < y.End) || ((x.End == y.End) && CompareByHost(x.Addr, y.Addr) < 0);
+ }
+ return x.ServiceId < y.ServiceId;
+ }
+
+ bool TBusLocator::TItem::operator==(const TItem& y) const {
+ return ServiceId == y.ServiceId && Start == y.Start && End == y.End && Addr == y.Addr;
+ }
+
+ TBusLocator::TItem::TItem(TServiceId serviceId, TBusKey start, TBusKey end, const TNetAddr& addr)
+ : ServiceId(serviceId)
+ , Start(start)
+ , End(end)
+ , Addr(addr)
+ {
+ }
+
+ bool TBusLocator::IsLocal(const TNetAddr& addr) {
+ for (const auto& myInterface : MyInterfaces) {
+ if (IsAddressEqual(addr, *myInterface.Address)) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ TBusLocator::TServiceId TBusLocator::GetServiceId(const char* name) {
+ const char* c = ServiceIdSet.insert(name).first->c_str();
+ return (ui64)c;
+ }
+
+ int TBusLocator::RegisterBreak(TBusService service, const TVector<TBusKey>& starts, const TNetAddr& addr) {
+ TGuard<TMutex> G(Lock);
+
+ TServiceId serviceId = GetServiceId(service);
+ for (size_t i = 0; i < starts.size(); ++i) {
+ RegisterBreak(serviceId, starts[i], addr);
+ }
+ return 0;
+ }
+
+ int TBusLocator::RegisterBreak(TServiceId serviceId, const TBusKey start, const TNetAddr& addr) {
+ TItems::const_iterator it = Items.lower_bound(TItem(serviceId, 0, start, addr));
+ TItems::const_iterator service_it =
+ Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr()));
+
+ THolder<TItem> left;
+ THolder<TItem> right;
+ if ((it != Items.end() || Items.begin() != Items.end()) && service_it != Items.end() && service_it->ServiceId == serviceId) {
+ if (it == Items.end()) {
+ --it;
+ }
+ const TItem& item = *it;
+ left.Reset(new TItem(serviceId, item.Start,
+ Max<TBusKey>(item.Start, start - 1), item.Addr));
+ right.Reset(new TItem(serviceId, start, item.End, addr));
+ Items.erase(*it);
+ } else {
+ left.Reset(new TItem(serviceId, YBUS_KEYMIN, start, addr));
+ if (start < YBUS_KEYMAX) {
+ right.Reset(new TItem(serviceId, start + 1, YBUS_KEYMAX, addr));
+ }
+ }
+ Items.insert(*left);
+ Items.insert(*right);
+ NormalizeBreaks(serviceId);
+ return 0;
+ }
+
+ int TBusLocator::UnregisterBreak(TBusService service, const TNetAddr& addr) {
+ TGuard<TMutex> G(Lock);
+
+ TServiceId serviceId = GetServiceId(service);
+ return UnregisterBreak(serviceId, addr);
+ }
+
+ int TBusLocator::UnregisterBreak(TServiceId serviceId, const TNetAddr& addr) {
+ int deleted = 0;
+ TItems::iterator it = Items.begin();
+ while (it != Items.end()) {
+ const TItem& item = *it;
+ if (item.ServiceId != serviceId) {
+ ++it;
+ continue;
+ }
+ TItems::iterator itErase = it++;
+ if (item.ServiceId == serviceId && item.Addr == addr) {
+ Items.erase(itErase);
+ deleted += 1;
+ }
+ }
+
+ if (Items.begin() == Items.end()) {
+ return deleted;
+ }
+ TBusKey keyItem = YBUS_KEYMAX;
+ it = Items.end();
+ TItems::iterator first = it;
+ do {
+ --it;
+ // item.Start is not used in set comparison function
+ // so you can't violate set sort order by changing it
+ // hence const_cast()
+ TItem& item = const_cast<TItem&>(*it);
+ if (item.ServiceId != serviceId) {
+ continue;
+ }
+ first = it;
+ if (item.End < keyItem) {
+ item.End = keyItem;
+ }
+ keyItem = item.Start - 1;
+ } while (it != Items.begin());
+
+ if (first != Items.end() && first->Start != 0) {
+ TItem item(serviceId, YBUS_KEYMIN, first->Start - 1, first->Addr);
+ Items.insert(item);
+ }
+
+ NormalizeBreaks(serviceId);
+ return deleted;
+ }
+
+ void TBusLocator::NormalizeBreaks(TServiceId serviceId) {
+ TItems::const_iterator first = Items.lower_bound(TItem(serviceId, YBUS_KEYMIN, YBUS_KEYMIN, TNetAddr()));
+ TItems::const_iterator last = Items.end();
+
+ if ((Items.end() != first) && (first->ServiceId == serviceId)) {
+ if (serviceId != Max<TServiceId>()) {
+ last = Items.lower_bound(TItem(serviceId + 1, YBUS_KEYMIN, YBUS_KEYMIN, TNetAddr()));
+ }
+
+ --last;
+ Y_ASSERT(Items.end() != last);
+ Y_ASSERT(last->ServiceId == serviceId);
+
+ TItem& beg = const_cast<TItem&>(*first);
+ beg.Addr = last->Addr;
+ }
+ }
+
+ int TBusLocator::LocateAll(TBusService service, TBusKey key, TVector<TNetAddr>& addrs) {
+ TGuard<TMutex> G(Lock);
+ Y_VERIFY(addrs.empty(), "Non emtpy addresses");
+
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+
+ for (it = Items.lower_bound(TItem(serviceId, 0, key, TNetAddr()));
+ it != Items.end() && it->ServiceId == serviceId && it->Start <= key && key <= it->End;
+ ++it) {
+ const TItem& item = *it;
+ addrs.push_back(item.Addr);
+ }
+
+ if (addrs.size() == 0) {
+ return -1;
+ }
+ return (int)addrs.size();
+ }
+
+ int TBusLocator::Locate(TBusService service, TBusKey key, TNetAddr* addr) {
+ TGuard<TMutex> G(Lock);
+
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+
+ it = Items.lower_bound(TItem(serviceId, 0, key, TNetAddr()));
+
+ if (it != Items.end()) {
+ const TItem& item = *it;
+ if (item.ServiceId == serviceId && item.Start <= key && key < item.End) {
+ *addr = item.Addr;
+
+ return 0;
+ }
+ }
+
+ return -1;
+ }
+
+ int TBusLocator::GetLocalPort(TBusService service) {
+ TGuard<TMutex> G(Lock);
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+ int port = 0;
+
+ for (it = Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr())); it != Items.end(); ++it) {
+ const TItem& item = *it;
+ if (item.ServiceId != serviceId) {
+ break;
+ }
+
+ if (IsLocal(item.Addr)) {
+ if (port != 0 && port != GetAddrPort(item.Addr)) {
+ Y_ASSERT(0 && "Can't decide which port to use.");
+ return 0;
+ }
+ port = GetAddrPort(item.Addr);
+ }
+ }
+
+ return port;
+ }
+
+ int TBusLocator::GetLocalAddresses(TBusService service, TVector<TNetAddr>& addrs) {
+ TGuard<TMutex> G(Lock);
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+
+ for (it = Items.lower_bound(TItem(serviceId, 0, 0, TNetAddr())); it != Items.end(); ++it) {
+ const TItem& item = *it;
+ if (item.ServiceId != serviceId) {
+ break;
+ }
+
+ if (IsLocal(item.Addr)) {
+ addrs.push_back(item.Addr);
+ }
+ }
+
+ if (addrs.size() == 0) {
+ return -1;
+ }
+
+ return (int)addrs.size();
+ }
+
+ int TBusLocator::LocateHost(TBusService service, TBusKey key, TString* host, int* port, bool* isLocal) {
+ int ret;
+ TNetAddr addr;
+ ret = Locate(service, key, &addr);
+ if (ret != 0) {
+ return ret;
+ }
+
+ {
+ TGuard<TMutex> G(Lock);
+ THostAddrMap::const_iterator it = HostAddrMap.find(addr);
+ if (it == HostAddrMap.end()) {
+ return -1;
+ }
+ *host = it->second;
+ }
+
+ *port = GetAddrPort(addr);
+ if (isLocal != nullptr) {
+ *isLocal = IsLocal(addr);
+ }
+ return 0;
+ }
+
+ int TBusLocator::LocateKeys(TBusService service, TBusKeyVec& keys, bool onlyLocal) {
+ TGuard<TMutex> G(Lock);
+ Y_VERIFY(keys.empty(), "Non empty keys");
+
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+ for (it = Items.begin(); it != Items.end(); ++it) {
+ const TItem& item = *it;
+ if (item.ServiceId != serviceId) {
+ continue;
+ }
+ if (onlyLocal && !IsLocal(item.Addr)) {
+ continue;
+ }
+ keys.push_back(std::make_pair(item.Start, item.End));
+ }
+ return (int)keys.size();
+ }
+
+ int TBusLocator::Register(TBusService service, const char* hostName, int port, TBusKey start /*= YBUS_KEYMIN*/, TBusKey end /*= YBUS_KEYMAX*/, EIpVersion requireVersion /*= EIP_VERSION_4*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) {
+ TNetAddr addr(hostName, port, requireVersion, preferVersion); // throws
+ {
+ TGuard<TMutex> G(Lock);
+ HostAddrMap[addr] = hostName;
+ }
+ Register(service, start, end, addr);
+ return 0;
+ }
+
+ int TBusLocator::Register(TBusService service, TBusKey start, TBusKey end, const TNetworkAddress& na, EIpVersion requireVersion /*= EIP_VERSION_4*/, EIpVersion preferVersion /*= EIP_VERSION_ANY*/) {
+ TNetAddr addr(na, requireVersion, preferVersion); // throws
+ Register(service, start, end, addr);
+ return 0;
+ }
+
+ int TBusLocator::Register(TBusService service, TBusKey start, TBusKey end, const TNetAddr& addr) {
+ TGuard<TMutex> G(Lock);
+
+ TServiceId serviceId = GetServiceId(service);
+ TItems::const_iterator it;
+
+ TItem itemToReg(serviceId, start, end, addr);
+ for (it = Items.lower_bound(TItem(serviceId, 0, start, TNetAddr()));
+ it != Items.end() && it->ServiceId == serviceId;
+ ++it) {
+ const TItem& item = *it;
+ if (item == itemToReg) {
+ return 0;
+ }
+ if ((item.Start < start && start < item.End) || (item.Start < end && end < item.End)) {
+ Y_FAIL("Overlap in registered keys with non-identical range");
+ }
+ }
+
+ Items.insert(itemToReg);
+ return 0;
+ }
+
+ int TBusLocator::Unregister(TBusService service, TBusKey start, TBusKey end) {
+ TGuard<TMutex> G(Lock);
+ TServiceId serviceId = GetServiceId(service);
+ Items.erase(TItem(serviceId, start, end, TNetAddr()));
+ return 0;
+ }
+
+}
diff --git a/library/cpp/messagebus/locator.h b/library/cpp/messagebus/locator.h
new file mode 100644
index 0000000000..f8556a3fce
--- /dev/null
+++ b/library/cpp/messagebus/locator.h
@@ -0,0 +1,93 @@
+#pragma once
+
+#include "defs.h"
+
+#include <util/generic/hash.h>
+#include <util/generic/map.h>
+#include <util/generic/set.h>
+#include <util/generic/string.h>
+#include <util/network/interface.h>
+#include <util/system/mutex.h>
+
+namespace NBus {
+ ///////////////////////////////////////////////
+ /// \brief Client interface to locator service
+
+ /// This interface abstracts clustering/location service that
+ /// allows clients find servers (address, port) using "name" and "key".
+ /// The instance lives in TBusMessageQueue-object, but can be shared by different queues.
+ class TBusLocator: public TAtomicRefCount<TBusLocator>, public TNonCopyable {
+ private:
+ typedef ui64 TServiceId;
+ typedef TSet<TString> TServiceIdSet;
+ TServiceIdSet ServiceIdSet;
+ TServiceId GetServiceId(const char* name);
+
+ typedef TMap<TNetAddr, TString> THostAddrMap;
+ THostAddrMap HostAddrMap;
+
+ TNetworkInterfaceList MyInterfaces;
+
+ struct TItem {
+ TServiceId ServiceId;
+ TBusKey Start;
+ TBusKey End;
+ TNetAddr Addr;
+
+ bool operator<(const TItem& y) const;
+
+ bool operator==(const TItem& y) const;
+
+ TItem(TServiceId serviceId, TBusKey start, TBusKey end, const TNetAddr& addr);
+ };
+
+ typedef TMultiSet<TItem> TItems;
+ TItems Items;
+ TMutex Lock;
+
+ int RegisterBreak(TServiceId serviceId, const TBusKey start, const TNetAddr& addr);
+ int UnregisterBreak(TServiceId serviceId, const TNetAddr& addr);
+
+ void NormalizeBreaks(TServiceId serviceId);
+
+ private:
+ int Register(TBusService service, TBusKey start, TBusKey end, const TNetAddr& addr);
+
+ public:
+ /// creates instance that obtains location table from locator server (not implemented)
+ TBusLocator();
+
+ /// returns true if this address is on the same node for YBUS_KEYLOCAL
+ bool IsLocal(const TNetAddr& addr);
+
+ /// returns first address for service and key
+ int Locate(TBusService service, TBusKey key, TNetAddr* addr);
+
+ /// returns all addresses mathing service and key
+ int LocateAll(TBusService service, TBusKey key, TVector<TNetAddr>& addrs);
+
+ /// returns actual host name for service and key
+ int LocateHost(TBusService service, TBusKey key, TString* host, int* port, bool* isLocal = nullptr);
+
+ /// returns all key ranges for the given service
+ int LocateKeys(TBusService service, TBusKeyVec& keys, bool onlyLocal = false);
+
+ /// returns port on the local node for the service
+ int GetLocalPort(TBusService service);
+
+ /// returns addresses of the local node for the service
+ int GetLocalAddresses(TBusService service, TVector<TNetAddr>& addrs);
+
+ /// register service instance
+ int Register(TBusService service, TBusKey start, TBusKey end, const TNetworkAddress& addr, EIpVersion requireVersion = EIP_VERSION_4, EIpVersion preferVersion = EIP_VERSION_ANY);
+ /// @throws yexception
+ int Register(TBusService service, const char* host, int port, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion requireVersion = EIP_VERSION_4, EIpVersion preferVersion = EIP_VERSION_ANY);
+
+ /// unregister service instance
+ int Unregister(TBusService service, TBusKey start, TBusKey end);
+
+ int RegisterBreak(TBusService service, const TVector<TBusKey>& starts, const TNetAddr& addr);
+ int UnregisterBreak(TBusService service, const TNetAddr& addr);
+ };
+
+}
diff --git a/library/cpp/messagebus/mb_lwtrace.cpp b/library/cpp/messagebus/mb_lwtrace.cpp
new file mode 100644
index 0000000000..c54cd5ab71
--- /dev/null
+++ b/library/cpp/messagebus/mb_lwtrace.cpp
@@ -0,0 +1,12 @@
+#include "mb_lwtrace.h"
+
+#include <library/cpp/lwtrace/all.h>
+
+#include <util/generic/singleton.h>
+
+LWTRACE_DEFINE_PROVIDER(LWTRACE_MESSAGEBUS_PROVIDER)
+
+void NBus::InitBusLwtrace() {
+ // Function is nop, and needed only to make sure TBusLwtraceInit loaded.
+ // It won't be necessary when pg@ implements GLOBAL in arc.
+}
diff --git a/library/cpp/messagebus/mb_lwtrace.h b/library/cpp/messagebus/mb_lwtrace.h
new file mode 100644
index 0000000000..e62728b265
--- /dev/null
+++ b/library/cpp/messagebus/mb_lwtrace.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <library/cpp/lwtrace/all.h>
+
+#include <util/generic/string.h>
+
+#define LWTRACE_MESSAGEBUS_PROVIDER(PROBE, EVENT, GROUPS, TYPES, NAMES) \
+ PROBE(Error, GROUPS("MessagebusRare"), TYPES(TString, TString, TString), NAMES("status", "address", "misc")) \
+ PROBE(ServerUnknownVersion, GROUPS("MessagebusRare"), TYPES(TString, ui32), NAMES("address", "version")) \
+ PROBE(Accepted, GROUPS("MessagebusRare"), TYPES(TString), NAMES("address")) \
+ PROBE(Disconnected, GROUPS("MessagebusRare"), TYPES(TString), NAMES("address")) \
+ PROBE(Read, GROUPS(), TYPES(ui32), NAMES("size")) \
+ /**/
+
+LWTRACE_DECLARE_PROVIDER(LWTRACE_MESSAGEBUS_PROVIDER)
+
+namespace NBus {
+ void InitBusLwtrace();
+}
diff --git a/library/cpp/messagebus/memory.h b/library/cpp/messagebus/memory.h
new file mode 100644
index 0000000000..b2c0544491
--- /dev/null
+++ b/library/cpp/messagebus/memory.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#ifndef CACHE_LINE_SIZE
+#define CACHE_LINE_SIZE 64
+#endif
+
+#define CONCAT(a, b) a##b
+#define LABEL(a) CONCAT(UniqueName_, a)
+#define UNIQUE_NAME LABEL(__LINE__)
+
+#define CACHE_LINE_PADDING char UNIQUE_NAME[CACHE_LINE_SIZE];
+
+static inline void* MallocAligned(size_t size, size_t alignment) {
+ void** ptr = (void**)malloc(size + alignment + sizeof(size_t*));
+ if (!ptr) {
+ return nullptr;
+ }
+
+ size_t mask = ~(alignment - 1);
+ intptr_t roundedDown = intptr_t(ptr) & mask;
+ void** alignedPtr = (void**)(roundedDown + alignment);
+ alignedPtr[-1] = ptr;
+ return alignedPtr;
+}
+
+static inline void FreeAligned(void* ptr) {
+ if (!ptr) {
+ return;
+ }
+
+ void** typedPtr = (void**)ptr;
+ void* originalPtr = typedPtr[-1];
+ free(originalPtr);
+}
+
+static inline void* MallocCacheAligned(size_t size) {
+ return MallocAligned(size, CACHE_LINE_SIZE);
+}
+
+static inline void FreeCacheAligned(void* ptr) {
+ return FreeAligned(ptr);
+}
diff --git a/library/cpp/messagebus/memory_ut.cpp b/library/cpp/messagebus/memory_ut.cpp
new file mode 100644
index 0000000000..00654f28a1
--- /dev/null
+++ b/library/cpp/messagebus/memory_ut.cpp
@@ -0,0 +1,13 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "memory.h"
+
+Y_UNIT_TEST_SUITE(MallocAligned) {
+ Y_UNIT_TEST(Test) {
+ for (size_t size = 0; size < 1000; ++size) {
+ void* ptr = MallocAligned(size, 128);
+ UNIT_ASSERT(uintptr_t(ptr) % 128 == 0);
+ FreeAligned(ptr);
+ }
+ }
+}
diff --git a/library/cpp/messagebus/message.cpp b/library/cpp/messagebus/message.cpp
new file mode 100644
index 0000000000..bfa7ed8e9b
--- /dev/null
+++ b/library/cpp/messagebus/message.cpp
@@ -0,0 +1,198 @@
+#include "remote_server_connection.h"
+#include "ybus.h"
+
+#include <util/random/random.h>
+#include <util/string/printf.h>
+#include <util/system/atomic.h>
+
+#include <string.h>
+
+using namespace NBus;
+
+namespace NBus {
+ using namespace NBus::NPrivate;
+
+ TBusIdentity::TBusIdentity()
+ : MessageId(0)
+ , Size(0)
+ , Flags(0)
+ , LocalFlags(0)
+ {
+ }
+
+ TBusIdentity::~TBusIdentity() {
+ // TODO: print local flags
+#ifndef NDEBUG
+ Y_VERIFY(LocalFlags == 0, "local flags must be zero at this point; message type is %s",
+ MessageType.value_or("unknown").c_str());
+#else
+ Y_VERIFY(LocalFlags == 0, "local flags must be zero at this point");
+#endif
+ }
+
+ TNetAddr TBusIdentity::GetNetAddr() const {
+ if (!!Connection) {
+ return Connection->GetAddr();
+ } else {
+ Y_FAIL();
+ }
+ }
+
+ void TBusIdentity::Pack(char* dest) {
+ memcpy(dest, this, sizeof(TBusIdentity));
+ LocalFlags = 0;
+
+ // prevent decref
+ new (&Connection) TIntrusivePtr<TRemoteServerConnection>;
+ }
+
+ void TBusIdentity::Unpack(const char* src) {
+ Y_VERIFY(LocalFlags == 0);
+ Y_VERIFY(!Connection);
+
+ memcpy(this, src, sizeof(TBusIdentity));
+ }
+
+ void TBusHeader::GenerateId() {
+ for (;;) {
+ Id = RandomNumber<TBusKey>();
+ // Skip reserved ids
+ if (IsBusKeyValid(Id))
+ return;
+ }
+ }
+
+ TBusMessage::TBusMessage(ui16 type, int approxsize)
+ //: TCtr("BusMessage")
+ : TRefCounted<TBusMessage, TAtomicCounter, TDelete>(1)
+ , LocalFlags(0)
+ , RequestSize(0)
+ , Data(nullptr)
+ {
+ Y_UNUSED(approxsize);
+ GetHeader()->Type = type;
+ DoReset();
+ }
+
+ TBusMessage::TBusMessage(ECreateUninitialized)
+ //: TCtr("BusMessage")
+ : TRefCounted<TBusMessage, TAtomicCounter, TDelete>(1)
+ , LocalFlags(0)
+ , Data(nullptr)
+ {
+ }
+
+ TString TBusMessage::Describe() const {
+ return Sprintf("object type: %s, message type: %d", TypeName(*this).data(), int(GetHeader()->Type));
+ }
+
+ TBusMessage::~TBusMessage() {
+#ifndef NDEBUG
+ Y_VERIFY(GetHeader()->Id != YBUS_KEYINVALID, "must not be invalid key, message type: %d, ", int(Type));
+ GetHeader()->Id = YBUS_KEYINVALID;
+ Data = (void*)17;
+ CheckClean();
+#endif
+ }
+
+ void TBusMessage::DoReset() {
+ GetHeader()->SendTime = 0;
+ GetHeader()->Size = 0;
+ GetHeader()->FlagsInternal = 0;
+ GetHeader()->GenerateId();
+ GetHeader()->SetVersionInternal();
+ }
+
+ void TBusMessage::Reset() {
+ CheckClean();
+ DoReset();
+ }
+
+ void TBusMessage::CheckClean() const {
+ if (Y_UNLIKELY(LocalFlags != 0)) {
+ TString describe = Describe();
+ TString localFlags = LocalFlagSetToString(LocalFlags);
+ Y_FAIL("message local flags must be zero, got: %s, message: %s", localFlags.data(), describe.data());
+ }
+ }
+
+ ///////////////////////////////////////////////////////
+ /// \brief Unpacks header from network order
+
+ /// \todo ntoh instead of memcpy
+ int TBusHeader::ReadHeader(TArrayRef<const char> data) {
+ Y_ASSERT(data.size() >= sizeof(TBusHeader));
+ memcpy(this, data.data(), sizeof(TBusHeader));
+ return sizeof(TBusHeader);
+ }
+
+ ///////////////////////////////////////////////////////
+ /// \brief Packs header to network order
+
+ //////////////////////////////////////////////////////////
+ /// \brief serialize message identity to be used to construct reply message
+
+ /// function stores messageid, flags and connection reply address into the buffer
+ /// that can later be used to construct a reply to the message
+ void TBusMessage::GetIdentity(TBusIdentity& data) const {
+ data.MessageId = GetHeader()->Id;
+ data.Size = GetHeader()->Size;
+ data.Flags = GetHeader()->FlagsInternal;
+ //data.LocalFlags = LocalFlags;
+ }
+
+ ////////////////////////////////////////////////////////////
+ /// \brief set message identity from serialized form
+
+ /// function restores messageid, flags and connection reply address from the buffer
+ /// into the reply message
+ void TBusMessage::SetIdentity(const TBusIdentity& data) {
+ // TODO: wrong assertion: YBUS_KEYMIN is 0
+ Y_ASSERT(data.MessageId != 0);
+ bool compressed = IsCompressed();
+ GetHeader()->Id = data.MessageId;
+ GetHeader()->FlagsInternal = data.Flags;
+ LocalFlags = data.LocalFlags & ~MESSAGE_IN_WORK;
+ ReplyTo = data.Connection->PeerAddrSocketAddr;
+ SetCompressed(compressed || IsCompressedResponse());
+ }
+
+ void TBusMessage::SetCompressed(bool v) {
+ if (v) {
+ GetHeader()->FlagsInternal |= MESSAGE_COMPRESS_INTERNAL;
+ } else {
+ GetHeader()->FlagsInternal &= ~(MESSAGE_COMPRESS_INTERNAL);
+ }
+ }
+
+ void TBusMessage::SetCompressedResponse(bool v) {
+ if (v) {
+ GetHeader()->FlagsInternal |= MESSAGE_COMPRESS_RESPONSE;
+ } else {
+ GetHeader()->FlagsInternal &= ~(MESSAGE_COMPRESS_RESPONSE);
+ }
+ }
+
+ TString TBusIdentity::ToString() const {
+ TStringStream ss;
+ ss << "msg-id=" << MessageId
+ << " size=" << Size;
+ if (!!Connection) {
+ ss << " conn=" << Connection->GetAddr();
+ }
+ ss
+ << " flags=" << Flags
+ << " local-flags=" << LocalFlags
+#ifndef NDEBUG
+ << " msg-type= " << MessageType.value_or("unknown").c_str()
+#endif
+ ;
+ return ss.Str();
+ }
+
+}
+
+template <>
+void Out<TBusIdentity>(IOutputStream& os, TTypeTraits<TBusIdentity>::TFuncParam ident) {
+ os << ident.ToString();
+}
diff --git a/library/cpp/messagebus/message.h b/library/cpp/messagebus/message.h
new file mode 100644
index 0000000000..005ca10c65
--- /dev/null
+++ b/library/cpp/messagebus/message.h
@@ -0,0 +1,272 @@
+#pragma once
+
+#include "base.h"
+#include "local_flags.h"
+#include "message_status.h"
+#include "netaddr.h"
+#include "socket_addr.h"
+
+#include <util/generic/array_ref.h>
+#include <util/generic/noncopyable.h>
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/system/defaults.h>
+#include <util/system/type_name.h>
+#include <util/system/yassert.h>
+
+#include <optional>
+#include <typeinfo>
+
+namespace NBus {
+ ///////////////////////////////////////////////////////////////////
+ /// \brief Structure to preserve identity from message to reply
+ struct TBusIdentity : TNonCopyable {
+ friend class TBusMessage;
+ friend class NPrivate::TRemoteServerSession;
+ friend struct NPrivate::TClientRequestImpl;
+ friend class TOnMessageContext;
+
+ // TODO: make private
+ TBusKey MessageId;
+
+ private:
+ ui32 Size;
+ TIntrusivePtr<NPrivate::TRemoteServerConnection> Connection;
+ ui16 Flags;
+ ui32 LocalFlags;
+ TInstant RecvTime;
+
+#ifndef NDEBUG
+ std::optional<TString> MessageType;
+#endif
+
+ private:
+ // TODO: drop
+ TNetAddr GetNetAddr() const;
+
+ public:
+ void Pack(char* dest);
+ void Unpack(const char* src);
+
+ bool IsInWork() const {
+ return LocalFlags & NPrivate::MESSAGE_IN_WORK;
+ }
+
+ // for internal use only
+ void BeginWork() {
+ SetInWork(true);
+ }
+
+ // for internal use only
+ void EndWork() {
+ SetInWork(false);
+ }
+
+ TBusIdentity();
+ ~TBusIdentity();
+
+ void Swap(TBusIdentity& that) {
+ DoSwap(MessageId, that.MessageId);
+ DoSwap(Size, that.Size);
+ DoSwap(Connection, that.Connection);
+ DoSwap(Flags, that.Flags);
+ DoSwap(LocalFlags, that.LocalFlags);
+ DoSwap(RecvTime, that.RecvTime);
+#ifndef NDEBUG
+ DoSwap(MessageType, that.MessageType);
+#endif
+ }
+
+ TString ToString() const;
+
+ private:
+ void SetInWork(bool inWork) {
+ if (LocalFlags == 0 && inWork) {
+ LocalFlags = NPrivate::MESSAGE_IN_WORK;
+ } else if (LocalFlags == NPrivate::MESSAGE_IN_WORK && !inWork) {
+ LocalFlags = 0;
+ } else {
+ Y_FAIL("impossible combination of flag and parameter: %s %d",
+ inWork ? "true" : "false", unsigned(LocalFlags));
+ }
+ }
+
+ void SetMessageType(const std::type_info& messageTypeInfo) {
+#ifndef NDEBUG
+ Y_VERIFY(!MessageType, "state check");
+ MessageType = TypeName(messageTypeInfo);
+#else
+ Y_UNUSED(messageTypeInfo);
+#endif
+ }
+ };
+
+ static const size_t BUS_IDENTITY_PACKED_SIZE = sizeof(TBusIdentity);
+
+ ///////////////////////////////////////////////////////////////
+ /// \brief Message flags in TBusHeader.Flags
+ enum EMessageFlags {
+ MESSAGE_COMPRESS_INTERNAL = 0x8000, ///< message is compressed
+ MESSAGE_COMPRESS_RESPONSE = 0x4000, ///< message prefers compressed response
+ MESSAGE_VERSION_INTERNAL = 0x00F0, ///< these bits are used as version
+ };
+
+//////////////////////////////////////////////////////////
+/// \brief Message header present in all message send and received
+
+/// This header is send into the wire.
+/// \todo fix for low/high end, 32/64bit some day
+#pragma pack(1)
+ struct TBusHeader {
+ friend class TBusMessage;
+
+ TBusKey Id = 0; ///< unique message ID
+ ui32 Size = 0; ///< total size of the message
+ TBusInstant SendTime = 0; ///< time the message was sent
+ ui16 FlagsInternal = 0; ///< TRACE is one of the flags
+ ui16 Type = 0; ///< to be used by TBusProtocol
+
+ int GetVersionInternal() {
+ return (FlagsInternal & MESSAGE_VERSION_INTERNAL) >> 4;
+ }
+ void SetVersionInternal(unsigned ver = YBUS_VERSION) {
+ FlagsInternal |= (ver << 4);
+ }
+
+ public:
+ TBusHeader() {
+ }
+ TBusHeader(TArrayRef<const char> data) {
+ ReadHeader(data);
+ }
+
+ private:
+ /// function for serialization/deserialization of the header
+ /// returns number of bytes written/read
+ int ReadHeader(TArrayRef<const char> data);
+
+ void GenerateId();
+ };
+#pragma pack()
+
+#define TBUSMAX_MESSAGE 26 * 1024 * 1024 + sizeof(NBus::TBusHeader) ///< is't it enough?
+#define TBUSMIN_MESSAGE sizeof(NBus::TBusHeader) ///< can't be less then header
+
+ inline bool IsVersionNegotiation(const NBus::TBusHeader& header) {
+ return header.Id == 0 && header.Size == sizeof(TBusHeader);
+ }
+
+ //////////////////////////////////////////////////////////
+ /// \brief Base class for all messages passed in the system
+
+ enum ECreateUninitialized {
+ MESSAGE_CREATE_UNINITIALIZED,
+ };
+
+ class TBusMessage
+ : protected TBusHeader,
+ public TRefCounted<TBusMessage, TAtomicCounter, TDelete>,
+ private TNonCopyable {
+ friend class TLocalSession;
+ friend struct ::NBus::NPrivate::TBusSessionImpl;
+ friend class ::NBus::NPrivate::TRemoteServerSession;
+ friend class ::NBus::NPrivate::TRemoteClientSession;
+ friend class ::NBus::NPrivate::TRemoteConnection;
+ friend class ::NBus::NPrivate::TRemoteClientConnection;
+ friend class ::NBus::NPrivate::TRemoteServerConnection;
+ friend struct ::NBus::NPrivate::TBusMessagePtrAndHeader;
+
+ private:
+ ui32 LocalFlags;
+
+ /// connection identity for reply set by PushMessage()
+ NPrivate::TBusSocketAddr ReplyTo;
+ // server-side response only, hack
+ ui32 RequestSize;
+
+ TInstant RecvTime;
+
+ public:
+ /// constructor to create messages on sending end
+ TBusMessage(ui16 type, int approxsize = sizeof(TBusHeader));
+
+ /// constructor with serialzed data to examine the header
+ TBusMessage(ECreateUninitialized);
+
+ // slow, for diagnostics only
+ virtual TString Describe() const;
+
+ // must be called if this message object needs to be reused
+ void Reset();
+
+ void CheckClean() const;
+
+ void SetCompressed(bool);
+ void SetCompressedResponse(bool);
+
+ private:
+ bool IsCompressed() const {
+ return FlagsInternal & MESSAGE_COMPRESS_INTERNAL;
+ }
+ bool IsCompressedResponse() const {
+ return FlagsInternal & MESSAGE_COMPRESS_RESPONSE;
+ }
+
+ public:
+ /// can have private data to destroy
+ virtual ~TBusMessage();
+
+ /// returns header of the message
+ TBusHeader* GetHeader() {
+ return this;
+ }
+ const TBusHeader* GetHeader() const {
+ return this;
+ }
+
+ /// helper to return type for protocol object to unpack object
+ static ui16 GetType(TArrayRef<const char> data) {
+ return TBusHeader(data).Type;
+ }
+
+ /// returns payload data
+ static TArrayRef<const char> GetPayload(TArrayRef<const char> data) {
+ return data.Slice(sizeof(TBusHeader));
+ }
+
+ private:
+ void DoReset();
+
+ /// serialize message identity to be used to construct reply message
+ void GetIdentity(TBusIdentity& ident) const;
+
+ /// set message identity from serialized form
+ void SetIdentity(const TBusIdentity& ident);
+
+ public:
+ TNetAddr GetReplyTo() const {
+ return ReplyTo.ToNetAddr();
+ }
+
+ /// store of application specific data, never serialized into wire
+ void* Data;
+ };
+
+ class TBusMessageAutoPtr: public TAutoPtr<TBusMessage> {
+ public:
+ TBusMessageAutoPtr() {
+ }
+
+ TBusMessageAutoPtr(TBusMessage* message)
+ : TAutoPtr<TBusMessage>(message)
+ {
+ }
+
+ template <typename T1>
+ TBusMessageAutoPtr(const TAutoPtr<T1>& that)
+ : TAutoPtr<TBusMessage>(that.Release())
+ {
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/message_counter.cpp b/library/cpp/messagebus/message_counter.cpp
new file mode 100644
index 0000000000..04d9343f6a
--- /dev/null
+++ b/library/cpp/messagebus/message_counter.cpp
@@ -0,0 +1,46 @@
+#include "message_counter.h"
+
+#include <util/stream/str.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TMessageCounter::TMessageCounter()
+ : BytesData(0)
+ , BytesNetwork(0)
+ , Count(0)
+ , CountCompressed(0)
+ , CountCompressionRequests(0)
+{
+}
+
+TMessageCounter& TMessageCounter::operator+=(const TMessageCounter& that) {
+ BytesData += that.BytesData;
+ BytesNetwork += that.BytesNetwork;
+ Count += that.Count;
+ CountCompressed += that.CountCompressed;
+ CountCompressionRequests += that.CountCompressionRequests;
+ return *this;
+}
+
+TString TMessageCounter::ToString(bool reader) const {
+ if (reader) {
+ Y_ASSERT(CountCompressionRequests == 0);
+ }
+
+ TStringStream readValue;
+ readValue << Count;
+ if (CountCompressionRequests != 0 || CountCompressed != 0) {
+ readValue << " (" << CountCompressed << " compr";
+ if (!reader) {
+ readValue << ", " << CountCompressionRequests << " compr reqs";
+ }
+ readValue << ")";
+ }
+ readValue << ", ";
+ readValue << BytesData << "b";
+ if (BytesNetwork != BytesData) {
+ readValue << " (" << BytesNetwork << "b network)";
+ }
+ return readValue.Str();
+}
diff --git a/library/cpp/messagebus/message_counter.h b/library/cpp/messagebus/message_counter.h
new file mode 100644
index 0000000000..e4be1180b0
--- /dev/null
+++ b/library/cpp/messagebus/message_counter.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include <util/generic/string.h>
+
+#include <cstddef>
+
+namespace NBus {
+ namespace NPrivate {
+ struct TMessageCounter {
+ size_t BytesData;
+ size_t BytesNetwork;
+ size_t Count;
+ size_t CountCompressed;
+ size_t CountCompressionRequests; // reader only
+
+ void AddMessage(size_t bytesData, size_t bytesCompressed, bool Compressed, bool compressionRequested) {
+ BytesData += bytesData;
+ BytesNetwork += bytesCompressed;
+ Count += 1;
+ if (Compressed) {
+ CountCompressed += 1;
+ }
+ if (compressionRequested) {
+ CountCompressionRequests += 1;
+ }
+ }
+
+ TMessageCounter& operator+=(const TMessageCounter& that);
+
+ TString ToString(bool reader) const;
+
+ TMessageCounter();
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/message_ptr_and_header.h b/library/cpp/messagebus/message_ptr_and_header.h
new file mode 100644
index 0000000000..9b4e2fd270
--- /dev/null
+++ b/library/cpp/messagebus/message_ptr_and_header.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include "message.h"
+#include "nondestroying_holder.h"
+
+#include <util/generic/noncopyable.h>
+#include <util/generic/utility.h>
+
+namespace NBus {
+ namespace NPrivate {
+ struct TBusMessagePtrAndHeader : TNonCopyable {
+ TNonDestroyingHolder<TBusMessage> MessagePtr;
+ TBusHeader Header;
+ ui32 LocalFlags;
+
+ TBusMessagePtrAndHeader()
+ : LocalFlags()
+ {
+ }
+
+ explicit TBusMessagePtrAndHeader(TBusMessage* messagePtr)
+ : MessagePtr(messagePtr)
+ , Header(*MessagePtr->GetHeader())
+ , LocalFlags(MessagePtr->LocalFlags)
+ {
+ }
+
+ void Swap(TBusMessagePtrAndHeader& that) {
+ DoSwap(MessagePtr, that.MessagePtr);
+ DoSwap(Header, that.Header);
+ DoSwap(LocalFlags, that.LocalFlags);
+ }
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/message_status.cpp b/library/cpp/messagebus/message_status.cpp
new file mode 100644
index 0000000000..41ad62b73f
--- /dev/null
+++ b/library/cpp/messagebus/message_status.cpp
@@ -0,0 +1,13 @@
+#include "message_status.h"
+
+using namespace NBus;
+
+const char* NBus::MessageStatusDescription(EMessageStatus messageStatus) {
+#define MESSAGE_STATUS_DESCRIPTION_GEN(name, description, ...) \
+ if (messageStatus == name) \
+ return description;
+
+ MESSAGE_STATUS_MAP(MESSAGE_STATUS_DESCRIPTION_GEN)
+
+ return "Unknown";
+}
diff --git a/library/cpp/messagebus/message_status.h b/library/cpp/messagebus/message_status.h
new file mode 100644
index 0000000000..e1878960b3
--- /dev/null
+++ b/library/cpp/messagebus/message_status.h
@@ -0,0 +1,57 @@
+#pragma once
+
+#include "codegen.h"
+#include "defs.h"
+
+#include <library/cpp/deprecated/enum_codegen/enum_codegen.h>
+
+namespace NBus {
+////////////////////////////////////////////////////////////////
+/// \brief Status of message communication
+
+#define MESSAGE_STATUS_MAP(XX) \
+ XX(MESSAGE_OK, "OK") \
+ XX(MESSAGE_CONNECT_FAILED, "Connect failed") \
+ XX(MESSAGE_TIMEOUT, "Message timed out") \
+ XX(MESSAGE_SERVICE_UNKNOWN, "Locator hasn't found address for key") \
+ XX(MESSAGE_BUSY, "Too many messages in flight") \
+ XX(MESSAGE_UNKNOWN, "Request not found by id, usually it means that message is timed out") \
+ XX(MESSAGE_DESERIALIZE_ERROR, "Deserialize by TBusProtocol failed") \
+ XX(MESSAGE_HEADER_CORRUPTED, "Header corrupted") \
+ XX(MESSAGE_DECOMPRESS_ERROR, "Failed to decompress") \
+ XX(MESSAGE_MESSAGE_TOO_LARGE, "Message too large") \
+ XX(MESSAGE_REPLY_FAILED, "Unused by messagebus, used by other code") \
+ XX(MESSAGE_DELIVERY_FAILED, "Message delivery failed because connection is closed") \
+ XX(MESSAGE_INVALID_VERSION, "Protocol error: invalid version") \
+ XX(MESSAGE_SERVICE_TOOMANY, "Locator failed to resolve address") \
+ XX(MESSAGE_SHUTDOWN, "Failure because of either session or connection shutdown") \
+ XX(MESSAGE_DONT_ASK, "Internal error code used by modules")
+
+ enum EMessageStatus {
+ MESSAGE_STATUS_MAP(ENUM_VALUE_GEN_NO_VALUE)
+ MESSAGE_STATUS_COUNT
+ };
+
+ ENUM_TO_STRING(EMessageStatus, MESSAGE_STATUS_MAP)
+
+ const char* MessageStatusDescription(EMessageStatus);
+
+ static inline const char* GetMessageStatus(EMessageStatus status) {
+ return ToCString(status);
+ }
+
+ // For lwtrace
+ struct TMessageStatusField {
+ typedef int TStoreType;
+ typedef int TFuncParam;
+
+ static void ToString(int value, TString* out) {
+ *out = GetMessageStatus((NBus::EMessageStatus)value);
+ }
+
+ static int ToStoreType(int value) {
+ return value;
+ }
+ };
+
+} // ns
diff --git a/library/cpp/messagebus/message_status_counter.cpp b/library/cpp/messagebus/message_status_counter.cpp
new file mode 100644
index 0000000000..891c8f5bb2
--- /dev/null
+++ b/library/cpp/messagebus/message_status_counter.cpp
@@ -0,0 +1,71 @@
+#include "message_status_counter.h"
+
+#include "key_value_printer.h"
+#include "text_utils.h"
+
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+#include <util/stream/str.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TMessageStatusCounter::TMessageStatusCounter() {
+ Zero(Counts);
+}
+
+TMessageStatusCounter& TMessageStatusCounter::operator+=(const TMessageStatusCounter& that) {
+ for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) {
+ Counts[i] += that.Counts[i];
+ }
+ return *this;
+}
+
+TString TMessageStatusCounter::PrintToString() const {
+ TStringStream ss;
+ TKeyValuePrinter p;
+ bool hasNonZeros = false;
+ bool hasZeros = false;
+ for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) {
+ if (i == MESSAGE_OK) {
+ Y_VERIFY(Counts[i] == 0);
+ continue;
+ }
+ if (Counts[i] != 0) {
+ p.AddRow(EMessageStatus(i), Counts[i]);
+ const char* description = MessageStatusDescription(EMessageStatus(i));
+ // TODO: add third column
+ Y_UNUSED(description);
+
+ hasNonZeros = true;
+ } else {
+ hasZeros = true;
+ }
+ }
+ if (!hasNonZeros) {
+ ss << "message status counts are zeros\n";
+ } else {
+ if (hasZeros) {
+ ss << "message status counts are zeros, except:\n";
+ } else {
+ ss << "message status counts:\n";
+ }
+ ss << IndentText(p.PrintToString());
+ }
+ return ss.Str();
+}
+
+void TMessageStatusCounter::FillErrorsProtobuf(TConnectionStatusMonRecord* status) const {
+ status->clear_errorcountbystatus();
+ for (size_t i = 0; i < MESSAGE_STATUS_COUNT; ++i) {
+ if (i == MESSAGE_OK) {
+ Y_VERIFY(Counts[i] == 0);
+ continue;
+ }
+ if (Counts[i] != 0) {
+ TMessageStatusRecord* description = status->add_errorcountbystatus();
+ description->SetStatus(TMessageStatusCounter::MessageStatusToProtobuf((EMessageStatus)i));
+ description->SetCount(Counts[i]);
+ }
+ }
+}
diff --git a/library/cpp/messagebus/message_status_counter.h b/library/cpp/messagebus/message_status_counter.h
new file mode 100644
index 0000000000..e8ba2fdd31
--- /dev/null
+++ b/library/cpp/messagebus/message_status_counter.h
@@ -0,0 +1,36 @@
+#pragma once
+
+#include "message_status.h"
+
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+#include <util/generic/string.h>
+
+#include <array>
+
+namespace NBus {
+ namespace NPrivate {
+ struct TMessageStatusCounter {
+ static TMessageStatusRecord::EMessageStatus MessageStatusToProtobuf(EMessageStatus status) {
+ return (TMessageStatusRecord::EMessageStatus)status;
+ }
+
+ std::array<unsigned, MESSAGE_STATUS_COUNT> Counts;
+
+ unsigned& operator[](EMessageStatus index) {
+ return Counts[index];
+ }
+ const unsigned& operator[](EMessageStatus index) const {
+ return Counts[index];
+ }
+
+ TMessageStatusCounter();
+
+ TMessageStatusCounter& operator+=(const TMessageStatusCounter&);
+
+ TString PrintToString() const;
+ void FillErrorsProtobuf(TConnectionStatusMonRecord*) const;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/message_status_counter_ut.cpp b/library/cpp/messagebus/message_status_counter_ut.cpp
new file mode 100644
index 0000000000..9598651329
--- /dev/null
+++ b/library/cpp/messagebus/message_status_counter_ut.cpp
@@ -0,0 +1,23 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "message_status_counter.h"
+
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+Y_UNIT_TEST_SUITE(MessageStatusCounter) {
+ Y_UNIT_TEST(MessageStatusConversion) {
+ const ::google::protobuf::EnumDescriptor* descriptor =
+ TMessageStatusRecord_EMessageStatus_descriptor();
+
+ for (int i = 0; i < MESSAGE_STATUS_COUNT; i++) {
+ const ::google::protobuf::EnumValueDescriptor* valueDescriptor =
+ descriptor->FindValueByName(ToString((EMessageStatus)i));
+ UNIT_ASSERT_UNEQUAL(valueDescriptor, nullptr);
+ UNIT_ASSERT_EQUAL(valueDescriptor->number(), i);
+ }
+ UNIT_ASSERT_EQUAL(MESSAGE_STATUS_COUNT, descriptor->value_count());
+ }
+}
diff --git a/library/cpp/messagebus/messqueue.cpp b/library/cpp/messagebus/messqueue.cpp
new file mode 100644
index 0000000000..3474d62705
--- /dev/null
+++ b/library/cpp/messagebus/messqueue.cpp
@@ -0,0 +1,198 @@
+#include "key_value_printer.h"
+#include "mb_lwtrace.h"
+#include "remote_client_session.h"
+#include "remote_server_session.h"
+#include "ybus.h"
+
+#include <util/generic/singleton.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NActor;
+
+TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, TExecutorPtr executor, TBusLocator* locator, const char* name) {
+ return new TBusMessageQueue(config, executor, locator, name);
+}
+
+TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, TBusLocator* locator, const char* name) {
+ TExecutor::TConfig executorConfig;
+ executorConfig.WorkerCount = config.NumWorkers;
+ executorConfig.Name = name;
+ TExecutorPtr executor = new TExecutor(executorConfig);
+ return CreateMessageQueue(config, executor, locator, name);
+}
+
+TBusMessageQueuePtr NBus::CreateMessageQueue(const TBusQueueConfig& config, const char* name) {
+ return CreateMessageQueue(config, new TBusLocator, name);
+}
+
+TBusMessageQueuePtr NBus::CreateMessageQueue(TExecutorPtr executor, const char* name) {
+ return CreateMessageQueue(TBusQueueConfig(), executor, new TBusLocator, name);
+}
+
+TBusMessageQueuePtr NBus::CreateMessageQueue(const char* name) {
+ TBusQueueConfig config;
+ return CreateMessageQueue(config, name);
+}
+
+namespace {
+ TBusQueueConfig QueueConfigFillDefaults(const TBusQueueConfig& orig, const TString& name) {
+ TBusQueueConfig patched = orig;
+ if (!patched.Name) {
+ patched.Name = name;
+ }
+ return patched;
+ }
+}
+
+TBusMessageQueue::TBusMessageQueue(const TBusQueueConfig& config, TExecutorPtr executor, TBusLocator* locator, const char* name)
+ : Config(QueueConfigFillDefaults(config, name))
+ , Locator(locator)
+ , WorkQueue(executor)
+ , Running(1)
+{
+ InitBusLwtrace();
+ InitNetworkSubSystem();
+}
+
+TBusMessageQueue::~TBusMessageQueue() {
+ Stop();
+}
+
+void TBusMessageQueue::Stop() {
+ if (!AtomicCas(&Running, 0, 1)) {
+ ShutdownComplete.WaitI();
+ return;
+ }
+
+ Scheduler.Stop();
+
+ DestroyAllSessions();
+
+ WorkQueue->Stop();
+
+ ShutdownComplete.Signal();
+}
+
+bool TBusMessageQueue::IsRunning() {
+ return AtomicGet(Running);
+}
+
+TBusMessageQueueStatus TBusMessageQueue::GetStatusRecordInternal() const {
+ TBusMessageQueueStatus r;
+ r.ExecutorStatus = WorkQueue->GetStatusRecordInternal();
+ r.Config = Config;
+ return r;
+}
+
+TString TBusMessageQueue::GetStatusSelf() const {
+ return GetStatusRecordInternal().PrintToString();
+}
+
+TString TBusMessageQueue::GetStatusSingleLine() const {
+ return WorkQueue->GetStatusSingleLine();
+}
+
+TString TBusMessageQueue::GetStatus(ui16 flags) const {
+ TStringStream ss;
+
+ ss << GetStatusSelf();
+
+ TList<TIntrusivePtr<TBusSessionImpl>> sessions;
+ {
+ TGuard<TMutex> scope(Lock);
+ sessions = Sessions;
+ }
+
+ for (TList<TIntrusivePtr<TBusSessionImpl>>::const_iterator session = sessions.begin();
+ session != sessions.end(); ++session) {
+ ss << Endl;
+ ss << (*session)->GetStatus(flags);
+ }
+
+ ss << Endl;
+ ss << "object counts (not necessarily owned by this message queue):" << Endl;
+ TKeyValuePrinter p;
+ p.AddRow("TRemoteClientConnection", TObjectCounter<TRemoteClientConnection>::ObjectCount(), false);
+ p.AddRow("TRemoteServerConnection", TObjectCounter<TRemoteServerConnection>::ObjectCount(), false);
+ p.AddRow("TRemoteClientSession", TObjectCounter<TRemoteClientSession>::ObjectCount(), false);
+ p.AddRow("TRemoteServerSession", TObjectCounter<TRemoteServerSession>::ObjectCount(), false);
+ p.AddRow("NEventLoop::TEventLoop", TObjectCounter<NEventLoop::TEventLoop>::ObjectCount(), false);
+ p.AddRow("NEventLoop::TChannel", TObjectCounter<NEventLoop::TChannel>::ObjectCount(), false);
+ ss << p.PrintToString();
+
+ return ss.Str();
+}
+
+TBusClientSessionPtr TBusMessageQueue::CreateSource(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, const TString& name) {
+ TRemoteClientSessionPtr session(new TRemoteClientSession(this, proto, handler, config, name));
+ Add(session.Get());
+ return session.Get();
+}
+
+TBusServerSessionPtr TBusMessageQueue::CreateDestination(TBusProtocol* proto, IBusServerHandler* handler, const TBusClientSessionConfig& config, const TString& name) {
+ TRemoteServerSessionPtr session(new TRemoteServerSession(this, proto, handler, config, name));
+ try {
+ int port = config.ListenPort;
+ if (port == 0) {
+ port = Locator->GetLocalPort(proto->GetService());
+ }
+ if (port == 0) {
+ port = proto->GetPort();
+ }
+
+ session->Listen(port, this);
+
+ Add(session.Get());
+ return session.Release();
+ } catch (...) {
+ Y_FAIL("create destination failure: %s", CurrentExceptionMessage().c_str());
+ }
+}
+
+TBusServerSessionPtr TBusMessageQueue::CreateDestination(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, const TVector<TBindResult>& bindTo, const TString& name) {
+ TRemoteServerSessionPtr session(new TRemoteServerSession(this, proto, handler, config, name));
+ try {
+ session->Listen(bindTo, this);
+ Add(session.Get());
+ return session.Release();
+ } catch (...) {
+ Y_FAIL("create destination failure: %s", CurrentExceptionMessage().c_str());
+ }
+}
+
+void TBusMessageQueue::Add(TIntrusivePtr<TBusSessionImpl> session) {
+ TGuard<TMutex> scope(Lock);
+ Sessions.push_back(session);
+}
+
+void TBusMessageQueue::Remove(TBusSession* session) {
+ TGuard<TMutex> scope(Lock);
+ TList<TIntrusivePtr<TBusSessionImpl>>::iterator it = std::find(Sessions.begin(), Sessions.end(), session);
+ Y_VERIFY(it != Sessions.end(), "do not destroy session twice");
+ Sessions.erase(it);
+}
+
+void TBusMessageQueue::Destroy(TBusSession* session) {
+ session->Shutdown();
+}
+
+void TBusMessageQueue::DestroyAllSessions() {
+ TList<TIntrusivePtr<TBusSessionImpl>> sessions;
+ {
+ TGuard<TMutex> scope(Lock);
+ sessions = Sessions;
+ }
+
+ for (auto& session : sessions) {
+ Y_VERIFY(session->IsDown(), "Session must be shut down prior to queue shutdown");
+ }
+}
+
+void TBusMessageQueue::Schedule(IScheduleItemAutoPtr i) {
+ Scheduler.Schedule(i);
+}
+
+TString TBusMessageQueue::GetNameInternal() const {
+ return Config.Name;
+}
diff --git a/library/cpp/messagebus/misc/atomic_box.h b/library/cpp/messagebus/misc/atomic_box.h
new file mode 100644
index 0000000000..401621f933
--- /dev/null
+++ b/library/cpp/messagebus/misc/atomic_box.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <util/system/atomic.h>
+
+// TAtomic with human interface
+template <typename T>
+class TAtomicBox {
+private:
+ union {
+ TAtomic Value;
+ // when T is enum, it is convenient to inspect its content in gdb
+ T ValueForDebugger;
+ };
+
+ static_assert(sizeof(T) <= sizeof(TAtomic), "expect sizeof(T) <= sizeof(TAtomic)");
+
+public:
+ TAtomicBox(T value = T())
+ : Value(value)
+ {
+ }
+
+ void Set(T value) {
+ AtomicSet(Value, (TAtomic)value);
+ }
+
+ T Get() const {
+ return (T)AtomicGet(Value);
+ }
+
+ bool CompareAndSet(T expected, T set) {
+ return AtomicCas(&Value, (TAtomicBase)set, (TAtomicBase)expected);
+ }
+};
diff --git a/library/cpp/messagebus/misc/granup.h b/library/cpp/messagebus/misc/granup.h
new file mode 100644
index 0000000000..36ecfebc93
--- /dev/null
+++ b/library/cpp/messagebus/misc/granup.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include <util/datetime/base.h>
+#include <util/system/guard.h>
+#include <util/system/mutex.h>
+#include <util/system/spinlock.h>
+
+namespace NBus {
+ template <typename TItem, typename TLocker = TSpinLock>
+ class TGranUp {
+ public:
+ TGranUp(TDuration gran)
+ : Gran(gran)
+ , Next(TInstant::MicroSeconds(0))
+ {
+ }
+
+ template <typename TFunctor>
+ void Update(TFunctor functor, TInstant now, bool force = false) {
+ if (force || now > Next)
+ Set(functor(), now);
+ }
+
+ void Update(const TItem& item, TInstant now, bool force = false) {
+ if (force || now > Next)
+ Set(item, now);
+ }
+
+ TItem Get() const noexcept {
+ TGuard<TLocker> guard(Lock);
+
+ return Item;
+ }
+
+ protected:
+ void Set(const TItem& item, TInstant now) {
+ TGuard<TLocker> guard(Lock);
+
+ Item = item;
+
+ Next = now + Gran;
+ }
+
+ private:
+ const TDuration Gran;
+ TLocker Lock;
+ TItem Item;
+ TInstant Next;
+ };
+}
diff --git a/library/cpp/messagebus/misc/test_sync.h b/library/cpp/messagebus/misc/test_sync.h
new file mode 100644
index 0000000000..be3f4f20b8
--- /dev/null
+++ b/library/cpp/messagebus/misc/test_sync.h
@@ -0,0 +1,75 @@
+#pragma once
+
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+
+class TTestSync {
+private:
+ unsigned Current;
+
+ TMutex Mutex;
+ TCondVar CondVar;
+
+public:
+ TTestSync()
+ : Current(0)
+ {
+ }
+
+ void Inc() {
+ TGuard<TMutex> guard(Mutex);
+
+ DoInc();
+ CondVar.BroadCast();
+ }
+
+ unsigned Get() {
+ TGuard<TMutex> guard(Mutex);
+
+ return Current;
+ }
+
+ void WaitFor(unsigned n) {
+ TGuard<TMutex> guard(Mutex);
+
+ Y_VERIFY(Current <= n, "too late, waiting for %d, already %d", n, Current);
+
+ while (n > Current) {
+ CondVar.WaitI(Mutex);
+ }
+ }
+
+ void WaitForAndIncrement(unsigned n) {
+ TGuard<TMutex> guard(Mutex);
+
+ Y_VERIFY(Current <= n, "too late, waiting for %d, already %d", n, Current);
+
+ while (n > Current) {
+ CondVar.WaitI(Mutex);
+ }
+
+ DoInc();
+ CondVar.BroadCast();
+ }
+
+ void CheckAndIncrement(unsigned n) {
+ TGuard<TMutex> guard(Mutex);
+
+ Y_VERIFY(Current == n, "must be %d, currently %d", n, Current);
+
+ DoInc();
+ CondVar.BroadCast();
+ }
+
+ void Check(unsigned n) {
+ TGuard<TMutex> guard(Mutex);
+
+ Y_VERIFY(Current == n, "must be %d, currently %d", n, Current);
+ }
+
+private:
+ void DoInc() {
+ unsigned r = ++Current;
+ Y_UNUSED(r);
+ }
+};
diff --git a/library/cpp/messagebus/misc/tokenquota.h b/library/cpp/messagebus/misc/tokenquota.h
new file mode 100644
index 0000000000..190547fa54
--- /dev/null
+++ b/library/cpp/messagebus/misc/tokenquota.h
@@ -0,0 +1,83 @@
+#pragma once
+
+#include <util/system/atomic.h>
+
+namespace NBus {
+ /* Consumer and feeder quota model impl.
+
+ Consumer thread only calls:
+ Acquire(), fetches tokens for usage from bucket;
+ Consume(), eats given amount of tokens, must not
+ be greater than Value() items;
+
+ Other threads (feeders) calls:
+ Return(), put used tokens back to bucket;
+ */
+
+ class TTokenQuota {
+ public:
+ TTokenQuota(bool enabled, size_t tokens, size_t wake)
+ : Enabled(tokens > 0 ? enabled : false)
+ , Acquired(0)
+ , WakeLev(wake < 1 ? Max<size_t>(1, tokens / 2) : 0)
+ , Tokens_(tokens)
+ {
+ Y_UNUSED(padd_);
+ }
+
+ bool Acquire(TAtomic level = 1, bool force = false) {
+ level = Max(TAtomicBase(level), TAtomicBase(1));
+
+ if (Enabled && (Acquired < level || force)) {
+ Acquired += AtomicSwap(&Tokens_, 0);
+ }
+
+ return !Enabled || Acquired >= level;
+ }
+
+ void Consume(size_t items) {
+ if (Enabled) {
+ Y_ASSERT(Acquired >= TAtomicBase(items));
+
+ Acquired -= items;
+ }
+ }
+
+ bool Return(size_t items_) noexcept {
+ if (!Enabled || items_ == 0)
+ return false;
+
+ const TAtomic items = items_;
+ const TAtomic value = AtomicAdd(Tokens_, items);
+
+ return (value - items < WakeLev && value >= WakeLev);
+ }
+
+ bool IsEnabled() const noexcept {
+ return Enabled;
+ }
+
+ bool IsAboveWake() const noexcept {
+ return !Enabled || (WakeLev <= AtomicGet(Tokens_));
+ }
+
+ size_t Tokens() const noexcept {
+ return Acquired + AtomicGet(Tokens_);
+ }
+
+ size_t Check(const TAtomic level) const noexcept {
+ return !Enabled || level <= Acquired;
+ }
+
+ private:
+ bool Enabled;
+ TAtomicBase Acquired;
+ const TAtomicBase WakeLev;
+ TAtomic Tokens_;
+
+ /* This padd requires for align Tokens_ member on its own
+ CPU cacheline. */
+
+ ui64 padd_;
+ };
+}
diff --git a/library/cpp/messagebus/misc/weak_ptr.h b/library/cpp/messagebus/misc/weak_ptr.h
new file mode 100644
index 0000000000..70fdeb0e2a
--- /dev/null
+++ b/library/cpp/messagebus/misc/weak_ptr.h
@@ -0,0 +1,99 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <util/system/mutex.h>
+
+template <typename T>
+struct TWeakPtr;
+
+template <typename TSelf>
+struct TWeakRefCounted {
+ template <typename>
+ friend struct TWeakPtr;
+
+private:
+ struct TRef: public TAtomicRefCount<TRef> {
+ TMutex Mutex;
+ TSelf* Outer;
+
+ TRef(TSelf* outer)
+ : Outer(outer)
+ {
+ }
+
+ void Release() {
+ TGuard<TMutex> g(Mutex);
+ Y_ASSERT(!!Outer);
+ Outer = nullptr;
+ }
+
+ TIntrusivePtr<TSelf> Get() {
+ TGuard<TMutex> g(Mutex);
+ Y_ASSERT(!Outer || Outer->RefCount() > 0);
+ return Outer;
+ }
+ };
+
+ TAtomicCounter Counter;
+ TIntrusivePtr<TRef> RefPtr;
+
+public:
+ TWeakRefCounted()
+ : RefPtr(new TRef(static_cast<TSelf*>(this)))
+ {
+ }
+
+ void Ref() {
+ Counter.Inc();
+ }
+
+ void UnRef() {
+ if (Counter.Dec() == 0) {
+ RefPtr->Release();
+
+ // drop is to prevent dtor from reading it
+ RefPtr.Drop();
+
+ delete static_cast<TSelf*>(this);
+ }
+ }
+
+ void DecRef() {
+ Counter.Dec();
+ }
+
+ unsigned RefCount() const {
+ return Counter.Val();
+ }
+};
+
+template <typename T>
+struct TWeakPtr {
+private:
+ typedef TIntrusivePtr<typename T::TRef> TRefPtr;
+ TRefPtr RefPtr;
+
+public:
+ TWeakPtr() {
+ }
+
+ TWeakPtr(T* t) {
+ if (!!t) {
+ RefPtr = t->RefPtr;
+ }
+ }
+
+ TWeakPtr(TIntrusivePtr<T> t) {
+ if (!!t) {
+ RefPtr = t->RefPtr;
+ }
+ }
+
+ TIntrusivePtr<T> Get() {
+ if (!RefPtr) {
+ return nullptr;
+ } else {
+ return RefPtr->Get();
+ }
+ }
+};
diff --git a/library/cpp/messagebus/misc/weak_ptr_ut.cpp b/library/cpp/messagebus/misc/weak_ptr_ut.cpp
new file mode 100644
index 0000000000..5a325278db
--- /dev/null
+++ b/library/cpp/messagebus/misc/weak_ptr_ut.cpp
@@ -0,0 +1,46 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "weak_ptr.h"
+
+Y_UNIT_TEST_SUITE(TWeakPtrTest) {
+ struct TWeakPtrTester: public TWeakRefCounted<TWeakPtrTester> {
+ int* const CounterPtr;
+
+ TWeakPtrTester(int* counterPtr)
+ : CounterPtr(counterPtr)
+ {
+ }
+ ~TWeakPtrTester() {
+ ++*CounterPtr;
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ int destroyCount = 0;
+
+ TIntrusivePtr<TWeakPtrTester> p(new TWeakPtrTester(&destroyCount));
+
+ UNIT_ASSERT(!!p);
+ UNIT_ASSERT_VALUES_EQUAL(1u, p->RefCount());
+
+ TWeakPtr<TWeakPtrTester> p2(p);
+
+ UNIT_ASSERT_VALUES_EQUAL(1u, p->RefCount());
+
+ {
+ TIntrusivePtr<TWeakPtrTester> p3 = p2.Get();
+ UNIT_ASSERT(!!p3);
+ UNIT_ASSERT_VALUES_EQUAL(2u, p->RefCount());
+ }
+
+ p.Drop();
+ UNIT_ASSERT_VALUES_EQUAL(1, destroyCount);
+
+ {
+ TIntrusivePtr<TWeakPtrTester> p3 = p2.Get();
+ UNIT_ASSERT(!p3);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(1, destroyCount);
+ }
+}
diff --git a/library/cpp/messagebus/monitoring/mon_proto.proto b/library/cpp/messagebus/monitoring/mon_proto.proto
new file mode 100644
index 0000000000..73b6614481
--- /dev/null
+++ b/library/cpp/messagebus/monitoring/mon_proto.proto
@@ -0,0 +1,55 @@
+import "library/cpp/monlib/encode/legacy_protobuf/protos/metric_meta.proto";
+
+package NBus;
+
+option java_package = "ru.yandex.messagebus.monitoring.proto";
+
+message TMessageStatusRecord {
+ enum EMessageStatus {
+ MESSAGE_OK = 0;
+ MESSAGE_CONNECT_FAILED = 1;
+ MESSAGE_TIMEOUT = 2;
+ MESSAGE_SERVICE_UNKNOWN = 3;
+ MESSAGE_BUSY = 4;
+ MESSAGE_UNKNOWN = 5;
+ MESSAGE_DESERIALIZE_ERROR = 6;
+ MESSAGE_HEADER_CORRUPTED = 7;
+ MESSAGE_DECOMPRESS_ERROR = 8;
+ MESSAGE_MESSAGE_TOO_LARGE = 9;
+ MESSAGE_REPLY_FAILED = 10;
+ MESSAGE_DELIVERY_FAILED = 11;
+ MESSAGE_INVALID_VERSION = 12;
+ MESSAGE_SERVICE_TOOMANY = 13;
+ MESSAGE_SHUTDOWN = 14;
+ MESSAGE_DONT_ASK = 15;
+ }
+
+ optional EMessageStatus Status = 1;
+ optional uint32 Count = 2;
+}
+
+message TConnectionStatusMonRecord {
+ optional uint32 SendQueueSize = 1 [ (NMonProto.Metric).Type = GAUGE ];
+ // client only
+ optional uint32 AckMessagesSize = 2 [ (NMonProto.Metric).Type = GAUGE ];
+ optional uint32 ErrorCount = 3 [ (NMonProto.Metric).Type = RATE ];
+
+ optional uint64 WriteBytes = 10 [ (NMonProto.Metric).Type = RATE ];
+ optional uint64 WriteBytesCompressed = 11;
+ optional uint64 WriteMessages = 12 [ (NMonProto.Metric).Type = RATE ];
+ optional uint64 WriteSyscalls = 13;
+ optional uint64 WriteActs = 14;
+ optional uint64 ReadBytes = 20 [ (NMonProto.Metric).Type = RATE ];
+ optional uint64 ReadBytesCompressed = 21;
+ optional uint64 ReadMessages = 22 [ (NMonProto.Metric).Type = RATE ];
+ optional uint64 ReadSyscalls = 23;
+ optional uint64 ReadActs = 24;
+
+ repeated TMessageStatusRecord ErrorCountByStatus = 25;
+}
+
+message TSessionStatusMonRecord {
+ optional uint32 InFlight = 1 [ (NMonProto.Metric).Type = GAUGE ];
+ optional uint32 ConnectionCount = 2 [ (NMonProto.Metric).Type = GAUGE ];
+ optional uint32 ConnectCount = 3 [ (NMonProto.Metric).Type = RATE ];
+}
diff --git a/library/cpp/messagebus/monitoring/ya.make b/library/cpp/messagebus/monitoring/ya.make
new file mode 100644
index 0000000000..25782492b1
--- /dev/null
+++ b/library/cpp/messagebus/monitoring/ya.make
@@ -0,0 +1,15 @@
+PROTO_LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/monlib/encode/legacy_protobuf/protos
+)
+
+SRCS(
+ mon_proto.proto
+)
+
+EXCLUDE_TAGS(GO_PROTO)
+
+END()
diff --git a/library/cpp/messagebus/moved.h b/library/cpp/messagebus/moved.h
new file mode 100644
index 0000000000..ede8dcd244
--- /dev/null
+++ b/library/cpp/messagebus/moved.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <util/generic/utility.h>
+
+template <typename T>
+class TMoved {
+private:
+ mutable T Value;
+
+public:
+ TMoved() {
+ }
+ TMoved(const TMoved<T>& that) {
+ DoSwap(Value, that.Value);
+ }
+ TMoved(const T& that) {
+ DoSwap(Value, const_cast<T&>(that));
+ }
+
+ void swap(TMoved& that) {
+ DoSwap(Value, that.Value);
+ }
+
+ T& operator*() {
+ return Value;
+ }
+
+ const T& operator*() const {
+ return Value;
+ }
+
+ T* operator->() {
+ return &Value;
+ }
+
+ const T* operator->() const {
+ return &Value;
+ }
+};
diff --git a/library/cpp/messagebus/moved_ut.cpp b/library/cpp/messagebus/moved_ut.cpp
new file mode 100644
index 0000000000..c1a07cce7e
--- /dev/null
+++ b/library/cpp/messagebus/moved_ut.cpp
@@ -0,0 +1,22 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "moved.h"
+
+Y_UNIT_TEST_SUITE(TMovedTest) {
+ Y_UNIT_TEST(Simple) {
+ TMoved<THolder<int>> h1(MakeHolder<int>(10));
+ TMoved<THolder<int>> h2 = h1;
+ UNIT_ASSERT(!*h1);
+ UNIT_ASSERT(!!*h2);
+ UNIT_ASSERT_VALUES_EQUAL(10, **h2);
+ }
+
+ void Foo(TMoved<THolder<int>> h) {
+ UNIT_ASSERT_VALUES_EQUAL(11, **h);
+ }
+
+ Y_UNIT_TEST(PassToFunction) {
+ THolder<int> h(new int(11));
+ Foo(h);
+ }
+}
diff --git a/library/cpp/messagebus/netaddr.h b/library/cpp/messagebus/netaddr.h
new file mode 100644
index 0000000000..f915c8c574
--- /dev/null
+++ b/library/cpp/messagebus/netaddr.h
@@ -0,0 +1,4 @@
+#pragma once
+
+#include <library/cpp/messagebus/config/netaddr.h>
+
diff --git a/library/cpp/messagebus/netaddr_ut.cpp b/library/cpp/messagebus/netaddr_ut.cpp
new file mode 100644
index 0000000000..e5c68bf402
--- /dev/null
+++ b/library/cpp/messagebus/netaddr_ut.cpp
@@ -0,0 +1,21 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "netaddr.h"
+#include "test_utils.h"
+
+using namespace NBus;
+
+Y_UNIT_TEST_SUITE(TNetAddr) {
+ Y_UNIT_TEST(ResolveIpv4) {
+ ASSUME_IP_V4_ENABLED;
+ UNIT_ASSERT(TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_4).IsIpv4());
+ }
+
+ Y_UNIT_TEST(ResolveIpv6) {
+ UNIT_ASSERT(TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_6).IsIpv6());
+ }
+
+ Y_UNIT_TEST(ResolveAny) {
+ TNetAddr("ns1.yandex.ru", 80, EIP_VERSION_ANY);
+ }
+}
diff --git a/library/cpp/messagebus/network.cpp b/library/cpp/messagebus/network.cpp
new file mode 100644
index 0000000000..304bedae5a
--- /dev/null
+++ b/library/cpp/messagebus/network.cpp
@@ -0,0 +1,156 @@
+#include "network.h"
+
+#include <util/generic/maybe.h>
+#include <util/generic/ptr.h>
+#include <util/network/init.h>
+#include <util/network/socket.h>
+#include <util/system/platform.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+namespace {
+ TBindResult BindOnPortProto(int port, int af, bool reusePort) {
+ Y_VERIFY(af == AF_INET || af == AF_INET6, "wrong af");
+
+ SOCKET fd = ::socket(af, SOCK_STREAM, 0);
+ if (fd == INVALID_SOCKET) {
+ ythrow TSystemError() << "failed to create a socket";
+ }
+
+ int one = 1;
+ int r1 = SetSockOpt(fd, SOL_SOCKET, SO_REUSEADDR, one);
+ if (r1 < 0) {
+ ythrow TSystemError() << "failed to setsockopt SO_REUSEADDR";
+ }
+
+#ifdef SO_REUSEPORT
+ if (reusePort) {
+ int r = SetSockOpt(fd, SOL_SOCKET, SO_REUSEPORT, one);
+ if (r < 0) {
+ ythrow TSystemError() << "failed to setsockopt SO_REUSEPORT";
+ }
+ }
+#else
+ Y_UNUSED(reusePort);
+#endif
+
+ THolder<TOpaqueAddr> addr(new TOpaqueAddr);
+ sockaddr* sa = addr->MutableAddr();
+ sa->sa_family = af;
+ socklen_t len;
+ if (af == AF_INET) {
+ len = sizeof(sockaddr_in);
+ ((sockaddr_in*)sa)->sin_port = HostToInet((ui16)port);
+ ((sockaddr_in*)sa)->sin_addr.s_addr = INADDR_ANY;
+ } else {
+ len = sizeof(sockaddr_in6);
+ ((sockaddr_in6*)sa)->sin6_port = HostToInet((ui16)port);
+ }
+
+ if (af == AF_INET6) {
+ FixIPv6ListenSocket(fd);
+ }
+
+ int r2 = ::bind(fd, sa, len);
+ if (r2 < 0) {
+ ythrow TSystemError() << "failed to bind on port " << port;
+ }
+
+ int rsn = ::getsockname(fd, addr->MutableAddr(), addr->LenPtr());
+ if (rsn < 0) {
+ ythrow TSystemError() << "failed to getsockname";
+ }
+
+ int r3 = ::listen(fd, 50);
+ if (r3 < 0) {
+ ythrow TSystemError() << "listen failed";
+ }
+
+ TBindResult r;
+ r.Socket.Reset(new TSocketHolder(fd));
+ r.Addr = TNetAddr(addr.Release());
+ return r;
+ }
+
+ TMaybe<TBindResult> TryBindOnPortProto(int port, int af, bool reusePort) {
+ try {
+ return {BindOnPortProto(port, af, reusePort)};
+ } catch (const TSystemError&) {
+ return {};
+ }
+ }
+
+ std::pair<unsigned, TVector<TBindResult>> AggregateBindResults(TBindResult&& r1, TBindResult&& r2) {
+ Y_VERIFY(r1.Addr.GetPort() == r2.Addr.GetPort(), "internal");
+ std::pair<unsigned, TVector<TBindResult>> r;
+ r.second.reserve(2);
+
+ r.first = r1.Addr.GetPort();
+ r.second.emplace_back(std::move(r1));
+ r.second.emplace_back(std::move(r2));
+ return r;
+ }
+}
+
+std::pair<unsigned, TVector<TBindResult>> NBus::BindOnPort(int port, bool reusePort) {
+ std::pair<unsigned, TVector<TBindResult>> r;
+ r.second.reserve(2);
+
+ if (port != 0) {
+ return AggregateBindResults(BindOnPortProto(port, AF_INET, reusePort),
+ BindOnPortProto(port, AF_INET6, reusePort));
+ }
+
+ // use nothrow versions in cycle
+ for (int i = 0; i < 1000; ++i) {
+ TMaybe<TBindResult> in4 = TryBindOnPortProto(0, AF_INET, reusePort);
+ if (!in4) {
+ continue;
+ }
+
+ TMaybe<TBindResult> in6 = TryBindOnPortProto(in4->Addr.GetPort(), AF_INET6, reusePort);
+ if (!in6) {
+ continue;
+ }
+
+ return AggregateBindResults(std::move(*in4), std::move(*in6));
+ }
+
+ TBindResult in4 = BindOnPortProto(0, AF_INET, reusePort);
+ TBindResult in6 = BindOnPortProto(in4.Addr.GetPort(), AF_INET6, reusePort);
+ return AggregateBindResults(std::move(in4), std::move(in6));
+}
+
+void NBus::NPrivate::SetSockOptTcpCork(SOCKET s, bool value) {
+#ifdef _linux_
+ CheckedSetSockOpt(s, IPPROTO_TCP, TCP_CORK, (int)value, "TCP_CORK");
+#else
+ Y_UNUSED(s);
+ Y_UNUSED(value);
+#endif
+}
+
+ssize_t NBus::NPrivate::SocketSend(SOCKET s, TArrayRef<const char> data) {
+ int flags = 0;
+#if defined(_linux_) || defined(_freebsd_)
+ flags |= MSG_NOSIGNAL;
+#endif
+ ssize_t r = ::send(s, data.data(), data.size(), flags);
+ if (r < 0) {
+ Y_VERIFY(LastSystemError() != EBADF, "bad fd");
+ }
+ return r;
+}
+
+ssize_t NBus::NPrivate::SocketRecv(SOCKET s, TArrayRef<char> buffer) {
+ int flags = 0;
+#if defined(_linux_) || defined(_freebsd_)
+ flags |= MSG_NOSIGNAL;
+#endif
+ ssize_t r = ::recv(s, buffer.data(), buffer.size(), flags);
+ if (r < 0) {
+ Y_VERIFY(LastSystemError() != EBADF, "bad fd");
+ }
+ return r;
+}
diff --git a/library/cpp/messagebus/network.h b/library/cpp/messagebus/network.h
new file mode 100644
index 0000000000..cc4bd76ea3
--- /dev/null
+++ b/library/cpp/messagebus/network.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "netaddr.h"
+
+#include <util/generic/array_ref.h>
+#include <util/generic/ptr.h>
+#include <util/network/socket.h>
+
+#include <utility>
+
+namespace NBus {
+ namespace NPrivate {
+ void SetSockOptTcpCork(SOCKET s, bool value);
+
+ [[nodiscard]] ssize_t SocketSend(SOCKET s, TArrayRef<const char> data);
+
+ [[nodiscard]] ssize_t SocketRecv(SOCKET s, TArrayRef<char> buffer);
+
+ }
+
+ struct TBindResult {
+ TSimpleSharedPtr<TSocketHolder> Socket;
+ TNetAddr Addr;
+ };
+
+ std::pair<unsigned, TVector<TBindResult>> BindOnPort(int port, bool reusePort);
+
+}
diff --git a/library/cpp/messagebus/network_ut.cpp b/library/cpp/messagebus/network_ut.cpp
new file mode 100644
index 0000000000..f1798419db
--- /dev/null
+++ b/library/cpp/messagebus/network_ut.cpp
@@ -0,0 +1,65 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "network.h"
+
+#include <library/cpp/messagebus/test/helper/fixed_port.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NBus::NTest;
+
+namespace {
+ int GetSockPort(SOCKET socket) {
+ sockaddr_storage addr;
+ Zero(addr);
+
+ socklen_t len = sizeof(addr);
+
+ int r = ::getsockname(socket, (sockaddr*)&addr, &len);
+ UNIT_ASSERT(r >= 0);
+
+ if (addr.ss_family == AF_INET) {
+ sockaddr_in* addr_in = (sockaddr_in*)&addr;
+ return InetToHost(addr_in->sin_port);
+ } else if (addr.ss_family == AF_INET6) {
+ sockaddr_in6* addr_in6 = (sockaddr_in6*)&addr;
+ return InetToHost(addr_in6->sin6_port);
+ } else {
+ UNIT_FAIL("unknown AF");
+ throw 1;
+ }
+ }
+}
+
+Y_UNIT_TEST_SUITE(Network) {
+ Y_UNIT_TEST(BindOnPortConcrete) {
+ if (!IsFixedPortTestAllowed()) {
+ return;
+ }
+
+ TVector<TBindResult> r = BindOnPort(FixedPort, false).second;
+ UNIT_ASSERT_VALUES_EQUAL(size_t(2), r.size());
+
+ for (TVector<TBindResult>::iterator i = r.begin(); i != r.end(); ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(i->Addr.GetPort(), GetSockPort(i->Socket->operator SOCKET()));
+ }
+ }
+
+ Y_UNIT_TEST(BindOnPortRandom) {
+ TVector<TBindResult> r = BindOnPort(0, false).second;
+ UNIT_ASSERT_VALUES_EQUAL(size_t(2), r.size());
+
+ for (TVector<TBindResult>::iterator i = r.begin(); i != r.end(); ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(i->Addr.GetPort(), GetSockPort(i->Socket->operator SOCKET()));
+ UNIT_ASSERT(i->Addr.GetPort() > 0);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(r.at(0).Addr.GetPort(), r.at(1).Addr.GetPort());
+ }
+
+ Y_UNIT_TEST(BindOnBusyPort) {
+ auto r = BindOnPort(0, false);
+
+ UNIT_ASSERT_EXCEPTION_CONTAINS(BindOnPort(r.first, false), TSystemError, "failed to bind on port " + ToString(r.first));
+ }
+}
diff --git a/library/cpp/messagebus/nondestroying_holder.h b/library/cpp/messagebus/nondestroying_holder.h
new file mode 100644
index 0000000000..f4725d696f
--- /dev/null
+++ b/library/cpp/messagebus/nondestroying_holder.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+
+template <typename T>
+class TNonDestroyingHolder: public THolder<T> {
+public:
+ TNonDestroyingHolder(T* t = nullptr) noexcept
+ : THolder<T>(t)
+ {
+ }
+
+ TNonDestroyingHolder(TAutoPtr<T> t) noexcept
+ : THolder<T>(t)
+ {
+ }
+
+ ~TNonDestroyingHolder() {
+ Y_VERIFY(!*this, "stored object must be explicitly released");
+ }
+};
+
+template <class T>
+class TNonDestroyingAutoPtr: public TAutoPtr<T> {
+public:
+ inline TNonDestroyingAutoPtr(T* t = 0) noexcept
+ : TAutoPtr<T>(t)
+ {
+ }
+
+ inline TNonDestroyingAutoPtr(const TAutoPtr<T>& t) noexcept
+ : TAutoPtr<T>(t.Release())
+ {
+ }
+
+ inline ~TNonDestroyingAutoPtr() {
+ Y_VERIFY(!*this, "stored object must be explicitly released");
+ }
+};
diff --git a/library/cpp/messagebus/nondestroying_holder_ut.cpp b/library/cpp/messagebus/nondestroying_holder_ut.cpp
new file mode 100644
index 0000000000..208042a2ba
--- /dev/null
+++ b/library/cpp/messagebus/nondestroying_holder_ut.cpp
@@ -0,0 +1,12 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "nondestroying_holder.h"
+
+Y_UNIT_TEST_SUITE(TNonDestroyingHolder) {
+ Y_UNIT_TEST(ToAutoPtr) {
+ TNonDestroyingHolder<int> h(new int(11));
+ TAutoPtr<int> i(h);
+ UNIT_ASSERT_VALUES_EQUAL(11, *i);
+ UNIT_ASSERT(!h);
+ }
+}
diff --git a/library/cpp/messagebus/oldmodule/module.cpp b/library/cpp/messagebus/oldmodule/module.cpp
new file mode 100644
index 0000000000..24bd778799
--- /dev/null
+++ b/library/cpp/messagebus/oldmodule/module.cpp
@@ -0,0 +1,881 @@
+#include "module.h"
+
+#include <library/cpp/messagebus/scheduler_actor.h>
+#include <library/cpp/messagebus/thread_extra.h>
+#include <library/cpp/messagebus/actor/actor.h>
+#include <library/cpp/messagebus/actor/queue_in_actor.h>
+#include <library/cpp/messagebus/actor/what_thread_does.h>
+#include <library/cpp/messagebus/actor/what_thread_does_guard.h>
+
+#include <util/generic/singleton.h>
+#include <util/string/printf.h>
+#include <util/system/event.h>
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+namespace {
+ Y_POD_STATIC_THREAD(TBusJob*)
+ ThreadCurrentJob;
+
+ struct TThreadCurrentJobGuard {
+ TBusJob* Prev;
+
+ TThreadCurrentJobGuard(TBusJob* job)
+ : Prev(ThreadCurrentJob)
+ {
+ Y_ASSERT(!ThreadCurrentJob || ThreadCurrentJob == job);
+ ThreadCurrentJob = job;
+ }
+ ~TThreadCurrentJobGuard() {
+ ThreadCurrentJob = Prev;
+ }
+ };
+
+ void ClearState(NBus::TJobState* state) {
+ /// skip sendbacks handlers
+ if (state->Message != state->Reply) {
+ if (state->Message) {
+ delete state->Message;
+ state->Message = nullptr;
+ }
+
+ if (state->Reply) {
+ delete state->Reply;
+ state->Reply = nullptr;
+ }
+ }
+ }
+
+ void ClearJobStateVector(NBus::TJobStateVec* vec) {
+ Y_ASSERT(vec);
+
+ for (auto& call : *vec) {
+ ClearState(&call);
+ }
+
+ vec->clear();
+ }
+
+}
+
+namespace NBus {
+ namespace NPrivate {
+ class TJobStorage {
+ };
+
+ struct TModuleClientHandler
+ : public IBusClientHandler {
+ TModuleClientHandler(TBusModuleImpl* module)
+ : Module(module)
+ {
+ }
+
+ void OnReply(TAutoPtr<TBusMessage> req, TAutoPtr<TBusMessage> reply) override;
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override;
+ void OnError(TAutoPtr<TBusMessage> msg, EMessageStatus status) override;
+ void OnClientConnectionEvent(const TClientConnectionEvent& event) override;
+
+ TBusModuleImpl* const Module;
+ };
+
+ struct TModuleServerHandler
+ : public IBusServerHandler {
+ TModuleServerHandler(TBusModuleImpl* module)
+ : Module(module)
+ {
+ }
+
+ void OnMessage(TOnMessageContext& msg) override;
+
+ TBusModuleImpl* const Module;
+ };
+
+ struct TBusModuleImpl: public TBusModuleInternal {
+ TBusModule* const Module;
+
+ TBusMessageQueue* Queue;
+
+ TScheduler Scheduler;
+
+ const char* const Name;
+
+ typedef TList<TJobRunner*> TBusJobList;
+ /// jobs currently in-flight on this module
+ TBusJobList Jobs;
+ /// module level mutex
+ TMutex Lock;
+ TCondVar ShutdownCondVar;
+ TAtomic JobCount;
+
+ enum EState {
+ CREATED,
+ RUNNING,
+ STOPPED,
+ };
+
+ TAtomic State;
+ TBusModuleConfig ModuleConfig;
+ TBusServerSessionPtr ExternalSession;
+ /// protocol for local proxy session
+ THolder<IBusClientHandler> ModuleClientHandler;
+ THolder<IBusServerHandler> ModuleServerHandler;
+ TVector<TSimpleSharedPtr<TBusStarter>> Starters;
+
+ // Sessions must be destroyed before
+ // ModuleClientHandler / ModuleServerHandler
+ TVector<TBusClientSessionPtr> ClientSessions;
+ TVector<TBusServerSessionPtr> ServerSessions;
+
+ TBusModuleImpl(TBusModule* module, const char* name)
+ : Module(module)
+ , Queue()
+ , Name(name)
+ , JobCount(0)
+ , State(CREATED)
+ , ExternalSession(nullptr)
+ , ModuleClientHandler(new TModuleClientHandler(this))
+ , ModuleServerHandler(new TModuleServerHandler(this))
+ {
+ }
+
+ ~TBusModuleImpl() override {
+ // Shutdown cannot be called from destructor,
+ // because module has virtual methods.
+ Y_VERIFY(State != RUNNING, "if running, must explicitly call Shutdown() before destructor");
+
+ Scheduler.Stop();
+
+ while (!Jobs.empty()) {
+ DestroyJob(Jobs.front());
+ }
+ Y_VERIFY(JobCount == 0, "state check");
+ }
+
+ void OnMessageReceived(TAutoPtr<TBusMessage> msg, TOnMessageContext&);
+
+ void AddJob(TJobRunner* jobRunner);
+
+ void DestroyJob(TJobRunner* job);
+
+ /// terminate job on this message
+ void CancelJob(TBusJob* job, EMessageStatus status);
+ /// prints statuses of jobs
+ TString GetStatus(unsigned flags);
+
+ size_t Size() const {
+ return AtomicGet(JobCount);
+ }
+
+ void Shutdown();
+
+ TVector<TBusClientSessionPtr> GetClientSessionsInternal() override {
+ return ClientSessions;
+ }
+
+ TVector<TBusServerSessionPtr> GetServerSessionsInternal() override {
+ return ServerSessions;
+ }
+
+ TBusMessageQueue* GetQueue() override {
+ return Queue;
+ }
+
+ TString GetNameInternal() override {
+ return Name;
+ }
+
+ TString GetStatusSingleLine() override {
+ TStringStream ss;
+ ss << "jobs: " << Size();
+ return ss.Str();
+ }
+
+ void OnClientConnectionEvent(const TClientConnectionEvent& event) {
+ Module->OnClientConnectionEvent(event);
+ }
+ };
+
+ struct TJobResponseMessage {
+ TBusMessage* Request;
+ TBusMessage* Response;
+ EMessageStatus Status;
+
+ TJobResponseMessage(TBusMessage* request, TBusMessage* response, EMessageStatus status)
+ : Request(request)
+ , Response(response)
+ , Status(status)
+ {
+ }
+ };
+
+ struct TJobRunner: public TAtomicRefCount<TJobRunner>,
+ public NActor::TActor<TJobRunner>,
+ public NActor::TQueueInActor<TJobRunner, TJobResponseMessage>,
+ public TScheduleActor<TJobRunner> {
+ THolder<TBusJob> Job;
+
+ TList<TJobRunner*>::iterator JobStorageIterator;
+
+ TJobRunner(TAutoPtr<TBusJob> job)
+ : NActor::TActor<TJobRunner>(job->ModuleImpl->Queue->GetExecutor())
+ , TScheduleActor<TJobRunner>(&job->ModuleImpl->Scheduler)
+ , Job(job.Release())
+ , JobStorageIterator()
+ {
+ Job->Runner = this;
+ }
+
+ ~TJobRunner() override {
+ Y_ASSERT(JobStorageIterator == TList<TJobRunner*>::iterator());
+ }
+
+ void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, const TJobResponseMessage& message) {
+ Job->CallReplyHandler(message.Status, message.Request, message.Response);
+ }
+
+ void Destroy() {
+ if (!!Job->OnMessageContext) {
+ if (!Job->ReplySent) {
+ Job->OnMessageContext.ForgetRequest();
+ }
+ }
+ Job->ModuleImpl->DestroyJob(this);
+ }
+
+ void Act(NActor::TDefaultTag) {
+ if (JobStorageIterator == TList<TJobRunner*>::iterator()) {
+ return;
+ }
+
+ if (Job->SleepUntil != 0) {
+ if (AtomicGet(Job->ModuleImpl->State) == TBusModuleImpl::STOPPED) {
+ Destroy();
+ return;
+ }
+ }
+
+ TThreadCurrentJobGuard g(Job.Get());
+
+ NActor::TQueueInActor<TJobRunner, TJobResponseMessage>::DequeueAll();
+
+ if (Alarm.FetchTask()) {
+ if (Job->AnyPendingToSend()) {
+ Y_ASSERT(Job->SleepUntil == 0);
+ Job->SendPending();
+ if (Job->AnyPendingToSend()) {
+ }
+ } else {
+ // regular alarm
+ Y_ASSERT(Job->Pending.empty());
+ Y_ASSERT(Job->SleepUntil != 0);
+ Job->SleepUntil = 0;
+ }
+ }
+
+ for (;;) {
+ if (Job->Pending.empty() && !!Job->Handler && Job->Status == MESSAGE_OK) {
+ TWhatThreadDoesPushPop pp("do call job handler (do not confuse with reply handler)");
+
+ Job->Handler = Job->Handler(Job->Module, Job.Get(), Job->Message);
+ }
+
+ if (Job->SleepUntil != 0) {
+ ScheduleAt(TInstant::MilliSeconds(Job->SleepUntil));
+ return;
+ }
+
+ Job->SendPending();
+
+ if (Job->AnyPendingToSend()) {
+ ScheduleAt(TInstant::Now() + TDuration::Seconds(1));
+ return;
+ }
+
+ if (!Job->Pending.empty()) {
+ // waiting replies
+ return;
+ }
+
+ if (Job->IsDone()) {
+ Destroy();
+ return;
+ }
+ }
+ }
+ };
+
+ }
+
+ static inline TJobRunner* GetJob(TBusMessage* message) {
+ return (TJobRunner*)message->Data;
+ }
+
+ static inline void SetJob(TBusMessage* message, TJobRunner* job) {
+ message->Data = job;
+ }
+
+ TBusJob::TBusJob(TBusModule* module, TBusMessage* message)
+ : Status(MESSAGE_OK)
+ , Runner()
+ , Message(message)
+ , ReplySent(false)
+ , Module(module)
+ , ModuleImpl(module->Impl.Get())
+ , SleepUntil(0)
+ {
+ Handler = TJobHandler(&TBusModule::Start);
+ }
+
+ TBusJob::~TBusJob() {
+ Y_ASSERT(Pending.size() == 0);
+ //Y_ASSERT(SleepUntil == 0);
+
+ ClearAllMessageStates();
+ }
+
+ TNetAddr TBusJob::GetPeerAddrNetAddr() const {
+ Y_VERIFY(!!OnMessageContext);
+ return OnMessageContext.GetPeerAddrNetAddr();
+ }
+
+ void TBusJob::CheckThreadCurrentJob() {
+ Y_ASSERT(ThreadCurrentJob == this);
+ }
+
+ /////////////////////////////////////////////////////////
+ /// \brief Send messages in pending list
+
+ /// If at least one message is gone return true
+ /// If message has not been send, move it to Finished with appropriate error code
+ bool TBusJob::SendPending() {
+ // Iterator type must be size_t, not vector::iterator,
+ // because `DoCallReplyHandler` may call `Send` that modifies `Pending` vector,
+ // that in turn invalidates iterator.
+ // Implementation assumes that `DoCallReplyHandler` only pushes back to `Pending`
+ // (not erases, and not inserts) so iteration by index is valid.
+ size_t it = 0;
+ while (it != Pending.size()) {
+ TJobState& call = Pending[it];
+
+ if (call.Status == MESSAGE_DONT_ASK) {
+ EMessageStatus getAddressStatus = MESSAGE_OK;
+ TNetAddr addr;
+ if (call.UseAddr) {
+ addr = call.Addr;
+ } else {
+ getAddressStatus = const_cast<TBusProtocol*>(call.Session->GetProto())->GetDestination(call.Session, call.Message, call.Session->GetQueue()->GetLocator(), &addr);
+ }
+
+ if (getAddressStatus == MESSAGE_OK) {
+ // hold extra reference for each request in flight
+ Runner->Ref();
+
+ if (call.OneWay) {
+ call.Status = call.Session->SendMessageOneWay(call.Message, &addr);
+ } else {
+ call.Status = call.Session->SendMessage(call.Message, &addr);
+ }
+
+ if (call.Status != MESSAGE_OK) {
+ Runner->UnRef();
+ }
+
+ } else {
+ call.Status = getAddressStatus;
+ }
+ }
+
+ if (call.Status == MESSAGE_OK) {
+ ++it; // keep pending list until we get reply
+ } else if (call.Status == MESSAGE_BUSY) {
+ Y_FAIL("MESSAGE_BUSY is prohibited in modules. Please increase MaxInFlight");
+ } else if (call.Status == MESSAGE_CONNECT_FAILED && call.NumRetries < call.MaxRetries) {
+ ++it; // try up to call.MaxRetries times to send message
+ call.NumRetries++;
+ DoCallReplyHandler(call);
+ call.Status = MESSAGE_DONT_ASK;
+ call.Message->Reset(); // generate new Id
+ } else {
+ Finished.push_back(call);
+ DoCallReplyHandler(call);
+ Pending.erase(Pending.begin() + it);
+ }
+ }
+ return Pending.size() > 0;
+ }
+
+ bool TBusJob::AnyPendingToSend() {
+ for (unsigned i = 0; i < Pending.size(); ++i) {
+ if (Pending[i].Status == MESSAGE_DONT_ASK) {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ bool TBusJob::IsDone() {
+ bool r = (SleepUntil == 0 && Pending.size() == 0 && (Handler == nullptr || Status != MESSAGE_OK));
+ return r;
+ }
+
+ void TBusJob::CallJobHandlerOnly() {
+ TThreadCurrentJobGuard threadCurrentJobGuard(this);
+ TWhatThreadDoesPushPop pp("do call job handler (do not confuse with reply handler)");
+
+ Handler = Handler(ModuleImpl->Module, this, Message);
+ }
+
+ bool TBusJob::CallJobHandler() {
+ /// go on as far as we can go without waiting
+ while (!IsDone()) {
+ /// call the handler
+ CallJobHandlerOnly();
+
+ /// quit if job is canceled
+ if (Status != MESSAGE_OK) {
+ break;
+ }
+
+ /// there are messages to send and wait for reply
+ SendPending();
+
+ if (!Pending.empty()) {
+ break;
+ }
+
+ /// asked to sleep
+ if (SleepUntil) {
+ break;
+ }
+ }
+
+ Y_VERIFY(!(Pending.size() == 0 && Handler == nullptr && Status == MESSAGE_OK && !ReplySent),
+ "Handler returned NULL without Cancel() or SendReply() for message=%016" PRIx64 " type=%d",
+ Message->GetHeader()->Id, Message->GetHeader()->Type);
+
+ return IsDone();
+ }
+
+ void TBusJob::DoCallReplyHandler(TJobState& call) {
+ if (call.Handler) {
+ TWhatThreadDoesPushPop pp("do call reply handler (do not confuse with job handler)");
+
+ TThreadCurrentJobGuard threadCurrentJobGuard(this);
+ (Module->*(call.Handler))(this, call.Status, call.Message, call.Reply);
+ }
+ }
+
+ int TBusJob::CallReplyHandler(EMessageStatus status, TBusMessage* mess, TBusMessage* reply) {
+ /// find handler for given message and update it's status
+ size_t i = 0;
+ for (; i < Pending.size(); ++i) {
+ TJobState& call = Pending[i];
+ if (call.Message == mess) {
+ break;
+ }
+ }
+
+ /// if not found, report error
+ if (i == Pending.size()) {
+ Y_FAIL("must not happen");
+ }
+
+ /// fill in response into job state
+ TJobState& call = Pending[i];
+ call.Status = status;
+ Y_ASSERT(call.Message == mess);
+ call.Reply = reply;
+
+ if ((status == MESSAGE_TIMEOUT || status == MESSAGE_DELIVERY_FAILED) && call.NumRetries < call.MaxRetries) {
+ call.NumRetries++;
+ call.Status = MESSAGE_DONT_ASK;
+ call.Message->Reset(); // generate new Id
+ DoCallReplyHandler(call);
+ return 0;
+ }
+
+ /// call the handler if provided
+ DoCallReplyHandler(call);
+
+ /// move job state into the finished stack
+ Finished.push_back(Pending[i]);
+ Pending.erase(Pending.begin() + i);
+
+ return 0;
+ }
+
+ ///////////////////////////////////////////////////////////////
+ /// send message to any other session or application
+ void TBusJob::Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries) {
+ CheckThreadCurrentJob();
+
+ SetJob(mess.Get(), Runner);
+ Pending.push_back(TJobState(rhandler, MESSAGE_DONT_ASK, mess.Release(), session, nullptr, maxRetries, nullptr, false));
+ }
+
+ void TBusJob::Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries, const TNetAddr& addr) {
+ CheckThreadCurrentJob();
+
+ SetJob(mess.Get(), Runner);
+ Pending.push_back(TJobState(rhandler, MESSAGE_DONT_ASK, mess.Release(), session, nullptr, maxRetries, &addr, false));
+ }
+
+ void TBusJob::SendOneWayTo(TBusMessageAutoPtr req, TBusClientSession* session, const TNetAddr& addr) {
+ CheckThreadCurrentJob();
+
+ SetJob(req.Get(), Runner);
+ Pending.push_back(TJobState(nullptr, MESSAGE_DONT_ASK, req.Release(), session, nullptr, 0, &addr, true));
+ }
+
+ void TBusJob::SendOneWayWithLocator(TBusMessageAutoPtr req, TBusClientSession* session) {
+ CheckThreadCurrentJob();
+
+ SetJob(req.Get(), Runner);
+ Pending.push_back(TJobState(nullptr, MESSAGE_DONT_ASK, req.Release(), session, nullptr, 0, nullptr, true));
+ }
+
+ ///////////////////////////////////////////////////////////////
+ /// send reply to the starter message
+ void TBusJob::SendReply(TBusMessageAutoPtr reply) {
+ CheckThreadCurrentJob();
+
+ Y_VERIFY(!ReplySent, "cannot call SendReply twice");
+ ReplySent = true;
+ if (!OnMessageContext)
+ return;
+
+ EMessageStatus ok = OnMessageContext.SendReplyMove(reply);
+ if (ok != MESSAGE_OK) {
+ // TODO: count errors
+ }
+ }
+
+ /// set the flag to terminate job at the earliest convenience
+ void TBusJob::Cancel(EMessageStatus status) {
+ CheckThreadCurrentJob();
+
+ Status = status;
+ }
+
+ void TBusJob::ClearState(TJobState& call) {
+ TJobStateVec::iterator it;
+ for (it = Finished.begin(); it != Finished.end(); ++it) {
+ TJobState& state = *it;
+ if (&call == &state) {
+ ::ClearState(&call);
+ Finished.erase(it);
+ return;
+ }
+ }
+ Y_ASSERT(0);
+ }
+
+ void TBusJob::ClearAllMessageStates() {
+ ClearJobStateVector(&Finished);
+ ClearJobStateVector(&Pending);
+ }
+
+ void TBusJob::Sleep(int milliSeconds) {
+ CheckThreadCurrentJob();
+
+ Y_VERIFY(Pending.empty(), "sleep is not allowed when there are pending job");
+ Y_VERIFY(SleepUntil == 0, "must not override sleep");
+
+ SleepUntil = Now() + milliSeconds;
+ }
+
+ TString TBusJob::GetStatus(unsigned flags) {
+ TString strReturn;
+ strReturn += Sprintf(" job=%016" PRIx64 " type=%d sent=%d pending=%d (%d) %s\n",
+ Message->GetHeader()->Id,
+ (int)Message->GetHeader()->Type,
+ (int)(Now() - Message->GetHeader()->SendTime) / 1000,
+ (int)Pending.size(),
+ (int)Finished.size(),
+ Status != MESSAGE_OK ? ToString(Status).data() : "");
+
+ TJobStateVec::iterator it;
+ for (it = Pending.begin(); it != Pending.end(); ++it) {
+ TJobState& call = *it;
+ strReturn += call.GetStatus(flags);
+ }
+ return strReturn;
+ }
+
+ TString TJobState::GetStatus(unsigned flags) {
+ Y_UNUSED(flags);
+ TString strReturn;
+ strReturn += Sprintf(" pending=%016" PRIx64 " type=%d (%s) sent=%d %s\n",
+ Message->GetHeader()->Id,
+ (int)Message->GetHeader()->Type,
+ Session->GetProto()->GetService(),
+ (int)(Now() - Message->GetHeader()->SendTime) / 1000,
+ ToString(Status).data());
+ return strReturn;
+ }
+
+ //////////////////////////////////////////////////////////////////////
+
+ void TBusModuleImpl::CancelJob(TBusJob* job, EMessageStatus status) {
+ TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for CancelJob");
+ if (job) {
+ job->Cancel(status);
+ }
+ }
+
+ TString TBusModuleImpl::GetStatus(unsigned flags) {
+ Y_UNUSED(flags);
+ TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for GetStatus");
+ TString strReturn = Sprintf("JobsInFlight=%d\n", (int)Jobs.size());
+ for (auto job : Jobs) {
+ //strReturn += job->Job->GetStatus(flags);
+ Y_UNUSED(job);
+ strReturn += "TODO\n";
+ }
+ return strReturn;
+ }
+
+ TBusModuleConfig::TBusModuleConfig()
+ : StarterMaxInFlight(1000)
+ {
+ }
+
+ TBusModuleConfig::TSecret::TSecret()
+ : SchedulePeriod(TDuration::Seconds(1))
+ {
+ }
+
+ TBusModule::TBusModule(const char* name)
+ : Impl(new TBusModuleImpl(this, name))
+ {
+ }
+
+ TBusModule::~TBusModule() {
+ }
+
+ const char* TBusModule::GetName() const {
+ return Impl->Name;
+ }
+
+ void TBusModule::SetConfig(const TBusModuleConfig& config) {
+ Impl->ModuleConfig = config;
+ }
+
+ bool TBusModule::StartInput() {
+ Y_VERIFY(Impl->State == TBusModuleImpl::CREATED, "state check");
+ Y_VERIFY(!!Impl->Queue, "state check");
+ Impl->State = TBusModuleImpl::RUNNING;
+
+ Y_ASSERT(!Impl->ExternalSession);
+ TBusServerSessionPtr extSession = CreateExtSession(*Impl->Queue);
+ if (extSession != nullptr) {
+ Impl->ExternalSession = extSession;
+ }
+
+ return true;
+ }
+
+ bool TBusModule::Shutdown() {
+ Impl->Shutdown();
+
+ return true;
+ }
+
+ TBusJob* TBusModule::CreateJobInstance(TBusMessage* message) {
+ TBusJob* job = new TBusJob(this, message);
+ return job;
+ }
+
+ /**
+Example for external session creation:
+
+TBusSession* TMyModule::CreateExtSession(TBusMessageQueue& queue) {
+ TBusSession* session = CreateDefaultDestination(queue, &ExternalProto, ExternalConfig);
+ session->RegisterService(hostname, begin, end);
+ return session;
+*/
+
+ bool TBusModule::CreatePrivateSessions(TBusMessageQueue* queue) {
+ Impl->Queue = queue;
+ return true;
+ }
+
+ int TBusModule::GetModuleSessionInFlight() const {
+ return Impl->Size();
+ }
+
+ TIntrusivePtr<TBusModuleInternal> TBusModule::GetInternal() {
+ return Impl.Get();
+ }
+
+ TBusServerSessionPtr TBusModule::CreateDefaultDestination(
+ TBusMessageQueue& queue, TBusProtocol* proto, const TBusServerSessionConfig& config, const TString& name) {
+ TBusServerSessionConfig patchedConfig = config;
+ patchedConfig.ExecuteOnMessageInWorkerPool = false;
+ if (!patchedConfig.Name) {
+ patchedConfig.Name = name;
+ }
+ if (!patchedConfig.Name) {
+ patchedConfig.Name = Impl->Name;
+ }
+ TBusServerSessionPtr session =
+ TBusServerSession::Create(proto, Impl->ModuleServerHandler.Get(), patchedConfig, &queue);
+ Impl->ServerSessions.push_back(session);
+ return session;
+ }
+
+ TBusClientSessionPtr TBusModule::CreateDefaultSource(
+ TBusMessageQueue& queue, TBusProtocol* proto, const TBusClientSessionConfig& config, const TString& name) {
+ TBusClientSessionConfig patchedConfig = config;
+ patchedConfig.ExecuteOnReplyInWorkerPool = false;
+ if (!patchedConfig.Name) {
+ patchedConfig.Name = name;
+ }
+ if (!patchedConfig.Name) {
+ patchedConfig.Name = Impl->Name;
+ }
+ TBusClientSessionPtr session =
+ TBusClientSession::Create(proto, Impl->ModuleClientHandler.Get(), patchedConfig, &queue);
+ Impl->ClientSessions.push_back(session);
+ return session;
+ }
+
+ TBusStarter* TBusModule::CreateDefaultStarter(TBusMessageQueue&, const TBusSessionConfig& config) {
+ TBusStarter* session = new TBusStarter(this, config);
+ Impl->Starters.push_back(session);
+ return session;
+ }
+
+ void TBusModule::OnClientConnectionEvent(const TClientConnectionEvent& event) {
+ Y_UNUSED(event);
+ }
+
+ TString TBusModule::GetStatus(unsigned flags) {
+ TString strReturn = Sprintf("%s\n", Impl->Name);
+ strReturn += Impl->GetStatus(flags);
+ return strReturn;
+ }
+
+}
+
+void TBusModuleImpl::AddJob(TJobRunner* jobRunner) {
+ TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for AddJob");
+ Jobs.push_back(jobRunner);
+ jobRunner->JobStorageIterator = Jobs.end();
+ --jobRunner->JobStorageIterator;
+}
+
+void TBusModuleImpl::DestroyJob(TJobRunner* job) {
+ Y_ASSERT(job->JobStorageIterator != TList<TJobRunner*>::iterator());
+
+ {
+ TWhatThreadDoesAcquireGuard<TMutex> G(Lock, "modules: acquiring lock for DestroyJob");
+ int jobCount = AtomicDecrement(JobCount);
+ Y_VERIFY(jobCount >= 0, "decremented too much");
+ Jobs.erase(job->JobStorageIterator);
+
+ if (AtomicGet(State) == STOPPED) {
+ if (jobCount == 0) {
+ ShutdownCondVar.BroadCast();
+ }
+ }
+ }
+
+ job->JobStorageIterator = TList<TJobRunner*>::iterator();
+}
+
+void TBusModuleImpl::OnMessageReceived(TAutoPtr<TBusMessage> msg0, TOnMessageContext& context) {
+ TBusMessage* msg = !!msg0 ? msg0.Get() : context.GetMessage();
+ Y_VERIFY(!!msg);
+
+ THolder<TJobRunner> jobRunner(new TJobRunner(Module->CreateJobInstance(msg)));
+ jobRunner->Job->MessageHolder.Reset(msg0.Release());
+ jobRunner->Job->OnMessageContext.Swap(context);
+ SetJob(jobRunner->Job->Message, jobRunner.Get());
+
+ AtomicIncrement(JobCount);
+
+ AddJob(jobRunner.Get());
+
+ jobRunner.Release()->Schedule();
+}
+
+void TBusModuleImpl::Shutdown() {
+ if (AtomicGet(State) != TBusModuleImpl::RUNNING) {
+ AtomicSet(State, TBusModuleImpl::STOPPED);
+ return;
+ }
+ AtomicSet(State, TBusModuleImpl::STOPPED);
+
+ for (auto& clientSession : ClientSessions) {
+ clientSession->Shutdown();
+ }
+ for (auto& serverSession : ServerSessions) {
+ serverSession->Shutdown();
+ }
+
+ for (size_t starter = 0; starter < Starters.size(); ++starter) {
+ Starters[starter]->Shutdown();
+ }
+
+ {
+ TWhatThreadDoesAcquireGuard<TMutex> guard(Lock, "modules: acquiring lock for Shutdown");
+ for (auto& Job : Jobs) {
+ Job->Schedule();
+ }
+
+ while (!Jobs.empty()) {
+ ShutdownCondVar.WaitI(Lock);
+ }
+ }
+}
+
+EMessageStatus TBusModule::StartJob(TAutoPtr<TBusMessage> message) {
+ Y_VERIFY(Impl->State == TBusModuleImpl::RUNNING);
+ Y_VERIFY(!!Impl->Queue);
+
+ if ((unsigned)AtomicGet(Impl->JobCount) >= Impl->ModuleConfig.StarterMaxInFlight) {
+ return MESSAGE_BUSY;
+ }
+
+ TOnMessageContext dummy;
+ Impl->OnMessageReceived(message.Release(), dummy);
+
+ return MESSAGE_OK;
+}
+
+void TModuleServerHandler::OnMessage(TOnMessageContext& msg) {
+ Module->OnMessageReceived(nullptr, msg);
+}
+
+void TModuleClientHandler::OnReply(TAutoPtr<TBusMessage> req, TAutoPtr<TBusMessage> resp) {
+ TJobRunner* job = GetJob(req.Get());
+ Y_ASSERT(job);
+ Y_ASSERT(job->Job->Message != req.Get());
+ job->EnqueueAndSchedule(TJobResponseMessage(req.Release(), resp.Release(), MESSAGE_OK));
+ job->UnRef();
+}
+
+void TModuleClientHandler::OnMessageSentOneWay(TAutoPtr<TBusMessage> req) {
+ TJobRunner* job = GetJob(req.Get());
+ Y_ASSERT(job);
+ Y_ASSERT(job->Job->Message != req.Get());
+ job->EnqueueAndSchedule(TJobResponseMessage(req.Release(), nullptr, MESSAGE_OK));
+ job->UnRef();
+}
+
+void TModuleClientHandler::OnError(TAutoPtr<TBusMessage> msg, EMessageStatus status) {
+ TJobRunner* job = GetJob(msg.Get());
+ if (job) {
+ Y_ASSERT(job->Job->Message != msg.Get());
+ job->EnqueueAndSchedule(TJobResponseMessage(msg.Release(), nullptr, status));
+ job->UnRef();
+ }
+}
+
+void TModuleClientHandler::OnClientConnectionEvent(const TClientConnectionEvent& event) {
+ Module->OnClientConnectionEvent(event);
+}
diff --git a/library/cpp/messagebus/oldmodule/module.h b/library/cpp/messagebus/oldmodule/module.h
new file mode 100644
index 0000000000..8d1c4a5d52
--- /dev/null
+++ b/library/cpp/messagebus/oldmodule/module.h
@@ -0,0 +1,410 @@
+#pragma once
+
+///////////////////////////////////////////////////////////////////////////
+/// \file
+/// \brief Application interface for modules
+
+/// NBus::TBusModule provides foundation for implementation of asynchnous
+/// modules that communicate with multiple external or local sessions
+/// NBus::TBusSession.
+
+/// To implement the module some virtual functions needs to be overridden:
+
+/// NBus::TBusModule::CreateExtSession() creates and registers an
+/// external session that receives incoming messages as input for module
+/// processing.
+
+/// When new incoming message arrives the new NBus::TBusJob is created.
+/// NBus::TBusJob is somewhat similar to a thread, it maintains all the state
+/// during processing of one incoming message. Default implementation of
+/// NBus::TBusJob will maintain all send and received messages during
+/// lifetime of this job. Each message, status and reply can be found
+/// within NBus::TJobState using NBus::TBusJob::GetState(). If your module
+/// needs to maintain an additional information during lifetime of the job
+/// you can derive your own class from NBus::TBusJob and override job
+/// factory method NBus::IJobFactory::CreateJobInstance() to create your instances.
+
+/// Processing of a given message starts with a call to NBus::TBusModule::Start()
+/// handler that should be overridden in the module implementation. Within
+/// the callback handler module can perform any computation and access any
+/// datastore tables that it needs. The handler can also access any module
+/// variables. However, same handler can be called from multiple threads so,
+/// it is recommended that handler only access read-only module level variables.
+
+/// Handler should use NBus::TBusJob::Send() to send messages to other client
+/// sessions and it can use NBus::TBusJob::Reply() to send reply to the main
+/// job message. When handler is done, it returns the pointer to the next handler to call
+/// when all pending messages have cleared. If handler
+/// returns pointer to itself the module will reschedule execution of this handler
+/// for a later time. This should be done in case NBus::TBusJob::Send() returns
+/// error (not MESSAGE_OK)
+
+#include "startsession.h"
+
+#include <library/cpp/messagebus/ybus.h>
+
+#include <util/generic/noncopyable.h>
+#include <util/generic/object_counter.h>
+
+namespace NBus {
+ class TBusJob;
+ class TBusModule;
+
+ namespace NPrivate {
+ struct TCallJobHandlerWorkItem;
+ struct TBusModuleImpl;
+ struct TModuleServerHandler;
+ struct TModuleClientHandler;
+ struct TJobRunner;
+ }
+
+ class TJobHandler {
+ protected:
+ typedef TJobHandler (TBusModule::*TBusHandlerPtr)(TBusJob* job, TBusMessage* mess);
+ TBusHandlerPtr MyPtr;
+
+ public:
+ template <class B>
+ TJobHandler(TJobHandler (B::*fptr)(TBusJob* job, TBusMessage* mess)) {
+ MyPtr = static_cast<TBusHandlerPtr>(fptr);
+ }
+ TJobHandler(TBusHandlerPtr fptr = nullptr) {
+ MyPtr = fptr;
+ }
+ TJobHandler(const TJobHandler&) = default;
+ TJobHandler& operator =(const TJobHandler&) = default;
+ bool operator==(TJobHandler h) const {
+ return MyPtr == h.MyPtr;
+ }
+ bool operator!=(TJobHandler h) const {
+ return MyPtr != h.MyPtr;
+ }
+ bool operator!() const {
+ return !MyPtr;
+ }
+ TJobHandler operator()(TBusModule* b, TBusJob* job, TBusMessage* mess) {
+ return (b->*MyPtr)(job, mess);
+ }
+ };
+
+ typedef void (TBusModule::*TReplyHandler)(TBusJob* job, EMessageStatus status, TBusMessage* mess, TBusMessage* reply);
+
+ ////////////////////////////////////////////////////
+ /// \brief Pending message state
+
+ struct TJobState {
+ friend class TBusJob;
+ friend class ::TCrawlerModule;
+
+ TReplyHandler Handler;
+ EMessageStatus Status;
+ TBusMessage* Message;
+ TBusMessage* Reply;
+ TBusClientSession* Session;
+ size_t NumRetries;
+ size_t MaxRetries;
+ // If != NULL then use it as destination.
+ TNetAddr Addr;
+ bool UseAddr;
+ bool OneWay;
+
+ private:
+ TJobState(TReplyHandler handler,
+ EMessageStatus status,
+ TBusMessage* mess, TBusClientSession* session, TBusMessage* reply, size_t maxRetries = 0,
+ const TNetAddr* addr = nullptr, bool oneWay = false)
+ : Handler(handler)
+ , Status(status)
+ , Message(mess)
+ , Reply(reply)
+ , Session(session)
+ , NumRetries(0)
+ , MaxRetries(maxRetries)
+ , OneWay(oneWay)
+ {
+ if (!!addr) {
+ Addr = *addr;
+ }
+ UseAddr = !!addr;
+ }
+
+ public:
+ TString GetStatus(unsigned flags);
+ };
+
+ using TJobStateVec = TVector<TJobState>;
+
+ /////////////////////////////////////////////////////////
+ /// \brief Execution item = thread
+
+ /// Maintains internal state of document in computation
+ class TBusJob {
+ TObjectCounter<TBusJob> ObjectCounter;
+
+ private:
+ void CheckThreadCurrentJob();
+
+ public:
+ /// given a module and starter message
+ TBusJob(TBusModule* module, TBusMessage* message);
+
+ /// destructor will free all the message that were send and received
+ virtual ~TBusJob();
+
+ TBusMessage* GetMessage() const {
+ return Message;
+ }
+
+ TNetAddr GetPeerAddrNetAddr() const;
+
+ /// send message to any other session or application
+ /// If addr is set then use it as destination.
+ void Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler, size_t maxRetries, const TNetAddr& addr);
+ void Send(TBusMessageAutoPtr mess, TBusClientSession* session, TReplyHandler rhandler = nullptr, size_t maxRetries = 0);
+
+ void SendOneWayTo(TBusMessageAutoPtr req, TBusClientSession* session, const TNetAddr& addr);
+ void SendOneWayWithLocator(TBusMessageAutoPtr req, TBusClientSession* session);
+
+ /// send reply to the starter message
+ virtual void SendReply(TBusMessageAutoPtr reply);
+
+ /// set the flag to terminate job at the earliest convenience
+ void Cancel(EMessageStatus status);
+
+ /// helper to put item on finished list of states
+ /// It should not be a part of public API,
+ /// so prohibit it for all except current users.
+ private:
+ friend class ::TCrawlerModule;
+ void PutState(const TJobState& state) {
+ Finished.push_back(state);
+ }
+
+ public:
+ /// retrieve all pending messages
+ void GetPending(TJobStateVec* stateVec) {
+ Y_ASSERT(stateVec);
+ *stateVec = Pending;
+ }
+
+ /// helper function to find state of previously sent messages
+ template <class MessageType>
+ TJobState* GetState(int* startFrom = nullptr) {
+ for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) {
+ TJobState* call = &Finished[i];
+ if (call->Reply != nullptr && dynamic_cast<MessageType*>(call->Reply)) {
+ if (startFrom) {
+ *startFrom = i;
+ }
+ return call;
+ }
+ if (call->Message != nullptr && dynamic_cast<MessageType*>(call->Message)) {
+ if (startFrom) {
+ *startFrom = i;
+ }
+ return call;
+ }
+ }
+ return nullptr;
+ }
+
+ /// helper function to find response for previously sent messages
+ template <class MessageType>
+ MessageType* Get(int* startFrom = nullptr) {
+ for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) {
+ TJobState& call = Finished[i];
+ if (call.Reply != nullptr && dynamic_cast<MessageType*>(call.Reply)) {
+ if (startFrom) {
+ *startFrom = i;
+ }
+ return static_cast<MessageType*>(call.Reply);
+ }
+ if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) {
+ if (startFrom) {
+ *startFrom = i;
+ }
+ return static_cast<MessageType*>(call.Message);
+ }
+ }
+ return nullptr;
+ }
+
+ /// helper function to find status for previously sent message
+ template <class MessageType>
+ EMessageStatus GetStatus(int* startFrom = nullptr) {
+ for (int i = startFrom ? *startFrom : 0; i < int(Finished.size()); i++) {
+ TJobState& call = Finished[i];
+ if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) {
+ if (startFrom) {
+ *startFrom = i;
+ }
+ return call.Status;
+ }
+ }
+ return MESSAGE_UNKNOWN;
+ }
+
+ /// helper function to clear state of previosly sent messages
+ template <class MessageType>
+ void Clear() {
+ for (size_t i = 0; i < Finished.size();) {
+ // `Finished.size() - i` decreases with each iteration
+ // we either increment i, or remove element from Finished.
+ TJobState& call = Finished[i];
+ if (call.Message != nullptr && dynamic_cast<MessageType*>(call.Message)) {
+ ClearState(call);
+ } else {
+ ++i;
+ }
+ }
+ }
+
+ /// helper function to clear state in order to try again
+ void ClearState(TJobState& state);
+
+ /// clears all message states
+ void ClearAllMessageStates();
+
+ /// returns true if job is done
+ bool IsDone();
+
+ /// return human reabable status of this job
+ virtual TString GetStatus(unsigned flags);
+
+ /// set sleep time for job
+ void Sleep(int milliSeconds);
+
+ void CallJobHandlerOnly();
+
+ private:
+ bool CallJobHandler();
+ void DoCallReplyHandler(TJobState&);
+ /// send out all Pending jobs, failed sends will be migrated to Finished
+ bool SendPending();
+ bool AnyPendingToSend();
+
+ public:
+ /// helper to call from OnReply() and OnError()
+ int CallReplyHandler(EMessageStatus status, TBusMessage* mess, TBusMessage* reply);
+
+ public:
+ TJobHandler Handler; ///< job handler to be executed within next CallJobHandler()
+ EMessageStatus Status; ///< set != MESSAGE_OK if job should terminate asap
+ private:
+ NPrivate::TJobRunner* Runner;
+ TBusMessage* Message;
+ THolder<TBusMessage> MessageHolder;
+ TOnMessageContext OnMessageContext; // starter
+ public:
+ bool ReplySent;
+
+ private:
+ friend class TBusModule;
+ friend struct NPrivate::TBusModuleImpl;
+ friend struct NPrivate::TCallJobHandlerWorkItem;
+ friend struct NPrivate::TModuleServerHandler;
+ friend struct NPrivate::TModuleClientHandler;
+ friend struct NPrivate::TJobRunner;
+
+ TJobStateVec Pending; ///< messages currently outstanding via Send()
+ TJobStateVec Finished; ///< messages that were replied to
+ TBusModule* Module;
+ NPrivate::TBusModuleImpl* ModuleImpl; ///< module which created the job
+ TBusInstant SleepUntil; ///< time to wakeup, 0 if no sleep
+ };
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief Classes to implement basic module functionality
+
+ class IJobFactory {
+ protected:
+ virtual ~IJobFactory() {
+ }
+
+ public:
+ /// job factory method, override to create custom jobs
+ virtual TBusJob* CreateJobInstance(TBusMessage* message) = 0;
+ };
+
+ struct TBusModuleConfig {
+ unsigned StarterMaxInFlight;
+
+ struct TSecret {
+ TDuration SchedulePeriod;
+
+ TSecret();
+ };
+ TSecret Secret;
+
+ TBusModuleConfig();
+ };
+
+ namespace NPrivate {
+ struct TBusModuleInternal: public TAtomicRefCount<TBusModuleInternal> {
+ virtual TVector<TBusClientSessionPtr> GetClientSessionsInternal() = 0;
+ virtual TVector<TBusServerSessionPtr> GetServerSessionsInternal() = 0;
+ virtual TBusMessageQueue* GetQueue() = 0;
+
+ virtual TString GetNameInternal() = 0;
+
+ virtual TString GetStatusSingleLine() = 0;
+
+ virtual ~TBusModuleInternal() {
+ }
+ };
+ }
+
+ class TBusModule: public IJobFactory, TNonCopyable {
+ friend class TBusJob;
+
+ TObjectCounter<TBusModule> ObjectCounter;
+
+ TIntrusivePtr<NPrivate::TBusModuleImpl> Impl;
+
+ public:
+ /// Each module should have a name which is used as protocol service
+ TBusModule(const char* name);
+ ~TBusModule() override;
+
+ const char* GetName() const;
+
+ void SetConfig(const TBusModuleConfig& config);
+
+ /// get status of all jobs in flight
+ TString GetStatus(unsigned flags = 0);
+
+ /// called when application is about to start
+ virtual bool StartInput();
+ /// called when application is about to exit
+ virtual bool Shutdown();
+
+ // this default implementation just creates TBusJob object
+ TBusJob* CreateJobInstance(TBusMessage* message) override;
+
+ EMessageStatus StartJob(TAutoPtr<TBusMessage> message);
+
+ /// creates private sessions, calls CreateExtSession(), should be called before StartInput()
+ bool CreatePrivateSessions(TBusMessageQueue* queue);
+
+ virtual void OnClientConnectionEvent(const TClientConnectionEvent& event);
+
+ public:
+ /// entry point into module, first function to call
+ virtual TJobHandler Start(TBusJob* job, TBusMessage* mess) = 0;
+
+ protected:
+ /// override this function to create destination session
+ virtual TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) = 0;
+
+ public:
+ int GetModuleSessionInFlight() const;
+
+ TIntrusivePtr<NPrivate::TBusModuleInternal> GetInternal();
+
+ protected:
+ TBusServerSessionPtr CreateDefaultDestination(TBusMessageQueue& queue, TBusProtocol* proto, const TBusServerSessionConfig& config, const TString& name = TString());
+ TBusClientSessionPtr CreateDefaultSource(TBusMessageQueue& queue, TBusProtocol* proto, const TBusClientSessionConfig& config, const TString& name = TString());
+ TBusStarter* CreateDefaultStarter(TBusMessageQueue& unused, const TBusSessionConfig& config);
+ };
+
+}
diff --git a/library/cpp/messagebus/oldmodule/startsession.cpp b/library/cpp/messagebus/oldmodule/startsession.cpp
new file mode 100644
index 0000000000..7c38801d62
--- /dev/null
+++ b/library/cpp/messagebus/oldmodule/startsession.cpp
@@ -0,0 +1,65 @@
+///////////////////////////////////////////////////////////
+/// \file
+/// \brief Starter session implementation
+
+/// Starter session will generate emtpy message to insert
+/// into local session that are registered under same protocol
+
+/// Starter (will one day) automatically adjust number
+/// of message inflight to make sure that at least one of source
+/// sessions within message queue is at the limit (bottle neck)
+
+/// Maximum number of messages that starter will instert into
+/// the pipeline is configured by NBus::TBusSessionConfig::MaxInFlight
+
+#include "startsession.h"
+
+#include "module.h"
+
+#include <library/cpp/messagebus/ybus.h>
+
+namespace NBus {
+ void* TBusStarter::_starter(void* data) {
+ TBusStarter* pThis = static_cast<TBusStarter*>(data);
+ pThis->Starter();
+ return nullptr;
+ }
+
+ TBusStarter::TBusStarter(TBusModule* module, const TBusSessionConfig& config)
+ : Module(module)
+ , Config(config)
+ , StartThread(_starter, this)
+ , Exiting(false)
+ {
+ StartThread.Start();
+ }
+
+ TBusStarter::~TBusStarter() {
+ Shutdown();
+ }
+
+ void TBusStarter::Shutdown() {
+ {
+ TGuard<TMutex> g(ExitLock);
+ Exiting = true;
+ ExitSignal.Signal();
+ }
+ StartThread.Join();
+ }
+
+ void TBusStarter::Starter() {
+ TGuard<TMutex> g(ExitLock);
+ while (!Exiting) {
+ TAutoPtr<TBusMessage> empty(new TBusMessage(0));
+
+ EMessageStatus status = Module->StartJob(empty);
+
+ if (Config.SendTimeout > 0) {
+ ExitSignal.WaitT(ExitLock, TDuration::MilliSeconds(Config.SendTimeout));
+ } else {
+ ExitSignal.WaitT(ExitLock, (status == MESSAGE_BUSY) ? TDuration::MilliSeconds(1) : TDuration::Zero());
+ }
+ }
+ }
+
+}
diff --git a/library/cpp/messagebus/oldmodule/startsession.h b/library/cpp/messagebus/oldmodule/startsession.h
new file mode 100644
index 0000000000..5e26e7e1e5
--- /dev/null
+++ b/library/cpp/messagebus/oldmodule/startsession.h
@@ -0,0 +1,34 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+#include <util/system/thread.h>
+
+namespace NBus {
+ class TBusModule;
+
+ class TBusStarter {
+ private:
+ TBusModule* Module;
+ TBusSessionConfig Config;
+ TThread StartThread;
+ bool Exiting;
+ TCondVar ExitSignal;
+ TMutex ExitLock;
+
+ static void* _starter(void* data);
+
+ void Starter();
+
+ TString GetStatus(ui16 /*flags=YBUS_STATUS_CONNS*/) {
+ return "";
+ }
+
+ public:
+ TBusStarter(TBusModule* module, const TBusSessionConfig& config);
+ ~TBusStarter();
+
+ void Shutdown();
+ };
+
+}
diff --git a/library/cpp/messagebus/oldmodule/ya.make b/library/cpp/messagebus/oldmodule/ya.make
new file mode 100644
index 0000000000..ca5eae74f0
--- /dev/null
+++ b/library/cpp/messagebus/oldmodule/ya.make
@@ -0,0 +1,15 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus
+ library/cpp/messagebus/actor
+)
+
+SRCS(
+ module.cpp
+ startsession.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/protobuf/ya.make b/library/cpp/messagebus/protobuf/ya.make
new file mode 100644
index 0000000000..64ff240b51
--- /dev/null
+++ b/library/cpp/messagebus/protobuf/ya.make
@@ -0,0 +1,15 @@
+LIBRARY(messagebus_protobuf)
+
+OWNER(g:messagebus)
+
+SRCS(
+ ybusbuf.cpp
+)
+
+PEERDIR(
+ contrib/libs/protobuf
+ library/cpp/messagebus
+ library/cpp/messagebus/actor
+)
+
+END()
diff --git a/library/cpp/messagebus/protobuf/ybusbuf.cpp b/library/cpp/messagebus/protobuf/ybusbuf.cpp
new file mode 100644
index 0000000000..63415b3737
--- /dev/null
+++ b/library/cpp/messagebus/protobuf/ybusbuf.cpp
@@ -0,0 +1,88 @@
+#include "ybusbuf.h"
+
+#include <library/cpp/messagebus/actor/what_thread_does.h>
+
+#include <google/protobuf/io/coded_stream.h>
+
+using namespace NBus;
+
+TBusBufferProtocol::TBusBufferProtocol(TBusService name, int port)
+ : TBusProtocol(name, port)
+{
+}
+
+TBusBufferProtocol::~TBusBufferProtocol() {
+ for (auto& type : Types) {
+ delete type;
+ }
+}
+
+TBusBufferBase* TBusBufferProtocol::FindType(int type) {
+ for (unsigned i = 0; i < Types.size(); i++) {
+ if (Types[i]->GetHeader()->Type == type) {
+ return Types[i];
+ }
+ }
+ return nullptr;
+}
+
+bool TBusBufferProtocol::IsRegisteredType(unsigned type) {
+ return TypeMask[type >> 5] & (1 << (type & ((1 << 5) - 1)));
+}
+
+void TBusBufferProtocol::RegisterType(TAutoPtr<TBusBufferBase> mess) {
+ ui32 type = mess->GetHeader()->Type;
+ TypeMask[type >> 5] |= 1 << (type & ((1 << 5) - 1));
+
+ Types.push_back(mess.Release());
+}
+
+TArrayRef<TBusBufferBase* const> TBusBufferProtocol::GetTypes() const {
+ return Types;
+}
+
+void TBusBufferProtocol::Serialize(const TBusMessage* mess, TBuffer& data) {
+ TWhatThreadDoesPushPop pp("serialize protobuf message");
+
+ const TBusHeader* header = mess->GetHeader();
+
+ if (!IsRegisteredType(header->Type)) {
+ Y_FAIL("unknown message type: %d", int(header->Type));
+ return;
+ }
+
+ // cast the base from real message
+ const TBusBufferBase* bmess = CheckedCast<const TBusBufferBase*>(mess);
+
+ unsigned size = bmess->GetRecord()->ByteSize();
+ data.Reserve(data.Size() + size);
+
+ char* after = (char*)bmess->GetRecord()->SerializeWithCachedSizesToArray((ui8*)data.Pos());
+ Y_VERIFY(after - data.Pos() == size);
+
+ data.Advance(size);
+}
+
+TAutoPtr<TBusMessage> TBusBufferProtocol::Deserialize(ui16 messageType, TArrayRef<const char> payload) {
+ TWhatThreadDoesPushPop pp("deserialize protobuf message");
+
+ TBusBufferBase* messageTemplate = FindType(messageType);
+ if (messageTemplate == nullptr) {
+ return nullptr;
+ //Y_FAIL("unknown message type: %d", unsigned(messageType));
+ }
+
+ // clone the base
+ TAutoPtr<TBusBufferBase> bmess = messageTemplate->New();
+
+ // Need to override protobuf message size limit
+ // NOTE: the payload size has already been checked against session MaxMessageSize
+ google::protobuf::io::CodedInputStream input(reinterpret_cast<const ui8*>(payload.data()), payload.size());
+ input.SetTotalBytesLimit(payload.size());
+
+ bool ok = bmess->GetRecord()->ParseFromCodedStream(&input) && input.ConsumedEntireMessage();
+ if (!ok) {
+ return nullptr;
+ }
+ return bmess.Release();
+}
diff --git a/library/cpp/messagebus/protobuf/ybusbuf.h b/library/cpp/messagebus/protobuf/ybusbuf.h
new file mode 100644
index 0000000000..57b4267ea5
--- /dev/null
+++ b/library/cpp/messagebus/protobuf/ybusbuf.h
@@ -0,0 +1,233 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+#include <google/protobuf/descriptor.h>
+#include <google/protobuf/message.h>
+
+#include <util/generic/cast.h>
+#include <util/generic/vector.h>
+#include <util/stream/mem.h>
+
+#include <array>
+
+namespace NBus {
+ using TBusBufferRecord = ::google::protobuf::Message;
+
+ template <class TBufferMessage>
+ class TBusBufferMessagePtr;
+ template <class TBufferMessage>
+ class TBusBufferMessageAutoPtr;
+
+ class TBusBufferBase: public TBusMessage {
+ public:
+ TBusBufferBase(int type)
+ : TBusMessage((ui16)type)
+ {
+ }
+ TBusBufferBase(ECreateUninitialized)
+ : TBusMessage(MESSAGE_CREATE_UNINITIALIZED)
+ {
+ }
+
+ ui16 GetType() const {
+ return GetHeader()->Type;
+ }
+
+ virtual TBusBufferRecord* GetRecord() const = 0;
+ virtual TBusBufferBase* New() = 0;
+ };
+
+ ///////////////////////////////////////////////////////////////////
+ /// \brief Template for all messages that have protobuf description
+
+ /// @param TBufferRecord is record described in .proto file with namespace
+ /// @param MessageFile is offset for .proto file message ids
+
+ /// \attention If you want one protocol NBus::TBusBufferProtocol to handle
+ /// messageges described in different .proto files, make sure that they have
+ /// unique values for MessageFile
+
+ template <class TBufferRecord, int MType>
+ class TBusBufferMessage: public TBusBufferBase {
+ public:
+ static const int MessageType = MType;
+
+ typedef TBusBufferMessagePtr<TBusBufferMessage<TBufferRecord, MType>> TPtr;
+ typedef TBusBufferMessageAutoPtr<TBusBufferMessage<TBufferRecord, MType>> TAutoPtr;
+
+ public:
+ typedef TBufferRecord RecordType;
+ TBufferRecord Record;
+
+ public:
+ TBusBufferMessage()
+ : TBusBufferBase(MessageType)
+ {
+ }
+ TBusBufferMessage(ECreateUninitialized)
+ : TBusBufferBase(MESSAGE_CREATE_UNINITIALIZED)
+ {
+ }
+ explicit TBusBufferMessage(const TBufferRecord& record)
+ : TBusBufferBase(MessageType)
+ , Record(record)
+ {
+ }
+ explicit TBusBufferMessage(TBufferRecord&& record)
+ : TBusBufferBase(MessageType)
+ , Record(std::move(record))
+ {
+ }
+
+ public:
+ TBusBufferRecord* GetRecord() const override {
+ return (TBusBufferRecord*)&Record;
+ }
+ TBusBufferBase* New() override {
+ return new TBusBufferMessage<TBufferRecord, MessageType>();
+ }
+ };
+
+ template <class TSelf, class TBufferMessage>
+ class TBusBufferMessagePtrBase {
+ public:
+ typedef typename TBufferMessage::RecordType RecordType;
+
+ private:
+ TSelf* GetSelf() {
+ return static_cast<TSelf*>(this);
+ }
+ const TSelf* GetSelf() const {
+ return static_cast<const TSelf*>(this);
+ }
+
+ public:
+ RecordType* operator->() {
+ Y_ASSERT(GetSelf()->Get());
+ return &(GetSelf()->Get()->Record);
+ }
+ const RecordType* operator->() const {
+ Y_ASSERT(GetSelf()->Get());
+ return &(GetSelf()->Get()->Record);
+ }
+ RecordType& operator*() {
+ Y_ASSERT(GetSelf()->Get());
+ return GetSelf()->Get()->Record;
+ }
+ const RecordType& operator*() const {
+ Y_ASSERT(GetSelf()->Get());
+ return GetSelf()->Get()->Record;
+ }
+
+ TBusHeader* GetHeader() {
+ return GetSelf()->Get()->GetHeader();
+ }
+ const TBusHeader* GetHeader() const {
+ return GetSelf()->Get()->GetHeader();
+ }
+ };
+
+ template <class TBufferMessage>
+ class TBusBufferMessagePtr: public TBusBufferMessagePtrBase<TBusBufferMessagePtr<TBufferMessage>, TBufferMessage> {
+ protected:
+ TBufferMessage* Holder;
+
+ public:
+ TBusBufferMessagePtr(TBufferMessage* mess)
+ : Holder(mess)
+ {
+ }
+ static TBusBufferMessagePtr<TBufferMessage> DynamicCast(TBusMessage* message) {
+ return dynamic_cast<TBufferMessage*>(message);
+ }
+ TBufferMessage* Get() {
+ return Holder;
+ }
+ const TBufferMessage* Get() const {
+ return Holder;
+ }
+
+ operator TBufferMessage*() {
+ return Holder;
+ }
+ operator const TBufferMessage*() const {
+ return Holder;
+ }
+
+ operator TAutoPtr<TBusMessage>() {
+ TAutoPtr<TBusMessage> r(Holder);
+ Holder = 0;
+ return r;
+ }
+ operator TBusMessageAutoPtr() {
+ TBusMessageAutoPtr r(Holder);
+ Holder = nullptr;
+ return r;
+ }
+ };
+
+ template <class TBufferMessage>
+ class TBusBufferMessageAutoPtr: public TBusBufferMessagePtrBase<TBusBufferMessageAutoPtr<TBufferMessage>, TBufferMessage> {
+ public:
+ TAutoPtr<TBufferMessage> AutoPtr;
+
+ public:
+ TBusBufferMessageAutoPtr() {
+ }
+ TBusBufferMessageAutoPtr(TBufferMessage* message)
+ : AutoPtr(message)
+ {
+ }
+
+ TBufferMessage* Get() {
+ return AutoPtr.Get();
+ }
+ const TBufferMessage* Get() const {
+ return AutoPtr.Get();
+ }
+
+ TBufferMessage* Release() const {
+ return AutoPtr.Release();
+ }
+
+ operator TAutoPtr<TBusMessage>() {
+ return AutoPtr.Release();
+ }
+ operator TBusMessageAutoPtr() {
+ return AutoPtr.Release();
+ }
+ };
+
+ /////////////////////////////////////////////
+ /// \brief Generic protocol object for messages descibed with protobuf
+
+ /// \attention If you mix messages in the same protocol from more than
+ /// .proto file make sure that they have different MessageFile parameter
+ /// in the NBus::TBusBufferMessage template
+
+ class TBusBufferProtocol: public TBusProtocol {
+ private:
+ TVector<TBusBufferBase*> Types;
+ std::array<ui32, ((1 << 16) >> 5)> TypeMask;
+
+ TBusBufferBase* FindType(int type);
+ bool IsRegisteredType(unsigned type);
+
+ public:
+ TBusBufferProtocol(TBusService name, int port);
+
+ ~TBusBufferProtocol() override;
+
+ /// register all the message that this protocol should handle
+ void RegisterType(TAutoPtr<TBusBufferBase> mess);
+
+ TArrayRef<TBusBufferBase* const> GetTypes() const;
+
+ /// serialized protocol specific data into TBusData
+ void Serialize(const TBusMessage* mess, TBuffer& data) override;
+
+ TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override;
+ };
+
+}
diff --git a/library/cpp/messagebus/queue_config.cpp b/library/cpp/messagebus/queue_config.cpp
new file mode 100644
index 0000000000..78fb52ee49
--- /dev/null
+++ b/library/cpp/messagebus/queue_config.cpp
@@ -0,0 +1,22 @@
+#include "queue_config.h"
+
+using namespace NBus;
+
+TBusQueueConfig::TBusQueueConfig() {
+ // workers and listeners configuratioin
+ NumWorkers = 1;
+}
+
+void TBusQueueConfig::ConfigureLastGetopt(
+ NLastGetopt::TOpts& opts, const TString& prefix) {
+ opts.AddLongOption(prefix + "worker-count")
+ .RequiredArgument("COUNT")
+ .DefaultValue(ToString(NumWorkers))
+ .StoreResult(&NumWorkers);
+}
+
+TString TBusQueueConfig::PrintToString() const {
+ TStringStream ss;
+ ss << "NumWorkers=" << NumWorkers << "\n";
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/queue_config.h b/library/cpp/messagebus/queue_config.h
new file mode 100644
index 0000000000..a9955f0c70
--- /dev/null
+++ b/library/cpp/messagebus/queue_config.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <library/cpp/getopt/last_getopt.h>
+
+namespace NBus {
+ //////////////////////////////////////////////////////////////////
+ /// \brief Configuration for message queue
+ struct TBusQueueConfig {
+ TString Name;
+ int NumWorkers; ///< number of threads calling OnMessage(), OnReply() handlers
+
+ TBusQueueConfig(); ///< initializes with default settings
+
+ void ConfigureLastGetopt(NLastGetopt::TOpts&, const TString& prefix = "mb-");
+
+ TString PrintToString() const;
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/coro.cpp b/library/cpp/messagebus/rain_check/core/coro.cpp
new file mode 100644
index 0000000000..500841dd5b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/coro.cpp
@@ -0,0 +1,60 @@
+#include "coro.h"
+
+#include "coro_stack.h"
+
+#include <util/system/tls.h>
+#include <util/system/yassert.h>
+
+using namespace NRainCheck;
+
+TContClosure TCoroTaskRunner::ContClosure(TCoroTaskRunner* runner, TArrayRef<char> memRegion) {
+ TContClosure contClosure;
+ contClosure.TrampoLine = runner;
+ contClosure.Stack = memRegion;
+ return contClosure;
+}
+
+TCoroTaskRunner::TCoroTaskRunner(IEnv* env, ISubtaskListener* parent, TAutoPtr<ICoroTask> impl)
+ : TTaskRunnerBase(env, parent, impl.Release())
+ , Stack(GetImpl()->StackSize)
+ , ContMachineContext(ContClosure(this, Stack.MemRegion()))
+ , CoroDone(false)
+{
+}
+
+TCoroTaskRunner::~TCoroTaskRunner() {
+ Y_ASSERT(CoroDone);
+}
+
+Y_POD_STATIC_THREAD(TContMachineContext*)
+CallerContext;
+Y_POD_STATIC_THREAD(TCoroTaskRunner*)
+Task;
+
+bool TCoroTaskRunner::ReplyReceived() {
+ Y_ASSERT(!CoroDone);
+
+ TContMachineContext me;
+
+ CallerContext = &me;
+ Task = this;
+
+ me.SwitchTo(&ContMachineContext);
+
+ Stack.VerifyNoStackOverflow();
+
+ Y_ASSERT(CallerContext == &me);
+ Y_ASSERT(Task == this);
+
+ return !CoroDone;
+}
+
+void NRainCheck::TCoroTaskRunner::DoRun() {
+ GetImpl()->Run();
+ CoroDone = true;
+ ContMachineContext.SwitchTo(CallerContext);
+}
+
+void NRainCheck::ICoroTask::WaitForSubtasks() {
+ Task->ContMachineContext.SwitchTo(CallerContext);
+}
diff --git a/library/cpp/messagebus/rain_check/core/coro.h b/library/cpp/messagebus/rain_check/core/coro.h
new file mode 100644
index 0000000000..95e2a30f9b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/coro.h
@@ -0,0 +1,58 @@
+#pragma once
+
+#include "coro_stack.h"
+#include "task.h"
+
+#include <util/generic/ptr.h>
+#include <util/memory/alloc.h>
+#include <util/system/align.h>
+#include <util/system/context.h>
+#include <util/system/valgrind.h>
+
+namespace NRainCheck {
+ class ICoroTask;
+
+ class TCoroTaskRunner: public TTaskRunnerBase, private ITrampoLine {
+ friend class ICoroTask;
+
+ private:
+ NPrivate::TCoroStack Stack;
+ TContMachineContext ContMachineContext;
+ bool CoroDone;
+
+ public:
+ TCoroTaskRunner(IEnv* env, ISubtaskListener* parent, TAutoPtr<ICoroTask> impl);
+ ~TCoroTaskRunner() override;
+
+ private:
+ static TContClosure ContClosure(TCoroTaskRunner* runner, TArrayRef<char> memRegion);
+
+ bool ReplyReceived() override /* override */;
+
+ void DoRun() override /* override */;
+
+ ICoroTask* GetImpl() {
+ return (ICoroTask*)GetImplBase();
+ }
+ };
+
+ class ICoroTask: public ITaskBase {
+ friend class TCoroTaskRunner;
+
+ private:
+ size_t StackSize;
+
+ public:
+ typedef TCoroTaskRunner TTaskRunner;
+ typedef ICoroTask ITask;
+
+ ICoroTask(size_t stackSize = 0x2000)
+ : StackSize(stackSize)
+ {
+ }
+
+ virtual void Run() = 0;
+ static void WaitForSubtasks();
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/coro_stack.cpp b/library/cpp/messagebus/rain_check/core/coro_stack.cpp
new file mode 100644
index 0000000000..83b984ca6e
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/coro_stack.cpp
@@ -0,0 +1,41 @@
+#include "coro_stack.h"
+
+#include <util/generic/singleton.h>
+#include <util/system/valgrind.h>
+
+#include <cstdlib>
+#include <stdio.h>
+
+using namespace NRainCheck;
+using namespace NRainCheck::NPrivate;
+
+TCoroStack::TCoroStack(size_t size)
+ : SizeValue(size)
+{
+ Y_VERIFY(size % sizeof(ui32) == 0);
+ Y_VERIFY(size >= 0x1000);
+
+ DataHolder.Reset(malloc(size));
+
+ // register in valgrind
+
+ *MagicNumberLocation() = MAGIC_NUMBER;
+
+#if defined(WITH_VALGRIND)
+ ValgrindStackId = VALGRIND_STACK_REGISTER(Data(), (char*)Data() + Size());
+#endif
+}
+
+TCoroStack::~TCoroStack() {
+#if defined(WITH_VALGRIND)
+ VALGRIND_STACK_DEREGISTER(ValgrindStackId);
+#endif
+
+ VerifyNoStackOverflow();
+}
+
+void TCoroStack::FailStackOverflow() {
+ static const char message[] = "stack overflow\n";
+ fputs(message, stderr);
+ abort();
+}
diff --git a/library/cpp/messagebus/rain_check/core/coro_stack.h b/library/cpp/messagebus/rain_check/core/coro_stack.h
new file mode 100644
index 0000000000..2f3520e6e4
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/coro_stack.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include <util/generic/array_ref.h>
+#include <util/generic/ptr.h>
+#include <util/system/valgrind.h>
+
+namespace NRainCheck {
+ namespace NPrivate {
+ struct TCoroStack {
+ THolder<void, TFree> DataHolder;
+ size_t SizeValue;
+
+#if defined(WITH_VALGRIND)
+ size_t ValgrindStackId;
+#endif
+
+ TCoroStack(size_t size);
+ ~TCoroStack();
+
+ void* Data() {
+ return DataHolder.Get();
+ }
+
+ size_t Size() {
+ return SizeValue;
+ }
+
+ TArrayRef<char> MemRegion() {
+ return TArrayRef((char*)Data(), Size());
+ }
+
+ ui32* MagicNumberLocation() {
+#if STACK_GROW_DOWN == 1
+ return (ui32*)Data();
+#elif STACK_GROW_DOWN == 0
+ return ((ui32*)(((char*)Data()) + Size())) - 1;
+#else
+#error "unknown"
+#endif
+ }
+
+ static void FailStackOverflow();
+
+ inline void VerifyNoStackOverflow() noexcept {
+ if (Y_UNLIKELY(*MagicNumberLocation() != MAGIC_NUMBER)) {
+ FailStackOverflow();
+ }
+ }
+
+ static const ui32 MAGIC_NUMBER = 0xAB4D15FE;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/coro_ut.cpp b/library/cpp/messagebus/rain_check/core/coro_ut.cpp
new file mode 100644
index 0000000000..61a33584a5
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/coro_ut.cpp
@@ -0,0 +1,106 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "coro.h"
+#include "spawn.h"
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+using namespace NRainCheck;
+
+Y_UNIT_TEST_SUITE(RainCheckCoro) {
+ struct TSimpleCoroTask : ICoroTask {
+ TTestSync* const TestSync;
+
+ TSimpleCoroTask(TTestEnv*, TTestSync* testSync)
+ : TestSync(testSync)
+ {
+ }
+
+ void Run() override {
+ TestSync->WaitForAndIncrement(0);
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TTestSync testSync;
+
+ TTestEnv env;
+
+ TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSimpleCoroTask>(&testSync);
+ testSync.WaitForAndIncrement(1);
+ }
+
+ struct TSleepCoroTask : ICoroTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TSleepCoroTask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TSubtaskCompletion SleepCompletion;
+
+ void Run() override {
+ Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1));
+ WaitForSubtasks();
+ TestSync->WaitForAndIncrement(0);
+ }
+ };
+
+ Y_UNIT_TEST(Sleep) {
+ TTestSync testSync;
+
+ TTestEnv env;
+
+ TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSleepCoroTask>(&testSync);
+
+ testSync.WaitForAndIncrement(1);
+ }
+
+ struct TSubtask : ICoroTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TSubtask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ void Run() override {
+ TestSync->CheckAndIncrement(1);
+ }
+ };
+
+ struct TSpawnCoroTask : ICoroTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TSpawnCoroTask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TSubtaskCompletion SubtaskCompletion;
+
+ void Run() override {
+ TestSync->CheckAndIncrement(0);
+ SpawnSubtask<TSubtask>(Env, &SubtaskCompletion, TestSync);
+ WaitForSubtasks();
+ TestSync->CheckAndIncrement(2);
+ }
+ };
+
+ Y_UNIT_TEST(Spawn) {
+ TTestSync testSync;
+
+ TTestEnv env;
+
+ TIntrusivePtr<TCoroTaskRunner> task = env.SpawnTask<TSpawnCoroTask>(&testSync);
+
+ testSync.WaitForAndIncrement(3);
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/env.cpp b/library/cpp/messagebus/rain_check/core/env.cpp
new file mode 100644
index 0000000000..fdc0000dbd
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/env.cpp
@@ -0,0 +1,3 @@
+#include "env.h"
+
+using namespace NRainCheck;
diff --git a/library/cpp/messagebus/rain_check/core/env.h b/library/cpp/messagebus/rain_check/core/env.h
new file mode 100644
index 0000000000..f6dd7fceb6
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/env.h
@@ -0,0 +1,47 @@
+#pragma once
+
+#include "sleep.h"
+#include "spawn.h"
+
+#include <library/cpp/messagebus/actor/executor.h>
+
+#include <util/generic/ptr.h>
+
+namespace NRainCheck {
+ struct IEnv {
+ virtual ::NActor::TExecutor* GetExecutor() = 0;
+ virtual ~IEnv() {
+ }
+ };
+
+ template <typename TSelf>
+ struct TEnvTemplate: public IEnv {
+ template <typename TTask, typename TParam>
+ TIntrusivePtr<typename TTask::TTaskRunner> SpawnTask(TParam param) {
+ return ::NRainCheck::SpawnTask<TTask, TSelf>((TSelf*)this, param);
+ }
+ };
+
+ template <typename TSelf>
+ struct TSimpleEnvTemplate: public TEnvTemplate<TSelf> {
+ ::NActor::TExecutorPtr Executor;
+ TSleepService SleepService;
+
+ TSimpleEnvTemplate(unsigned threadCount = 0)
+ : Executor(new ::NActor::TExecutor(threadCount != 0 ? threadCount : 4))
+ {
+ }
+
+ ::NActor::TExecutor* GetExecutor() override {
+ return Executor.Get();
+ }
+ };
+
+ struct TSimpleEnv: public TSimpleEnvTemplate<TSimpleEnv> {
+ TSimpleEnv(unsigned threadCount = 0)
+ : TSimpleEnvTemplate<TSimpleEnv>(threadCount)
+ {
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/fwd.h b/library/cpp/messagebus/rain_check/core/fwd.h
new file mode 100644
index 0000000000..b43ff8c17c
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/fwd.h
@@ -0,0 +1,18 @@
+#pragma once
+
+namespace NRainCheck {
+ namespace NPrivate {
+ }
+
+ class ITaskBase;
+ class ISimpleTask;
+ class ICoroTask;
+
+ struct ISubtaskListener;
+
+ class TTaskRunnerBase;
+
+ class TSubtaskCompletion;
+ struct IEnv;
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/rain_check.cpp b/library/cpp/messagebus/rain_check/core/rain_check.cpp
new file mode 100644
index 0000000000..2ea1f9e21b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/rain_check.cpp
@@ -0,0 +1 @@
+#include "rain_check.h"
diff --git a/library/cpp/messagebus/rain_check/core/rain_check.h b/library/cpp/messagebus/rain_check/core/rain_check.h
new file mode 100644
index 0000000000..0f289717a2
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/rain_check.h
@@ -0,0 +1,8 @@
+#pragma once
+
+#include "coro.h"
+#include "env.h"
+#include "simple.h"
+#include "sleep.h"
+#include "spawn.h"
+#include "task.h"
diff --git a/library/cpp/messagebus/rain_check/core/simple.cpp b/library/cpp/messagebus/rain_check/core/simple.cpp
new file mode 100644
index 0000000000..70182b2f93
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/simple.cpp
@@ -0,0 +1,18 @@
+#include "simple.h"
+
+using namespace NRainCheck;
+
+TSimpleTaskRunner::TSimpleTaskRunner(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ISimpleTask> impl)
+ : TTaskRunnerBase(env, parentTask, impl.Release())
+ , ContinueFunc(&ISimpleTask::Start)
+{
+}
+
+TSimpleTaskRunner::~TSimpleTaskRunner() {
+ Y_ASSERT(!ContinueFunc);
+}
+
+bool TSimpleTaskRunner::ReplyReceived() {
+ ContinueFunc = (GetImpl()->*(ContinueFunc.Func))();
+ return !!ContinueFunc;
+}
diff --git a/library/cpp/messagebus/rain_check/core/simple.h b/library/cpp/messagebus/rain_check/core/simple.h
new file mode 100644
index 0000000000..20e1bf19f5
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/simple.h
@@ -0,0 +1,62 @@
+#pragma once
+
+#include "task.h"
+
+namespace NRainCheck {
+ class ISimpleTask;
+
+ // Function called on continue
+ class TContinueFunc {
+ friend class TSimpleTaskRunner;
+
+ typedef TContinueFunc (ISimpleTask::*TFunc)();
+ TFunc Func;
+
+ public:
+ TContinueFunc()
+ : Func(nullptr)
+ {
+ }
+
+ TContinueFunc(void*)
+ : Func(nullptr)
+ {
+ }
+
+ template <typename TTask>
+ TContinueFunc(TContinueFunc (TTask::*func)())
+ : Func((TFunc)func)
+ {
+ static_assert((std::is_base_of<ISimpleTask, TTask>::value), "expect (std::is_base_of<ISimpleTask, TTask>::value)");
+ }
+
+ bool operator!() const {
+ return !Func;
+ }
+ };
+
+ class TSimpleTaskRunner: public TTaskRunnerBase {
+ public:
+ TSimpleTaskRunner(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ISimpleTask>);
+ ~TSimpleTaskRunner() override;
+
+ private:
+ // Function to be called on completion of all pending tasks.
+ TContinueFunc ContinueFunc;
+
+ bool ReplyReceived() override /* override */;
+
+ ISimpleTask* GetImpl() {
+ return (ISimpleTask*)GetImplBase();
+ }
+ };
+
+ class ISimpleTask: public ITaskBase {
+ public:
+ typedef TSimpleTaskRunner TTaskRunner;
+ typedef ISimpleTask ITask;
+
+ virtual TContinueFunc Start() = 0;
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/simple_ut.cpp b/library/cpp/messagebus/rain_check/core/simple_ut.cpp
new file mode 100644
index 0000000000..d4545e05aa
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/simple_ut.cpp
@@ -0,0 +1,59 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <library/cpp/messagebus/latch.h>
+
+#include <util/system/event.h>
+
+using namespace NRainCheck;
+
+Y_UNIT_TEST_SUITE(RainCheckSimple) {
+ struct TTaskWithCompletionCallback: public ISimpleTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TTaskWithCompletionCallback(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TSubtaskCompletion SleepCompletion;
+
+ TContinueFunc Start() override {
+ TestSync->CheckAndIncrement(0);
+
+ Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1));
+ SleepCompletion.SetCompletionCallback(&TTaskWithCompletionCallback::SleepCompletionCallback);
+
+ return &TTaskWithCompletionCallback::Last;
+ }
+
+ void SleepCompletionCallback(TSubtaskCompletion* completion) {
+ Y_VERIFY(completion == &SleepCompletion);
+ TestSync->CheckAndIncrement(1);
+
+ Env->SleepService.Sleep(&SleepCompletion, TDuration::MilliSeconds(1));
+ SleepCompletion.SetCompletionCallback(&TTaskWithCompletionCallback::NextSleepCompletionCallback);
+ }
+
+ void NextSleepCompletionCallback(TSubtaskCompletion*) {
+ TestSync->CheckAndIncrement(2);
+ }
+
+ TContinueFunc Last() {
+ TestSync->CheckAndIncrement(3);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(CompletionCallback) {
+ TTestEnv env;
+ TTestSync testSync;
+
+ env.SpawnTask<TTaskWithCompletionCallback>(&testSync);
+
+ testSync.WaitForAndIncrement(4);
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/sleep.cpp b/library/cpp/messagebus/rain_check/core/sleep.cpp
new file mode 100644
index 0000000000..f5d0b4cac9
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/sleep.cpp
@@ -0,0 +1,47 @@
+#include "rain_check.h"
+
+#include <util/system/yassert.h>
+
+using namespace NRainCheck;
+using namespace NRainCheck::NPrivate;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TSleepService::TSleepService(::NBus::NPrivate::TScheduler* scheduler)
+ : Scheduler(scheduler)
+{
+}
+
+NRainCheck::TSleepService::TSleepService()
+ : SchedulerHolder(new TScheduler)
+ , Scheduler(SchedulerHolder.Get())
+{
+}
+
+NRainCheck::TSleepService::~TSleepService() {
+ if (!!SchedulerHolder) {
+ Scheduler->Stop();
+ }
+}
+
+namespace {
+ struct TSleepServiceScheduleItem: public IScheduleItem {
+ ISubtaskListener* const Parent;
+
+ TSleepServiceScheduleItem(ISubtaskListener* parent, TInstant time)
+ : IScheduleItem(time)
+ , Parent(parent)
+ {
+ }
+
+ void Do() override {
+ Parent->SetDone();
+ }
+ };
+}
+
+void TSleepService::Sleep(TSubtaskCompletion* r, TDuration duration) {
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+ r->SetRunning(current);
+ Scheduler->Schedule(new TSleepServiceScheduleItem(r, TInstant::Now() + duration));
+}
diff --git a/library/cpp/messagebus/rain_check/core/sleep.h b/library/cpp/messagebus/rain_check/core/sleep.h
new file mode 100644
index 0000000000..1a7a1f8674
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/sleep.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include "fwd.h"
+
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <util/datetime/base.h>
+
+namespace NRainCheck {
+ class TSleepService {
+ private:
+ THolder< ::NBus::NPrivate::TScheduler> SchedulerHolder;
+ ::NBus::NPrivate::TScheduler* const Scheduler;
+
+ public:
+ TSleepService(::NBus::NPrivate::TScheduler*);
+ TSleepService();
+ ~TSleepService();
+
+ // Wake up a task after given duration.
+ void Sleep(TSubtaskCompletion* r, TDuration);
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/sleep_ut.cpp b/library/cpp/messagebus/rain_check/core/sleep_ut.cpp
new file mode 100644
index 0000000000..2ae85a87b1
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/sleep_ut.cpp
@@ -0,0 +1,46 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <util/system/event.h>
+
+using namespace NRainCheck;
+using namespace NActor;
+
+Y_UNIT_TEST_SUITE(Sleep) {
+ struct TTestTask: public ISimpleTask {
+ TSimpleEnv* const Env;
+ TTestSync* const TestSync;
+
+ TTestTask(TSimpleEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TSubtaskCompletion Sleep;
+
+ TContinueFunc Start() override {
+ Env->SleepService.Sleep(&Sleep, TDuration::MilliSeconds(1));
+
+ TestSync->CheckAndIncrement(0);
+
+ return &TTestTask::Continue;
+ }
+
+ TContinueFunc Continue() {
+ TestSync->CheckAndIncrement(1);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Test) {
+ TTestSync testSync;
+
+ TSimpleEnv env;
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TTestTask>(&testSync);
+
+ testSync.WaitForAndIncrement(2);
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/spawn.cpp b/library/cpp/messagebus/rain_check/core/spawn.cpp
new file mode 100644
index 0000000000..c570355fbe
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/spawn.cpp
@@ -0,0 +1,5 @@
+#include "spawn.h"
+
+void NRainCheck::NPrivate::SpawnTaskImpl(TTaskRunnerBase* task) {
+ task->Schedule();
+}
diff --git a/library/cpp/messagebus/rain_check/core/spawn.h b/library/cpp/messagebus/rain_check/core/spawn.h
new file mode 100644
index 0000000000..f2b146bf29
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/spawn.h
@@ -0,0 +1,50 @@
+#pragma once
+
+#include "coro.h"
+#include "simple.h"
+#include "task.h"
+
+namespace NRainCheck {
+ namespace NPrivate {
+ void SpawnTaskImpl(TTaskRunnerBase* task);
+
+ template <typename TTask, typename ITask, typename TRunner, typename TEnv, typename TParam>
+ TIntrusivePtr<TRunner> SpawnTaskWithRunner(TEnv* env, TParam param1, ISubtaskListener* subtaskListener) {
+ static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)");
+ TIntrusivePtr<TRunner> task(new TRunner(env, subtaskListener, new TTask(env, param1)));
+ NPrivate::SpawnTaskImpl(task.Get());
+ return task;
+ }
+
+ template <typename TTask, typename ITask, typename TRunner, typename TEnv>
+ void SpawnSubtaskWithRunner(TEnv* env, TSubtaskCompletion* completion) {
+ static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)");
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+ completion->SetRunning(current);
+ NPrivate::SpawnTaskImpl(new TRunner(env, completion, new TTask(env)));
+ }
+
+ template <typename TTask, typename ITask, typename TRunner, typename TEnv, typename TParam>
+ void SpawnSubtaskWithRunner(TEnv* env, TSubtaskCompletion* completion, TParam param) {
+ static_assert((std::is_base_of<ITask, TTask>::value), "expect (std::is_base_of<ITask, TTask>::value)");
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+ completion->SetRunning(current);
+ NPrivate::SpawnTaskImpl(new TRunner(env, completion, new TTask(env, param)));
+ }
+
+ }
+
+ // Instantiate and start a task with given parameter.
+ template <typename TTask, typename TEnv, typename TParam>
+ TIntrusivePtr<typename TTask::TTaskRunner> SpawnTask(TEnv* env, TParam param1, ISubtaskListener* subtaskListener = &TNopSubtaskListener::Instance) {
+ return NPrivate::SpawnTaskWithRunner<
+ TTask, typename TTask::ITask, typename TTask::TTaskRunner, TEnv, TParam>(env, param1, subtaskListener);
+ }
+
+ // Instantiate and start subtask of given task.
+ template <typename TTask, typename TEnv, typename TParam>
+ void SpawnSubtask(TEnv* env, TSubtaskCompletion* completion, TParam param) {
+ return NPrivate::SpawnSubtaskWithRunner<TTask, typename TTask::ITask, typename TTask::TTaskRunner>(env, completion, param);
+ }
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/spawn_ut.cpp b/library/cpp/messagebus/rain_check/core/spawn_ut.cpp
new file mode 100644
index 0000000000..ba5a5e41cf
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/spawn_ut.cpp
@@ -0,0 +1,145 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/rain_check/test/helper/misc.h>
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <library/cpp/messagebus/latch.h>
+
+#include <util/system/event.h>
+
+#include <array>
+
+using namespace NRainCheck;
+using namespace NActor;
+
+Y_UNIT_TEST_SUITE(Spawn) {
+ struct TTestTask: public ISimpleTask {
+ TTestSync* const TestSync;
+
+ TTestTask(TSimpleEnv*, TTestSync* testSync)
+ : TestSync(testSync)
+ , I(0)
+ {
+ }
+
+ TSystemEvent Started;
+
+ unsigned I;
+
+ TContinueFunc Start() override {
+ if (I < 4) {
+ I += 1;
+ return &TTestTask::Start;
+ }
+ TestSync->CheckAndIncrement(0);
+ return &TTestTask::Continue;
+ }
+
+ TContinueFunc Continue() {
+ TestSync->CheckAndIncrement(1);
+
+ Started.Signal();
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Continuation) {
+ TTestSync testSync;
+
+ TSimpleEnv env;
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TTestTask>(&testSync);
+
+ testSync.WaitForAndIncrement(2);
+ }
+
+ struct TSubtask: public ISimpleTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TSubtask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TContinueFunc Start() override {
+ Sleep(TDuration::MilliSeconds(1));
+ TestSync->CheckAndIncrement(1);
+ return nullptr;
+ }
+ };
+
+ struct TSpawnTask: public ISimpleTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+
+ TSpawnTask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ {
+ }
+
+ TSubtaskCompletion SubtaskCompletion;
+
+ TContinueFunc Start() override {
+ TestSync->CheckAndIncrement(0);
+ SpawnSubtask<TSubtask>(Env, &SubtaskCompletion, TestSync);
+ return &TSpawnTask::Continue;
+ }
+
+ TContinueFunc Continue() {
+ TestSync->CheckAndIncrement(2);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Subtask) {
+ TTestSync testSync;
+
+ TTestEnv env;
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSpawnTask>(&testSync);
+
+ testSync.WaitForAndIncrement(3);
+ }
+
+ struct TSpawnLongTask: public ISimpleTask {
+ TTestEnv* const Env;
+ TTestSync* const TestSync;
+ unsigned I;
+
+ TSpawnLongTask(TTestEnv* env, TTestSync* testSync)
+ : Env(env)
+ , TestSync(testSync)
+ , I(0)
+ {
+ }
+
+ std::array<TSubtaskCompletion, 3> Subtasks;
+
+ TContinueFunc Start() override {
+ if (I == 1000) {
+ TestSync->CheckAndIncrement(0);
+ return nullptr;
+ }
+
+ for (auto& subtask : Subtasks) {
+ SpawnSubtask<TNopSimpleTask>(Env, &subtask, "");
+ }
+
+ ++I;
+ return &TSpawnLongTask::Start;
+ }
+ };
+
+ Y_UNIT_TEST(SubtaskLong) {
+ TTestSync testSync;
+
+ TTestEnv env;
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSpawnLongTask>(&testSync);
+
+ testSync.WaitForAndIncrement(1);
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/task.cpp b/library/cpp/messagebus/rain_check/core/task.cpp
new file mode 100644
index 0000000000..a098437d53
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/task.cpp
@@ -0,0 +1,216 @@
+#include "rain_check.h"
+
+#include <library/cpp/messagebus/actor/temp_tls_vector.h>
+
+#include <util/system/type_name.h>
+#include <util/system/tls.h>
+
+using namespace NRainCheck;
+using namespace NRainCheck::NPrivate;
+
+using namespace NActor;
+
+namespace {
+ Y_POD_STATIC_THREAD(TTaskRunnerBase*)
+ ThreadCurrentTask;
+}
+
+void TNopSubtaskListener::SetDone() {
+}
+
+TNopSubtaskListener TNopSubtaskListener::Instance;
+
+TTaskRunnerBase::TTaskRunnerBase(IEnv* env, ISubtaskListener* parentTask, TAutoPtr<ITaskBase> impl)
+ : TActor<TTaskRunnerBase>(env->GetExecutor())
+ , Impl(impl)
+ , ParentTask(parentTask)
+ //, HoldsSelfReference(false)
+ , Done(false)
+ , SetDoneCalled(false)
+{
+}
+
+TTaskRunnerBase::~TTaskRunnerBase() {
+ Y_ASSERT(Done);
+}
+
+namespace {
+ struct TRunningInThisThreadGuard {
+ TTaskRunnerBase* const Task;
+ TRunningInThisThreadGuard(TTaskRunnerBase* task)
+ : Task(task)
+ {
+ Y_ASSERT(!ThreadCurrentTask);
+ ThreadCurrentTask = task;
+ }
+
+ ~TRunningInThisThreadGuard() {
+ Y_ASSERT(ThreadCurrentTask == Task);
+ ThreadCurrentTask = nullptr;
+ }
+ };
+}
+
+void NRainCheck::TTaskRunnerBase::Act(NActor::TDefaultTag) {
+ Y_ASSERT(RefCount() > 0);
+
+ TRunningInThisThreadGuard g(this);
+
+ //RetainRef();
+
+ for (;;) {
+ TTempTlsVector<TSubtaskCompletion*> temp;
+
+ temp.GetVector()->swap(Pending);
+
+ for (auto& pending : *temp.GetVector()) {
+ if (pending->IsComplete()) {
+ pending->FireCompletionCallback(GetImplBase());
+ } else {
+ Pending.push_back(pending);
+ }
+ }
+
+ if (!Pending.empty()) {
+ return;
+ }
+
+ if (!Done) {
+ Done = !ReplyReceived();
+ } else {
+ if (Pending.empty()) {
+ if (!SetDoneCalled) {
+ ParentTask->SetDone();
+ SetDoneCalled = true;
+ }
+ //ReleaseRef();
+ return;
+ }
+ }
+ }
+}
+
+bool TTaskRunnerBase::IsRunningInThisThread() const {
+ return ThreadCurrentTask == this;
+}
+
+TSubtaskCompletion::~TSubtaskCompletion() {
+ ESubtaskState state = State.Get();
+ Y_ASSERT(state == CREATED || state == DONE || state == CANCELED);
+}
+
+void TSubtaskCompletion::FireCompletionCallback(ITaskBase* task) {
+ Y_ASSERT(IsComplete());
+
+ if (!!CompletionFunc) {
+ TSubtaskCompletionFunc temp = CompletionFunc;
+ // completion func must be reset before calling it,
+ // because function may set it back
+ CompletionFunc = TSubtaskCompletionFunc();
+ (task->*(temp.Func))(this);
+ }
+}
+
+void NRainCheck::TSubtaskCompletion::Cancel() {
+ for (;;) {
+ ESubtaskState state = State.Get();
+ if (state == CREATED && State.CompareAndSet(CREATED, CANCELED)) {
+ return;
+ }
+ if (state == RUNNING && State.CompareAndSet(RUNNING, CANCEL_REQUESTED)) {
+ return;
+ }
+ if (state == DONE && State.CompareAndSet(DONE, CANCELED)) {
+ return;
+ }
+ if (state == CANCEL_REQUESTED || state == CANCELED) {
+ return;
+ }
+ }
+}
+
+void TSubtaskCompletion::SetRunning(TTaskRunnerBase* parent) {
+ Y_ASSERT(!TaskRunner);
+ Y_ASSERT(!!parent);
+
+ TaskRunner = parent;
+
+ parent->Pending.push_back(this);
+
+ parent->RefV();
+
+ for (;;) {
+ ESubtaskState current = State.Get();
+ if (current != CREATED && current != DONE) {
+ Y_FAIL("current state should be CREATED or DONE: %s", ToCString(current));
+ }
+ if (State.CompareAndSet(current, RUNNING)) {
+ return;
+ }
+ }
+}
+
+void TSubtaskCompletion::SetDone() {
+ Y_ASSERT(!!TaskRunner);
+ TTaskRunnerBase* temp = TaskRunner;
+ TaskRunner = nullptr;
+
+ for (;;) {
+ ESubtaskState state = State.Get();
+ if (state == RUNNING) {
+ if (State.CompareAndSet(RUNNING, DONE)) {
+ break;
+ }
+ } else if (state == CANCEL_REQUESTED) {
+ if (State.CompareAndSet(CANCEL_REQUESTED, CANCELED)) {
+ break;
+ }
+ } else {
+ Y_FAIL("cannot SetDone: unknown state: %s", ToCString(state));
+ }
+ }
+
+ temp->ScheduleV();
+ temp->UnRefV();
+}
+
+#if 0
+void NRainCheck::TTaskRunnerBase::RetainRef()
+{
+ if (HoldsSelfReference) {
+ return;
+ }
+ HoldsSelfReference = true;
+ Ref();
+}
+
+void NRainCheck::TTaskRunnerBase::ReleaseRef()
+{
+ if (!HoldsSelfReference) {
+ return;
+ }
+ HoldsSelfReference = false;
+ DecRef();
+}
+#endif
+
+void TTaskRunnerBase::AssertInThisThread() const {
+ Y_ASSERT(IsRunningInThisThread());
+}
+
+TTaskRunnerBase* TTaskRunnerBase::CurrentTask() {
+ Y_VERIFY(!!ThreadCurrentTask);
+ return ThreadCurrentTask;
+}
+
+ITaskBase* TTaskRunnerBase::CurrentTaskImpl() {
+ return CurrentTask()->GetImplBase();
+}
+
+TString TTaskRunnerBase::GetStatusSingleLine() {
+ return TypeName(*Impl);
+}
+
+bool NRainCheck::AreWeInsideTask() {
+ return ThreadCurrentTask != nullptr;
+}
diff --git a/library/cpp/messagebus/rain_check/core/task.h b/library/cpp/messagebus/rain_check/core/task.h
new file mode 100644
index 0000000000..7d8778bcda
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/task.h
@@ -0,0 +1,184 @@
+#pragma once
+
+#include "fwd.h"
+
+#include <library/cpp/messagebus/actor/actor.h>
+#include <library/cpp/messagebus/misc/atomic_box.h>
+
+#include <library/cpp/deprecated/enum_codegen/enum_codegen.h>
+
+#include <util/generic/noncopyable.h>
+#include <util/generic/ptr.h>
+#include <util/thread/lfstack.h>
+
+namespace NRainCheck {
+ struct ISubtaskListener {
+ virtual void SetDone() = 0;
+ virtual ~ISubtaskListener() {
+ }
+ };
+
+ struct TNopSubtaskListener: public ISubtaskListener {
+ void SetDone() override;
+
+ static TNopSubtaskListener Instance;
+ };
+
+ class TSubtaskCompletionFunc {
+ friend class TSubtaskCompletion;
+
+ typedef void (ITaskBase::*TFunc)(TSubtaskCompletion*);
+ TFunc Func;
+
+ public:
+ TSubtaskCompletionFunc()
+ : Func(nullptr)
+ {
+ }
+
+ TSubtaskCompletionFunc(void*)
+ : Func(nullptr)
+ {
+ }
+
+ template <typename TTask>
+ TSubtaskCompletionFunc(void (TTask::*func)(TSubtaskCompletion*))
+ : Func((TFunc)func)
+ {
+ static_assert((std::is_base_of<ITaskBase, TTask>::value), "expect (std::is_base_of<ITaskBase, TTask>::value)");
+ }
+
+ bool operator!() const {
+ return !Func;
+ }
+ };
+
+ template <typename T>
+ class TTaskFuture;
+
+#define SUBTASK_STATE_MAP(XX) \
+ XX(CREATED, "Initial") \
+ XX(RUNNING, "Running") \
+ XX(DONE, "Completed") \
+ XX(CANCEL_REQUESTED, "Cancel requested, but still executing") \
+ XX(CANCELED, "Canceled") \
+ /**/
+
+ enum ESubtaskState {
+ SUBTASK_STATE_MAP(ENUM_VALUE_GEN_NO_VALUE)
+ };
+
+ ENUM_TO_STRING(ESubtaskState, SUBTASK_STATE_MAP)
+
+ class TSubtaskCompletion : TNonCopyable, public ISubtaskListener {
+ friend struct TTaskAccessor;
+
+ private:
+ TAtomicBox<ESubtaskState> State;
+ TTaskRunnerBase* volatile TaskRunner;
+ TSubtaskCompletionFunc CompletionFunc;
+
+ public:
+ TSubtaskCompletion()
+ : State(CREATED)
+ , TaskRunner()
+ {
+ }
+ ~TSubtaskCompletion() override;
+
+ // Either done or cancel requested or cancelled
+ bool IsComplete() const {
+ ESubtaskState state = State.Get();
+ switch (state) {
+ case RUNNING:
+ return false;
+ case DONE:
+ return true;
+ case CANCEL_REQUESTED:
+ return false;
+ case CANCELED:
+ return true;
+ case CREATED:
+ Y_FAIL("not started");
+ default:
+ Y_FAIL("unknown value: %u", (unsigned)state);
+ }
+ }
+
+ void FireCompletionCallback(ITaskBase*);
+
+ void SetCompletionCallback(TSubtaskCompletionFunc func) {
+ CompletionFunc = func;
+ }
+
+ // Completed, but not cancelled
+ bool IsDone() const {
+ return State.Get() == DONE;
+ }
+
+ // Request cancel by actor
+ // Does nothing but marks task cancelled,
+ // and allows proceeding to next callback
+ void Cancel();
+
+ // called by service provider implementations
+ // must not be called by actor
+ void SetRunning(TTaskRunnerBase* parent);
+ void SetDone() override;
+ };
+
+ // See ISimpleTask, ICoroTask
+ class TTaskRunnerBase: public TAtomicRefCount<TTaskRunnerBase>, public NActor::TActor<TTaskRunnerBase> {
+ friend class NActor::TActor<TTaskRunnerBase>;
+ friend class TContinueFunc;
+ friend struct TTaskAccessor;
+ friend class TSubtaskCompletion;
+
+ private:
+ THolder<ITaskBase> Impl;
+
+ ISubtaskListener* const ParentTask;
+ // While task is running, it holds extra reference to self.
+ //bool HoldsSelfReference;
+ bool Done;
+ bool SetDoneCalled;
+
+ // Subtasks currently executed.
+ TVector<TSubtaskCompletion*> Pending;
+
+ void Act(NActor::TDefaultTag);
+
+ public:
+ // Construct task. Task is not automatically started.
+ TTaskRunnerBase(IEnv*, ISubtaskListener* parent, TAutoPtr<ITaskBase> impl);
+ ~TTaskRunnerBase() override;
+
+ bool IsRunningInThisThread() const;
+ void AssertInThisThread() const;
+ static TTaskRunnerBase* CurrentTask();
+ static ITaskBase* CurrentTaskImpl();
+
+ TString GetStatusSingleLine();
+
+ protected:
+ //void RetainRef();
+ //void ReleaseRef();
+ ITaskBase* GetImplBase() {
+ return Impl.Get();
+ }
+
+ private:
+ // true if need to call again
+ virtual bool ReplyReceived() = 0;
+ };
+
+ class ITaskBase {
+ public:
+ virtual ~ITaskBase() {
+ }
+ };
+
+ // Check that current method executed inside some task.
+ bool AreWeInsideTask();
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/track.cpp b/library/cpp/messagebus/rain_check/core/track.cpp
new file mode 100644
index 0000000000..092a51a214
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/track.cpp
@@ -0,0 +1,66 @@
+#include "track.h"
+
+using namespace NRainCheck;
+using namespace NRainCheck::NPrivate;
+
+void TTaskTrackerReceipt::SetDone() {
+ TaskTracker->GetQueue<TTaskTrackerReceipt*>()->EnqueueAndSchedule(this);
+}
+
+TString TTaskTrackerReceipt::GetStatusSingleLine() {
+ return Task->GetStatusSingleLine();
+}
+
+TTaskTracker::TTaskTracker(NActor::TExecutor* executor)
+ : NActor::TActor<TTaskTracker>(executor)
+{
+}
+
+TTaskTracker::~TTaskTracker() {
+ Y_ASSERT(Tasks.Empty());
+}
+
+void TTaskTracker::Shutdown() {
+ ShutdownFlag.Set(true);
+ Schedule();
+ ShutdownEvent.WaitI();
+}
+
+void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, ITaskFactory* taskFactory) {
+ THolder<ITaskFactory> holder(taskFactory);
+
+ THolder<TTaskTrackerReceipt> receipt(new TTaskTrackerReceipt(this));
+ receipt->Task = taskFactory->NewTask(receipt.Get());
+
+ Tasks.PushBack(receipt.Release());
+}
+
+void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TTaskTrackerReceipt* receipt) {
+ Y_ASSERT(!receipt->Empty());
+ receipt->Unlink();
+ delete receipt;
+}
+
+void TTaskTracker::ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TAsyncResult<TTaskTrackerStatus>* status) {
+ TTaskTrackerStatus s;
+ s.Size = Tasks.Size();
+ status->SetResult(s);
+}
+
+void TTaskTracker::Act(NActor::TDefaultTag) {
+ GetQueue<TAsyncResult<TTaskTrackerStatus>*>()->DequeueAll();
+ GetQueue<ITaskFactory*>()->DequeueAll();
+ GetQueue<TTaskTrackerReceipt*>()->DequeueAll();
+
+ if (ShutdownFlag.Get()) {
+ if (Tasks.Empty()) {
+ ShutdownEvent.Signal();
+ }
+ }
+}
+
+ui32 TTaskTracker::Size() {
+ TAsyncResult<TTaskTrackerStatus> r;
+ GetQueue<TAsyncResult<TTaskTrackerStatus>*>()->EnqueueAndSchedule(&r);
+ return r.GetResult().Size;
+}
diff --git a/library/cpp/messagebus/rain_check/core/track.h b/library/cpp/messagebus/rain_check/core/track.h
new file mode 100644
index 0000000000..d387de7574
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/track.h
@@ -0,0 +1,97 @@
+#pragma once
+
+#include "spawn.h"
+#include "task.h"
+
+#include <library/cpp/messagebus/async_result.h>
+#include <library/cpp/messagebus/actor/queue_in_actor.h>
+#include <library/cpp/messagebus/misc/atomic_box.h>
+
+#include <util/generic/intrlist.h>
+#include <util/system/event.h>
+
+namespace NRainCheck {
+ class TTaskTracker;
+
+ namespace NPrivate {
+ struct ITaskFactory {
+ virtual TIntrusivePtr<TTaskRunnerBase> NewTask(ISubtaskListener*) = 0;
+ virtual ~ITaskFactory() {
+ }
+ };
+
+ struct TTaskTrackerReceipt: public ISubtaskListener, public TIntrusiveListItem<TTaskTrackerReceipt> {
+ TTaskTracker* const TaskTracker;
+ TIntrusivePtr<TTaskRunnerBase> Task;
+
+ TTaskTrackerReceipt(TTaskTracker* taskTracker)
+ : TaskTracker(taskTracker)
+ {
+ }
+
+ void SetDone() override;
+
+ TString GetStatusSingleLine();
+ };
+
+ struct TTaskTrackerStatus {
+ ui32 Size;
+ };
+
+ }
+
+ class TTaskTracker
+ : public TAtomicRefCount<TTaskTracker>,
+ public NActor::TActor<TTaskTracker>,
+ public NActor::TQueueInActor<TTaskTracker, NPrivate::ITaskFactory*>,
+ public NActor::TQueueInActor<TTaskTracker, NPrivate::TTaskTrackerReceipt*>,
+ public NActor::TQueueInActor<TTaskTracker, TAsyncResult<NPrivate::TTaskTrackerStatus>*> {
+ friend struct NPrivate::TTaskTrackerReceipt;
+
+ private:
+ TAtomicBox<bool> ShutdownFlag;
+ TSystemEvent ShutdownEvent;
+
+ TIntrusiveList<NPrivate::TTaskTrackerReceipt> Tasks;
+
+ template <typename TItem>
+ NActor::TQueueInActor<TTaskTracker, TItem>* GetQueue() {
+ return this;
+ }
+
+ public:
+ TTaskTracker(NActor::TExecutor* executor);
+ ~TTaskTracker() override;
+
+ void Shutdown();
+
+ void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, NPrivate::ITaskFactory*);
+ void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, NPrivate::TTaskTrackerReceipt*);
+ void ProcessItem(NActor::TDefaultTag, NActor::TDefaultTag, TAsyncResult<NPrivate::TTaskTrackerStatus>*);
+
+ void Act(NActor::TDefaultTag);
+
+ template <typename TTask, typename TEnv, typename TParam>
+ void Spawn(TEnv* env, TParam param) {
+ struct TTaskFactory: public NPrivate::ITaskFactory {
+ TEnv* const Env;
+ TParam Param;
+
+ TTaskFactory(TEnv* env, TParam param)
+ : Env(env)
+ , Param(param)
+ {
+ }
+
+ TIntrusivePtr<TTaskRunnerBase> NewTask(ISubtaskListener* subtaskListener) override {
+ return NRainCheck::SpawnTask<TTask>(Env, Param, subtaskListener).Get();
+ }
+ };
+
+ GetQueue<NPrivate::ITaskFactory*>()->EnqueueAndSchedule(new TTaskFactory(env, param));
+ }
+
+ ui32 Size();
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/core/track_ut.cpp b/library/cpp/messagebus/rain_check/core/track_ut.cpp
new file mode 100644
index 0000000000..05f7de1319
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/track_ut.cpp
@@ -0,0 +1,45 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "track.h"
+
+#include <library/cpp/messagebus/rain_check/test/helper/misc.h>
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+using namespace NRainCheck;
+
+Y_UNIT_TEST_SUITE(TaskTracker) {
+ struct TTaskForTracker: public ISimpleTask {
+ TTestSync* const TestSync;
+
+ TTaskForTracker(TTestEnv*, TTestSync* testSync)
+ : TestSync(testSync)
+ {
+ }
+
+ TContinueFunc Start() override {
+ TestSync->WaitForAndIncrement(0);
+ TestSync->WaitForAndIncrement(2);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TTestEnv env;
+
+ TIntrusivePtr<TTaskTracker> tracker(new TTaskTracker(env.GetExecutor()));
+
+ TTestSync testSync;
+
+ tracker->Spawn<TTaskForTracker>(&env, &testSync);
+
+ testSync.WaitFor(1);
+
+ UNIT_ASSERT_VALUES_EQUAL(1u, tracker->Size());
+
+ testSync.CheckAndIncrement(1);
+
+ testSync.WaitForAndIncrement(3);
+
+ tracker->Shutdown();
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/core/ya.make b/library/cpp/messagebus/rain_check/core/ya.make
new file mode 100644
index 0000000000..c6fb5640d4
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/core/ya.make
@@ -0,0 +1,25 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/coroutine/engine
+ library/cpp/deprecated/enum_codegen
+ library/cpp/messagebus
+ library/cpp/messagebus/actor
+ library/cpp/messagebus/scheduler
+)
+
+SRCS(
+ coro.cpp
+ coro_stack.cpp
+ env.cpp
+ rain_check.cpp
+ simple.cpp
+ sleep.cpp
+ spawn.cpp
+ task.cpp
+ track.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/http/client.cpp b/library/cpp/messagebus/rain_check/http/client.cpp
new file mode 100644
index 0000000000..5ef5ceeece
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/client.cpp
@@ -0,0 +1,154 @@
+#include "client.h"
+
+#include "http_code_extractor.h"
+
+#include <library/cpp/http/io/stream.h>
+#include <library/cpp/neh/factory.h>
+#include <library/cpp/neh/http_common.h>
+#include <library/cpp/neh/location.h>
+#include <library/cpp/neh/neh.h>
+
+#include <util/generic/ptr.h>
+#include <util/generic/strbuf.h>
+#include <util/network/socket.h>
+#include <util/stream/str.h>
+
+namespace NRainCheck {
+ class THttpCallback: public NNeh::IOnRecv {
+ public:
+ THttpCallback(NRainCheck::THttpFuture* future)
+ : Future(future)
+ {
+ Y_VERIFY(!!future, "future is NULL");
+ }
+
+ void OnRecv(NNeh::THandle& handle) override {
+ THolder<THttpCallback> self(this);
+ NNeh::TResponseRef response = handle.Get();
+ Future->SetDoneAndSchedule(response);
+ }
+
+ private:
+ NRainCheck::THttpFuture* const Future;
+ };
+
+ THttpFuture::THttpFuture()
+ : Task(nullptr)
+ , ErrorCode(THttpFuture::NoError)
+ {
+ }
+
+ THttpFuture::~THttpFuture() {
+ }
+
+ bool THttpFuture::HasError() const {
+ return (ErrorCode != THttpFuture::NoError);
+ }
+
+ THttpFuture::EError THttpFuture::GetErrorCode() const {
+ return ErrorCode;
+ }
+
+ TString THttpFuture::GetErrorDescription() const {
+ return ErrorDescription;
+ }
+
+ THttpClientService::THttpClientService()
+ : GetProtocol(NNeh::ProtocolFactory()->Protocol("http"))
+ , FullProtocol(NNeh::ProtocolFactory()->Protocol("full"))
+ {
+ Y_VERIFY(!!GetProtocol, "GET protocol is NULL.");
+ Y_VERIFY(!!FullProtocol, "POST protocol is NULL.");
+ }
+
+ THttpClientService::~THttpClientService() {
+ }
+
+ void THttpClientService::SendPost(TString addr, const TString& data, const THttpHeaders& headers, THttpFuture* future) {
+ Y_VERIFY(!!future, "future is NULL.");
+
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+ future->SetRunning(current);
+ future->Task = current;
+
+ THolder<THttpCallback> callback(new THttpCallback(future));
+ NNeh::TServiceStatRef stat;
+ try {
+ NNeh::TMessage msg(addr.replace(0, NNeh::TParsedLocation(addr).Scheme.size(), "post"), data);
+ TStringStream headersText;
+ headers.OutTo(&headersText);
+ NNeh::NHttp::MakeFullRequest(msg, headersText.Str(), TString());
+ FullProtocol->ScheduleRequest(msg, callback.Get(), stat);
+ Y_UNUSED(callback.Release());
+ } catch (const TNetworkResolutionError& err) {
+ future->SetFail(THttpFuture::CantResolveNameError, err.AsStrBuf());
+ } catch (const yexception& err) {
+ future->SetFail(THttpFuture::OtherError, err.AsStrBuf());
+ }
+ }
+
+ void THttpClientService::Send(const TString& request, THttpFuture* future) {
+ Y_VERIFY(!!future, "future is NULL.");
+
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+ future->SetRunning(current);
+ future->Task = current;
+
+ THolder<THttpCallback> callback(new THttpCallback(future));
+ NNeh::TServiceStatRef stat;
+ try {
+ GetProtocol->ScheduleRequest(NNeh::TMessage::FromString(request),
+ callback.Get(),
+ stat);
+ Y_UNUSED(callback.Release());
+ } catch (const TNetworkResolutionError& err) {
+ future->SetFail(THttpFuture::CantResolveNameError, err.AsStrBuf());
+ } catch (const yexception& err) {
+ future->SetFail(THttpFuture::OtherError, err.AsStrBuf());
+ }
+ }
+
+ bool THttpFuture::HasHttpCode() const {
+ return !!HttpCode;
+ }
+
+ bool THttpFuture::HasResponseBody() const {
+ return !!Response;
+ }
+
+ ui32 THttpFuture::GetHttpCode() const {
+ Y_ASSERT(IsDone());
+ Y_ASSERT(HasHttpCode());
+
+ return static_cast<ui32>(*HttpCode);
+ }
+
+ TString THttpFuture::GetResponseBody() const {
+ Y_ASSERT(IsDone());
+ Y_ASSERT(HasResponseBody());
+
+ return Response->Data;
+ }
+
+ void THttpFuture::SetDoneAndSchedule(TAutoPtr<NNeh::TResponse> response) {
+ if (!response->IsError()) {
+ ErrorCode = THttpFuture::NoError;
+ HttpCode = HttpCodes::HTTP_OK;
+ } else {
+ ErrorCode = THttpFuture::BadHttpCodeError;
+ ErrorDescription = response->GetErrorText();
+
+ HttpCode = TryGetHttpCodeFromErrorDescription(ErrorDescription);
+ }
+ Response.Reset(response);
+ SetDone();
+ }
+
+ void THttpFuture::SetFail(THttpFuture::EError errorCode, const TStringBuf& errorDescription) {
+ ErrorCode = errorCode;
+ ErrorDescription = errorDescription;
+ Response.Destroy();
+ SetDone();
+ }
+
+}
diff --git a/library/cpp/messagebus/rain_check/http/client.h b/library/cpp/messagebus/rain_check/http/client.h
new file mode 100644
index 0000000000..d4199c4c98
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/client.h
@@ -0,0 +1,78 @@
+#pragma once
+
+#include <library/cpp/messagebus/rain_check/core/task.h>
+
+#include <library/cpp/http/misc/httpcodes.h>
+
+#include <util/generic/maybe.h>
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/system/defaults.h>
+#include <util/system/yassert.h>
+
+class THttpHeaders;
+
+namespace NNeh {
+ class IProtocol;
+ struct TResponse;
+}
+
+namespace NRainCheck {
+ class THttpCallback;
+ class THttpClientService;
+
+ class THttpFuture: public TSubtaskCompletion {
+ public:
+ enum EError {
+ NoError = 0,
+
+ CantResolveNameError = 1,
+ BadHttpCodeError = 2,
+
+ OtherError = 100
+ };
+
+ private:
+ friend class THttpCallback;
+ friend class THttpClientService;
+
+ public:
+ THttpFuture();
+ ~THttpFuture() override;
+
+ bool HasHttpCode() const;
+ bool HasResponseBody() const;
+
+ ui32 GetHttpCode() const;
+ TString GetResponseBody() const;
+
+ bool HasError() const;
+ EError GetErrorCode() const;
+ TString GetErrorDescription() const;
+
+ private:
+ void SetDoneAndSchedule(TAutoPtr<NNeh::TResponse> response);
+ void SetFail(EError errorCode, const TStringBuf& errorDescription);
+
+ private:
+ TTaskRunnerBase* Task;
+ TMaybe<HttpCodes> HttpCode;
+ THolder<NNeh::TResponse> Response;
+ EError ErrorCode;
+ TString ErrorDescription;
+ };
+
+ class THttpClientService {
+ public:
+ THttpClientService();
+ virtual ~THttpClientService();
+
+ void Send(const TString& request, THttpFuture* future);
+ void SendPost(TString addr, const TString& data, const THttpHeaders& headers, THttpFuture* future);
+
+ private:
+ NNeh::IProtocol* const GetProtocol;
+ NNeh::IProtocol* const FullProtocol;
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/http/client_ut.cpp b/library/cpp/messagebus/rain_check/http/client_ut.cpp
new file mode 100644
index 0000000000..1628114391
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/client_ut.cpp
@@ -0,0 +1,205 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "client.h"
+#include "http_code_extractor.h"
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <library/cpp/messagebus/test/helper/fixed_port.h>
+
+#include <library/cpp/http/io/stream.h>
+#include <library/cpp/neh/rpc.h>
+
+#include <util/generic/cast.h>
+#include <util/generic/ptr.h>
+#include <util/generic/strbuf.h>
+#include <util/generic/string.h>
+#include <util/generic/vector.h>
+#include <util/network/ip.h>
+#include <util/stream/str.h>
+#include <util/string/printf.h>
+#include <util/system/defaults.h>
+#include <util/system/yassert.h>
+
+#include <cstdlib>
+#include <utility>
+
+using namespace NRainCheck;
+using namespace NBus::NTest;
+
+namespace {
+ class THttpClientEnv: public TTestEnvTemplate<THttpClientEnv> {
+ public:
+ THttpClientService HttpClientService;
+ };
+
+ const TString TEST_SERVICE = "test-service";
+ const TString TEST_GET_PARAMS = "p=GET";
+ const TString TEST_POST_PARAMS = "p=POST";
+ const TString TEST_POST_HEADERS = "Content-Type: application/json\r\n";
+ const TString TEST_GET_RECV = "GET was ok.";
+ const TString TEST_POST_RECV = "POST was ok.";
+
+ TString BuildServiceLocation(ui32 port) {
+ return Sprintf("http://*:%" PRIu32 "/%s", port, TEST_SERVICE.data());
+ }
+
+ TString BuildPostServiceLocation(ui32 port) {
+ return Sprintf("post://*:%" PRIu32 "/%s", port + 1, TEST_SERVICE.data());
+ }
+
+ TString BuildGetTestRequest(ui32 port) {
+ return BuildServiceLocation(port) + "?" + TEST_GET_PARAMS;
+ }
+
+ class TSimpleServer {
+ public:
+ inline void ServeRequest(const NNeh::IRequestRef& req) {
+ NNeh::TData response;
+ if (req->Data() == TEST_GET_PARAMS) {
+ response.assign(TEST_GET_RECV.begin(), TEST_GET_RECV.end());
+ } else {
+ response.assign(TEST_POST_RECV.begin(), TEST_POST_RECV.end());
+ }
+ req->SendReply(response);
+ }
+ };
+
+ NNeh::IServicesRef RunServer(ui32 port, TSimpleServer& server) {
+ NNeh::IServicesRef runner = NNeh::CreateLoop();
+ runner->Add(BuildServiceLocation(port), server);
+ runner->Add(BuildPostServiceLocation(port), server);
+
+ try {
+ const int THR_POOL_SIZE = 2;
+ runner->ForkLoop(THR_POOL_SIZE);
+ } catch (...) {
+ Y_FAIL("Can't run server: %s", CurrentExceptionMessage().data());
+ }
+
+ return runner;
+ }
+ enum ERequestType {
+ RT_HTTP_GET = 0,
+ RT_HTTP_POST = 1
+ };
+
+ using TTaskParam = std::pair<TIpPort, ERequestType>;
+
+ class THttpClientTask: public ISimpleTask {
+ public:
+ THttpClientTask(THttpClientEnv* env, TTaskParam param)
+ : Env(env)
+ , ServerPort(param.first)
+ , ReqType(param.second)
+ {
+ }
+
+ TContinueFunc Start() override {
+ switch (ReqType) {
+ case RT_HTTP_GET: {
+ TString getRequest = BuildGetTestRequest(ServerPort);
+ for (size_t i = 0; i < 3; ++i) {
+ Requests.push_back(new THttpFuture());
+ Env->HttpClientService.Send(getRequest, Requests[i].Get());
+ }
+ break;
+ }
+ case RT_HTTP_POST: {
+ TString servicePath = BuildPostServiceLocation(ServerPort);
+ TStringInput headersText(TEST_POST_HEADERS);
+ THttpHeaders headers(&headersText);
+ for (size_t i = 0; i < 3; ++i) {
+ Requests.push_back(new THttpFuture());
+ Env->HttpClientService.SendPost(servicePath, TEST_POST_PARAMS, headers, Requests[i].Get());
+ }
+ break;
+ }
+ }
+
+ return &THttpClientTask::GotReplies;
+ }
+
+ TContinueFunc GotReplies() {
+ const TString& TEST_OK_RECV = (ReqType == RT_HTTP_GET) ? TEST_GET_RECV : TEST_POST_RECV;
+ for (size_t i = 0; i < Requests.size(); ++i) {
+ UNIT_ASSERT_EQUAL(Requests[i]->GetHttpCode(), 200);
+ UNIT_ASSERT_EQUAL(Requests[i]->GetResponseBody(), TEST_OK_RECV);
+ }
+
+ Env->TestSync.CheckAndIncrement(0);
+
+ return nullptr;
+ }
+
+ THttpClientEnv* const Env;
+ const TIpPort ServerPort;
+ const ERequestType ReqType;
+
+ TVector<TSimpleSharedPtr<THttpFuture>> Requests;
+ };
+
+} // anonymous namespace
+
+Y_UNIT_TEST_SUITE(RainCheckHttpClient) {
+ static const TIpPort SERVER_PORT = 4000;
+
+ Y_UNIT_TEST(Simple) {
+ // TODO: randomize port
+ if (!IsFixedPortTestAllowed()) {
+ return;
+ }
+
+ TSimpleServer server;
+ NNeh::IServicesRef runner = RunServer(SERVER_PORT, server);
+
+ THttpClientEnv env;
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<THttpClientTask>(TTaskParam(SERVER_PORT, RT_HTTP_GET));
+
+ env.TestSync.WaitForAndIncrement(1);
+ }
+
+ Y_UNIT_TEST(SimplePost) {
+ // TODO: randomize port
+ if (!IsFixedPortTestAllowed()) {
+ return;
+ }
+
+ TSimpleServer server;
+ NNeh::IServicesRef runner = RunServer(SERVER_PORT, server);
+
+ THttpClientEnv env;
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<THttpClientTask>(TTaskParam(SERVER_PORT, RT_HTTP_POST));
+
+ env.TestSync.WaitForAndIncrement(1);
+ }
+
+ Y_UNIT_TEST(HttpCodeExtraction) {
+ // Find "request failed(" string, then copy len("HTTP/1.X NNN") chars and try to convert NNN to HTTP code.
+
+#define CHECK_VALID_LINE(line, code) \
+ UNIT_ASSERT_NO_EXCEPTION(TryGetHttpCodeFromErrorDescription(line)); \
+ UNIT_ASSERT(!!TryGetHttpCodeFromErrorDescription(line)); \
+ UNIT_ASSERT_EQUAL(*TryGetHttpCodeFromErrorDescription(line), code)
+
+ CHECK_VALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.0 200 Some random message"), 200);
+ CHECK_VALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.0 404 Some random message"), 404);
+ CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.0 100 Some random message"), 100);
+ CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.0 105)"), 105);
+ CHECK_VALID_LINE(TStringBuf("request failed(HTTP/1.1 2004 Some random message"), 200);
+#undef CHECK_VALID_LINE
+
+#define CHECK_INVALID_LINE(line) \
+ UNIT_ASSERT_NO_EXCEPTION(TryGetHttpCodeFromErrorDescription(line)); \
+ UNIT_ASSERT(!TryGetHttpCodeFromErrorDescription(line))
+
+ CHECK_INVALID_LINE(TStringBuf("library/cpp/neh/http.cpp:<LINE>: request failed(HTTP/1.1 1 Some random message"));
+ CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.0 asdf Some random message"));
+ CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200 Some random message"));
+ CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.0 2x00 Some random message"));
+ CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200 Some random message"));
+ CHECK_INVALID_LINE(TStringBuf("HTTP/1.0 200"));
+ CHECK_INVALID_LINE(TStringBuf("request failed(HTTP/1.1 3334 Some random message"));
+#undef CHECK_INVALID_LINE
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp b/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp
new file mode 100644
index 0000000000..51d75762f6
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/http_code_extractor.cpp
@@ -0,0 +1,39 @@
+#include "http_code_extractor.h"
+
+#include <library/cpp/http/io/stream.h>
+#include <library/cpp/http/misc/httpcodes.h>
+
+#include <util/generic/maybe.h>
+#include <util/generic/strbuf.h>
+#include <util/string/cast.h>
+
+namespace NRainCheck {
+ TMaybe<HttpCodes> TryGetHttpCodeFromErrorDescription(const TStringBuf& errorMessage) {
+ // Try to get HttpCode from library/cpp/neh response.
+ // If response has HttpCode and it is not 200 OK, library/cpp/neh will send a message
+ // "library/cpp/neh/http.cpp:<LINE>: request failed(<FIRST-HTTP-RESPONSE-LINE>)"
+ // (see library/cpp/neh/http.cpp:625). So, we will try to parse this message and
+ // find out HttpCode in it. It is bad temporary solution, but we have no choice.
+ const TStringBuf SUBSTR = "request failed(";
+ const size_t SUBSTR_LEN = SUBSTR.size();
+ const size_t FIRST_LINE_LEN = TStringBuf("HTTP/1.X NNN").size();
+
+ TMaybe<HttpCodes> httpCode;
+
+ const size_t substrPos = errorMessage.find(SUBSTR);
+ if (substrPos != TStringBuf::npos) {
+ const TStringBuf firstLineStart = errorMessage.SubStr(substrPos + SUBSTR_LEN, FIRST_LINE_LEN);
+ try {
+ httpCode = static_cast<HttpCodes>(ParseHttpRetCode(firstLineStart));
+ if (*httpCode < HTTP_CONTINUE || *httpCode >= HTTP_CODE_MAX) {
+ httpCode = Nothing();
+ }
+ } catch (const TFromStringException& ex) {
+ // Can't parse HttpCode: it is OK, because ErrorDescription can be random string.
+ }
+ }
+
+ return httpCode;
+ }
+
+}
diff --git a/library/cpp/messagebus/rain_check/http/http_code_extractor.h b/library/cpp/messagebus/rain_check/http/http_code_extractor.h
new file mode 100644
index 0000000000..33b565fa1c
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/http_code_extractor.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <library/cpp/http/misc/httpcodes.h>
+
+#include <util/generic/maybe.h>
+#include <util/generic/strbuf.h>
+
+namespace NRainCheck {
+ // Try to get HttpCode from library/cpp/neh response.
+ // If response has HttpCode and it is not 200 OK, library/cpp/neh will send a message
+ // "library/cpp/neh/http.cpp:<LINE>: request failed(<FIRST-HTTP-RESPONSE-LINE>)"
+ // (see library/cpp/neh/http.cpp:625). So, we will try to parse this message and
+ // find out HttpCode in it. It is bad temporary solution, but we have no choice.
+ TMaybe<HttpCodes> TryGetHttpCodeFromErrorDescription(const TStringBuf& errorMessage);
+
+}
diff --git a/library/cpp/messagebus/rain_check/http/ya.make b/library/cpp/messagebus/rain_check/http/ya.make
new file mode 100644
index 0000000000..ef13329df3
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/http/ya.make
@@ -0,0 +1,17 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+SRCS(
+ client.cpp
+ http_code_extractor.cpp
+)
+
+PEERDIR(
+ library/cpp/messagebus/rain_check/core
+ library/cpp/neh
+ library/cpp/http/misc
+ library/cpp/http/io
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp
new file mode 100644
index 0000000000..daac8d9a99
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.cpp
@@ -0,0 +1,98 @@
+#include "messagebus_client.h"
+
+using namespace NRainCheck;
+using namespace NBus;
+
+TBusClientService::TBusClientService(
+ const NBus::TBusSessionConfig& config,
+ NBus::TBusProtocol* proto,
+ NBus::TBusMessageQueue* queue) {
+ Session = queue->CreateSource(proto, this, config);
+}
+
+TBusClientService::~TBusClientService() {
+ Session->Shutdown();
+}
+
+void TBusClientService::SendCommon(NBus::TBusMessage* message, const NBus::TNetAddr&, TBusFuture* future) {
+ TTaskRunnerBase* current = TTaskRunnerBase::CurrentTask();
+
+ future->SetRunning(current);
+
+ future->Task = current;
+
+ // after this statement message is owned by both messagebus and future
+ future->Request.Reset(message);
+
+ // TODO: allow cookie in messagebus
+ message->Data = future;
+}
+
+void TBusClientService::ProcessResultCommon(NBus::TBusMessageAutoPtr message,
+ const NBus::TNetAddr&, TBusFuture* future,
+ NBus::EMessageStatus status) {
+ Y_UNUSED(message.Release());
+
+ if (status == NBus::MESSAGE_OK) {
+ return;
+ }
+
+ future->SetDoneAndSchedule(status, nullptr);
+}
+
+void TBusClientService::SendOneWay(
+ NBus::TBusMessageAutoPtr message, const NBus::TNetAddr& addr,
+ TBusFuture* future) {
+ SendCommon(message.Get(), addr, future);
+
+ EMessageStatus ok = Session->SendMessageOneWay(message.Get(), &addr, false);
+ ProcessResultCommon(message, addr, future, ok);
+}
+
+NBus::TBusClientSessionPtr TBusClientService::GetSessionForMonitoring() const {
+ return Session;
+}
+
+void TBusClientService::Send(
+ TBusMessageAutoPtr message, const TNetAddr& addr,
+ TBusFuture* future) {
+ SendCommon(message.Get(), addr, future);
+
+ EMessageStatus ok = Session->SendMessage(message.Get(), &addr, false);
+ ProcessResultCommon(message, addr, future, ok);
+}
+
+void TBusClientService::OnReply(
+ TAutoPtr<TBusMessage> request,
+ TAutoPtr<TBusMessage> response) {
+ TBusFuture* future = (TBusFuture*)request->Data;
+ Y_ASSERT(future->Request.Get() == request.Get());
+ Y_UNUSED(request.Release());
+ future->SetDoneAndSchedule(MESSAGE_OK, response);
+}
+
+void NRainCheck::TBusClientService::OnMessageSentOneWay(
+ TAutoPtr<NBus::TBusMessage> request) {
+ TBusFuture* future = (TBusFuture*)request->Data;
+ Y_ASSERT(future->Request.Get() == request.Get());
+ Y_UNUSED(request.Release());
+ future->SetDoneAndSchedule(MESSAGE_OK, nullptr);
+}
+
+void TBusClientService::OnError(
+ TAutoPtr<TBusMessage> message, NBus::EMessageStatus status) {
+ if (message->Data == nullptr) {
+ return;
+ }
+
+ TBusFuture* future = (TBusFuture*)message->Data;
+ Y_ASSERT(future->Request.Get() == message.Get());
+ Y_UNUSED(message.Release());
+ future->SetDoneAndSchedule(status, nullptr);
+}
+
+void TBusFuture::SetDoneAndSchedule(EMessageStatus status, TAutoPtr<TBusMessage> response) {
+ Status = status;
+ Response.Reset(response.Release());
+ SetDone();
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h
new file mode 100644
index 0000000000..0a291cdea6
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client.h
@@ -0,0 +1,67 @@
+#pragma once
+
+#include <library/cpp/messagebus/rain_check/core/task.h>
+
+#include <library/cpp/messagebus/ybus.h>
+
+namespace NRainCheck {
+ class TBusFuture: public TSubtaskCompletion {
+ friend class TBusClientService;
+
+ private:
+ THolder<NBus::TBusMessage> Request;
+ THolder<NBus::TBusMessage> Response;
+ NBus::EMessageStatus Status;
+
+ private:
+ TTaskRunnerBase* Task;
+
+ void SetDoneAndSchedule(NBus::EMessageStatus, TAutoPtr<NBus::TBusMessage>);
+
+ public:
+ // TODO: add MESSAGE_UNDEFINED
+ TBusFuture()
+ : Status(NBus::MESSAGE_DONT_ASK)
+ , Task(nullptr)
+ {
+ }
+
+ NBus::TBusMessage* GetRequest() const {
+ return Request.Get();
+ }
+
+ NBus::TBusMessage* GetResponse() const {
+ Y_ASSERT(IsDone());
+ return Response.Get();
+ }
+
+ NBus::EMessageStatus GetStatus() const {
+ Y_ASSERT(IsDone());
+ return Status;
+ }
+ };
+
+ class TBusClientService: private NBus::IBusClientHandler {
+ private:
+ NBus::TBusClientSessionPtr Session;
+
+ public:
+ TBusClientService(const NBus::TBusSessionConfig&, NBus::TBusProtocol*, NBus::TBusMessageQueue*);
+ ~TBusClientService() override;
+
+ void Send(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future);
+ void SendOneWay(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future);
+
+ // Use it only for monitoring
+ NBus::TBusClientSessionPtr GetSessionForMonitoring() const;
+
+ private:
+ void SendCommon(NBus::TBusMessage*, const NBus::TNetAddr&, TBusFuture* future);
+ void ProcessResultCommon(NBus::TBusMessageAutoPtr, const NBus::TNetAddr&, TBusFuture* future, NBus::EMessageStatus);
+
+ void OnReply(TAutoPtr<NBus::TBusMessage> pMessage, TAutoPtr<NBus::TBusMessage> pReply) override;
+ void OnError(TAutoPtr<NBus::TBusMessage> pMessage, NBus::EMessageStatus status) override;
+ void OnMessageSentOneWay(TAutoPtr<NBus::TBusMessage>) override;
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp
new file mode 100644
index 0000000000..1b3618558b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_client_ut.cpp
@@ -0,0 +1,146 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "messagebus_client.h"
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+
+#include <util/generic/cast.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+using namespace NRainCheck;
+
+struct TMessageBusClientEnv: public TTestEnvTemplate<TMessageBusClientEnv> {
+ // TODO: use same thread pool
+ TBusMessageQueuePtr Queue;
+ TExampleProtocol Proto;
+ TBusClientService BusClientService;
+
+ static TBusQueueConfig QueueConfig() {
+ TBusQueueConfig r;
+ r.NumWorkers = 4;
+ return r;
+ }
+
+ TMessageBusClientEnv()
+ : Queue(CreateMessageQueue(GetExecutor()))
+ , BusClientService(TBusSessionConfig(), &Proto, Queue.Get())
+ {
+ }
+};
+
+Y_UNIT_TEST_SUITE(RainCheckMessageBusClient) {
+ struct TSimpleTask: public ISimpleTask {
+ TMessageBusClientEnv* const Env;
+
+ const unsigned ServerPort;
+
+ TSimpleTask(TMessageBusClientEnv* env, unsigned serverPort)
+ : Env(env)
+ , ServerPort(serverPort)
+ {
+ }
+
+ TVector<TSimpleSharedPtr<TBusFuture>> Requests;
+
+ TContinueFunc Start() override {
+ for (unsigned i = 0; i < 3; ++i) {
+ Requests.push_back(new TBusFuture);
+ TNetAddr addr("localhost", ServerPort);
+ Env->BusClientService.Send(new TExampleRequest(&Env->Proto.RequestCount), addr, Requests[i].Get());
+ }
+
+ return TContinueFunc(&TSimpleTask::GotReplies);
+ }
+
+ TContinueFunc GotReplies() {
+ for (unsigned i = 0; i < Requests.size(); ++i) {
+ Y_VERIFY(Requests[i]->GetStatus() == MESSAGE_OK);
+ VerifyDynamicCast<TExampleResponse*>(Requests[i]->GetResponse());
+ }
+ Env->TestSync.CheckAndIncrement(0);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TMessageBusClientEnv env;
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TSimpleTask>(server.GetActualListenPort());
+
+ env.TestSync.WaitForAndIncrement(1);
+ }
+
+ struct TOneWayServer: public NBus::IBusServerHandler {
+ TTestSync* const TestSync;
+ TExampleProtocol Proto;
+ NBus::TBusMessageQueuePtr Queue;
+ NBus::TBusServerSessionPtr Session;
+
+ TOneWayServer(TTestSync* testSync)
+ : TestSync(testSync)
+ {
+ Queue = CreateMessageQueue();
+ Session = Queue->CreateDestination(&Proto, this, NBus::TBusSessionConfig());
+ }
+
+ void OnMessage(NBus::TOnMessageContext& context) override {
+ TestSync->CheckAndIncrement(1);
+ context.ForgetRequest();
+ }
+ };
+
+ struct TOneWayTask: public ISimpleTask {
+ TMessageBusClientEnv* const Env;
+
+ const unsigned ServerPort;
+
+ TOneWayTask(TMessageBusClientEnv* env, unsigned serverPort)
+ : Env(env)
+ , ServerPort(serverPort)
+ {
+ }
+
+ TVector<TSimpleSharedPtr<TBusFuture>> Requests;
+
+ TContinueFunc Start() override {
+ Env->TestSync.CheckAndIncrement(0);
+
+ for (unsigned i = 0; i < 1; ++i) {
+ Requests.push_back(new TBusFuture);
+ TNetAddr addr("localhost", ServerPort);
+ Env->BusClientService.SendOneWay(new TExampleRequest(&Env->Proto.RequestCount), addr, Requests[i].Get());
+ }
+
+ return TContinueFunc(&TOneWayTask::GotReplies);
+ }
+
+ TContinueFunc GotReplies() {
+ for (unsigned i = 0; i < Requests.size(); ++i) {
+ Y_VERIFY(Requests[i]->GetStatus() == MESSAGE_OK);
+ Y_VERIFY(!Requests[i]->GetResponse());
+ }
+ Env->TestSync.WaitForAndIncrement(2);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(OneWay) {
+ TObjectCountCheck objectCountCheck;
+
+ TMessageBusClientEnv env;
+
+ TOneWayServer server(&env.TestSync);
+
+ TIntrusivePtr<TSimpleTaskRunner> task = env.SpawnTask<TOneWayTask>(server.Session->GetActualListenPort());
+
+ env.TestSync.WaitForAndIncrement(3);
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp
new file mode 100644
index 0000000000..5d4b13d664
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.cpp
@@ -0,0 +1,17 @@
+#include "messagebus_server.h"
+
+#include <library/cpp/messagebus/rain_check/core/spawn.h>
+
+using namespace NRainCheck;
+
+TBusTaskStarter::TBusTaskStarter(TAutoPtr<ITaskFactory> taskFactory)
+ : TaskFactory(taskFactory)
+{
+}
+
+void TBusTaskStarter::OnMessage(NBus::TOnMessageContext& onMessage) {
+ TaskFactory->NewTask(onMessage);
+}
+
+TBusTaskStarter::~TBusTaskStarter() {
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h
new file mode 100644
index 0000000000..1334f05fe4
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server.h
@@ -0,0 +1,46 @@
+#pragma once
+
+#include <library/cpp/messagebus/rain_check/core/spawn.h>
+#include <library/cpp/messagebus/rain_check/core/task.h>
+
+#include <library/cpp/messagebus/ybus.h>
+
+#include <util/system/yassert.h>
+
+namespace NRainCheck {
+ class TBusTaskStarter: public NBus::IBusServerHandler {
+ private:
+ struct ITaskFactory {
+ virtual void NewTask(NBus::TOnMessageContext&) = 0;
+ virtual ~ITaskFactory() {
+ }
+ };
+
+ THolder<ITaskFactory> TaskFactory;
+
+ void OnMessage(NBus::TOnMessageContext&) override;
+
+ public:
+ TBusTaskStarter(TAutoPtr<ITaskFactory>);
+ ~TBusTaskStarter() override;
+
+ public:
+ template <typename TTask, typename TEnv>
+ static TAutoPtr<TBusTaskStarter> NewStarter(TEnv* env) {
+ struct TTaskFactory: public ITaskFactory {
+ TEnv* const Env;
+
+ TTaskFactory(TEnv* env)
+ : Env(env)
+ {
+ }
+
+ void NewTask(NBus::TOnMessageContext& context) override {
+ SpawnTask<TTask, TEnv, NBus::TOnMessageContext&>(Env, context);
+ }
+ };
+
+ return new TBusTaskStarter(new TTaskFactory(env));
+ }
+ };
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp b/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp
new file mode 100644
index 0000000000..7c11399f1b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/messagebus_server_ut.cpp
@@ -0,0 +1,51 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "messagebus_server.h"
+
+#include <library/cpp/messagebus/rain_check/test/ut/test.h>
+
+#include <library/cpp/messagebus/test/helper/example.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+using namespace NRainCheck;
+
+struct TMessageBusServerEnv: public TTestEnvTemplate<TMessageBusServerEnv> {
+ TExampleProtocol Proto;
+};
+
+Y_UNIT_TEST_SUITE(RainCheckMessageBusServer) {
+ struct TSimpleServerTask: public ISimpleTask {
+ private:
+ TMessageBusServerEnv* const Env;
+ TOnMessageContext MessageContext;
+
+ public:
+ TSimpleServerTask(TMessageBusServerEnv* env, TOnMessageContext& messageContext)
+ : Env(env)
+ {
+ MessageContext.Swap(messageContext);
+ }
+
+ TContinueFunc Start() override {
+ MessageContext.SendReplyMove(new TExampleResponse(&Env->Proto.ResponseCount));
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TMessageBusServerEnv env;
+
+ THolder<TBusTaskStarter> starter(TBusTaskStarter::NewStarter<TSimpleServerTask>(&env));
+
+ TBusMessageQueuePtr queue(CreateMessageQueue(env.GetExecutor()));
+
+ TExampleProtocol proto;
+
+ TBusServerSessionPtr session = queue->CreateDestination(&env.Proto, starter.Get(), TBusSessionConfig());
+
+ TExampleClient client;
+
+ client.SendMessagesWaitReplies(1, TNetAddr("localhost", session->GetActualListenPort()));
+ }
+}
diff --git a/library/cpp/messagebus/rain_check/messagebus/ya.make b/library/cpp/messagebus/rain_check/messagebus/ya.make
new file mode 100644
index 0000000000..defdac9a61
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/messagebus/ya.make
@@ -0,0 +1,15 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus
+ library/cpp/messagebus/rain_check/core
+)
+
+SRCS(
+ messagebus_client.cpp
+ messagebus_server.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/test/TestRainCheck.py b/library/cpp/messagebus/rain_check/test/TestRainCheck.py
new file mode 100644
index 0000000000..92ed727b62
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/TestRainCheck.py
@@ -0,0 +1,8 @@
+from devtools.fleur.ytest import group, constraint
+from devtools.fleur.ytest.integration import UnitTestGroup
+
+@group
+@constraint('library.messagebus')
+class TestMessageBus3(UnitTestGroup):
+ def __init__(self, context):
+ UnitTestGroup.__init__(self, context, 'MessageBus', 'library-messagebus-rain_check-test-ut')
diff --git a/library/cpp/messagebus/rain_check/test/helper/misc.cpp b/library/cpp/messagebus/rain_check/test/helper/misc.cpp
new file mode 100644
index 0000000000..c0fcb27252
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/helper/misc.cpp
@@ -0,0 +1,27 @@
+#include "misc.h"
+
+#include <util/system/yassert.h>
+
+using namespace NRainCheck;
+
+void TSpawnNopTasksCoroTask::Run() {
+ Y_VERIFY(Count <= Completion.size());
+ for (unsigned i = 0; i < Count; ++i) {
+ SpawnSubtask<TNopCoroTask>(Env, &Completion[i], "");
+ }
+
+ WaitForSubtasks();
+}
+
+TContinueFunc TSpawnNopTasksSimpleTask::Start() {
+ Y_VERIFY(Count <= Completion.size());
+ for (unsigned i = 0; i < Count; ++i) {
+ SpawnSubtask<TNopSimpleTask>(Env, &Completion[i], "");
+ }
+
+ return &TSpawnNopTasksSimpleTask::Join;
+}
+
+TContinueFunc TSpawnNopTasksSimpleTask::Join() {
+ return nullptr;
+}
diff --git a/library/cpp/messagebus/rain_check/test/helper/misc.h b/library/cpp/messagebus/rain_check/test/helper/misc.h
new file mode 100644
index 0000000000..9150be4d2f
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/helper/misc.h
@@ -0,0 +1,57 @@
+#pragma once
+
+#include <library/cpp/messagebus/rain_check/core/rain_check.h>
+
+#include <array>
+
+namespace NRainCheck {
+ struct TNopSimpleTask: public ISimpleTask {
+ TNopSimpleTask(IEnv*, const void*) {
+ }
+
+ TContinueFunc Start() override {
+ return nullptr;
+ }
+ };
+
+ struct TNopCoroTask: public ICoroTask {
+ TNopCoroTask(IEnv*, const void*) {
+ }
+
+ void Run() override {
+ }
+ };
+
+ struct TSpawnNopTasksCoroTask: public ICoroTask {
+ IEnv* const Env;
+ unsigned const Count;
+
+ TSpawnNopTasksCoroTask(IEnv* env, unsigned count)
+ : Env(env)
+ , Count(count)
+ {
+ }
+
+ std::array<TSubtaskCompletion, 2> Completion;
+
+ void Run() override;
+ };
+
+ struct TSpawnNopTasksSimpleTask: public ISimpleTask {
+ IEnv* const Env;
+ unsigned const Count;
+
+ TSpawnNopTasksSimpleTask(IEnv* env, unsigned count)
+ : Env(env)
+ , Count(count)
+ {
+ }
+
+ std::array<TSubtaskCompletion, 2> Completion;
+
+ TContinueFunc Start() override;
+
+ TContinueFunc Join();
+ };
+
+}
diff --git a/library/cpp/messagebus/rain_check/test/helper/ya.make b/library/cpp/messagebus/rain_check/test/helper/ya.make
new file mode 100644
index 0000000000..aa9e4e6d81
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/helper/ya.make
@@ -0,0 +1,13 @@
+LIBRARY(messagebus-rain_check-test-helper)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus/rain_check/core
+)
+
+SRCS(
+ misc.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp b/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp
new file mode 100644
index 0000000000..22edbd8c6b
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/perftest/perftest.cpp
@@ -0,0 +1,154 @@
+#include <library/cpp/messagebus/rain_check/test/helper/misc.h>
+
+#include <library/cpp/messagebus/rain_check/core/rain_check.h>
+
+#include <util/datetime/base.h>
+
+#include <array>
+
+using namespace NRainCheck;
+
+static const unsigned SUBTASKS = 2;
+
+struct TRainCheckPerftestEnv: public TSimpleEnvTemplate<TRainCheckPerftestEnv> {
+ unsigned SubtasksPerTask;
+
+ TRainCheckPerftestEnv()
+ : TSimpleEnvTemplate<TRainCheckPerftestEnv>(4)
+ , SubtasksPerTask(1000)
+ {
+ }
+};
+
+struct TCoroOuter: public ICoroTask {
+ TRainCheckPerftestEnv* const Env;
+
+ TCoroOuter(TRainCheckPerftestEnv* env)
+ : Env(env)
+ {
+ }
+
+ void Run() override {
+ for (;;) {
+ TInstant start = TInstant::Now();
+
+ unsigned count = 0;
+
+ unsigned current = 1000;
+
+ do {
+ for (unsigned i = 0; i < current; ++i) {
+ std::array<TSubtaskCompletion, SUBTASKS> completion;
+
+ for (unsigned j = 0; j < SUBTASKS; ++j) {
+ //SpawnSubtask<TNopSimpleTask>(Env, &completion[j]);
+ //SpawnSubtask<TSpawnNopTasksCoroTask>(Env, &completion[j], SUBTASKS);
+ SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &completion[j], SUBTASKS);
+ }
+
+ WaitForSubtasks();
+ }
+
+ count += current;
+ current *= 2;
+ } while (TInstant::Now() - start < TDuration::Seconds(1));
+
+ TDuration d = TInstant::Now() - start;
+ unsigned dns = d.NanoSeconds() / count;
+ Cerr << dns << "ns per spawn/join\n";
+ }
+ }
+};
+
+struct TSimpleOuter: public ISimpleTask {
+ TRainCheckPerftestEnv* const Env;
+
+ TSimpleOuter(TRainCheckPerftestEnv* env, const void*)
+ : Env(env)
+ {
+ }
+
+ TInstant StartInstant;
+ unsigned Count;
+ unsigned Current;
+ unsigned I;
+
+ TContinueFunc Start() override {
+ StartInstant = TInstant::Now();
+ Count = 0;
+ Current = 1000;
+ I = 0;
+
+ return &TSimpleOuter::Spawn;
+ }
+
+ std::array<TSubtaskCompletion, SUBTASKS> Completion;
+
+ TContinueFunc Spawn() {
+ for (unsigned j = 0; j < SUBTASKS; ++j) {
+ //SpawnSubtask<TNopSimpleTask>(Env, &Completion[j]);
+ //SpawnSubtask<TSpawnNopTasksCoroTask>(Env, &Completion[j], SUBTASKS);
+ SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &Completion[j], SUBTASKS);
+ }
+
+ return &TSimpleOuter::Join;
+ }
+
+ TContinueFunc Join() {
+ I += 1;
+ if (I != Current) {
+ return &TSimpleOuter::Spawn;
+ }
+
+ I = 0;
+ Count += Current;
+ Current *= 2;
+
+ TDuration d = TInstant::Now() - StartInstant;
+ if (d < TDuration::Seconds(1)) {
+ return &TSimpleOuter::Spawn;
+ }
+
+ unsigned dns = d.NanoSeconds() / Count;
+ Cerr << dns << "ns per spawn/join\n";
+
+ return &TSimpleOuter::Start;
+ }
+};
+
+struct TReproduceCrashTask: public ISimpleTask {
+ TRainCheckPerftestEnv* const Env;
+
+ TReproduceCrashTask(TRainCheckPerftestEnv* env)
+ : Env(env)
+ {
+ }
+
+ std::array<TSubtaskCompletion, SUBTASKS> Completion;
+
+ TContinueFunc Start() override {
+ for (unsigned j = 0; j < 2; ++j) {
+ //SpawnSubtask<TNopSimpleTask>(Env, &Completion[j]);
+ SpawnSubtask<TSpawnNopTasksSimpleTask>(Env, &Completion[j], SUBTASKS);
+ }
+
+ return &TReproduceCrashTask::Start;
+ }
+};
+
+int main(int argc, char** argv) {
+ Y_UNUSED(argc);
+ Y_UNUSED(argv);
+
+ TRainCheckPerftestEnv env;
+
+ env.SpawnTask<TSimpleOuter>("");
+ //env.SpawnTask<TCoroOuter>();
+ //env.SpawnTask<TReproduceCrashTask>();
+
+ for (;;) {
+ Sleep(TDuration::Hours(1));
+ }
+
+ return 0;
+}
diff --git a/library/cpp/messagebus/rain_check/test/perftest/ya.make b/library/cpp/messagebus/rain_check/test/perftest/ya.make
new file mode 100644
index 0000000000..7330a71700
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/perftest/ya.make
@@ -0,0 +1,14 @@
+PROGRAM(messagebus_rain_check_perftest)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus/rain_check/core
+ library/cpp/messagebus/rain_check/test/helper
+)
+
+SRCS(
+ perftest.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/test/ut/test.h b/library/cpp/messagebus/rain_check/test/ut/test.h
new file mode 100644
index 0000000000..724f6b7530
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/ut/test.h
@@ -0,0 +1,13 @@
+#pragma once
+
+#include <library/cpp/messagebus/rain_check/core/rain_check.h>
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+
+template <typename TSelf>
+struct TTestEnvTemplate: public NRainCheck::TSimpleEnvTemplate<TSelf> {
+ TTestSync TestSync;
+};
+
+struct TTestEnv: public TTestEnvTemplate<TTestEnv> {
+};
diff --git a/library/cpp/messagebus/rain_check/test/ut/ya.make b/library/cpp/messagebus/rain_check/test/ut/ya.make
new file mode 100644
index 0000000000..9f7a93417a
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/ut/ya.make
@@ -0,0 +1,24 @@
+PROGRAM(library-messagebus-rain_check-test-ut)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/testing/unittest_main
+ library/cpp/messagebus/rain_check/core
+ library/cpp/messagebus/rain_check/http
+ library/cpp/messagebus/rain_check/messagebus
+ library/cpp/messagebus/test/helper
+)
+
+SRCS(
+ ../../core/coro_ut.cpp
+ ../../core/simple_ut.cpp
+ ../../core/sleep_ut.cpp
+ ../../core/spawn_ut.cpp
+ ../../core/track_ut.cpp
+ ../../http/client_ut.cpp
+ ../../messagebus/messagebus_client_ut.cpp
+ ../../messagebus/messagebus_server_ut.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/rain_check/test/ya.make b/library/cpp/messagebus/rain_check/test/ya.make
new file mode 100644
index 0000000000..4c1d6f8161
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/test/ya.make
@@ -0,0 +1,6 @@
+OWNER(g:messagebus)
+
+RECURSE(
+ perftest
+ ut
+)
diff --git a/library/cpp/messagebus/rain_check/ya.make b/library/cpp/messagebus/rain_check/ya.make
new file mode 100644
index 0000000000..966d54c232
--- /dev/null
+++ b/library/cpp/messagebus/rain_check/ya.make
@@ -0,0 +1,8 @@
+OWNER(g:messagebus)
+
+RECURSE(
+ core
+ http
+ messagebus
+ test
+)
diff --git a/library/cpp/messagebus/ref_counted.h b/library/cpp/messagebus/ref_counted.h
new file mode 100644
index 0000000000..29b87764e3
--- /dev/null
+++ b/library/cpp/messagebus/ref_counted.h
@@ -0,0 +1,6 @@
+#pragma once
+
+class TAtomicRefCountedObject: public TAtomicRefCount<TAtomicRefCountedObject> {
+ virtual ~TAtomicRefCountedObject() {
+ }
+};
diff --git a/library/cpp/messagebus/remote_client_connection.cpp b/library/cpp/messagebus/remote_client_connection.cpp
new file mode 100644
index 0000000000..8c7a6db3a8
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_connection.cpp
@@ -0,0 +1,343 @@
+#include "remote_client_connection.h"
+
+#include "mb_lwtrace.h"
+#include "network.h"
+#include "remote_client_session.h"
+
+#include <library/cpp/messagebus/actor/executor.h>
+#include <library/cpp/messagebus/actor/temp_tls_vector.h>
+
+#include <util/generic/cast.h>
+#include <util/thread/singleton.h>
+
+LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER)
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteClientConnection::TRemoteClientConnection(TRemoteClientSessionPtr session, ui64 id, TNetAddr addr)
+ : TRemoteConnection(session.Get(), id, addr)
+ , ClientHandler(GetSession()->ClientHandler)
+{
+ Y_VERIFY(addr.GetPort() > 0, "must connect to non-zero port");
+
+ ScheduleWrite();
+}
+
+TRemoteClientSession* TRemoteClientConnection::GetSession() {
+ return CheckedCast<TRemoteClientSession*>(Session.Get());
+}
+
+TBusMessage* TRemoteClientConnection::PopAck(TBusKey id) {
+ return AckMessages.Pop(id);
+}
+
+SOCKET TRemoteClientConnection::CreateSocket(const TNetAddr& addr) {
+ SOCKET handle = socket(addr.Addr()->sa_family, SOCK_STREAM, 0);
+ Y_VERIFY(handle != INVALID_SOCKET, "failed to create socket: %s", LastSystemErrorText());
+
+ TSocketHolder s(handle);
+
+ SetNonBlock(s, true);
+ SetNoDelay(s, Config.TcpNoDelay);
+ SetSockOptTcpCork(s, Config.TcpCork);
+ SetCloseOnExec(s, true);
+ SetKeepAlive(s, true);
+ if (Config.SocketRecvBufferSize != 0) {
+ SetInputBuffer(s, Config.SocketRecvBufferSize);
+ }
+ if (Config.SocketSendBufferSize != 0) {
+ SetOutputBuffer(s, Config.SocketSendBufferSize);
+ }
+ if (Config.SocketToS >= 0) {
+ SetSocketToS(s, &addr, Config.SocketToS);
+ }
+
+ return s.Release();
+}
+
+void TRemoteClientConnection::TryConnect() {
+ if (AtomicGet(WriterData.Down)) {
+ return;
+ }
+ Y_VERIFY(!WriterData.Status.Connected);
+
+ TInstant now = TInstant::Now();
+
+ if (!WriterData.Channel) {
+ if ((now - LastConnectAttempt) < TDuration::MilliSeconds(Config.RetryInterval)) {
+ DropEnqueuedData(MESSAGE_CONNECT_FAILED, MESSAGE_CONNECT_FAILED);
+ return;
+ }
+ LastConnectAttempt = now;
+
+ TSocket connectSocket(CreateSocket(PeerAddr));
+ WriterData.SetChannel(Session->WriteEventLoop.Register(connectSocket, this, WriteCookie));
+ }
+
+ if (BeforeSendQueue.IsEmpty() && WriterData.SendQueue.Empty() && !Config.ReconnectWhenIdle) {
+ // TryConnect is called from Writer::Act, which is called in cycle
+ // from session's ScheduleTimeoutMessages via Cron. This prevent these excessive connects.
+ return;
+ }
+
+ ++WriterData.Status.ConnectSyscalls;
+
+ int ret = connect(WriterData.Channel->GetSocket(), PeerAddr.Addr(), PeerAddr.Len());
+ int err = ret ? LastSystemError() : 0;
+
+ if (!ret || (ret && err == EISCONN)) {
+ WriterData.Status.ConnectTime = now;
+ ++WriterData.SocketVersion;
+
+ WriterData.Channel->DisableWrite();
+ WriterData.Status.Connected = true;
+ AtomicSet(ReturnConnectFailedImmediately, false);
+
+ WriterData.Status.MyAddr = TNetAddr(GetSockAddr(WriterData.Channel->GetSocket()));
+
+ TSocket readSocket = WriterData.Channel->GetSocketPtr();
+
+ ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(readSocket, WriterData.SocketVersion));
+
+ FireClientConnectionEvent(TClientConnectionEvent::CONNECTED);
+
+ ScheduleWrite();
+ } else {
+ if (WouldBlock() || err == EALREADY) {
+ WriterData.Channel->EnableWrite();
+ } else {
+ WriterData.DropChannel();
+ WriterData.Status.MyAddr = TNetAddr();
+ WriterData.Status.Connected = false;
+ WriterData.Status.ConnectError = err;
+
+ DropEnqueuedData(MESSAGE_CONNECT_FAILED, MESSAGE_CONNECT_FAILED);
+ }
+ }
+}
+
+void TRemoteClientConnection::HandleEvent(SOCKET socket, void* cookie) {
+ Y_UNUSED(socket);
+ Y_ASSERT(cookie == WriteCookie || cookie == ReadCookie);
+ if (cookie == ReadCookie) {
+ ScheduleRead();
+ } else {
+ ScheduleWrite();
+ }
+}
+
+void TRemoteClientConnection::WriterFillStatus() {
+ TRemoteConnection::WriterFillStatus();
+ WriterData.Status.AckMessagesSize = AckMessages.Size();
+}
+
+void TRemoteClientConnection::BeforeTryWrite() {
+ ProcessReplyQueue();
+ TimeoutMessages();
+}
+
+namespace NBus {
+ namespace NPrivate {
+ class TInvokeOnReply: public IWorkItem {
+ private:
+ TRemoteClientSession* RemoteClientSession;
+ TNonDestroyingHolder<TBusMessage> Request;
+ TBusMessagePtrAndHeader Response;
+
+ public:
+ TInvokeOnReply(TRemoteClientSession* session,
+ TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response)
+ : RemoteClientSession(session)
+ , Request(request)
+ {
+ Response.Swap(response);
+ }
+
+ void DoWork() override {
+ THolder<TInvokeOnReply> holder(this);
+ RemoteClientSession->ReleaseInFlightAndCallOnReply(Request.Release(), Response);
+ // TODO: TRemoteClientSessionSemaphore should be enough
+ RemoteClientSession->JobCount.Decrement();
+ }
+ };
+
+ }
+}
+
+void TRemoteClientConnection::ProcessReplyQueue() {
+ if (AtomicGet(WriterData.Down)) {
+ return;
+ }
+
+ bool executeInWorkerPool = Session->Config.ExecuteOnReplyInWorkerPool;
+
+ TTempTlsVector<TBusMessagePtrAndHeader, void, TVectorSwaps> replyQueueTemp;
+ TTempTlsVector< ::NActor::IWorkItem*> workQueueTemp;
+
+ ReplyQueue.DequeueAllSingleConsumer(replyQueueTemp.GetVector());
+ if (executeInWorkerPool) {
+ workQueueTemp.GetVector()->reserve(replyQueueTemp.GetVector()->size());
+ }
+
+ for (auto& resp : *replyQueueTemp.GetVector()) {
+ TBusMessage* req = PopAck(resp.Header.Id);
+
+ if (!req) {
+ WriterErrorMessage(resp.MessagePtr.Release(), MESSAGE_UNKNOWN);
+ continue;
+ }
+
+ if (executeInWorkerPool) {
+ workQueueTemp.GetVector()->push_back(new TInvokeOnReply(GetSession(), req, resp));
+ } else {
+ GetSession()->ReleaseInFlightAndCallOnReply(req, resp);
+ }
+ }
+
+ if (executeInWorkerPool) {
+ Session->JobCount.Add(workQueueTemp.GetVector()->size());
+ Session->Queue->EnqueueWork(*workQueueTemp.GetVector());
+ }
+}
+
+void TRemoteClientConnection::TimeoutMessages() {
+ if (!TimeToTimeoutMessages.FetchTask()) {
+ return;
+ }
+
+ TMessagesPtrs timedOutMessages;
+
+ TInstant sendDeadline;
+ TInstant ackDeadline;
+ if (IsReturnConnectFailedImmediately()) {
+ sendDeadline = TInstant::Max();
+ ackDeadline = TInstant::Max();
+ } else {
+ TInstant now = TInstant::Now();
+ sendDeadline = now - TDuration::MilliSeconds(Session->Config.SendTimeout);
+ ackDeadline = now - TDuration::MilliSeconds(Session->Config.TotalTimeout);
+ }
+
+ {
+ TMessagesPtrs temp;
+ WriterData.SendQueue.Timeout(sendDeadline, &temp);
+ timedOutMessages.insert(timedOutMessages.end(), temp.begin(), temp.end());
+ }
+
+ // Ignores message that is being written currently (that is stored
+ // in WriteMessage). It is not a big problem, because after written
+ // to the network, message will be placed to the AckMessages queue,
+ // and timed out on the next iteration of this procedure.
+
+ {
+ TMessagesPtrs temp;
+ AckMessages.Timeout(ackDeadline, &temp);
+ timedOutMessages.insert(timedOutMessages.end(), temp.begin(), temp.end());
+ }
+
+ ResetOneWayFlag(timedOutMessages);
+
+ GetSession()->ReleaseInFlight(timedOutMessages);
+ WriterErrorMessages(timedOutMessages, MESSAGE_TIMEOUT);
+}
+
+void TRemoteClientConnection::ScheduleTimeoutMessages() {
+ TimeToTimeoutMessages.AddTask();
+ ScheduleWrite();
+}
+
+void TRemoteClientConnection::ReaderProcessMessageUnknownVersion(TArrayRef<const char>) {
+ LWPROBE(Error, ToString(MESSAGE_INVALID_VERSION), ToString(PeerAddr), "");
+ ReaderData.Status.Incremental.StatusCounter[MESSAGE_INVALID_VERSION] += 1;
+ // TODO: close connection
+ Y_FAIL("unknown message");
+}
+
+void TRemoteClientConnection::ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) {
+ Y_ASSERT(result.empty());
+
+ TRemoteConnection::ClearOutgoingQueue(result, reconnect);
+ AckMessages.Clear(&result);
+
+ ResetOneWayFlag(result);
+ GetSession()->ReleaseInFlight(result);
+}
+
+void TRemoteClientConnection::MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) {
+ for (auto& message : messages) {
+ bool oneWay = message.LocalFlags & MESSAGE_ONE_WAY_INTERNAL;
+
+ if (oneWay) {
+ message.MessagePtr->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL;
+
+ TBusMessage* ackMsg = this->PopAck(message.Header.Id);
+ if (!ackMsg) {
+ // TODO: expired?
+ }
+
+ if (ackMsg != message.MessagePtr.Get()) {
+ // TODO: non-unique id?
+ }
+
+ GetSession()->ReleaseInFlight({message.MessagePtr.Get()});
+ ClientHandler->OnMessageSentOneWay(message.MessagePtr.Release());
+ } else {
+ ClientHandler->OnMessageSent(message.MessagePtr.Get());
+ AckMessages.Push(message);
+ }
+ }
+}
+
+EMessageStatus TRemoteClientConnection::SendMessage(TBusMessage* req, bool wait) {
+ return SendMessageImpl(req, wait, false);
+}
+
+EMessageStatus TRemoteClientConnection::SendMessageOneWay(TBusMessage* req, bool wait) {
+ return SendMessageImpl(req, wait, true);
+}
+
+EMessageStatus TRemoteClientConnection::SendMessageImpl(TBusMessage* msg, bool wait, bool oneWay) {
+ msg->CheckClean();
+
+ if (Session->IsDown()) {
+ return MESSAGE_SHUTDOWN;
+ }
+
+ if (wait) {
+ Y_VERIFY(!Session->Queue->GetExecutor()->IsInExecutorThread());
+ GetSession()->ClientRemoteInFlight.Wait();
+ } else {
+ if (!GetSession()->ClientRemoteInFlight.TryWait()) {
+ return MESSAGE_BUSY;
+ }
+ }
+
+ GetSession()->AcquireInFlight({msg});
+
+ EMessageStatus ret = MESSAGE_OK;
+
+ if (oneWay) {
+ msg->LocalFlags |= MESSAGE_ONE_WAY_INTERNAL;
+ }
+
+ msg->GetHeader()->SendTime = Now();
+
+ if (IsReturnConnectFailedImmediately()) {
+ ret = MESSAGE_CONNECT_FAILED;
+ goto clean;
+ }
+
+ Send(msg);
+
+ return MESSAGE_OK;
+clean:
+ msg->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL;
+ GetSession()->ReleaseInFlight({msg});
+ return ret;
+}
+
+void TRemoteClientConnection::OpenConnection() {
+ // TODO
+}
diff --git a/library/cpp/messagebus/remote_client_connection.h b/library/cpp/messagebus/remote_client_connection.h
new file mode 100644
index 0000000000..fe80b7d2f9
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_connection.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include "connection.h"
+#include "local_tasks.h"
+#include "remote_client_session.h"
+#include "remote_connection.h"
+
+#include <util/generic/object_counter.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteClientConnection: public TRemoteConnection, public TBusClientConnection {
+ friend class TRemoteConnection;
+ friend struct TBusSessionImpl;
+ friend class TRemoteClientSession;
+
+ private:
+ TObjectCounter<TRemoteClientConnection> ObjectCounter;
+
+ TSyncAckMessages AckMessages;
+
+ TLocalTasks TimeToTimeoutMessages;
+
+ IBusClientHandler* const ClientHandler;
+
+ public:
+ TRemoteClientConnection(TRemoteClientSessionPtr session, ui64 id, TNetAddr addr);
+
+ inline TRemoteClientSession* GetSession();
+
+ SOCKET CreateSocket(const TNetAddr& addr);
+
+ void TryConnect() override;
+
+ void HandleEvent(SOCKET socket, void* cookie) override;
+
+ TBusMessage* PopAck(TBusKey id);
+
+ void WriterFillStatus() override;
+
+ void ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) override;
+
+ void BeforeTryWrite() override;
+
+ void ProcessReplyQueue();
+
+ void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) override;
+
+ void TimeoutMessages();
+
+ void ScheduleTimeoutMessages();
+
+ void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) override;
+
+ EMessageStatus SendMessage(TBusMessage* pMes, bool wait) override;
+
+ EMessageStatus SendMessageOneWay(TBusMessage* pMes, bool wait) override;
+
+ EMessageStatus SendMessageImpl(TBusMessage*, bool wait, bool oneWay);
+
+ void OpenConnection() override;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_client_session.cpp b/library/cpp/messagebus/remote_client_session.cpp
new file mode 100644
index 0000000000..3bc421944f
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_session.cpp
@@ -0,0 +1,127 @@
+#include "remote_client_session.h"
+
+#include "mb_lwtrace.h"
+#include "remote_client_connection.h"
+
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <util/generic/cast.h>
+#include <util/system/defaults.h>
+
+LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER)
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteClientSession::TRemoteClientSession(TBusMessageQueue* queue,
+ TBusProtocol* proto, IBusClientHandler* handler,
+ const TBusClientSessionConfig& config, const TString& name)
+ : TBusSessionImpl(true, queue, proto, handler, config, name)
+ , ClientRemoteInFlight(config.MaxInFlight, "ClientRemoteInFlight")
+ , ClientHandler(handler)
+{
+}
+
+TRemoteClientSession::~TRemoteClientSession() {
+ //Cerr << "~TRemoteClientSession" << Endl;
+}
+
+void TRemoteClientSession::OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) {
+ TAutoPtr<TVectorSwaps<TBusMessagePtrAndHeader>> temp(new TVectorSwaps<TBusMessagePtrAndHeader>);
+ temp->swap(newMsg);
+ c->ReplyQueue.EnqueueAll(temp);
+ c->ScheduleWrite();
+}
+
+EMessageStatus TRemoteClientSession::SendMessageImpl(TBusMessage* msg, const TNetAddr* addr, bool wait, bool oneWay) {
+ if (Y_UNLIKELY(IsDown())) {
+ return MESSAGE_SHUTDOWN;
+ }
+
+ TBusSocketAddr resolvedAddr;
+ EMessageStatus ret = GetMessageDestination(msg, addr, &resolvedAddr);
+ if (ret != MESSAGE_OK) {
+ return ret;
+ }
+
+ msg->ReplyTo = resolvedAddr;
+
+ TRemoteConnectionPtr c = ((TBusSessionImpl*)this)->GetConnection(resolvedAddr, true);
+ Y_ASSERT(!!c);
+
+ return CheckedCast<TRemoteClientConnection*>(c.Get())->SendMessageImpl(msg, wait, oneWay);
+}
+
+EMessageStatus TRemoteClientSession::SendMessage(TBusMessage* msg, const TNetAddr* addr, bool wait) {
+ return SendMessageImpl(msg, addr, wait, false);
+}
+
+EMessageStatus TRemoteClientSession::SendMessageOneWay(TBusMessage* pMes, const TNetAddr* addr, bool wait) {
+ return SendMessageImpl(pMes, addr, wait, true);
+}
+
+int TRemoteClientSession::GetInFlight() const noexcept {
+ return ClientRemoteInFlight.GetCurrent();
+}
+
+void TRemoteClientSession::FillStatus() {
+ TBusSessionImpl::FillStatus();
+
+ StatusData.Status.InFlightCount = ClientRemoteInFlight.GetCurrent();
+ StatusData.Status.InputPaused = false;
+}
+
+void TRemoteClientSession::AcquireInFlight(TArrayRef<TBusMessage* const> messages) {
+ for (auto message : messages) {
+ Y_ASSERT(!(message->LocalFlags & MESSAGE_IN_FLIGHT_ON_CLIENT));
+ message->LocalFlags |= MESSAGE_IN_FLIGHT_ON_CLIENT;
+ }
+ ClientRemoteInFlight.IncrementMultiple(messages.size());
+}
+
+void TRemoteClientSession::ReleaseInFlight(TArrayRef<TBusMessage* const> messages) {
+ for (auto message : messages) {
+ Y_ASSERT(message->LocalFlags & MESSAGE_IN_FLIGHT_ON_CLIENT);
+ message->LocalFlags &= ~MESSAGE_IN_FLIGHT_ON_CLIENT;
+ }
+ ClientRemoteInFlight.ReleaseMultiple(messages.size());
+}
+
+void TRemoteClientSession::ReleaseInFlightAndCallOnReply(TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response) {
+ ReleaseInFlight({request.Get()});
+ if (Y_UNLIKELY(AtomicGet(Down))) {
+ InvokeOnError(request, MESSAGE_SHUTDOWN);
+ InvokeOnError(response.MessagePtr.Release(), MESSAGE_SHUTDOWN);
+
+ TRemoteConnectionReaderIncrementalStatus counter;
+ LWPROBE(Error, ToString(MESSAGE_SHUTDOWN), "", "");
+ counter.StatusCounter[MESSAGE_SHUTDOWN] += 1;
+ GetDeadConnectionReaderStatusQueue()->EnqueueAndSchedule(counter);
+ } else {
+ TWhatThreadDoesPushPop pp("OnReply");
+ ClientHandler->OnReply(request, response.MessagePtr.Release());
+ }
+}
+
+EMessageStatus TRemoteClientSession::GetMessageDestination(TBusMessage* mess, const TNetAddr* addrp, TBusSocketAddr* dest) {
+ if (addrp) {
+ *dest = *addrp;
+ } else {
+ TNetAddr tmp;
+ EMessageStatus ret = const_cast<TBusProtocol*>(GetProto())->GetDestination(this, mess, GetQueue()->GetLocator(), &tmp);
+ if (ret != MESSAGE_OK) {
+ return ret;
+ }
+ *dest = tmp;
+ }
+ return MESSAGE_OK;
+}
+
+void TRemoteClientSession::OpenConnection(const TNetAddr& addr) {
+ GetConnection(addr)->OpenConnection();
+}
+
+TBusClientConnectionPtr TRemoteClientSession::GetConnection(const TNetAddr& addr) {
+ // TODO: GetConnection should not open
+ return CheckedCast<TRemoteClientConnection*>(((TBusSessionImpl*)this)->GetConnection(addr, true).Get());
+}
diff --git a/library/cpp/messagebus/remote_client_session.h b/library/cpp/messagebus/remote_client_session.h
new file mode 100644
index 0000000000..7160d0dae9
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_session.h
@@ -0,0 +1,59 @@
+#pragma once
+
+#include "remote_client_session_semaphore.h"
+#include "session_impl.h"
+
+#include <util/generic/array_ref.h>
+#include <util/generic/object_counter.h>
+
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance
+#endif
+
+namespace NBus {
+ namespace NPrivate {
+ using TRemoteClientSessionPtr = TIntrusivePtr<TRemoteClientSession>;
+
+ class TRemoteClientSession: public TBusClientSession, public TBusSessionImpl {
+ friend class TRemoteClientConnection;
+ friend class TInvokeOnReply;
+
+ public:
+ TObjectCounter<TRemoteClientSession> ObjectCounter;
+
+ TRemoteClientSessionSemaphore ClientRemoteInFlight;
+ IBusClientHandler* const ClientHandler;
+
+ public:
+ TRemoteClientSession(TBusMessageQueue* queue, TBusProtocol* proto,
+ IBusClientHandler* handler,
+ const TBusSessionConfig& config, const TString& name);
+
+ ~TRemoteClientSession() override;
+
+ void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) override;
+
+ EMessageStatus SendMessageImpl(TBusMessage* msg, const TNetAddr* addr, bool wait, bool oneWay);
+ EMessageStatus SendMessage(TBusMessage* msg, const TNetAddr* addr = nullptr, bool wait = false) override;
+ EMessageStatus SendMessageOneWay(TBusMessage* msg, const TNetAddr* addr = nullptr, bool wait = false) override;
+
+ int GetInFlight() const noexcept override;
+ void FillStatus() override;
+ void AcquireInFlight(TArrayRef<TBusMessage* const> messages);
+ void ReleaseInFlight(TArrayRef<TBusMessage* const> messages);
+ void ReleaseInFlightAndCallOnReply(TNonDestroyingAutoPtr<TBusMessage> request, TBusMessagePtrAndHeader& response);
+
+ EMessageStatus GetMessageDestination(TBusMessage* mess, const TNetAddr* addrp, TBusSocketAddr* dest);
+
+ void OpenConnection(const TNetAddr&) override;
+
+ TBusClientConnectionPtr GetConnection(const TNetAddr&) override;
+ };
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_client_session_semaphore.cpp b/library/cpp/messagebus/remote_client_session_semaphore.cpp
new file mode 100644
index 0000000000..f877ed4257
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_session_semaphore.cpp
@@ -0,0 +1,67 @@
+#include "remote_client_session_semaphore.h"
+
+#include <util/stream/output.h>
+#include <util/system/yassert.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteClientSessionSemaphore::TRemoteClientSessionSemaphore(TAtomicBase limit, const char* name)
+ : Name(name)
+ , Limit(limit)
+ , Current(0)
+ , StopSignal(0)
+{
+ Y_VERIFY(limit > 0, "limit must be > 0");
+ Y_UNUSED(Name);
+}
+
+TRemoteClientSessionSemaphore::~TRemoteClientSessionSemaphore() {
+ Y_VERIFY(AtomicGet(Current) == 0);
+}
+
+bool TRemoteClientSessionSemaphore::TryAcquire() {
+ if (!TryWait()) {
+ return false;
+ }
+
+ AtomicIncrement(Current);
+ return true;
+}
+
+bool TRemoteClientSessionSemaphore::TryWait() {
+ if (AtomicGet(Current) < Limit)
+ return true;
+ if (Y_UNLIKELY(AtomicGet(StopSignal)))
+ return true;
+ return false;
+}
+
+void TRemoteClientSessionSemaphore::Acquire() {
+ Wait();
+
+ Increment();
+}
+
+void TRemoteClientSessionSemaphore::Increment() {
+ IncrementMultiple(1);
+}
+
+void TRemoteClientSessionSemaphore::IncrementMultiple(TAtomicBase count) {
+ AtomicAdd(Current, count);
+ Updated();
+}
+
+void TRemoteClientSessionSemaphore::Release() {
+ ReleaseMultiple(1);
+}
+
+void TRemoteClientSessionSemaphore::ReleaseMultiple(TAtomicBase count) {
+ AtomicSub(Current, count);
+ Updated();
+}
+
+void TRemoteClientSessionSemaphore::Stop() {
+ AtomicSet(StopSignal, 1);
+ Updated();
+}
diff --git a/library/cpp/messagebus/remote_client_session_semaphore.h b/library/cpp/messagebus/remote_client_session_semaphore.h
new file mode 100644
index 0000000000..286ca3c86f
--- /dev/null
+++ b/library/cpp/messagebus/remote_client_session_semaphore.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include "cc_semaphore.h"
+
+#include <util/generic/noncopyable.h>
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteClientSessionSemaphore: public TComplexConditionSemaphore<TRemoteClientSessionSemaphore> {
+ private:
+ const char* const Name;
+
+ TAtomicBase const Limit;
+ TAtomic Current;
+ TAtomic StopSignal;
+
+ public:
+ TRemoteClientSessionSemaphore(TAtomicBase limit, const char* name = "unnamed");
+ ~TRemoteClientSessionSemaphore();
+
+ TAtomicBase GetCurrent() const {
+ return AtomicGet(Current);
+ }
+
+ void Acquire();
+ bool TryAcquire();
+ void Increment();
+ void IncrementMultiple(TAtomicBase count);
+ bool TryWait();
+ void Release();
+ void ReleaseMultiple(TAtomicBase count);
+ void Stop();
+
+ private:
+ void CheckNeedToUnlock();
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_connection.cpp b/library/cpp/messagebus/remote_connection.cpp
new file mode 100644
index 0000000000..22932569db
--- /dev/null
+++ b/library/cpp/messagebus/remote_connection.cpp
@@ -0,0 +1,974 @@
+#include "remote_connection.h"
+
+#include "key_value_printer.h"
+#include "mb_lwtrace.h"
+#include "network.h"
+#include "remote_client_connection.h"
+#include "remote_client_session.h"
+#include "remote_server_session.h"
+#include "session_impl.h"
+
+#include <library/cpp/messagebus/actor/what_thread_does.h>
+
+#include <util/generic/cast.h>
+#include <util/network/init.h>
+#include <util/system/atomic.h>
+
+LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER)
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+namespace NBus {
+ namespace NPrivate {
+ TRemoteConnection::TRemoteConnection(TRemoteSessionPtr session, ui64 connectionId, TNetAddr addr)
+ : TActor<TRemoteConnection, TWriterTag>(session->Queue->WorkQueue.Get())
+ , TActor<TRemoteConnection, TReaderTag>(session->Queue->WorkQueue.Get())
+ , TScheduleActor<TRemoteConnection, TWriterTag>(&session->Queue->Scheduler)
+ , Session(session)
+ , Proto(session->Proto)
+ , Config(session->Config)
+ , RemovedFromSession(false)
+ , ConnectionId(connectionId)
+ , PeerAddr(addr)
+ , PeerAddrSocketAddr(addr)
+ , CreatedTime(TInstant::Now())
+ , ReturnConnectFailedImmediately(false)
+ , GranStatus(Config.Secret.StatusFlushPeriod)
+ , QuotaMsg(!Session->IsSource_, Config.PerConnectionMaxInFlight, 0)
+ , QuotaBytes(!Session->IsSource_, Config.PerConnectionMaxInFlightBySize, 0)
+ , MaxBufferSize(session->Config.MaxBufferSize)
+ , ShutdownReason(MESSAGE_OK)
+ {
+ WriterData.Status.ConnectionId = connectionId;
+ WriterData.Status.PeerAddr = PeerAddr;
+ ReaderData.Status.ConnectionId = connectionId;
+
+ const TInstant now = TInstant::Now();
+
+ WriterFillStatus();
+
+ GranStatus.Writer.Update(WriterData.Status, now, true);
+ GranStatus.Reader.Update(ReaderData.Status, now, true);
+ }
+
+ TRemoteConnection::~TRemoteConnection() {
+ Y_VERIFY(ReplyQueue.IsEmpty());
+ }
+
+ TRemoteConnection::TWriterData::TWriterData()
+ : Down(0)
+ , SocketVersion(0)
+ , InFlight(0)
+ , AwakeFlags(0)
+ , State(WRITER_FILLING)
+ {
+ }
+
+ TRemoteConnection::TWriterData::~TWriterData() {
+ Y_VERIFY(AtomicGet(Down));
+ Y_VERIFY(SendQueue.Empty());
+ }
+
+ bool TRemoteConnection::TReaderData::HasBytesInBuf(size_t bytes) noexcept {
+ size_t left = Buffer.Size() - Offset;
+
+ return (MoreBytes = left >= bytes ? 0 : bytes - left) == 0;
+ }
+
+ void TRemoteConnection::TWriterData::SetChannel(NEventLoop::TChannelPtr channel) {
+ Y_VERIFY(!Channel, "must not have channel");
+ Y_VERIFY(Buffer.GetBuffer().Empty() && Buffer.LeftSize() == 0, "buffer must be empty");
+ Y_VERIFY(State == WRITER_FILLING, "state must be initial");
+ Channel = channel;
+ }
+
+ void TRemoteConnection::TReaderData::SetChannel(NEventLoop::TChannelPtr channel) {
+ Y_VERIFY(!Channel, "must not have channel");
+ Y_VERIFY(Buffer.Empty(), "buffer must be empty");
+ Channel = channel;
+ }
+
+ void TRemoteConnection::TWriterData::DropChannel() {
+ if (!!Channel) {
+ Channel->Unregister();
+ Channel.Drop();
+ }
+
+ Buffer.Reset();
+ State = WRITER_FILLING;
+ }
+
+ void TRemoteConnection::TReaderData::DropChannel() {
+ // TODO: make Drop call Unregister
+ if (!!Channel) {
+ Channel->Unregister();
+ Channel.Drop();
+ }
+ Buffer.Reset();
+ Offset = 0;
+ }
+
+ TRemoteConnection::TReaderData::TReaderData()
+ : Down(0)
+ , SocketVersion(0)
+ , Offset(0)
+ , MoreBytes(0)
+ {
+ }
+
+ TRemoteConnection::TReaderData::~TReaderData() {
+ Y_VERIFY(AtomicGet(Down));
+ }
+
+ void TRemoteConnection::Send(TNonDestroyingAutoPtr<TBusMessage> msg) {
+ BeforeSendQueue.Enqueue(msg.Release());
+ AtomicIncrement(WriterData.InFlight);
+ ScheduleWrite();
+ }
+
+ void TRemoteConnection::ClearOutgoingQueue(TMessagesPtrs& result, bool reconnect) {
+ if (!reconnect) {
+ // Do not clear send queue if reconnecting
+ WriterData.SendQueue.Clear(&result);
+ }
+ }
+
+ void TRemoteConnection::Shutdown(EMessageStatus status) {
+ ScheduleShutdown(status);
+
+ ReaderData.ShutdownComplete.WaitI();
+ WriterData.ShutdownComplete.WaitI();
+ }
+
+ void TRemoteConnection::TryConnect() {
+ Y_FAIL("TryConnect is client connection only operation");
+ }
+
+ void TRemoteConnection::ScheduleRead() {
+ GetReaderActor()->Schedule();
+ }
+
+ void TRemoteConnection::ScheduleWrite() {
+ GetWriterActor()->Schedule();
+ }
+
+ void TRemoteConnection::WriterRotateCounters() {
+ if (!WriterData.TimeToRotateCounters.FetchTask()) {
+ return;
+ }
+
+ WriterData.Status.DurationCounterPrev = WriterData.Status.DurationCounter;
+ Reset(WriterData.Status.DurationCounter);
+ }
+
+ void TRemoteConnection::WriterSendStatus(TInstant now, bool force) {
+ GranStatus.Writer.Update(std::bind(&TRemoteConnection::WriterGetStatus, this), now, force);
+ }
+
+ void TRemoteConnection::ReaderSendStatus(TInstant now, bool force) {
+ GranStatus.Reader.Update(std::bind(&TRemoteConnection::ReaderFillStatus, this), now, force);
+ }
+
+ const TRemoteConnectionReaderStatus& TRemoteConnection::ReaderFillStatus() {
+ ReaderData.Status.BufferSize = ReaderData.Buffer.Capacity();
+ ReaderData.Status.QuotaMsg = QuotaMsg.Tokens();
+ ReaderData.Status.QuotaBytes = QuotaBytes.Tokens();
+
+ return ReaderData.Status;
+ }
+
+ void TRemoteConnection::ProcessItem(TReaderTag, ::NActor::TDefaultTag, TWriterToReaderSocketMessage readSocket) {
+ if (AtomicGet(ReaderData.Down)) {
+ ReaderData.Status.Fd = INVALID_SOCKET;
+ return;
+ }
+
+ ReaderData.DropChannel();
+
+ ReaderData.Status.Fd = readSocket.Socket;
+ ReaderData.SocketVersion = readSocket.SocketVersion;
+
+ if (readSocket.Socket != INVALID_SOCKET) {
+ ReaderData.SetChannel(Session->ReadEventLoop.Register(readSocket.Socket, this, ReadCookie));
+ ReaderData.Channel->EnableRead();
+ }
+ }
+
+ void TRemoteConnection::ProcessItem(TWriterTag, TReconnectTag, ui32 socketVersion) {
+ Y_VERIFY(socketVersion <= WriterData.SocketVersion, "something weird");
+
+ if (WriterData.SocketVersion != socketVersion) {
+ return;
+ }
+ Y_VERIFY(WriterData.Status.Connected, "must be connected at this point");
+ Y_VERIFY(!!WriterData.Channel, "must have channel at this point");
+
+ WriterData.Status.Connected = false;
+ WriterData.DropChannel();
+ WriterData.Status.MyAddr = TNetAddr();
+ ++WriterData.SocketVersion;
+ LastConnectAttempt = TInstant();
+
+ TMessagesPtrs cleared;
+ ClearOutgoingQueue(cleared, true);
+ WriterErrorMessages(cleared, MESSAGE_DELIVERY_FAILED);
+
+ FireClientConnectionEvent(TClientConnectionEvent::DISCONNECTED);
+
+ ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(INVALID_SOCKET, WriterData.SocketVersion));
+ }
+
+ void TRemoteConnection::ProcessItem(TWriterTag, TWakeReaderTag, ui32 awakeFlags) {
+ WriterData.AwakeFlags |= awakeFlags;
+
+ ReadQuotaWakeup();
+ }
+
+ void TRemoteConnection::Act(TReaderTag) {
+ TInstant now = TInstant::Now();
+
+ ReaderData.Status.Acts += 1;
+
+ ReaderGetSocketQueue()->DequeueAllLikelyEmpty();
+
+ if (AtomicGet(ReaderData.Down)) {
+ ReaderData.DropChannel();
+
+ ReaderProcessStatusDown();
+ ReaderData.ShutdownComplete.Signal();
+
+ } else if (!!ReaderData.Channel) {
+ Y_ASSERT(ReaderData.ReadMessages.empty());
+
+ for (int i = 0;; ++i) {
+ if (i == 100) {
+ // perform other tasks
+ GetReaderActor()->AddTaskFromActorLoop();
+ break;
+ }
+
+ if (NeedInterruptRead()) {
+ ReaderData.Channel->EnableRead();
+ break;
+ }
+
+ if (!ReaderFillBuffer())
+ break;
+
+ if (!ReaderProcessBuffer())
+ break;
+ }
+
+ ReaderFlushMessages();
+ }
+
+ ReaderSendStatus(now);
+ }
+
+ bool TRemoteConnection::QuotaAcquire(size_t msg, size_t bytes) {
+ ui32 wakeFlags = 0;
+
+ if (!QuotaMsg.Acquire(msg))
+ wakeFlags |= WAKE_QUOTA_MSG;
+
+ else if (!QuotaBytes.Acquire(bytes))
+ wakeFlags |= WAKE_QUOTA_BYTES;
+
+ if (wakeFlags) {
+ ReaderData.Status.QuotaExhausted++;
+
+ WriterGetWakeQueue()->EnqueueAndSchedule(wakeFlags);
+ }
+
+ return wakeFlags == 0;
+ }
+
+ void TRemoteConnection::QuotaConsume(size_t msg, size_t bytes) {
+ QuotaMsg.Consume(msg);
+ QuotaBytes.Consume(bytes);
+ }
+
+ void TRemoteConnection::QuotaReturnSelf(size_t items, size_t bytes) {
+ if (QuotaReturnValues(items, bytes))
+ ReadQuotaWakeup();
+ }
+
+ void TRemoteConnection::QuotaReturnAside(size_t items, size_t bytes) {
+ if (QuotaReturnValues(items, bytes) && !AtomicGet(WriterData.Down))
+ WriterGetWakeQueue()->EnqueueAndSchedule(0x0);
+ }
+
+ bool TRemoteConnection::QuotaReturnValues(size_t items, size_t bytes) {
+ bool rMsg = QuotaMsg.Return(items);
+ bool rBytes = QuotaBytes.Return(bytes);
+
+ return rMsg || rBytes;
+ }
+
+ void TRemoteConnection::ReadQuotaWakeup() {
+ const ui32 mask = WriterData.AwakeFlags & WriteWakeFlags();
+
+ if (mask && mask == WriterData.AwakeFlags) {
+ WriterData.Status.ReaderWakeups++;
+ WriterData.AwakeFlags = 0;
+
+ ScheduleRead();
+ }
+ }
+
+ ui32 TRemoteConnection::WriteWakeFlags() const {
+ ui32 awakeFlags = 0;
+
+ if (QuotaMsg.IsAboveWake())
+ awakeFlags |= WAKE_QUOTA_MSG;
+
+ if (QuotaBytes.IsAboveWake())
+ awakeFlags |= WAKE_QUOTA_BYTES;
+
+ return awakeFlags;
+ }
+
+ bool TRemoteConnection::ReaderProcessBuffer() {
+ TInstant now = TInstant::Now();
+
+ for (;;) {
+ if (!ReaderData.HasBytesInBuf(sizeof(TBusHeader))) {
+ break;
+ }
+
+ TBusHeader header(MakeArrayRef(ReaderData.Buffer.Data() + ReaderData.Offset, ReaderData.Buffer.Size() - ReaderData.Offset));
+
+ if (header.Size < sizeof(TBusHeader)) {
+ LWPROBE(Error, ToString(MESSAGE_HEADER_CORRUPTED), ToString(PeerAddr), ToString(header.Size));
+ ReaderData.Status.Incremental.StatusCounter[MESSAGE_HEADER_CORRUPTED] += 1;
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_HEADER_CORRUPTED, false);
+ return false;
+ }
+
+ if (!IsVersionNegotiation(header) && !IsBusKeyValid(header.Id)) {
+ LWPROBE(Error, ToString(MESSAGE_HEADER_CORRUPTED), ToString(PeerAddr), ToString(header.Size));
+ ReaderData.Status.Incremental.StatusCounter[MESSAGE_HEADER_CORRUPTED] += 1;
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_HEADER_CORRUPTED, false);
+ return false;
+ }
+
+ if (header.Size > Config.MaxMessageSize) {
+ LWPROBE(Error, ToString(MESSAGE_MESSAGE_TOO_LARGE), ToString(PeerAddr), ToString(header.Size));
+ ReaderData.Status.Incremental.StatusCounter[MESSAGE_MESSAGE_TOO_LARGE] += 1;
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_MESSAGE_TOO_LARGE, false);
+ return false;
+ }
+
+ if (!ReaderData.HasBytesInBuf(header.Size)) {
+ if (ReaderData.Offset == 0) {
+ ReaderData.Buffer.Reserve(header.Size);
+ }
+ break;
+ }
+
+ if (!QuotaAcquire(1, header.Size))
+ return false;
+
+ if (!MessageRead(MakeArrayRef(ReaderData.Buffer.Data() + ReaderData.Offset, header.Size), now)) {
+ return false;
+ }
+
+ ReaderData.Offset += header.Size;
+ }
+
+ ReaderData.Buffer.ChopHead(ReaderData.Offset);
+ ReaderData.Offset = 0;
+
+ if (ReaderData.Buffer.Capacity() > MaxBufferSize && ReaderData.Buffer.Size() <= MaxBufferSize) {
+ ReaderData.Status.Incremental.BufferDrops += 1;
+
+ TBuffer temp;
+ // probably should use another constant
+ temp.Reserve(Config.DefaultBufferSize);
+ temp.Append(ReaderData.Buffer.Data(), ReaderData.Buffer.Size());
+
+ ReaderData.Buffer.Swap(temp);
+ }
+
+ return true;
+ }
+
+ bool TRemoteConnection::ReaderFillBuffer() {
+ if (!ReaderData.BufferMore())
+ return true;
+
+ if (ReaderData.Buffer.Avail() == 0) {
+ if (ReaderData.Buffer.Size() == 0) {
+ ReaderData.Buffer.Reserve(Config.DefaultBufferSize);
+ } else {
+ ReaderData.Buffer.Reserve(ReaderData.Buffer.Size() * 2);
+ }
+ }
+
+ Y_ASSERT(ReaderData.Buffer.Avail() > 0);
+
+ ssize_t bytes;
+ {
+ TWhatThreadDoesPushPop pp("recv syscall");
+ bytes = SocketRecv(ReaderData.Channel->GetSocket(), TArrayRef<char>(ReaderData.Buffer.Pos(), ReaderData.Buffer.Avail()));
+ }
+
+ if (bytes < 0) {
+ if (WouldBlock()) {
+ ReaderData.Channel->EnableRead();
+ return false;
+ } else {
+ ReaderData.Channel->DisableRead();
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, false);
+ return false;
+ }
+ }
+
+ if (bytes == 0) {
+ ReaderData.Channel->DisableRead();
+ // TODO: incorrect: it is possible that only input is shutdown, and output is available
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, false);
+ return false;
+ }
+
+ ReaderData.Status.Incremental.NetworkOps += 1;
+
+ ReaderData.Buffer.Advance(bytes);
+ ReaderData.MoreBytes = 0;
+ return true;
+ }
+
+ void TRemoteConnection::ClearBeforeSendQueue(EMessageStatus reason) {
+ BeforeSendQueue.DequeueAll(std::bind(&TRemoteConnection::WriterBeforeWriteErrorMessage, this, std::placeholders::_1, reason));
+ }
+
+ void TRemoteConnection::ClearReplyQueue(EMessageStatus reason) {
+ TVectorSwaps<TBusMessagePtrAndHeader> replyQueueTemp;
+ Y_ASSERT(replyQueueTemp.empty());
+ ReplyQueue.DequeueAllSingleConsumer(&replyQueueTemp);
+
+ TVector<TBusMessage*> messages;
+ for (TVectorSwaps<TBusMessagePtrAndHeader>::reverse_iterator message = replyQueueTemp.rbegin();
+ message != replyQueueTemp.rend(); ++message) {
+ messages.push_back(message->MessagePtr.Release());
+ }
+
+ WriterErrorMessages(messages, reason);
+
+ replyQueueTemp.clear();
+ }
+
+ void TRemoteConnection::ProcessBeforeSendQueueMessage(TBusMessage* message, TInstant now) {
+ // legacy clients expect this field to be set
+ if (!Session->IsSource_) {
+ message->SendTime = now.MilliSeconds();
+ }
+
+ WriterData.SendQueue.PushBack(message);
+ }
+
+ void TRemoteConnection::ProcessBeforeSendQueue(TInstant now) {
+ BeforeSendQueue.DequeueAll(std::bind(&TRemoteConnection::ProcessBeforeSendQueueMessage, this, std::placeholders::_1, now));
+ }
+
+ void TRemoteConnection::WriterFillInFlight() {
+ // this is hack for TLoadBalancedProtocol
+ WriterFillStatus();
+ AtomicSet(WriterData.InFlight, WriterData.Status.GetInFlight());
+ }
+
+ const TRemoteConnectionWriterStatus& TRemoteConnection::WriterGetStatus() {
+ WriterRotateCounters();
+ WriterFillStatus();
+
+ return WriterData.Status;
+ }
+
+ void TRemoteConnection::WriterFillStatus() {
+ if (!!WriterData.Channel) {
+ WriterData.Status.Fd = WriterData.Channel->GetSocket();
+ } else {
+ WriterData.Status.Fd = INVALID_SOCKET;
+ }
+ WriterData.Status.BufferSize = WriterData.Buffer.Capacity();
+ WriterData.Status.SendQueueSize = WriterData.SendQueue.Size();
+ WriterData.Status.State = WriterData.State;
+ }
+
+ void TRemoteConnection::WriterProcessStatusDown() {
+ Session->GetDeadConnectionWriterStatusQueue()->EnqueueAndSchedule(WriterData.Status.Incremental);
+ Reset(WriterData.Status.Incremental);
+ }
+
+ void TRemoteConnection::ReaderProcessStatusDown() {
+ Session->GetDeadConnectionReaderStatusQueue()->EnqueueAndSchedule(ReaderData.Status.Incremental);
+ Reset(ReaderData.Status.Incremental);
+ }
+
+ void TRemoteConnection::ProcessWriterDown() {
+ if (!RemovedFromSession) {
+ Session->GetRemoveConnectionQueue()->EnqueueAndSchedule(this);
+
+ if (Session->IsSource_) {
+ if (WriterData.Status.Connected) {
+ FireClientConnectionEvent(TClientConnectionEvent::DISCONNECTED);
+ }
+ }
+
+ LWPROBE(Disconnected, ToString(PeerAddr));
+ RemovedFromSession = true;
+ }
+
+ WriterData.DropChannel();
+
+ DropEnqueuedData(ShutdownReason, MESSAGE_SHUTDOWN);
+
+ WriterProcessStatusDown();
+
+ WriterData.ShutdownComplete.Signal();
+ }
+
+ void TRemoteConnection::DropEnqueuedData(EMessageStatus reason, EMessageStatus reasonForQueues) {
+ ClearReplyQueue(reasonForQueues);
+ ClearBeforeSendQueue(reasonForQueues);
+ WriterGetReconnectQueue()->Clear();
+ WriterGetWakeQueue()->Clear();
+
+ TMessagesPtrs cleared;
+ ClearOutgoingQueue(cleared, false);
+
+ if (!Session->IsSource_) {
+ for (auto& i : cleared) {
+ TBusMessagePtrAndHeader h(i);
+ CheckedCast<TRemoteServerSession*>(Session.Get())->ReleaseInWorkResponses(MakeArrayRef(&h, 1));
+ // assignment back is weird
+ i = h.MessagePtr.Release();
+ // and this part is not batch
+ }
+ }
+
+ WriterErrorMessages(cleared, reason);
+ }
+
+ void TRemoteConnection::BeforeTryWrite() {
+ }
+
+ void TRemoteConnection::Act(TWriterTag) {
+ TInstant now = TInstant::Now();
+
+ WriterData.Status.Acts += 1;
+
+ if (Y_UNLIKELY(AtomicGet(WriterData.Down))) {
+ // dump status must work even if WriterDown
+ WriterSendStatus(now, true);
+ ProcessWriterDown();
+ return;
+ }
+
+ ProcessBeforeSendQueue(now);
+
+ BeforeTryWrite();
+
+ WriterFillInFlight();
+
+ WriterGetReconnectQueue()->DequeueAllLikelyEmpty();
+
+ if (!WriterData.Status.Connected) {
+ TryConnect();
+ } else {
+ for (int i = 0;; ++i) {
+ if (i == 100) {
+ // perform other tasks
+ GetWriterActor()->AddTaskFromActorLoop();
+ break;
+ }
+
+ if (WriterData.State == WRITER_FILLING) {
+ WriterFillBuffer();
+
+ if (WriterData.State == WRITER_FILLING) {
+ WriterData.Channel->DisableWrite();
+ break;
+ }
+
+ Y_ASSERT(!WriterData.Buffer.Empty());
+ }
+
+ if (WriterData.State == WRITER_FLUSHING) {
+ WriterFlushBuffer();
+
+ if (WriterData.State == WRITER_FLUSHING) {
+ break;
+ }
+ }
+ }
+ }
+
+ WriterGetWakeQueue()->DequeueAllLikelyEmpty();
+
+ WriterSendStatus(now);
+ }
+
+ void TRemoteConnection::WriterFlushBuffer() {
+ Y_ASSERT(WriterData.State == WRITER_FLUSHING);
+ Y_ASSERT(!WriterData.Buffer.Empty());
+
+ WriterData.CorkUntil = TInstant::Zero();
+
+ while (!WriterData.Buffer.Empty()) {
+ ssize_t bytes;
+ {
+ TWhatThreadDoesPushPop pp("send syscall");
+ bytes = SocketSend(WriterData.Channel->GetSocket(), TArrayRef<const char>(WriterData.Buffer.LeftPos(), WriterData.Buffer.Size()));
+ }
+
+ if (bytes < 0) {
+ if (WouldBlock()) {
+ WriterData.Channel->EnableWrite();
+ return;
+ } else {
+ WriterData.Channel->DisableWrite();
+ ScheduleShutdownOnServerOrReconnectOnClient(MESSAGE_DELIVERY_FAILED, true);
+ return;
+ }
+ }
+
+ WriterData.Status.Incremental.NetworkOps += 1;
+
+ WriterData.Buffer.LeftProceed(bytes);
+ }
+
+ WriterData.Buffer.Clear();
+ if (WriterData.Buffer.Capacity() > MaxBufferSize) {
+ WriterData.Status.Incremental.BufferDrops += 1;
+ WriterData.Buffer.Reset();
+ }
+
+ WriterData.State = WRITER_FILLING;
+ }
+
+ void TRemoteConnection::ScheduleShutdownOnServerOrReconnectOnClient(EMessageStatus status, bool writer) {
+ if (Session->IsSource_) {
+ WriterGetReconnectQueue()->EnqueueAndSchedule(writer ? WriterData.SocketVersion : ReaderData.SocketVersion);
+ } else {
+ ScheduleShutdown(status);
+ }
+ }
+
+ void TRemoteConnection::ScheduleShutdown(EMessageStatus status) {
+ ShutdownReason = status;
+
+ AtomicSet(ReaderData.Down, 1);
+ ScheduleRead();
+
+ AtomicSet(WriterData.Down, 1);
+ ScheduleWrite();
+ }
+
+ void TRemoteConnection::CallSerialize(TBusMessage* msg, TBuffer& buffer) const {
+ size_t posForAssertion = buffer.Size();
+ Proto->Serialize(msg, buffer);
+ Y_VERIFY(buffer.Size() >= posForAssertion,
+ "incorrect Serialize implementation, pos before serialize: %d, pos after serialize: %d",
+ int(posForAssertion), int(buffer.Size()));
+ }
+
+ namespace {
+ inline void WriteHeader(const TBusHeader& header, TBuffer& data) {
+ data.Reserve(data.Size() + sizeof(TBusHeader));
+ /// \todo hton instead of memcpy
+ memcpy(data.Data() + data.Size(), &header, sizeof(TBusHeader));
+ data.Advance(sizeof(TBusHeader));
+ }
+
+ inline void WriteDummyHeader(TBuffer& data) {
+ data.Resize(data.Size() + sizeof(TBusHeader));
+ }
+
+ }
+
+ void TRemoteConnection::SerializeMessage(TBusMessage* msg, TBuffer* data, TMessageCounter* counter) const {
+ size_t pos = data->Size();
+
+ size_t dataSize;
+
+ bool compressionRequested = msg->IsCompressed();
+
+ if (compressionRequested) {
+ TBuffer compdata;
+ TBuffer plaindata;
+ CallSerialize(msg, plaindata);
+
+ dataSize = sizeof(TBusHeader) + plaindata.Size();
+
+ NCodecs::TCodecPtr c = Proto->GetTransportCodec();
+ c->Encode(TStringBuf{plaindata.data(), plaindata.size()}, compdata);
+
+ if (compdata.Size() < plaindata.Size()) {
+ plaindata.Clear();
+ msg->GetHeader()->Size = sizeof(TBusHeader) + compdata.Size();
+ WriteHeader(*msg->GetHeader(), *data);
+ data->Append(compdata.Data(), compdata.Size());
+ } else {
+ compdata.Clear();
+ msg->SetCompressed(false);
+ msg->GetHeader()->Size = sizeof(TBusHeader) + plaindata.Size();
+ WriteHeader(*msg->GetHeader(), *data);
+ data->Append(plaindata.Data(), plaindata.Size());
+ }
+ } else {
+ WriteDummyHeader(*data);
+ CallSerialize(msg, *data);
+
+ dataSize = msg->GetHeader()->Size = data->Size() - pos;
+
+ data->Proceed(pos);
+ WriteHeader(*msg->GetHeader(), *data);
+ data->Proceed(pos + msg->GetHeader()->Size);
+ }
+
+ Y_ASSERT(msg->GetHeader()->Size == data->Size() - pos);
+ counter->AddMessage(dataSize, data->Size() - pos, msg->IsCompressed(), compressionRequested);
+ }
+
+ TBusMessage* TRemoteConnection::DeserializeMessage(TArrayRef<const char> dataRef, const TBusHeader* header, TMessageCounter* messageCounter, EMessageStatus* status) const {
+ size_t dataSize;
+
+ TBusMessage* message;
+ if (header->FlagsInternal & MESSAGE_COMPRESS_INTERNAL) {
+ TBuffer msg;
+ {
+ TBuffer plaindata;
+ NCodecs::TCodecPtr c = Proto->GetTransportCodec();
+ try {
+ TArrayRef<const char> payload = TBusMessage::GetPayload(dataRef);
+ c->Decode(TStringBuf{payload.data(), payload.size()}, plaindata);
+ } catch (...) {
+ // catch all, because
+ // http://nga.at.yandex-team.ru/replies.xml?item_no=3884
+ *status = MESSAGE_DECOMPRESS_ERROR;
+ return nullptr;
+ }
+
+ msg.Append(dataRef.data(), sizeof(TBusHeader));
+ msg.Append(plaindata.Data(), plaindata.Size());
+ }
+ TArrayRef<const char> msgRef(msg.Data(), msg.Size());
+ dataSize = sizeof(TBusHeader) + msgRef.size();
+ // TODO: track error types
+ message = Proto->Deserialize(header->Type, msgRef.Slice(sizeof(TBusHeader))).Release();
+ if (!message) {
+ *status = MESSAGE_DESERIALIZE_ERROR;
+ return nullptr;
+ }
+ *message->GetHeader() = *header;
+ message->SetCompressed(true);
+ } else {
+ dataSize = dataRef.size();
+ message = Proto->Deserialize(header->Type, dataRef.Slice(sizeof(TBusHeader))).Release();
+ if (!message) {
+ *status = MESSAGE_DESERIALIZE_ERROR;
+ return nullptr;
+ }
+ *message->GetHeader() = *header;
+ }
+
+ messageCounter->AddMessage(dataSize, dataRef.size(), header->FlagsInternal & MESSAGE_COMPRESS_INTERNAL, false);
+
+ return message;
+ }
+
+ void TRemoteConnection::ResetOneWayFlag(TArrayRef<TBusMessage*> messages) {
+ for (auto message : messages) {
+ message->LocalFlags &= ~MESSAGE_ONE_WAY_INTERNAL;
+ }
+ }
+
+ void TRemoteConnection::ReaderFlushMessages() {
+ if (!ReaderData.ReadMessages.empty()) {
+ Session->OnMessageReceived(this, ReaderData.ReadMessages);
+ ReaderData.ReadMessages.clear();
+ }
+ }
+
+ // @return false if actor should break
+ bool TRemoteConnection::MessageRead(TArrayRef<const char> readDataRef, TInstant now) {
+ TBusHeader header(readDataRef);
+
+ Y_ASSERT(readDataRef.size() == header.Size);
+
+ if (header.GetVersionInternal() != YBUS_VERSION) {
+ ReaderProcessMessageUnknownVersion(readDataRef);
+ return true;
+ }
+
+ EMessageStatus deserializeFailureStatus = MESSAGE_OK;
+ TBusMessage* r = DeserializeMessage(readDataRef, &header, &ReaderData.Status.Incremental.MessageCounter, &deserializeFailureStatus);
+
+ if (!r) {
+ Y_VERIFY(deserializeFailureStatus != MESSAGE_OK, "state check");
+ LWPROBE(Error, ToString(deserializeFailureStatus), ToString(PeerAddr), "");
+ ReaderData.Status.Incremental.StatusCounter[deserializeFailureStatus] += 1;
+ ScheduleShutdownOnServerOrReconnectOnClient(deserializeFailureStatus, false);
+ return false;
+ }
+
+ LWPROBE(Read, r->GetHeader()->Size);
+
+ r->ReplyTo = PeerAddrSocketAddr;
+
+ TBusMessagePtrAndHeader h(r);
+ r->RecvTime = now;
+
+ QuotaConsume(1, header.Size);
+
+ ReaderData.ReadMessages.push_back(h);
+ if (ReaderData.ReadMessages.size() >= 100) {
+ ReaderFlushMessages();
+ }
+
+ return true;
+ }
+
+ void TRemoteConnection::WriterFillBuffer() {
+ Y_ASSERT(WriterData.State == WRITER_FILLING);
+
+ Y_ASSERT(WriterData.Buffer.LeftSize() == 0);
+
+ if (Y_UNLIKELY(!WrongVersionRequests.IsEmpty())) {
+ TVector<TBusHeader> headers;
+ WrongVersionRequests.DequeueAllSingleConsumer(&headers);
+ for (TVector<TBusHeader>::reverse_iterator header = headers.rbegin();
+ header != headers.rend(); ++header) {
+ TBusHeader response = *header;
+ response.SendTime = NBus::Now();
+ response.Size = sizeof(TBusHeader);
+ response.FlagsInternal = 0;
+ response.SetVersionInternal(YBUS_VERSION);
+ WriteHeader(response, WriterData.Buffer.GetBuffer());
+ }
+
+ Y_ASSERT(!WriterData.Buffer.Empty());
+ WriterData.State = WRITER_FLUSHING;
+ return;
+ }
+
+ TTempTlsVector<TBusMessagePtrAndHeader, void, TVectorSwaps> writeMessages;
+
+ for (;;) {
+ THolder<TBusMessage> writeMessage(WriterData.SendQueue.PopFront());
+ if (!writeMessage) {
+ break;
+ }
+
+ if (Config.Cork != TDuration::Zero()) {
+ if (WriterData.CorkUntil == TInstant::Zero()) {
+ WriterData.CorkUntil = TInstant::Now() + Config.Cork;
+ }
+ }
+
+ size_t sizeBeforeSerialize = WriterData.Buffer.Size();
+
+ TMessageCounter messageCounter = WriterData.Status.Incremental.MessageCounter;
+
+ SerializeMessage(writeMessage.Get(), &WriterData.Buffer.GetBuffer(), &messageCounter);
+
+ size_t written = WriterData.Buffer.Size() - sizeBeforeSerialize;
+ if (written > Config.MaxMessageSize) {
+ WriterData.Buffer.GetBuffer().EraseBack(written);
+ WriterBeforeWriteErrorMessage(writeMessage.Release(), MESSAGE_MESSAGE_TOO_LARGE);
+ continue;
+ }
+
+ WriterData.Status.Incremental.MessageCounter = messageCounter;
+
+ TBusMessagePtrAndHeader h(writeMessage.Release());
+ writeMessages.GetVector()->push_back(h);
+
+ Y_ASSERT(!WriterData.Buffer.Empty());
+ if (WriterData.Buffer.Size() >= Config.SendThreshold) {
+ break;
+ }
+ }
+
+ if (!WriterData.Buffer.Empty()) {
+ if (WriterData.Buffer.Size() >= Config.SendThreshold) {
+ WriterData.State = WRITER_FLUSHING;
+ } else if (WriterData.CorkUntil == TInstant::Zero()) {
+ WriterData.State = WRITER_FLUSHING;
+ } else if (TInstant::Now() >= WriterData.CorkUntil) {
+ WriterData.State = WRITER_FLUSHING;
+ } else {
+ // keep filling
+ Y_ASSERT(WriterData.State == WRITER_FILLING);
+ GetWriterSchedulerActor()->ScheduleAt(WriterData.CorkUntil);
+ }
+ } else {
+ // keep filling
+ Y_ASSERT(WriterData.State == WRITER_FILLING);
+ }
+
+ size_t bytes = MessageSize(*writeMessages.GetVector());
+
+ QuotaReturnSelf(writeMessages.GetVector()->size(), bytes);
+
+ // This is called before `send` syscall inducing latency
+ MessageSent(*writeMessages.GetVector());
+ }
+
+ size_t TRemoteConnection::MessageSize(TArrayRef<TBusMessagePtrAndHeader> messages) {
+ size_t size = 0;
+ for (const auto& message : messages)
+ size += message.MessagePtr->RequestSize;
+
+ return size;
+ }
+
+ size_t TRemoteConnection::GetInFlight() {
+ return AtomicGet(WriterData.InFlight);
+ }
+
+ size_t TRemoteConnection::GetConnectSyscallsNumForTest() {
+ return WriterData.Status.ConnectSyscalls;
+ }
+
+ void TRemoteConnection::WriterBeforeWriteErrorMessage(TBusMessage* message, EMessageStatus status) {
+ if (Session->IsSource_) {
+ CheckedCast<TRemoteClientSession*>(Session.Get())->ReleaseInFlight({message});
+ WriterErrorMessage(message, status);
+ } else {
+ TBusMessagePtrAndHeader h(message);
+ CheckedCast<TRemoteServerSession*>(Session.Get())->ReleaseInWorkResponses(MakeArrayRef(&h, 1));
+ WriterErrorMessage(h.MessagePtr.Release(), status);
+ }
+ }
+
+ void TRemoteConnection::WriterErrorMessage(TNonDestroyingAutoPtr<TBusMessage> m, EMessageStatus status) {
+ TBusMessage* released = m.Release();
+ WriterErrorMessages(MakeArrayRef(&released, 1), status);
+ }
+
+ void TRemoteConnection::WriterErrorMessages(const TArrayRef<TBusMessage*> ms, EMessageStatus status) {
+ ResetOneWayFlag(ms);
+
+ WriterData.Status.Incremental.StatusCounter[status] += ms.size();
+ for (auto m : ms) {
+ Session->InvokeOnError(m, status);
+ }
+ }
+
+ void TRemoteConnection::FireClientConnectionEvent(TClientConnectionEvent::EType type) {
+ Y_VERIFY(Session->IsSource_, "state check");
+ TClientConnectionEvent event(type, ConnectionId, PeerAddr);
+ TRemoteClientSession* session = CheckedCast<TRemoteClientSession*>(Session.Get());
+ session->ClientHandler->OnClientConnectionEvent(event);
+ }
+
+ bool TRemoteConnection::IsAlive() const {
+ return !AtomicGet(WriterData.Down);
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_connection.h b/library/cpp/messagebus/remote_connection.h
new file mode 100644
index 0000000000..4538947368
--- /dev/null
+++ b/library/cpp/messagebus/remote_connection.h
@@ -0,0 +1,294 @@
+#pragma once
+
+#include "async_result.h"
+#include "defs.h"
+#include "event_loop.h"
+#include "left_right_buffer.h"
+#include "lfqueue_batch.h"
+#include "message_ptr_and_header.h"
+#include "nondestroying_holder.h"
+#include "remote_connection_status.h"
+#include "scheduler_actor.h"
+#include "socket_addr.h"
+#include "storage.h"
+#include "vector_swaps.h"
+#include "ybus.h"
+#include "misc/granup.h"
+#include "misc/tokenquota.h"
+
+#include <library/cpp/messagebus/actor/actor.h>
+#include <library/cpp/messagebus/actor/executor.h>
+#include <library/cpp/messagebus/actor/queue_for_actor.h>
+#include <library/cpp/messagebus/actor/queue_in_actor.h>
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <util/system/atomic.h>
+#include <util/system/event.h>
+#include <util/thread/lfstack.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteConnection;
+
+ typedef TIntrusivePtr<TRemoteConnection> TRemoteConnectionPtr;
+ typedef TIntrusivePtr<TBusSessionImpl> TRemoteSessionPtr;
+
+ static void* const WriteCookie = (void*)1;
+ static void* const ReadCookie = (void*)2;
+
+ enum {
+ WAKE_QUOTA_MSG = 0x01,
+ WAKE_QUOTA_BYTES = 0x02
+ };
+
+ struct TWriterTag {};
+ struct TReaderTag {};
+ struct TReconnectTag {};
+ struct TWakeReaderTag {};
+
+ struct TWriterToReaderSocketMessage {
+ TSocket Socket;
+ ui32 SocketVersion;
+
+ TWriterToReaderSocketMessage(TSocket socket, ui32 socketVersion)
+ : Socket(socket)
+ , SocketVersion(socketVersion)
+ {
+ }
+ };
+
+ class TRemoteConnection
+ : public NEventLoop::IEventHandler,
+ public ::NActor::TActor<TRemoteConnection, TWriterTag>,
+ public ::NActor::TActor<TRemoteConnection, TReaderTag>,
+ private ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>,
+ private ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>,
+ private ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>,
+ public TScheduleActor<TRemoteConnection, TWriterTag> {
+ friend struct TBusSessionImpl;
+ friend class TRemoteClientSession;
+ friend class TRemoteServerSession;
+ friend class ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>;
+ friend class ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>;
+ friend class ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>;
+
+ protected:
+ ::NActor::TQueueInActor<TRemoteConnection, TWriterToReaderSocketMessage, TReaderTag>* ReaderGetSocketQueue() {
+ return this;
+ }
+
+ ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TReconnectTag>* WriterGetReconnectQueue() {
+ return this;
+ }
+
+ ::NActor::TQueueInActor<TRemoteConnection, ui32, TWriterTag, TWakeReaderTag>* WriterGetWakeQueue() {
+ return this;
+ }
+
+ protected:
+ TRemoteConnection(TRemoteSessionPtr session, ui64 connectionId, TNetAddr addr);
+ ~TRemoteConnection() override;
+
+ virtual void ClearOutgoingQueue(TMessagesPtrs&, bool reconnect /* or shutdown */);
+
+ public:
+ void Send(TNonDestroyingAutoPtr<TBusMessage> msg);
+ void Shutdown(EMessageStatus status);
+
+ inline const TNetAddr& GetAddr() const noexcept;
+
+ private:
+ friend class TScheduleConnect;
+ friend class TWorkIO;
+
+ protected:
+ static size_t MessageSize(TArrayRef<TBusMessagePtrAndHeader>);
+ bool QuotaAcquire(size_t msg, size_t bytes);
+ void QuotaConsume(size_t msg, size_t bytes);
+ void QuotaReturnSelf(size_t items, size_t bytes);
+ bool QuotaReturnValues(size_t items, size_t bytes);
+
+ bool ReaderProcessBuffer();
+ bool ReaderFillBuffer();
+ void ReaderFlushMessages();
+
+ void ReadQuotaWakeup();
+ ui32 WriteWakeFlags() const;
+
+ virtual bool NeedInterruptRead() {
+ return false;
+ }
+
+ public:
+ virtual void TryConnect();
+ void ProcessItem(TReaderTag, ::NActor::TDefaultTag, TWriterToReaderSocketMessage);
+ void ProcessItem(TWriterTag, TReconnectTag, ui32 socketVersion);
+ void ProcessItem(TWriterTag, TWakeReaderTag, ui32 awakeFlags);
+ void Act(TReaderTag);
+ inline void WriterBeforeWriteErrorMessage(TBusMessage*, EMessageStatus);
+ void ClearBeforeSendQueue(EMessageStatus reasonForQueues);
+ void ClearReplyQueue(EMessageStatus reasonForQueues);
+ inline void ProcessBeforeSendQueueMessage(TBusMessage*, TInstant now);
+ void ProcessBeforeSendQueue(TInstant now);
+ void WriterProcessStatusDown();
+ void ReaderProcessStatusDown();
+ void ProcessWriterDown();
+ void DropEnqueuedData(EMessageStatus reason, EMessageStatus reasonForQueues);
+ const TRemoteConnectionWriterStatus& WriterGetStatus();
+ virtual void WriterFillStatus();
+ void WriterFillInFlight();
+ virtual void BeforeTryWrite();
+ void Act(TWriterTag);
+ void ScheduleRead();
+ void ScheduleWrite();
+ void ScheduleShutdownOnServerOrReconnectOnClient(EMessageStatus status, bool writer);
+ void ScheduleShutdown(EMessageStatus status);
+ void WriterFlushBuffer();
+ void WriterFillBuffer();
+ void ReaderSendStatus(TInstant now, bool force = false);
+ const TRemoteConnectionReaderStatus& ReaderFillStatus();
+ void WriterRotateCounters();
+ void WriterSendStatus(TInstant now, bool force = false);
+ void WriterSendStatusIfNecessary(TInstant now);
+ void QuotaReturnAside(size_t items, size_t bytes);
+ virtual void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) = 0;
+ bool MessageRead(TArrayRef<const char> dataRef, TInstant now);
+ virtual void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) = 0;
+
+ void CallSerialize(TBusMessage* msg, TBuffer& buffer) const;
+ void SerializeMessage(TBusMessage* msg, TBuffer* data, TMessageCounter* counter) const;
+ TBusMessage* DeserializeMessage(TArrayRef<const char> dataRef, const TBusHeader* header, TMessageCounter* messageCounter, EMessageStatus* status) const;
+
+ void ResetOneWayFlag(TArrayRef<TBusMessage*>);
+
+ inline ::NActor::TActor<TRemoteConnection, TWriterTag>* GetWriterActor() {
+ return this;
+ }
+ inline ::NActor::TActor<TRemoteConnection, TReaderTag>* GetReaderActor() {
+ return this;
+ }
+ inline TScheduleActor<TRemoteConnection, TWriterTag>* GetWriterSchedulerActor() {
+ return this;
+ }
+
+ void WriterErrorMessage(TNonDestroyingAutoPtr<TBusMessage> m, EMessageStatus status);
+ // takes ownership of ms
+ void WriterErrorMessages(const TArrayRef<TBusMessage*> ms, EMessageStatus status);
+
+ void FireClientConnectionEvent(TClientConnectionEvent::EType);
+
+ size_t GetInFlight();
+ size_t GetConnectSyscallsNumForTest();
+
+ bool IsReturnConnectFailedImmediately() {
+ return (bool)AtomicGet(ReturnConnectFailedImmediately);
+ }
+
+ bool IsAlive() const;
+
+ TRemoteSessionPtr Session;
+ TBusProtocol* const Proto;
+ TBusSessionConfig const Config;
+ bool RemovedFromSession;
+ const ui64 ConnectionId;
+ const TNetAddr PeerAddr;
+ const TBusSocketAddr PeerAddrSocketAddr;
+
+ const TInstant CreatedTime;
+ TInstant LastConnectAttempt;
+ TAtomic ReturnConnectFailedImmediately;
+
+ protected:
+ ::NActor::TQueueForActor<TBusMessage*> BeforeSendQueue;
+ TLockFreeStack<TBusHeader> WrongVersionRequests;
+
+ struct TWriterData {
+ TAtomic Down;
+
+ NEventLoop::TChannelPtr Channel;
+ ui32 SocketVersion;
+
+ TRemoteConnectionWriterStatus Status;
+ TInstant StatusLastSendTime;
+
+ TLocalTasks TimeToRotateCounters;
+
+ TAtomic InFlight;
+
+ TTimedMessages SendQueue;
+ ui32 AwakeFlags;
+ EWriterState State;
+ TLeftRightBuffer Buffer;
+ TInstant CorkUntil;
+
+ TSystemEvent ShutdownComplete;
+
+ void SetChannel(NEventLoop::TChannelPtr channel);
+ void DropChannel();
+
+ TWriterData();
+ ~TWriterData();
+ };
+
+ struct TReaderData {
+ TAtomic Down;
+
+ NEventLoop::TChannelPtr Channel;
+ ui32 SocketVersion;
+
+ TRemoteConnectionReaderStatus Status;
+ TInstant StatusLastSendTime;
+
+ TBuffer Buffer;
+ size_t Offset; /* offset in read buffer */
+ size_t MoreBytes; /* more bytes required from socket */
+ TVectorSwaps<TBusMessagePtrAndHeader> ReadMessages;
+
+ TSystemEvent ShutdownComplete;
+
+ bool BufferMore() const noexcept {
+ return MoreBytes > 0;
+ }
+
+ bool HasBytesInBuf(size_t bytes) noexcept;
+ void SetChannel(NEventLoop::TChannelPtr channel);
+ void DropChannel();
+
+ TReaderData();
+ ~TReaderData();
+ };
+
+ // owned by session status actor
+ struct TGranStatus {
+ TGranStatus(TDuration gran)
+ : Writer(gran)
+ , Reader(gran)
+ {
+ }
+
+ TGranUp<TRemoteConnectionWriterStatus> Writer;
+ TGranUp<TRemoteConnectionReaderStatus> Reader;
+ };
+
+ TWriterData WriterData;
+ TReaderData ReaderData;
+ TGranStatus GranStatus;
+ TTokenQuota QuotaMsg;
+ TTokenQuota QuotaBytes;
+
+ size_t MaxBufferSize;
+
+ // client connection only
+ TLockFreeQueueBatch<TBusMessagePtrAndHeader, TVectorSwaps> ReplyQueue;
+
+ EMessageStatus ShutdownReason;
+ };
+
+ inline const TNetAddr& TRemoteConnection::GetAddr() const noexcept {
+ return PeerAddr;
+ }
+
+ typedef TIntrusivePtr<TRemoteConnection> TRemoteConnectionPtr;
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_connection_status.cpp b/library/cpp/messagebus/remote_connection_status.cpp
new file mode 100644
index 0000000000..2c48b2a287
--- /dev/null
+++ b/library/cpp/messagebus/remote_connection_status.cpp
@@ -0,0 +1,265 @@
+#include "remote_connection_status.h"
+
+#include "key_value_printer.h"
+
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+#include <util/stream/format.h>
+#include <util/stream/output.h>
+#include <util/system/yassert.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+template <typename T>
+static void Add(T& thiz, const T& that) {
+ thiz += that;
+}
+
+template <typename T>
+static void Max(T& thiz, const T& that) {
+ if (that > thiz) {
+ thiz = that;
+ }
+}
+
+template <typename T>
+static void AssertZero(T& thiz, const T& that) {
+ Y_ASSERT(thiz == T());
+ Y_UNUSED(that);
+}
+
+TDurationCounter::TDurationCounter()
+ : DURATION_COUNTER_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TDuration TDurationCounter::AvgDuration() const {
+ if (Count == 0) {
+ return TDuration::Zero();
+ } else {
+ return SumDuration / Count;
+ }
+}
+
+TDurationCounter& TDurationCounter::operator+=(const TDurationCounter& that) {
+ DURATION_COUNTER_MAP(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TString TDurationCounter::ToString() const {
+ if (Count == 0) {
+ return "0";
+ } else {
+ TStringStream ss;
+ ss << "avg: " << AvgDuration() << ", max: " << MaxDuration << ", count: " << Count;
+ return ss.Str();
+ }
+}
+
+TRemoteConnectionStatusBase::TRemoteConnectionStatusBase()
+ : REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionStatusBase& TRemoteConnectionStatusBase ::operator+=(const TRemoteConnectionStatusBase& that) {
+ REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TRemoteConnectionIncrementalStatusBase::TRemoteConnectionIncrementalStatusBase()
+ : REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionIncrementalStatusBase& TRemoteConnectionIncrementalStatusBase::operator+=(
+ const TRemoteConnectionIncrementalStatusBase& that) {
+ REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TRemoteConnectionReaderIncrementalStatus::TRemoteConnectionReaderIncrementalStatus()
+ : REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionReaderIncrementalStatus& TRemoteConnectionReaderIncrementalStatus::operator+=(
+ const TRemoteConnectionReaderIncrementalStatus& that) {
+ TRemoteConnectionIncrementalStatusBase::operator+=(that);
+ REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TRemoteConnectionReaderStatus::TRemoteConnectionReaderStatus()
+ : REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionReaderStatus& TRemoteConnectionReaderStatus::operator+=(const TRemoteConnectionReaderStatus& that) {
+ TRemoteConnectionStatusBase::operator+=(that);
+ REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TRemoteConnectionWriterIncrementalStatus::TRemoteConnectionWriterIncrementalStatus()
+ : REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionWriterIncrementalStatus& TRemoteConnectionWriterIncrementalStatus::operator+=(
+ const TRemoteConnectionWriterIncrementalStatus& that) {
+ TRemoteConnectionIncrementalStatusBase::operator+=(that);
+ REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+TRemoteConnectionWriterStatus::TRemoteConnectionWriterStatus()
+ : REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TRemoteConnectionWriterStatus& TRemoteConnectionWriterStatus::operator+=(const TRemoteConnectionWriterStatus& that) {
+ TRemoteConnectionStatusBase::operator+=(that);
+ REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_ADD, )
+ return *this;
+}
+
+size_t TRemoteConnectionWriterStatus::GetInFlight() const {
+ return SendQueueSize + AckMessagesSize;
+}
+
+TConnectionStatusMonRecord TRemoteConnectionStatus::GetStatusProtobuf() const {
+ TConnectionStatusMonRecord status;
+
+ // TODO: fill unfilled fields
+ status.SetSendQueueSize(WriterStatus.SendQueueSize);
+ status.SetAckMessagesSize(WriterStatus.AckMessagesSize);
+ // status.SetErrorCount();
+ // status.SetWriteBytes();
+ // status.SetWriteBytesCompressed();
+ // status.SetWriteMessages();
+ status.SetWriteSyscalls(WriterStatus.Incremental.NetworkOps);
+ status.SetWriteActs(WriterStatus.Acts);
+ // status.SetReadBytes();
+ // status.SetReadBytesCompressed();
+ // status.SetReadMessages();
+ status.SetReadSyscalls(ReaderStatus.Incremental.NetworkOps);
+ status.SetReadActs(ReaderStatus.Acts);
+
+ TMessageStatusCounter sumStatusCounter;
+ sumStatusCounter += WriterStatus.Incremental.StatusCounter;
+ sumStatusCounter += ReaderStatus.Incremental.StatusCounter;
+ sumStatusCounter.FillErrorsProtobuf(&status);
+
+ return status;
+}
+
+TString TRemoteConnectionStatus::PrintToString() const {
+ TStringStream ss;
+
+ TKeyValuePrinter p;
+
+ if (!Summary) {
+ // TODO: print MyAddr too, but only if it is set
+ ss << WriterStatus.PeerAddr << " (" << WriterStatus.ConnectionId << ")"
+ << ", writefd=" << WriterStatus.Fd
+ << ", readfd=" << ReaderStatus.Fd
+ << Endl;
+ if (WriterStatus.Connected) {
+ p.AddRow("connect time", WriterStatus.ConnectTime.ToString());
+ p.AddRow("writer state", ToCString(WriterStatus.State));
+ } else {
+ ss << "not connected";
+ if (WriterStatus.ConnectError != 0) {
+ ss << ", last connect error: " << LastSystemErrorText(WriterStatus.ConnectError);
+ }
+ ss << Endl;
+ }
+ }
+ if (!Server) {
+ p.AddRow("connect syscalls", WriterStatus.ConnectSyscalls);
+ }
+
+ p.AddRow("send queue", LeftPad(WriterStatus.SendQueueSize, 6));
+
+ if (Server) {
+ p.AddRow("quota msg", LeftPad(ReaderStatus.QuotaMsg, 6));
+ p.AddRow("quota bytes", LeftPad(ReaderStatus.QuotaBytes, 6));
+ p.AddRow("quota exhausted", LeftPad(ReaderStatus.QuotaExhausted, 6));
+ p.AddRow("reader wakeups", LeftPad(WriterStatus.ReaderWakeups, 6));
+ } else {
+ p.AddRow("ack messages", LeftPad(WriterStatus.AckMessagesSize, 6));
+ }
+
+ p.AddRow("written", WriterStatus.Incremental.MessageCounter.ToString(false));
+ p.AddRow("read", ReaderStatus.Incremental.MessageCounter.ToString(true));
+
+ p.AddRow("write syscalls", LeftPad(WriterStatus.Incremental.NetworkOps, 12));
+ p.AddRow("read syscalls", LeftPad(ReaderStatus.Incremental.NetworkOps, 12));
+
+ p.AddRow("write acts", LeftPad(WriterStatus.Acts, 12));
+ p.AddRow("read acts", LeftPad(ReaderStatus.Acts, 12));
+
+ p.AddRow("write buffer cap", LeftPad(WriterStatus.BufferSize, 12));
+ p.AddRow("read buffer cap", LeftPad(ReaderStatus.BufferSize, 12));
+
+ p.AddRow("write buffer drops", LeftPad(WriterStatus.Incremental.BufferDrops, 10));
+ p.AddRow("read buffer drops", LeftPad(ReaderStatus.Incremental.BufferDrops, 10));
+
+ if (Server) {
+ p.AddRow("process dur", WriterStatus.DurationCounterPrev.ToString());
+ }
+
+ ss << p.PrintToString();
+
+ if (false && Server) {
+ ss << "time histogram:\n";
+ ss << WriterStatus.Incremental.ProcessDurationHistogram.PrintToString();
+ }
+
+ TMessageStatusCounter sumStatusCounter;
+ sumStatusCounter += WriterStatus.Incremental.StatusCounter;
+ sumStatusCounter += ReaderStatus.Incremental.StatusCounter;
+
+ ss << sumStatusCounter.PrintToString();
+
+ return ss.Str();
+}
+
+TRemoteConnectionStatus::TRemoteConnectionStatus()
+ : REMOTE_CONNECTION_STATUS_MAP(STRUCT_FIELD_INIT_DEFAULT, COMMA)
+{
+}
+
+TString TSessionDumpStatus::PrintToString() const {
+ if (Shutdown) {
+ return "shutdown";
+ }
+
+ TStringStream ss;
+ ss << Head;
+ if (ConnectionStatusSummary.Server) {
+ ss << "\n";
+ ss << Acceptors;
+ }
+ ss << "\n";
+ ss << "connections summary:" << Endl;
+ ss << ConnectionsSummary;
+ if (!!Connections) {
+ ss << "\n";
+ ss << Connections;
+ }
+ ss << "\n";
+ ss << Config.PrintToString();
+ return ss.Str();
+}
+
+TString TBusMessageQueueStatus::PrintToString() const {
+ TStringStream ss;
+ ss << "work queue:\n";
+ ss << ExecutorStatus.Status;
+ ss << "\n";
+ ss << "queue config:\n";
+ ss << Config.PrintToString();
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/remote_connection_status.h b/library/cpp/messagebus/remote_connection_status.h
new file mode 100644
index 0000000000..5db10e51ea
--- /dev/null
+++ b/library/cpp/messagebus/remote_connection_status.h
@@ -0,0 +1,214 @@
+#pragma once
+
+#include "codegen.h"
+#include "duration_histogram.h"
+#include "message_counter.h"
+#include "message_status_counter.h"
+#include "queue_config.h"
+#include "session_config.h"
+
+#include <library/cpp/messagebus/actor/executor.h>
+
+#include <library/cpp/deprecated/enum_codegen/enum_codegen.h>
+
+namespace NBus {
+ class TConnectionStatusMonRecord;
+}
+
+namespace NBus {
+ namespace NPrivate {
+#define WRITER_STATE_MAP(XX) \
+ XX(WRITER_UNKNOWN) \
+ XX(WRITER_FILLING) \
+ XX(WRITER_FLUSHING) \
+ /**/
+
+ // TODO: move elsewhere
+ enum EWriterState {
+ WRITER_STATE_MAP(ENUM_VALUE_GEN_NO_VALUE)
+ };
+
+ ENUM_TO_STRING(EWriterState, WRITER_STATE_MAP)
+
+#define STRUCT_FIELD_ADD(name, type, func) func(name, that.name);
+
+ template <typename T>
+ void Reset(T& t) {
+ t.~T();
+ new (&t) T();
+ }
+
+#define DURATION_COUNTER_MAP(XX, comma) \
+ XX(Count, unsigned, Add) \
+ comma \
+ XX(SumDuration, TDuration, Add) comma \
+ XX(MaxDuration, TDuration, Max) /**/
+
+ struct TDurationCounter {
+ DURATION_COUNTER_MAP(STRUCT_FIELD_GEN, )
+
+ TDuration AvgDuration() const;
+
+ TDurationCounter();
+
+ void AddDuration(TDuration d) {
+ Count += 1;
+ SumDuration += d;
+ if (d > MaxDuration) {
+ MaxDuration = d;
+ }
+ }
+
+ TDurationCounter& operator+=(const TDurationCounter&);
+
+ TString ToString() const;
+ };
+
+#define REMOTE_CONNECTION_STATUS_BASE_MAP(XX, comma) \
+ XX(ConnectionId, ui64, AssertZero) \
+ comma \
+ XX(Fd, SOCKET, AssertZero) comma \
+ XX(Acts, ui64, Add) comma \
+ XX(BufferSize, ui64, Add) /**/
+
+ struct TRemoteConnectionStatusBase {
+ REMOTE_CONNECTION_STATUS_BASE_MAP(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionStatusBase& operator+=(const TRemoteConnectionStatusBase&);
+
+ TRemoteConnectionStatusBase();
+ };
+
+#define REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(XX, comma) \
+ XX(BufferDrops, unsigned, Add) \
+ comma \
+ XX(NetworkOps, unsigned, Add) /**/
+
+ struct TRemoteConnectionIncrementalStatusBase {
+ REMOTE_CONNECTION_INCREMENTAL_STATUS_BASE_MAP(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionIncrementalStatusBase& operator+=(const TRemoteConnectionIncrementalStatusBase&);
+
+ TRemoteConnectionIncrementalStatusBase();
+ };
+
+#define REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(XX, comma) \
+ XX(MessageCounter, TMessageCounter, Add) \
+ comma \
+ XX(StatusCounter, TMessageStatusCounter, Add) /**/
+
+ struct TRemoteConnectionReaderIncrementalStatus: public TRemoteConnectionIncrementalStatusBase {
+ REMOTE_CONNECTION_READER_INCREMENTAL_STATUS_MAP(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionReaderIncrementalStatus& operator+=(const TRemoteConnectionReaderIncrementalStatus&);
+
+ TRemoteConnectionReaderIncrementalStatus();
+ };
+
+#define REMOTE_CONNECTION_READER_STATUS_MAP(XX, comma) \
+ XX(QuotaMsg, size_t, Add) \
+ comma \
+ XX(QuotaBytes, size_t, Add) comma \
+ XX(QuotaExhausted, size_t, Add) comma \
+ XX(Incremental, TRemoteConnectionReaderIncrementalStatus, Add) /**/
+
+ struct TRemoteConnectionReaderStatus: public TRemoteConnectionStatusBase {
+ REMOTE_CONNECTION_READER_STATUS_MAP(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionReaderStatus& operator+=(const TRemoteConnectionReaderStatus&);
+
+ TRemoteConnectionReaderStatus();
+ };
+
+#define REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(XX, comma) \
+ XX(MessageCounter, TMessageCounter, Add) \
+ comma \
+ XX(StatusCounter, TMessageStatusCounter, Add) comma \
+ XX(ProcessDurationHistogram, TDurationHistogram, Add) /**/
+
+ struct TRemoteConnectionWriterIncrementalStatus: public TRemoteConnectionIncrementalStatusBase {
+ REMOTE_CONNECTION_WRITER_INCREMENTAL_STATUS(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionWriterIncrementalStatus& operator+=(const TRemoteConnectionWriterIncrementalStatus&);
+
+ TRemoteConnectionWriterIncrementalStatus();
+ };
+
+#define REMOTE_CONNECTION_WRITER_STATUS(XX, comma) \
+ XX(Connected, bool, AssertZero) \
+ comma \
+ XX(ConnectTime, TInstant, AssertZero) comma /* either connect time on client or accept time on server */ \
+ XX(ConnectError, int, AssertZero) comma \
+ XX(ConnectSyscalls, unsigned, Add) comma \
+ XX(PeerAddr, TNetAddr, AssertZero) comma \
+ XX(MyAddr, TNetAddr, AssertZero) comma \
+ XX(State, EWriterState, AssertZero) comma \
+ XX(SendQueueSize, size_t, Add) comma \
+ XX(AckMessagesSize, size_t, Add) comma /* client only */ \
+ XX(DurationCounter, TDurationCounter, Add) comma /* server only */ \
+ XX(DurationCounterPrev, TDurationCounter, Add) comma /* server only */ \
+ XX(Incremental, TRemoteConnectionWriterIncrementalStatus, Add) comma \
+ XX(ReaderWakeups, size_t, Add) /**/
+
+ struct TRemoteConnectionWriterStatus: public TRemoteConnectionStatusBase {
+ REMOTE_CONNECTION_WRITER_STATUS(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionWriterStatus();
+
+ TRemoteConnectionWriterStatus& operator+=(const TRemoteConnectionWriterStatus&);
+
+ size_t GetInFlight() const;
+ };
+
+#define REMOTE_CONNECTION_STATUS_MAP(XX, comma) \
+ XX(Summary, bool) \
+ comma \
+ XX(Server, bool) /**/
+
+ struct TRemoteConnectionStatus {
+ REMOTE_CONNECTION_STATUS_MAP(STRUCT_FIELD_GEN, )
+
+ TRemoteConnectionReaderStatus ReaderStatus;
+ TRemoteConnectionWriterStatus WriterStatus;
+
+ TRemoteConnectionStatus();
+
+ TString PrintToString() const;
+ TConnectionStatusMonRecord GetStatusProtobuf() const;
+ };
+
+ struct TBusSessionStatus {
+ size_t InFlightCount;
+ size_t InFlightSize;
+ bool InputPaused;
+
+ TBusSessionStatus();
+ };
+
+ struct TSessionDumpStatus {
+ bool Shutdown;
+ TString Head;
+ TString Acceptors;
+ TString ConnectionsSummary;
+ TString Connections;
+ TBusSessionStatus Status;
+ TRemoteConnectionStatus ConnectionStatusSummary;
+ TBusSessionConfig Config;
+
+ TSessionDumpStatus()
+ : Shutdown(false)
+ {
+ }
+
+ TString PrintToString() const;
+ };
+
+ // without sessions
+ struct TBusMessageQueueStatus {
+ NActor::NPrivate::TExecutorStatus ExecutorStatus;
+ TBusQueueConfig Config;
+
+ TString PrintToString() const;
+ };
+ }
+}
diff --git a/library/cpp/messagebus/remote_server_connection.cpp b/library/cpp/messagebus/remote_server_connection.cpp
new file mode 100644
index 0000000000..74be34ded9
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_connection.cpp
@@ -0,0 +1,73 @@
+#include "remote_server_connection.h"
+
+#include "mb_lwtrace.h"
+#include "remote_server_session.h"
+
+#include <util/generic/cast.h>
+
+LWTRACE_USING(LWTRACE_MESSAGEBUS_PROVIDER)
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteServerConnection::TRemoteServerConnection(TRemoteServerSessionPtr session, ui64 id, TNetAddr addr)
+ : TRemoteConnection(session.Get(), id, addr)
+{
+}
+
+void TRemoteServerConnection::Init(SOCKET socket, TInstant now) {
+ WriterData.Status.ConnectTime = now;
+ WriterData.Status.Connected = true;
+
+ Y_VERIFY(socket != INVALID_SOCKET, "must be a valid socket");
+
+ TSocket readSocket(socket);
+ TSocket writeSocket = readSocket;
+
+ // this must not be done in constructor, because if event loop is stopped,
+ // this is deleted
+ WriterData.SetChannel(Session->WriteEventLoop.Register(writeSocket, this, WriteCookie));
+ WriterData.SocketVersion = 1;
+
+ ReaderGetSocketQueue()->EnqueueAndSchedule(TWriterToReaderSocketMessage(readSocket, WriterData.SocketVersion));
+}
+
+TRemoteServerSession* TRemoteServerConnection::GetSession() {
+ return CheckedCast<TRemoteServerSession*>(Session.Get());
+}
+
+void TRemoteServerConnection::HandleEvent(SOCKET socket, void* cookie) {
+ Y_UNUSED(socket);
+ Y_ASSERT(cookie == ReadCookie || cookie == WriteCookie);
+ if (cookie == ReadCookie) {
+ GetSession()->ServerOwnedMessages.Wait();
+ ScheduleRead();
+ } else {
+ ScheduleWrite();
+ }
+}
+
+bool TRemoteServerConnection::NeedInterruptRead() {
+ return !GetSession()->ServerOwnedMessages.TryWait();
+}
+
+void TRemoteServerConnection::MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) {
+ TInstant now = TInstant::Now();
+
+ GetSession()->ReleaseInWorkResponses(messages);
+ for (auto& message : messages) {
+ TInstant recvTime = message.MessagePtr->RecvTime;
+ GetSession()->ServerHandler->OnSent(message.MessagePtr.Release());
+ TDuration d = now - recvTime;
+ WriterData.Status.DurationCounter.AddDuration(d);
+ WriterData.Status.Incremental.ProcessDurationHistogram.AddTime(d);
+ }
+}
+
+void TRemoteServerConnection::ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) {
+ TBusHeader header(dataRef);
+ // TODO: full version hex
+ LWPROBE(ServerUnknownVersion, ToString(PeerAddr), header.GetVersionInternal());
+ WrongVersionRequests.Enqueue(header);
+ GetWriterActor()->Schedule();
+}
diff --git a/library/cpp/messagebus/remote_server_connection.h b/library/cpp/messagebus/remote_server_connection.h
new file mode 100644
index 0000000000..63d7f20646
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_connection.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include "session_impl.h"
+
+#include <util/generic/object_counter.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteServerConnection: public TRemoteConnection {
+ friend struct TBusSessionImpl;
+ friend class TRemoteServerSession;
+
+ TObjectCounter<TRemoteServerConnection> ObjectCounter;
+
+ public:
+ TRemoteServerConnection(TRemoteServerSessionPtr session, ui64 id, TNetAddr addr);
+
+ void Init(SOCKET socket, TInstant now);
+
+ inline TRemoteServerSession* GetSession();
+
+ void HandleEvent(SOCKET socket, void* cookie) override;
+
+ bool NeedInterruptRead() override;
+
+ void MessageSent(TArrayRef<TBusMessagePtrAndHeader> messages) override;
+
+ void ReaderProcessMessageUnknownVersion(TArrayRef<const char> dataRef) override;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_server_session.cpp b/library/cpp/messagebus/remote_server_session.cpp
new file mode 100644
index 0000000000..6abbf88a60
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_session.cpp
@@ -0,0 +1,206 @@
+#include "remote_server_session.h"
+
+#include "remote_connection.h"
+#include "remote_server_connection.h"
+
+#include <library/cpp/messagebus/actor/temp_tls_vector.h>
+
+#include <util/generic/cast.h>
+#include <util/stream/output.h>
+#include <util/system/yassert.h>
+
+#include <typeinfo>
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteServerSession::TRemoteServerSession(TBusMessageQueue* queue,
+ TBusProtocol* proto, IBusServerHandler* handler,
+ const TBusServerSessionConfig& config, const TString& name)
+ : TBusSessionImpl(false, queue, proto, handler, config, name)
+ , ServerOwnedMessages(config.MaxInFlight, config.MaxInFlightBySize, "ServerOwnedMessages")
+ , ServerHandler(handler)
+{
+ if (config.PerConnectionMaxInFlightBySize > 0) {
+ if (config.PerConnectionMaxInFlightBySize < config.MaxMessageSize)
+ ythrow yexception()
+ << "too low PerConnectionMaxInFlightBySize value";
+ }
+}
+
+namespace NBus {
+ namespace NPrivate {
+ class TInvokeOnMessage: public IWorkItem {
+ private:
+ TRemoteServerSession* RemoteServerSession;
+ TBusMessagePtrAndHeader Request;
+ TIntrusivePtr<TRemoteServerConnection> Connection;
+
+ public:
+ TInvokeOnMessage(TRemoteServerSession* session, TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& connection)
+ : RemoteServerSession(session)
+ {
+ Y_ASSERT(!!connection);
+ Connection.Swap(connection);
+
+ Request.Swap(request);
+ }
+
+ void DoWork() override {
+ THolder<TInvokeOnMessage> holder(this);
+ RemoteServerSession->InvokeOnMessage(Request, Connection);
+ // TODO: TRemoteServerSessionSemaphore should be enough
+ RemoteServerSession->JobCount.Decrement();
+ }
+ };
+
+ }
+}
+
+void TRemoteServerSession::OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& messages) {
+ AcquireInWorkRequests(messages);
+
+ bool executeInPool = Config.ExecuteOnMessageInWorkerPool;
+
+ TTempTlsVector< ::IWorkItem*> workQueueTemp;
+
+ if (executeInPool) {
+ workQueueTemp.GetVector()->reserve(messages.size());
+ }
+
+ for (auto& message : messages) {
+ // TODO: incref once
+ TIntrusivePtr<TRemoteServerConnection> connection(CheckedCast<TRemoteServerConnection*>(c));
+ if (executeInPool) {
+ workQueueTemp.GetVector()->push_back(new TInvokeOnMessage(this, message, connection));
+ } else {
+ InvokeOnMessage(message, connection);
+ }
+ }
+
+ if (executeInPool) {
+ JobCount.Add(workQueueTemp.GetVector()->size());
+ Queue->EnqueueWork(*workQueueTemp.GetVector());
+ }
+}
+
+void TRemoteServerSession::InvokeOnMessage(TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& conn) {
+ if (Y_UNLIKELY(AtomicGet(Down))) {
+ ReleaseInWorkRequests(*conn.Get(), request.MessagePtr.Get());
+ InvokeOnError(request.MessagePtr.Release(), MESSAGE_SHUTDOWN);
+ } else {
+ TWhatThreadDoesPushPop pp("OnMessage");
+
+ TBusIdentity ident;
+
+ ident.Connection.Swap(conn);
+ request.MessagePtr->GetIdentity(ident);
+
+ Y_ASSERT(request.MessagePtr->LocalFlags & MESSAGE_IN_WORK);
+ DoSwap(request.MessagePtr->LocalFlags, ident.LocalFlags);
+
+ ident.RecvTime = request.MessagePtr->RecvTime;
+
+#ifndef NDEBUG
+ auto& message = *request.MessagePtr;
+ ident.SetMessageType(typeid(message));
+#endif
+
+ TOnMessageContext context(request.MessagePtr.Release(), ident, this);
+ ServerHandler->OnMessage(context);
+ }
+}
+
+EMessageStatus TRemoteServerSession::ForgetRequest(const TBusIdentity& ident) {
+ ReleaseInWork(const_cast<TBusIdentity&>(ident));
+
+ return MESSAGE_OK;
+}
+
+EMessageStatus TRemoteServerSession::SendReply(const TBusIdentity& ident, TBusMessage* reply) {
+ reply->CheckClean();
+
+ ConvertInWork(const_cast<TBusIdentity&>(ident), reply);
+
+ reply->RecvTime = ident.RecvTime;
+
+ ident.Connection->Send(reply);
+
+ return MESSAGE_OK;
+}
+
+int TRemoteServerSession::GetInFlight() const noexcept {
+ return ServerOwnedMessages.GetCurrentCount();
+}
+
+void TRemoteServerSession::FillStatus() {
+ TBusSessionImpl::FillStatus();
+
+ // TODO: weird
+ StatusData.Status.InFlightCount = ServerOwnedMessages.GetCurrentCount();
+ StatusData.Status.InFlightSize = ServerOwnedMessages.GetCurrentSize();
+ StatusData.Status.InputPaused = ServerOwnedMessages.IsLocked();
+}
+
+void TRemoteServerSession::AcquireInWorkRequests(TArrayRef<const TBusMessagePtrAndHeader> messages) {
+ TAtomicBase size = 0;
+ for (auto message = messages.begin(); message != messages.end(); ++message) {
+ Y_ASSERT(!(message->MessagePtr->LocalFlags & MESSAGE_IN_WORK));
+ message->MessagePtr->LocalFlags |= MESSAGE_IN_WORK;
+ size += message->MessagePtr->GetHeader()->Size;
+ }
+
+ ServerOwnedMessages.IncrementMultiple(messages.size(), size);
+}
+
+void TRemoteServerSession::ReleaseInWorkResponses(TArrayRef<const TBusMessagePtrAndHeader> responses) {
+ TAtomicBase size = 0;
+ for (auto response = responses.begin(); response != responses.end(); ++response) {
+ Y_ASSERT((response->MessagePtr->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT));
+ response->MessagePtr->LocalFlags &= ~MESSAGE_REPLY_IS_BEGING_SENT;
+ size += response->MessagePtr->RequestSize;
+ }
+
+ ServerOwnedMessages.ReleaseMultiple(responses.size(), size);
+}
+
+void TRemoteServerSession::ReleaseInWorkRequests(TRemoteConnection& con, TBusMessage* request) {
+ Y_ASSERT((request->LocalFlags & MESSAGE_IN_WORK));
+ request->LocalFlags &= ~MESSAGE_IN_WORK;
+
+ const size_t size = request->GetHeader()->Size;
+
+ con.QuotaReturnAside(1, size);
+ ServerOwnedMessages.ReleaseMultiple(1, size);
+}
+
+void TRemoteServerSession::ReleaseInWork(TBusIdentity& ident) {
+ ident.SetInWork(false);
+ ident.Connection->QuotaReturnAside(1, ident.Size);
+
+ ServerOwnedMessages.ReleaseMultiple(1, ident.Size);
+}
+
+void TRemoteServerSession::ConvertInWork(TBusIdentity& req, TBusMessage* reply) {
+ reply->SetIdentity(req);
+
+ req.SetInWork(false);
+ Y_ASSERT(!(reply->LocalFlags & MESSAGE_REPLY_IS_BEGING_SENT));
+ reply->LocalFlags |= MESSAGE_REPLY_IS_BEGING_SENT;
+ reply->RequestSize = req.Size;
+}
+
+void TRemoteServerSession::Shutdown() {
+ ServerOwnedMessages.Stop();
+ TBusSessionImpl::Shutdown();
+}
+
+void TRemoteServerSession::PauseInput(bool pause) {
+ ServerOwnedMessages.PauseByUsed(pause);
+}
+
+unsigned TRemoteServerSession::GetActualListenPort() {
+ Y_VERIFY(Config.ListenPort > 0, "state check");
+ return Config.ListenPort;
+}
diff --git a/library/cpp/messagebus/remote_server_session.h b/library/cpp/messagebus/remote_server_session.h
new file mode 100644
index 0000000000..f5c266a7f7
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_session.h
@@ -0,0 +1,54 @@
+#pragma once
+
+#include "remote_server_session_semaphore.h"
+#include "session_impl.h"
+
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance
+#endif
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteServerSession: public TBusServerSession, public TBusSessionImpl {
+ friend class TRemoteServerConnection;
+
+ private:
+ TObjectCounter<TRemoteServerSession> ObjectCounter;
+
+ TRemoteServerSessionSemaphore ServerOwnedMessages;
+ IBusServerHandler* const ServerHandler;
+
+ public:
+ TRemoteServerSession(TBusMessageQueue* queue, TBusProtocol* proto,
+ IBusServerHandler* handler,
+ const TBusSessionConfig& config, const TString& name);
+
+ void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) override;
+ void InvokeOnMessage(TBusMessagePtrAndHeader& request, TIntrusivePtr<TRemoteServerConnection>& conn);
+
+ EMessageStatus SendReply(const TBusIdentity& ident, TBusMessage* pRep) override;
+
+ EMessageStatus ForgetRequest(const TBusIdentity& ident) override;
+
+ int GetInFlight() const noexcept override;
+ void FillStatus() override;
+
+ void Shutdown() override;
+
+ void PauseInput(bool pause) override;
+ unsigned GetActualListenPort() override;
+
+ void AcquireInWorkRequests(TArrayRef<const TBusMessagePtrAndHeader> requests);
+ void ReleaseInWorkResponses(TArrayRef<const TBusMessagePtrAndHeader> responses);
+ void ReleaseInWorkRequests(TRemoteConnection&, TBusMessage*);
+ void ReleaseInWork(TBusIdentity&);
+ void ConvertInWork(TBusIdentity& req, TBusMessage* reply);
+ };
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+ }
+}
diff --git a/library/cpp/messagebus/remote_server_session_semaphore.cpp b/library/cpp/messagebus/remote_server_session_semaphore.cpp
new file mode 100644
index 0000000000..6094a3586e
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_session_semaphore.cpp
@@ -0,0 +1,59 @@
+#include "remote_server_session_semaphore.h"
+
+#include <util/stream/output.h>
+#include <util/system/yassert.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TRemoteServerSessionSemaphore::TRemoteServerSessionSemaphore(
+ TAtomicBase limitCount, TAtomicBase limitSize, const char* name)
+ : Name(name)
+ , LimitCount(limitCount)
+ , LimitSize(limitSize)
+ , CurrentCount(0)
+ , CurrentSize(0)
+ , PausedByUser(0)
+ , StopSignal(0)
+{
+ Y_VERIFY(limitCount > 0, "limit must be > 0");
+ Y_UNUSED(Name);
+}
+
+TRemoteServerSessionSemaphore::~TRemoteServerSessionSemaphore() {
+ Y_VERIFY(AtomicGet(CurrentCount) == 0);
+ // TODO: fix spider and enable
+ //Y_VERIFY(AtomicGet(CurrentSize) == 0);
+}
+
+bool TRemoteServerSessionSemaphore::TryWait() {
+ if (Y_UNLIKELY(AtomicGet(StopSignal)))
+ return true;
+ if (AtomicGet(PausedByUser))
+ return false;
+ if (AtomicGet(CurrentCount) < LimitCount && (LimitSize < 0 || AtomicGet(CurrentSize) < LimitSize))
+ return true;
+ return false;
+}
+
+void TRemoteServerSessionSemaphore::IncrementMultiple(TAtomicBase count, TAtomicBase size) {
+ AtomicAdd(CurrentCount, count);
+ AtomicAdd(CurrentSize, size);
+ Updated();
+}
+
+void TRemoteServerSessionSemaphore::ReleaseMultiple(TAtomicBase count, TAtomicBase size) {
+ AtomicSub(CurrentCount, count);
+ AtomicSub(CurrentSize, size);
+ Updated();
+}
+
+void TRemoteServerSessionSemaphore::Stop() {
+ AtomicSet(StopSignal, 1);
+ Updated();
+}
+
+void TRemoteServerSessionSemaphore::PauseByUsed(bool pause) {
+ AtomicSet(PausedByUser, pause);
+ Updated();
+}
diff --git a/library/cpp/messagebus/remote_server_session_semaphore.h b/library/cpp/messagebus/remote_server_session_semaphore.h
new file mode 100644
index 0000000000..de714fd342
--- /dev/null
+++ b/library/cpp/messagebus/remote_server_session_semaphore.h
@@ -0,0 +1,42 @@
+#pragma once
+
+#include "cc_semaphore.h"
+
+#include <util/generic/noncopyable.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TRemoteServerSessionSemaphore: public TComplexConditionSemaphore<TRemoteServerSessionSemaphore> {
+ private:
+ const char* const Name;
+
+ TAtomicBase const LimitCount;
+ TAtomicBase const LimitSize;
+ TAtomic CurrentCount;
+ TAtomic CurrentSize;
+ TAtomic PausedByUser;
+ TAtomic StopSignal;
+
+ public:
+ TRemoteServerSessionSemaphore(TAtomicBase limitCount, TAtomicBase limitSize, const char* name = "unnamed");
+ ~TRemoteServerSessionSemaphore();
+
+ TAtomicBase GetCurrentCount() const {
+ return AtomicGet(CurrentCount);
+ }
+ TAtomicBase GetCurrentSize() const {
+ return AtomicGet(CurrentSize);
+ }
+
+ void IncrementMultiple(TAtomicBase count, TAtomicBase size);
+ bool TryWait();
+ void ReleaseMultiple(TAtomicBase count, TAtomicBase size);
+ void Stop();
+ void PauseByUsed(bool pause);
+
+ private:
+ void CheckNeedToUnlock();
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/scheduler/scheduler.cpp b/library/cpp/messagebus/scheduler/scheduler.cpp
new file mode 100644
index 0000000000..5a5fe52894
--- /dev/null
+++ b/library/cpp/messagebus/scheduler/scheduler.cpp
@@ -0,0 +1,119 @@
+#include "scheduler.h"
+
+#include <util/datetime/base.h>
+#include <util/generic/algorithm.h>
+#include <util/generic/yexception.h>
+
+//#include "dummy_debugger.h"
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+class TScheduleDeadlineCompare {
+public:
+ bool operator()(const IScheduleItemAutoPtr& i1, const IScheduleItemAutoPtr& i2) const noexcept {
+ return i1->GetScheduleTime() > i2->GetScheduleTime();
+ }
+};
+
+TScheduler::TScheduler()
+ : StopThread(false)
+ , Thread([&] { this->SchedulerThread(); })
+{
+}
+
+TScheduler::~TScheduler() {
+ Y_VERIFY(StopThread, "state check");
+}
+
+size_t TScheduler::Size() const {
+ TGuard<TLock> guard(Lock);
+ return Items.size() + (!!NextItem ? 1 : 0);
+}
+
+void TScheduler::Stop() {
+ {
+ TGuard<TLock> guard(Lock);
+ Y_VERIFY(!StopThread, "Scheduler already stopped");
+ StopThread = true;
+ CondVar.Signal();
+ }
+ Thread.Get();
+
+ if (!!NextItem) {
+ NextItem.Destroy();
+ }
+
+ for (auto& item : Items) {
+ item.Destroy();
+ }
+}
+
+void TScheduler::Schedule(TAutoPtr<IScheduleItem> i) {
+ TGuard<TLock> lock(Lock);
+ if (StopThread)
+ return;
+
+ if (!!NextItem) {
+ if (i->GetScheduleTime() < NextItem->GetScheduleTime()) {
+ DoSwap(i, NextItem);
+ }
+ }
+
+ Items.push_back(i);
+ PushHeap(Items.begin(), Items.end(), TScheduleDeadlineCompare());
+
+ FillNextItem();
+
+ CondVar.Signal();
+}
+
+void TScheduler::FillNextItem() {
+ if (!NextItem && !Items.empty()) {
+ PopHeap(Items.begin(), Items.end(), TScheduleDeadlineCompare());
+ NextItem = Items.back();
+ Items.erase(Items.end() - 1);
+ }
+}
+
+void TScheduler::SchedulerThread() {
+ for (;;) {
+ IScheduleItemAutoPtr current;
+
+ {
+ TGuard<TLock> guard(Lock);
+
+ if (StopThread) {
+ break;
+ }
+
+ if (!!NextItem) {
+ CondVar.WaitD(Lock, NextItem->GetScheduleTime());
+ } else {
+ CondVar.WaitI(Lock);
+ }
+
+ if (StopThread) {
+ break;
+ }
+
+ // signal comes if either scheduler is to be stopped of there's work to do
+ Y_VERIFY(!!NextItem, "state check");
+
+ if (TInstant::Now() < NextItem->GetScheduleTime()) {
+ // NextItem is updated since WaitD
+ continue;
+ }
+
+ current = NextItem.Release();
+ }
+
+ current->Do();
+ current.Destroy();
+
+ {
+ TGuard<TLock> guard(Lock);
+ FillNextItem();
+ }
+ }
+}
diff --git a/library/cpp/messagebus/scheduler/scheduler.h b/library/cpp/messagebus/scheduler/scheduler.h
new file mode 100644
index 0000000000..afcc0de55d
--- /dev/null
+++ b/library/cpp/messagebus/scheduler/scheduler.h
@@ -0,0 +1,68 @@
+#pragma once
+
+#include <library/cpp/threading/future/legacy_future.h>
+
+#include <util/datetime/base.h>
+#include <util/generic/object_counter.h>
+#include <util/generic/ptr.h>
+#include <util/generic/vector.h>
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+#include <util/system/thread.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class IScheduleItem {
+ public:
+ inline IScheduleItem(TInstant scheduleTime) noexcept;
+ virtual ~IScheduleItem() {
+ }
+
+ virtual void Do() = 0;
+ inline TInstant GetScheduleTime() const noexcept;
+
+ private:
+ TInstant ScheduleTime;
+ };
+
+ using IScheduleItemAutoPtr = TAutoPtr<IScheduleItem>;
+
+ class TScheduler {
+ public:
+ TScheduler();
+ ~TScheduler();
+ void Stop();
+ void Schedule(TAutoPtr<IScheduleItem> i);
+
+ size_t Size() const;
+
+ private:
+ void SchedulerThread();
+
+ void FillNextItem();
+
+ private:
+ TVector<IScheduleItemAutoPtr> Items;
+ IScheduleItemAutoPtr NextItem;
+ typedef TMutex TLock;
+ TLock Lock;
+ TCondVar CondVar;
+
+ TObjectCounter<TScheduler> ObjectCounter;
+
+ bool StopThread;
+ NThreading::TLegacyFuture<> Thread;
+ };
+
+ inline IScheduleItem::IScheduleItem(TInstant scheduleTime) noexcept
+ : ScheduleTime(scheduleTime)
+ {
+ }
+
+ inline TInstant IScheduleItem::GetScheduleTime() const noexcept {
+ return ScheduleTime;
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/scheduler/scheduler_ut.cpp b/library/cpp/messagebus/scheduler/scheduler_ut.cpp
new file mode 100644
index 0000000000..a5ea641c10
--- /dev/null
+++ b/library/cpp/messagebus/scheduler/scheduler_ut.cpp
@@ -0,0 +1,36 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "scheduler.h"
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+Y_UNIT_TEST_SUITE(TSchedulerTests) {
+ struct TSimpleScheduleItem: public IScheduleItem {
+ TTestSync* const TestSync;
+
+ TSimpleScheduleItem(TTestSync* testSync)
+ : IScheduleItem((TInstant::Now() + TDuration::MilliSeconds(1)))
+ , TestSync(testSync)
+ {
+ }
+
+ void Do() override {
+ TestSync->WaitForAndIncrement(0);
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TTestSync testSync;
+
+ TScheduler scheduler;
+
+ scheduler.Schedule(new TSimpleScheduleItem(&testSync));
+
+ testSync.WaitForAndIncrement(1);
+
+ scheduler.Stop();
+ }
+}
diff --git a/library/cpp/messagebus/scheduler/ya.make b/library/cpp/messagebus/scheduler/ya.make
new file mode 100644
index 0000000000..dcb7408a20
--- /dev/null
+++ b/library/cpp/messagebus/scheduler/ya.make
@@ -0,0 +1,13 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/threading/future
+)
+
+SRCS(
+ scheduler.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/scheduler_actor.h b/library/cpp/messagebus/scheduler_actor.h
new file mode 100644
index 0000000000..d0c23c94c4
--- /dev/null
+++ b/library/cpp/messagebus/scheduler_actor.h
@@ -0,0 +1,85 @@
+#pragma once
+
+#include "local_tasks.h"
+
+#include <library/cpp/messagebus/actor/actor.h>
+#include <library/cpp/messagebus/actor/what_thread_does_guard.h>
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <util/system/mutex.h>
+
+namespace NBus {
+ namespace NPrivate {
+ template <typename TThis, typename TTag = NActor::TDefaultTag>
+ class TScheduleActor {
+ typedef NActor::TActor<TThis, TTag> TActorForMe;
+
+ private:
+ TScheduler* const Scheduler;
+
+ TMutex Mutex;
+
+ TInstant ScheduleTime;
+
+ public:
+ TLocalTasks Alarm;
+
+ private:
+ struct TScheduleItemImpl: public IScheduleItem {
+ TIntrusivePtr<TThis> Thiz;
+
+ TScheduleItemImpl(TIntrusivePtr<TThis> thiz, TInstant when)
+ : IScheduleItem(when)
+ , Thiz(thiz)
+ {
+ }
+
+ void Do() override {
+ {
+ TWhatThreadDoesAcquireGuard<TMutex> guard(Thiz->Mutex, "scheduler actor: acquiring lock for Do");
+
+ if (Thiz->ScheduleTime == TInstant::Max()) {
+ // was already fired
+ return;
+ }
+
+ Thiz->ScheduleTime = TInstant::Max();
+ }
+
+ Thiz->Alarm.AddTask();
+ Thiz->GetActorForMe()->Schedule();
+ }
+ };
+
+ public:
+ TScheduleActor(TScheduler* scheduler)
+ : Scheduler(scheduler)
+ , ScheduleTime(TInstant::Max())
+ {
+ }
+
+ /// call Act(TTag) at specified time, unless it is already scheduled at earlier time.
+ void ScheduleAt(TInstant when) {
+ TWhatThreadDoesAcquireGuard<TMutex> guard(Mutex, "scheduler: acquiring lock for ScheduleAt");
+
+ if (when > ScheduleTime) {
+ // already scheduled
+ return;
+ }
+
+ ScheduleTime = when;
+ Scheduler->Schedule(new TScheduleItemImpl(GetThis(), when));
+ }
+
+ private:
+ TThis* GetThis() {
+ return static_cast<TThis*>(this);
+ }
+
+ TActorForMe* GetActorForMe() {
+ return static_cast<TActorForMe*>(GetThis());
+ }
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/scheduler_actor_ut.cpp b/library/cpp/messagebus/scheduler_actor_ut.cpp
new file mode 100644
index 0000000000..e81ffd3186
--- /dev/null
+++ b/library/cpp/messagebus/scheduler_actor_ut.cpp
@@ -0,0 +1,48 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "scheduler_actor.h"
+#include "misc/test_sync.h"
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NActor;
+
+Y_UNIT_TEST_SUITE(TSchedulerActorTests) {
+ struct TMyActor: public TAtomicRefCount<TMyActor>, public TActor<TMyActor>, public TScheduleActor<TMyActor> {
+ TTestSync TestSync;
+
+ TMyActor(TExecutor* executor, TScheduler* scheduler)
+ : TActor<TMyActor>(executor)
+ , TScheduleActor<TMyActor>(scheduler)
+ , Iteration(0)
+ {
+ }
+
+ unsigned Iteration;
+
+ void Act(TDefaultTag) {
+ if (!Alarm.FetchTask()) {
+ Y_FAIL("must not have no spurious wakeups in test");
+ }
+
+ TestSync.WaitForAndIncrement(Iteration++);
+ if (Iteration <= 5) {
+ ScheduleAt(TInstant::Now() + TDuration::MilliSeconds(Iteration));
+ }
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TExecutor executor(1);
+ TScheduler scheduler;
+
+ TIntrusivePtr<TMyActor> actor(new TMyActor(&executor, &scheduler));
+
+ actor->ScheduleAt(TInstant::Now() + TDuration::MilliSeconds(1));
+
+ actor->TestSync.WaitForAndIncrement(6);
+
+ // TODO: stop in destructor
+ scheduler.Stop();
+ }
+}
diff --git a/library/cpp/messagebus/session.cpp b/library/cpp/messagebus/session.cpp
new file mode 100644
index 0000000000..46a7ece6a8
--- /dev/null
+++ b/library/cpp/messagebus/session.cpp
@@ -0,0 +1,130 @@
+#include "ybus.h"
+
+#include <util/generic/cast.h>
+
+using namespace NBus;
+
+namespace NBus {
+ TBusSession::TBusSession() {
+ }
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief Adds peer of connection into connection list
+
+ int CompareByHost(const IRemoteAddr& l, const IRemoteAddr& r) noexcept {
+ if (l.Addr()->sa_family != r.Addr()->sa_family) {
+ return l.Addr()->sa_family < r.Addr()->sa_family ? -1 : +1;
+ }
+
+ switch (l.Addr()->sa_family) {
+ case AF_INET: {
+ return memcmp(&(((const sockaddr_in*)l.Addr())->sin_addr), &(((const sockaddr_in*)r.Addr())->sin_addr), sizeof(in_addr));
+ }
+
+ case AF_INET6: {
+ return memcmp(&(((const sockaddr_in6*)l.Addr())->sin6_addr), &(((const sockaddr_in6*)r.Addr())->sin6_addr), sizeof(in6_addr));
+ }
+ }
+
+ return memcmp(l.Addr(), r.Addr(), Min<size_t>(l.Len(), r.Len()));
+ }
+
+ bool operator<(const TNetAddr& a1, const TNetAddr& a2) {
+ return CompareByHost(a1, a2) < 0;
+ }
+
+ size_t TBusSession::GetInFlight(const TNetAddr& addr) const {
+ size_t r;
+ GetInFlightBulk({addr}, MakeArrayRef(&r, 1));
+ return r;
+ }
+
+ size_t TBusSession::GetConnectSyscallsNumForTest(const TNetAddr& addr) const {
+ size_t r;
+ GetConnectSyscallsNumBulkForTest({addr}, MakeArrayRef(&r, 1));
+ return r;
+ }
+
+ // Split 'host' into name and port taking into account that host can be specified
+ // as ipv6 address ('[<ipv6 address]:port' notion).
+ bool SplitHost(const TString& host, TString* hostName, TString* portNum) {
+ hostName->clear();
+ portNum->clear();
+
+ // Simple check that we have to deal with ipv6 address specification or
+ // just host name or ipv4 address.
+ if (!host.empty() && (host[0] == '[')) {
+ size_t pos = host.find(']');
+ if (pos < 2 || pos == TString::npos) {
+ // '[]' and '[<address>' are errors.
+ return false;
+ }
+
+ *hostName = host.substr(1, pos - 1);
+
+ pos++;
+ if (pos != host.length()) {
+ if (host[pos] != ':') {
+ // Do not allow '[...]a' but '[...]:' is ok (as for ipv4 before
+ return false;
+ }
+
+ *portNum = host.substr(pos + 1);
+ }
+ } else {
+ size_t pos = host.find(':');
+ if (pos != TString::npos) {
+ if (pos == 0) {
+ // Treat ':<port>' as errors but allow or '<host>:' for compatibility.
+ return false;
+ }
+
+ *portNum = host.substr(pos + 1);
+ }
+
+ *hostName = host.substr(0, pos);
+ }
+
+ return true;
+ }
+
+ /// registers external session on host:port with locator service
+ int TBusSession::RegisterService(const char* host, TBusKey start /*= YBUS_KEYMIN*/, TBusKey end /*= YBUS_KEYMAX*/, EIpVersion ipVersion) {
+ TString hostName;
+ TString port;
+ int portNum;
+
+ if (!SplitHost(host, &hostName, &port)) {
+ hostName = host;
+ }
+
+ if (port.empty()) {
+ portNum = GetProto()->GetPort();
+ } else {
+ try {
+ portNum = FromString<int>(port);
+ } catch (const TFromStringException&) {
+ return -1;
+ }
+ }
+
+ TBusService service = GetProto()->GetService();
+ return GetQueue()->GetLocator()->Register(service, hostName.data(), portNum, start, end, ipVersion);
+ }
+
+ TBusSession::~TBusSession() {
+ }
+
+}
+
+TBusClientSessionPtr TBusClientSession::Create(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, TBusMessageQueuePtr queue) {
+ return queue->CreateSource(proto, handler, config);
+}
+
+TBusServerSessionPtr TBusServerSession::Create(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, TBusMessageQueuePtr queue) {
+ return queue->CreateDestination(proto, handler, config);
+}
+
+TBusServerSessionPtr TBusServerSession::Create(TBusProtocol* proto, IBusServerHandler* handler, const TBusServerSessionConfig& config, TBusMessageQueuePtr queue, const TVector<TBindResult>& bindTo) {
+ return queue->CreateDestination(proto, handler, config, bindTo);
+}
diff --git a/library/cpp/messagebus/session.h b/library/cpp/messagebus/session.h
new file mode 100644
index 0000000000..fb12ab7c22
--- /dev/null
+++ b/library/cpp/messagebus/session.h
@@ -0,0 +1,225 @@
+#pragma once
+
+#include "connection.h"
+#include "defs.h"
+#include "handler.h"
+#include "message.h"
+#include "netaddr.h"
+#include "network.h"
+#include "session_config.h"
+#include "misc/weak_ptr.h"
+
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+#include <util/generic/array_ref.h>
+#include <util/generic/ptr.h>
+
+namespace NBus {
+ template <typename TBusSessionSubclass>
+ class TBusSessionPtr;
+ using TBusClientSessionPtr = TBusSessionPtr<TBusClientSession>;
+ using TBusServerSessionPtr = TBusSessionPtr<TBusServerSession>;
+
+ ///////////////////////////////////////////////////////////////////
+ /// \brief Interface of session object.
+
+ /// Each client and server
+ /// should instantiate session object to be able to communicate via bus
+ /// client: sess = queue->CreateSource(protocol, handler);
+ /// server: sess = queue->CreateDestination(protocol, handler);
+
+ class TBusSession: public TWeakRefCounted<TBusSession> {
+ public:
+ size_t GetInFlight(const TNetAddr& addr) const;
+ size_t GetConnectSyscallsNumForTest(const TNetAddr& addr) const;
+
+ virtual void GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const = 0;
+ virtual void GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const = 0;
+
+ virtual int GetInFlight() const noexcept = 0;
+ /// monitoring status of current session and it's connections
+ virtual TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) = 0;
+ virtual TConnectionStatusMonRecord GetStatusProtobuf() = 0;
+ virtual NPrivate::TSessionDumpStatus GetStatusRecordInternal() = 0;
+ virtual TString GetStatusSingleLine() = 0;
+ /// return session config
+ virtual const TBusSessionConfig* GetConfig() const noexcept = 0;
+ /// return session protocol
+ virtual const TBusProtocol* GetProto() const noexcept = 0;
+ virtual TBusMessageQueue* GetQueue() const noexcept = 0;
+
+ /// registers external session on host:port with locator service
+ int RegisterService(const char* hostname, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion ipVersion = EIP_VERSION_4);
+
+ protected:
+ TBusSession();
+
+ public:
+ virtual TString GetNameInternal() = 0;
+
+ virtual void Shutdown() = 0;
+
+ virtual ~TBusSession();
+ };
+
+ struct TBusClientSession: public virtual TBusSession {
+ typedef ::NBus::NPrivate::TRemoteClientSession TImpl;
+
+ static TBusClientSessionPtr Create(
+ TBusProtocol* proto,
+ IBusClientHandler* handler,
+ const TBusClientSessionConfig& config,
+ TBusMessageQueuePtr queue);
+
+ virtual TBusClientConnectionPtr GetConnection(const TNetAddr&) = 0;
+
+ /// if you want to open connection early
+ virtual void OpenConnection(const TNetAddr&) = 0;
+
+ /// Send message to the destination
+ /// If addr is set then use it as destination.
+ /// Takes ownership of addr (see ClearState method).
+ virtual EMessageStatus SendMessage(TBusMessage* pMes, const TNetAddr* addr = nullptr, bool wait = false) = 0;
+
+ virtual EMessageStatus SendMessageOneWay(TBusMessage* pMes, const TNetAddr* addr = nullptr, bool wait = false) = 0;
+
+ /// Like SendMessage but cares about message
+ template <typename T /* <: TBusMessage */>
+ EMessageStatus SendMessageAutoPtr(const TAutoPtr<T>& mes, const TNetAddr* addr = nullptr, bool wait = false) {
+ EMessageStatus status = SendMessage(mes.Get(), addr, wait);
+ if (status == MESSAGE_OK)
+ Y_UNUSED(mes.Release());
+ return status;
+ }
+
+ /// Like SendMessageOneWay but cares about message
+ template <typename T /* <: TBusMessage */>
+ EMessageStatus SendMessageOneWayAutoPtr(const TAutoPtr<T>& mes, const TNetAddr* addr = nullptr, bool wait = false) {
+ EMessageStatus status = SendMessageOneWay(mes.Get(), addr, wait);
+ if (status == MESSAGE_OK)
+ Y_UNUSED(mes.Release());
+ return status;
+ }
+
+ EMessageStatus SendMessageMove(TBusMessageAutoPtr message, const TNetAddr* addr = nullptr, bool wait = false) {
+ return SendMessageAutoPtr(message, addr, wait);
+ }
+
+ EMessageStatus SendMessageOneWayMove(TBusMessageAutoPtr message, const TNetAddr* addr = nullptr, bool wait = false) {
+ return SendMessageOneWayAutoPtr(message, addr, wait);
+ }
+
+ // TODO: implement similar one-way methods
+ };
+
+ struct TBusServerSession: public virtual TBusSession {
+ typedef ::NBus::NPrivate::TRemoteServerSession TImpl;
+
+ static TBusServerSessionPtr Create(
+ TBusProtocol* proto,
+ IBusServerHandler* handler,
+ const TBusServerSessionConfig& config,
+ TBusMessageQueuePtr queue);
+
+ static TBusServerSessionPtr Create(
+ TBusProtocol* proto,
+ IBusServerHandler* handler,
+ const TBusServerSessionConfig& config,
+ TBusMessageQueuePtr queue,
+ const TVector<TBindResult>& bindTo);
+
+ // TODO: make parameter non-const
+ virtual EMessageStatus SendReply(const TBusIdentity& ident, TBusMessage* pRep) = 0;
+
+ // TODO: make parameter non-const
+ virtual EMessageStatus ForgetRequest(const TBusIdentity& ident) = 0;
+
+ template <typename U /* <: TBusMessage */>
+ EMessageStatus SendReplyAutoPtr(TBusIdentity& ident, TAutoPtr<U>& resp) {
+ EMessageStatus status = SendReply(const_cast<const TBusIdentity&>(ident), resp.Get());
+ if (status == MESSAGE_OK) {
+ Y_UNUSED(resp.Release());
+ }
+ return status;
+ }
+
+ EMessageStatus SendReplyMove(TBusIdentity& ident, TBusMessageAutoPtr resp) {
+ return SendReplyAutoPtr(ident, resp);
+ }
+
+ /// Pause input from the network.
+ /// It is valid to call this method in parallel.
+ /// TODO: pull this method up to TBusSession.
+ virtual void PauseInput(bool pause) = 0;
+ virtual unsigned GetActualListenPort() = 0;
+ };
+
+ namespace NPrivate {
+ template <typename TBusSessionSubclass>
+ class TBusOwnerSessionPtr: public TAtomicRefCount<TBusOwnerSessionPtr<TBusSessionSubclass>> {
+ private:
+ TIntrusivePtr<TBusSessionSubclass> Ptr;
+
+ public:
+ TBusOwnerSessionPtr(TBusSessionSubclass* session)
+ : Ptr(session)
+ {
+ Y_ASSERT(!!Ptr);
+ }
+
+ ~TBusOwnerSessionPtr() {
+ Ptr->Shutdown();
+ }
+
+ TBusSessionSubclass* Get() const {
+ return reinterpret_cast<TBusSessionSubclass*>(Ptr.Get());
+ }
+ };
+
+ }
+
+ template <typename TBusSessionSubclass>
+ class TBusSessionPtr {
+ private:
+ TIntrusivePtr<NPrivate::TBusOwnerSessionPtr<TBusSessionSubclass>> SmartPtr;
+ TBusSessionSubclass* Ptr;
+
+ public:
+ TBusSessionPtr()
+ : Ptr()
+ {
+ }
+ TBusSessionPtr(TBusSessionSubclass* session)
+ : SmartPtr(!!session ? new NPrivate::TBusOwnerSessionPtr<TBusSessionSubclass>(session) : nullptr)
+ , Ptr(session)
+ {
+ }
+
+ TBusSessionSubclass* Get() const {
+ return Ptr;
+ }
+ operator TBusSessionSubclass*() {
+ return Get();
+ }
+ TBusSessionSubclass& operator*() const {
+ return *Get();
+ }
+ TBusSessionSubclass* operator->() const {
+ return Get();
+ }
+
+ bool operator!() const {
+ return !Ptr;
+ }
+
+ void Swap(TBusSessionPtr& t) noexcept {
+ DoSwap(SmartPtr, t.SmartPtr);
+ DoSwap(Ptr, t.Ptr);
+ }
+
+ void Drop() {
+ TBusSessionPtr().Swap(*this);
+ }
+ };
+
+}
diff --git a/library/cpp/messagebus/session_config.h b/library/cpp/messagebus/session_config.h
new file mode 100644
index 0000000000..37df97e986
--- /dev/null
+++ b/library/cpp/messagebus/session_config.h
@@ -0,0 +1,4 @@
+#pragma once
+
+#include <library/cpp/messagebus/config/session_config.h>
+
diff --git a/library/cpp/messagebus/session_impl.cpp b/library/cpp/messagebus/session_impl.cpp
new file mode 100644
index 0000000000..ddf9f360c4
--- /dev/null
+++ b/library/cpp/messagebus/session_impl.cpp
@@ -0,0 +1,650 @@
+#include "session_impl.h"
+
+#include "acceptor.h"
+#include "network.h"
+#include "remote_client_connection.h"
+#include "remote_client_session.h"
+#include "remote_server_connection.h"
+#include "remote_server_session.h"
+#include "misc/weak_ptr.h"
+
+#include <util/generic/cast.h>
+
+using namespace NActor;
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NEventLoop;
+
+namespace {
+ class TScheduleSession: public IScheduleItem {
+ public:
+ TScheduleSession(TBusSessionImpl* session, TInstant deadline)
+ : IScheduleItem(deadline)
+ , Session(session)
+ , SessionImpl(session)
+ {
+ }
+
+ void Do() override {
+ TIntrusivePtr<TBusSession> session = Session.Get();
+ if (!!session) {
+ SessionImpl->Cron();
+ }
+ }
+
+ private:
+ TWeakPtr<TBusSession> Session;
+ // Work around TWeakPtr limitation
+ TBusSessionImpl* SessionImpl;
+ };
+}
+
+TConnectionsAcceptorsSnapshot::TConnectionsAcceptorsSnapshot()
+ : LastConnectionId(0)
+ , LastAcceptorId(0)
+{
+}
+
+struct TBusSessionImpl::TImpl {
+ TRemoteConnectionWriterIncrementalStatus DeadConnectionWriterStatusSummary;
+ TRemoteConnectionReaderIncrementalStatus DeadConnectionReaderStatusSummary;
+ TAcceptorStatus DeadAcceptorStatusSummary;
+};
+
+namespace {
+ TBusSessionConfig SessionConfigFillDefaults(const TBusSessionConfig& config, const TString& name) {
+ TBusSessionConfig copy = config;
+ if (copy.TotalTimeout == 0 && copy.SendTimeout == 0) {
+ copy.TotalTimeout = TDuration::Seconds(60).MilliSeconds();
+ copy.SendTimeout = TDuration::Seconds(15).MilliSeconds();
+ } else if (copy.TotalTimeout == 0) {
+ Y_ASSERT(copy.SendTimeout != 0);
+ copy.TotalTimeout = config.SendTimeout + TDuration::MilliSeconds(10).MilliSeconds();
+ } else if (copy.SendTimeout == 0) {
+ Y_ASSERT(copy.TotalTimeout != 0);
+ if ((ui64)copy.TotalTimeout > (ui64)TDuration::MilliSeconds(10).MilliSeconds()) {
+ copy.SendTimeout = copy.TotalTimeout - TDuration::MilliSeconds(10).MilliSeconds();
+ } else {
+ copy.SendTimeout = copy.TotalTimeout;
+ }
+ } else {
+ Y_ASSERT(copy.TotalTimeout != 0);
+ Y_ASSERT(copy.SendTimeout != 0);
+ }
+
+ if (copy.ConnectTimeout == 0) {
+ copy.ConnectTimeout = copy.SendTimeout;
+ }
+
+ Y_VERIFY(copy.SendTimeout > 0, "SendTimeout must be > 0");
+ Y_VERIFY(copy.TotalTimeout > 0, "TotalTimeout must be > 0");
+ Y_VERIFY(copy.ConnectTimeout > 0, "ConnectTimeout must be > 0");
+ Y_VERIFY(copy.TotalTimeout >= copy.SendTimeout, "TotalTimeout must be >= SendTimeout");
+
+ if (!copy.Name) {
+ copy.Name = name;
+ }
+
+ return copy;
+ }
+}
+
+TBusSessionImpl::TBusSessionImpl(bool isSource, TBusMessageQueue* queue, TBusProtocol* proto,
+ IBusErrorHandler* handler,
+ const TBusSessionConfig& config, const TString& name)
+ : TActor<TBusSessionImpl, TStatusTag>(queue->WorkQueue.Get())
+ , TActor<TBusSessionImpl, TConnectionTag>(queue->WorkQueue.Get())
+ , Impl(new TImpl)
+ , IsSource_(isSource)
+ , Queue(queue)
+ , Proto(proto)
+ , ProtoName(Proto->GetService())
+ , ErrorHandler(handler)
+ , HandlerUseCountHolder(&handler->UseCountChecker)
+ , Config(SessionConfigFillDefaults(config, name))
+ , WriteEventLoop("wr-el")
+ , ReadEventLoop("rd-el")
+ , LastAcceptorId(0)
+ , LastConnectionId(0)
+ , Down(0)
+{
+ Impl->DeadAcceptorStatusSummary.Summary = true;
+
+ ReadEventLoopThread.Reset(new NThreading::TLegacyFuture<void, false>(std::bind(&TEventLoop::Run, std::ref(ReadEventLoop))));
+ WriteEventLoopThread.Reset(new NThreading::TLegacyFuture<void, false>(std::bind(&TEventLoop::Run, std::ref(WriteEventLoop))));
+
+ Queue->Schedule(IScheduleItemAutoPtr(new TScheduleSession(this, TInstant::Now() + Config.Secret.TimeoutPeriod)));
+}
+
+TBusSessionImpl::~TBusSessionImpl() {
+ Y_VERIFY(Down);
+ Y_VERIFY(ShutdownCompleteEvent.WaitT(TDuration::Zero()));
+ Y_VERIFY(!WriteEventLoop.IsRunning());
+ Y_VERIFY(!ReadEventLoop.IsRunning());
+}
+
+TBusSessionStatus::TBusSessionStatus()
+ : InFlightCount(0)
+ , InFlightSize(0)
+ , InputPaused(false)
+{
+}
+
+void TBusSessionImpl::Shutdown() {
+ if (!AtomicCas(&Down, 1, 0)) {
+ ShutdownCompleteEvent.WaitI();
+ return;
+ }
+
+ Y_VERIFY(Queue->IsRunning(), "Session must be shut down prior to queue shutdown");
+
+ TUseAfterFreeCheckerGuard handlerAliveCheckedGuard(ErrorHandler->UseAfterFreeChecker);
+
+ // For legacy clients that don't use smart pointers
+ TIntrusivePtr<TBusSessionImpl> thiz(this);
+
+ Queue->Remove(this);
+
+ // shutdown event loops first, so they won't send more events
+ // to acceptors and connections
+ ReadEventLoop.Stop();
+ WriteEventLoop.Stop();
+ ReadEventLoopThread->Get();
+ WriteEventLoopThread->Get();
+
+ // shutdown acceptors before connections
+ // so they won't create more connections
+ TVector<TAcceptorPtr> acceptors;
+ GetAcceptors(&acceptors);
+ {
+ TGuard<TMutex> guard(ConnectionsLock);
+ Acceptors.clear();
+ }
+
+ for (auto& acceptor : acceptors) {
+ acceptor->Shutdown();
+ }
+
+ // shutdown connections
+ TVector<TRemoteConnectionPtr> cs;
+ GetConnections(&cs);
+
+ for (auto& c : cs) {
+ c->Shutdown(MESSAGE_SHUTDOWN);
+ }
+
+ // shutdown connections actor
+ // must shutdown after connections destroyed
+ ConnectionsData.ShutdownState.ShutdownCommand();
+ GetConnectionsActor()->Schedule();
+ ConnectionsData.ShutdownState.ShutdownComplete.WaitI();
+
+ // finally shutdown status actor
+ StatusData.ShutdownState.ShutdownCommand();
+ GetStatusActor()->Schedule();
+ StatusData.ShutdownState.ShutdownComplete.WaitI();
+
+ // Make sure no one references IMessageHandler after Shutdown()
+ JobCount.WaitForZero();
+ HandlerUseCountHolder.Reset();
+
+ ShutdownCompleteEvent.Signal();
+}
+
+bool TBusSessionImpl::IsDown() {
+ return static_cast<bool>(AtomicGet(Down));
+}
+
+size_t TBusSessionImpl::GetInFlightImpl(const TNetAddr& addr) const {
+ TRemoteConnectionPtr conn = const_cast<TBusSessionImpl*>(this)->GetConnection(addr, false);
+ if (!!conn) {
+ return conn->GetInFlight();
+ } else {
+ return 0;
+ }
+}
+
+void TBusSessionImpl::GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const {
+ Y_VERIFY(addrs.size() == results.size(), "input.size != output.size");
+ for (size_t i = 0; i < addrs.size(); ++i) {
+ results[i] = GetInFlightImpl(addrs[i]);
+ }
+}
+
+size_t TBusSessionImpl::GetConnectSyscallsNumForTestImpl(const TNetAddr& addr) const {
+ TRemoteConnectionPtr conn = const_cast<TBusSessionImpl*>(this)->GetConnection(addr, false);
+ if (!!conn) {
+ return conn->GetConnectSyscallsNumForTest();
+ } else {
+ return 0;
+ }
+}
+
+void TBusSessionImpl::GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const {
+ Y_VERIFY(addrs.size() == results.size(), "input.size != output.size");
+ for (size_t i = 0; i < addrs.size(); ++i) {
+ results[i] = GetConnectSyscallsNumForTestImpl(addrs[i]);
+ }
+}
+
+void TBusSessionImpl::FillStatus() {
+}
+
+TSessionDumpStatus TBusSessionImpl::GetStatusRecordInternal() {
+ // Probably useless, because it returns cached info now
+ Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(),
+ "GetStatus must not be called from executor thread");
+
+ TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex);
+ // TODO: returns zeros for a second after start
+ // (until first cron)
+ return StatusData.StatusDumpCached;
+}
+
+TString TBusSessionImpl::GetStatus(ui16 flags) {
+ Y_UNUSED(flags);
+
+ return GetStatusRecordInternal().PrintToString();
+}
+
+TConnectionStatusMonRecord TBusSessionImpl::GetStatusProtobuf() {
+ Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(),
+ "GetStatus must not be called from executor thread");
+
+ TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex);
+
+ return StatusData.StatusDumpCached.ConnectionStatusSummary.GetStatusProtobuf();
+}
+
+TString TBusSessionImpl::GetStatusSingleLine() {
+ TSessionDumpStatus status = GetStatusRecordInternal();
+
+ TStringStream ss;
+ ss << "in-flight: " << status.Status.InFlightCount;
+ if (IsSource_) {
+ ss << " ack: " << status.ConnectionStatusSummary.WriterStatus.AckMessagesSize;
+ }
+ ss << " send-q: " << status.ConnectionStatusSummary.WriterStatus.SendQueueSize;
+ return ss.Str();
+}
+
+void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionWriterIncrementalStatus& connectionStatus) {
+ Impl->DeadConnectionWriterStatusSummary += connectionStatus;
+}
+
+void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionReaderIncrementalStatus& connectionStatus) {
+ Impl->DeadConnectionReaderStatusSummary += connectionStatus;
+}
+
+void TBusSessionImpl::ProcessItem(TStatusTag, TDeadConnectionTag, const TAcceptorStatus& acceptorStatus) {
+ Impl->DeadAcceptorStatusSummary += acceptorStatus;
+}
+
+void TBusSessionImpl::ProcessItem(TConnectionTag, ::NActor::TDefaultTag, const TOnAccept& onAccept) {
+ TSocketHolder socket(onAccept.s);
+
+ if (AtomicGet(Down)) {
+ // do not create connections after shutdown initiated
+ return;
+ }
+
+ //if (Connections.find(addr) != Connections.end()) {
+ // TODO: it is possible
+ // won't be a problem after socket address replaced with id
+ //}
+
+ TRemoteConnectionPtr c(new TRemoteServerConnection(VerifyDynamicCast<TRemoteServerSession*>(this), ++LastConnectionId, onAccept.addr));
+
+ VerifyDynamicCast<TRemoteServerConnection*>(c.Get())->Init(socket.Release(), onAccept.now);
+
+ InsertConnectionLockAcquired(c.Get());
+}
+
+void TBusSessionImpl::ProcessItem(TConnectionTag, TRemoveTag, TRemoteConnectionPtr c) {
+ TAddrRemoteConnections::iterator it1 = Connections.find(c->PeerAddrSocketAddr);
+ if (it1 != Connections.end()) {
+ if (it1->second.Get() == c.Get()) {
+ Connections.erase(it1);
+ }
+ }
+
+ THashMap<ui64, TRemoteConnectionPtr>::iterator it2 = ConnectionsById.find(c->ConnectionId);
+ if (it2 != ConnectionsById.end()) {
+ ConnectionsById.erase(it2);
+ }
+
+ SendSnapshotToStatusActor();
+}
+
+void TBusSessionImpl::ProcessConnectionsAcceptorsShapshotQueueItem(TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> snapshot) {
+ for (TVector<TRemoteConnectionPtr>::const_iterator connection = snapshot->Connections.begin();
+ connection != snapshot->Connections.end(); ++connection) {
+ Y_ASSERT((*connection)->ConnectionId <= snapshot->LastConnectionId);
+ }
+
+ for (TVector<TAcceptorPtr>::const_iterator acceptor = snapshot->Acceptors.begin();
+ acceptor != snapshot->Acceptors.end(); ++acceptor) {
+ Y_ASSERT((*acceptor)->AcceptorId <= snapshot->LastAcceptorId);
+ }
+
+ StatusData.ConnectionsAcceptorsSnapshot = snapshot;
+}
+
+void TBusSessionImpl::StatusUpdateCachedDumpIfNecessary(TInstant now) {
+ if (now - StatusData.StatusDumpCachedLastUpdate > Config.Secret.StatusFlushPeriod) {
+ StatusUpdateCachedDump();
+ StatusData.StatusDumpCachedLastUpdate = now;
+ }
+}
+
+void TBusSessionImpl::StatusUpdateCachedDump() {
+ TSessionDumpStatus r;
+
+ if (AtomicGet(Down)) {
+ r.Shutdown = true;
+ TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex);
+ StatusData.StatusDumpCached = r;
+ return;
+ }
+
+ // TODO: make thread-safe
+ FillStatus();
+
+ r.Status = StatusData.Status;
+
+ {
+ TStringStream ss;
+
+ TString name = Config.Name;
+ if (!name) {
+ name = "unnamed";
+ }
+
+ ss << (IsSource_ ? "client" : "server") << " session " << name << ", proto " << Proto->GetService() << Endl;
+ ss << "in flight: " << r.Status.InFlightCount;
+ if (!IsSource_) {
+ ss << ", " << r.Status.InFlightSize << "b";
+ }
+ if (r.Status.InputPaused) {
+ ss << " (input paused)";
+ }
+ ss << "\n";
+
+ r.Head = ss.Str();
+ }
+
+ TVector<TRemoteConnectionPtr>& connections = StatusData.ConnectionsAcceptorsSnapshot->Connections;
+ TVector<TAcceptorPtr>& acceptors = StatusData.ConnectionsAcceptorsSnapshot->Acceptors;
+
+ r.ConnectionStatusSummary = TRemoteConnectionStatus();
+ r.ConnectionStatusSummary.Summary = true;
+ r.ConnectionStatusSummary.Server = !IsSource_;
+ r.ConnectionStatusSummary.WriterStatus.Incremental = Impl->DeadConnectionWriterStatusSummary;
+ r.ConnectionStatusSummary.ReaderStatus.Incremental = Impl->DeadConnectionReaderStatusSummary;
+
+ TAcceptorStatus acceptorStatusSummary = Impl->DeadAcceptorStatusSummary;
+
+ {
+ TStringStream ss;
+
+ for (TVector<TAcceptorPtr>::const_iterator acceptor = acceptors.begin();
+ acceptor != acceptors.end(); ++acceptor) {
+ const TAcceptorStatus status = (*acceptor)->GranStatus.Listen.Get();
+
+ acceptorStatusSummary += status;
+
+ if (acceptor != acceptors.begin()) {
+ ss << "\n";
+ }
+ ss << status.PrintToString();
+ }
+
+ r.Acceptors = ss.Str();
+ }
+
+ {
+ TStringStream ss;
+
+ for (TVector<TRemoteConnectionPtr>::const_iterator connection = connections.begin();
+ connection != connections.end(); ++connection) {
+ if (connection != connections.begin()) {
+ ss << "\n";
+ }
+
+ TRemoteConnectionStatus status;
+ status.Server = !IsSource_;
+ status.ReaderStatus = (*connection)->GranStatus.Reader.Get();
+ status.WriterStatus = (*connection)->GranStatus.Writer.Get();
+
+ ss << status.PrintToString();
+
+ r.ConnectionStatusSummary.ReaderStatus += status.ReaderStatus;
+ r.ConnectionStatusSummary.WriterStatus += status.WriterStatus;
+ }
+
+ r.ConnectionsSummary = r.ConnectionStatusSummary.PrintToString();
+ r.Connections = ss.Str();
+ }
+
+ r.Config = Config;
+
+ TGuard<TMutex> guard(StatusData.StatusDumpCachedMutex);
+ StatusData.StatusDumpCached = r;
+}
+
+TBusSessionImpl::TStatusData::TStatusData()
+ : ConnectionsAcceptorsSnapshot(new TConnectionsAcceptorsSnapshot)
+{
+}
+
+void TBusSessionImpl::Act(TStatusTag) {
+ TInstant now = TInstant::Now();
+
+ EShutdownState shutdownState = StatusData.ShutdownState.State.Get();
+
+ StatusData.ConnectionsAcceptorsSnapshotsQueue.DequeueAllLikelyEmpty(std::bind(&TBusSessionImpl::ProcessConnectionsAcceptorsShapshotQueueItem, this, std::placeholders::_1));
+
+ GetDeadConnectionWriterStatusQueue()->DequeueAllLikelyEmpty();
+ GetDeadConnectionReaderStatusQueue()->DequeueAllLikelyEmpty();
+ GetDeadAcceptorStatusQueue()->DequeueAllLikelyEmpty();
+
+ // TODO: check queues are empty if already stopped
+
+ if (shutdownState != SS_RUNNING) {
+ // important to beak cyclic link session -> connection -> session
+ StatusData.ConnectionsAcceptorsSnapshot->Connections.clear();
+ StatusData.ConnectionsAcceptorsSnapshot->Acceptors.clear();
+ }
+
+ if (shutdownState == SS_SHUTDOWN_COMMAND) {
+ StatusData.ShutdownState.CompleteShutdown();
+ }
+
+ StatusUpdateCachedDumpIfNecessary(now);
+}
+
+TBusSessionImpl::TConnectionsData::TConnectionsData() {
+}
+
+void TBusSessionImpl::Act(TConnectionTag) {
+ TConnectionsGuard guard(ConnectionsLock);
+
+ EShutdownState shutdownState = ConnectionsData.ShutdownState.State.Get();
+ if (shutdownState == SS_SHUTDOWN_COMPLETE) {
+ Y_VERIFY(GetRemoveConnectionQueue()->IsEmpty());
+ Y_VERIFY(GetOnAcceptQueue()->IsEmpty());
+ }
+
+ GetRemoveConnectionQueue()->DequeueAllLikelyEmpty();
+ GetOnAcceptQueue()->DequeueAllLikelyEmpty();
+
+ if (shutdownState == SS_SHUTDOWN_COMMAND) {
+ ConnectionsData.ShutdownState.CompleteShutdown();
+ }
+}
+
+void TBusSessionImpl::Listen(int port, TBusMessageQueue* q) {
+ Listen(BindOnPort(port, Config.ReusePort).second, q);
+}
+
+void TBusSessionImpl::Listen(const TVector<TBindResult>& bindTo, TBusMessageQueue* q) {
+ Y_ASSERT(q == Queue);
+ int actualPort = -1;
+
+ for (const TBindResult& br : bindTo) {
+ if (actualPort == -1) {
+ actualPort = br.Addr.GetPort();
+ } else {
+ Y_VERIFY(actualPort == br.Addr.GetPort(), "state check");
+ }
+ if (Config.SocketToS >= 0) {
+ SetSocketToS(*br.Socket, &(br.Addr), Config.SocketToS);
+ }
+
+ TAcceptorPtr acceptor(new TAcceptor(this, ++LastAcceptorId, br.Socket->Release(), br.Addr));
+
+ TConnectionsGuard guard(ConnectionsLock);
+ InsertAcceptorLockAcquired(acceptor.Get());
+ }
+
+ Config.ListenPort = actualPort;
+}
+
+void TBusSessionImpl::SendSnapshotToStatusActor() {
+ //Y_ASSERT(ConnectionsLock.IsLocked());
+
+ TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> snapshot(new TConnectionsAcceptorsSnapshot);
+ GetAcceptorsLockAquired(&snapshot->Acceptors);
+ GetConnectionsLockAquired(&snapshot->Connections);
+ snapshot->LastAcceptorId = LastAcceptorId;
+ snapshot->LastConnectionId = LastConnectionId;
+ StatusData.ConnectionsAcceptorsSnapshotsQueue.Enqueue(snapshot);
+ GetStatusActor()->Schedule();
+}
+
+void TBusSessionImpl::InsertConnectionLockAcquired(TRemoteConnection* connection) {
+ //Y_ASSERT(ConnectionsLock.IsLocked());
+
+ Connections.insert(std::make_pair(connection->PeerAddrSocketAddr, connection));
+ // connection for given adds may already exist at this point
+ // (so we overwrite old connection)
+ // after reconnect, if previous connections wasn't shutdown yet
+
+ bool inserted2 = ConnectionsById.insert(std::make_pair(connection->ConnectionId, connection)).second;
+ Y_VERIFY(inserted2, "state check: must be inserted (2)");
+
+ SendSnapshotToStatusActor();
+}
+
+void TBusSessionImpl::InsertAcceptorLockAcquired(TAcceptor* acceptor) {
+ //Y_ASSERT(ConnectionsLock.IsLocked());
+
+ Acceptors.push_back(acceptor);
+
+ SendSnapshotToStatusActor();
+}
+
+void TBusSessionImpl::GetConnections(TVector<TRemoteConnectionPtr>* r) {
+ TConnectionsGuard guard(ConnectionsLock);
+ GetConnectionsLockAquired(r);
+}
+
+void TBusSessionImpl::GetAcceptors(TVector<TAcceptorPtr>* r) {
+ TConnectionsGuard guard(ConnectionsLock);
+ GetAcceptorsLockAquired(r);
+}
+
+void TBusSessionImpl::GetConnectionsLockAquired(TVector<TRemoteConnectionPtr>* r) {
+ //Y_ASSERT(ConnectionsLock.IsLocked());
+
+ r->reserve(Connections.size());
+
+ for (auto& connection : Connections) {
+ r->push_back(connection.second);
+ }
+}
+
+void TBusSessionImpl::GetAcceptorsLockAquired(TVector<TAcceptorPtr>* r) {
+ //Y_ASSERT(ConnectionsLock.IsLocked());
+
+ r->reserve(Acceptors.size());
+
+ for (auto& acceptor : Acceptors) {
+ r->push_back(acceptor);
+ }
+}
+
+TRemoteConnectionPtr TBusSessionImpl::GetConnectionById(ui64 id) {
+ TConnectionsGuard guard(ConnectionsLock);
+
+ THashMap<ui64, TRemoteConnectionPtr>::const_iterator it = ConnectionsById.find(id);
+ if (it == ConnectionsById.end()) {
+ return nullptr;
+ } else {
+ return it->second;
+ }
+}
+
+TAcceptorPtr TBusSessionImpl::GetAcceptorById(ui64 id) {
+ TGuard<TMutex> guard(ConnectionsLock);
+
+ for (const auto& Acceptor : Acceptors) {
+ if (Acceptor->AcceptorId == id) {
+ return Acceptor;
+ }
+ }
+
+ return nullptr;
+}
+
+void TBusSessionImpl::InvokeOnError(TNonDestroyingAutoPtr<TBusMessage> message, EMessageStatus status) {
+ message->CheckClean();
+ ErrorHandler->OnError(message, status);
+}
+
+TRemoteConnectionPtr TBusSessionImpl::GetConnection(const TBusSocketAddr& addr, bool create) {
+ TConnectionsGuard guard(ConnectionsLock);
+
+ TAddrRemoteConnections::const_iterator it = Connections.find(addr);
+ if (it != Connections.end()) {
+ return it->second;
+ }
+
+ if (!create) {
+ return TRemoteConnectionPtr();
+ }
+
+ Y_VERIFY(IsSource_, "must be source");
+
+ TRemoteConnectionPtr c(new TRemoteClientConnection(VerifyDynamicCast<TRemoteClientSession*>(this), ++LastConnectionId, addr.ToNetAddr()));
+ InsertConnectionLockAcquired(c.Get());
+
+ return c;
+}
+
+void TBusSessionImpl::Cron() {
+ TVector<TRemoteConnectionPtr> connections;
+ GetConnections(&connections);
+
+ for (const auto& it : connections) {
+ TRemoteConnection* connection = it.Get();
+ if (IsSource_) {
+ VerifyDynamicCast<TRemoteClientConnection*>(connection)->ScheduleTimeoutMessages();
+ } else {
+ VerifyDynamicCast<TRemoteServerConnection*>(connection)->WriterData.TimeToRotateCounters.AddTask();
+ // no schedule: do not rotate if there's no traffic
+ }
+ }
+
+ // status updates are sent without scheduling
+ GetStatusActor()->Schedule();
+
+ Queue->Schedule(IScheduleItemAutoPtr(new TScheduleSession(this, TInstant::Now() + Config.Secret.TimeoutPeriod)));
+}
+
+TString TBusSessionImpl::GetNameInternal() {
+ if (!!Config.Name) {
+ return Config.Name;
+ }
+ return ProtoName;
+}
diff --git a/library/cpp/messagebus/session_impl.h b/library/cpp/messagebus/session_impl.h
new file mode 100644
index 0000000000..90ef246ff8
--- /dev/null
+++ b/library/cpp/messagebus/session_impl.h
@@ -0,0 +1,259 @@
+#pragma once
+
+#include "acceptor_status.h"
+#include "async_result.h"
+#include "event_loop.h"
+#include "netaddr.h"
+#include "remote_connection.h"
+#include "remote_connection_status.h"
+#include "session_job_count.h"
+#include "shutdown_state.h"
+#include "ybus.h"
+
+#include <library/cpp/messagebus/actor/actor.h>
+#include <library/cpp/messagebus/actor/queue_in_actor.h>
+#include <library/cpp/messagebus/monitoring/mon_proto.pb.h>
+
+#include <library/cpp/threading/future/legacy_future.h>
+
+#include <util/generic/array_ref.h>
+#include <util/generic/string.h>
+
+namespace NBus {
+ namespace NPrivate {
+ typedef TIntrusivePtr<TRemoteClientConnection> TRemoteClientConnectionPtr;
+ typedef TIntrusivePtr<TRemoteServerConnection> TRemoteServerConnectionPtr;
+
+ typedef TIntrusivePtr<TRemoteServerSession> TRemoteServerSessionPtr;
+
+ typedef TIntrusivePtr<TAcceptor> TAcceptorPtr;
+ typedef TVector<TAcceptorPtr> TAcceptorsPtrs;
+
+ struct TConnectionsAcceptorsSnapshot {
+ TVector<TRemoteConnectionPtr> Connections;
+ TVector<TAcceptorPtr> Acceptors;
+ ui64 LastConnectionId;
+ ui64 LastAcceptorId;
+
+ TConnectionsAcceptorsSnapshot();
+ };
+
+ typedef TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> TConnectionsAcceptorsSnapshotPtr;
+
+ struct TOnAccept {
+ SOCKET s;
+ TNetAddr addr;
+ TInstant now;
+ };
+
+ struct TStatusTag {};
+ struct TConnectionTag {};
+
+ struct TDeadConnectionTag {};
+ struct TRemoveTag {};
+
+ struct TBusSessionImpl
+ : public virtual TBusSession,
+ private ::NActor::TActor<TBusSessionImpl, TStatusTag>,
+ private ::NActor::TActor<TBusSessionImpl, TConnectionTag>
+
+ ,
+ private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus, TStatusTag, TDeadConnectionTag>,
+ private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus, TStatusTag, TDeadConnectionTag>,
+ private ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus, TStatusTag, TDeadConnectionTag>
+
+ ,
+ private ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>,
+ private ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag> {
+ friend class TAcceptor;
+ friend class TRemoteConnection;
+ friend class TRemoteServerConnection;
+ friend class ::NActor::TActor<TBusSessionImpl, TStatusTag>;
+ friend class ::NActor::TActor<TBusSessionImpl, TConnectionTag>;
+ friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus, TStatusTag, TDeadConnectionTag>;
+ friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus, TStatusTag, TDeadConnectionTag>;
+ friend class ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus, TStatusTag, TDeadConnectionTag>;
+ friend class ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>;
+ friend class ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag>;
+
+ public:
+ ::NActor::TQueueInActor<TBusSessionImpl, TOnAccept, TConnectionTag>* GetOnAcceptQueue() {
+ return this;
+ }
+
+ ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionPtr, TConnectionTag, TRemoveTag>* GetRemoveConnectionQueue() {
+ return this;
+ }
+
+ ::NActor::TActor<TBusSessionImpl, TConnectionTag>* GetConnectionActor() {
+ return this;
+ }
+
+ typedef TGuard<TMutex> TConnectionsGuard;
+
+ TBusSessionImpl(bool isSource, TBusMessageQueue* queue, TBusProtocol* proto,
+ IBusErrorHandler* handler,
+ const TBusSessionConfig& config, const TString& name);
+
+ ~TBusSessionImpl() override;
+
+ void Shutdown() override;
+ bool IsDown();
+
+ size_t GetInFlightImpl(const TNetAddr& addr) const;
+ size_t GetConnectSyscallsNumForTestImpl(const TNetAddr& addr) const;
+
+ void GetInFlightBulk(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const override;
+ void GetConnectSyscallsNumBulkForTest(TArrayRef<const TNetAddr> addrs, TArrayRef<size_t> results) const override;
+
+ virtual void FillStatus();
+ TSessionDumpStatus GetStatusRecordInternal() override;
+ TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) override;
+ TConnectionStatusMonRecord GetStatusProtobuf() override;
+ TString GetStatusSingleLine() override;
+
+ void ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionWriterIncrementalStatus&);
+ void ProcessItem(TStatusTag, TDeadConnectionTag, const TRemoteConnectionReaderIncrementalStatus&);
+ void ProcessItem(TStatusTag, TDeadConnectionTag, const TAcceptorStatus&);
+ void ProcessItem(TStatusTag, ::NActor::TDefaultTag, const TAcceptorStatus&);
+ void ProcessItem(TConnectionTag, ::NActor::TDefaultTag, const TOnAccept&);
+ void ProcessItem(TConnectionTag, TRemoveTag, TRemoteConnectionPtr);
+ void ProcessConnectionsAcceptorsShapshotQueueItem(TAtomicSharedPtr<TConnectionsAcceptorsSnapshot>);
+ void StatusUpdateCachedDump();
+ void StatusUpdateCachedDumpIfNecessary(TInstant now);
+ void Act(TStatusTag);
+ void Act(TConnectionTag);
+
+ TBusProtocol* GetProto() const noexcept override;
+ const TBusSessionConfig* GetConfig() const noexcept override;
+ TBusMessageQueue* GetQueue() const noexcept override;
+ TString GetNameInternal() override;
+
+ virtual void OnMessageReceived(TRemoteConnection* c, TVectorSwaps<TBusMessagePtrAndHeader>& newMsg) = 0;
+
+ void Listen(int port, TBusMessageQueue* q);
+ void Listen(const TVector<TBindResult>& bindTo, TBusMessageQueue* q);
+ TBusConnection* Accept(SOCKET listen);
+
+ inline ::NActor::TActor<TBusSessionImpl, TStatusTag>* GetStatusActor() {
+ return this;
+ }
+ inline ::NActor::TActor<TBusSessionImpl, TConnectionTag>* GetConnectionsActor() {
+ return this;
+ }
+
+ typedef THashMap<TBusSocketAddr, TRemoteConnectionPtr> TAddrRemoteConnections;
+
+ void SendSnapshotToStatusActor();
+
+ void InsertConnectionLockAcquired(TRemoteConnection* connection);
+ void InsertAcceptorLockAcquired(TAcceptor* acceptor);
+
+ void GetConnections(TVector<TRemoteConnectionPtr>*);
+ void GetAcceptors(TVector<TAcceptorPtr>*);
+ void GetConnectionsLockAquired(TVector<TRemoteConnectionPtr>*);
+ void GetAcceptorsLockAquired(TVector<TAcceptorPtr>*);
+
+ TRemoteConnectionPtr GetConnection(const TBusSocketAddr& addr, bool create);
+ TRemoteConnectionPtr GetConnectionById(ui64 id);
+ TAcceptorPtr GetAcceptorById(ui64 id);
+
+ void InvokeOnError(TNonDestroyingAutoPtr<TBusMessage>, EMessageStatus);
+
+ void Cron();
+
+ TBusSessionJobCount JobCount;
+
+ // TODO: replace with actor
+ TMutex ConnectionsLock;
+
+ struct TImpl;
+ THolder<TImpl> Impl;
+
+ const bool IsSource_;
+
+ TBusMessageQueue* const Queue;
+ TBusProtocol* const Proto;
+ // copied to be available after Proto dies
+ const TString ProtoName;
+
+ IBusErrorHandler* const ErrorHandler;
+ TUseCountHolder HandlerUseCountHolder;
+ TBusSessionConfig Config; // TODO: make const
+
+ NEventLoop::TEventLoop WriteEventLoop;
+ NEventLoop::TEventLoop ReadEventLoop;
+ THolder<NThreading::TLegacyFuture<void, false>> ReadEventLoopThread;
+ THolder<NThreading::TLegacyFuture<void, false>> WriteEventLoopThread;
+
+ THashMap<ui64, TRemoteConnectionPtr> ConnectionsById;
+ TAddrRemoteConnections Connections;
+ TAcceptorsPtrs Acceptors;
+
+ struct TStatusData {
+ TAtomicSharedPtr<TConnectionsAcceptorsSnapshot> ConnectionsAcceptorsSnapshot;
+ ::NActor::TQueueForActor<TAtomicSharedPtr<TConnectionsAcceptorsSnapshot>> ConnectionsAcceptorsSnapshotsQueue;
+
+ TAtomicShutdownState ShutdownState;
+
+ TBusSessionStatus Status;
+
+ TSessionDumpStatus StatusDumpCached;
+ TMutex StatusDumpCachedMutex;
+ TInstant StatusDumpCachedLastUpdate;
+
+ TStatusData();
+ };
+ TStatusData StatusData;
+
+ struct TConnectionsData {
+ TAtomicShutdownState ShutdownState;
+
+ TConnectionsData();
+ };
+ TConnectionsData ConnectionsData;
+
+ ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionWriterIncrementalStatus,
+ TStatusTag, TDeadConnectionTag>*
+ GetDeadConnectionWriterStatusQueue() {
+ return this;
+ }
+
+ ::NActor::TQueueInActor<TBusSessionImpl, TRemoteConnectionReaderIncrementalStatus,
+ TStatusTag, TDeadConnectionTag>*
+ GetDeadConnectionReaderStatusQueue() {
+ return this;
+ }
+
+ ::NActor::TQueueInActor<TBusSessionImpl, TAcceptorStatus,
+ TStatusTag, TDeadConnectionTag>*
+ GetDeadAcceptorStatusQueue() {
+ return this;
+ }
+
+ template <typename TItem>
+ ::NActor::IQueueInActor<TItem>* GetQueue() {
+ return this;
+ }
+
+ ui64 LastAcceptorId;
+ ui64 LastConnectionId;
+
+ TAtomic Down;
+ TSystemEvent ShutdownCompleteEvent;
+ };
+
+ inline TBusProtocol* TBusSessionImpl::GetProto() const noexcept {
+ return Proto;
+ }
+
+ inline const TBusSessionConfig* TBusSessionImpl::GetConfig() const noexcept {
+ return &Config;
+ }
+
+ inline TBusMessageQueue* TBusSessionImpl::GetQueue() const noexcept {
+ return Queue;
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/session_job_count.cpp b/library/cpp/messagebus/session_job_count.cpp
new file mode 100644
index 0000000000..33322b1910
--- /dev/null
+++ b/library/cpp/messagebus/session_job_count.cpp
@@ -0,0 +1,22 @@
+#include "session_job_count.h"
+
+#include <util/system/yassert.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+TBusSessionJobCount::TBusSessionJobCount()
+ : JobCount(0)
+{
+}
+
+TBusSessionJobCount::~TBusSessionJobCount() {
+ Y_VERIFY(JobCount == 0, "must be 0 job count to destroy job");
+}
+
+void TBusSessionJobCount::WaitForZero() {
+ TGuard<TMutex> guard(Mutex);
+ while (AtomicGet(JobCount) > 0) {
+ CondVar.WaitI(Mutex);
+ }
+}
diff --git a/library/cpp/messagebus/session_job_count.h b/library/cpp/messagebus/session_job_count.h
new file mode 100644
index 0000000000..23aca618b1
--- /dev/null
+++ b/library/cpp/messagebus/session_job_count.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/mutex.h>
+
+namespace NBus {
+ namespace NPrivate {
+ class TBusSessionJobCount {
+ private:
+ TAtomic JobCount;
+
+ TMutex Mutex;
+ TCondVar CondVar;
+
+ public:
+ TBusSessionJobCount();
+ ~TBusSessionJobCount();
+
+ void Add(unsigned delta) {
+ AtomicAdd(JobCount, delta);
+ }
+
+ void Increment() {
+ Add(1);
+ }
+
+ void Decrement() {
+ if (AtomicDecrement(JobCount) == 0) {
+ TGuard<TMutex> guard(Mutex);
+ CondVar.BroadCast();
+ }
+ }
+
+ void WaitForZero();
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/shutdown_state.cpp b/library/cpp/messagebus/shutdown_state.cpp
new file mode 100644
index 0000000000..a4e2bfa8b2
--- /dev/null
+++ b/library/cpp/messagebus/shutdown_state.cpp
@@ -0,0 +1,20 @@
+#include "shutdown_state.h"
+
+#include <util/system/yassert.h>
+
+void TAtomicShutdownState::ShutdownCommand() {
+ Y_VERIFY(State.CompareAndSet(SS_RUNNING, SS_SHUTDOWN_COMMAND));
+}
+
+void TAtomicShutdownState::CompleteShutdown() {
+ Y_VERIFY(State.CompareAndSet(SS_SHUTDOWN_COMMAND, SS_SHUTDOWN_COMPLETE));
+ ShutdownComplete.Signal();
+}
+
+bool TAtomicShutdownState::IsRunning() {
+ return State.Get() == SS_RUNNING;
+}
+
+TAtomicShutdownState::~TAtomicShutdownState() {
+ Y_VERIFY(SS_SHUTDOWN_COMPLETE == State.Get());
+}
diff --git a/library/cpp/messagebus/shutdown_state.h b/library/cpp/messagebus/shutdown_state.h
new file mode 100644
index 0000000000..86bd7110ae
--- /dev/null
+++ b/library/cpp/messagebus/shutdown_state.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include "misc/atomic_box.h"
+
+#include <util/system/event.h>
+
+enum EShutdownState {
+ SS_RUNNING,
+ SS_SHUTDOWN_COMMAND,
+ SS_SHUTDOWN_COMPLETE,
+};
+
+struct TAtomicShutdownState {
+ TAtomicBox<EShutdownState> State;
+ TSystemEvent ShutdownComplete;
+
+ void ShutdownCommand();
+ void CompleteShutdown();
+ bool IsRunning();
+
+ ~TAtomicShutdownState();
+};
diff --git a/library/cpp/messagebus/socket_addr.cpp b/library/cpp/messagebus/socket_addr.cpp
new file mode 100644
index 0000000000..c1b3a28fbe
--- /dev/null
+++ b/library/cpp/messagebus/socket_addr.cpp
@@ -0,0 +1,79 @@
+#include "socket_addr.h"
+
+#include "netaddr.h"
+
+#include <util/network/address.h>
+#include <util/network/init.h>
+#include <util/system/yassert.h>
+
+using namespace NAddr;
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+static_assert(ADDR_UNSPEC == 0, "expect ADDR_UNSPEC == 0");
+
+NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(const NAddr::IRemoteAddr* addr)
+ : IPv6ScopeID(0)
+{
+ const sockaddr* sa = addr->Addr();
+
+ switch ((EAddrFamily)sa->sa_family) {
+ case AF_UNSPEC: {
+ IpAddr.Clear();
+ Port = 0;
+ break;
+ }
+ case AF_INET: {
+ IpAddr.SetInAddr(((const sockaddr_in*)sa)->sin_addr);
+ Port = InetToHost(((const sockaddr_in*)sa)->sin_port);
+ break;
+ }
+ case AF_INET6: {
+ IpAddr.SetIn6Addr(((const sockaddr_in6*)sa)->sin6_addr);
+ Port = InetToHost(((const sockaddr_in*)sa)->sin_port);
+ IPv6ScopeID = InetToHost(((const sockaddr_in6*)sa)->sin6_scope_id);
+ break;
+ }
+ default:
+ Y_FAIL("unknown address family");
+ }
+}
+
+NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(TStringBuf host, unsigned port) {
+ *this = TNetAddr(host, port);
+}
+
+NBus::NPrivate::TBusSocketAddr::TBusSocketAddr(const TNetAddr& addr) {
+ *this = TBusSocketAddr(&addr);
+}
+
+TNetAddr NBus::NPrivate::TBusSocketAddr::ToNetAddr() const {
+ sockaddr_storage storage;
+ Zero(storage);
+
+ storage.ss_family = (ui16)IpAddr.GetAddrFamily();
+
+ switch (IpAddr.GetAddrFamily()) {
+ case ADDR_UNSPEC:
+ return TNetAddr();
+ case ADDR_IPV4: {
+ ((sockaddr_in*)&storage)->sin_addr = IpAddr.GetInAddr();
+ ((sockaddr_in*)&storage)->sin_port = HostToInet(Port);
+ break;
+ }
+ case ADDR_IPV6: {
+ ((sockaddr_in6*)&storage)->sin6_addr = IpAddr.GetIn6Addr();
+ ((sockaddr_in6*)&storage)->sin6_port = HostToInet(Port);
+ ((sockaddr_in6*)&storage)->sin6_scope_id = HostToInet(IPv6ScopeID);
+ break;
+ }
+ }
+
+ return TNetAddr(new TOpaqueAddr((sockaddr*)&storage));
+}
+
+template <>
+void Out<TBusSocketAddr>(IOutputStream& out, const TBusSocketAddr& addr) {
+ out << addr.ToNetAddr();
+}
diff --git a/library/cpp/messagebus/socket_addr.h b/library/cpp/messagebus/socket_addr.h
new file mode 100644
index 0000000000..959eafe689
--- /dev/null
+++ b/library/cpp/messagebus/socket_addr.h
@@ -0,0 +1,113 @@
+#pragma once
+
+#include "hash.h"
+
+#include <util/generic/hash.h>
+#include <util/generic/utility.h>
+#include <util/network/address.h>
+#include <util/network/init.h>
+
+#include <string.h>
+
+namespace NBus {
+ class TNetAddr;
+}
+
+namespace NBus {
+ namespace NPrivate {
+ enum EAddrFamily {
+ ADDR_UNSPEC = AF_UNSPEC,
+ ADDR_IPV4 = AF_INET,
+ ADDR_IPV6 = AF_INET6,
+ };
+
+ class TBusIpAddr {
+ private:
+ EAddrFamily Af;
+
+ union {
+ in_addr In4;
+ in6_addr In6;
+ };
+
+ public:
+ TBusIpAddr() {
+ Clear();
+ }
+
+ EAddrFamily GetAddrFamily() const {
+ return Af;
+ }
+
+ void Clear() {
+ Zero(*this);
+ }
+
+ in_addr GetInAddr() const {
+ Y_ASSERT(Af == ADDR_IPV4);
+ return In4;
+ }
+
+ void SetInAddr(const in_addr& in4) {
+ Clear();
+ Af = ADDR_IPV4;
+ In4 = in4;
+ }
+
+ in6_addr GetIn6Addr() const {
+ Y_ASSERT(Af == ADDR_IPV6);
+ return In6;
+ }
+
+ void SetIn6Addr(const in6_addr& in6) {
+ Clear();
+ Af = ADDR_IPV6;
+ In6 = in6;
+ }
+
+ bool operator==(const TBusIpAddr& that) const {
+ return memcmp(this, &that, sizeof(that)) == 0;
+ }
+ };
+
+ class TBusSocketAddr {
+ public:
+ TBusIpAddr IpAddr;
+ ui16 Port;
+
+ //Only makes sense for IPv6 link-local addresses
+ ui32 IPv6ScopeID;
+
+ TBusSocketAddr()
+ : Port(0)
+ , IPv6ScopeID(0)
+ {
+ }
+
+ TBusSocketAddr(const NAddr::IRemoteAddr*);
+ TBusSocketAddr(const TNetAddr&);
+ TBusSocketAddr(TStringBuf host, unsigned port);
+
+ TNetAddr ToNetAddr() const;
+
+ bool operator==(const TBusSocketAddr& that) const {
+ return IpAddr == that.IpAddr && Port == that.Port;
+ }
+ };
+
+ }
+}
+
+template <>
+struct THash<NBus::NPrivate::TBusIpAddr> {
+ inline size_t operator()(const NBus::NPrivate::TBusIpAddr& a) const {
+ return ComputeHash(TStringBuf((const char*)&a, sizeof(a)));
+ }
+};
+
+template <>
+struct THash<NBus::NPrivate::TBusSocketAddr> {
+ inline size_t operator()(const NBus::NPrivate::TBusSocketAddr& a) const {
+ return HashValues(a.IpAddr, a.Port);
+ }
+};
diff --git a/library/cpp/messagebus/socket_addr_ut.cpp b/library/cpp/messagebus/socket_addr_ut.cpp
new file mode 100644
index 0000000000..783bb62a86
--- /dev/null
+++ b/library/cpp/messagebus/socket_addr_ut.cpp
@@ -0,0 +1,15 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "netaddr.h"
+#include "socket_addr.h"
+
+#include <util/string/cast.h>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+Y_UNIT_TEST_SUITE(TBusSocketAddr) {
+ Y_UNIT_TEST(Simple) {
+ UNIT_ASSERT_VALUES_EQUAL(TString("127.0.0.1:80"), ToString(TBusSocketAddr("127.0.0.1", 80)));
+ }
+}
diff --git a/library/cpp/messagebus/storage.cpp b/library/cpp/messagebus/storage.cpp
new file mode 100644
index 0000000000..efefc87340
--- /dev/null
+++ b/library/cpp/messagebus/storage.cpp
@@ -0,0 +1,161 @@
+#include "storage.h"
+
+#include <typeinfo>
+
+namespace NBus {
+ namespace NPrivate {
+ TTimedMessages::TTimedMessages() {
+ }
+
+ TTimedMessages::~TTimedMessages() {
+ Y_VERIFY(Items.empty());
+ }
+
+ void TTimedMessages::PushBack(TNonDestroyingAutoPtr<TBusMessage> m) {
+ TItem i;
+ i.Message.Reset(m.Release());
+ Items.push_back(i);
+ }
+
+ TNonDestroyingAutoPtr<TBusMessage> TTimedMessages::PopFront() {
+ TBusMessage* r = nullptr;
+ if (!Items.empty()) {
+ r = Items.front()->Message.Release();
+ Items.pop_front();
+ }
+ return r;
+ }
+
+ bool TTimedMessages::Empty() const {
+ return Items.empty();
+ }
+
+ size_t TTimedMessages::Size() const {
+ return Items.size();
+ }
+
+ void TTimedMessages::Timeout(TInstant before, TMessagesPtrs* r) {
+ // shortcut
+ if (before == TInstant::Max()) {
+ Clear(r);
+ return;
+ }
+
+ while (!Items.empty()) {
+ TItem& i = *Items.front();
+ if (TInstant::MilliSeconds(i.Message->GetHeader()->SendTime) > before) {
+ break;
+ }
+ r->push_back(i.Message.Release());
+ Items.pop_front();
+ }
+ }
+
+ void TTimedMessages::Clear(TMessagesPtrs* r) {
+ while (!Items.empty()) {
+ r->push_back(Items.front()->Message.Release());
+ Items.pop_front();
+ }
+ }
+
+ TSyncAckMessages::TSyncAckMessages() {
+ KeyToMessage.set_empty_key(0);
+ KeyToMessage.set_deleted_key(1);
+ }
+
+ TSyncAckMessages::~TSyncAckMessages() {
+ Y_VERIFY(KeyToMessage.empty());
+ Y_VERIFY(TimedItems.empty());
+ }
+
+ void TSyncAckMessages::Push(TBusMessagePtrAndHeader& m) {
+ // Perform garbage collection if `TimedMessages` contain too many junk data
+ if (TimedItems.size() > 1000 && TimedItems.size() > KeyToMessage.size() * 4) {
+ Gc();
+ }
+
+ TValue value = {m.MessagePtr.Release()};
+
+ std::pair<TKeyToMessage::iterator, bool> p = KeyToMessage.insert(TKeyToMessage::value_type(m.Header.Id, value));
+ Y_VERIFY(p.second, "non-unique id; %s", value.Message->Describe().data());
+
+ TTimedItem item = {m.Header.Id, m.Header.SendTime};
+ TimedItems.push_back(item);
+ }
+
+ TBusMessage* TSyncAckMessages::Pop(TBusKey id) {
+ TKeyToMessage::iterator it = KeyToMessage.find(id);
+ if (it == KeyToMessage.end()) {
+ return nullptr;
+ }
+ TValue v = it->second;
+ KeyToMessage.erase(it);
+
+ // `TimedMessages` still contain record about this message
+
+ return v.Message;
+ }
+
+ void TSyncAckMessages::Timeout(TInstant before, TMessagesPtrs* r) {
+ // shortcut
+ if (before == TInstant::Max()) {
+ Clear(r);
+ return;
+ }
+
+ Y_ASSERT(r->empty());
+
+ while (!TimedItems.empty()) {
+ TTimedItem i = TimedItems.front();
+ if (TInstant::MilliSeconds(i.SendTime) > before) {
+ break;
+ }
+
+ TKeyToMessage::iterator itMessage = KeyToMessage.find(i.Key);
+
+ if (itMessage != KeyToMessage.end()) {
+ r->push_back(itMessage->second.Message);
+ KeyToMessage.erase(itMessage);
+ }
+
+ TimedItems.pop_front();
+ }
+ }
+
+ void TSyncAckMessages::Clear(TMessagesPtrs* r) {
+ for (TKeyToMessage::const_iterator i = KeyToMessage.begin(); i != KeyToMessage.end(); ++i) {
+ r->push_back(i->second.Message);
+ }
+
+ KeyToMessage.clear();
+ TimedItems.clear();
+ }
+
+ void TSyncAckMessages::Gc() {
+ TDeque<TTimedItem> tmp;
+
+ for (auto& timedItem : TimedItems) {
+ if (KeyToMessage.find(timedItem.Key) == KeyToMessage.end()) {
+ continue;
+ }
+ tmp.push_back(timedItem);
+ }
+
+ TimedItems.swap(tmp);
+ }
+
+ void TSyncAckMessages::RemoveAll(const TMessagesPtrs& messages) {
+ for (auto message : messages) {
+ TKeyToMessage::iterator it = KeyToMessage.find(message->GetHeader()->Id);
+ Y_VERIFY(it != KeyToMessage.end(), "delete non-existent message");
+ KeyToMessage.erase(it);
+ }
+ }
+
+ void TSyncAckMessages::DumpState() {
+ Cerr << TimedItems.size() << Endl;
+ Cerr << KeyToMessage.size() << Endl;
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/storage.h b/library/cpp/messagebus/storage.h
new file mode 100644
index 0000000000..7d168844ed
--- /dev/null
+++ b/library/cpp/messagebus/storage.h
@@ -0,0 +1,94 @@
+#pragma once
+
+#include "message_ptr_and_header.h"
+#include "moved.h"
+#include "ybus.h"
+
+#include <contrib/libs/sparsehash/src/sparsehash/dense_hash_map>
+
+#include <util/generic/deque.h>
+#include <util/generic/noncopyable.h>
+#include <util/generic/utility.h>
+
+namespace NBus {
+ namespace NPrivate {
+ typedef TVector<TBusMessage*> TMessagesPtrs;
+
+ class TTimedMessages {
+ public:
+ TTimedMessages();
+ ~TTimedMessages();
+
+ struct TItem {
+ THolder<TBusMessage> Message;
+
+ void Swap(TItem& that) {
+ DoSwap(Message, that.Message);
+ }
+ };
+
+ typedef TDeque<TMoved<TItem>> TItems;
+
+ void PushBack(TNonDestroyingAutoPtr<TBusMessage> m);
+ TNonDestroyingAutoPtr<TBusMessage> PopFront();
+ bool Empty() const;
+ size_t Size() const;
+
+ void Timeout(TInstant before, TMessagesPtrs* r);
+ void Clear(TMessagesPtrs* r);
+
+ private:
+ TItems Items;
+ };
+
+ class TSyncAckMessages : TNonCopyable {
+ public:
+ TSyncAckMessages();
+ ~TSyncAckMessages();
+
+ void Push(TBusMessagePtrAndHeader& m);
+ TBusMessage* Pop(TBusKey id);
+
+ void Timeout(TInstant before, TMessagesPtrs* r);
+
+ void Clear(TMessagesPtrs* r);
+
+ size_t Size() const {
+ return KeyToMessage.size();
+ }
+
+ void RemoveAll(const TMessagesPtrs&);
+
+ void Gc();
+
+ void DumpState();
+
+ private:
+ struct TTimedItem {
+ TBusKey Key;
+ TBusInstant SendTime;
+ };
+
+ typedef TDeque<TTimedItem> TTimedItems;
+ typedef TDeque<TTimedItem>::iterator TTimedIterator;
+
+ TTimedItems TimedItems;
+
+ struct TValue {
+ TBusMessage* Message;
+ };
+
+ // keys are already random, no need to hash them further
+ struct TIdHash {
+ size_t operator()(TBusKey value) const {
+ return value;
+ }
+ };
+
+ typedef google::dense_hash_map<TBusKey, TValue, TIdHash> TKeyToMessage;
+
+ TKeyToMessage KeyToMessage;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/synchandler.cpp b/library/cpp/messagebus/synchandler.cpp
new file mode 100644
index 0000000000..8e891d66b3
--- /dev/null
+++ b/library/cpp/messagebus/synchandler.cpp
@@ -0,0 +1,198 @@
+#include "remote_client_session.h"
+#include "remote_connection.h"
+#include "ybus.h"
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+
+/////////////////////////////////////////////////////////////////
+/// Object that encapsulates all messgae data required for sending
+/// a message synchronously and receiving a reply. It includes:
+/// 1. ConditionVariable to wait on message reply
+/// 2. Lock used by condition variable
+/// 3. Message reply
+/// 4. Reply status
+struct TBusSyncMessageData {
+ TCondVar ReplyEvent;
+ TMutex ReplyLock;
+ TBusMessage* Reply;
+ EMessageStatus ReplyStatus;
+
+ TBusSyncMessageData()
+ : Reply(nullptr)
+ , ReplyStatus(MESSAGE_DONT_ASK)
+ {
+ }
+};
+
+class TSyncHandler: public IBusClientHandler {
+public:
+ TSyncHandler(bool expectReply = true)
+ : ExpectReply(expectReply)
+ , Session(nullptr)
+ {
+ }
+ ~TSyncHandler() override {
+ }
+
+ void OnReply(TAutoPtr<TBusMessage> pMessage0, TAutoPtr<TBusMessage> pReply0) override {
+ TBusMessage* pMessage = pMessage0.Release();
+ TBusMessage* pReply = pReply0.Release();
+
+ if (!ExpectReply) { // Maybe need VERIFY, but it will be better to support backward compatibility here.
+ return;
+ }
+
+ TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage->Data);
+ SignalResult(data, pReply, MESSAGE_OK);
+ }
+
+ void OnError(TAutoPtr<TBusMessage> pMessage0, EMessageStatus status) override {
+ TBusMessage* pMessage = pMessage0.Release();
+ TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage->Data);
+ if (!data) {
+ return;
+ }
+
+ SignalResult(data, /*pReply=*/nullptr, status);
+ }
+
+ void OnMessageSent(TBusMessage* pMessage) override {
+ Y_UNUSED(pMessage);
+ Y_ASSERT(ExpectReply);
+ }
+
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override {
+ Y_ASSERT(!ExpectReply);
+ TBusSyncMessageData* data = static_cast<TBusSyncMessageData*>(pMessage.Release()->Data);
+ SignalResult(data, /*pReply=*/nullptr, MESSAGE_OK);
+ }
+
+ void SetSession(TRemoteClientSession* session) {
+ if (!ExpectReply) {
+ Session = session;
+ }
+ }
+
+private:
+ void SignalResult(TBusSyncMessageData* data, TBusMessage* pReply, EMessageStatus status) const {
+ Y_VERIFY(data, "Message data is set to NULL.");
+ TGuard<TMutex> G(data->ReplyLock);
+ data->Reply = pReply;
+ data->ReplyStatus = status;
+ data->ReplyEvent.Signal();
+ }
+
+private:
+ // This is weird, because in regular client one-way-ness is selected per call, not per session.
+ bool ExpectReply;
+ TRemoteClientSession* Session;
+};
+
+namespace NBus {
+ namespace NPrivate {
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4250) // 'NBus::NPrivate::TRemoteClientSession' : inherits 'NBus::NPrivate::TBusSessionImpl::NBus::NPrivate::TBusSessionImpl::GetConfig' via dominance
+#endif
+
+ ///////////////////////////////////////////////////////////////////////////
+ class TBusSyncSourceSessionImpl
+ : private TSyncHandler
+ // TODO: do not extend TRemoteClientSession
+ ,
+ public TRemoteClientSession {
+ private:
+ bool NeedReply;
+
+ public:
+ TBusSyncSourceSessionImpl(TBusMessageQueue* queue, TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply, const TString& name)
+ : TSyncHandler(needReply)
+ , TRemoteClientSession(queue, proto, this, config, name)
+ , NeedReply(needReply)
+ {
+ SetSession(this);
+ }
+
+ TBusMessage* SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr = nullptr) {
+ Y_VERIFY(!Queue->GetExecutor()->IsInExecutorThread(),
+ "SendSyncMessage must not be called from executor thread");
+
+ TBusMessage* reply = nullptr;
+ THolder<TBusSyncMessageData> data(new TBusSyncMessageData());
+
+ pMessage->Data = data.Get();
+
+ {
+ TGuard<TMutex> G(data->ReplyLock);
+ if (NeedReply) {
+ status = SendMessage(pMessage, addr, false); // probably should be true
+ } else {
+ status = SendMessageOneWay(pMessage, addr);
+ }
+
+ if (status == MESSAGE_OK) {
+ data->ReplyEvent.Wait(data->ReplyLock);
+ TBusSyncMessageData* rdata = static_cast<TBusSyncMessageData*>(pMessage->Data);
+ Y_VERIFY(rdata == data.Get(), "Message data pointer should not be modified.");
+ reply = rdata->Reply;
+ status = rdata->ReplyStatus;
+ }
+ }
+
+ // deletion of message and reply is a job of application.
+ pMessage->Data = nullptr;
+
+ return reply;
+ }
+ };
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+ }
+}
+
+TBusSyncSourceSession::TBusSyncSourceSession(TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> session)
+ : Session(session)
+{
+}
+
+TBusSyncSourceSession::~TBusSyncSourceSession() {
+ Shutdown();
+}
+
+void TBusSyncSourceSession::Shutdown() {
+ Session->Shutdown();
+}
+
+TBusMessage* TBusSyncSourceSession::SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr) {
+ return Session->SendSyncMessage(pMessage, status, addr);
+}
+
+int TBusSyncSourceSession::RegisterService(const char* hostname, TBusKey start, TBusKey end, EIpVersion ipVersion) {
+ return Session->RegisterService(hostname, start, end, ipVersion);
+}
+
+int TBusSyncSourceSession::GetInFlight() {
+ return Session->GetInFlight();
+}
+
+const TBusProtocol* TBusSyncSourceSession::GetProto() const {
+ return Session->GetProto();
+}
+
+const TBusClientSession* TBusSyncSourceSession::GetBusClientSessionWorkaroundDoNotUse() const {
+ return Session.Get();
+}
+
+TBusSyncClientSessionPtr TBusMessageQueue::CreateSyncSource(TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply, const TString& name) {
+ TIntrusivePtr<TBusSyncSourceSessionImpl> session = new TBusSyncSourceSessionImpl(this, proto, config, needReply, name);
+ Add(session.Get());
+ return new TBusSyncSourceSession(session);
+}
+
+void TBusMessageQueue::Destroy(TBusSyncClientSessionPtr session) {
+ Destroy(session->Session.Get());
+ Y_UNUSED(session->Session.Release());
+}
diff --git a/library/cpp/messagebus/test/TestMessageBus.py b/library/cpp/messagebus/test/TestMessageBus.py
new file mode 100644
index 0000000000..0bbaa0a313
--- /dev/null
+++ b/library/cpp/messagebus/test/TestMessageBus.py
@@ -0,0 +1,8 @@
+from devtools.fleur.ytest import group, constraint
+from devtools.fleur.ytest.integration import UnitTestGroup
+
+@group
+@constraint('library.messagebus')
+class TestMessageBus(UnitTestGroup):
+ def __init__(self, context):
+ UnitTestGroup.__init__(self, context, 'MessageBus', 'library-messagebus-test-ut')
diff --git a/library/cpp/messagebus/test/example/client/client.cpp b/library/cpp/messagebus/test/example/client/client.cpp
new file mode 100644
index 0000000000..89b5f2c9be
--- /dev/null
+++ b/library/cpp/messagebus/test/example/client/client.cpp
@@ -0,0 +1,81 @@
+#include <library/cpp/messagebus/test/example/common/proto.h>
+
+#include <util/random/random.h>
+
+using namespace NBus;
+using namespace NCalculator;
+
+namespace NCalculator {
+ struct TCalculatorClient: public IBusClientHandler {
+ TCalculatorProtocol Proto;
+ TBusMessageQueuePtr MessageQueue;
+ TBusClientSessionPtr ClientSession;
+
+ TCalculatorClient() {
+ MessageQueue = CreateMessageQueue();
+ TBusClientSessionConfig config;
+ config.TotalTimeout = 2 * 1000;
+ ClientSession = TBusClientSession::Create(&Proto, this, config, MessageQueue);
+ }
+
+ ~TCalculatorClient() override {
+ MessageQueue->Stop();
+ }
+
+ void OnReply(TAutoPtr<TBusMessage> request, TAutoPtr<TBusMessage> response0) override {
+ Y_VERIFY(response0->GetHeader()->Type == TResponse::MessageType, "wrong response");
+ TResponse* response = VerifyDynamicCast<TResponse*>(response0.Get());
+ if (request->GetHeader()->Type == TRequestSum::MessageType) {
+ TRequestSum* requestSum = VerifyDynamicCast<TRequestSum*>(request.Get());
+ int a = requestSum->Record.GetA();
+ int b = requestSum->Record.GetB();
+ Cerr << a << " + " << b << " = " << response->Record.GetResult() << "\n";
+ } else if (request->GetHeader()->Type == TRequestMul::MessageType) {
+ TRequestMul* requestMul = VerifyDynamicCast<TRequestMul*>(request.Get());
+ int a = requestMul->Record.GetA();
+ int b = requestMul->Record.GetB();
+ Cerr << a << " * " << b << " = " << response->Record.GetResult() << "\n";
+ } else {
+ Y_FAIL("unknown request");
+ }
+ }
+
+ void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override {
+ Cerr << "got error " << status << "\n";
+ }
+ };
+
+}
+
+int main(int, char**) {
+ TCalculatorClient client;
+
+ for (;;) {
+ TNetAddr addr(TNetAddr("127.0.0.1", TCalculatorProtocol().GetPort()));
+
+ int a = RandomNumber<unsigned>(10);
+ int b = RandomNumber<unsigned>(10);
+ EMessageStatus ok;
+ if (RandomNumber<bool>()) {
+ TAutoPtr<TRequestSum> request(new TRequestSum);
+ request->Record.SetA(a);
+ request->Record.SetB(b);
+ Cerr << "sending " << a << " + " << b << "\n";
+ ok = client.ClientSession->SendMessageAutoPtr(request, &addr);
+ } else {
+ TAutoPtr<TRequestMul> request(new TRequestMul);
+ request->Record.SetA(a);
+ request->Record.SetB(b);
+ Cerr << "sending " << a << " * " << b << "\n";
+ ok = client.ClientSession->SendMessageAutoPtr(request, &addr);
+ }
+
+ if (ok != MESSAGE_OK) {
+ Cerr << "failed to send message " << ok << "\n";
+ }
+
+ Sleep(TDuration::Seconds(1));
+ }
+
+ return 0;
+}
diff --git a/library/cpp/messagebus/test/example/client/ya.make b/library/cpp/messagebus/test/example/client/ya.make
new file mode 100644
index 0000000000..a660a01698
--- /dev/null
+++ b/library/cpp/messagebus/test/example/client/ya.make
@@ -0,0 +1,13 @@
+PROGRAM(messagebus_example_client)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus/test/example/common
+)
+
+SRCS(
+ client.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/test/example/common/messages.proto b/library/cpp/messagebus/test/example/common/messages.proto
new file mode 100644
index 0000000000..16b858fc77
--- /dev/null
+++ b/library/cpp/messagebus/test/example/common/messages.proto
@@ -0,0 +1,15 @@
+package NCalculator;
+
+message TRequestSumRecord {
+ required int32 A = 1;
+ required int32 B = 2;
+}
+
+message TRequestMulRecord {
+ required int32 A = 1;
+ required int32 B = 2;
+}
+
+message TResponseRecord {
+ required int32 Result = 1;
+}
diff --git a/library/cpp/messagebus/test/example/common/proto.cpp b/library/cpp/messagebus/test/example/common/proto.cpp
new file mode 100644
index 0000000000..1d18aa77ea
--- /dev/null
+++ b/library/cpp/messagebus/test/example/common/proto.cpp
@@ -0,0 +1,12 @@
+#include "proto.h"
+
+using namespace NCalculator;
+using namespace NBus;
+
+TCalculatorProtocol::TCalculatorProtocol()
+ : TBusBufferProtocol("Calculator", 34567)
+{
+ RegisterType(new TRequestSum);
+ RegisterType(new TRequestMul);
+ RegisterType(new TResponse);
+}
diff --git a/library/cpp/messagebus/test/example/common/proto.h b/library/cpp/messagebus/test/example/common/proto.h
new file mode 100644
index 0000000000..a151aac468
--- /dev/null
+++ b/library/cpp/messagebus/test/example/common/proto.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include <library/cpp/messagebus/test/example/common/messages.pb.h>
+
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/protobuf/ybusbuf.h>
+
+namespace NCalculator {
+ typedef ::NBus::TBusBufferMessage<TRequestSumRecord, 1> TRequestSum;
+ typedef ::NBus::TBusBufferMessage<TRequestMulRecord, 2> TRequestMul;
+ typedef ::NBus::TBusBufferMessage<TResponseRecord, 3> TResponse;
+
+ struct TCalculatorProtocol: public ::NBus::TBusBufferProtocol {
+ TCalculatorProtocol();
+ };
+
+}
diff --git a/library/cpp/messagebus/test/example/common/ya.make b/library/cpp/messagebus/test/example/common/ya.make
new file mode 100644
index 0000000000..4da16608fc
--- /dev/null
+++ b/library/cpp/messagebus/test/example/common/ya.make
@@ -0,0 +1,15 @@
+LIBRARY(messagebus_test_example_common)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus
+ library/cpp/messagebus/protobuf
+)
+
+SRCS(
+ proto.cpp
+ messages.proto
+)
+
+END()
diff --git a/library/cpp/messagebus/test/example/server/server.cpp b/library/cpp/messagebus/test/example/server/server.cpp
new file mode 100644
index 0000000000..13e52d75f5
--- /dev/null
+++ b/library/cpp/messagebus/test/example/server/server.cpp
@@ -0,0 +1,58 @@
+#include <library/cpp/messagebus/test/example/common/proto.h>
+
+using namespace NBus;
+using namespace NCalculator;
+
+namespace NCalculator {
+ struct TCalculatorServer: public IBusServerHandler {
+ TCalculatorProtocol Proto;
+ TBusMessageQueuePtr MessageQueue;
+ TBusServerSessionPtr ServerSession;
+
+ TCalculatorServer() {
+ MessageQueue = CreateMessageQueue();
+ TBusServerSessionConfig config;
+ ServerSession = TBusServerSession::Create(&Proto, this, config, MessageQueue);
+ }
+
+ ~TCalculatorServer() override {
+ MessageQueue->Stop();
+ }
+
+ void OnMessage(TOnMessageContext& request) override {
+ if (request.GetMessage()->GetHeader()->Type == TRequestSum::MessageType) {
+ TRequestSum* requestSum = VerifyDynamicCast<TRequestSum*>(request.GetMessage());
+ int a = requestSum->Record.GetA();
+ int b = requestSum->Record.GetB();
+ int result = a + b;
+ Cerr << "requested " << a << " + " << b << ", sending " << result << "\n";
+ TAutoPtr<TResponse> response(new TResponse);
+ response->Record.SetResult(result);
+ request.SendReplyMove(response);
+ } else if (request.GetMessage()->GetHeader()->Type == TRequestMul::MessageType) {
+ TRequestMul* requestMul = VerifyDynamicCast<TRequestMul*>(request.GetMessage());
+ int a = requestMul->Record.GetA();
+ int b = requestMul->Record.GetB();
+ int result = a * b;
+ Cerr << "requested " << a << " * " << b << ", sending " << result << "\n";
+ TAutoPtr<TResponse> response(new TResponse);
+ response->Record.SetResult(result);
+ request.SendReplyMove(response);
+ } else {
+ Y_FAIL("unknown request");
+ }
+ }
+ };
+}
+
+int main(int, char**) {
+ TCalculatorServer server;
+
+ Cerr << "listening on port " << server.ServerSession->GetActualListenPort() << "\n";
+
+ for (;;) {
+ Sleep(TDuration::Seconds(1));
+ }
+
+ return 0;
+}
diff --git a/library/cpp/messagebus/test/example/server/ya.make b/library/cpp/messagebus/test/example/server/ya.make
new file mode 100644
index 0000000000..8cdd97cb12
--- /dev/null
+++ b/library/cpp/messagebus/test/example/server/ya.make
@@ -0,0 +1,13 @@
+PROGRAM(messagebus_example_server)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/messagebus/test/example/common
+)
+
+SRCS(
+ server.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/test/example/ya.make b/library/cpp/messagebus/test/example/ya.make
new file mode 100644
index 0000000000..f275351c29
--- /dev/null
+++ b/library/cpp/messagebus/test/example/ya.make
@@ -0,0 +1,7 @@
+OWNER(g:messagebus)
+
+RECURSE(
+ client
+ common
+ server
+)
diff --git a/library/cpp/messagebus/test/helper/alloc_counter.h b/library/cpp/messagebus/test/helper/alloc_counter.h
new file mode 100644
index 0000000000..ec9041cb15
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/alloc_counter.h
@@ -0,0 +1,21 @@
+#pragma once
+
+#include <util/generic/noncopyable.h>
+#include <util/system/atomic.h>
+#include <util/system/yassert.h>
+
+class TAllocCounter : TNonCopyable {
+private:
+ TAtomic* CountPtr;
+
+public:
+ TAllocCounter(TAtomic* countPtr)
+ : CountPtr(countPtr)
+ {
+ AtomicIncrement(*CountPtr);
+ }
+
+ ~TAllocCounter() {
+ Y_VERIFY(AtomicDecrement(*CountPtr) >= 0, "released too many");
+ }
+};
diff --git a/library/cpp/messagebus/test/helper/example.cpp b/library/cpp/messagebus/test/helper/example.cpp
new file mode 100644
index 0000000000..7c6d704042
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/example.cpp
@@ -0,0 +1,281 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "example.h"
+
+#include <util/generic/cast.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+static void FillWithJunk(TArrayRef<char> data) {
+ TStringBuf junk =
+ "01234567890123456789012345678901234567890123456789012345678901234567890123456789"
+ "01234567890123456789012345678901234567890123456789012345678901234567890123456789"
+ "01234567890123456789012345678901234567890123456789012345678901234567890123456789"
+ "01234567890123456789012345678901234567890123456789012345678901234567890123456789";
+
+ for (size_t i = 0; i < data.size(); i += junk.size()) {
+ memcpy(data.data() + i, junk.data(), Min(junk.size(), data.size() - i));
+ }
+}
+
+static TString JunkString(size_t len) {
+ TTempBuf temp(len);
+ TArrayRef<char> tempArrayRef(temp.Data(), len);
+ FillWithJunk(tempArrayRef);
+
+ return TString(tempArrayRef.data(), tempArrayRef.size());
+}
+
+TExampleRequest::TExampleRequest(TAtomic* counterPtr, size_t payloadSize)
+ : TBusMessage(77)
+ , AllocCounter(counterPtr)
+ , Data(JunkString(payloadSize))
+{
+}
+
+TExampleRequest::TExampleRequest(ECreateUninitialized, TAtomic* counterPtr)
+ : TBusMessage(MESSAGE_CREATE_UNINITIALIZED)
+ , AllocCounter(counterPtr)
+{
+}
+
+TExampleResponse::TExampleResponse(TAtomic* counterPtr, size_t payloadSize)
+ : TBusMessage(79)
+ , AllocCounter(counterPtr)
+ , Data(JunkString(payloadSize))
+{
+}
+
+TExampleResponse::TExampleResponse(ECreateUninitialized, TAtomic* counterPtr)
+ : TBusMessage(MESSAGE_CREATE_UNINITIALIZED)
+ , AllocCounter(counterPtr)
+{
+}
+
+TExampleProtocol::TExampleProtocol(int port)
+ : TBusProtocol("Example", port)
+ , RequestCount(0)
+ , ResponseCount(0)
+ , RequestCountDeserialized(0)
+ , ResponseCountDeserialized(0)
+ , StartCount(0)
+{
+}
+
+TExampleProtocol::~TExampleProtocol() {
+ if (UncaughtException()) {
+ // so it could be reported in test
+ return;
+ }
+ Y_VERIFY(0 == AtomicGet(RequestCount), "protocol %s: must be 0 requests allocated, actually %d", GetService(), int(RequestCount));
+ Y_VERIFY(0 == AtomicGet(ResponseCount), "protocol %s: must be 0 responses allocated, actually %d", GetService(), int(ResponseCount));
+ Y_VERIFY(0 == AtomicGet(RequestCountDeserialized), "protocol %s: must be 0 requests deserialized allocated, actually %d", GetService(), int(RequestCountDeserialized));
+ Y_VERIFY(0 == AtomicGet(ResponseCountDeserialized), "protocol %s: must be 0 responses deserialized allocated, actually %d", GetService(), int(ResponseCountDeserialized));
+ Y_VERIFY(0 == AtomicGet(StartCount), "protocol %s: must be 0 start objects allocated, actually %d", GetService(), int(StartCount));
+}
+
+void TExampleProtocol::Serialize(const TBusMessage* message, TBuffer& buffer) {
+ // Messages have no data, we recreate them from scratch
+ // instead of sending, so we don't need to serialize them.
+ if (const TExampleRequest* exampleMessage = dynamic_cast<const TExampleRequest*>(message)) {
+ buffer.Append(exampleMessage->Data.data(), exampleMessage->Data.size());
+ } else if (const TExampleResponse* exampleReply = dynamic_cast<const TExampleResponse*>(message)) {
+ buffer.Append(exampleReply->Data.data(), exampleReply->Data.size());
+ } else {
+ Y_FAIL("unknown message type");
+ }
+}
+
+TAutoPtr<TBusMessage> TExampleProtocol::Deserialize(ui16 messageType, TArrayRef<const char> payload) {
+ // TODO: check data
+ Y_UNUSED(payload);
+
+ if (messageType == 77) {
+ TExampleRequest* exampleMessage = new TExampleRequest(MESSAGE_CREATE_UNINITIALIZED, &RequestCountDeserialized);
+ exampleMessage->Data.append(payload.data(), payload.size());
+ return exampleMessage;
+ } else if (messageType == 79) {
+ TExampleResponse* exampleReply = new TExampleResponse(MESSAGE_CREATE_UNINITIALIZED, &ResponseCountDeserialized);
+ exampleReply->Data.append(payload.data(), payload.size());
+ return exampleReply;
+ } else {
+ return nullptr;
+ }
+}
+
+TExampleClient::TExampleClient(const TBusClientSessionConfig sessionConfig, int port)
+ : Proto(port)
+ , UseCompression(false)
+ , CrashOnError(false)
+ , DataSize(320)
+ , MessageCount(0)
+ , RepliesCount(0)
+ , Errors(0)
+ , LastError(MESSAGE_OK)
+{
+ Bus = CreateMessageQueue("TExampleClient");
+
+ Session = TBusClientSession::Create(&Proto, this, sessionConfig, Bus);
+
+ Session->RegisterService("localhost");
+}
+
+TExampleClient::~TExampleClient() {
+}
+
+EMessageStatus TExampleClient::SendMessage(const TNetAddr* addr) {
+ TAutoPtr<TExampleRequest> message(new TExampleRequest(&Proto.RequestCount, DataSize));
+ message->SetCompressed(UseCompression);
+ return Session->SendMessageAutoPtr(message, addr);
+}
+
+void TExampleClient::SendMessages(size_t count, const TNetAddr* addr) {
+ UNIT_ASSERT(MessageCount == 0);
+ UNIT_ASSERT(RepliesCount == 0);
+ UNIT_ASSERT(Errors == 0);
+
+ WorkDone.Reset();
+ MessageCount = count;
+ for (ssize_t i = 0; i < MessageCount; ++i) {
+ EMessageStatus s = SendMessage(addr);
+ UNIT_ASSERT_EQUAL_C(s, MESSAGE_OK, "expecting OK, got " << s);
+ }
+}
+
+void TExampleClient::SendMessages(size_t count, const TNetAddr& addr) {
+ SendMessages(count, &addr);
+}
+
+void TExampleClient::ResetCounters() {
+ MessageCount = 0;
+ RepliesCount = 0;
+ Errors = 0;
+ LastError = MESSAGE_OK;
+
+ WorkDone.Reset();
+}
+
+void TExampleClient::WaitReplies() {
+ WorkDone.WaitT(TDuration::Seconds(60));
+
+ UNIT_ASSERT_VALUES_EQUAL(AtomicGet(RepliesCount), MessageCount);
+ UNIT_ASSERT_VALUES_EQUAL(AtomicGet(Errors), 0);
+ UNIT_ASSERT_VALUES_EQUAL(Session->GetInFlight(), 0);
+
+ ResetCounters();
+}
+
+EMessageStatus TExampleClient::WaitForError() {
+ WorkDone.WaitT(TDuration::Seconds(60));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, MessageCount);
+ UNIT_ASSERT_VALUES_EQUAL(0, AtomicGet(RepliesCount));
+ UNIT_ASSERT_VALUES_EQUAL(0, Session->GetInFlight());
+ UNIT_ASSERT_VALUES_EQUAL(1, Errors);
+ EMessageStatus result = LastError;
+
+ ResetCounters();
+ return result;
+}
+
+void TExampleClient::WaitForError(EMessageStatus status) {
+ EMessageStatus error = WaitForError();
+ UNIT_ASSERT_VALUES_EQUAL(status, error);
+}
+
+void TExampleClient::SendMessagesWaitReplies(size_t count, const TNetAddr* addr) {
+ SendMessages(count, addr);
+ WaitReplies();
+}
+
+void TExampleClient::SendMessagesWaitReplies(size_t count, const TNetAddr& addr) {
+ SendMessagesWaitReplies(count, &addr);
+}
+
+void TExampleClient::OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) {
+ Y_UNUSED(mess);
+ Y_UNUSED(reply);
+
+ if (AtomicIncrement(RepliesCount) == MessageCount) {
+ WorkDone.Signal();
+ }
+}
+
+void TExampleClient::OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) {
+ if (CrashOnError) {
+ Y_FAIL("client failed: %s", ToCString(status));
+ }
+
+ Y_UNUSED(mess);
+
+ AtomicIncrement(Errors);
+ LastError = status;
+ WorkDone.Signal();
+}
+
+TExampleServer::TExampleServer(
+ const char* name,
+ const TBusServerSessionConfig& sessionConfig)
+ : UseCompression(false)
+ , AckMessageBeforeSendReply(false)
+ , ForgetRequest(false)
+{
+ Bus = CreateMessageQueue(name);
+ Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus);
+}
+
+TExampleServer::TExampleServer(unsigned port, const char* name)
+ : UseCompression(false)
+ , AckMessageBeforeSendReply(false)
+ , ForgetRequest(false)
+{
+ Bus = CreateMessageQueue(name);
+ TBusServerSessionConfig sessionConfig;
+ sessionConfig.ListenPort = port;
+ Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus);
+}
+
+TExampleServer::~TExampleServer() {
+}
+
+size_t TExampleServer::GetInFlight() const {
+ return Session->GetInFlight();
+}
+
+unsigned TExampleServer::GetActualListenPort() const {
+ return Session->GetActualListenPort();
+}
+
+TNetAddr TExampleServer::GetActualListenAddr() const {
+ return TNetAddr("127.0.0.1", GetActualListenPort());
+}
+
+void TExampleServer::WaitForOnMessageCount(unsigned n) {
+ TestSync.WaitFor(n);
+}
+
+void TExampleServer::OnMessage(TOnMessageContext& mess) {
+ TestSync.Inc();
+
+ TExampleRequest* request = VerifyDynamicCast<TExampleRequest*>(mess.GetMessage());
+
+ if (ForgetRequest) {
+ mess.ForgetRequest();
+ return;
+ }
+
+ TAutoPtr<TBusMessage> reply(new TExampleResponse(&Proto.ResponseCount, DataSize.GetOrElse(request->Data.size())));
+ reply->SetCompressed(UseCompression);
+
+ EMessageStatus status;
+ if (AckMessageBeforeSendReply) {
+ TBusIdentity ident;
+ mess.AckMessage(ident);
+ status = Session->SendReply(ident, reply.Release()); // TODO: leaks on error
+ } else {
+ status = mess.SendReplyMove(reply);
+ }
+
+ Y_VERIFY(status == MESSAGE_OK, "failed to send reply: %s", ToString(status).data());
+}
diff --git a/library/cpp/messagebus/test/helper/example.h b/library/cpp/messagebus/test/helper/example.h
new file mode 100644
index 0000000000..26b7475308
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/example.h
@@ -0,0 +1,132 @@
+#pragma once
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "alloc_counter.h"
+#include "message_handler_error.h"
+
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/misc/test_sync.h>
+
+#include <util/system/event.h>
+
+namespace NBus {
+ namespace NTest {
+ class TExampleRequest: public TBusMessage {
+ friend class TExampleProtocol;
+
+ private:
+ TAllocCounter AllocCounter;
+
+ public:
+ TString Data;
+
+ public:
+ TExampleRequest(TAtomic* counterPtr, size_t payloadSize = 320);
+ TExampleRequest(ECreateUninitialized, TAtomic* counterPtr);
+ };
+
+ class TExampleResponse: public TBusMessage {
+ friend class TExampleProtocol;
+
+ private:
+ TAllocCounter AllocCounter;
+
+ public:
+ TString Data;
+ TExampleResponse(TAtomic* counterPtr, size_t payloadSize = 320);
+ TExampleResponse(ECreateUninitialized, TAtomic* counterPtr);
+ };
+
+ class TExampleProtocol: public TBusProtocol {
+ public:
+ TAtomic RequestCount;
+ TAtomic ResponseCount;
+ TAtomic RequestCountDeserialized;
+ TAtomic ResponseCountDeserialized;
+ TAtomic StartCount;
+
+ TExampleProtocol(int port = 0);
+
+ ~TExampleProtocol() override;
+
+ void Serialize(const TBusMessage* message, TBuffer& buffer) override;
+
+ TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override;
+ };
+
+ class TExampleClient: private TBusClientHandlerError {
+ public:
+ TExampleProtocol Proto;
+ bool UseCompression;
+ bool CrashOnError;
+ size_t DataSize;
+
+ ssize_t MessageCount;
+ TAtomic RepliesCount;
+ TAtomic Errors;
+ EMessageStatus LastError;
+
+ TSystemEvent WorkDone;
+
+ TBusMessageQueuePtr Bus;
+ TBusClientSessionPtr Session;
+
+ public:
+ TExampleClient(const TBusClientSessionConfig sessionConfig = TBusClientSessionConfig(), int port = 0);
+ ~TExampleClient() override;
+
+ EMessageStatus SendMessage(const TNetAddr* addr = nullptr);
+
+ void SendMessages(size_t count, const TNetAddr* addr = nullptr);
+ void SendMessages(size_t count, const TNetAddr& addr);
+
+ void ResetCounters();
+ void WaitReplies();
+ EMessageStatus WaitForError();
+ void WaitForError(EMessageStatus status);
+
+ void SendMessagesWaitReplies(size_t count, const TNetAddr* addr = nullptr);
+ void SendMessagesWaitReplies(size_t count, const TNetAddr& addr);
+
+ void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override;
+
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus) override;
+ };
+
+ class TExampleServer: private TBusServerHandlerError {
+ public:
+ TExampleProtocol Proto;
+ bool UseCompression;
+ bool AckMessageBeforeSendReply;
+ TMaybe<size_t> DataSize; // Nothing means use request size
+ bool ForgetRequest;
+
+ TTestSync TestSync;
+
+ TBusMessageQueuePtr Bus;
+ TBusServerSessionPtr Session;
+
+ public:
+ TExampleServer(
+ const char* name = "TExampleServer",
+ const TBusServerSessionConfig& sessionConfig = TBusServerSessionConfig());
+
+ TExampleServer(unsigned port, const char* name = "TExampleServer");
+
+ ~TExampleServer() override;
+
+ public:
+ size_t GetInFlight() const;
+ unsigned GetActualListenPort() const;
+ // any of
+ TNetAddr GetActualListenAddr() const;
+
+ void WaitForOnMessageCount(unsigned n);
+
+ protected:
+ void OnMessage(TOnMessageContext& mess) override;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/test/helper/example_module.cpp b/library/cpp/messagebus/test/helper/example_module.cpp
new file mode 100644
index 0000000000..65ecfcf73f
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/example_module.cpp
@@ -0,0 +1,43 @@
+#include "example_module.h"
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+TExampleModule::TExampleModule()
+ : TBusModule("TExampleModule")
+{
+ TBusQueueConfig queueConfig;
+ queueConfig.NumWorkers = 5;
+ Queue = CreateMessageQueue(queueConfig);
+}
+
+void TExampleModule::StartModule() {
+ CreatePrivateSessions(Queue.Get());
+ StartInput();
+}
+
+bool TExampleModule::Shutdown() {
+ TBusModule::Shutdown();
+ return true;
+}
+
+TBusServerSessionPtr TExampleModule::CreateExtSession(TBusMessageQueue&) {
+ return nullptr;
+}
+
+TBusServerSessionPtr TExampleServerModule::CreateExtSession(TBusMessageQueue& queue) {
+ TBusServerSessionPtr r = CreateDefaultDestination(queue, &Proto, TBusServerSessionConfig());
+ ServerAddr = TNetAddr("localhost", r->GetActualListenPort());
+ return r;
+}
+
+TExampleClientModule::TExampleClientModule()
+ : Source()
+{
+}
+
+TBusServerSessionPtr TExampleClientModule::CreateExtSession(TBusMessageQueue& queue) {
+ Source = CreateDefaultSource(queue, &Proto, TBusServerSessionConfig());
+ Source->RegisterService("localhost");
+ return nullptr;
+}
diff --git a/library/cpp/messagebus/test/helper/example_module.h b/library/cpp/messagebus/test/helper/example_module.h
new file mode 100644
index 0000000000..a0b295f613
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/example_module.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include "example.h"
+
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+namespace NBus {
+ namespace NTest {
+ struct TExampleModule: public TBusModule {
+ TExampleProtocol Proto;
+ TBusMessageQueuePtr Queue;
+
+ TExampleModule();
+
+ void StartModule();
+
+ bool Shutdown() override;
+
+ // nop by default
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override;
+ };
+
+ struct TExampleServerModule: public TExampleModule {
+ TNetAddr ServerAddr;
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override;
+ };
+
+ struct TExampleClientModule: public TExampleModule {
+ TBusClientSessionPtr Source;
+
+ TExampleClientModule();
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/test/helper/fixed_port.cpp b/library/cpp/messagebus/test/helper/fixed_port.cpp
new file mode 100644
index 0000000000..258da0d1a5
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/fixed_port.cpp
@@ -0,0 +1,10 @@
+#include "fixed_port.h"
+
+#include <util/system/env.h>
+
+#include <stdlib.h>
+
+bool NBus::NTest::IsFixedPortTestAllowed() {
+ // TODO: report skipped tests to test
+ return !GetEnv("MB_TESTS_SKIP_FIXED_PORT");
+}
diff --git a/library/cpp/messagebus/test/helper/fixed_port.h b/library/cpp/messagebus/test/helper/fixed_port.h
new file mode 100644
index 0000000000..a9c61ebc63
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/fixed_port.h
@@ -0,0 +1,11 @@
+#pragma once
+
+namespace NBus {
+ namespace NTest {
+ bool IsFixedPortTestAllowed();
+
+ // Must not be in range OS uses for bind on random port.
+ const unsigned FixedPort = 4927;
+
+ }
+}
diff --git a/library/cpp/messagebus/test/helper/hanging_server.cpp b/library/cpp/messagebus/test/helper/hanging_server.cpp
new file mode 100644
index 0000000000..a35514b00d
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/hanging_server.cpp
@@ -0,0 +1,13 @@
+#include "hanging_server.h"
+
+#include <util/system/yassert.h>
+
+using namespace NBus;
+
+THangingServer::THangingServer(int port) {
+ BindResult = BindOnPort(port, false);
+}
+
+int THangingServer::GetPort() const {
+ return BindResult.first;
+}
diff --git a/library/cpp/messagebus/test/helper/hanging_server.h b/library/cpp/messagebus/test/helper/hanging_server.h
new file mode 100644
index 0000000000..cc9fb274d8
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/hanging_server.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <library/cpp/messagebus/network.h>
+
+#include <util/network/sock.h>
+
+class THangingServer {
+private:
+ std::pair<unsigned, TVector<NBus::TBindResult>> BindResult;
+
+public:
+ // listen on given port, and nothing else
+ THangingServer(int port = 0);
+ // actual port
+ int GetPort() const;
+};
diff --git a/library/cpp/messagebus/test/helper/message_handler_error.cpp b/library/cpp/messagebus/test/helper/message_handler_error.cpp
new file mode 100644
index 0000000000..c09811ec67
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/message_handler_error.cpp
@@ -0,0 +1,26 @@
+#include "message_handler_error.h"
+
+#include <util/system/yassert.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+void TBusClientHandlerError::OnError(TAutoPtr<TBusMessage>, EMessageStatus status) {
+ Y_FAIL("must not be called, status: %s", ToString(status).data());
+}
+
+void TBusClientHandlerError::OnReply(TAutoPtr<TBusMessage>, TAutoPtr<TBusMessage>) {
+ Y_FAIL("must not be called");
+}
+
+void TBusClientHandlerError::OnMessageSentOneWay(TAutoPtr<TBusMessage>) {
+ Y_FAIL("must not be called");
+}
+
+void TBusServerHandlerError::OnError(TAutoPtr<TBusMessage>, EMessageStatus status) {
+ Y_FAIL("must not be called, status: %s", ToString(status).data());
+}
+
+void TBusServerHandlerError::OnMessage(TOnMessageContext&) {
+ Y_FAIL("must not be called");
+}
diff --git a/library/cpp/messagebus/test/helper/message_handler_error.h b/library/cpp/messagebus/test/helper/message_handler_error.h
new file mode 100644
index 0000000000..a314b10761
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/message_handler_error.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+namespace NBus {
+ namespace NTest {
+ struct TBusClientHandlerError: public IBusClientHandler {
+ void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) override;
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage> pMessage) override;
+ void OnReply(TAutoPtr<TBusMessage> pMessage, TAutoPtr<TBusMessage> pReply) override;
+ };
+
+ struct TBusServerHandlerError: public IBusServerHandler {
+ void OnError(TAutoPtr<TBusMessage> pMessage, EMessageStatus status) override;
+ void OnMessage(TOnMessageContext& pMessage) override;
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/test/helper/object_count_check.h b/library/cpp/messagebus/test/helper/object_count_check.h
new file mode 100644
index 0000000000..1c4756e58c
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/object_count_check.h
@@ -0,0 +1,74 @@
+#pragma once
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/remote_client_connection.h>
+#include <library/cpp/messagebus/remote_client_session.h>
+#include <library/cpp/messagebus/remote_server_connection.h>
+#include <library/cpp/messagebus/remote_server_session.h>
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <util/generic/object_counter.h>
+#include <util/system/type_name.h>
+#include <util/stream/output.h>
+
+#include <typeinfo>
+
+struct TObjectCountCheck {
+ bool Enabled;
+
+ template <typename T>
+ struct TReset {
+ TObjectCountCheck* const Thiz;
+
+ TReset(TObjectCountCheck* thiz)
+ : Thiz(thiz)
+ {
+ }
+
+ void operator()() {
+ long oldValue = TObjectCounter<T>::ResetObjectCount();
+ if (oldValue != 0) {
+ Cerr << "warning: previous counter: " << oldValue << " for " << TypeName<T>() << Endl;
+ Cerr << "won't check in this test" << Endl;
+ Thiz->Enabled = false;
+ }
+ }
+ };
+
+ TObjectCountCheck() {
+ Enabled = true;
+ DoForAllCounters<TReset>();
+ }
+
+ template <typename T>
+ struct TCheckZero {
+ TCheckZero(TObjectCountCheck*) {
+ }
+
+ void operator()() {
+ UNIT_ASSERT_VALUES_EQUAL_C(0L, TObjectCounter<T>::ObjectCount(), TypeName<T>());
+ }
+ };
+
+ ~TObjectCountCheck() {
+ if (Enabled) {
+ DoForAllCounters<TCheckZero>();
+ }
+ }
+
+ template <template <typename> class TOp>
+ void DoForAllCounters() {
+ TOp< ::NBus::NPrivate::TRemoteClientConnection>(this)();
+ TOp< ::NBus::NPrivate::TRemoteServerConnection>(this)();
+ TOp< ::NBus::NPrivate::TRemoteClientSession>(this)();
+ TOp< ::NBus::NPrivate::TRemoteServerSession>(this)();
+ TOp< ::NBus::NPrivate::TScheduler>(this)();
+ TOp< ::NEventLoop::TEventLoop>(this)();
+ TOp< ::NEventLoop::TChannel>(this)();
+ TOp< ::NBus::TBusModule>(this)();
+ TOp< ::NBus::TBusJob>(this)();
+ }
+};
diff --git a/library/cpp/messagebus/test/helper/wait_for.h b/library/cpp/messagebus/test/helper/wait_for.h
new file mode 100644
index 0000000000..f09958d4c0
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/wait_for.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include <util/datetime/base.h>
+#include <util/system/yassert.h>
+
+#define UNIT_WAIT_FOR(condition) \
+ do { \
+ TInstant start(TInstant::Now()); \
+ while (!(condition) && (TInstant::Now() - start < TDuration::Seconds(10))) { \
+ Sleep(TDuration::MilliSeconds(1)); \
+ } \
+ /* TODO: use UNIT_ASSERT if in unittest thread */ \
+ Y_VERIFY(condition, "condition failed after 10 seconds wait"); \
+ } while (0)
diff --git a/library/cpp/messagebus/test/helper/ya.make b/library/cpp/messagebus/test/helper/ya.make
new file mode 100644
index 0000000000..97bd45f573
--- /dev/null
+++ b/library/cpp/messagebus/test/helper/ya.make
@@ -0,0 +1,17 @@
+LIBRARY(messagebus_test_helper)
+
+OWNER(g:messagebus)
+
+SRCS(
+ example.cpp
+ example_module.cpp
+ fixed_port.cpp
+ message_handler_error.cpp
+ hanging_server.cpp
+)
+
+PEERDIR(
+ library/cpp/messagebus/oldmodule
+)
+
+END()
diff --git a/library/cpp/messagebus/test/perftest/messages.proto b/library/cpp/messagebus/test/perftest/messages.proto
new file mode 100644
index 0000000000..8919034e7a
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/messages.proto
@@ -0,0 +1,7 @@
+message TPerftestRequestRecord {
+ required string Data = 1;
+}
+
+message TPerftestResponseRecord {
+ required string Data = 1;
+}
diff --git a/library/cpp/messagebus/test/perftest/perftest.cpp b/library/cpp/messagebus/test/perftest/perftest.cpp
new file mode 100644
index 0000000000..8489319278
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/perftest.cpp
@@ -0,0 +1,713 @@
+#include "simple_proto.h"
+
+#include <library/cpp/messagebus/test/perftest/messages.pb.h>
+
+#include <library/cpp/messagebus/text_utils.h>
+#include <library/cpp/messagebus/thread_extra.h>
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+#include <library/cpp/messagebus/protobuf/ybusbuf.h>
+#include <library/cpp/messagebus/www/www.h>
+
+#include <library/cpp/deprecated/threadable/threadable.h>
+#include <library/cpp/execprofile/profile.h>
+#include <library/cpp/getopt/opt.h>
+#include <library/cpp/lwtrace/start.h>
+#include <library/cpp/sighandler/async_signals_handler.h>
+#include <library/cpp/threading/future/legacy_future.h>
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <util/generic/vector.h>
+#include <util/generic/yexception.h>
+#include <util/random/random.h>
+#include <util/stream/file.h>
+#include <util/stream/output.h>
+#include <util/stream/str.h>
+#include <util/string/split.h>
+#include <util/system/event.h>
+#include <util/system/sysstat.h>
+#include <util/system/thread.h>
+#include <util/thread/lfqueue.h>
+
+#include <signal.h>
+#include <stdlib.h>
+
+using namespace NBus;
+
+///////////////////////////////////////////////////////
+/// \brief Configuration parameters of the test
+
+const int DEFAULT_PORT = 55666;
+
+struct TPerftestConfig {
+ TString Nodes; ///< node1:port1,node2:port2
+ int ClientCount;
+ int MessageSize; ///< size of message to send
+ int Delay; ///< server delay (milliseconds)
+ float Failure; ///< simulated failure rate
+ int ServerPort;
+ int Run;
+ bool ServerUseModules;
+ bool ExecuteOnMessageInWorkerPool;
+ bool ExecuteOnReplyInWorkerPool;
+ bool UseCompression;
+ bool Profile;
+ unsigned WwwPort;
+
+ TPerftestConfig();
+
+ void Print() {
+ fprintf(stderr, "ClientCount=%d\n", ClientCount);
+ fprintf(stderr, "ServerPort=%d\n", ServerPort);
+ fprintf(stderr, "Delay=%d usecs\n", Delay);
+ fprintf(stderr, "MessageSize=%d bytes\n", MessageSize);
+ fprintf(stderr, "Failure=%.3f%%\n", Failure * 100.0);
+ fprintf(stderr, "Runtime=%d seconds\n", Run);
+ fprintf(stderr, "ServerUseModules=%s\n", ServerUseModules ? "true" : "false");
+ fprintf(stderr, "ExecuteOnMessageInWorkerPool=%s\n", ExecuteOnMessageInWorkerPool ? "true" : "false");
+ fprintf(stderr, "ExecuteOnReplyInWorkerPool=%s\n", ExecuteOnReplyInWorkerPool ? "true" : "false");
+ fprintf(stderr, "UseCompression=%s\n", UseCompression ? "true" : "false");
+ fprintf(stderr, "Profile=%s\n", Profile ? "true" : "false");
+ fprintf(stderr, "WwwPort=%u\n", WwwPort);
+ }
+};
+
+extern TPerftestConfig* TheConfig;
+extern bool TheExit;
+
+TVector<TNetAddr> ServerAddresses;
+
+struct TConfig {
+ TBusQueueConfig ServerQueueConfig;
+ TBusQueueConfig ClientQueueConfig;
+ TBusServerSessionConfig ServerSessionConfig;
+ TBusClientSessionConfig ClientSessionConfig;
+ bool SimpleProtocol;
+
+private:
+ void ConfigureDefaults(TBusQueueConfig& config) {
+ config.NumWorkers = 4;
+ }
+
+ void ConfigureDefaults(TBusSessionConfig& config) {
+ config.MaxInFlight = 10000;
+ config.SendTimeout = TDuration::Seconds(20).MilliSeconds();
+ config.TotalTimeout = TDuration::Seconds(60).MilliSeconds();
+ }
+
+public:
+ TConfig()
+ : SimpleProtocol(false)
+ {
+ ConfigureDefaults(ServerQueueConfig);
+ ConfigureDefaults(ClientQueueConfig);
+ ConfigureDefaults(ServerSessionConfig);
+ ConfigureDefaults(ClientSessionConfig);
+ }
+
+ void Print() {
+ // TODO: do not print server if only client and vice verse
+ Cerr << "server queue config:\n";
+ Cerr << IndentText(ServerQueueConfig.PrintToString());
+ Cerr << "server session config:" << Endl;
+ Cerr << IndentText(ServerSessionConfig.PrintToString());
+ Cerr << "client queue config:\n";
+ Cerr << IndentText(ClientQueueConfig.PrintToString());
+ Cerr << "client session config:" << Endl;
+ Cerr << IndentText(ClientSessionConfig.PrintToString());
+ Cerr << "simple protocol: " << SimpleProtocol << "\n";
+ }
+};
+
+TConfig Config;
+
+////////////////////////////////////////////////////////////////
+/// \brief Fast message
+
+using TPerftestRequest = TBusBufferMessage<TPerftestRequestRecord, 77>;
+using TPerftestResponse = TBusBufferMessage<TPerftestResponseRecord, 79>;
+
+static size_t RequestSize() {
+ return RandomNumber<size_t>(TheConfig->MessageSize * 2 + 1);
+}
+
+TAutoPtr<TBusMessage> NewRequest() {
+ if (Config.SimpleProtocol) {
+ TAutoPtr<TSimpleMessage> r(new TSimpleMessage);
+ r->SetCompressed(TheConfig->UseCompression);
+ r->Payload = 10;
+ return r.Release();
+ } else {
+ TAutoPtr<TPerftestRequest> r(new TPerftestRequest);
+ r->SetCompressed(TheConfig->UseCompression);
+ // TODO: use random content for better compression test
+ r->Record.SetData(TString(RequestSize(), '?'));
+ return r.Release();
+ }
+}
+
+void CheckRequest(TPerftestRequest* request) {
+ const TString& data = request->Record.GetData();
+ for (size_t i = 0; i != data.size(); ++i) {
+ Y_VERIFY(data.at(i) == '?', "must be question mark");
+ }
+}
+
+TAutoPtr<TPerftestResponse> NewResponse(TPerftestRequest* request) {
+ TAutoPtr<TPerftestResponse> r(new TPerftestResponse);
+ r->SetCompressed(TheConfig->UseCompression);
+ r->Record.SetData(TString(request->Record.GetData().size(), '.'));
+ return r;
+}
+
+void CheckResponse(TPerftestResponse* response) {
+ const TString& data = response->Record.GetData();
+ for (size_t i = 0; i != data.size(); ++i) {
+ Y_VERIFY(data.at(i) == '.', "must be dot");
+ }
+}
+
+////////////////////////////////////////////////////////////////////
+/// \brief Fast protocol that common between client and server
+class TPerftestProtocol: public TBusBufferProtocol {
+public:
+ TPerftestProtocol()
+ : TBusBufferProtocol("TPerftestProtocol", TheConfig->ServerPort)
+ {
+ RegisterType(new TPerftestRequest);
+ RegisterType(new TPerftestResponse);
+ }
+};
+
+class TPerftestServer;
+class TPerftestUsingModule;
+class TPerftestClient;
+
+struct TTestStats {
+ TInstant Start;
+
+ TAtomic Messages;
+ TAtomic Errors;
+ TAtomic Replies;
+
+ void IncMessage() {
+ AtomicIncrement(Messages);
+ }
+ void IncReplies() {
+ AtomicDecrement(Messages);
+ AtomicIncrement(Replies);
+ }
+ int NumMessage() {
+ return AtomicGet(Messages);
+ }
+ void IncErrors() {
+ AtomicDecrement(Messages);
+ AtomicIncrement(Errors);
+ }
+ int NumErrors() {
+ return AtomicGet(Errors);
+ }
+ int NumReplies() {
+ return AtomicGet(Replies);
+ }
+
+ double GetThroughput() {
+ return NumReplies() * 1000000.0 / (TInstant::Now() - Start).MicroSeconds();
+ }
+
+public:
+ TTestStats()
+ : Start(TInstant::Now())
+ , Messages(0)
+ , Errors(0)
+ , Replies(0)
+ {
+ }
+
+ void PeriodicallyPrint();
+};
+
+TTestStats Stats;
+
+////////////////////////////////////////////////////////////////////
+/// \brief Fast of the client session
+class TPerftestClient : IBusClientHandler {
+public:
+ TBusClientSessionPtr Session;
+ THolder<TBusProtocol> Proto;
+ TBusMessageQueuePtr Bus;
+ TVector<TBusClientConnectionPtr> Connections;
+
+public:
+ /// constructor creates instances of protocol and session
+ TPerftestClient() {
+ /// create or get instance of message queue, need one per application
+ Bus = CreateMessageQueue(Config.ClientQueueConfig, "client");
+
+ if (Config.SimpleProtocol) {
+ Proto.Reset(new TSimpleProtocol);
+ } else {
+ Proto.Reset(new TPerftestProtocol);
+ }
+
+ Session = TBusClientSession::Create(Proto.Get(), this, Config.ClientSessionConfig, Bus);
+
+ for (unsigned i = 0; i < ServerAddresses.size(); ++i) {
+ Connections.push_back(Session->GetConnection(ServerAddresses[i]));
+ }
+ }
+
+ /// dispatch of requests is done here
+ void Work() {
+ SetCurrentThreadName("FastClient::Work");
+
+ while (!TheExit) {
+ TBusClientConnection* connection;
+ if (Connections.size() == 1) {
+ connection = Connections.front().Get();
+ } else {
+ connection = Connections.at(RandomNumber<size_t>()).Get();
+ }
+
+ TBusMessage* message = NewRequest().Release();
+ int ret = connection->SendMessage(message, true);
+
+ if (ret == MESSAGE_OK) {
+ Stats.IncMessage();
+ } else if (ret == MESSAGE_BUSY) {
+ //delete message;
+ //Sleep(TDuration::MilliSeconds(1));
+ //continue;
+ Y_FAIL("unreachable");
+ } else if (ret == MESSAGE_SHUTDOWN) {
+ delete message;
+ } else {
+ delete message;
+ Stats.IncErrors();
+ }
+ }
+ }
+
+ void Stop() {
+ Session->Shutdown();
+ }
+
+ /// actual work is being done here
+ void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override {
+ Y_UNUSED(mess);
+
+ if (Config.SimpleProtocol) {
+ VerifyDynamicCast<TSimpleMessage*>(reply.Get());
+ } else {
+ TPerftestResponse* typed = VerifyDynamicCast<TPerftestResponse*>(reply.Get());
+
+ CheckResponse(typed);
+ }
+
+ Stats.IncReplies();
+ }
+
+ /// message that could not be delivered
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ Y_UNUSED(mess);
+ Y_UNUSED(status);
+
+ if (TheExit) {
+ return;
+ }
+
+ Stats.IncErrors();
+
+ // Y_ASSERT(TheConfig->Failure > 0.0);
+ }
+};
+
+class TPerftestServerCommon {
+public:
+ THolder<TBusProtocol> Proto;
+
+ TBusMessageQueuePtr Bus;
+
+ TBusServerSessionPtr Session;
+
+protected:
+ TPerftestServerCommon(const char* name)
+ : Session()
+ {
+ if (Config.SimpleProtocol) {
+ Proto.Reset(new TSimpleProtocol);
+ } else {
+ Proto.Reset(new TPerftestProtocol);
+ }
+
+ /// create or get instance of single message queue, need one for application
+ Bus = CreateMessageQueue(Config.ServerQueueConfig, name);
+ }
+
+public:
+ void Stop() {
+ Session->Shutdown();
+ }
+};
+
+struct TAsyncRequest {
+ TBusMessage* Request;
+ TInstant ReceivedTime;
+};
+
+/////////////////////////////////////////////////////////////////////
+/// \brief Fast of the server session
+class TPerftestServer: public TPerftestServerCommon, public IBusServerHandler {
+public:
+ TLockFreeQueue<TAsyncRequest> AsyncRequests;
+
+public:
+ TPerftestServer()
+ : TPerftestServerCommon("server")
+ {
+ /// register destination session
+ Session = TBusServerSession::Create(Proto.Get(), this, Config.ServerSessionConfig, Bus);
+ Y_ASSERT(Session && "probably somebody is listening on the same port");
+ }
+
+ /// when message comes, send reply
+ void OnMessage(TOnMessageContext& mess) override {
+ if (Config.SimpleProtocol) {
+ TSimpleMessage* typed = VerifyDynamicCast<TSimpleMessage*>(mess.GetMessage());
+ TAutoPtr<TSimpleMessage> response(new TSimpleMessage);
+ response->Payload = typed->Payload;
+ mess.SendReplyMove(response);
+ return;
+ }
+
+ TPerftestRequest* typed = VerifyDynamicCast<TPerftestRequest*>(mess.GetMessage());
+
+ CheckRequest(typed);
+
+ /// forget replies for few messages, see what happends
+ if (TheConfig->Failure > RandomNumber<double>()) {
+ return;
+ }
+
+ /// sleep requested time
+ if (TheConfig->Delay) {
+ TAsyncRequest request;
+ request.Request = mess.ReleaseMessage();
+ request.ReceivedTime = TInstant::Now();
+ AsyncRequests.Enqueue(request);
+ return;
+ }
+
+ TAutoPtr<TPerftestResponse> reply(NewResponse(typed));
+ /// sent empty reply for each message
+ mess.SendReplyMove(reply);
+ // TODO: count results
+ }
+
+ void Stop() {
+ TPerftestServerCommon::Stop();
+ }
+};
+
+class TPerftestUsingModule: public TPerftestServerCommon, public TBusModule {
+public:
+ TPerftestUsingModule()
+ : TPerftestServerCommon("server")
+ , TBusModule("fast")
+ {
+ Y_VERIFY(CreatePrivateSessions(Bus.Get()), "failed to initialize dupdetect module");
+ Y_VERIFY(StartInput(), "failed to start input");
+ }
+
+ ~TPerftestUsingModule() override {
+ Shutdown();
+ }
+
+private:
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ TPerftestRequest* typed = VerifyDynamicCast<TPerftestRequest*>(mess);
+ CheckRequest(typed);
+
+ /// sleep requested time
+ if (TheConfig->Delay) {
+ usleep(TheConfig->Delay);
+ }
+
+ /// forget replies for few messages, see what happends
+ if (TheConfig->Failure > RandomNumber<double>()) {
+ return nullptr;
+ }
+
+ job->SendReply(NewResponse(typed).Release());
+ return nullptr;
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ return Session = CreateDefaultDestination(queue, Proto.Get(), Config.ServerSessionConfig);
+ }
+};
+
+// ./perftest/perftest -s 11456 -c localhost:11456 -r 60 -n 4 -i 5000
+
+using namespace std;
+using namespace NBus;
+
+static TNetworkAddress ParseNetworkAddress(const char* string) {
+ TString Name;
+ int Port;
+
+ const char* port = strchr(string, ':');
+
+ if (port != nullptr) {
+ Name.append(string, port - string);
+ Port = atoi(port + 1);
+ } else {
+ Name.append(string);
+ Port = TheConfig->ServerPort != 0 ? TheConfig->ServerPort : DEFAULT_PORT;
+ }
+
+ return TNetworkAddress(Name, Port);
+}
+
+TVector<TNetAddr> ParseNodes(const TString nodes) {
+ TVector<TNetAddr> r;
+
+ TVector<TString> hosts;
+
+ size_t numh = Split(nodes.data(), ",", hosts);
+
+ for (int i = 0; i < int(numh); i++) {
+ const TNetworkAddress& networkAddress = ParseNetworkAddress(hosts[i].data());
+ Y_VERIFY(networkAddress.Begin() != networkAddress.End(), "no addresses");
+ r.push_back(TNetAddr(networkAddress, &*networkAddress.Begin()));
+ }
+
+ return r;
+}
+
+TPerftestConfig::TPerftestConfig() {
+ TBusSessionConfig defaultConfig;
+
+ ServerPort = DEFAULT_PORT;
+ Delay = 0; // artificial delay inside server OnMessage()
+ MessageSize = 200;
+ Failure = 0.00;
+ Run = 60; // in seconds
+ Nodes = "localhost";
+ ServerUseModules = false;
+ ExecuteOnMessageInWorkerPool = defaultConfig.ExecuteOnMessageInWorkerPool;
+ ExecuteOnReplyInWorkerPool = defaultConfig.ExecuteOnReplyInWorkerPool;
+ UseCompression = false;
+ Profile = false;
+ WwwPort = 0;
+}
+
+TPerftestConfig* TheConfig = new TPerftestConfig();
+bool TheExit = false;
+
+TSystemEvent StopEvent;
+
+TSimpleSharedPtr<TPerftestServer> Server;
+TSimpleSharedPtr<TPerftestUsingModule> ServerUsingModule;
+
+TVector<TSimpleSharedPtr<TPerftestClient>> Clients;
+TMutex ClientsLock;
+
+void stopsignal(int /*sig*/) {
+ fprintf(stderr, "\n-------------------- exiting ------------------\n");
+ TheExit = true;
+ StopEvent.Signal();
+}
+
+// -s <num> - start server on port <num>
+// -c <node:port,node:port> - start client
+
+void TTestStats::PeriodicallyPrint() {
+ SetCurrentThreadName("print-stats");
+
+ for (;;) {
+ StopEvent.WaitT(TDuration::Seconds(1));
+ if (TheExit)
+ break;
+
+ TVector<TSimpleSharedPtr<TPerftestClient>> clients;
+ {
+ TGuard<TMutex> guard(ClientsLock);
+ clients = Clients;
+ }
+
+ fprintf(stderr, "replies=%d errors=%d throughput=%.3f mess/sec\n",
+ NumReplies(), NumErrors(), GetThroughput());
+ if (!!Server) {
+ fprintf(stderr, "server: q: %u %s\n",
+ (unsigned)Server->Bus->GetExecutor()->GetWorkQueueSize(),
+ Server->Session->GetStatusSingleLine().data());
+ }
+ if (!!ServerUsingModule) {
+ fprintf(stderr, "server: q: %u %s\n",
+ (unsigned)ServerUsingModule->Bus->GetExecutor()->GetWorkQueueSize(),
+ ServerUsingModule->Session->GetStatusSingleLine().data());
+ }
+ for (const auto& client : clients) {
+ fprintf(stderr, "client: q: %u %s\n",
+ (unsigned)client->Bus->GetExecutor()->GetWorkQueueSize(),
+ client->Session->GetStatusSingleLine().data());
+ }
+
+ TStringStream stats;
+
+ bool first = true;
+ if (!!Server) {
+ if (!first) {
+ stats << "\n";
+ }
+ first = false;
+ stats << "server:\n";
+ stats << IndentText(Server->Bus->GetStatus());
+ }
+ if (!!ServerUsingModule) {
+ if (!first) {
+ stats << "\n";
+ }
+ first = false;
+ stats << "server using modules:\n";
+ stats << IndentText(ServerUsingModule->Bus->GetStatus());
+ }
+ for (const auto& client : clients) {
+ if (!first) {
+ stats << "\n";
+ }
+ first = false;
+ stats << "client:\n";
+ stats << IndentText(client->Bus->GetStatus());
+ }
+
+ TUnbufferedFileOutput("stats").Write(stats.Str());
+ }
+}
+
+int main(int argc, char* argv[]) {
+ NLWTrace::StartLwtraceFromEnv();
+
+ /* unix foo */
+ setvbuf(stdout, nullptr, _IONBF, 0);
+ setvbuf(stderr, nullptr, _IONBF, 0);
+ Umask(0);
+ SetAsyncSignalHandler(SIGINT, stopsignal);
+ SetAsyncSignalHandler(SIGTERM, stopsignal);
+#ifndef _win_
+ SetAsyncSignalHandler(SIGUSR1, stopsignal);
+#endif
+ signal(SIGPIPE, SIG_IGN);
+
+ NLastGetopt::TOpts opts = NLastGetopt::TOpts::Default();
+ opts.AddLongOption('s', "server-port", "server port").RequiredArgument("port").StoreResult(&TheConfig->ServerPort);
+ opts.AddCharOption('m', "average message size").RequiredArgument("size").StoreResult(&TheConfig->MessageSize);
+ opts.AddLongOption('c', "server-host", "server hosts").RequiredArgument("host[,host]...").StoreResult(&TheConfig->Nodes);
+ opts.AddCharOption('f', "failure rate (rational number between 0 and 1)").RequiredArgument("rate").StoreResult(&TheConfig->Failure);
+ opts.AddCharOption('w', "delay before reply").RequiredArgument("microseconds").StoreResult(&TheConfig->Delay);
+ opts.AddCharOption('r', "run duration").RequiredArgument("seconds").StoreResult(&TheConfig->Run);
+ opts.AddLongOption("client-count", "amount of clients").RequiredArgument("count").StoreResult(&TheConfig->ClientCount).DefaultValue("1");
+ opts.AddLongOption("server-use-modules").StoreResult(&TheConfig->ServerUseModules, true);
+ opts.AddLongOption("on-message-in-pool", "execute OnMessage callback in worker pool")
+ .RequiredArgument("BOOL")
+ .StoreResult(&TheConfig->ExecuteOnMessageInWorkerPool);
+ opts.AddLongOption("on-reply-in-pool", "execute OnReply callback in worker pool")
+ .RequiredArgument("BOOL")
+ .StoreResult(&TheConfig->ExecuteOnReplyInWorkerPool);
+ opts.AddLongOption("compression", "use compression").RequiredArgument("BOOL").StoreResult(&TheConfig->UseCompression);
+ opts.AddLongOption("simple-proto").SetFlag(&Config.SimpleProtocol);
+ opts.AddLongOption("profile").SetFlag(&TheConfig->Profile);
+ opts.AddLongOption("www-port").RequiredArgument("PORT").StoreResult(&TheConfig->WwwPort);
+ opts.AddHelpOption();
+
+ Config.ServerQueueConfig.ConfigureLastGetopt(opts, "server-");
+ Config.ServerSessionConfig.ConfigureLastGetopt(opts, "server-");
+ Config.ClientQueueConfig.ConfigureLastGetopt(opts, "client-");
+ Config.ClientSessionConfig.ConfigureLastGetopt(opts, "client-");
+
+ opts.SetFreeArgsMax(0);
+
+ NLastGetopt::TOptsParseResult parseResult(&opts, argc, argv);
+
+ TheConfig->Print();
+ Config.Print();
+
+ if (TheConfig->Profile) {
+ BeginProfiling();
+ }
+
+ TIntrusivePtr<TBusWww> www(new TBusWww);
+
+ ServerAddresses = ParseNodes(TheConfig->Nodes);
+
+ if (TheConfig->ServerPort) {
+ if (TheConfig->ServerUseModules) {
+ ServerUsingModule = new TPerftestUsingModule();
+ www->RegisterModule(ServerUsingModule.Get());
+ } else {
+ Server = new TPerftestServer();
+ www->RegisterServerSession(Server->Session);
+ }
+ }
+
+ TVector<TSimpleSharedPtr<NThreading::TLegacyFuture<void, false>>> futures;
+
+ if (ServerAddresses.size() > 0 && TheConfig->ClientCount > 0) {
+ for (int i = 0; i < TheConfig->ClientCount; ++i) {
+ TGuard<TMutex> guard(ClientsLock);
+ Clients.push_back(new TPerftestClient);
+ futures.push_back(new NThreading::TLegacyFuture<void, false>(std::bind(&TPerftestClient::Work, Clients.back())));
+ www->RegisterClientSession(Clients.back()->Session);
+ }
+ }
+
+ futures.push_back(new NThreading::TLegacyFuture<void, false>(std::bind(&TTestStats::PeriodicallyPrint, std::ref(Stats))));
+
+ THolder<TBusWwwHttpServer> wwwServer;
+ if (TheConfig->WwwPort != 0) {
+ wwwServer.Reset(new TBusWwwHttpServer(www, TheConfig->WwwPort));
+ }
+
+ /* sit here until signal terminate our process */
+ StopEvent.WaitT(TDuration::Seconds(TheConfig->Run));
+ TheExit = true;
+ StopEvent.Signal();
+
+ if (!!Server) {
+ Cerr << "Stopping server\n";
+ Server->Stop();
+ }
+ if (!!ServerUsingModule) {
+ Cerr << "Stopping server (using modules)\n";
+ ServerUsingModule->Stop();
+ }
+
+ TVector<TSimpleSharedPtr<TPerftestClient>> clients;
+ {
+ TGuard<TMutex> guard(ClientsLock);
+ clients = Clients;
+ }
+
+ if (!clients.empty()) {
+ Cerr << "Stopping clients\n";
+
+ for (auto& client : clients) {
+ client->Stop();
+ }
+ }
+
+ wwwServer.Destroy();
+
+ for (const auto& future : futures) {
+ future->Get();
+ }
+
+ if (TheConfig->Profile) {
+ EndProfiling();
+ }
+
+ Cerr << "***SUCCESS***\n";
+ return 0;
+}
diff --git a/library/cpp/messagebus/test/perftest/simple_proto.cpp b/library/cpp/messagebus/test/perftest/simple_proto.cpp
new file mode 100644
index 0000000000..19d6c15b9d
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/simple_proto.cpp
@@ -0,0 +1,22 @@
+#include "simple_proto.h"
+
+#include <util/generic/cast.h>
+
+#include <typeinfo>
+
+using namespace NBus;
+
+void TSimpleProtocol::Serialize(const TBusMessage* mess, TBuffer& data) {
+ Y_VERIFY(typeid(TSimpleMessage) == typeid(*mess));
+ const TSimpleMessage* typed = static_cast<const TSimpleMessage*>(mess);
+ data.Append((const char*)&typed->Payload, 4);
+}
+
+TAutoPtr<TBusMessage> TSimpleProtocol::Deserialize(ui16, TArrayRef<const char> payload) {
+ if (payload.size() != 4) {
+ return nullptr;
+ }
+ TAutoPtr<TSimpleMessage> r(new TSimpleMessage);
+ memcpy(&r->Payload, payload.data(), 4);
+ return r.Release();
+}
diff --git a/library/cpp/messagebus/test/perftest/simple_proto.h b/library/cpp/messagebus/test/perftest/simple_proto.h
new file mode 100644
index 0000000000..4a0cc08db3
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/simple_proto.h
@@ -0,0 +1,29 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+
+struct TSimpleMessage: public NBus::TBusMessage {
+ ui32 Payload;
+
+ TSimpleMessage()
+ : TBusMessage(1)
+ , Payload(0)
+ {
+ }
+
+ TSimpleMessage(NBus::ECreateUninitialized)
+ : TBusMessage(NBus::ECreateUninitialized())
+ {
+ }
+};
+
+struct TSimpleProtocol: public NBus::TBusProtocol {
+ TSimpleProtocol()
+ : NBus::TBusProtocol("simple", 55666)
+ {
+ }
+
+ void Serialize(const NBus::TBusMessage* mess, TBuffer& data) override;
+
+ TAutoPtr<NBus::TBusMessage> Deserialize(ui16 ty, TArrayRef<const char> payload) override;
+};
diff --git a/library/cpp/messagebus/test/perftest/stackcollect.diff b/library/cpp/messagebus/test/perftest/stackcollect.diff
new file mode 100644
index 0000000000..658f0141b3
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/stackcollect.diff
@@ -0,0 +1,13 @@
+Index: test/perftest/CMakeLists.txt
+===================================================================
+--- test/perftest/CMakeLists.txt (revision 1088840)
++++ test/perftest/CMakeLists.txt (working copy)
+@@ -3,7 +3,7 @@ PROGRAM(messagebus_perftest)
+ OWNER(nga)
+
+ PEERDIR(
+- library/cpp/execprofile
++ junk/davenger/stackcollect
+ library/cpp/messagebus
+ library/cpp/messagebus/protobuf
+ library/cpp/sighandler
diff --git a/library/cpp/messagebus/test/perftest/ya.make b/library/cpp/messagebus/test/perftest/ya.make
new file mode 100644
index 0000000000..24c2848ed5
--- /dev/null
+++ b/library/cpp/messagebus/test/perftest/ya.make
@@ -0,0 +1,24 @@
+PROGRAM(messagebus_perftest)
+
+OWNER(g:messagebus)
+
+PEERDIR(
+ library/cpp/deprecated/threadable
+ library/cpp/execprofile
+ library/cpp/getopt
+ library/cpp/lwtrace
+ library/cpp/messagebus
+ library/cpp/messagebus/oldmodule
+ library/cpp/messagebus/protobuf
+ library/cpp/messagebus/www
+ library/cpp/sighandler
+ library/cpp/threading/future
+)
+
+SRCS(
+ messages.proto
+ perftest.cpp
+ simple_proto.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/test/ut/count_down_latch.h b/library/cpp/messagebus/test/ut/count_down_latch.h
new file mode 100644
index 0000000000..5117db5731
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/count_down_latch.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <util/system/atomic.h>
+#include <util/system/event.h>
+
+class TCountDownLatch {
+private:
+ TAtomic Current;
+ TSystemEvent EventObject;
+
+public:
+ TCountDownLatch(unsigned initial)
+ : Current(initial)
+ {
+ }
+
+ void CountDown() {
+ if (AtomicDecrement(Current) == 0) {
+ EventObject.Signal();
+ }
+ }
+
+ void Await() {
+ EventObject.Wait();
+ }
+
+ bool Await(TDuration timeout) {
+ return EventObject.WaitT(timeout);
+ }
+};
diff --git a/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp b/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp
new file mode 100644
index 0000000000..3fdd175d73
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/locator_uniq_ut.cpp
@@ -0,0 +1,40 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/test_utils.h>
+#include <library/cpp/messagebus/ybus.h>
+
+class TLocatorRegisterUniqTest: public TTestBase {
+ UNIT_TEST_SUITE(TLocatorRegisterUniqTest);
+ UNIT_TEST(TestRegister);
+ UNIT_TEST_SUITE_END();
+
+protected:
+ void TestRegister();
+};
+
+UNIT_TEST_SUITE_REGISTRATION(TLocatorRegisterUniqTest);
+
+void TLocatorRegisterUniqTest::TestRegister() {
+ ASSUME_IP_V4_ENABLED;
+
+ NBus::TBusLocator locator;
+ const char* serviceName = "TestService";
+ const char* hostName = "192.168.0.42";
+ int port = 31337;
+
+ NBus::TBusKeyVec keys;
+ locator.LocateKeys(serviceName, keys);
+ UNIT_ASSERT(keys.size() == 0);
+
+ locator.Register(serviceName, hostName, port);
+ locator.LocateKeys(serviceName, keys);
+ /// YBUS_KEYMIN YBUS_KEYMAX range
+ UNIT_ASSERT(keys.size() == 1);
+
+ TVector<NBus::TNetAddr> hosts;
+ UNIT_ASSERT(locator.LocateAll(serviceName, NBus::YBUS_KEYMIN, hosts) == 1);
+
+ locator.Register(serviceName, hostName, port);
+ hosts.clear();
+ UNIT_ASSERT(locator.LocateAll(serviceName, NBus::YBUS_KEYMIN, hosts) == 1);
+}
diff --git a/library/cpp/messagebus/test/ut/messagebus_ut.cpp b/library/cpp/messagebus/test/ut/messagebus_ut.cpp
new file mode 100644
index 0000000000..040f9b7702
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/messagebus_ut.cpp
@@ -0,0 +1,1151 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/fixed_port.h>
+#include <library/cpp/messagebus/test/helper/hanging_server.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+#include <library/cpp/messagebus/test/helper/wait_for.h>
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+
+#include <util/network/sock.h>
+
+#include <utility>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+namespace {
+ struct TExampleClientSlowOnMessageSent: public TExampleClient {
+ TAtomic SentCompleted;
+
+ TSystemEvent ReplyReceived;
+
+ TExampleClientSlowOnMessageSent()
+ : SentCompleted(0)
+ {
+ }
+
+ ~TExampleClientSlowOnMessageSent() override {
+ Session->Shutdown();
+ }
+
+ void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override {
+ Y_VERIFY(AtomicGet(SentCompleted), "must be completed");
+
+ TExampleClient::OnReply(mess, reply);
+
+ ReplyReceived.Signal();
+ }
+
+ void OnMessageSent(TBusMessage*) override {
+ Sleep(TDuration::MilliSeconds(100));
+ AtomicSet(SentCompleted, 1);
+ }
+ };
+
+}
+
+Y_UNIT_TEST_SUITE(TMessageBusTests) {
+ void TestDestinationTemplate(bool useCompression, bool ackMessageBeforeReply,
+ const TBusServerSessionConfig& sessionConfig) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TExampleClient client(sessionConfig);
+ client.CrashOnError = true;
+
+ server.UseCompression = useCompression;
+ client.UseCompression = useCompression;
+
+ server.AckMessageBeforeSendReply = ackMessageBeforeReply;
+
+ client.SendMessagesWaitReplies(100, server.GetActualListenAddr());
+ UNIT_ASSERT_EQUAL(server.Session->GetInFlight(), 0);
+ UNIT_ASSERT_EQUAL(client.Session->GetInFlight(), 0);
+ }
+
+ Y_UNIT_TEST(TestDestination) {
+ TestDestinationTemplate(false, false, TBusServerSessionConfig());
+ }
+
+ Y_UNIT_TEST(TestDestinationUsingAck) {
+ TestDestinationTemplate(false, true, TBusServerSessionConfig());
+ }
+
+ Y_UNIT_TEST(TestDestinationWithCompression) {
+ TestDestinationTemplate(true, false, TBusServerSessionConfig());
+ }
+
+ Y_UNIT_TEST(TestCork) {
+ TBusServerSessionConfig config;
+ config.SendThreshold = 1000000000000;
+ config.Cork = TDuration::MilliSeconds(10);
+ TestDestinationTemplate(false, false, config);
+ // TODO: test for cork hanging
+ }
+
+ Y_UNIT_TEST(TestReconnect) {
+ if (!IsFixedPortTestAllowed()) {
+ return;
+ }
+
+ TObjectCountCheck objectCountCheck;
+
+ unsigned port = FixedPort;
+ TNetAddr serverAddr("localhost", port);
+ THolder<TExampleServer> server;
+
+ TBusClientSessionConfig clientConfig;
+ clientConfig.RetryInterval = 0;
+ TExampleClient client(clientConfig);
+
+ server.Reset(new TExampleServer(port, "TExampleServer 1"));
+
+ client.SendMessagesWaitReplies(17, serverAddr);
+
+ server.Destroy();
+
+ // Making the client to detect disconnection.
+ client.SendMessages(1, serverAddr);
+ EMessageStatus error = client.WaitForError();
+ if (error == MESSAGE_DELIVERY_FAILED) {
+ client.SendMessages(1, serverAddr);
+ error = client.WaitForError();
+ }
+ UNIT_ASSERT_VALUES_EQUAL(MESSAGE_CONNECT_FAILED, error);
+
+ server.Reset(new TExampleServer(port, "TExampleServer 2"));
+
+ client.SendMessagesWaitReplies(19, serverAddr);
+ }
+
+ struct TestNoServerImplClient: public TExampleClient {
+ TTestSync TestSync;
+ int failures = 0;
+
+ template <typename... Args>
+ TestNoServerImplClient(Args&&... args)
+ : TExampleClient(std::forward<Args>(args)...)
+ {
+ }
+
+ ~TestNoServerImplClient() override {
+ Session->Shutdown();
+ }
+
+ void OnError(TAutoPtr<TBusMessage> message, EMessageStatus status) override {
+ Y_UNUSED(message);
+
+ Y_VERIFY(status == MESSAGE_CONNECT_FAILED, "must be MESSAGE_CONNECT_FAILED, got %s", ToString(status).data());
+
+ TestSync.CheckAndIncrement((failures++) * 2);
+ }
+ };
+
+ void TestNoServerImpl(unsigned port, bool oneWay) {
+ TNetAddr noServerAddr("localhost", port);
+
+ TestNoServerImplClient client;
+
+ int count = 0;
+ for (; count < 200; ++count) {
+ EMessageStatus status;
+ if (oneWay) {
+ status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &noServerAddr);
+ } else {
+ TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount));
+ status = client.Session->SendMessageAutoPtr(message, &noServerAddr);
+ }
+
+ Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data());
+
+ if (count == 0) {
+ // lame way to wait until it is connected
+ Sleep(TDuration::MilliSeconds(10));
+ }
+ client.TestSync.WaitForAndIncrement(count * 2 + 1);
+ }
+
+ client.TestSync.WaitForAndIncrement(count * 2);
+ }
+
+ void HangingServerImpl(unsigned port) {
+ TNetAddr noServerAddr("localhost", port);
+
+ TExampleClient client;
+
+ int count = 0;
+ for (;; ++count) {
+ TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount));
+ EMessageStatus status = client.Session->SendMessageAutoPtr(message, &noServerAddr);
+ if (status == MESSAGE_BUSY) {
+ break;
+ }
+ UNIT_ASSERT_VALUES_EQUAL(int(MESSAGE_OK), int(status));
+
+ if (count == 0) {
+ // lame way to wait until it is connected
+ Sleep(TDuration::MilliSeconds(10));
+ }
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(client.Session->GetConfig()->MaxInFlight, count);
+ }
+
+ Y_UNIT_TEST(TestHangindServer) {
+ TObjectCountCheck objectCountCheck;
+
+ THangingServer server(0);
+
+ HangingServerImpl(server.GetPort());
+ }
+
+ Y_UNIT_TEST(TestNoServer) {
+ TObjectCountCheck objectCountCheck;
+
+ TestNoServerImpl(17, false);
+ }
+
+ Y_UNIT_TEST(PauseInput) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ server.Session->PauseInput(true);
+
+ TBusClientSessionConfig clientConfig;
+ clientConfig.MaxInFlight = 1000;
+ TExampleClient client(clientConfig);
+
+ client.SendMessages(100, server.GetActualListenAddr());
+
+ server.TestSync.Check(0);
+
+ server.Session->PauseInput(false);
+
+ server.TestSync.WaitFor(100);
+
+ client.WaitReplies();
+
+ server.Session->PauseInput(true);
+
+ client.SendMessages(200, server.GetActualListenAddr());
+
+ server.TestSync.Check(100);
+
+ server.Session->PauseInput(false);
+
+ server.TestSync.WaitFor(300);
+
+ client.WaitReplies();
+ }
+
+ struct TSendTimeoutCheckerExampleClient: public TExampleClient {
+ static TBusClientSessionConfig SessionConfig(bool periodLessThanConnectTimeout) {
+ TBusClientSessionConfig sessionConfig;
+ if (periodLessThanConnectTimeout) {
+ sessionConfig.SendTimeout = 1;
+ sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(50);
+ } else {
+ sessionConfig.SendTimeout = 50;
+ sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(1);
+ }
+ return sessionConfig;
+ }
+
+ TSendTimeoutCheckerExampleClient(bool periodLessThanConnectTimeout)
+ : TExampleClient(SessionConfig(periodLessThanConnectTimeout))
+ {
+ }
+
+ ~TSendTimeoutCheckerExampleClient() override {
+ Session->Shutdown();
+ }
+
+ TSystemEvent ErrorHappened;
+
+ void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override {
+ Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "got status: %s", ToString(status).data());
+ ErrorHappened.Signal();
+ }
+ };
+
+ void NoServer_SendTimeout_Callback_Impl(bool periodLessThanConnectTimeout) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr serverAddr("localhost", 17);
+
+ TSendTimeoutCheckerExampleClient client(periodLessThanConnectTimeout);
+
+ client.SendMessages(1, serverAddr);
+
+ client.ErrorHappened.WaitI();
+ }
+
+ Y_UNIT_TEST(NoServer_SendTimeout_Callback_PeriodLess) {
+ NoServer_SendTimeout_Callback_Impl(true);
+ }
+
+ Y_UNIT_TEST(NoServer_SendTimeout_Callback_TimeoutLess) {
+ NoServer_SendTimeout_Callback_Impl(false);
+ }
+
+ Y_UNIT_TEST(TestOnReplyCalledAfterOnMessageSent) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ TNetAddr serverAddr = server.GetActualListenAddr();
+ TExampleClientSlowOnMessageSent client;
+
+ TAutoPtr<TExampleRequest> message(new TExampleRequest(&client.Proto.RequestCount));
+ EMessageStatus s = client.Session->SendMessageAutoPtr(message, &serverAddr);
+ UNIT_ASSERT_EQUAL(s, MESSAGE_OK);
+
+ UNIT_ASSERT(client.ReplyReceived.WaitT(TDuration::Seconds(5)));
+ }
+
+ struct TDelayReplyServer: public TBusServerHandlerError {
+ TBusMessageQueuePtr Bus;
+ TExampleProtocol Proto;
+ TSystemEvent MessageReceivedEvent; // 1 wait for 1 message
+ TBusServerSessionPtr Session;
+ TMutex Lock_;
+ TDeque<TAutoPtr<TOnMessageContext>> DelayedMessages;
+
+ TDelayReplyServer()
+ : MessageReceivedEvent(TEventResetType::rAuto)
+ {
+ Bus = CreateMessageQueue("TDelayReplyServer");
+ TBusServerSessionConfig sessionConfig;
+ sessionConfig.SendTimeout = 1000;
+ sessionConfig.TotalTimeout = 2001;
+ Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus);
+ if (!Session) {
+ ythrow yexception() << "Failed to create destination session";
+ }
+ }
+
+ void OnMessage(TOnMessageContext& mess) override {
+ Y_VERIFY(mess.IsConnectionAlive(), "connection should be alive here");
+ TAutoPtr<TOnMessageContext> delayedMsg(new TOnMessageContext);
+ delayedMsg->Swap(mess);
+ auto g(Guard(Lock_));
+ DelayedMessages.push_back(delayedMsg);
+ MessageReceivedEvent.Signal();
+ }
+
+ bool CheckClientIsAlive() {
+ auto g(Guard(Lock_));
+ for (auto& delayedMessage : DelayedMessages) {
+ if (!delayedMessage->IsConnectionAlive()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ bool CheckClientIsDead() const {
+ auto g(Guard(Lock_));
+ for (const auto& delayedMessage : DelayedMessages) {
+ if (delayedMessage->IsConnectionAlive()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ void ReplyToDelayedMessages() {
+ while (true) {
+ TOnMessageContext msg;
+ {
+ auto g(Guard(Lock_));
+ if (DelayedMessages.empty()) {
+ break;
+ }
+ DelayedMessages.front()->Swap(msg);
+ DelayedMessages.pop_front();
+ }
+ TAutoPtr<TBusMessage> reply(new TExampleResponse(&Proto.ResponseCount));
+ msg.SendReplyMove(reply);
+ }
+ }
+
+ size_t GetDelayedMessageCount() const {
+ auto g(Guard(Lock_));
+ return DelayedMessages.size();
+ }
+
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ Y_UNUSED(mess);
+ Y_VERIFY(status == MESSAGE_SHUTDOWN, "only shutdown allowed, got %s", ToString(status).data());
+ }
+ };
+
+ Y_UNIT_TEST(TestReplyCalledAfterClientDisconnected) {
+ TObjectCountCheck objectCountCheck;
+
+ TDelayReplyServer server;
+
+ THolder<TExampleClient> client(new TExampleClient);
+
+ client->SendMessages(1, TNetAddr("localhost", server.Session->GetActualListenPort()));
+
+ UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5)));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, server.Session->GetInFlight());
+
+ client.Destroy();
+
+ UNIT_WAIT_FOR(server.CheckClientIsDead());
+
+ server.ReplyToDelayedMessages();
+
+ // wait until all server message are delivered
+ UNIT_WAIT_FOR(0 == server.Session->GetInFlight());
+ }
+
+ struct TPackUnpackServer: public TBusServerHandlerError {
+ TBusMessageQueuePtr Bus;
+ TExampleProtocol Proto;
+ TSystemEvent MessageReceivedEvent;
+ TSystemEvent ClientDiedEvent;
+ TBusServerSessionPtr Session;
+
+ TPackUnpackServer() {
+ Bus = CreateMessageQueue("TPackUnpackServer");
+ TBusServerSessionConfig sessionConfig;
+ Session = TBusServerSession::Create(&Proto, this, sessionConfig, Bus);
+ }
+
+ void OnMessage(TOnMessageContext& mess) override {
+ TBusIdentity ident;
+ mess.AckMessage(ident);
+
+ char packed[BUS_IDENTITY_PACKED_SIZE];
+ ident.Pack(packed);
+ TBusIdentity resurrected;
+ resurrected.Unpack(packed);
+
+ mess.GetSession()->SendReply(resurrected, new TExampleResponse(&Proto.ResponseCount));
+ }
+
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ Y_UNUSED(mess);
+ Y_VERIFY(status == MESSAGE_SHUTDOWN, "only shutdown allowed");
+ }
+ };
+
+ Y_UNIT_TEST(PackUnpack) {
+ TObjectCountCheck objectCountCheck;
+
+ TPackUnpackServer server;
+
+ THolder<TExampleClient> client(new TExampleClient);
+
+ client->SendMessagesWaitReplies(1, TNetAddr("localhost", server.Session->GetActualListenPort()));
+ }
+
+ Y_UNIT_TEST(ClientRequestTooLarge) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TBusClientSessionConfig clientConfig;
+ clientConfig.MaxMessageSize = 100;
+ TExampleClient client(clientConfig);
+
+ client.DataSize = 10;
+ client.SendMessagesWaitReplies(1, server.GetActualListenAddr());
+
+ client.DataSize = 1000;
+ client.SendMessages(1, server.GetActualListenAddr());
+ client.WaitForError(MESSAGE_MESSAGE_TOO_LARGE);
+
+ client.DataSize = 20;
+ client.SendMessagesWaitReplies(10, server.GetActualListenAddr());
+
+ client.DataSize = 10000;
+ client.SendMessages(1, server.GetActualListenAddr());
+ client.WaitForError(MESSAGE_MESSAGE_TOO_LARGE);
+ }
+
+ struct TServerForResponseTooLarge: public TExampleServer {
+ TTestSync TestSync;
+
+ static TBusServerSessionConfig Config() {
+ TBusServerSessionConfig config;
+ config.MaxMessageSize = 100;
+ return config;
+ }
+
+ TServerForResponseTooLarge()
+ : TExampleServer("TServerForResponseTooLarge", Config())
+ {
+ }
+
+ ~TServerForResponseTooLarge() override {
+ Session->Shutdown();
+ }
+
+ void OnMessage(TOnMessageContext& mess) override {
+ TAutoPtr<TBusMessage> response;
+
+ if (TestSync.Get() == 0) {
+ TestSync.CheckAndIncrement(0);
+ response.Reset(new TExampleResponse(&Proto.ResponseCount, 1000));
+ } else {
+ TestSync.WaitForAndIncrement(3);
+ response.Reset(new TExampleResponse(&Proto.ResponseCount, 10));
+ }
+
+ mess.SendReplyMove(response);
+ }
+
+ void OnError(TAutoPtr<TBusMessage>, EMessageStatus status) override {
+ TestSync.WaitForAndIncrement(1);
+
+ Y_VERIFY(status == MESSAGE_MESSAGE_TOO_LARGE, "status");
+ }
+ };
+
+ Y_UNIT_TEST(ServerResponseTooLarge) {
+ TObjectCountCheck objectCountCheck;
+
+ TServerForResponseTooLarge server;
+
+ TExampleClient client;
+ client.DataSize = 10;
+
+ client.SendMessages(1, server.GetActualListenAddr());
+ server.TestSync.WaitForAndIncrement(2);
+ client.ResetCounters();
+
+ client.SendMessages(1, server.GetActualListenAddr());
+
+ client.WorkDone.WaitI();
+
+ server.TestSync.CheckAndIncrement(4);
+
+ UNIT_ASSERT_VALUES_EQUAL(1, client.Session->GetInFlight());
+ }
+
+ struct TServerForRequestTooLarge: public TExampleServer {
+ TTestSync TestSync;
+
+ static TBusServerSessionConfig Config() {
+ TBusServerSessionConfig config;
+ config.MaxMessageSize = 100;
+ return config;
+ }
+
+ TServerForRequestTooLarge()
+ : TExampleServer("TServerForRequestTooLarge", Config())
+ {
+ }
+
+ ~TServerForRequestTooLarge() override {
+ Session->Shutdown();
+ }
+
+ void OnMessage(TOnMessageContext& req) override {
+ unsigned n = TestSync.Get();
+ if (n < 2) {
+ TestSync.CheckAndIncrement(n);
+ TAutoPtr<TExampleResponse> resp(new TExampleResponse(&Proto.ResponseCount, 10));
+ req.SendReplyMove(resp);
+ } else {
+ Y_FAIL("wrong");
+ }
+ }
+ };
+
+ Y_UNIT_TEST(ServerRequestTooLarge) {
+ TObjectCountCheck objectCountCheck;
+
+ TServerForRequestTooLarge server;
+
+ TExampleClient client;
+ client.DataSize = 10;
+
+ client.SendMessagesWaitReplies(2, server.GetActualListenAddr());
+
+ server.TestSync.CheckAndIncrement(2);
+
+ client.DataSize = 200;
+ client.SendMessages(1, server.GetActualListenAddr());
+ // server closes connection, so MESSAGE_DELIVERY_FAILED is returned to client
+ client.WaitForError(MESSAGE_DELIVERY_FAILED);
+ }
+
+ Y_UNIT_TEST(ClientResponseTooLarge) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ server.DataSize = 10;
+
+ TBusClientSessionConfig clientSessionConfig;
+ clientSessionConfig.MaxMessageSize = 100;
+ TExampleClient client(clientSessionConfig);
+ client.DataSize = 10;
+
+ client.SendMessagesWaitReplies(3, server.GetActualListenAddr());
+
+ server.DataSize = 1000;
+
+ client.SendMessages(1, server.GetActualListenAddr());
+ client.WaitForError(MESSAGE_DELIVERY_FAILED);
+ }
+
+ Y_UNIT_TEST(ServerUnknownMessage) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ TNetAddr serverAddr = server.GetActualListenAddr();
+
+ TExampleClient client;
+
+ client.SendMessagesWaitReplies(2, serverAddr);
+
+ TAutoPtr<TBusMessage> req(new TExampleRequest(&client.Proto.RequestCount));
+ req->GetHeader()->Type = 11;
+ client.Session->SendMessageAutoPtr(req, &serverAddr);
+ client.MessageCount = 1;
+
+ client.WaitForError(MESSAGE_DELIVERY_FAILED);
+ }
+
+ Y_UNIT_TEST(ServerMessageReservedIds) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ TNetAddr serverAddr = server.GetActualListenAddr();
+
+ TExampleClient client;
+
+ client.SendMessagesWaitReplies(2, serverAddr);
+
+ // This test doens't check 0, 1, YBUS_KEYINVALID because there are asserts() on sending side
+
+ TAutoPtr<TBusMessage> req(new TExampleRequest(&client.Proto.RequestCount));
+ req->GetHeader()->Id = 2;
+ client.Session->SendMessageAutoPtr(req, &serverAddr);
+ client.MessageCount = 1;
+ client.WaitForError(MESSAGE_DELIVERY_FAILED);
+
+ req.Reset(new TExampleRequest(&client.Proto.RequestCount));
+ req->GetHeader()->Id = YBUS_KEYLOCAL;
+ client.Session->SendMessageAutoPtr(req, &serverAddr);
+ client.MessageCount = 1;
+ client.WaitForError(MESSAGE_DELIVERY_FAILED);
+ }
+
+ Y_UNIT_TEST(TestGetInFlightForDestination) {
+ TObjectCountCheck objectCountCheck;
+
+ TDelayReplyServer server;
+
+ TExampleClient client;
+
+ TNetAddr addr("localhost", server.Session->GetActualListenPort());
+
+ UNIT_ASSERT_VALUES_EQUAL(size_t(0), client.Session->GetInFlight(addr));
+
+ client.SendMessages(2, &addr);
+
+ for (size_t i = 0; i < 5; ++i) {
+ // One MessageReceivedEvent indicates one message, we need to wait for two
+ UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5)));
+ if (server.GetDelayedMessageCount() == 2) {
+ break;
+ }
+ }
+ UNIT_ASSERT_VALUES_EQUAL(server.GetDelayedMessageCount(), 2);
+
+ size_t inFlight = client.Session->GetInFlight(addr);
+ // 4 is for messagebus1 that adds inFlight counter twice for some reason
+ UNIT_ASSERT(inFlight == 2 || inFlight == 4);
+
+ UNIT_ASSERT(server.CheckClientIsAlive());
+
+ server.ReplyToDelayedMessages();
+
+ client.WaitReplies();
+ }
+
+ struct TResetAfterSendOneWayErrorInCallbackClient: public TExampleClient {
+ TTestSync TestSync;
+
+ static TBusClientSessionConfig SessionConfig() {
+ TBusClientSessionConfig config;
+ // 1 ms is not enough when test is running under valgrind
+ config.ConnectTimeout = 10;
+ config.SendTimeout = 10;
+ config.Secret.TimeoutPeriod = TDuration::MilliSeconds(1);
+ return config;
+ }
+
+ TResetAfterSendOneWayErrorInCallbackClient()
+ : TExampleClient(SessionConfig())
+ {
+ }
+
+ ~TResetAfterSendOneWayErrorInCallbackClient() override {
+ Session->Shutdown();
+ }
+
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ TestSync.WaitForAndIncrement(0);
+ Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "must be connection failed, got %s", ToString(status).data());
+ mess.Destroy();
+ TestSync.CheckAndIncrement(1);
+ }
+ };
+
+ Y_UNIT_TEST(ResetAfterSendOneWayErrorInCallback) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr noServerAddr("localhost", 17);
+
+ TResetAfterSendOneWayErrorInCallbackClient client;
+
+ EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &noServerAddr);
+ UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok);
+
+ client.TestSync.WaitForAndIncrement(2);
+ }
+
+ struct TResetAfterSendMessageOneWayDuringShutdown: public TExampleClient {
+ TTestSync TestSync;
+
+ ~TResetAfterSendMessageOneWayDuringShutdown() override {
+ Session->Shutdown();
+ }
+
+ void OnError(TAutoPtr<TBusMessage> message, EMessageStatus status) override {
+ TestSync.CheckAndIncrement(0);
+
+ Y_VERIFY(status == MESSAGE_CONNECT_FAILED, "must be MESSAGE_CONNECT_FAILED, got %s", ToString(status).data());
+
+ // check reset is possible here
+ message->Reset();
+
+ // intentionally don't destroy the message
+ // we will try to resend it
+ Y_UNUSED(message.Release());
+
+ TestSync.CheckAndIncrement(1);
+ }
+ };
+
+ Y_UNIT_TEST(ResetAfterSendMessageOneWayDuringShutdown) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr noServerAddr("localhost", 17);
+
+ TResetAfterSendMessageOneWayDuringShutdown client;
+
+ TExampleRequest* message = new TExampleRequest(&client.Proto.RequestCount);
+ EMessageStatus ok = client.Session->SendMessageOneWay(message, &noServerAddr);
+ UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok);
+
+ client.TestSync.WaitForAndIncrement(2);
+
+ client.Session->Shutdown();
+
+ ok = client.Session->SendMessageOneWay(message);
+ Y_VERIFY(ok == MESSAGE_SHUTDOWN, "must be shutdown when sending during shutdown, got %s", ToString(ok).data());
+
+ // check reset is possible here
+ message->Reset();
+ client.TestSync.CheckAndIncrement(3);
+
+ delete message;
+ }
+
+ Y_UNIT_TEST(ResetAfterSendOneWayErrorInReturn) {
+ TObjectCountCheck objectCountCheck;
+
+ TestNoServerImpl(17, true);
+ }
+
+ struct TResetAfterSendOneWaySuccessClient: public TExampleClient {
+ TTestSync TestSync;
+
+ ~TResetAfterSendOneWaySuccessClient() override {
+ Session->Shutdown();
+ }
+
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage> sent) override {
+ TestSync.WaitForAndIncrement(0);
+ sent->Reset();
+ TestSync.CheckAndIncrement(1);
+ }
+ };
+
+ Y_UNIT_TEST(ResetAfterSendOneWaySuccess) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ TNetAddr serverAddr = server.GetActualListenAddr();
+
+ TResetAfterSendOneWaySuccessClient client;
+
+ EMessageStatus ok = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &serverAddr);
+ UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok);
+ // otherwize message might go to OnError(MESSAGE_SHUTDOWN)
+ server.WaitForOnMessageCount(1);
+
+ client.TestSync.WaitForAndIncrement(2);
+ }
+
+ Y_UNIT_TEST(GetStatus) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TExampleClient client;
+ // make sure connected
+ client.SendMessagesWaitReplies(3, server.GetActualListenAddr());
+
+ server.Bus->GetStatus();
+ server.Bus->GetStatus();
+ server.Bus->GetStatus();
+
+ client.Bus->GetStatus();
+ client.Bus->GetStatus();
+ client.Bus->GetStatus();
+ }
+
+ Y_UNIT_TEST(BindOnRandomPort) {
+ TObjectCountCheck objectCountCheck;
+
+ TBusServerSessionConfig serverConfig;
+ TExampleServer server;
+
+ TExampleClient client;
+ TNetAddr addr(TNetAddr("127.0.0.1", server.Session->GetActualListenPort()));
+ client.SendMessagesWaitReplies(3, &addr);
+ }
+
+ Y_UNIT_TEST(UnbindOnShutdown) {
+ TBusMessageQueuePtr queue(CreateMessageQueue());
+
+ TExampleProtocol proto;
+ TBusServerHandlerError handler;
+ TBusServerSessionPtr session = TBusServerSession::Create(
+ &proto, &handler, TBusServerSessionConfig(), queue);
+
+ unsigned port = session->GetActualListenPort();
+ UNIT_ASSERT(port > 0);
+
+ session->Shutdown();
+
+ // fails is Shutdown() didn't unbind
+ THangingServer hangingServer(port);
+ }
+
+ Y_UNIT_TEST(VersionNegotiation) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TSockAddrInet addr(IpFromString("127.0.0.1"), server.Session->GetActualListenPort());
+
+ TInetStreamSocket socket;
+ int r1 = socket.Connect(&addr);
+ UNIT_ASSERT(r1 >= 0);
+
+ TStreamSocketOutput output(&socket);
+
+ TBusHeader request;
+ Zero(request);
+ request.Size = sizeof(request);
+ request.SetVersionInternal(0xF); // max
+ output.Write(&request, sizeof(request));
+
+ UNIT_ASSERT_VALUES_EQUAL(IsVersionNegotiation(request), true);
+
+ TStreamSocketInput input(&socket);
+
+ TBusHeader response;
+ size_t pos = 0;
+
+ while (pos < sizeof(response)) {
+ size_t count = input.Read(((char*)&response) + pos, sizeof(response) - pos);
+ pos += count;
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(sizeof(response), pos);
+
+ UNIT_ASSERT_VALUES_EQUAL(YBUS_VERSION, response.GetVersionInternal());
+ }
+
+ struct TOnConnectionEventClient: public TExampleClient {
+ TTestSync Sync;
+
+ ~TOnConnectionEventClient() override {
+ Session->Shutdown();
+ }
+
+ void OnClientConnectionEvent(const TClientConnectionEvent& event) override {
+ if (Sync.Get() > 2) {
+ // Test OnClientConnectionEvent_Disconnect is broken.
+ // Sometimes reconnect happens during server shutdown
+ // when acceptor connections is still alive, and
+ // server connection is already closed
+ return;
+ }
+
+ if (event.GetType() == TClientConnectionEvent::CONNECTED) {
+ Sync.WaitForAndIncrement(0);
+ } else if (event.GetType() == TClientConnectionEvent::DISCONNECTED) {
+ Sync.WaitForAndIncrement(2);
+ }
+ }
+
+ void OnError(TAutoPtr<TBusMessage>, EMessageStatus) override {
+ // We do not check for message errors in this test.
+ }
+
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage>) override {
+ }
+ };
+
+ struct TOnConnectionEventServer: public TExampleServer {
+ TOnConnectionEventServer()
+ : TExampleServer("TOnConnectionEventServer")
+ {
+ }
+
+ ~TOnConnectionEventServer() override {
+ Session->Shutdown();
+ }
+
+ void OnError(TAutoPtr<TBusMessage>, EMessageStatus) override {
+ // We do not check for server message errors in this test.
+ }
+ };
+
+ Y_UNIT_TEST(OnClientConnectionEvent_Shutdown) {
+ TObjectCountCheck objectCountCheck;
+
+ TOnConnectionEventServer server;
+
+ TOnConnectionEventClient client;
+
+ TNetAddr addr("127.0.0.1", server.Session->GetActualListenPort());
+
+ client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &addr);
+
+ client.Sync.WaitForAndIncrement(1);
+
+ client.Session->Shutdown();
+
+ client.Sync.WaitForAndIncrement(3);
+ }
+
+ Y_UNIT_TEST(OnClientConnectionEvent_Disconnect) {
+ TObjectCountCheck objectCountCheck;
+
+ THolder<TOnConnectionEventServer> server(new TOnConnectionEventServer);
+
+ TOnConnectionEventClient client;
+ TNetAddr addr("127.0.0.1", server->Session->GetActualListenPort());
+
+ client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &addr);
+
+ client.Sync.WaitForAndIncrement(1);
+
+ server.Destroy();
+
+ client.Sync.WaitForAndIncrement(3);
+ }
+
+ struct TServerForQuotaWake: public TExampleServer {
+ TSystemEvent GoOn;
+ TMutex OneLock;
+
+ TOnMessageContext OneMessage;
+
+ static TBusServerSessionConfig Config() {
+ TBusServerSessionConfig config;
+
+ config.PerConnectionMaxInFlight = 1;
+ config.PerConnectionMaxInFlightBySize = 1500;
+ config.MaxMessageSize = 1024;
+
+ return config;
+ }
+
+ TServerForQuotaWake()
+ : TExampleServer("TServerForQuotaWake", Config())
+ {
+ }
+
+ ~TServerForQuotaWake() override {
+ Session->Shutdown();
+ }
+
+ void OnMessage(TOnMessageContext& req) override {
+ if (!GoOn.Wait(0)) {
+ TGuard<TMutex> guard(OneLock);
+
+ UNIT_ASSERT(!OneMessage);
+
+ OneMessage.Swap(req);
+ } else
+ TExampleServer::OnMessage(req);
+ }
+
+ void WakeOne() {
+ TGuard<TMutex> guard(OneLock);
+
+ UNIT_ASSERT(!!OneMessage);
+
+ TExampleServer::OnMessage(OneMessage);
+
+ TOnMessageContext().Swap(OneMessage);
+ }
+ };
+
+ Y_UNIT_TEST(WakeReaderOnQuota) {
+ const size_t test_msg_count = 64;
+
+ TBusClientSessionConfig clientConfig;
+
+ clientConfig.MaxInFlight = test_msg_count;
+
+ TExampleClient client(clientConfig);
+ TServerForQuotaWake server;
+ TInstant start;
+
+ client.MessageCount = test_msg_count;
+
+ const NBus::TNetAddr addr = server.GetActualListenAddr();
+
+ for (unsigned count = 0;;) {
+ UNIT_ASSERT(count <= test_msg_count);
+
+ TAutoPtr<TBusMessage> message(new TExampleRequest(&client.Proto.RequestCount));
+ EMessageStatus status = client.Session->SendMessageAutoPtr(message, &addr);
+
+ if (status == MESSAGE_OK) {
+ count++;
+
+ } else if (status == MESSAGE_BUSY) {
+ if (count == test_msg_count) {
+ TInstant now = TInstant::Now();
+
+ if (start.GetValue() == 0) {
+ start = now;
+
+ // TODO: properly check that server is blocked
+ } else if (start + TDuration::MilliSeconds(100) < now) {
+ break;
+ }
+ }
+
+ Sleep(TDuration::MilliSeconds(10));
+
+ } else
+ UNIT_ASSERT(false);
+ }
+
+ server.GoOn.Signal();
+ server.WakeOne();
+
+ client.WaitReplies();
+
+ server.WaitForOnMessageCount(test_msg_count);
+ };
+
+ Y_UNIT_TEST(TestConnectionAttempts) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr noServerAddr("localhost", 17);
+ TBusClientSessionConfig clientConfig;
+ clientConfig.RetryInterval = 100;
+ TestNoServerImplClient client(clientConfig);
+
+ int count = 0;
+ for (; count < 10; ++count) {
+ EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount),
+ &noServerAddr);
+
+ Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data());
+ client.TestSync.WaitForAndIncrement(count * 2 + 1);
+
+ // First connection attempt is for connect call; second one is to get connect result.
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ }
+ Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval));
+ for (; count < 10; ++count) {
+ EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount),
+ &noServerAddr);
+
+ Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data());
+ client.TestSync.WaitForAndIncrement(count * 2 + 1);
+
+ // First connection attempt is for connect call; second one is to get connect result.
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 4);
+ }
+ };
+
+ Y_UNIT_TEST(TestConnectionAttemptsOnNoMessagesAndNotReconnectWhenIdle) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr noServerAddr("localhost", 17);
+ TBusClientSessionConfig clientConfig;
+ clientConfig.RetryInterval = 100;
+ clientConfig.ReconnectWhenIdle = false;
+ TestNoServerImplClient client(clientConfig);
+
+ int count = 0;
+ for (; count < 10; ++count) {
+ EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount),
+ &noServerAddr);
+
+ Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data());
+ client.TestSync.WaitForAndIncrement(count * 2 + 1);
+
+ // First connection attempt is for connect call; second one is to get connect result.
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ }
+
+ Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval / 2));
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ Sleep(TDuration::MilliSeconds(10 * clientConfig.RetryInterval));
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ };
+
+ Y_UNIT_TEST(TestConnectionAttemptsOnNoMessagesAndReconnectWhenIdle) {
+ TObjectCountCheck objectCountCheck;
+
+ TNetAddr noServerAddr("localhost", 17);
+ TBusClientSessionConfig clientConfig;
+ clientConfig.ReconnectWhenIdle = true;
+ clientConfig.RetryInterval = 100;
+ TestNoServerImplClient client(clientConfig);
+
+ int count = 0;
+ for (; count < 10; ++count) {
+ EMessageStatus status = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount),
+ &noServerAddr);
+
+ Y_VERIFY(status == MESSAGE_OK, "must be MESSAGE_OK, got %s", ToString(status).data());
+ client.TestSync.WaitForAndIncrement(count * 2 + 1);
+
+ // First connection attempt is for connect call; second one is to get connect result.
+ UNIT_ASSERT_VALUES_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ }
+
+ Sleep(TDuration::MilliSeconds(clientConfig.RetryInterval / 2));
+ UNIT_ASSERT_EQUAL(client.Session->GetConnectSyscallsNumForTest(noServerAddr), 2);
+ Sleep(TDuration::MilliSeconds(10 * clientConfig.RetryInterval));
+ // it is undeterministic how many reconnects will be during that amount of time
+ // but it should occur at least once
+ UNIT_ASSERT(client.Session->GetConnectSyscallsNumForTest(noServerAddr) > 2);
+ };
+};
diff --git a/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp b/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp
new file mode 100644
index 0000000000..4083cf3b7b
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/module_client_one_way_ut.cpp
@@ -0,0 +1,143 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/message_handler_error.h>
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+Y_UNIT_TEST_SUITE(ModuleClientOneWay) {
+ struct TTestServer: public TBusServerHandlerError {
+ TExampleProtocol Proto;
+
+ TTestSync* const TestSync;
+
+ TBusMessageQueuePtr Queue;
+ TBusServerSessionPtr ServerSession;
+
+ TTestServer(TTestSync* testSync)
+ : TestSync(testSync)
+ {
+ Queue = CreateMessageQueue();
+ ServerSession = TBusServerSession::Create(&Proto, this, TBusServerSessionConfig(), Queue);
+ }
+
+ void OnMessage(TOnMessageContext& context) override {
+ TestSync->WaitForAndIncrement(1);
+ context.ForgetRequest();
+ }
+ };
+
+ struct TClientModule: public TBusModule {
+ TExampleProtocol Proto;
+
+ TTestSync* const TestSync;
+ unsigned const Port;
+
+ TBusClientSessionPtr ClientSession;
+
+ TClientModule(TTestSync* testSync, unsigned port)
+ : TBusModule("m")
+ , TestSync(testSync)
+ , Port(port)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage*) override {
+ TestSync->WaitForAndIncrement(0);
+
+ job->SendOneWayTo(new TExampleRequest(&Proto.RequestCount), ClientSession.Get(), TNetAddr("localhost", Port));
+
+ return &TClientModule::Sent;
+ }
+
+ TJobHandler Sent(TBusJob* job, TBusMessage*) {
+ TestSync->WaitForAndIncrement(2);
+ job->Cancel(MESSAGE_DONT_ASK);
+ return nullptr;
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ ClientSession = CreateDefaultSource(queue, &Proto, TBusServerSessionConfig());
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Simple) {
+ TTestSync testSync;
+
+ TTestServer server(&testSync);
+
+ TBusMessageQueuePtr queue = CreateMessageQueue();
+ TClientModule clientModule(&testSync, server.ServerSession->GetActualListenPort());
+
+ clientModule.CreatePrivateSessions(queue.Get());
+ clientModule.StartInput();
+
+ clientModule.StartJob(new TExampleRequest(&clientModule.Proto.StartCount));
+
+ testSync.WaitForAndIncrement(3);
+
+ clientModule.Shutdown();
+ }
+
+ struct TSendErrorModule: public TBusModule {
+ TExampleProtocol Proto;
+
+ TTestSync* const TestSync;
+
+ TBusClientSessionPtr ClientSession;
+
+ TSendErrorModule(TTestSync* testSync)
+ : TBusModule("m")
+ , TestSync(testSync)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage*) override {
+ TestSync->WaitForAndIncrement(0);
+
+ job->SendOneWayTo(new TExampleRequest(&Proto.RequestCount), ClientSession.Get(), TNetAddr("localhost", 1));
+
+ return &TSendErrorModule::Sent;
+ }
+
+ TJobHandler Sent(TBusJob* job, TBusMessage*) {
+ TestSync->WaitForAndIncrement(1);
+ job->Cancel(MESSAGE_DONT_ASK);
+ return nullptr;
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ TBusServerSessionConfig sessionConfig;
+ sessionConfig.ConnectTimeout = 1;
+ sessionConfig.SendTimeout = 1;
+ sessionConfig.TotalTimeout = 1;
+ sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(1);
+ ClientSession = CreateDefaultSource(queue, &Proto, sessionConfig);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(SendError) {
+ TTestSync testSync;
+
+ TBusQueueConfig queueConfig;
+ queueConfig.NumWorkers = 5;
+
+ TBusMessageQueuePtr queue = CreateMessageQueue(queueConfig);
+ TSendErrorModule clientModule(&testSync);
+
+ clientModule.CreatePrivateSessions(queue.Get());
+ clientModule.StartInput();
+
+ clientModule.StartJob(new TExampleRequest(&clientModule.Proto.StartCount));
+
+ testSync.WaitForAndIncrement(2);
+
+ clientModule.Shutdown();
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/module_client_ut.cpp b/library/cpp/messagebus/test/ut/module_client_ut.cpp
new file mode 100644
index 0000000000..ebfe185cc6
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/module_client_ut.cpp
@@ -0,0 +1,368 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "count_down_latch.h"
+#include "moduletest.h"
+
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/example_module.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+#include <library/cpp/messagebus/test/helper/wait_for.h>
+
+#include <library/cpp/messagebus/misc/test_sync.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+#include <util/generic/cast.h>
+#include <util/system/event.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+// helper class that cleans TBusJob instance, so job's destructor can
+// be completed without assertion fail.
+struct TJobGuard {
+public:
+ TJobGuard(NBus::TBusJob* job)
+ : Job(job)
+ {
+ }
+
+ ~TJobGuard() {
+ Job->ClearAllMessageStates();
+ }
+
+private:
+ NBus::TBusJob* Job;
+};
+
+class TMessageOk: public NBus::TBusMessage {
+public:
+ TMessageOk()
+ : NBus::TBusMessage(1)
+ {
+ }
+};
+
+class TMessageError: public NBus::TBusMessage {
+public:
+ TMessageError()
+ : NBus::TBusMessage(2)
+ {
+ }
+};
+
+Y_UNIT_TEST_SUITE(BusJobTest) {
+#if 0
+ Y_UNIT_TEST(TestPending) {
+ TObjectCountCheck objectCountCheck;
+
+ TDupDetectModule module;
+ TBusJob job(&module, new TBusMessage(0));
+ // Guard will clear the job if unit-assertion fails.
+ TJobGuard g(&job);
+
+ NBus::TBusMessage* msg = new NBus::TBusMessage(1);
+ job.Send(msg, NULL);
+ NBus::TJobStateVec pending;
+ job.GetPending(&pending);
+
+ UNIT_ASSERT_VALUES_EQUAL(pending.size(), 1u);
+ UNIT_ASSERT_EQUAL(msg, pending[0].Message);
+ }
+
+ Y_UNIT_TEST(TestCallReplyHandler) {
+ TObjectCountCheck objectCountCheck;
+
+ TDupDetectModule module;
+ NBus::TBusJob job(&module, new NBus::TBusMessage(0));
+ // Guard will clear the job if unit-assertion fails.
+ TJobGuard g(&job);
+
+ NBus::TBusMessage* msgOk = new TMessageOk;
+ NBus::TBusMessage* msgError = new TMessageError;
+ job.Send(msgOk, NULL);
+ job.Send(msgError, NULL);
+
+ UNIT_ASSERT_EQUAL(job.GetState<TMessageOk>(), NULL);
+ UNIT_ASSERT_EQUAL(job.GetState<TMessageError>(), NULL);
+
+ NBus::TBusMessage* reply = new NBus::TBusMessage(0);
+ job.CallReplyHandler(NBus::MESSAGE_OK, msgOk, reply);
+ job.CallReplyHandler(NBus::MESSAGE_TIMEOUT, msgError, NULL);
+
+ UNIT_ASSERT_UNEQUAL(job.GetState<TMessageOk>(), NULL);
+ UNIT_ASSERT_UNEQUAL(job.GetState<TMessageError>(), NULL);
+
+ UNIT_ASSERT_VALUES_EQUAL(job.GetStatus<TMessageError>(), NBus::MESSAGE_TIMEOUT);
+ UNIT_ASSERT_EQUAL(job.GetState<TMessageError>()->Status, NBus::MESSAGE_TIMEOUT);
+
+ UNIT_ASSERT_VALUES_EQUAL(job.GetStatus<TMessageOk>(), NBus::MESSAGE_OK);
+ UNIT_ASSERT_EQUAL(job.GetState<TMessageOk>()->Reply, reply);
+ }
+#endif
+
+ struct TParallelOnReplyModule : TExampleClientModule {
+ TNetAddr ServerAddr;
+
+ TCountDownLatch RepliesLatch;
+
+ TParallelOnReplyModule(const TNetAddr& serverAddr)
+ : ServerAddr(serverAddr)
+ , RepliesLatch(2)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+ job->Send(new TExampleRequest(&Proto.RequestCount), Source, TReplyHandler(&TParallelOnReplyModule::ReplyHandler), 0, ServerAddr);
+ return &TParallelOnReplyModule::HandleReplies;
+ }
+
+ void ReplyHandler(TBusJob*, EMessageStatus status, TBusMessage* mess, TBusMessage* reply) {
+ Y_UNUSED(mess);
+ Y_UNUSED(reply);
+ Y_VERIFY(status == MESSAGE_OK, "failed to get reply: %s", ToCString(status));
+ }
+
+ TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) {
+ Y_UNUSED(mess);
+ RepliesLatch.CountDown();
+ Y_VERIFY(RepliesLatch.Await(TDuration::Seconds(10)), "failed to get answers");
+ job->Cancel(MESSAGE_UNKNOWN);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(TestReplyHandlerCalledInParallel) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+
+ TExampleProtocol proto;
+
+ TBusQueueConfig config;
+ config.NumWorkers = 5;
+
+ TParallelOnReplyModule module(server.GetActualListenAddr());
+ module.StartModule();
+
+ module.StartJob(new TExampleRequest(&proto.StartCount));
+ module.StartJob(new TExampleRequest(&proto.StartCount));
+
+ UNIT_ASSERT(module.RepliesLatch.Await(TDuration::Seconds(10)));
+
+ module.Shutdown();
+ }
+
+ struct TErrorHandlerCheckerModule : TExampleModule {
+ TNetAddr ServerAddr;
+
+ TBusClientSessionPtr Source;
+
+ TCountDownLatch GotReplyLatch;
+
+ TBusMessage* SentMessage;
+
+ TErrorHandlerCheckerModule()
+ : ServerAddr("localhost", 17)
+ , GotReplyLatch(2)
+ , SentMessage()
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+ TExampleRequest* message = new TExampleRequest(&Proto.RequestCount);
+ job->Send(message, Source, TReplyHandler(&TErrorHandlerCheckerModule::ReplyHandler), 0, ServerAddr);
+ SentMessage = message;
+ return &TErrorHandlerCheckerModule::HandleReplies;
+ }
+
+ void ReplyHandler(TBusJob*, EMessageStatus status, TBusMessage* req, TBusMessage* resp) {
+ Y_VERIFY(status == MESSAGE_CONNECT_FAILED || status == MESSAGE_TIMEOUT, "got wrong status: %s", ToString(status).data());
+ Y_VERIFY(req == SentMessage, "checking request");
+ Y_VERIFY(resp == nullptr, "checking response");
+ GotReplyLatch.CountDown();
+ }
+
+ TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) {
+ Y_UNUSED(mess);
+ job->Cancel(MESSAGE_UNKNOWN);
+ GotReplyLatch.CountDown();
+ return nullptr;
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ TBusClientSessionConfig sessionConfig;
+ sessionConfig.SendTimeout = 1; // TODO: allow 0
+ sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(10);
+ Source = CreateDefaultSource(queue, &Proto, sessionConfig);
+ Source->RegisterService("localhost");
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(ErrorHandler) {
+ TExampleProtocol proto;
+
+ TBusQueueConfig config;
+ config.NumWorkers = 5;
+
+ TErrorHandlerCheckerModule module;
+
+ TBusModuleConfig moduleConfig;
+ moduleConfig.Secret.SchedulePeriod = TDuration::MilliSeconds(10);
+ module.SetConfig(moduleConfig);
+
+ module.StartModule();
+
+ module.StartJob(new TExampleRequest(&proto.StartCount));
+
+ module.GotReplyLatch.Await();
+
+ module.Shutdown();
+ }
+
+ struct TSlowReplyServer: public TBusServerHandlerError {
+ TTestSync* const TestSync;
+ TBusMessageQueuePtr Bus;
+ TBusServerSessionPtr ServerSession;
+ TExampleProtocol Proto;
+
+ TAtomic OnMessageCount;
+
+ TSlowReplyServer(TTestSync* testSync)
+ : TestSync(testSync)
+ , OnMessageCount(0)
+ {
+ Bus = CreateMessageQueue("TSlowReplyServer");
+ TBusServerSessionConfig sessionConfig;
+ ServerSession = TBusServerSession::Create(&Proto, this, sessionConfig, Bus);
+ }
+
+ void OnMessage(TOnMessageContext& req) override {
+ if (AtomicIncrement(OnMessageCount) == 1) {
+ TestSync->WaitForAndIncrement(0);
+ }
+ TAutoPtr<TBusMessage> response(new TExampleResponse(&Proto.ResponseCount));
+ req.SendReplyMove(response);
+ }
+ };
+
+ struct TModuleThatSendsReplyEarly: public TExampleClientModule {
+ TTestSync* const TestSync;
+ const unsigned ServerPort;
+
+ TBusServerSessionPtr ServerSession;
+ TAtomic ReplyCount;
+
+ TModuleThatSendsReplyEarly(TTestSync* testSync, unsigned serverPort)
+ : TestSync(testSync)
+ , ServerPort(serverPort)
+ , ServerSession(nullptr)
+ , ReplyCount(0)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+ for (unsigned i = 0; i < 2; ++i) {
+ job->Send(
+ new TExampleRequest(&Proto.RequestCount),
+ Source,
+ TReplyHandler(&TModuleThatSendsReplyEarly::ReplyHandler),
+ 0,
+ TNetAddr("127.0.0.1", ServerPort));
+ }
+ return &TModuleThatSendsReplyEarly::HandleReplies;
+ }
+
+ void ReplyHandler(TBusJob* job, EMessageStatus status, TBusMessage* mess, TBusMessage* reply) {
+ Y_UNUSED(mess);
+ Y_UNUSED(reply);
+ Y_VERIFY(status == MESSAGE_OK, "failed to get reply");
+ if (AtomicIncrement(ReplyCount) == 1) {
+ TestSync->WaitForAndIncrement(1);
+ job->SendReply(new TExampleResponse(&Proto.ResponseCount));
+ } else {
+ TestSync->WaitForAndIncrement(3);
+ }
+ }
+
+ TJobHandler HandleReplies(TBusJob* job, TBusMessage* mess) {
+ Y_UNUSED(mess);
+ job->Cancel(MESSAGE_UNKNOWN);
+ return nullptr;
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ TExampleClientModule::CreateExtSession(queue);
+ TBusServerSessionConfig sessionConfig;
+ return ServerSession = CreateDefaultDestination(queue, &Proto, sessionConfig);
+ }
+ };
+
+ Y_UNIT_TEST(SendReplyCalledBeforeAllRepliesReceived) {
+ TTestSync testSync;
+
+ TSlowReplyServer slowReplyServer(&testSync);
+
+ TModuleThatSendsReplyEarly module(&testSync, slowReplyServer.ServerSession->GetActualListenPort());
+ module.StartModule();
+
+ TExampleClient client;
+ TNetAddr addr("127.0.0.1", module.ServerSession->GetActualListenPort());
+ client.SendMessagesWaitReplies(1, &addr);
+
+ testSync.WaitForAndIncrement(2);
+
+ module.Shutdown();
+ }
+
+ struct TShutdownCalledBeforeReplyReceivedModule: public TExampleClientModule {
+ unsigned ServerPort;
+
+ TTestSync TestSync;
+
+ TShutdownCalledBeforeReplyReceivedModule(unsigned serverPort)
+ : ServerPort(serverPort)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage*) override {
+ TestSync.CheckAndIncrement(0);
+
+ job->Send(new TExampleRequest(&Proto.RequestCount), Source,
+ TReplyHandler(&TShutdownCalledBeforeReplyReceivedModule::HandleReply),
+ 0, TNetAddr("localhost", ServerPort));
+ return &TShutdownCalledBeforeReplyReceivedModule::End;
+ }
+
+ void HandleReply(TBusJob*, EMessageStatus status, TBusMessage*, TBusMessage*) {
+ Y_VERIFY(status == MESSAGE_SHUTDOWN, "got %s", ToCString(status));
+ TestSync.CheckAndIncrement(1);
+ }
+
+ TJobHandler End(TBusJob* job, TBusMessage*) {
+ TestSync.CheckAndIncrement(2);
+ job->Cancel(MESSAGE_SHUTDOWN);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(ShutdownCalledBeforeReplyReceived) {
+ TExampleServer server;
+ server.ForgetRequest = true;
+
+ TShutdownCalledBeforeReplyReceivedModule module(server.GetActualListenPort());
+
+ module.StartModule();
+
+ module.StartJob(new TExampleRequest(&module.Proto.RequestCount));
+
+ server.TestSync.WaitFor(1);
+
+ module.Shutdown();
+
+ module.TestSync.CheckAndIncrement(3);
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/module_server_ut.cpp b/library/cpp/messagebus/test/ut/module_server_ut.cpp
new file mode 100644
index 0000000000..88fe1dd9b6
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/module_server_ut.cpp
@@ -0,0 +1,119 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "count_down_latch.h"
+#include "moduletest.h"
+
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/example_module.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+#include <library/cpp/messagebus/test/helper/wait_for.h>
+
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+#include <util/generic/cast.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+Y_UNIT_TEST_SUITE(ModuleServerTests) {
+ Y_UNIT_TEST(TestModule) {
+ TObjectCountCheck objectCountCheck;
+
+ /// create or get instance of message queue, need one per application
+ TBusMessageQueuePtr bus(CreateMessageQueue());
+ THostInfoHandler hostHandler(bus.Get());
+ TDupDetectModule module(hostHandler.GetActualListenAddr());
+ bool success;
+ success = module.Init(bus.Get());
+ UNIT_ASSERT_C(success, "failed to initialize dupdetect module");
+
+ success = module.StartInput();
+ UNIT_ASSERT_C(success, "failed to start dupdetect module");
+
+ TDupDetectHandler dupHandler(module.ListenAddr, bus.Get());
+ dupHandler.Work();
+
+ UNIT_WAIT_FOR(dupHandler.NumMessages == dupHandler.NumReplies);
+
+ module.Shutdown();
+ dupHandler.DupDetect->Shutdown();
+ }
+
+ struct TParallelOnMessageModule: public TExampleServerModule {
+ TCountDownLatch WaitTwoRequestsLatch;
+
+ TParallelOnMessageModule()
+ : WaitTwoRequestsLatch(2)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ WaitTwoRequestsLatch.CountDown();
+ Y_VERIFY(WaitTwoRequestsLatch.Await(TDuration::Seconds(5)), "oops");
+
+ VerifyDynamicCast<TExampleRequest*>(mess);
+
+ job->SendReply(new TExampleResponse(&Proto.ResponseCount));
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(TestOnMessageHandlerCalledInParallel) {
+ TObjectCountCheck objectCountCheck;
+
+ TBusQueueConfig config;
+ config.NumWorkers = 5;
+
+ TParallelOnMessageModule module;
+ module.StartModule();
+
+ TExampleClient client;
+
+ client.SendMessagesWaitReplies(2, module.ServerAddr);
+
+ module.Shutdown();
+ }
+
+ struct TDelayReplyServer: public TExampleServerModule {
+ TSystemEvent MessageReceivedEvent;
+ TSystemEvent ClientDiedEvent;
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+
+ MessageReceivedEvent.Signal();
+
+ Y_VERIFY(ClientDiedEvent.WaitT(TDuration::Seconds(5)), "oops");
+
+ job->SendReply(new TExampleResponse(&Proto.ResponseCount));
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(TestReplyCalledAfterClientDisconnected) {
+ TObjectCountCheck objectCountCheck;
+
+ TBusQueueConfig config;
+ config.NumWorkers = 5;
+
+ TDelayReplyServer server;
+ server.StartModule();
+
+ THolder<TExampleClient> client(new TExampleClient);
+
+ client->SendMessages(1, server.ServerAddr);
+
+ UNIT_ASSERT(server.MessageReceivedEvent.WaitT(TDuration::Seconds(5)));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, server.GetModuleSessionInFlight());
+
+ client.Destroy();
+
+ server.ClientDiedEvent.Signal();
+
+ // wait until all server message are delivered
+ UNIT_WAIT_FOR(0 == server.GetModuleSessionInFlight());
+
+ server.Shutdown();
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/moduletest.h b/library/cpp/messagebus/test/ut/moduletest.h
new file mode 100644
index 0000000000..d5da72c0cb
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/moduletest.h
@@ -0,0 +1,221 @@
+#pragma once
+
+///////////////////////////////////////////////////////////////////
+/// \file
+/// \brief Example of using local session for communication.
+
+#include <library/cpp/messagebus/test/helper/alloc_counter.h>
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/message_handler_error.h>
+
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+namespace NBus {
+ namespace NTest {
+ using namespace std;
+
+#define TYPE_HOSTINFOREQUEST 100
+#define TYPE_HOSTINFORESPONSE 101
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief DupDetect protocol that common between client and server
+ ////////////////////////////////////////////////////////////////////
+ /// \brief HostInfo request class
+ class THostInfoMessage: public TBusMessage {
+ public:
+ THostInfoMessage()
+ : TBusMessage(TYPE_HOSTINFOREQUEST)
+ {
+ }
+ THostInfoMessage(ECreateUninitialized)
+ : TBusMessage(MESSAGE_CREATE_UNINITIALIZED)
+ {
+ }
+
+ ~THostInfoMessage() override {
+ }
+ };
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief HostInfo reply class
+ class THostInfoReply: public TBusMessage {
+ public:
+ THostInfoReply()
+ : TBusMessage(TYPE_HOSTINFORESPONSE)
+ {
+ }
+ THostInfoReply(ECreateUninitialized)
+ : TBusMessage(MESSAGE_CREATE_UNINITIALIZED)
+ {
+ }
+
+ ~THostInfoReply() override {
+ }
+ };
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief HostInfo protocol that common between client and server
+ class THostInfoProtocol: public TBusProtocol {
+ public:
+ THostInfoProtocol()
+ : TBusProtocol("HOSTINFO", 0)
+ {
+ }
+ /// serialized protocol specific data into TBusData
+ void Serialize(const TBusMessage* mess, TBuffer& data) override {
+ Y_UNUSED(data);
+ Y_UNUSED(mess);
+ }
+
+ /// deserialized TBusData into new instance of the message
+ TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) override {
+ Y_UNUSED(payload);
+
+ if (messageType == TYPE_HOSTINFOREQUEST) {
+ return new THostInfoMessage(MESSAGE_CREATE_UNINITIALIZED);
+ } else if (messageType == TYPE_HOSTINFORESPONSE) {
+ return new THostInfoReply(MESSAGE_CREATE_UNINITIALIZED);
+ } else {
+ Y_FAIL("unknown");
+ }
+ }
+ };
+
+ //////////////////////////////////////////////////////////////
+ /// \brief HostInfo handler (should convert it to module too)
+ struct THostInfoHandler: public TBusServerHandlerError {
+ TBusServerSessionPtr Session;
+ TBusServerSessionConfig HostInfoConfig;
+ THostInfoProtocol HostInfoProto;
+
+ THostInfoHandler(TBusMessageQueue* queue) {
+ Session = TBusServerSession::Create(&HostInfoProto, this, HostInfoConfig, queue);
+ }
+
+ void OnMessage(TOnMessageContext& mess) override {
+ usleep(10 * 1000); /// pretend we are doing something
+
+ TAutoPtr<THostInfoReply> reply(new THostInfoReply());
+
+ mess.SendReplyMove(reply);
+ }
+
+ TNetAddr GetActualListenAddr() {
+ return TNetAddr("localhost", Session->GetActualListenPort());
+ }
+ };
+
+ //////////////////////////////////////////////////////////////
+ /// \brief DupDetect handler (should convert it to module too)
+ struct TDupDetectHandler: public TBusClientHandlerError {
+ TNetAddr ServerAddr;
+
+ TBusClientSessionPtr DupDetect;
+ TBusClientSessionConfig DupDetectConfig;
+ TExampleProtocol DupDetectProto;
+
+ int NumMessages;
+ int NumReplies;
+
+ TDupDetectHandler(const TNetAddr& serverAddr, TBusMessageQueuePtr queue)
+ : ServerAddr(serverAddr)
+ {
+ DupDetect = TBusClientSession::Create(&DupDetectProto, this, DupDetectConfig, queue);
+ DupDetect->RegisterService("localhost");
+ }
+
+ void Work() {
+ NumMessages = 10;
+ NumReplies = 0;
+
+ for (int i = 0; i < NumMessages; i++) {
+ TExampleRequest* mess = new TExampleRequest(&DupDetectProto.RequestCount);
+ DupDetect->SendMessage(mess, &ServerAddr);
+ }
+ }
+
+ void OnReply(TAutoPtr<TBusMessage> mess, TAutoPtr<TBusMessage> reply) override {
+ Y_UNUSED(mess);
+ Y_UNUSED(reply);
+ NumReplies++;
+ }
+ };
+
+ /////////////////////////////////////////////////////////////////
+ /// \brief DupDetect module
+
+ struct TDupDetectModule: public TBusModule {
+ TNetAddr HostInfoAddr;
+
+ TBusClientSessionPtr HostInfoClientSession;
+ TBusClientSessionConfig HostInfoConfig;
+ THostInfoProtocol HostInfoProto;
+
+ TExampleProtocol DupDetectProto;
+ TBusServerSessionConfig DupDetectConfig;
+
+ TNetAddr ListenAddr;
+
+ TDupDetectModule(const TNetAddr& hostInfoAddr)
+ : TBusModule("DUPDETECTMODULE")
+ , HostInfoAddr(hostInfoAddr)
+ {
+ }
+
+ bool Init(TBusMessageQueue* queue) {
+ HostInfoClientSession = CreateDefaultSource(*queue, &HostInfoProto, HostInfoConfig);
+ HostInfoClientSession->RegisterService("localhost");
+
+ return TBusModule::CreatePrivateSessions(queue);
+ }
+
+ TBusServerSessionPtr CreateExtSession(TBusMessageQueue& queue) override {
+ TBusServerSessionPtr session = CreateDefaultDestination(queue, &DupDetectProto, DupDetectConfig);
+
+ ListenAddr = TNetAddr("localhost", session->GetActualListenPort());
+
+ return session;
+ }
+
+ /// entry point into module, first function to call
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ TExampleRequest* dmess = dynamic_cast<TExampleRequest*>(mess);
+ Y_UNUSED(dmess);
+
+ THostInfoMessage* hmess = new THostInfoMessage();
+
+ /// send message to imaginary hostinfo server
+ job->Send(hmess, HostInfoClientSession, TReplyHandler(), 0, HostInfoAddr);
+
+ return TJobHandler(&TDupDetectModule::ProcessHostInfo);
+ }
+
+ /// next handler is executed when all outstanding requests from previous handler is completed
+ TJobHandler ProcessHostInfo(TBusJob* job, TBusMessage* mess) {
+ TExampleRequest* dmess = dynamic_cast<TExampleRequest*>(mess);
+ Y_UNUSED(dmess);
+
+ THostInfoMessage* hmess = job->Get<THostInfoMessage>();
+ THostInfoReply* hreply = job->Get<THostInfoReply>();
+ EMessageStatus hstatus = job->GetStatus<THostInfoMessage>();
+ Y_ASSERT(hmess != nullptr);
+ Y_ASSERT(hreply != nullptr);
+ Y_ASSERT(hstatus == MESSAGE_OK);
+
+ return TJobHandler(&TDupDetectModule::Finish);
+ }
+
+ /// last handler sends reply and returns NULL
+ TJobHandler Finish(TBusJob* job, TBusMessage* mess) {
+ Y_UNUSED(mess);
+
+ TExampleResponse* reply = new TExampleResponse(&DupDetectProto.ResponseCount);
+ job->SendReply(reply);
+
+ return nullptr;
+ }
+ };
+
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/one_way_ut.cpp b/library/cpp/messagebus/test/ut/one_way_ut.cpp
new file mode 100644
index 0000000000..9c21227e2b
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/one_way_ut.cpp
@@ -0,0 +1,255 @@
+///////////////////////////////////////////////////////////////////
+/// \file
+/// \brief Example of reply-less communication
+
+/// This example demostrates how asynchronous message passing library
+/// can be used to send message and do not wait for reply back.
+/// The usage of reply-less communication should be restricted to
+/// low-throughput clients and high-throughput server to provide reasonable
+/// utility. Removing replies from the communication removes any restriction
+/// on how many message can be send to server and rougue clients may overwelm
+/// server without thoughtput control.
+
+/// 1) To implement reply-less client \n
+
+/// Call NBus::TBusSession::AckMessage()
+/// from within NBus::IMessageHandler::OnSent() handler when message has
+/// gone into wire on client end. See example in NBus::NullClient::OnMessageSent().
+/// Discard identity for reply message.
+
+/// 2) To implement reply-less server \n
+
+/// Call NBus::TBusSession::AckMessage() from within NBus::IMessageHandler::OnMessage()
+/// handler when message has been received on server end.
+/// See example in NBus::NullServer::OnMessage().
+/// Discard identity for reply message.
+
+#include <library/cpp/messagebus/test/helper/alloc_counter.h>
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/hanging_server.h>
+#include <library/cpp/messagebus/test/helper/message_handler_error.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+#include <library/cpp/messagebus/test/helper/wait_for.h>
+
+#include <library/cpp/messagebus/ybus.h>
+
+using namespace std;
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NBus::NTest;
+
+////////////////////////////////////////////////////////////////////
+////////////////////////////////////////////////////////////////////
+/// \brief Reply-less client and handler
+struct NullClient : TBusClientHandlerError {
+ TNetAddr ServerAddr;
+
+ TBusMessageQueuePtr Queue;
+ TBusClientSessionPtr Session;
+ TExampleProtocol Proto;
+
+ /// constructor creates instances of protocol and session
+ NullClient(const TNetAddr& serverAddr, const TBusClientSessionConfig& sessionConfig = TBusClientSessionConfig())
+ : ServerAddr(serverAddr)
+ {
+ UNIT_ASSERT(serverAddr.GetPort() > 0);
+
+ /// create or get instance of message queue, need one per application
+ Queue = CreateMessageQueue();
+
+ /// register source/client session
+ Session = TBusClientSession::Create(&Proto, this, sessionConfig, Queue);
+
+ /// register service, announce to clients via LocatorService
+ Session->RegisterService("localhost");
+ }
+
+ ~NullClient() override {
+ Session->Shutdown();
+ }
+
+ /// dispatch of requests is done here
+ void Work() {
+ int batch = 10;
+
+ for (int i = 0; i < batch; i++) {
+ TExampleRequest* mess = new TExampleRequest(&Proto.RequestCount);
+ mess->Data = "TADA";
+ Session->SendMessageOneWay(mess, &ServerAddr);
+ }
+ }
+
+ void OnMessageSentOneWay(TAutoPtr<TBusMessage>) override {
+ }
+};
+
+/////////////////////////////////////////////////////////////////////
+/// \brief Reply-less server and handler
+class NullServer: public TBusServerHandlerError {
+public:
+ /// session object to maintian
+ TBusMessageQueuePtr Queue;
+ TBusServerSessionPtr Session;
+ TExampleProtocol Proto;
+
+public:
+ TAtomic NumMessages;
+
+ NullServer() {
+ NumMessages = 0;
+
+ /// create or get instance of single message queue, need one for application
+ Queue = CreateMessageQueue();
+
+ /// register destination session
+ TBusServerSessionConfig sessionConfig;
+ Session = TBusServerSession::Create(&Proto, this, sessionConfig, Queue);
+ }
+
+ ~NullServer() override {
+ Session->Shutdown();
+ }
+
+ /// when message comes do not send reply, just acknowledge
+ void OnMessage(TOnMessageContext& mess) override {
+ TExampleRequest* fmess = static_cast<TExampleRequest*>(mess.GetMessage());
+
+ Y_ASSERT(fmess->Data == "TADA");
+
+ /// tell session to forget this message and never expect any reply
+ mess.ForgetRequest();
+
+ AtomicIncrement(NumMessages);
+ }
+
+ /// this handler should not be called because this server does not send replies
+ void OnSent(TAutoPtr<TBusMessage> mess) override {
+ Y_UNUSED(mess);
+ Y_FAIL("This server does not sent replies");
+ }
+};
+
+Y_UNIT_TEST_SUITE(TMessageBusTests_OneWay) {
+ Y_UNIT_TEST(Simple) {
+ TObjectCountCheck objectCountCheck;
+
+ NullServer server;
+ NullClient client(TNetAddr("localhost", server.Session->GetActualListenPort()));
+
+ client.Work();
+
+ // wait until all client message are delivered
+ UNIT_WAIT_FOR(AtomicGet(server.NumMessages) == 10);
+
+ // assert correct number of messages
+ UNIT_ASSERT_VALUES_EQUAL(AtomicGet(server.NumMessages), 10);
+ UNIT_ASSERT_VALUES_EQUAL(server.Session->GetInFlight(), 0);
+ UNIT_ASSERT_VALUES_EQUAL(client.Session->GetInFlight(), 0);
+ }
+
+ struct TMessageTooLargeClient: public NullClient {
+ TSystemEvent GotTooLarge;
+
+ TBusClientSessionConfig Config() {
+ TBusClientSessionConfig r;
+ r.MaxMessageSize = 1;
+ return r;
+ }
+
+ TMessageTooLargeClient(unsigned port)
+ : NullClient(TNetAddr("localhost", port), Config())
+ {
+ }
+
+ ~TMessageTooLargeClient() override {
+ Session->Shutdown();
+ }
+
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ Y_UNUSED(mess);
+
+ Y_VERIFY(status == MESSAGE_MESSAGE_TOO_LARGE, "wrong status: %s", ToCString(status));
+
+ GotTooLarge.Signal();
+ }
+ };
+
+ Y_UNIT_TEST(MessageTooLargeOnClient) {
+ TObjectCountCheck objectCountCheck;
+
+ NullServer server;
+
+ TMessageTooLargeClient client(server.Session->GetActualListenPort());
+
+ EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr);
+ UNIT_ASSERT_VALUES_EQUAL(MESSAGE_OK, ok);
+
+ client.GotTooLarge.WaitI();
+ }
+
+ struct TCheckTimeoutClient: public NullClient {
+ ~TCheckTimeoutClient() override {
+ Session->Shutdown();
+ }
+
+ static TBusClientSessionConfig SessionConfig() {
+ TBusClientSessionConfig sessionConfig;
+ sessionConfig.SendTimeout = 1;
+ sessionConfig.ConnectTimeout = 1;
+ sessionConfig.Secret.TimeoutPeriod = TDuration::MilliSeconds(10);
+ return sessionConfig;
+ }
+
+ TCheckTimeoutClient(const TNetAddr& serverAddr)
+ : NullClient(serverAddr, SessionConfig())
+ {
+ }
+
+ TSystemEvent GotError;
+
+ /// message that could not be delivered
+ void OnError(TAutoPtr<TBusMessage> mess, EMessageStatus status) override {
+ Y_UNUSED(mess);
+ Y_UNUSED(status); // TODO: check status
+
+ GotError.Signal();
+ }
+ };
+
+ Y_UNIT_TEST(SendTimeout_Callback_NoServer) {
+ TObjectCountCheck objectCountCheck;
+
+ TCheckTimeoutClient client(TNetAddr("localhost", 17));
+
+ EMessageStatus ok = client.Session->SendMessageOneWay(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr);
+ UNIT_ASSERT_EQUAL(ok, MESSAGE_OK);
+
+ client.GotError.WaitI();
+ }
+
+ Y_UNIT_TEST(SendTimeout_Callback_HangingServer) {
+ THangingServer server;
+
+ TObjectCountCheck objectCountCheck;
+
+ TCheckTimeoutClient client(TNetAddr("localhost", server.GetPort()));
+
+ bool first = true;
+ for (;;) {
+ EMessageStatus ok = client.Session->SendMessageOneWayMove(new TExampleRequest(&client.Proto.RequestCount), &client.ServerAddr);
+ if (ok == MESSAGE_BUSY) {
+ UNIT_ASSERT(!first);
+ break;
+ }
+ UNIT_ASSERT_VALUES_EQUAL(ok, MESSAGE_OK);
+ first = false;
+ }
+
+ // BUGBUG: The test is buggy: the client might not get any error when sending one-way messages.
+ // All the messages that the client has sent before he gets first MESSAGE_BUSY error might get
+ // serailized and written to the socket buffer, so the write queue gets drained and there are
+ // no messages to timeout when periodic timeout check happens.
+
+ client.GotError.WaitI();
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/starter_ut.cpp b/library/cpp/messagebus/test/ut/starter_ut.cpp
new file mode 100644
index 0000000000..dd4d3aaa5e
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/starter_ut.cpp
@@ -0,0 +1,140 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <library/cpp/messagebus/test/helper/example_module.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+#include <library/cpp/messagebus/test/helper/wait_for.h>
+
+using namespace NBus;
+using namespace NBus::NTest;
+
+Y_UNIT_TEST_SUITE(TBusStarterTest) {
+ struct TStartJobTestModule: public TExampleModule {
+ using TBusModule::CreateDefaultStarter;
+
+ TAtomic StartCount;
+
+ TStartJobTestModule()
+ : StartCount(0)
+ {
+ }
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+ AtomicIncrement(StartCount);
+ job->Sleep(10);
+ return &TStartJobTestModule::End;
+ }
+
+ TJobHandler End(TBusJob* job, TBusMessage* mess) {
+ Y_UNUSED(mess);
+ AtomicIncrement(StartCount);
+ job->Cancel(MESSAGE_UNKNOWN);
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(Test) {
+ TObjectCountCheck objectCountCheck;
+
+ TBusMessageQueuePtr bus(CreateMessageQueue());
+
+ TStartJobTestModule module;
+
+ //module.StartModule();
+ module.CreatePrivateSessions(bus.Get());
+ module.StartInput();
+
+ TBusSessionConfig config;
+ config.SendTimeout = 10;
+
+ module.CreateDefaultStarter(*bus, config);
+
+ UNIT_WAIT_FOR(AtomicGet(module.StartCount) >= 3);
+
+ module.Shutdown();
+ bus->Stop();
+ }
+
+ Y_UNIT_TEST(TestModuleStartJob) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleProtocol proto;
+
+ TStartJobTestModule module;
+
+ TBusModuleConfig moduleConfig;
+ moduleConfig.Secret.SchedulePeriod = TDuration::MilliSeconds(10);
+ module.SetConfig(moduleConfig);
+
+ module.StartModule();
+
+ module.StartJob(new TExampleRequest(&proto.RequestCount));
+
+ UNIT_WAIT_FOR(AtomicGet(module.StartCount) != 2);
+
+ module.Shutdown();
+ }
+
+ struct TSleepModule: public TExampleServerModule {
+ TSystemEvent MessageReceivedEvent;
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+
+ MessageReceivedEvent.Signal();
+
+ job->Sleep(1000000000);
+
+ return TJobHandler(&TSleepModule::Never);
+ }
+
+ TJobHandler Never(TBusJob*, TBusMessage*) {
+ Y_FAIL("happens");
+ throw 1;
+ }
+ };
+
+ Y_UNIT_TEST(StartJobDestroyDuringSleep) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleProtocol proto;
+
+ TSleepModule module;
+
+ module.StartModule();
+
+ module.StartJob(new TExampleRequest(&proto.StartCount));
+
+ module.MessageReceivedEvent.WaitI();
+
+ module.Shutdown();
+ }
+
+ struct TSendReplyModule: public TExampleServerModule {
+ TSystemEvent MessageReceivedEvent;
+
+ TJobHandler Start(TBusJob* job, TBusMessage* mess) override {
+ Y_UNUSED(mess);
+
+ job->SendReply(new TExampleResponse(&Proto.ResponseCount));
+
+ MessageReceivedEvent.Signal();
+
+ return nullptr;
+ }
+ };
+
+ Y_UNIT_TEST(AllowSendReplyInStarted) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleProtocol proto;
+
+ TSendReplyModule module;
+ module.StartModule();
+ module.StartJob(new TExampleRequest(&proto.StartCount));
+
+ module.MessageReceivedEvent.WaitI();
+
+ module.Shutdown();
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/sync_client_ut.cpp b/library/cpp/messagebus/test/ut/sync_client_ut.cpp
new file mode 100644
index 0000000000..400128193f
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/sync_client_ut.cpp
@@ -0,0 +1,69 @@
+#include <library/cpp/messagebus/test/helper/example.h>
+#include <library/cpp/messagebus/test/helper/object_count_check.h>
+
+namespace NBus {
+ namespace NTest {
+ using namespace std;
+
+ ////////////////////////////////////////////////////////////////////
+ /// \brief Client for sending synchronous message to local server
+ struct TSyncClient {
+ TNetAddr ServerAddr;
+
+ TExampleProtocol Proto;
+ TBusMessageQueuePtr Bus;
+ TBusSyncClientSessionPtr Session;
+
+ int NumReplies;
+ int NumMessages;
+
+ /// constructor creates instances of queue, protocol and session
+ TSyncClient(const TNetAddr& serverAddr)
+ : ServerAddr(serverAddr)
+ {
+ /// create or get instance of message queue, need one per application
+ Bus = CreateMessageQueue();
+
+ NumReplies = 0;
+ NumMessages = 10;
+
+ /// register source/client session
+ TBusClientSessionConfig sessionConfig;
+ Session = Bus->CreateSyncSource(&Proto, sessionConfig);
+ Session->RegisterService("localhost");
+ }
+
+ ~TSyncClient() {
+ Session->Shutdown();
+ }
+
+ /// dispatch of requests is done here
+ void Work() {
+ for (int i = 0; i < NumMessages; i++) {
+ THolder<TExampleRequest> mess(new TExampleRequest(&Proto.RequestCount));
+ EMessageStatus status;
+ THolder<TBusMessage> reply(Session->SendSyncMessage(mess.Get(), status, &ServerAddr));
+ if (!!reply) {
+ NumReplies++;
+ }
+ }
+ }
+ };
+
+ Y_UNIT_TEST_SUITE(SyncClientTest) {
+ Y_UNIT_TEST(TestSync) {
+ TObjectCountCheck objectCountCheck;
+
+ TExampleServer server;
+ TSyncClient client(server.GetActualListenAddr());
+ client.Work();
+ // assert correct number of replies
+ UNIT_ASSERT_EQUAL(client.NumReplies, client.NumMessages);
+ // assert that there is no message left in flight
+ UNIT_ASSERT_EQUAL(server.Session->GetInFlight(), 0);
+ UNIT_ASSERT_EQUAL(client.Session->GetInFlight(), 0);
+ }
+ }
+
+ }
+}
diff --git a/library/cpp/messagebus/test/ut/ya.make b/library/cpp/messagebus/test/ut/ya.make
new file mode 100644
index 0000000000..fe1b4961d6
--- /dev/null
+++ b/library/cpp/messagebus/test/ut/ya.make
@@ -0,0 +1,56 @@
+OWNER(g:messagebus)
+
+UNITTEST_FOR(library/cpp/messagebus)
+
+TIMEOUT(1200)
+
+SIZE(LARGE)
+
+TAG(
+ ya:not_autocheck
+ ya:fat
+)
+
+FORK_SUBTESTS()
+
+PEERDIR(
+ library/cpp/testing/unittest_main
+ library/cpp/messagebus
+ library/cpp/messagebus/test/helper
+ library/cpp/messagebus/www
+)
+
+SRCS(
+ messagebus_ut.cpp
+ module_client_ut.cpp
+ module_client_one_way_ut.cpp
+ module_server_ut.cpp
+ one_way_ut.cpp
+ starter_ut.cpp
+ sync_client_ut.cpp
+ locator_uniq_ut.cpp
+ ../../actor/actor_ut.cpp
+ ../../actor/ring_buffer_ut.cpp
+ ../../actor/tasks_ut.cpp
+ ../../actor/what_thread_does_guard_ut.cpp
+ ../../async_result_ut.cpp
+ ../../cc_semaphore_ut.cpp
+ ../../coreconn_ut.cpp
+ ../../duration_histogram_ut.cpp
+ ../../message_status_counter_ut.cpp
+ ../../misc/weak_ptr_ut.cpp
+ ../../latch_ut.cpp
+ ../../lfqueue_batch_ut.cpp
+ ../../local_flags_ut.cpp
+ ../../memory_ut.cpp
+ ../../moved_ut.cpp
+ ../../netaddr_ut.cpp
+ ../../network_ut.cpp
+ ../../nondestroying_holder_ut.cpp
+ ../../scheduler_actor_ut.cpp
+ ../../scheduler/scheduler_ut.cpp
+ ../../socket_addr_ut.cpp
+ ../../vector_swaps_ut.cpp
+)
+
+END()
diff --git a/library/cpp/messagebus/test/ya.make b/library/cpp/messagebus/test/ya.make
new file mode 100644
index 0000000000..0dc4bd4720
--- /dev/null
+++ b/library/cpp/messagebus/test/ya.make
@@ -0,0 +1,7 @@
+OWNER(g:messagebus)
+
+RECURSE(
+ example
+ perftest
+ ut
+)
diff --git a/library/cpp/messagebus/test_utils.h b/library/cpp/messagebus/test_utils.h
new file mode 100644
index 0000000000..2abdf504b1
--- /dev/null
+++ b/library/cpp/messagebus/test_utils.h
@@ -0,0 +1,12 @@
+#pragma once
+
+// Do nothing if there is no support for IPv4
+#define ASSUME_IP_V4_ENABLED \
+ do { \
+ try { \
+ TNetworkAddress("192.168.0.42", 80); \
+ } catch (const TNetworkResolutionError& ex) { \
+ Y_UNUSED(ex); \
+ return; \
+ } \
+ } while (0)
diff --git a/library/cpp/messagebus/text_utils.h b/library/cpp/messagebus/text_utils.h
new file mode 100644
index 0000000000..c2dcad834c
--- /dev/null
+++ b/library/cpp/messagebus/text_utils.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include <library/cpp/string_utils/indent_text/indent_text.h>
diff --git a/library/cpp/messagebus/thread_extra.h b/library/cpp/messagebus/thread_extra.h
new file mode 100644
index 0000000000..2c79741e88
--- /dev/null
+++ b/library/cpp/messagebus/thread_extra.h
@@ -0,0 +1,3 @@
+#pragma once
+
+#include <library/cpp/messagebus/actor/thread_extra.h>
diff --git a/library/cpp/messagebus/use_after_free_checker.cpp b/library/cpp/messagebus/use_after_free_checker.cpp
new file mode 100644
index 0000000000..4904e7c614
--- /dev/null
+++ b/library/cpp/messagebus/use_after_free_checker.cpp
@@ -0,0 +1,22 @@
+#include "use_after_free_checker.h"
+
+#include <util/system/yassert.h>
+
+namespace {
+ const ui64 VALID = (ui64)0xAABBCCDDEEFF0011LL;
+ const ui64 INVALID = (ui64)0x1122334455667788LL;
+}
+
+TUseAfterFreeChecker::TUseAfterFreeChecker()
+ : Magic(VALID)
+{
+}
+
+TUseAfterFreeChecker::~TUseAfterFreeChecker() {
+ Y_VERIFY(Magic == VALID, "Corrupted");
+ Magic = INVALID;
+}
+
+void TUseAfterFreeChecker::CheckNotFreed() const {
+ Y_VERIFY(Magic == VALID, "Freed or corrupted");
+}
diff --git a/library/cpp/messagebus/use_after_free_checker.h b/library/cpp/messagebus/use_after_free_checker.h
new file mode 100644
index 0000000000..590b076156
--- /dev/null
+++ b/library/cpp/messagebus/use_after_free_checker.h
@@ -0,0 +1,31 @@
+#pragma once
+
+#include <util/system/platform.h>
+#include <util/system/types.h>
+
+class TUseAfterFreeChecker {
+private:
+ ui64 Magic;
+
+public:
+ TUseAfterFreeChecker();
+ ~TUseAfterFreeChecker();
+ void CheckNotFreed() const;
+};
+
+// check twice: in constructor and in destructor
+class TUseAfterFreeCheckerGuard {
+private:
+ const TUseAfterFreeChecker& Check;
+
+public:
+ TUseAfterFreeCheckerGuard(const TUseAfterFreeChecker& check)
+ : Check(check)
+ {
+ Check.CheckNotFreed();
+ }
+
+ ~TUseAfterFreeCheckerGuard() {
+ Check.CheckNotFreed();
+ }
+};
diff --git a/library/cpp/messagebus/use_count_checker.cpp b/library/cpp/messagebus/use_count_checker.cpp
new file mode 100644
index 0000000000..c6243ea21f
--- /dev/null
+++ b/library/cpp/messagebus/use_count_checker.cpp
@@ -0,0 +1,53 @@
+#include "use_count_checker.h"
+
+#include <util/generic/utility.h>
+#include <util/system/yassert.h>
+
+TUseCountChecker::TUseCountChecker() {
+}
+
+TUseCountChecker::~TUseCountChecker() {
+ TAtomicBase count = Counter.Val();
+ Y_VERIFY(count == 0, "must not release when count is not zero: %ld", (long)count);
+}
+
+void TUseCountChecker::Inc() {
+ Counter.Inc();
+}
+
+void TUseCountChecker::Dec() {
+ Counter.Dec();
+}
+
+TUseCountHolder::TUseCountHolder()
+ : CurrentChecker(nullptr)
+{
+}
+
+TUseCountHolder::TUseCountHolder(TUseCountChecker* currentChecker)
+ : CurrentChecker(currentChecker)
+{
+ if (!!CurrentChecker) {
+ CurrentChecker->Inc();
+ }
+}
+
+TUseCountHolder::~TUseCountHolder() {
+ if (!!CurrentChecker) {
+ CurrentChecker->Dec();
+ }
+}
+
+TUseCountHolder& TUseCountHolder::operator=(TUseCountHolder that) {
+ Swap(that);
+ return *this;
+}
+
+void TUseCountHolder::Swap(TUseCountHolder& that) {
+ DoSwap(CurrentChecker, that.CurrentChecker);
+}
+
+void TUseCountHolder::Reset() {
+ TUseCountHolder tmp;
+ Swap(tmp);
+}
diff --git a/library/cpp/messagebus/use_count_checker.h b/library/cpp/messagebus/use_count_checker.h
new file mode 100644
index 0000000000..70bef6fa8a
--- /dev/null
+++ b/library/cpp/messagebus/use_count_checker.h
@@ -0,0 +1,27 @@
+#pragma once
+
+#include <util/generic/refcount.h>
+
+class TUseCountChecker {
+private:
+ TAtomicCounter Counter;
+
+public:
+ TUseCountChecker();
+ ~TUseCountChecker();
+ void Inc();
+ void Dec();
+};
+
+class TUseCountHolder {
+private:
+ TUseCountChecker* CurrentChecker;
+
+public:
+ TUseCountHolder();
+ explicit TUseCountHolder(TUseCountChecker* currentChecker);
+ TUseCountHolder& operator=(TUseCountHolder that);
+ ~TUseCountHolder();
+ void Swap(TUseCountHolder&);
+ void Reset();
+};
diff --git a/library/cpp/messagebus/vector_swaps.h b/library/cpp/messagebus/vector_swaps.h
new file mode 100644
index 0000000000..b920bcf03e
--- /dev/null
+++ b/library/cpp/messagebus/vector_swaps.h
@@ -0,0 +1,171 @@
+#pragma once
+
+#include <util/generic/array_ref.h>
+#include <util/generic/noncopyable.h>
+#include <util/generic/utility.h>
+#include <util/system/yassert.h>
+
+#include <stdlib.h>
+
+template <typename T, class A = std::allocator<T>>
+class TVectorSwaps : TNonCopyable {
+private:
+ T* Start;
+ T* Finish;
+ T* EndOfStorage;
+
+ void StateCheck() {
+ Y_ASSERT(Start <= Finish);
+ Y_ASSERT(Finish <= EndOfStorage);
+ }
+
+public:
+ typedef T* iterator;
+ typedef const T* const_iterator;
+
+ typedef std::reverse_iterator<iterator> reverse_iterator;
+ typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
+
+ TVectorSwaps()
+ : Start()
+ , Finish()
+ , EndOfStorage()
+ {
+ }
+
+ ~TVectorSwaps() {
+ for (size_t i = 0; i < size(); ++i) {
+ Start[i].~T();
+ }
+ free(Start);
+ }
+
+ operator TArrayRef<const T>() const {
+ return MakeArrayRef(data(), size());
+ }
+
+ operator TArrayRef<T>() {
+ return MakeArrayRef(data(), size());
+ }
+
+ size_t capacity() const {
+ return EndOfStorage - Start;
+ }
+
+ size_t size() const {
+ return Finish - Start;
+ }
+
+ bool empty() const {
+ return size() == 0;
+ }
+
+ T* data() {
+ return Start;
+ }
+
+ const T* data() const {
+ return Start;
+ }
+
+ T& operator[](size_t index) {
+ Y_ASSERT(index < size());
+ return Start[index];
+ }
+
+ const T& operator[](size_t index) const {
+ Y_ASSERT(index < size());
+ return Start[index];
+ }
+
+ iterator begin() {
+ return Start;
+ }
+
+ iterator end() {
+ return Finish;
+ }
+
+ const_iterator begin() const {
+ return Start;
+ }
+
+ const_iterator end() const {
+ return Finish;
+ }
+
+ reverse_iterator rbegin() {
+ return reverse_iterator(end());
+ }
+ reverse_iterator rend() {
+ return reverse_iterator(begin());
+ }
+
+ const_reverse_iterator rbegin() const {
+ return reverse_iterator(end());
+ }
+ const_reverse_iterator rend() const {
+ return reverse_iterator(begin());
+ }
+
+ void swap(TVectorSwaps<T>& that) {
+ DoSwap(Start, that.Start);
+ DoSwap(Finish, that.Finish);
+ DoSwap(EndOfStorage, that.EndOfStorage);
+ }
+
+ void reserve(size_t n) {
+ if (n <= capacity()) {
+ return;
+ }
+
+ size_t newCapacity = FastClp2(n);
+ TVectorSwaps<T> tmp;
+ tmp.Start = (T*)malloc(sizeof(T) * newCapacity);
+ Y_VERIFY(!!tmp.Start);
+
+ tmp.EndOfStorage = tmp.Start + newCapacity;
+
+ for (size_t i = 0; i < size(); ++i) {
+ // TODO: catch exceptions
+ new (tmp.Start + i) T();
+ DoSwap(Start[i], tmp.Start[i]);
+ }
+
+ tmp.Finish = tmp.Start + size();
+
+ swap(tmp);
+
+ StateCheck();
+ }
+
+ void clear() {
+ TVectorSwaps<T> tmp;
+ swap(tmp);
+ }
+
+ template <class TIterator>
+ void insert(iterator pos, TIterator b, TIterator e) {
+ Y_VERIFY(pos == end(), "TODO: only insert at the end is implemented");
+
+ size_t count = e - b;
+
+ reserve(size() + count);
+
+ TIterator next = b;
+
+ for (size_t i = 0; i < count; ++i) {
+ new (Start + size() + i) T();
+ DoSwap(Start[size() + i], *next);
+ ++next;
+ }
+
+ Finish += count;
+
+ StateCheck();
+ }
+
+ void push_back(T& elem) {
+ insert(end(), &elem, &elem + 1);
+ }
+};
diff --git a/library/cpp/messagebus/vector_swaps_ut.cpp b/library/cpp/messagebus/vector_swaps_ut.cpp
new file mode 100644
index 0000000000..693cc6857b
--- /dev/null
+++ b/library/cpp/messagebus/vector_swaps_ut.cpp
@@ -0,0 +1,17 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "vector_swaps.h"
+
+Y_UNIT_TEST_SUITE(TVectorSwapsTest) {
+ Y_UNIT_TEST(Simple) {
+ TVectorSwaps<THolder<unsigned>> v;
+ for (unsigned i = 0; i < 100; ++i) {
+ THolder<unsigned> tmp(new unsigned(i));
+ v.push_back(tmp);
+ }
+
+ for (unsigned i = 0; i < 100; ++i) {
+ UNIT_ASSERT_VALUES_EQUAL(i, *v[i]);
+ }
+ }
+}
diff --git a/library/cpp/messagebus/www/bus-ico.png b/library/cpp/messagebus/www/bus-ico.png
new file mode 100644
index 0000000000..c69a461892
--- /dev/null
+++ b/library/cpp/messagebus/www/bus-ico.png
Binary files differ
diff --git a/library/cpp/messagebus/www/concat_strings.h b/library/cpp/messagebus/www/concat_strings.h
new file mode 100644
index 0000000000..7b730564eb
--- /dev/null
+++ b/library/cpp/messagebus/www/concat_strings.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#include <util/generic/string.h>
+#include <util/stream/str.h>
+
+// ATTN: not equivalent to TString::Join - cat concat anything "outputable" to stream, not only TString convertable types.
+
+inline void DoConcatStrings(TStringStream&) {
+}
+
+template <class T, class... R>
+inline void DoConcatStrings(TStringStream& ss, const T& t, const R&... r) {
+ ss << t;
+ DoConcatStrings(ss, r...);
+}
+
+template <class... R>
+inline TString ConcatStrings(const R&... r) {
+ TStringStream ss;
+ DoConcatStrings(ss, r...);
+ return ss.Str();
+}
diff --git a/library/cpp/messagebus/www/html_output.cpp b/library/cpp/messagebus/www/html_output.cpp
new file mode 100644
index 0000000000..10ea2e163b
--- /dev/null
+++ b/library/cpp/messagebus/www/html_output.cpp
@@ -0,0 +1,4 @@
+#include "html_output.h"
+
+Y_POD_THREAD(IOutputStream*)
+HtmlOutputStreamPtr;
diff --git a/library/cpp/messagebus/www/html_output.h b/library/cpp/messagebus/www/html_output.h
new file mode 100644
index 0000000000..27e77adefa
--- /dev/null
+++ b/library/cpp/messagebus/www/html_output.h
@@ -0,0 +1,324 @@
+#pragma once
+
+#include "concat_strings.h"
+
+#include <util/generic/string.h>
+#include <util/stream/output.h>
+#include <library/cpp/html/pcdata/pcdata.h>
+#include <util/system/tls.h>
+
+extern Y_POD_THREAD(IOutputStream*) HtmlOutputStreamPtr;
+
+static IOutputStream& HtmlOutputStream() {
+ Y_VERIFY(!!HtmlOutputStreamPtr);
+ return *HtmlOutputStreamPtr;
+}
+
+struct THtmlOutputStreamPushPop {
+ IOutputStream* const Prev;
+
+ THtmlOutputStreamPushPop(IOutputStream* outputStream)
+ : Prev(HtmlOutputStreamPtr)
+ {
+ HtmlOutputStreamPtr = outputStream;
+ }
+
+ ~THtmlOutputStreamPushPop() {
+ HtmlOutputStreamPtr = Prev;
+ }
+};
+
+struct TChars {
+ TString Text;
+ bool NeedEscape;
+
+ TChars(TStringBuf text)
+ : Text(text)
+ , NeedEscape(true)
+ {
+ }
+ TChars(TStringBuf text, bool escape)
+ : Text(text)
+ , NeedEscape(escape)
+ {
+ }
+ TChars(const char* text)
+ : Text(text)
+ , NeedEscape(true)
+ {
+ }
+ TChars(const char* text, bool escape)
+ : Text(text)
+ , NeedEscape(escape)
+ {
+ }
+
+ TString Escape() {
+ if (NeedEscape) {
+ return EncodeHtmlPcdata(Text);
+ } else {
+ return Text;
+ }
+ }
+};
+
+struct TAttr {
+ TString Name;
+ TString Value;
+
+ TAttr(TStringBuf name, TStringBuf value)
+ : Name(name)
+ , Value(value)
+ {
+ }
+
+ TAttr() {
+ }
+
+ bool operator!() const {
+ return !Name;
+ }
+};
+
+static inline void Doctype() {
+ HtmlOutputStream() << "<!doctype html>\n";
+}
+
+static inline void Nl() {
+ HtmlOutputStream() << "\n";
+}
+
+static inline void Sp() {
+ HtmlOutputStream() << " ";
+}
+
+static inline void Text(TStringBuf text) {
+ HtmlOutputStream() << EncodeHtmlPcdata(text);
+}
+
+static inline void Line(TStringBuf text) {
+ Text(text);
+ Nl();
+}
+
+static inline void WriteAttr(TAttr a) {
+ if (!!a) {
+ HtmlOutputStream() << " " << a.Name << "='" << EncodeHtmlPcdata(a.Value) << "'";
+ }
+}
+
+static inline void Open(TStringBuf tag, TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr(), TAttr a4 = TAttr()) {
+ HtmlOutputStream() << "<" << tag;
+ WriteAttr(a1);
+ WriteAttr(a2);
+ WriteAttr(a3);
+ WriteAttr(a4);
+ HtmlOutputStream() << ">";
+}
+
+static inline void Open(TStringBuf tag, TStringBuf cssClass, TStringBuf id = "") {
+ Open(tag, TAttr("class", cssClass), !!id ? TAttr("id", id) : TAttr());
+}
+
+static inline void OpenBlock(TStringBuf tag, TStringBuf cssClass = "") {
+ Open(tag, cssClass);
+ Nl();
+}
+
+static inline void Close(TStringBuf tag) {
+ HtmlOutputStream() << "</" << tag << ">\n";
+}
+
+static inline void CloseBlock(TStringBuf tag) {
+ Close(tag);
+ Nl();
+}
+
+static inline void TagWithContent(TStringBuf tag, TChars content) {
+ HtmlOutputStream() << "<" << tag << ">" << content.Escape() << "</" << tag << ">";
+}
+
+static inline void BlockTagWithContent(TStringBuf tag, TStringBuf content) {
+ TagWithContent(tag, content);
+ Nl();
+}
+
+static inline void TagWithClass(TStringBuf tag, TStringBuf cssClass) {
+ Open(tag, cssClass);
+ Close(tag);
+}
+
+static inline void Hn(unsigned n, TStringBuf title) {
+ BlockTagWithContent(ConcatStrings("h", n), title);
+}
+
+static inline void Small(TStringBuf text) {
+ TagWithContent("small", text);
+}
+
+static inline void HnWithSmall(unsigned n, TStringBuf title, TStringBuf small) {
+ TString tagName = ConcatStrings("h", n);
+ Open(tagName);
+ HtmlOutputStream() << title;
+ Sp();
+ Small(small);
+ Close(tagName);
+}
+
+static inline void H1(TStringBuf title) {
+ Hn(1, title);
+}
+
+static inline void H2(TStringBuf title) {
+ Hn(2, title);
+}
+
+static inline void H3(TStringBuf title) {
+ Hn(3, title);
+}
+
+static inline void H4(TStringBuf title) {
+ Hn(4, title);
+}
+
+static inline void H5(TStringBuf title) {
+ Hn(5, title);
+}
+
+static inline void H6(TStringBuf title) {
+ Hn(6, title);
+}
+
+static inline void Pre(TStringBuf content) {
+ HtmlOutputStream() << "<pre>" << EncodeHtmlPcdata(content) << "</pre>\n";
+}
+
+static inline void Li(TStringBuf content) {
+ BlockTagWithContent("li", content);
+}
+
+static inline void LiWithClass(TStringBuf cssClass, TStringBuf content) {
+ Open("li", cssClass);
+ Text(content);
+ Close("li");
+}
+
+static inline void OpenA(TStringBuf href) {
+ Open("a", TAttr("href", href));
+}
+
+static inline void A(TStringBuf href, TStringBuf text) {
+ OpenA(href);
+ Text(text);
+ Close("a");
+}
+
+static inline void Td(TStringBuf content) {
+ TagWithContent("td", content);
+}
+
+static inline void Th(TStringBuf content, TStringBuf cssClass = "") {
+ OpenBlock("th", cssClass);
+ Text(content);
+ CloseBlock("th");
+}
+
+static inline void DivWithClassAndContent(TStringBuf cssClass, TStringBuf content) {
+ Open("div", cssClass);
+ Text(content);
+ Close("div");
+}
+
+static inline void BootstrapError(TStringBuf text) {
+ DivWithClassAndContent("alert alert-danger", text);
+}
+
+static inline void BootstrapInfo(TStringBuf text) {
+ DivWithClassAndContent("alert alert-info", text);
+}
+
+static inline void ScriptHref(TStringBuf href) {
+ Open("script",
+ TAttr("language", "javascript"),
+ TAttr("type", "text/javascript"),
+ TAttr("src", href));
+ Close("script");
+ Nl();
+}
+
+static inline void LinkStylesheet(TStringBuf href) {
+ Open("link", TAttr("rel", "stylesheet"), TAttr("href", href));
+ Close("link");
+ Nl();
+}
+
+static inline void LinkFavicon(TStringBuf href) {
+ Open("link", TAttr("rel", "shortcut icon"), TAttr("href", href));
+ Close("link");
+ Nl();
+}
+
+static inline void Title(TChars title) {
+ TagWithContent("title", title);
+ Nl();
+}
+
+static inline void Code(TStringBuf content) {
+ TagWithContent("code", content);
+}
+
+struct TTagGuard {
+ const TString TagName;
+
+ TTagGuard(TStringBuf tagName, TStringBuf cssClass, TStringBuf id = "")
+ : TagName(tagName)
+ {
+ Open(TagName, cssClass, id);
+ }
+
+ TTagGuard(TStringBuf tagName, TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr(), TAttr a4 = TAttr())
+ : TagName(tagName)
+ {
+ Open(tagName, a1, a2, a3, a4);
+ }
+
+ ~TTagGuard() {
+ Close(TagName);
+ }
+};
+
+struct TDivGuard: public TTagGuard {
+ TDivGuard(TStringBuf cssClass, TStringBuf id = "")
+ : TTagGuard("div", cssClass, id)
+ {
+ }
+
+ TDivGuard(TAttr a1 = TAttr(), TAttr a2 = TAttr(), TAttr a3 = TAttr())
+ : TTagGuard("div", a1, a2, a3)
+ {
+ }
+};
+
+struct TAGuard {
+ TAGuard(TStringBuf href) {
+ OpenA(href);
+ }
+
+ ~TAGuard() {
+ Close("a");
+ }
+};
+
+struct TScriptFunctionGuard {
+ TTagGuard Script;
+
+ TScriptFunctionGuard()
+ : Script("script")
+ {
+ Line("$(function() {");
+ }
+
+ ~TScriptFunctionGuard() {
+ Line("});");
+ }
+};
diff --git a/library/cpp/messagebus/www/messagebus.js b/library/cpp/messagebus/www/messagebus.js
new file mode 100644
index 0000000000..e30508b879
--- /dev/null
+++ b/library/cpp/messagebus/www/messagebus.js
@@ -0,0 +1,48 @@
+function logTransform(v) {
+ return Math.log(v + 1);
+}
+
+function plotHist(where, hist) {
+ var max = hist.map(function(x) {return x[1]}).reduce(function(x, y) {return Math.max(x, y)});
+
+ var ticks = [];
+ for (var t = 1; ; t *= 10) {
+ if (t > max) {
+ break;
+ }
+ ticks.push(t);
+ }
+
+ $.plot(where, [hist],
+ {
+ data: hist,
+ series: {
+ bars: {
+ show: true,
+ barWidth: 0.9
+ }
+ },
+ xaxis: {
+ mode: 'categories',
+ tickLength: 0
+ },
+ yaxis: {
+ ticks: ticks,
+ transform: logTransform
+ }
+ }
+ );
+}
+
+function plotQueueSize(where, data, ticks) {
+ $.plot(where, [data],
+ {
+ xaxis: {
+ ticks: ticks,
+ },
+ yaxis: {
+ //transform: logTransform
+ }
+ }
+ );
+}
diff --git a/library/cpp/messagebus/www/www.cpp b/library/cpp/messagebus/www/www.cpp
new file mode 100644
index 0000000000..62ec241d85
--- /dev/null
+++ b/library/cpp/messagebus/www/www.cpp
@@ -0,0 +1,930 @@
+#include "www.h"
+
+#include "concat_strings.h"
+#include "html_output.h"
+
+#include <library/cpp/messagebus/remote_connection_status.h>
+#include <library/cpp/monlib/deprecated/json/writer.h>
+
+#include <library/cpp/archive/yarchive.h>
+#include <library/cpp/http/fetch/httpfsm.h>
+#include <library/cpp/http/fetch/httpheader.h>
+#include <library/cpp/http/server/http.h>
+#include <library/cpp/json/writer/json.h>
+#include <library/cpp/uri/http_url.h>
+
+#include <util/string/cast.h>
+#include <util/string/printf.h>
+#include <util/system/mutex.h>
+
+#include <utility>
+
+using namespace NBus;
+using namespace NBus::NPrivate;
+using namespace NActor;
+using namespace NActor::NPrivate;
+
+static const char HTTP_OK_JS[] = "HTTP/1.1 200 Ok\r\nContent-Type: text/javascript\r\nConnection: Close\r\n\r\n";
+static const char HTTP_OK_JSON[] = "HTTP/1.1 200 Ok\r\nContent-Type: application/json; charset=utf-8\r\nConnection: Close\r\n\r\n";
+static const char HTTP_OK_PNG[] = "HTTP/1.1 200 Ok\r\nContent-Type: image/png\r\nConnection: Close\r\n\r\n";
+static const char HTTP_OK_BIN[] = "HTTP/1.1 200 Ok\r\nContent-Type: application/octet-stream\r\nConnection: Close\r\n\r\n";
+static const char HTTP_OK_HTML[] = "HTTP/1.1 200 Ok\r\nContent-Type: text/html; charset=utf-8\r\nConnection: Close\r\n\r\n";
+
+namespace {
+ typedef TIntrusivePtr<TBusModuleInternal> TBusModuleInternalPtr;
+
+ template <typename TValuePtr>
+ struct TNamedValues {
+ TVector<std::pair<TString, TValuePtr>> Entries;
+
+ TValuePtr FindByName(TStringBuf name) {
+ Y_VERIFY(!!name);
+
+ for (unsigned i = 0; i < Entries.size(); ++i) {
+ if (Entries[i].first == name) {
+ return Entries[i].second;
+ }
+ }
+ return TValuePtr();
+ }
+
+ TString FindNameByPtr(TValuePtr value) {
+ Y_VERIFY(!!value);
+
+ for (unsigned i = 0; i < Entries.size(); ++i) {
+ if (Entries[i].second.Get() == value.Get()) {
+ return Entries[i].first;
+ }
+ }
+
+ Y_FAIL("unregistered");
+ }
+
+ void Add(TValuePtr p) {
+ Y_VERIFY(!!p);
+
+ // Do not add twice
+ for (unsigned i = 0; i < Entries.size(); ++i) {
+ if (Entries[i].second.Get() == p.Get()) {
+ return;
+ }
+ }
+
+ if (!!p->GetNameInternal()) {
+ TValuePtr current = FindByName(p->GetNameInternal());
+
+ if (!current) {
+ Entries.emplace_back(p->GetNameInternal(), p);
+ return;
+ }
+ }
+
+ for (unsigned i = 1;; ++i) {
+ TString prefix = p->GetNameInternal();
+ if (!prefix) {
+ prefix = "unnamed";
+ }
+ TString name = ConcatStrings(prefix, "-", i);
+
+ TValuePtr current = FindByName(name);
+
+ if (!current) {
+ Entries.emplace_back(name, p);
+ return;
+ }
+ }
+ }
+
+ size_t size() const {
+ return Entries.size();
+ }
+
+ bool operator!() const {
+ return size() == 0;
+ }
+ };
+
+ template <typename TSessionPtr>
+ struct TSessionValues: public TNamedValues<TSessionPtr> {
+ typedef TNamedValues<TSessionPtr> TBase;
+
+ TVector<TString> GetNamesForQueue(TBusMessageQueue* queue) {
+ TVector<TString> r;
+ for (unsigned i = 0; i < TBase::size(); ++i) {
+ if (TBase::Entries[i].second->GetQueue() == queue) {
+ r.push_back(TBase::Entries[i].first);
+ }
+ }
+ return r;
+ }
+ };
+}
+
+namespace {
+ TString RootHref() {
+ return ConcatStrings("?");
+ }
+
+ TString QueueHref(TStringBuf name) {
+ return ConcatStrings("?q=", name);
+ }
+
+ TString ServerSessionHref(TStringBuf name) {
+ return ConcatStrings("?ss=", name);
+ }
+
+ TString ClientSessionHref(TStringBuf name) {
+ return ConcatStrings("?cs=", name);
+ }
+
+ TString OldModuleHref(TStringBuf name) {
+ return ConcatStrings("?om=", name);
+ }
+
+ /*
+ static void RootLink() {
+ A(RootHref(), "root");
+ }
+ */
+
+ void QueueLink(TStringBuf name) {
+ A(QueueHref(name), name);
+ }
+
+ void ServerSessionLink(TStringBuf name) {
+ A(ServerSessionHref(name), name);
+ }
+
+ void ClientSessionLink(TStringBuf name) {
+ A(ClientSessionHref(name), name);
+ }
+
+ void OldModuleLink(TStringBuf name) {
+ A(OldModuleHref(name), name);
+ }
+
+}
+
+const unsigned char WWW_STATIC_DATA[] = {
+#include "www_static.inc"
+};
+
+class TWwwStaticLoader: public TArchiveReader {
+public:
+ TWwwStaticLoader()
+ : TArchiveReader(TBlob::NoCopy(WWW_STATIC_DATA, sizeof(WWW_STATIC_DATA)))
+ {
+ }
+};
+
+struct TBusWww::TImpl {
+ // TODO: use weak pointers
+ TNamedValues<TBusMessageQueuePtr> Queues;
+ TSessionValues<TIntrusivePtr<TBusClientSession>> ClientSessions;
+ TSessionValues<TIntrusivePtr<TBusServerSession>> ServerSessions;
+ TSessionValues<TBusModuleInternalPtr> Modules;
+
+ TMutex Mutex;
+
+ void RegisterClientSession(TBusClientSessionPtr s) {
+ Y_VERIFY(!!s);
+ TGuard<TMutex> g(Mutex);
+ ClientSessions.Add(s.Get());
+ Queues.Add(s->GetQueue());
+ }
+
+ void RegisterServerSession(TBusServerSessionPtr s) {
+ Y_VERIFY(!!s);
+ TGuard<TMutex> g(Mutex);
+ ServerSessions.Add(s.Get());
+ Queues.Add(s->GetQueue());
+ }
+
+ void RegisterQueue(TBusMessageQueuePtr q) {
+ Y_VERIFY(!!q);
+ TGuard<TMutex> g(Mutex);
+ Queues.Add(q);
+ }
+
+ void RegisterModule(TBusModule* module) {
+ Y_VERIFY(!!module);
+ TGuard<TMutex> g(Mutex);
+
+ {
+ TVector<TBusClientSessionPtr> clientSessions = module->GetInternal()->GetClientSessionsInternal();
+ for (unsigned i = 0; i < clientSessions.size(); ++i) {
+ RegisterClientSession(clientSessions[i]);
+ }
+ }
+
+ {
+ TVector<TBusServerSessionPtr> serverSessions = module->GetInternal()->GetServerSessionsInternal();
+ for (unsigned i = 0; i < serverSessions.size(); ++i) {
+ RegisterServerSession(serverSessions[i]);
+ }
+ }
+
+ Queues.Add(module->GetInternal()->GetQueue());
+ Modules.Add(module->GetInternal());
+ }
+
+ TString FindQueueNameBySessionName(TStringBuf sessionName, bool client) {
+ TIntrusivePtr<TBusClientSession> clientSession;
+ TIntrusivePtr<TBusServerSession> serverSession;
+ TBusSession* session;
+ if (client) {
+ clientSession = ClientSessions.FindByName(sessionName);
+ session = clientSession.Get();
+ } else {
+ serverSession = ServerSessions.FindByName(sessionName);
+ session = serverSession.Get();
+ }
+ Y_VERIFY(!!session);
+ return Queues.FindNameByPtr(session->GetQueue());
+ }
+
+ struct TRequest {
+ TImpl* const Outer;
+ IOutputStream& Os;
+ const TCgiParameters& CgiParams;
+ const TOptionalParams& Params;
+
+ TRequest(TImpl* outer, IOutputStream& os, const TCgiParameters& cgiParams, const TOptionalParams& params)
+ : Outer(outer)
+ , Os(os)
+ , CgiParams(cgiParams)
+ , Params(params)
+ {
+ }
+
+ void CrumbsParentLinks() {
+ for (unsigned i = 0; i < Params.ParentLinks.size(); ++i) {
+ const TLink& link = Params.ParentLinks[i];
+ TTagGuard li("li");
+ A(link.Href, link.Title);
+ }
+ }
+
+ void Crumb(TStringBuf name, TStringBuf href = "") {
+ if (!!href) {
+ TTagGuard li("li");
+ A(href, name);
+ } else {
+ LiWithClass("active", name);
+ }
+ }
+
+ void BreadcrumbRoot() {
+ TTagGuard ol("ol", "breadcrumb");
+ CrumbsParentLinks();
+ Crumb("MessageBus");
+ }
+
+ void BreadcrumbQueue(TStringBuf queueName) {
+ TTagGuard ol("ol", "breadcrumb");
+ CrumbsParentLinks();
+ Crumb("MessageBus", RootHref());
+ Crumb(ConcatStrings("queue ", queueName));
+ }
+
+ void BreadcrumbSession(TStringBuf sessionName, bool client) {
+ TString queueName = Outer->FindQueueNameBySessionName(sessionName, client);
+ TStringBuf whatSession = client ? "client session" : "server session";
+
+ TTagGuard ol("ol", "breadcrumb");
+ CrumbsParentLinks();
+ Crumb("MessageBus", RootHref());
+ Crumb(ConcatStrings("queue ", queueName), QueueHref(queueName));
+ Crumb(ConcatStrings(whatSession, " ", sessionName));
+ }
+
+ void ServeSessionsOfQueue(TBusMessageQueuePtr queue, bool includeQueue) {
+ TVector<TString> clientNames = Outer->ClientSessions.GetNamesForQueue(queue.Get());
+ TVector<TString> serverNames = Outer->ServerSessions.GetNamesForQueue(queue.Get());
+ TVector<TString> moduleNames = Outer->Modules.GetNamesForQueue(queue.Get());
+
+ TTagGuard table("table", "table table-condensed table-bordered");
+
+ {
+ TTagGuard colgroup("colgroup");
+ TagWithClass("col", "col-md-2");
+ TagWithClass("col", "col-md-2");
+ TagWithClass("col", "col-md-8");
+ }
+
+ {
+ TTagGuard tr("tr");
+ Th("What", "span2");
+ Th("Name", "span2");
+ Th("Status", "span6");
+ }
+
+ if (includeQueue) {
+ TTagGuard tr1("tr");
+ Td("queue");
+
+ {
+ TTagGuard td("td");
+ QueueLink(Outer->Queues.FindNameByPtr(queue));
+ }
+
+ {
+ TTagGuard tr2("td");
+ Pre(queue->GetStatusSingleLine());
+ }
+ }
+
+ for (unsigned j = 0; j < clientNames.size(); ++j) {
+ TTagGuard tr("tr");
+ Td("client session");
+
+ {
+ TTagGuard td("td");
+ ClientSessionLink(clientNames[j]);
+ }
+
+ {
+ TTagGuard td("td");
+ Pre(Outer->ClientSessions.FindByName(clientNames[j])->GetStatusSingleLine());
+ }
+ }
+
+ for (unsigned j = 0; j < serverNames.size(); ++j) {
+ TTagGuard tr("tr");
+ Td("server session");
+
+ {
+ TTagGuard td("td");
+ ServerSessionLink(serverNames[j]);
+ }
+
+ {
+ TTagGuard td("td");
+ Pre(Outer->ServerSessions.FindByName(serverNames[j])->GetStatusSingleLine());
+ }
+ }
+
+ for (unsigned j = 0; j < moduleNames.size(); ++j) {
+ TTagGuard tr("tr");
+ Td("module");
+
+ {
+ TTagGuard td("td");
+ if (false) {
+ OldModuleLink(moduleNames[j]);
+ } else {
+ // TODO
+ Text(moduleNames[j]);
+ }
+ }
+
+ {
+ TTagGuard td("td");
+ Pre(Outer->Modules.FindByName(moduleNames[j])->GetStatusSingleLine());
+ }
+ }
+ }
+
+ void ServeQueue(const TString& name) {
+ TBusMessageQueuePtr queue = Outer->Queues.FindByName(name);
+
+ if (!queue) {
+ BootstrapError(ConcatStrings("queue not found by name: ", name));
+ return;
+ }
+
+ BreadcrumbQueue(name);
+
+ TDivGuard container("container");
+
+ H1(ConcatStrings("MessageBus queue ", '"', name, '"'));
+
+ TBusMessageQueueStatus status = queue->GetStatusRecordInternal();
+
+ Pre(status.PrintToString());
+
+ ServeSessionsOfQueue(queue, false);
+
+ HnWithSmall(3, "Peak queue size", "(stored for an hour)");
+
+ {
+ TDivGuard div;
+ TDivGuard div2(TAttr("id", "queue-size-graph"), TAttr("style", "height: 300px"));
+ }
+
+ {
+ TScriptFunctionGuard script;
+
+ NJsonWriter::TBuf data(NJsonWriter::HEM_ESCAPE_HTML);
+ NJsonWriter::TBuf ticks(NJsonWriter::HEM_ESCAPE_HTML);
+
+ const TExecutorHistory& history = status.ExecutorStatus.History;
+
+ data.BeginList();
+ ticks.BeginList();
+ for (unsigned i = 0; i < history.HistoryRecords.size(); ++i) {
+ ui64 secondOfMinute = (history.FirstHistoryRecordSecond() + i) % 60;
+ ui64 minuteOfHour = (history.FirstHistoryRecordSecond() + i) / 60 % 60;
+
+ unsigned printEach;
+
+ if (history.HistoryRecords.size() <= 500) {
+ printEach = 1;
+ } else if (history.HistoryRecords.size() <= 1000) {
+ printEach = 2;
+ } else if (history.HistoryRecords.size() <= 3000) {
+ printEach = 6;
+ } else {
+ printEach = 12;
+ }
+
+ if (secondOfMinute % printEach != 0) {
+ continue;
+ }
+
+ ui32 max = 0;
+ for (unsigned j = 0; j < printEach; ++j) {
+ if (i < j) {
+ continue;
+ }
+ max = Max<ui32>(max, history.HistoryRecords[i - j].MaxQueueSize);
+ }
+
+ data.BeginList();
+ data.WriteString(ToString(i));
+ data.WriteInt(max);
+ data.EndList();
+
+ // TODO: can be done with flot time plugin
+ if (history.HistoryRecords.size() <= 20) {
+ ticks.BeginList();
+ ticks.WriteInt(i);
+ ticks.WriteString(ToString(secondOfMinute));
+ ticks.EndList();
+ } else if (history.HistoryRecords.size() <= 60) {
+ if (secondOfMinute % 5 == 0) {
+ ticks.BeginList();
+ ticks.WriteInt(i);
+ ticks.WriteString(ToString(secondOfMinute));
+ ticks.EndList();
+ }
+ } else {
+ bool needTick;
+ if (history.HistoryRecords.size() <= 3 * 60) {
+ needTick = secondOfMinute % 15 == 0;
+ } else if (history.HistoryRecords.size() <= 7 * 60) {
+ needTick = secondOfMinute % 30 == 0;
+ } else if (history.HistoryRecords.size() <= 20 * 60) {
+ needTick = secondOfMinute == 0;
+ } else {
+ needTick = secondOfMinute == 0 && minuteOfHour % 5 == 0;
+ }
+ if (needTick) {
+ ticks.BeginList();
+ ticks.WriteInt(i);
+ ticks.WriteString(Sprintf(":%02u:%02u", (unsigned)minuteOfHour, (unsigned)secondOfMinute));
+ ticks.EndList();
+ }
+ }
+ }
+ ticks.EndList();
+ data.EndList();
+
+ HtmlOutputStream() << " var data = " << data.Str() << ";\n";
+ HtmlOutputStream() << " var ticks = " << ticks.Str() << ";\n";
+ HtmlOutputStream() << " plotQueueSize('#queue-size-graph', data, ticks);\n";
+ }
+ }
+
+ void ServeSession(TStringBuf name, bool client) {
+ TIntrusivePtr<TBusClientSession> clientSession;
+ TIntrusivePtr<TBusServerSession> serverSession;
+ TBusSession* session;
+ TStringBuf whatSession;
+ if (client) {
+ whatSession = "client session";
+ clientSession = Outer->ClientSessions.FindByName(name);
+ session = clientSession.Get();
+ } else {
+ whatSession = "server session";
+ serverSession = Outer->ServerSessions.FindByName(name);
+ session = serverSession.Get();
+ }
+ if (!session) {
+ BootstrapError(ConcatStrings(whatSession, " not found by name: ", name));
+ return;
+ }
+
+ TSessionDumpStatus dumpStatus = session->GetStatusRecordInternal();
+
+ TBusMessageQueuePtr queue = session->GetQueue();
+ TString queueName = Outer->Queues.FindNameByPtr(session->GetQueue());
+
+ BreadcrumbSession(name, client);
+
+ TDivGuard container("container");
+
+ H1(ConcatStrings("MessageBus ", whatSession, " ", '"', name, '"'));
+
+ TBusMessageQueueStatus queueStatus = queue->GetStatusRecordInternal();
+
+ {
+ H3(ConcatStrings("queue ", queueName));
+ Pre(queueStatus.PrintToString());
+ }
+
+ TSessionDumpStatus status = session->GetStatusRecordInternal();
+
+ if (status.Shutdown) {
+ BootstrapError("Session shut down");
+ return;
+ }
+
+ H3("Basic");
+ Pre(status.Head);
+
+ if (status.ConnectionStatusSummary.Server) {
+ H3("Acceptors");
+ Pre(status.Acceptors);
+ }
+
+ H3("Connections");
+ Pre(status.ConnectionsSummary);
+
+ {
+ TDivGuard div;
+ TTagGuard button("button",
+ TAttr("type", "button"),
+ TAttr("class", "btn"),
+ TAttr("data-toggle", "collapse"),
+ TAttr("data-target", "#connections"));
+ Text("Show connection details");
+ }
+ {
+ TDivGuard div(TAttr("id", "connections"), TAttr("class", "collapse"));
+ Pre(status.Connections);
+ }
+
+ H3("TBusSessionConfig");
+ Pre(status.Config.PrintToString());
+
+ if (!client) {
+ H3("Message process time histogram");
+
+ const TDurationHistogram& h =
+ dumpStatus.ConnectionStatusSummary.WriterStatus.Incremental.ProcessDurationHistogram;
+
+ {
+ TDivGuard div;
+ TDivGuard div2(TAttr("id", "h"), TAttr("style", "height: 300px"));
+ }
+
+ {
+ TScriptFunctionGuard script;
+
+ NJsonWriter::TBuf buf(NJsonWriter::HEM_ESCAPE_HTML);
+ buf.BeginList();
+ for (unsigned i = 0; i < h.Times.size(); ++i) {
+ TString label = TDurationHistogram::LabelBefore(i);
+ buf.BeginList();
+ buf.WriteString(label);
+ buf.WriteLongLong(h.Times[i]);
+ buf.EndList();
+ }
+ buf.EndList();
+
+ HtmlOutputStream() << " var hist = " << buf.Str() << ";\n";
+ HtmlOutputStream() << " plotHist('#h', hist);\n";
+ }
+ }
+ }
+
+ void ServeDefault() {
+ if (!Outer->Queues) {
+ BootstrapError("no queues");
+ return;
+ }
+
+ BreadcrumbRoot();
+
+ TDivGuard container("container");
+
+ H1("MessageBus queues");
+
+ for (unsigned i = 0; i < Outer->Queues.size(); ++i) {
+ TString queueName = Outer->Queues.Entries[i].first;
+ TBusMessageQueuePtr queue = Outer->Queues.Entries[i].second;
+
+ HnWithSmall(3, queueName, "(queue)");
+
+ ServeSessionsOfQueue(queue, true);
+ }
+ }
+
+ void WriteQueueSensors(NMonitoring::TDeprecatedJsonWriter& sj, TStringBuf queueName, TBusMessageQueue* queue) {
+ auto status = queue->GetStatusRecordInternal();
+ sj.OpenMetric();
+ sj.WriteLabels("mb_queue", queueName, "sensor", "WorkQueueSize");
+ sj.WriteValue(status.ExecutorStatus.WorkQueueSize);
+ sj.CloseMetric();
+ }
+
+ void WriteMessageCounterSensors(NMonitoring::TDeprecatedJsonWriter& sj,
+ TStringBuf labelName, TStringBuf sessionName, bool read, const TMessageCounter& counter) {
+ TStringBuf readOrWrite = read ? "read" : "write";
+
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "mb_dir", readOrWrite, "sensor", "MessageBytes");
+ sj.WriteValue(counter.BytesData);
+ sj.WriteModeDeriv();
+ sj.CloseMetric();
+
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "mb_dir", readOrWrite, "sensor", "MessageCount");
+ sj.WriteValue(counter.Count);
+ sj.WriteModeDeriv();
+ sj.CloseMetric();
+ }
+
+ void WriteSessionStatus(NMonitoring::TDeprecatedJsonWriter& sj, TStringBuf sessionName, bool client,
+ TBusSession* session) {
+ TStringBuf labelName = client ? "mb_client_session" : "mb_server_session";
+
+ auto status = session->GetStatusRecordInternal();
+
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "sensor", "InFlightCount");
+ sj.WriteValue(status.Status.InFlightCount);
+ sj.CloseMetric();
+
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "sensor", "InFlightSize");
+ sj.WriteValue(status.Status.InFlightSize);
+ sj.CloseMetric();
+
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "sensor", "SendQueueSize");
+ sj.WriteValue(status.ConnectionStatusSummary.WriterStatus.SendQueueSize);
+ sj.CloseMetric();
+
+ if (client) {
+ sj.OpenMetric();
+ sj.WriteLabels(labelName, sessionName, "sensor", "AckMessagesSize");
+ sj.WriteValue(status.ConnectionStatusSummary.WriterStatus.AckMessagesSize);
+ sj.CloseMetric();
+ }
+
+ WriteMessageCounterSensors(sj, labelName, sessionName, false,
+ status.ConnectionStatusSummary.WriterStatus.Incremental.MessageCounter);
+ WriteMessageCounterSensors(sj, labelName, sessionName, true,
+ status.ConnectionStatusSummary.ReaderStatus.Incremental.MessageCounter);
+ }
+
+ void ServeSolomonJson(const TString& q, const TString& cs, const TString& ss) {
+ Y_UNUSED(q);
+ Y_UNUSED(cs);
+ Y_UNUSED(ss);
+ bool all = q == "" && cs == "" && ss == "";
+
+ NMonitoring::TDeprecatedJsonWriter sj(&Os);
+
+ sj.OpenDocument();
+ sj.OpenMetrics();
+
+ for (unsigned i = 0; i < Outer->Queues.size(); ++i) {
+ TString queueName = Outer->Queues.Entries[i].first;
+ TBusMessageQueuePtr queue = Outer->Queues.Entries[i].second;
+ if (all || q == queueName) {
+ WriteQueueSensors(sj, queueName, &*queue);
+ }
+
+ TVector<TString> clientNames = Outer->ClientSessions.GetNamesForQueue(queue.Get());
+ TVector<TString> serverNames = Outer->ServerSessions.GetNamesForQueue(queue.Get());
+ TVector<TString> moduleNames = Outer->Modules.GetNamesForQueue(queue.Get());
+ for (auto& sessionName : clientNames) {
+ if (all || cs == sessionName) {
+ auto session = Outer->ClientSessions.FindByName(sessionName);
+ WriteSessionStatus(sj, sessionName, true, &*session);
+ }
+ }
+
+ for (auto& sessionName : serverNames) {
+ if (all || ss == sessionName) {
+ auto session = Outer->ServerSessions.FindByName(sessionName);
+ WriteSessionStatus(sj, sessionName, false, &*session);
+ }
+ }
+ }
+
+ sj.CloseMetrics();
+ sj.CloseDocument();
+ }
+
+ void ServeStatic(IOutputStream& os, TStringBuf path) {
+ if (path.EndsWith(".js")) {
+ os << HTTP_OK_JS;
+ } else if (path.EndsWith(".png")) {
+ os << HTTP_OK_PNG;
+ } else {
+ os << HTTP_OK_BIN;
+ }
+ TBlob blob = Singleton<TWwwStaticLoader>()->ObjectBlobByKey(TString("/") + TString(path));
+ os.Write(blob.Data(), blob.Size());
+ }
+
+ void HeaderJsCss() {
+ LinkStylesheet("//yandex.st/bootstrap/3.0.2/css/bootstrap.css");
+ LinkFavicon("?file=bus-ico.png");
+ ScriptHref("//yandex.st/jquery/2.0.3/jquery.js");
+ ScriptHref("//yandex.st/bootstrap/3.0.2/js/bootstrap.js");
+ ScriptHref("//cdnjs.cloudflare.com/ajax/libs/flot/0.8.1/jquery.flot.min.js");
+ ScriptHref("//cdnjs.cloudflare.com/ajax/libs/flot/0.8.1/jquery.flot.categories.min.js");
+ ScriptHref("?file=messagebus.js");
+ }
+
+ void Serve() {
+ THtmlOutputStreamPushPop pp(&Os);
+
+ TCgiParameters::const_iterator file = CgiParams.Find("file");
+ if (file != CgiParams.end()) {
+ ServeStatic(Os, file->second);
+ return;
+ }
+
+ bool solomonJson = false;
+ TCgiParameters::const_iterator fmt = CgiParams.Find("fmt");
+ if (fmt != CgiParams.end()) {
+ if (fmt->second == "solomon-json") {
+ solomonJson = true;
+ }
+ }
+
+ TCgiParameters::const_iterator cs = CgiParams.Find("cs");
+ TCgiParameters::const_iterator ss = CgiParams.Find("ss");
+ TCgiParameters::const_iterator q = CgiParams.Find("q");
+
+ if (solomonJson) {
+ Os << HTTP_OK_JSON;
+
+ TString qp = q != CgiParams.end() ? q->first : "";
+ TString csp = cs != CgiParams.end() ? cs->first : "";
+ TString ssp = ss != CgiParams.end() ? ss->first : "";
+ ServeSolomonJson(qp, csp, ssp);
+ } else {
+ Os << HTTP_OK_HTML;
+
+ Doctype();
+
+ TTagGuard html("html");
+ {
+ TTagGuard head("head");
+
+ HeaderJsCss();
+ // &#x2709; &#x1f68c;
+ Title(TChars("MessageBus", false));
+ }
+
+ TTagGuard body("body");
+
+ if (cs != CgiParams.end()) {
+ ServeSession(cs->second, true);
+ } else if (ss != CgiParams.end()) {
+ ServeSession(ss->second, false);
+ } else if (q != CgiParams.end()) {
+ ServeQueue(q->second);
+ } else {
+ ServeDefault();
+ }
+ }
+ }
+ };
+
+ void ServeHttp(IOutputStream& os, const TCgiParameters& queryArgs, const TBusWww::TOptionalParams& params) {
+ TGuard<TMutex> g(Mutex);
+
+ TRequest request(this, os, queryArgs, params);
+
+ request.Serve();
+ }
+};
+
+NBus::TBusWww::TBusWww()
+ : Impl(new TImpl)
+{
+}
+
+NBus::TBusWww::~TBusWww() {
+}
+
+void NBus::TBusWww::RegisterClientSession(TBusClientSessionPtr s) {
+ Impl->RegisterClientSession(s);
+}
+
+void TBusWww::RegisterServerSession(TBusServerSessionPtr s) {
+ Impl->RegisterServerSession(s);
+}
+
+void TBusWww::RegisterQueue(TBusMessageQueuePtr q) {
+ Impl->RegisterQueue(q);
+}
+
+void TBusWww::RegisterModule(TBusModule* module) {
+ Impl->RegisterModule(module);
+}
+
+void TBusWww::ServeHttp(IOutputStream& httpOutputStream,
+ const TCgiParameters& queryArgs,
+ const TBusWww::TOptionalParams& params) {
+ Impl->ServeHttp(httpOutputStream, queryArgs, params);
+}
+
+struct TBusWwwHttpServer::TImpl: public THttpServer::ICallBack {
+ TIntrusivePtr<TBusWww> Www;
+ THttpServer HttpServer;
+
+ static THttpServer::TOptions MakeHttpServerOptions(unsigned port) {
+ Y_VERIFY(port > 0);
+ THttpServer::TOptions r;
+ r.Port = port;
+ return r;
+ }
+
+ TImpl(TIntrusivePtr<TBusWww> www, unsigned port)
+ : Www(www)
+ , HttpServer(this, MakeHttpServerOptions(port))
+ {
+ HttpServer.Start();
+ }
+
+ struct TClientRequestImpl: public TClientRequest {
+ TBusWwwHttpServer::TImpl* const Outer;
+
+ TClientRequestImpl(TBusWwwHttpServer::TImpl* outer)
+ : Outer(outer)
+ {
+ }
+
+ bool Reply(void*) override {
+ Outer->ServeRequest(Input(), Output());
+ return true;
+ }
+ };
+
+ TString MakeSimpleResponse(unsigned code, TString text, TString content = "") {
+ if (!content) {
+ TStringStream contentSs;
+ contentSs << code << " " << text;
+ content = contentSs.Str();
+ }
+ TStringStream ss;
+ ss << "HTTP/1.1 "
+ << code << " " << text << "\r\nConnection: Close\r\n\r\n"
+ << content;
+ return ss.Str();
+ }
+
+ void ServeRequest(THttpInput& input, THttpOutput& output) {
+ TCgiParameters cgiParams;
+ try {
+ THttpRequestHeader header;
+ THttpHeaderParser parser;
+ parser.Init(&header);
+ if (parser.Execute(input.FirstLine()) < 0) {
+ HtmlOutputStream() << MakeSimpleResponse(400, "Bad request");
+ return;
+ }
+ THttpURL url;
+ if (url.Parse(header.GetUrl()) != THttpURL::ParsedOK) {
+ HtmlOutputStream() << MakeSimpleResponse(400, "Invalid url");
+ return;
+ }
+ cgiParams.Scan(url.Get(THttpURL::FieldQuery));
+
+ TBusWww::TOptionalParams params;
+ //params.ParentLinks.emplace_back();
+ //params.ParentLinks.back().Title = "temp";
+ //params.ParentLinks.back().Href = "http://wiki.yandex-team.ru/";
+
+ Www->ServeHttp(output, cgiParams, params);
+ } catch (...) {
+ output << MakeSimpleResponse(500, "Exception",
+ TString() + "Exception: " + CurrentExceptionMessage());
+ }
+ }
+
+ TClientRequest* CreateClient() override {
+ return new TClientRequestImpl(this);
+ }
+
+ ~TImpl() override {
+ HttpServer.Stop();
+ }
+};
+
+NBus::TBusWwwHttpServer::TBusWwwHttpServer(TIntrusivePtr<TBusWww> www, unsigned port)
+ : Impl(new TImpl(www, port))
+{
+}
+
+NBus::TBusWwwHttpServer::~TBusWwwHttpServer() {
+}
diff --git a/library/cpp/messagebus/www/www.h b/library/cpp/messagebus/www/www.h
new file mode 100644
index 0000000000..6cd652b477
--- /dev/null
+++ b/library/cpp/messagebus/www/www.h
@@ -0,0 +1,45 @@
+#pragma once
+
+#include <library/cpp/messagebus/ybus.h>
+#include <library/cpp/messagebus/oldmodule/module.h>
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+#include <library/cpp/cgiparam/cgiparam.h>
+
+namespace NBus {
+ class TBusWww: public TAtomicRefCount<TBusWww> {
+ public:
+ struct TLink {
+ TString Title;
+ TString Href;
+ };
+
+ struct TOptionalParams {
+ TVector<TLink> ParentLinks;
+ };
+
+ TBusWww();
+ ~TBusWww();
+
+ void RegisterClientSession(TBusClientSessionPtr);
+ void RegisterServerSession(TBusServerSessionPtr);
+ void RegisterQueue(TBusMessageQueuePtr);
+ void RegisterModule(TBusModule*);
+
+ void ServeHttp(IOutputStream& httpOutputStream, const TCgiParameters& queryArgs, const TOptionalParams& params = TOptionalParams());
+
+ struct TImpl;
+ THolder<TImpl> Impl;
+ };
+
+ class TBusWwwHttpServer {
+ public:
+ TBusWwwHttpServer(TIntrusivePtr<TBusWww> www, unsigned port);
+ ~TBusWwwHttpServer();
+
+ struct TImpl;
+ THolder<TImpl> Impl;
+ };
+
+}
diff --git a/library/cpp/messagebus/www/ya.make b/library/cpp/messagebus/www/ya.make
new file mode 100644
index 0000000000..972390cea3
--- /dev/null
+++ b/library/cpp/messagebus/www/ya.make
@@ -0,0 +1,29 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+SRCS(
+ html_output.cpp
+ www.cpp
+)
+
+ARCHIVE(
+ NAME www_static.inc
+ messagebus.js
+ bus-ico.png
+)
+
+PEERDIR(
+ library/cpp/archive
+ library/cpp/cgiparam
+ library/cpp/html/pcdata
+ library/cpp/http/fetch
+ library/cpp/http/server
+ library/cpp/json/writer
+ library/cpp/messagebus
+ library/cpp/messagebus/oldmodule
+ library/cpp/monlib/deprecated/json
+ library/cpp/uri
+)
+
+END()
diff --git a/library/cpp/messagebus/ya.make b/library/cpp/messagebus/ya.make
new file mode 100644
index 0000000000..e13cf06dea
--- /dev/null
+++ b/library/cpp/messagebus/ya.make
@@ -0,0 +1,68 @@
+LIBRARY()
+
+OWNER(g:messagebus)
+
+IF (SANITIZER_TYPE == "undefined")
+ NO_SANITIZE()
+ENDIF()
+
+SRCS(
+ acceptor.cpp
+ acceptor_status.cpp
+ connection.cpp
+ coreconn.cpp
+ duration_histogram.cpp
+ event_loop.cpp
+ futex_like.cpp
+ handler.cpp
+ key_value_printer.cpp
+ local_flags.cpp
+ locator.cpp
+ mb_lwtrace.cpp
+ message.cpp
+ message_counter.cpp
+ message_status.cpp
+ message_status_counter.cpp
+ messqueue.cpp
+ misc/atomic_box.h
+ misc/granup.h
+ misc/test_sync.h
+ misc/tokenquota.h
+ misc/weak_ptr.h
+ network.cpp
+ queue_config.cpp
+ remote_client_connection.cpp
+ remote_client_session.cpp
+ remote_client_session_semaphore.cpp
+ remote_connection.cpp
+ remote_connection_status.cpp
+ remote_server_connection.cpp
+ remote_server_session.cpp
+ remote_server_session_semaphore.cpp
+ session.cpp
+ session_impl.cpp
+ session_job_count.cpp
+ shutdown_state.cpp
+ socket_addr.cpp
+ storage.cpp
+ synchandler.cpp
+ use_after_free_checker.cpp
+ use_count_checker.cpp
+ ybus.h
+)
+
+PEERDIR(
+ contrib/libs/sparsehash
+ library/cpp/codecs
+ library/cpp/deprecated/enum_codegen
+ library/cpp/getopt/small
+ library/cpp/lwtrace
+ library/cpp/messagebus/actor
+ library/cpp/messagebus/config
+ library/cpp/messagebus/monitoring
+ library/cpp/messagebus/scheduler
+ library/cpp/string_utils/indent_text
+ library/cpp/threading/future
+)
+
+END()
diff --git a/library/cpp/messagebus/ybus.h b/library/cpp/messagebus/ybus.h
new file mode 100644
index 0000000000..de21ad8521
--- /dev/null
+++ b/library/cpp/messagebus/ybus.h
@@ -0,0 +1,205 @@
+#pragma once
+
+/// Asynchronous Messaging Library implements framework for sending and
+/// receiving messages between loosely connected processes.
+
+#include "coreconn.h"
+#include "defs.h"
+#include "handler.h"
+#include "handler_impl.h"
+#include "local_flags.h"
+#include "locator.h"
+#include "message.h"
+#include "message_status.h"
+#include "network.h"
+#include "queue_config.h"
+#include "remote_connection_status.h"
+#include "session.h"
+#include "session_config.h"
+#include "socket_addr.h"
+
+#include <library/cpp/messagebus/actor/executor.h>
+#include <library/cpp/messagebus/scheduler/scheduler.h>
+
+#include <library/cpp/codecs/codecs.h>
+
+#include <util/generic/array_ref.h>
+#include <util/generic/buffer.h>
+#include <util/generic/noncopyable.h>
+#include <util/generic/ptr.h>
+#include <util/stream/input.h>
+#include <util/system/atomic.h>
+#include <util/system/condvar.h>
+#include <util/system/type_name.h>
+#include <util/system/event.h>
+#include <util/system/mutex.h>
+
+namespace NBus {
+ ////////////////////////////////////////////////////////
+ /// \brief Common structure to store address information
+
+ int CompareByHost(const IRemoteAddr& l, const IRemoteAddr& r) noexcept;
+ bool operator<(const TNetAddr& a1, const TNetAddr& a2); // compare by addresses
+
+ /////////////////////////////////////////////////////////////////////////
+ /// \brief Handles routing and data encoding to/from wire
+
+ /// Protocol is stateless threadsafe singleton object that
+ /// encapsulates relationship between a message (TBusMessage) object
+ /// and destination server. Protocol object is reponsible for serializing in-memory
+ /// message and reply into the wire, retuning name of the service and resource
+ /// distribution key for given protocol.
+
+ /// Protocol object should transparently handle messages and replies.
+ /// This is interface only class, actuall instances of the protocols
+ /// should be created using templates inhereted from this base class.
+ class TBusProtocol {
+ private:
+ TString ServiceName;
+ int ServicePort;
+
+ public:
+ TBusProtocol(TBusService name = "UNKNOWN", int port = 0)
+ : ServiceName(name)
+ , ServicePort(port)
+ {
+ }
+
+ /// returns service type for this protocol and message
+ TBusService GetService() const {
+ return ServiceName.data();
+ }
+
+ /// returns port number for destination session to open socket
+ int GetPort() const {
+ return ServicePort;
+ }
+
+ virtual ~TBusProtocol() {
+ }
+
+ /// \brief serialized protocol specific data into TBusData
+ /// \note buffer passed to the function (data) is not empty, use append functions
+ virtual void Serialize(const TBusMessage* mess, TBuffer& data) = 0;
+
+ /// deserialized TBusData into new instance of the message
+ virtual TAutoPtr<TBusMessage> Deserialize(ui16 messageType, TArrayRef<const char> payload) = 0;
+
+ /// returns key for messages of this protocol
+ virtual TBusKey GetKey(const TBusMessage*) {
+ return YBUS_KEYMIN;
+ }
+
+ /// default implementation of routing policy to allow overrides
+ virtual EMessageStatus GetDestination(const TBusClientSession* session, TBusMessage* mess, TBusLocator* locator, TNetAddr* addr);
+
+ /// codec for transport level compression
+ virtual NCodecs::TCodecPtr GetTransportCodec(void) const {
+ return NCodecs::ICodec::GetInstance("snappy");
+ }
+ };
+
+ class TBusSyncSourceSession: public TAtomicRefCount<TBusSyncSourceSession> {
+ friend class TBusMessageQueue;
+
+ public:
+ TBusSyncSourceSession(TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> session);
+ ~TBusSyncSourceSession();
+
+ void Shutdown();
+
+ TBusMessage* SendSyncMessage(TBusMessage* pMessage, EMessageStatus& status, const TNetAddr* addr = nullptr);
+
+ int RegisterService(const char* hostname, TBusKey start = YBUS_KEYMIN, TBusKey end = YBUS_KEYMAX, EIpVersion ipVersion = EIP_VERSION_4);
+
+ int GetInFlight();
+
+ const TBusProtocol* GetProto() const;
+
+ const TBusClientSession* GetBusClientSessionWorkaroundDoNotUse() const; // It's for TLoadBalancedProtocol::GetDestination() function that really needs TBusClientSession* unlike all other protocols. Look at review 32425 (http://rb.yandex-team.ru/arc/r/32425/) for more information.
+ private:
+ TIntrusivePtr< ::NBus::NPrivate::TBusSyncSourceSessionImpl> Session;
+ };
+
+ using TBusSyncClientSessionPtr = TIntrusivePtr<TBusSyncSourceSession>;
+
+ ///////////////////////////////////////////////////////////////////
+ /// \brief Main message queue object, need one per application
+ class TBusMessageQueue: public TAtomicRefCount<TBusMessageQueue> {
+ /// allow mesage queue to be created only via factory
+ friend TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name);
+ friend class ::NBus::NPrivate::TRemoteConnection;
+ friend struct ::NBus::NPrivate::TBusSessionImpl;
+ friend class ::NBus::NPrivate::TAcceptor;
+ friend struct ::NBus::TBusServerSession;
+
+ private:
+ const TBusQueueConfig Config;
+ TMutex Lock;
+ TList<TIntrusivePtr< ::NBus::NPrivate::TBusSessionImpl>> Sessions;
+ TSimpleIntrusivePtr<TBusLocator> Locator;
+ NPrivate::TScheduler Scheduler;
+
+ ::NActor::TExecutorPtr WorkQueue;
+
+ TAtomic Running;
+ TSystemEvent ShutdownComplete;
+
+ private:
+ /// constructor is protected, used NBus::CreateMessageQueue() to create a instance
+ TBusMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name);
+
+ public:
+ TString GetNameInternal() const;
+
+ ~TBusMessageQueue();
+
+ void Stop();
+ bool IsRunning();
+
+ public:
+ void EnqueueWork(TArrayRef< ::NActor::IWorkItem* const> w) {
+ WorkQueue->EnqueueWork(w);
+ }
+
+ ::NActor::TExecutor* GetExecutor() {
+ return WorkQueue.Get();
+ }
+
+ TString GetStatus(ui16 flags = YBUS_STATUS_CONNS) const;
+ // without sessions
+ NPrivate::TBusMessageQueueStatus GetStatusRecordInternal() const;
+ TString GetStatusSelf() const;
+ TString GetStatusSingleLine() const;
+
+ TBusLocator* GetLocator() const {
+ return Locator.Get();
+ }
+
+ TBusClientSessionPtr CreateSource(TBusProtocol* proto, IBusClientHandler* handler, const TBusClientSessionConfig& config, const TString& name = "");
+ TBusSyncClientSessionPtr CreateSyncSource(TBusProtocol* proto, const TBusClientSessionConfig& config, bool needReply = true, const TString& name = "");
+ TBusServerSessionPtr CreateDestination(TBusProtocol* proto, IBusServerHandler* hander, const TBusServerSessionConfig& config, const TString& name = "");
+ TBusServerSessionPtr CreateDestination(TBusProtocol* proto, IBusServerHandler* hander, const TBusServerSessionConfig& config, const TVector<TBindResult>& bindTo, const TString& name = "");
+
+ private:
+ void Destroy(TBusSession* session);
+ void Destroy(TBusSyncClientSessionPtr session);
+
+ public:
+ void Schedule(NPrivate::IScheduleItemAutoPtr i);
+
+ private:
+ void DestroyAllSessions();
+ void Add(TIntrusivePtr< ::NBus::NPrivate::TBusSessionImpl> session);
+ void Remove(TBusSession* session);
+ };
+
+ /////////////////////////////////////////////////////////////////
+ /// Factory methods to construct message queue
+ TBusMessageQueuePtr CreateMessageQueue(const char* name = "");
+ TBusMessageQueuePtr CreateMessageQueue(NActor::TExecutorPtr executor, const char* name = "");
+ TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, const char* name = "");
+ TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, TBusLocator* locator, const char* name = "");
+ TBusMessageQueuePtr CreateMessageQueue(const TBusQueueConfig& config, NActor::TExecutorPtr executor, TBusLocator* locator, const char* name = "");
+
+}