aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbabenko <babenko@yandex-team.com>2025-01-19 20:18:26 +0300
committerbabenko <babenko@yandex-team.com>2025-01-19 20:35:20 +0300
commit02877bfb2eed135564651ea971c8c26f6b173266 (patch)
tree61aa7eb7b655564d6af75ecd37ebd1bec9485d82
parent73afb75a4295ec99f0c8c49ab30fbff2b7a542fe (diff)
downloadydb-02877bfb2eed135564651ea971c8c26f6b173266.tar.gz
Implement local bypass for intra-process Bus communication
commit_hash:8332b96f0bc1eb8faef360bf472d47b08923a79d
-rw-r--r--yt/yt/core/bus/bus.h2
-rw-r--r--yt/yt/core/bus/client.h2
-rw-r--r--yt/yt/core/bus/private.h4
-rw-r--r--yt/yt/core/bus/tcp/client.cpp44
-rw-r--r--yt/yt/core/bus/tcp/config.cpp9
-rw-r--r--yt/yt/core/bus/tcp/config.h6
-rw-r--r--yt/yt/core/bus/tcp/connection.cpp231
-rw-r--r--yt/yt/core/bus/tcp/connection.h12
-rw-r--r--yt/yt/core/bus/tcp/dispatcher.h1
-rw-r--r--yt/yt/core/bus/tcp/dispatcher_impl.cpp67
-rw-r--r--yt/yt/core/bus/tcp/dispatcher_impl.h14
-rw-r--r--yt/yt/core/bus/tcp/local_bypass.cpp139
-rw-r--r--yt/yt/core/bus/tcp/local_bypass.h32
-rw-r--r--yt/yt/core/bus/tcp/private.h2
-rw-r--r--yt/yt/core/bus/tcp/server.cpp130
-rw-r--r--yt/yt/core/net/address.cpp8
-rw-r--r--yt/yt/core/rpc/bus/channel.cpp3
-rw-r--r--yt/yt/core/rpc/unittests/lib/common.h78
-rw-r--r--yt/yt/core/rpc/unittests/rpc_ut.cpp29
-rw-r--r--yt/yt/core/ya.make1
20 files changed, 625 insertions, 189 deletions
diff --git a/yt/yt/core/bus/bus.h b/yt/yt/core/bus/bus.h
index 6c0ded55a4..4c08a78440 100644
--- a/yt/yt/core/bus/bus.h
+++ b/yt/yt/core/bus/bus.h
@@ -126,7 +126,7 @@ struct IBus
/*!
* Does not block -- termination typically happens in background.
* It is safe to call this method multiple times.
- * On terminated the instance is no longer usable.
+ * After termination the instance is no longer usable.
* \note Thread affinity: any.
*/
diff --git a/yt/yt/core/bus/client.h b/yt/yt/core/bus/client.h
index 8c23346b95..fa12665fc2 100644
--- a/yt/yt/core/bus/client.h
+++ b/yt/yt/core/bus/client.h
@@ -24,7 +24,7 @@ struct IBusClient
//! Typically used for logging.
virtual const std::string& GetEndpointDescription() const = 0;
- //! Returns the bus' endpoint attributes.
+ //! Returns the bus endpoint attributes.
//! Typically used for constructing errors.
virtual const NYTree::IAttributeDictionary& GetEndpointAttributes() const = 0;
diff --git a/yt/yt/core/bus/private.h b/yt/yt/core/bus/private.h
index 1f8d52f96a..e334303fdf 100644
--- a/yt/yt/core/bus/private.h
+++ b/yt/yt/core/bus/private.h
@@ -15,6 +15,10 @@ namespace NYT::NBus {
YT_DEFINE_GLOBAL(const NLogging::TLogger, BusLogger, "Bus");
YT_DEFINE_GLOBAL(const NProfiling::TProfiler, BusProfiler, "/bus");
+////////////////////////////////////////////////////////////////////////////////
+
+DECLARE_REFCOUNTED_STRUCT(IMessageHandler)
+
using TConnectionId = TGuid;
using TPacketId = TGuid;
diff --git a/yt/yt/core/bus/tcp/client.cpp b/yt/yt/core/bus/tcp/client.cpp
index d2ddf6b24e..5dd0117bee 100644
--- a/yt/yt/core/bus/tcp/client.cpp
+++ b/yt/yt/core/bus/tcp/client.cpp
@@ -146,20 +146,9 @@ public:
: Config_(std::move(config))
, PacketTranscoderFactory_(packetTranscoderFactory)
, MemoryUsageTracker_(std::move(memoryUsageTracker))
- {
- if (Config_->Address) {
- EndpointDescription_ = *Config_->Address;
- } else if (Config_->UnixDomainSocketPath) {
- EndpointDescription_ = Format("unix://%v", *Config_->UnixDomainSocketPath);
- }
-
- EndpointAttributes_ = ConvertToAttributes(BuildYsonStringFluently()
- .BeginMap()
- .Item("address").Value(EndpointDescription_)
- .Item("encryption_mode").Value(Config_->EncryptionMode)
- .Item("verification_mode").Value(Config_->VerificationMode)
- .EndMap());
- }
+ , EndpointDescription_(MakeEndpointDescription(Config_))
+ , EndpointAttributes_(MakeEndpointAttributes(Config_, EndpointDescription_))
+ { }
const std::string& GetEndpointDescription() const override
{
@@ -215,13 +204,32 @@ public:
private:
const TBusClientConfigPtr Config_;
-
IPacketTranscoderFactory* const PacketTranscoderFactory_;
-
const IMemoryUsageTrackerPtr MemoryUsageTracker_;
+ const std::string EndpointDescription_;
+ const IAttributeDictionaryPtr EndpointAttributes_;
- std::string EndpointDescription_;
- IAttributeDictionaryPtr EndpointAttributes_;
+ static std::string MakeEndpointDescription(const TBusClientConfigPtr& config)
+ {
+ if (config->Address) {
+ return *config->Address;
+ } else if (config->UnixDomainSocketPath) {
+ return Format("unix://%v", *config->UnixDomainSocketPath);
+ }
+ YT_ABORT();
+ }
+
+ static IAttributeDictionaryPtr MakeEndpointAttributes(
+ const TBusClientConfigPtr& config,
+ const std::string& endpointDescription)
+ {
+ return ConvertToAttributes(BuildYsonStringFluently()
+ .BeginMap()
+ .Item("address").Value(endpointDescription)
+ .Item("encryption_mode").Value(config->EncryptionMode)
+ .Item("verification_mode").Value(config->VerificationMode)
+ .EndMap());
+ }
};
////////////////////////////////////////////////////////////////////////////////
diff --git a/yt/yt/core/bus/tcp/config.cpp b/yt/yt/core/bus/tcp/config.cpp
index 18e8371bf0..f67c6fba6c 100644
--- a/yt/yt/core/bus/tcp/config.cpp
+++ b/yt/yt/core/bus/tcp/config.cpp
@@ -52,6 +52,9 @@ void TTcpDispatcherConfig::Register(TRegistrar registrar)
registrar.Parameter("bus_certs_directory_path", &TThis::BusCertsDirectoryPath)
.Default();
+
+ registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass)
+ .Default(false);
}
TTcpDispatcherConfigPtr TTcpDispatcherConfig::ApplyDynamic(
@@ -63,6 +66,7 @@ TTcpDispatcherConfigPtr TTcpDispatcherConfig::ApplyDynamic(
UpdateYsonStructField(mergedConfig->Networks, dynamicConfig->Networks);
UpdateYsonStructField(mergedConfig->MultiplexingBands, dynamicConfig->MultiplexingBands);
UpdateYsonStructField(mergedConfig->BusCertsDirectoryPath, dynamicConfig->BusCertsDirectoryPath);
+ UpdateYsonStructField(mergedConfig->EnableLocalBypass, dynamicConfig->EnableLocalBypass);
mergedConfig->Postprocess();
return mergedConfig;
}
@@ -89,6 +93,9 @@ void TTcpDispatcherDynamicConfig::Register(TRegistrar registrar)
registrar.Parameter("bus_certs_directory_path", &TThis::BusCertsDirectoryPath)
.Default();
+
+ registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass)
+ .Default();
}
////////////////////////////////////////////////////////////////////////////////
@@ -141,6 +148,8 @@ void TBusConfig::Register(TRegistrar registrar)
.Default(true);
registrar.Parameter("generate_checksums", &TThis::GenerateChecksums)
.Default(true);
+ registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass)
+ .Default(true);
registrar.Parameter("encryption_mode", &TThis::EncryptionMode)
.Default(EEncryptionMode::Optional);
registrar.Parameter("verification_mode", &TThis::VerificationMode)
diff --git a/yt/yt/core/bus/tcp/config.h b/yt/yt/core/bus/tcp/config.h
index 914035e2d9..491612fabb 100644
--- a/yt/yt/core/bus/tcp/config.h
+++ b/yt/yt/core/bus/tcp/config.h
@@ -50,6 +50,8 @@ public:
//! Used to store TLS/SSL certificate files.
std::optional<TString> BusCertsDirectoryPath;
+ bool EnableLocalBypass;
+
TTcpDispatcherConfigPtr ApplyDynamic(const TTcpDispatcherDynamicConfigPtr& dynamicConfig) const;
REGISTER_YSON_STRUCT(TTcpDispatcherConfig);
@@ -78,7 +80,7 @@ public:
//! Used to store TLS/SSL certificate files.
std::optional<TString> BusCertsDirectoryPath;
- static void Setup(auto&& registrar);
+ std::optional<bool> EnableLocalBypass;
REGISTER_YSON_STRUCT(TTcpDispatcherDynamicConfig);
@@ -107,6 +109,8 @@ public:
bool VerifyChecksums;
bool GenerateChecksums;
+ bool EnableLocalBypass;
+
// Ssl options.
EEncryptionMode EncryptionMode;
EVerificationMode VerificationMode;
diff --git a/yt/yt/core/bus/tcp/connection.cpp b/yt/yt/core/bus/tcp/connection.cpp
index 2ab85dcca0..6807a9d8d4 100644
--- a/yt/yt/core/bus/tcp/connection.cpp
+++ b/yt/yt/core/bus/tcp/connection.cpp
@@ -1,7 +1,7 @@
#include "connection.h"
#include "config.h"
-#include "server.h"
+#include "local_bypass.h"
#include "dispatcher_impl.h"
#include "ssl_context.h"
#include "ssl_helpers.h"
@@ -125,12 +125,9 @@ TTcpConnection::TTcpConnection(
, UnixDomainSocketPath_(unixDomainSocketPath)
, Handler_(std::move(handler))
, Poller_(std::move(poller))
- , LoggingTag_(Format("ConnectionId: %v, ConnectionType: %v, RemoteAddress: %v, EncryptionMode: %v, VerificationMode: %v",
+ , LoggingTag_(Format("ConnectionId: %v, Endpoint: %v",
Id_,
- ConnectionType_,
- EndpointDescription_,
- Config_->EncryptionMode,
- Config_->VerificationMode))
+ EndpointDescription_))
, Logger(BusLogger().WithRawTag(LoggingTag_))
, GenerateChecksums_(Config_->GenerateChecksums)
, Socket_(socket)
@@ -196,15 +193,24 @@ void TTcpConnection::Close()
EncodedFragments_.clear();
+ ILocalMessageHandlerPtr localBypassHandler;
{
auto guard = Guard(Lock_);
FlushStatistics();
+ localBypassHandler = LocalBypassHandler_;
+ }
+
+ if (localBypassHandler) {
+ localBypassHandler->UnsubscribeTerminated(LocalBypassTerminatedCallback_);
}
}
void TTcpConnection::Start()
{
- YT_LOG_DEBUG("Starting TCP connection");
+ YT_LOG_DEBUG("Starting TCP connection (ConnectionType: %v, EncryptionMode: %v, VerificationMode: %v)",
+ ConnectionType_,
+ Config_->EncryptionMode,
+ Config_->VerificationMode);
// Offline in PendingControl_ prevents retrying events until end of Open().
YT_VERIFY(Any(static_cast<EPollControl>(PendingControl_.load()) & EPollControl::Offline));
@@ -480,17 +486,26 @@ void TTcpConnection::OnAddressResolveFinished(const TErrorOr<TNetworkAddress>& r
}
TNetworkAddress address(result.Value(), Port_);
- OnAddressResolved(address);
-
YT_LOG_DEBUG("Connection network address resolved (Address: %v, NetworkName: %v)",
address,
NetworkName_);
+
+ OnAddressResolved(address);
}
void TTcpConnection::OnAddressResolved(const TNetworkAddress& address)
{
State_ = EState::Opening;
+
SetupNetwork(address);
+
+ if (Config_->EnableLocalBypass) {
+ if (auto localHandler = TTcpDispatcher::TImpl::Get()->FindLocalBypassMessageHandler(address)) {
+ InitLocalBypass(std::move(localHandler), address);
+ return;
+ }
+ }
+
ConnectSocket(address);
}
@@ -609,6 +624,61 @@ int TTcpConnection::GetSocketPort()
}
}
+void TTcpConnection::InitLocalBypass(
+ ILocalMessageHandlerPtr localBypassHandler,
+ const NNet::TNetworkAddress& address)
+{
+ {
+ auto guard = Guard(Lock_);
+
+ if (State_ != EState::Opening) {
+ return;
+ }
+
+ State_ = EState::Open;
+
+ LocalBypassHandler_ = std::move(localBypassHandler);
+
+ LocalBypassReplyBus_ = CreateLocalBypassReplyBus(
+ address,
+ LocalBypassHandler_,
+ Handler_);
+
+ LocalBypassTerminatedCallback_ = BIND(&TTcpConnection::OnLocalBypassHandlerTerminated, MakeWeak(this));
+ LocalBypassHandler_->SubscribeTerminated(LocalBypassTerminatedCallback_);
+
+ UpdateConnectionCount(+1);
+
+ LocalBypassActive_.store(true, std::memory_order::release);
+
+ YT_LOG_DEBUG("Local bypass activated");
+ }
+
+ ReadyPromise_.TrySet();
+
+ FlushQueuedMessagesToLocalBypass();
+}
+
+void TTcpConnection::OnLocalBypassHandlerTerminated(const TError& error)
+{
+ Abort(error);
+}
+
+void TTcpConnection::FlushQueuedMessagesToLocalBypass()
+{
+ QueuedMessages_.DequeueAll(
+ /*reverse*/ true,
+ [&] (auto& queuedMessage) {
+ // Log first to avoid producing weird traces.
+ YT_LOG_DEBUG("Queued message sent via local bypass (PacketId: %v)", queuedMessage.PacketId);
+ LocalBypassHandler_->HandleMessage(std::move(queuedMessage.Message), LocalBypassReplyBus_);
+
+ if (queuedMessage.Promise) {
+ queuedMessage.Promise.TrySet();
+ }
+ });
+}
+
void TTcpConnection::ConnectSocket(const TNetworkAddress& address)
{
auto dialer = CreateAsyncDialer(
@@ -734,37 +804,13 @@ TFuture<void> TTcpConnection::Send(TSharedRefArray message, const TSendOptions&
}
}
- TQueuedMessage queuedMessage(std::move(message), options);
- auto promise = queuedMessage.Promise;
- auto pendingOutPayloadBytes = PendingOutPayloadBytes_.fetch_add(queuedMessage.PayloadSize);
-
- // Log first to avoid producing weird traces.
- YT_LOG_DEBUG("Outcoming message enqueued (PacketId: %v, PendingOutPayloadBytes: %v)",
- queuedMessage.PacketId,
- pendingOutPayloadBytes);
-
- if (LastIncompleteWriteTime_ == std::numeric_limits<NProfiling::TCpuInstant>::max()) {
- // Arm stall detection.
- LastIncompleteWriteTime_ = NProfiling::GetCpuInstant();
- }
-
- QueuedMessages_.Enqueue(std::move(queuedMessage));
-
- // Wake up the event processing if needed.
- {
- auto previousPendingControl = static_cast<EPollControl>(PendingControl_.fetch_or(static_cast<ui64>(EPollControl::Write)));
- if (None(previousPendingControl)) {
- YT_LOG_TRACE("Retrying event processing for Send");
- Poller_->Retry(this);
- }
- }
-
- // Double-check the state not to leave any dangling outcoming messages.
- if (State_.load() == EState::Closed) {
- DiscardOutcomingMessages();
+ // Fast path for bypass.
+ if (LocalBypassActive_.load(std::memory_order::acquire)) {
+ return SendViaLocalBypass(std::move(message), options);
}
- return promise;
+ // Slow path.
+ return SendViaSocket(std::move(message), options);
}
void TTcpConnection::SetTosLevel(TTosLevel tosLevel)
@@ -1276,6 +1322,57 @@ bool TTcpConnection::OnHandshakePacketReceived()
return true;
}
+TFuture<void> TTcpConnection::SendViaSocket(TSharedRefArray message, const TSendOptions& options)
+{
+ TQueuedMessage queuedMessage(std::move(message), options);
+ auto promise = queuedMessage.Promise;
+ auto pendingOutPayloadBytes = PendingOutPayloadBytes_.fetch_add(queuedMessage.PayloadSize);
+
+ // Log first to avoid producing weird traces.
+ YT_LOG_DEBUG("Outcoming message enqueued (PacketId: %v, PendingOutPayloadBytes: %v)",
+ queuedMessage.PacketId,
+ pendingOutPayloadBytes);
+
+ if (LastIncompleteWriteTime_ == std::numeric_limits<NProfiling::TCpuInstant>::max()) {
+ // Arm stall detection.
+ LastIncompleteWriteTime_ = NProfiling::GetCpuInstant();
+ }
+
+ QueuedMessages_.Enqueue(std::move(queuedMessage));
+
+ // Wake up the event processing if needed.
+ {
+ auto previousPendingControl = static_cast<EPollControl>(PendingControl_.fetch_or(static_cast<ui64>(EPollControl::Write)));
+ if (None(previousPendingControl)) {
+ YT_LOG_TRACE("Retrying event processing for Send");
+ Poller_->Retry(this);
+ }
+ }
+
+ // Double-check the state not to leave any dangling outcoming messages.
+ if (State_.load() == EState::Closed) {
+ DiscardOutcomingMessages();
+ }
+
+ // Another double-check to prevent a race with bypass activation.
+ if (LocalBypassActive_.load(std::memory_order::acquire)) {
+ FlushQueuedMessagesToLocalBypass();
+ }
+
+ return promise;
+}
+
+TFuture<void> TTcpConnection::SendViaLocalBypass(TSharedRefArray message, const TSendOptions& /*options*/)
+{
+ // Log first to avoid producing weird traces.
+ YT_LOG_DEBUG("Outcoming message sent via local bypass");
+
+ LocalBypassHandler_->HandleMessage(std::move(message), LocalBypassReplyBus_);
+
+ // No delivery tracking for local bypass.
+ return VoidFuture;
+}
+
TTcpConnection::TPacket* TTcpConnection::EnqueuePacket(
EPacketType type,
EPacketFlags flags,
@@ -1657,38 +1754,36 @@ void TTcpConnection::OnTerminate()
void TTcpConnection::ProcessQueuedMessages()
{
- auto messages = QueuedMessages_.DequeueAll();
-
- for (auto it = messages.rbegin(); it != messages.rend(); ++it) {
- auto& queuedMessage = *it;
-
- auto packetId = queuedMessage.PacketId;
- auto flags = queuedMessage.Options.TrackingLevel == EDeliveryTrackingLevel::Full
- ? EPacketFlags::RequestAcknowledgement
- : EPacketFlags::None;
-
- auto* packet = EnqueuePacket(
- EPacketType::Message,
- flags,
- GenerateChecksums_ ? queuedMessage.Options.ChecksummedPartCount : 0,
- packetId,
- std::move(queuedMessage.Message),
- queuedMessage.PayloadSize);
-
- packet->Promise = queuedMessage.Promise;
- if (queuedMessage.Options.EnableSendCancelation) {
- packet->EnableCancel(MakeStrong(this));
- }
+ QueuedMessages_.DequeueAll(
+ /*reverse*/ true,
+ [&] (auto& queuedMessage) {
+ auto packetId = queuedMessage.PacketId;
+ auto flags = queuedMessage.Options.TrackingLevel == EDeliveryTrackingLevel::Full
+ ? EPacketFlags::RequestAcknowledgement
+ : EPacketFlags::None;
+
+ auto* packet = EnqueuePacket(
+ EPacketType::Message,
+ flags,
+ GenerateChecksums_ ? queuedMessage.Options.ChecksummedPartCount : 0,
+ packetId,
+ std::move(queuedMessage.Message),
+ queuedMessage.PayloadSize);
+
+ packet->Promise = queuedMessage.Promise;
+ if (queuedMessage.Options.EnableSendCancelation) {
+ packet->EnableCancel(MakeStrong(this));
+ }
- YT_LOG_DEBUG("Outcoming message dequeued (PacketId: %v, PacketSize: %v, Flags: %v)",
- packetId,
- packet->PacketSize,
- flags);
+ YT_LOG_DEBUG("Outcoming message dequeued (PacketId: %v, PacketSize: %v, Flags: %v)",
+ packetId,
+ packet->PacketSize,
+ flags);
- if (queuedMessage.Promise && !queuedMessage.Options.EnableSendCancelation && !Any(flags & EPacketFlags::RequestAcknowledgement)) {
- queuedMessage.Promise.TrySet();
- }
- }
+ if (queuedMessage.Promise && !queuedMessage.Options.EnableSendCancelation && None(flags & EPacketFlags::RequestAcknowledgement)) {
+ queuedMessage.Promise.TrySet();
+ }
+ });
}
void TTcpConnection::DiscardOutcomingMessages()
diff --git a/yt/yt/core/bus/tcp/connection.h b/yt/yt/core/bus/tcp/connection.h
index a591566428..3d5447263f 100644
--- a/yt/yt/core/bus/tcp/connection.h
+++ b/yt/yt/core/bus/tcp/connection.h
@@ -284,6 +284,11 @@ private:
size_t MaxFragmentsPerWrite_ = 256;
+ ILocalMessageHandlerPtr LocalBypassHandler_;
+ IBusPtr LocalBypassReplyBus_;
+ TCallback<void(const TError&)> LocalBypassTerminatedCallback_;
+ std::atomic<bool> LocalBypassActive_ = false;
+
void Open(TGuard<NThreading::TSpinLock>& guard);
void Close();
void CloseSslSession(ESslState newSslState);
@@ -296,6 +301,10 @@ private:
int GetSocketPort();
+ void InitLocalBypass(ILocalMessageHandlerPtr localBypassHandler, const NNet::TNetworkAddress& address);
+ void OnLocalBypassHandlerTerminated(const TError& error);
+ void FlushQueuedMessagesToLocalBypass();
+
void ConnectSocket(const NNet::TNetworkAddress& address);
void OnDialerFinished(const TErrorOr<TFileDescriptor>& fdOrError);
@@ -323,6 +332,9 @@ private:
bool OnHandshakePacketReceived();
bool OnSslAckPacketReceived();
+ TFuture<void> SendViaSocket(TSharedRefArray message, const TSendOptions& options);
+ TFuture<void> SendViaLocalBypass(TSharedRefArray message, const TSendOptions& options);
+
TPacket* EnqueuePacket(
EPacketType type,
EPacketFlags flags,
diff --git a/yt/yt/core/bus/tcp/dispatcher.h b/yt/yt/core/bus/tcp/dispatcher.h
index b39b015c88..1e6ba8fd97 100644
--- a/yt/yt/core/bus/tcp/dispatcher.h
+++ b/yt/yt/core/bus/tcp/dispatcher.h
@@ -76,6 +76,7 @@ private:
friend class TTcpBusServerBase;
template <class TServer>
friend class TTcpBusServerProxy;
+ friend class TCompositeBusServer;
class TImpl;
const TIntrusivePtr<TImpl> Impl_;
diff --git a/yt/yt/core/bus/tcp/dispatcher_impl.cpp b/yt/yt/core/bus/tcp/dispatcher_impl.cpp
index 5a6334ddbc..c5cb88624b 100644
--- a/yt/yt/core/bus/tcp/dispatcher_impl.cpp
+++ b/yt/yt/core/bus/tcp/dispatcher_impl.cpp
@@ -79,7 +79,7 @@ IPollerPtr TTcpDispatcher::TImpl::GetOrCreatePoller(
const TString& threadNamePrefix)
{
{
- auto guard = ReaderGuard(PollerLock_);
+ auto guard = ReaderGuard(PollersLock_);
if (*pollerPtr) {
return *pollerPtr;
}
@@ -87,13 +87,14 @@ IPollerPtr TTcpDispatcher::TImpl::GetOrCreatePoller(
IPollerPtr poller;
{
- auto guard = WriterGuard(PollerLock_);
+ auto guard = WriterGuard(PollersLock_);
+ auto config = Config_.Acquire();
if (!*pollerPtr) {
if (isXfer) {
*pollerPtr = CreateThreadPoolPoller(
- Config_->ThreadPoolSize,
+ config->ThreadPoolSize,
threadNamePrefix,
- Config_->ThreadPoolPollingPeriod);
+ config->ThreadPoolPollingPeriod);
} else {
*pollerPtr = CreateThreadPoolPoller(/*threadCount*/ 1, threadNamePrefix);
}
@@ -181,13 +182,13 @@ IPollerPtr TTcpDispatcher::TImpl::GetXferPoller()
void TTcpDispatcher::TImpl::Configure(const TTcpDispatcherConfigPtr& config)
{
{
- auto guard = WriterGuard(PollerLock_);
+ auto guard = WriterGuard(PollersLock_);
- Config_ = config;
+ Config_.Store(config);
if (XferPoller_) {
- XferPoller_->SetThreadCount(Config_->ThreadPoolSize);
- XferPoller_->SetPollingPeriod(Config_->ThreadPoolPollingPeriod);
+ XferPoller_->SetThreadCount(config->ThreadPoolSize);
+ XferPoller_->SetPollingPeriod(config->ThreadPoolPollingPeriod);
}
}
@@ -259,12 +260,7 @@ void TTcpDispatcher::TImpl::CollectSensors(ISensorWriter* writer)
}
});
- TTcpDispatcherConfigPtr config;
- {
- auto guard = ReaderGuard(PollerLock_);
- config = Config_;
- }
-
+ auto config = Config_.Acquire();
if (config->NetworkBandwidth) {
writer->AddGauge("/network_bandwidth_limit", *config->NetworkBandwidth);
}
@@ -348,8 +344,47 @@ void TTcpDispatcher::TImpl::OnPeriodicCheck()
std::optional<TString> TTcpDispatcher::TImpl::GetBusCertsDirectoryPath() const
{
- auto guard = ReaderGuard(PollerLock_);
- return Config_->BusCertsDirectoryPath;
+ return Config_.Acquire()->BusCertsDirectoryPath;
+}
+
+void TTcpDispatcher::TImpl::RegisterLocalMessageHandler(int port, const ILocalMessageHandlerPtr& handler)
+{
+ {
+ auto guard = WriterGuard(LocalMessageHandlersLock_);
+ if (!LocalMessageHandlers_.emplace(port, handler).second) {
+ THROW_ERROR_EXCEPTION("Local message handler is already registered for port %v",
+ port);
+ }
+ }
+
+ YT_LOG_INFO("Local message handler registered (Port: %v)",
+ port);
+}
+
+void TTcpDispatcher::TImpl::UnregisterLocalMessageHandler(int port)
+{
+ {
+ auto guard = WriterGuard(LocalMessageHandlersLock_);
+ LocalMessageHandlers_.erase(port);
+ }
+
+ YT_LOG_INFO("Local message handler unregistered (Port: %v)",
+ port);
+}
+
+ILocalMessageHandlerPtr TTcpDispatcher::TImpl::FindLocalBypassMessageHandler(const TNetworkAddress& address)
+{
+ if (!TAddressResolver::Get()->IsLocalAddress(address)) {
+ return nullptr;
+ }
+
+ if (!Config_.Acquire()->EnableLocalBypass) {
+ YT_LOG_INFO("XXX disabled %v", address);
+ return nullptr;
+ }
+
+ auto guard = ReaderGuard(LocalMessageHandlersLock_);
+ return GetOrDefault(LocalMessageHandlers_, address.GetPort());
}
////////////////////////////////////////////////////////////////////////////////
diff --git a/yt/yt/core/bus/tcp/dispatcher_impl.h b/yt/yt/core/bus/tcp/dispatcher_impl.h
index 09ed61d50d..054607bc3e 100644
--- a/yt/yt/core/bus/tcp/dispatcher_impl.h
+++ b/yt/yt/core/bus/tcp/dispatcher_impl.h
@@ -16,6 +16,8 @@
#include <library/cpp/yt/threading/rw_spin_lock.h>
#include <library/cpp/yt/threading/fork_aware_rw_spin_lock.h>
+#include <library/cpp/yt/memory/atomic_intrusive_ptr.h>
+
#include <atomic>
namespace NYT::NBus {
@@ -57,6 +59,10 @@ public:
std::optional<TString> GetBusCertsDirectoryPath() const;
+ void RegisterLocalMessageHandler(int port, const ILocalMessageHandlerPtr& handler);
+ void UnregisterLocalMessageHandler(int port);
+ ILocalMessageHandlerPtr FindLocalBypassMessageHandler(const NNet::TNetworkAddress& address);
+
private:
friend class TTcpDispatcher;
@@ -73,8 +79,9 @@ private:
std::vector<TTcpConnectionPtr> GetConnections();
void BuildOrchid(NYson::IYsonConsumer* consumer);
- YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, PollerLock_);
- TTcpDispatcherConfigPtr Config_ = New<TTcpDispatcherConfig>();
+ TAtomicIntrusivePtr<TTcpDispatcherConfig> Config_{New<TTcpDispatcherConfig>()};
+
+ YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, PollersLock_);
NConcurrency::IThreadPoolPollerPtr AcceptorPoller_;
NConcurrency::IThreadPoolPollerPtr XferPoller_;
@@ -106,6 +113,9 @@ private:
};
TEnumIndexedArray<EMultiplexingBand, TBandDescriptor> BandToDescriptor_;
+
+ YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, LocalMessageHandlersLock_);
+ THashMap<int, ILocalMessageHandlerPtr> LocalMessageHandlers_;
};
////////////////////////////////////////////////////////////////////////////////
diff --git a/yt/yt/core/bus/tcp/local_bypass.cpp b/yt/yt/core/bus/tcp/local_bypass.cpp
new file mode 100644
index 0000000000..911cbd805b
--- /dev/null
+++ b/yt/yt/core/bus/tcp/local_bypass.cpp
@@ -0,0 +1,139 @@
+#include "local_bypass.h"
+
+#include <yt/yt/core/bus/bus.h>
+
+#include <yt/yt/core/ytree/fluent.h>
+
+#include <yt/yt/core/net/address.h>
+
+namespace NYT::NBus {
+
+using namespace NYTree;
+
+////////////////////////////////////////////////////////////////////////////////
+
+class TLocalBypassReplyBus
+ : public IBus
+{
+public:
+ TLocalBypassReplyBus(
+ const NNet::TNetworkAddress& localAddress,
+ ILocalMessageHandlerPtr serverHandler,
+ IMessageHandlerPtr clientHandler)
+ : LocalAddress_(localAddress)
+ , ServerHandler_(std::move(serverHandler))
+ , ClientHandler_(std::move(clientHandler))
+ , EndpointDescription_(Format("local-bypass:%v", LocalAddress_))
+ , EndpointAttributes_(ConvertToAttributes(BuildYsonStringFluently()
+ .BeginMap()
+ .Item("local_bypass").Value(true)
+ .Item("address").Value(EndpointDescription_)
+ .EndMap()))
+ , ServerHandlerTerminatedCallback_(BIND(&TLocalBypassReplyBus::OnServerHandlerTerminated, MakeWeak(this)))
+ {
+ ServerHandler_->SubscribeTerminated(ServerHandlerTerminatedCallback_);
+ }
+
+ ~TLocalBypassReplyBus()
+ {
+ ServerHandler_->UnsubscribeTerminated(ServerHandlerTerminatedCallback_);
+ }
+
+ // IBus overrides.
+ const std::string& GetEndpointDescription() const override
+ {
+ return EndpointDescription_;
+ }
+
+ const NYTree::IAttributeDictionary& GetEndpointAttributes() const override
+ {
+ return *EndpointAttributes_;
+ }
+
+ TBusNetworkStatistics GetNetworkStatistics() const override
+ {
+ return {};
+ }
+
+ const std::string& GetEndpointAddress() const override
+ {
+ return EndpointDescription_;
+ }
+
+ const NNet::TNetworkAddress& GetEndpointNetworkAddress() const override
+ {
+ return LocalAddress_;
+ }
+
+ bool IsEndpointLocal() const override
+ {
+ return true;
+ }
+
+ bool IsEncrypted() const override
+ {
+ return false;
+ }
+
+ TFuture<void> GetReadyFuture() const override
+ {
+ return VoidFuture;
+ }
+
+ TFuture<void> Send(TSharedRefArray message, const NBus::TSendOptions& /*options*/) override
+ {
+ ClientHandler_->HandleMessage(std::move(message), /*replyBus*/ nullptr);
+ return VoidFuture;
+ }
+
+ void SetTosLevel(TTosLevel /*tosLevel*/) override
+ { }
+
+ void Terminate(const TError& error) override
+ {
+ TerminatedList_.Fire(error);
+ }
+
+ void SubscribeTerminated(const TCallback<void(const TError&)>& callback) override
+ {
+ TerminatedList_.Subscribe(callback);
+ }
+
+ void UnsubscribeTerminated(const TCallback<void(const TError&)>& callback) override
+ {
+ TerminatedList_.Unsubscribe(callback);
+ }
+
+private:
+ const NNet::TNetworkAddress LocalAddress_;
+ const ILocalMessageHandlerPtr ServerHandler_;
+ const IMessageHandlerPtr ClientHandler_;
+ const std::string EndpointDescription_;
+ const IAttributeDictionaryPtr EndpointAttributes_;
+ const TCallback<void(const TError&)> ServerHandlerTerminatedCallback_;
+
+ TSingleShotCallbackList<void(const TError&)> TerminatedList_;
+
+ void OnServerHandlerTerminated(const TError& error)
+ {
+ TerminatedList_.Fire(error);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+IBusPtr CreateLocalBypassReplyBus(
+ const NNet::TNetworkAddress& localAddress,
+ ILocalMessageHandlerPtr serverHandler,
+ IMessageHandlerPtr clientHandler)
+{
+ return New<TLocalBypassReplyBus>(
+ localAddress,
+ std::move(serverHandler),
+ std::move(clientHandler));
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT::NBus
+
diff --git a/yt/yt/core/bus/tcp/local_bypass.h b/yt/yt/core/bus/tcp/local_bypass.h
new file mode 100644
index 0000000000..221e7045e4
--- /dev/null
+++ b/yt/yt/core/bus/tcp/local_bypass.h
@@ -0,0 +1,32 @@
+#pragma once
+
+#include "private.h"
+
+#include <yt/yt/core/bus/bus.h>
+
+#include <yt/yt/core/net/public.h>
+
+namespace NYT::NBus {
+
+////////////////////////////////////////////////////////////////////////////////
+
+//! An in-process variant of IMessageHandler that is used implement local bypass.
+struct ILocalMessageHandler
+ : public IMessageHandler
+{
+ DECLARE_INTERFACE_SIGNAL(void(const TError&), Terminated);
+};
+
+DEFINE_REFCOUNTED_TYPE(ILocalMessageHandler)
+
+////////////////////////////////////////////////////////////////////////////////
+
+IBusPtr CreateLocalBypassReplyBus(
+ const NNet::TNetworkAddress& localAddress,
+ ILocalMessageHandlerPtr serverHandler,
+ IMessageHandlerPtr clientHandler);
+
+////////////////////////////////////////////////////////////////////////////////
+
+} // namespace NYT::NBus
+
diff --git a/yt/yt/core/bus/tcp/private.h b/yt/yt/core/bus/tcp/private.h
index a7114a203c..57e627eea2 100644
--- a/yt/yt/core/bus/tcp/private.h
+++ b/yt/yt/core/bus/tcp/private.h
@@ -6,6 +6,8 @@ namespace NYT::NBus {
////////////////////////////////////////////////////////////////////////////////
+DECLARE_REFCOUNTED_STRUCT(ILocalMessageHandler)
+
constexpr ui32 HandshakeMessageSignature = 0x68737562;
////////////////////////////////////////////////////////////////////////////////
diff --git a/yt/yt/core/bus/tcp/server.cpp b/yt/yt/core/bus/tcp/server.cpp
index a447b0a35f..a9fcf0c8a0 100644
--- a/yt/yt/core/bus/tcp/server.cpp
+++ b/yt/yt/core/bus/tcp/server.cpp
@@ -1,5 +1,6 @@
#include "server.h"
#include "config.h"
+#include "local_bypass.h"
#include "server.h"
#include "connection.h"
#include "dispatcher_impl.h"
@@ -50,18 +51,12 @@ public:
, Handler_(std::move(handler))
, PacketTranscoderFactory_(std::move(packetTranscoderFactory))
, MemoryUsageTracker_(std::move(memoryUsageTracker))
+ , Logger(MakeLogger(Config_))
{
YT_VERIFY(Config_);
YT_VERIFY(Poller_);
YT_VERIFY(Handler_);
YT_VERIFY(MemoryUsageTracker_);
-
- if (Config_->Port) {
- Logger.AddTag("ServerPort: %v", *Config_->Port);
- }
- if (Config_->UnixDomainSocketPath) {
- Logger.AddTag("UnixDomainSocketPath: %v", *Config_->UnixDomainSocketPath);
- }
}
~TTcpBusServerBase()
@@ -72,17 +67,23 @@ public:
void Start()
{
OpenServerSocket();
+
if (!Poller_->TryRegister(this)) {
CloseServerSocket();
THROW_ERROR_EXCEPTION("Cannot register server pollable");
}
+
ArmPoller();
+
+ YT_LOG_INFO("Bus server started");
}
TFuture<void> Stop()
{
YT_LOG_INFO("Stopping Bus server");
+
UnarmPoller();
+
return Poller_->Unregister(this).Apply(BIND([this, this_ = MakeStrong(this)] {
YT_LOG_INFO("Bus server stopped");
}));
@@ -94,12 +95,12 @@ public:
return Logger.GetTag();
}
- void OnEvent(EPollControl /*control*/) override
+ void OnEvent(EPollControl /*control*/) final
{
OnAccept();
}
- void OnShutdown() override
+ void OnShutdown() final
{
{
auto guard = Guard(ControlSpinLock_);
@@ -123,23 +124,32 @@ protected:
const TBusServerConfigPtr Config_;
const IPollerPtr Poller_;
const IMessageHandlerPtr Handler_;
-
IPacketTranscoderFactory* const PacketTranscoderFactory_;
-
const IMemoryUsageTrackerPtr MemoryUsageTracker_;
+ const NLogging::TLogger Logger;
+
YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, ControlSpinLock_);
SOCKET ServerSocket_ = INVALID_SOCKET;
YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, ConnectionsSpinLock_);
THashSet<TTcpConnectionPtr> Connections_;
- NLogging::TLogger Logger = BusLogger();
-
virtual void CreateServerSocket() = 0;
-
virtual void InitClientSocket(SOCKET clientSocket) = 0;
+ static NLogging::TLogger MakeLogger(const TBusServerConfigPtr& config)
+ {
+ auto logger = BusLogger();
+ if (config->Port) {
+ logger.AddTag("ServerPort: %v", *config->Port);
+ }
+ if (config->UnixDomainSocketPath) {
+ logger.AddTag("UnixDomainSocketPath: %v", *config->UnixDomainSocketPath);
+ }
+ return logger;
+ }
+
void OnConnectionTerminated(const TTcpConnectionPtr& connection, const TError& /*error*/)
{
auto guard = WriterGuard(ConnectionsSpinLock_);
@@ -263,7 +273,7 @@ protected:
{
auto guard = WriterGuard(ConnectionsSpinLock_);
- YT_VERIFY(Connections_.insert(connection).second);
+ EmplaceOrCrash(Connections_, connection);
}
connection->SubscribeTerminated(BIND_NO_PROPAGATE(
@@ -321,7 +331,7 @@ public:
using TTcpBusServerBase::TTcpBusServerBase;
private:
- void CreateServerSocket() override
+ void CreateServerSocket() final
{
ServerSocket_ = CreateTcpServerSocket();
@@ -329,7 +339,7 @@ private:
BindSocket(serverAddress, Format("Failed to bind a server socket to port %v", Config_->Port));
}
- void InitClientSocket(SOCKET clientSocket) override
+ void InitClientSocket(SOCKET clientSocket) final
{
if (Config_->EnableNoDelay) {
if (!TrySetSocketNoDelay(clientSocket)) {
@@ -364,7 +374,7 @@ public:
{ }
private:
- void CreateServerSocket() override
+ void CreateServerSocket() final
{
ServerSocket_ = CreateUnixServerSocket();
@@ -381,7 +391,7 @@ private:
}
}
- void InitClientSocket(SOCKET /*clientSocket*/) override
+ void InitClientSocket(SOCKET /*clientSocket*/) final
{ }
};
@@ -414,7 +424,7 @@ public:
YT_UNUSED_FUTURE(Stop());
}
- void Start(IMessageHandlerPtr handler) override
+ void Start(IMessageHandlerPtr handler) final
{
auto server = New<TServer>(
Config_,
@@ -427,7 +437,7 @@ public:
server->Start();
}
- TFuture<void> Stop() override
+ TFuture<void> Stop() final
{
if (auto server = Server_.Exchange(nullptr)) {
return server->Stop();
@@ -438,9 +448,7 @@ public:
private:
const TBusServerConfigPtr Config_;
-
IPacketTranscoderFactory* const PacketTranscoderFactory_;
-
const IMemoryUsageTrackerPtr MemoryUsageTracker_;
TAtomicIntrusivePtr<TServer> Server_;
@@ -448,34 +456,91 @@ private:
////////////////////////////////////////////////////////////////////////////////
+DECLARE_REFCOUNTED_CLASS(TLocalMessageHandler)
+
+class TLocalMessageHandler
+ : public ILocalMessageHandler
+{
+public:
+ explicit TLocalMessageHandler(IMessageHandlerPtr underlying)
+ : Underlying_(std::move(underlying))
+ { }
+
+ void HandleMessage(TSharedRefArray message, IBusPtr replyBus) noexcept final
+ {
+ Underlying_->HandleMessage(std::move(message), std::move(replyBus));
+ }
+
+ void SubscribeTerminated(const TCallback<void(const TError&)>& callback) final
+ {
+ Terminated_.Subscribe(callback);
+ }
+
+ void UnsubscribeTerminated(const TCallback<void(const TError&)>& callback) final
+ {
+ Terminated_.Unsubscribe(callback);
+ }
+
+ void Terminate(const TError& error)
+ {
+ Terminated_.Fire(error);
+ }
+
+private:
+ const IMessageHandlerPtr Underlying_;
+
+ TSingleShotCallbackList<void(const TError&)> Terminated_;
+};
+
+DEFINE_REFCOUNTED_TYPE(TLocalMessageHandler)
+
+////////////////////////////////////////////////////////////////////////////////
+
class TCompositeBusServer
: public IBusServer
{
public:
- explicit TCompositeBusServer(std::vector<IBusServerPtr> servers)
- : Servers_(std::move(servers))
+ TCompositeBusServer(
+ TBusServerConfigPtr config,
+ std::vector<IBusServerPtr> servers)
+ : Config_(std::move(config))
+ , Servers_(std::move(servers))
{ }
// IBusServer implementation.
- void Start(IMessageHandlerPtr handler) override
+ void Start(IMessageHandlerPtr handler) final
{
for (const auto& server : Servers_) {
server->Start(handler);
}
+
+ if (Config_->EnableLocalBypass && Config_->Port) {
+ LocalHandler_ = New<TLocalMessageHandler>(std::move(handler));
+ TTcpDispatcher::TImpl::Get()->RegisterLocalMessageHandler(*Config_->Port, LocalHandler_);
+ }
}
- TFuture<void> Stop() override
+ TFuture<void> Stop() final
{
- std::vector<TFuture<void>> asyncResults;
+ if (Config_->EnableLocalBypass && Config_->Port) {
+ LocalHandler_->Terminate(TError(NRpc::EErrorCode::TransportError, "Local server stopped"));
+ LocalHandler_.Reset();
+ TTcpDispatcher::TImpl::Get()->UnregisterLocalMessageHandler(*Config_->Port);
+ }
+
+ std::vector<TFuture<void>> futures;
for (const auto& server : Servers_) {
- asyncResults.push_back(server->Stop());
+ futures.push_back(server->Stop());
}
- return AllSucceeded(asyncResults);
+ return AllSucceeded(futures);
}
private:
+ const TBusServerConfigPtr Config_;
const std::vector<IBusServerPtr> Servers_;
+
+ TLocalMessageHandlerPtr LocalHandler_;
};
////////////////////////////////////////////////////////////////////////////////
@@ -494,6 +559,7 @@ IBusServerPtr CreateBusServer(
packetTranscoderFactory,
memoryUsageTracker));
}
+
#ifdef _linux_
// Abstract unix sockets are supported only on Linux.
servers.push_back(
@@ -503,7 +569,9 @@ IBusServerPtr CreateBusServer(
memoryUsageTracker));
#endif
- return New<TCompositeBusServer>(std::move(servers));
+ return New<TCompositeBusServer>(
+ config,
+ std::move(servers));
}
////////////////////////////////////////////////////////////////////////////////
diff --git a/yt/yt/core/net/address.cpp b/yt/yt/core/net/address.cpp
index d048167af3..70a5d78bf5 100644
--- a/yt/yt/core/net/address.cpp
+++ b/yt/yt/core/net/address.cpp
@@ -1029,10 +1029,12 @@ public:
bool IsLocalAddress(const TNetworkAddress& address)
{
- TNetworkAddress localIP{address, 0};
-
+ if (!address.IsIP()) {
+ return false;
+ }
+ TNetworkAddress candidateAddress(address, /*port*/ 0);
const auto& localAddresses = GetLocalAddresses();
- return std::find(localAddresses.begin(), localAddresses.end(), localIP) != localAddresses.end();
+ return std::find(localAddresses.begin(), localAddresses.end(), candidateAddress) != localAddresses.end();
}
void PurgeCache()
diff --git a/yt/yt/core/rpc/bus/channel.cpp b/yt/yt/core/rpc/bus/channel.cpp
index 3a0214e4de..d90dc774d9 100644
--- a/yt/yt/core/rpc/bus/channel.cpp
+++ b/yt/yt/core/rpc/bus/channel.cpp
@@ -284,8 +284,7 @@ private:
void HandleMessage(TSharedRefArray message, IBusPtr replyBus) noexcept override
{
- auto session_ = Session_.Lock();
- if (session_) {
+ if (auto session_ = Session_.Lock()) {
session_->HandleMessage(std::move(message), std::move(replyBus));
}
}
diff --git a/yt/yt/core/rpc/unittests/lib/common.h b/yt/yt/core/rpc/unittests/lib/common.h
index 9d2f5fca81..5326f017cc 100644
--- a/yt/yt/core/rpc/unittests/lib/common.h
+++ b/yt/yt/core/rpc/unittests/lib/common.h
@@ -9,6 +9,7 @@
#include <yt/yt/core/bus/tcp/config.h>
#include <yt/yt/core/bus/tcp/client.h>
+#include <yt/yt/core/bus/tcp/dispatcher.h>
#include <yt/yt/core/bus/tcp/server.h>
#include <yt/yt/core/crypto/config.h>
@@ -75,21 +76,25 @@ class TRpcTestBase
public:
void SetUp() final
{
- bool secure = TImpl::Secure;
-
WorkerPool_ = NConcurrency::CreateThreadPool(4, "Worker");
MemoryUsageTracker_ = New<TTestNodeMemoryTracker>(32_MB);
- TestService_ = CreateTestService(WorkerPool_->GetInvoker(), secure, {}, MemoryUsageTracker_);
+ TestService_ = CreateTestService(WorkerPool_->GetInvoker(), TImpl::Secure, {}, MemoryUsageTracker_);
auto services = std::vector<IServicePtr>{
TestService_,
- CreateNoBaggageService(WorkerPool_->GetInvoker())
+ CreateNoBaggageService(WorkerPool_->GetInvoker()),
};
Host_ = TImpl::CreateTestServerHost(
NTesting::GetFreePort(),
std::move(services),
MemoryUsageTracker_);
+
+ // Make sure local bypass is globally enabled.
+ // Individual tests will toggle per-connection flag to actually enable this feature.
+ auto config = New<NYT::NBus::TTcpDispatcherConfig>();
+ config->EnableLocalBypass = true;
+ NYT::NBus::TTcpDispatcher::Get()->Configure(config);
}
void TearDown() final
@@ -98,14 +103,13 @@ public:
}
IChannelPtr CreateChannel(
- const std::optional<TString>& address = std::nullopt,
+ const std::optional<TString>& address = {},
THashMap<TString, NYTree::INodePtr> grpcArguments = {})
{
- if (address) {
- return TImpl::CreateChannel(*address, Host_->GetAddress(), std::move(grpcArguments));
- } else {
- return TImpl::CreateChannel(Host_->GetAddress(), Host_->GetAddress(), std::move(grpcArguments));
- }
+ return TImpl::CreateChannel(
+ address.value_or(Host_->GetAddress()),
+ Host_->GetAddress(),
+ std::move(grpcArguments));
}
TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() const
@@ -166,13 +170,14 @@ public:
static constexpr bool AllowTransportErrors = false;
static constexpr bool Secure = false;
static constexpr int MaxSimultaneousRequestCount = 1000;
+ static constexpr bool MemoryUsageTrackingEnabled = TImpl::MemoryUsageTrackingEnabled;
static TTestServerHostPtr CreateTestServerHost(
NTesting::TPortHolder port,
std::vector<IServicePtr> services,
TTestNodeMemoryTrackerPtr memoryUsageTracker)
{
- auto busServer = MakeBusServer(port, memoryUsageTracker);
+ auto busServer = CreateBusServer(port, memoryUsageTracker);
auto server = NRpc::NBus::CreateBusServer(busServer);
return New<TTestServerHost>(
@@ -190,34 +195,39 @@ public:
return TImpl::CreateChannel(address, serverAddress, std::move(grpcArguments));
}
- static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
+ static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
{
- return TImpl::MakeBusServer(port, memoryUsageTracker);
+ return TImpl::CreateBusServer(port, memoryUsageTracker);
}
};
////////////////////////////////////////////////////////////////////////////////
-template <bool ForceTcp>
+template <bool EnableLocalBypass>
class TRpcOverBusImpl
{
public:
+ static constexpr bool MemoryUsageTrackingEnabled = !EnableLocalBypass;
+
static IChannelPtr CreateChannel(
const std::string& address,
const std::string& /*serverAddress*/,
THashMap<TString, NYTree::INodePtr> /*grpcArguments*/)
{
- auto client = CreateBusClient(NYT::NBus::TBusClientConfig::CreateTcp(address));
- return NRpc::NBus::CreateBusChannel(client);
+ auto config = NYT::NBus::TBusClientConfig::CreateTcp(address);
+ config->EnableLocalBypass = EnableLocalBypass;
+ auto client = CreateBusClient(std::move(config));
+ return NRpc::NBus::CreateBusChannel(std::move(client));
}
- static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
+ static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
{
- auto busConfig = NYT::NBus::TBusServerConfig::CreateTcp(port);
- return CreateBusServer(
- busConfig,
+ auto config = NYT::NBus::TBusServerConfig::CreateTcp(port);
+ config->EnableLocalBypass = EnableLocalBypass;
+ return NYT::NBus::CreateBusServer(
+ std::move(config),
NYT::NBus::GetYTPacketTranscoderFactory(),
- memoryUsageTracker);
+ std::move(memoryUsageTracker));
}
};
@@ -439,12 +449,14 @@ public:
class TRpcOverUdsImpl
{
public:
- static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
+ static constexpr bool MemoryUsageTrackingEnabled = true;
+
+ static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker)
{
SocketPath_ = GetWorkPath() + "/socket_" + ToString(port);
- auto busConfig = NYT::NBus::TBusServerConfig::CreateUds(SocketPath_);
- return CreateBusServer(
- busConfig,
+ auto config = NYT::NBus::TBusServerConfig::CreateUds(SocketPath_);
+ return NYT::NBus::CreateBusServer(
+ config,
NYT::NBus::GetYTPacketTranscoderFactory(),
memoryUsageTracker);
}
@@ -454,9 +466,9 @@ public:
const std::string& serverAddress,
THashMap<TString, NYTree::INodePtr> /*grpcArguments*/)
{
- auto clientConfig = NYT::NBus::TBusClientConfig::CreateUds(
+ auto config = NYT::NBus::TBusClientConfig::CreateUds(
address == serverAddress ? SocketPath_ : address);
- auto client = CreateBusClient(clientConfig);
+ auto client = CreateBusClient(config);
return NRpc::NBus::CreateBusChannel(client);
}
@@ -530,9 +542,9 @@ public:
using TAllTransports = ::testing::Types<
#ifdef _linux_
TRpcOverBus<TRpcOverUdsImpl>,
- TRpcOverBus<TRpcOverBusImpl<true>>,
#endif
TRpcOverBus<TRpcOverBusImpl<false>>,
+ TRpcOverBus<TRpcOverBusImpl<true>>,
TRpcOverGrpcImpl<false, false>,
TRpcOverGrpcImpl<false, true>,
TRpcOverGrpcImpl<true, false>,
@@ -544,9 +556,9 @@ using TAllTransports = ::testing::Types<
using TWithAttachments = ::testing::Types<
#ifdef _linux_
TRpcOverBus<TRpcOverUdsImpl>,
- TRpcOverBus<TRpcOverBusImpl<true>>,
#endif
TRpcOverBus<TRpcOverBusImpl<false>>,
+ TRpcOverBus<TRpcOverBusImpl<true>>,
TRpcOverGrpcImpl<false, false>,
TRpcOverGrpcImpl<false, true>,
TRpcOverGrpcImpl<true, false>,
@@ -554,10 +566,8 @@ using TWithAttachments = ::testing::Types<
>;
using TWithoutUds = ::testing::Types<
-#ifdef _linux_
- TRpcOverBus<TRpcOverBusImpl<true>>,
-#endif
TRpcOverBus<TRpcOverBusImpl<false>>,
+ TRpcOverBus<TRpcOverBusImpl<true>>,
TRpcOverGrpcImpl<false, false>,
TRpcOverGrpcImpl<true, false>,
TRpcOverHttpImpl<false>,
@@ -567,9 +577,9 @@ using TWithoutUds = ::testing::Types<
using TWithoutGrpc = ::testing::Types<
#ifdef _linux_
TRpcOverBus<TRpcOverUdsImpl>,
- TRpcOverBus<TRpcOverBusImpl<true>>,
#endif
- TRpcOverBus<TRpcOverBusImpl<false>>
+ TRpcOverBus<TRpcOverBusImpl<false>>,
+ TRpcOverBus<TRpcOverBusImpl<true>>
>;
using TGrpcOnly = ::testing::Types<
diff --git a/yt/yt/core/rpc/unittests/rpc_ut.cpp b/yt/yt/core/rpc/unittests/rpc_ut.cpp
index e3ce7376ee..94e8d13e7a 100644
--- a/yt/yt/core/rpc/unittests/rpc_ut.cpp
+++ b/yt/yt/core/rpc/unittests/rpc_ut.cpp
@@ -508,9 +508,9 @@ TYPED_TEST(TAttachmentsTest, RegularAttachments)
const auto& attachments = rsp->Attachments();
EXPECT_EQ(3u, attachments.size());
- EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0]));
- EXPECT_EQ("from_", StringFromSharedRef(attachments[1]));
- EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2]));
+ EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0]));
+ EXPECT_EQ("from_", StringFromSharedRef(attachments[1]));
+ EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2]));
}
TYPED_TEST(TNotGrpcTest, TrackedRegularAttachments)
@@ -536,11 +536,13 @@ TYPED_TEST(TNotGrpcTest, TrackedRegularAttachments)
// header + body = 79 bytes.
// attachments = 22 bytes.
// sum is 4219 bytes.
- EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4197 + 32768 + std::ssize(GetRpcUserAgent()));
+ if (TypeParam::MemoryUsageTrackingEnabled) {
+ EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4197 + 32768 + std::ssize(GetRpcUserAgent()));
+ }
EXPECT_EQ(3u, attachments.size());
- EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0]));
- EXPECT_EQ("from_", StringFromSharedRef(attachments[1]));
- EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2]));
+ EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0]));
+ EXPECT_EQ("from_", StringFromSharedRef(attachments[1]));
+ EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2]));
}
TYPED_TEST(TAttachmentsTest, NullAndEmptyAttachments)
@@ -602,7 +604,9 @@ TYPED_TEST(TNotGrpcTest, Compression)
// attachmentStrings[1].size() = 36 * 2 bytes from decoder.
// attachmentStrings[2].size() = 90 * 2 bytes from decoder.
// sum is 4584 bytes.
- EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4562 + 32768 + std::ssize(GetRpcUserAgent()));
+ if (TypeParam::MemoryUsageTrackingEnabled) {
+ EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4562 + 32768 + std::ssize(GetRpcUserAgent()));
+ }
EXPECT_TRUE(rsp->message() == message);
EXPECT_GE(rsp->GetResponseMessage().Size(), static_cast<size_t>(2));
const auto& serializedResponseBody = SerializeProtoToRefWithCompression(*rsp, responseCodecId);
@@ -828,7 +832,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTracking)
Sleep(TDuration::MilliSeconds(200));
- {
+ if (TypeParam::MemoryUsageTrackingEnabled) {
auto rpcUsage = memoryUsageTracker->GetTotalUsage();
// 1285268 = 32768 + 1252500 = 32768 + 4096 * 300 + 300 * 79 (header + body).
@@ -849,7 +853,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConnections)
WaitFor(req->Invoke().AsVoid()).ThrowOnError();
}
- {
+ if (TypeParam::MemoryUsageTrackingEnabled) {
// 11082900 / 300 = 36974 = 32768 + 4096 + 79 (header + body).
// 4 KB - stub for request.
// See NYT::NBus::TPacketDecoder::TChunkedMemoryTrackingAllocator::Allocate.
@@ -877,7 +881,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConcurrent)
Sleep(TDuration::MilliSeconds(100));
- {
+ if (TypeParam::MemoryUsageTrackingEnabled) {
auto rpcUsage = memoryUsageTracker->GetUsed();
// connections count - per connection size.
@@ -901,7 +905,8 @@ TYPED_TEST(TNotGrpcTest, MemoryOvercommit)
req->set_request_codec(ToProto(requestCodecId));
req->Attachments().push_back(TSharedRef::FromString(TString(6_KB, 'x')));
WaitFor(req->Invoke()).ThrowOnError();
- {
+
+ if (TypeParam::MemoryUsageTrackingEnabled) {
auto rpcUsage = memoryReferenceUsageTracker->GetTotalUsage();
// Attachment allocator proactively allocate slice of 4 KB.
diff --git a/yt/yt/core/ya.make b/yt/yt/core/ya.make
index a47b5ddfbf..857e56920b 100644
--- a/yt/yt/core/ya.make
+++ b/yt/yt/core/ya.make
@@ -30,6 +30,7 @@ SRCS(
GLOBAL bus/tcp/configure_dispatcher.cpp
bus/tcp/packet.cpp
bus/tcp/client.cpp
+ bus/tcp/local_bypass.cpp
bus/tcp/server.cpp
bus/tcp/ssl_context.cpp
bus/tcp/ssl_helpers.cpp