diff options
author | babenko <babenko@yandex-team.com> | 2025-01-19 20:18:26 +0300 |
---|---|---|
committer | babenko <babenko@yandex-team.com> | 2025-01-19 20:35:20 +0300 |
commit | 02877bfb2eed135564651ea971c8c26f6b173266 (patch) | |
tree | 61aa7eb7b655564d6af75ecd37ebd1bec9485d82 | |
parent | 73afb75a4295ec99f0c8c49ab30fbff2b7a542fe (diff) | |
download | ydb-02877bfb2eed135564651ea971c8c26f6b173266.tar.gz |
Implement local bypass for intra-process Bus communication
commit_hash:8332b96f0bc1eb8faef360bf472d47b08923a79d
-rw-r--r-- | yt/yt/core/bus/bus.h | 2 | ||||
-rw-r--r-- | yt/yt/core/bus/client.h | 2 | ||||
-rw-r--r-- | yt/yt/core/bus/private.h | 4 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/client.cpp | 44 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/config.cpp | 9 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/config.h | 6 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/connection.cpp | 231 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/connection.h | 12 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/dispatcher.h | 1 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/dispatcher_impl.cpp | 67 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/dispatcher_impl.h | 14 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/local_bypass.cpp | 139 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/local_bypass.h | 32 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/private.h | 2 | ||||
-rw-r--r-- | yt/yt/core/bus/tcp/server.cpp | 130 | ||||
-rw-r--r-- | yt/yt/core/net/address.cpp | 8 | ||||
-rw-r--r-- | yt/yt/core/rpc/bus/channel.cpp | 3 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/lib/common.h | 78 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/rpc_ut.cpp | 29 | ||||
-rw-r--r-- | yt/yt/core/ya.make | 1 |
20 files changed, 625 insertions, 189 deletions
diff --git a/yt/yt/core/bus/bus.h b/yt/yt/core/bus/bus.h index 6c0ded55a4..4c08a78440 100644 --- a/yt/yt/core/bus/bus.h +++ b/yt/yt/core/bus/bus.h @@ -126,7 +126,7 @@ struct IBus /*! * Does not block -- termination typically happens in background. * It is safe to call this method multiple times. - * On terminated the instance is no longer usable. + * After termination the instance is no longer usable. * \note Thread affinity: any. */ diff --git a/yt/yt/core/bus/client.h b/yt/yt/core/bus/client.h index 8c23346b95..fa12665fc2 100644 --- a/yt/yt/core/bus/client.h +++ b/yt/yt/core/bus/client.h @@ -24,7 +24,7 @@ struct IBusClient //! Typically used for logging. virtual const std::string& GetEndpointDescription() const = 0; - //! Returns the bus' endpoint attributes. + //! Returns the bus endpoint attributes. //! Typically used for constructing errors. virtual const NYTree::IAttributeDictionary& GetEndpointAttributes() const = 0; diff --git a/yt/yt/core/bus/private.h b/yt/yt/core/bus/private.h index 1f8d52f96a..e334303fdf 100644 --- a/yt/yt/core/bus/private.h +++ b/yt/yt/core/bus/private.h @@ -15,6 +15,10 @@ namespace NYT::NBus { YT_DEFINE_GLOBAL(const NLogging::TLogger, BusLogger, "Bus"); YT_DEFINE_GLOBAL(const NProfiling::TProfiler, BusProfiler, "/bus"); +//////////////////////////////////////////////////////////////////////////////// + +DECLARE_REFCOUNTED_STRUCT(IMessageHandler) + using TConnectionId = TGuid; using TPacketId = TGuid; diff --git a/yt/yt/core/bus/tcp/client.cpp b/yt/yt/core/bus/tcp/client.cpp index d2ddf6b24e..5dd0117bee 100644 --- a/yt/yt/core/bus/tcp/client.cpp +++ b/yt/yt/core/bus/tcp/client.cpp @@ -146,20 +146,9 @@ public: : Config_(std::move(config)) , PacketTranscoderFactory_(packetTranscoderFactory) , MemoryUsageTracker_(std::move(memoryUsageTracker)) - { - if (Config_->Address) { - EndpointDescription_ = *Config_->Address; - } else if (Config_->UnixDomainSocketPath) { - EndpointDescription_ = Format("unix://%v", *Config_->UnixDomainSocketPath); - } - - EndpointAttributes_ = ConvertToAttributes(BuildYsonStringFluently() - .BeginMap() - .Item("address").Value(EndpointDescription_) - .Item("encryption_mode").Value(Config_->EncryptionMode) - .Item("verification_mode").Value(Config_->VerificationMode) - .EndMap()); - } + , EndpointDescription_(MakeEndpointDescription(Config_)) + , EndpointAttributes_(MakeEndpointAttributes(Config_, EndpointDescription_)) + { } const std::string& GetEndpointDescription() const override { @@ -215,13 +204,32 @@ public: private: const TBusClientConfigPtr Config_; - IPacketTranscoderFactory* const PacketTranscoderFactory_; - const IMemoryUsageTrackerPtr MemoryUsageTracker_; + const std::string EndpointDescription_; + const IAttributeDictionaryPtr EndpointAttributes_; - std::string EndpointDescription_; - IAttributeDictionaryPtr EndpointAttributes_; + static std::string MakeEndpointDescription(const TBusClientConfigPtr& config) + { + if (config->Address) { + return *config->Address; + } else if (config->UnixDomainSocketPath) { + return Format("unix://%v", *config->UnixDomainSocketPath); + } + YT_ABORT(); + } + + static IAttributeDictionaryPtr MakeEndpointAttributes( + const TBusClientConfigPtr& config, + const std::string& endpointDescription) + { + return ConvertToAttributes(BuildYsonStringFluently() + .BeginMap() + .Item("address").Value(endpointDescription) + .Item("encryption_mode").Value(config->EncryptionMode) + .Item("verification_mode").Value(config->VerificationMode) + .EndMap()); + } }; //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/config.cpp b/yt/yt/core/bus/tcp/config.cpp index 18e8371bf0..f67c6fba6c 100644 --- a/yt/yt/core/bus/tcp/config.cpp +++ b/yt/yt/core/bus/tcp/config.cpp @@ -52,6 +52,9 @@ void TTcpDispatcherConfig::Register(TRegistrar registrar) registrar.Parameter("bus_certs_directory_path", &TThis::BusCertsDirectoryPath) .Default(); + + registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass) + .Default(false); } TTcpDispatcherConfigPtr TTcpDispatcherConfig::ApplyDynamic( @@ -63,6 +66,7 @@ TTcpDispatcherConfigPtr TTcpDispatcherConfig::ApplyDynamic( UpdateYsonStructField(mergedConfig->Networks, dynamicConfig->Networks); UpdateYsonStructField(mergedConfig->MultiplexingBands, dynamicConfig->MultiplexingBands); UpdateYsonStructField(mergedConfig->BusCertsDirectoryPath, dynamicConfig->BusCertsDirectoryPath); + UpdateYsonStructField(mergedConfig->EnableLocalBypass, dynamicConfig->EnableLocalBypass); mergedConfig->Postprocess(); return mergedConfig; } @@ -89,6 +93,9 @@ void TTcpDispatcherDynamicConfig::Register(TRegistrar registrar) registrar.Parameter("bus_certs_directory_path", &TThis::BusCertsDirectoryPath) .Default(); + + registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass) + .Default(); } //////////////////////////////////////////////////////////////////////////////// @@ -141,6 +148,8 @@ void TBusConfig::Register(TRegistrar registrar) .Default(true); registrar.Parameter("generate_checksums", &TThis::GenerateChecksums) .Default(true); + registrar.Parameter("enable_local_bypass", &TThis::EnableLocalBypass) + .Default(true); registrar.Parameter("encryption_mode", &TThis::EncryptionMode) .Default(EEncryptionMode::Optional); registrar.Parameter("verification_mode", &TThis::VerificationMode) diff --git a/yt/yt/core/bus/tcp/config.h b/yt/yt/core/bus/tcp/config.h index 914035e2d9..491612fabb 100644 --- a/yt/yt/core/bus/tcp/config.h +++ b/yt/yt/core/bus/tcp/config.h @@ -50,6 +50,8 @@ public: //! Used to store TLS/SSL certificate files. std::optional<TString> BusCertsDirectoryPath; + bool EnableLocalBypass; + TTcpDispatcherConfigPtr ApplyDynamic(const TTcpDispatcherDynamicConfigPtr& dynamicConfig) const; REGISTER_YSON_STRUCT(TTcpDispatcherConfig); @@ -78,7 +80,7 @@ public: //! Used to store TLS/SSL certificate files. std::optional<TString> BusCertsDirectoryPath; - static void Setup(auto&& registrar); + std::optional<bool> EnableLocalBypass; REGISTER_YSON_STRUCT(TTcpDispatcherDynamicConfig); @@ -107,6 +109,8 @@ public: bool VerifyChecksums; bool GenerateChecksums; + bool EnableLocalBypass; + // Ssl options. EEncryptionMode EncryptionMode; EVerificationMode VerificationMode; diff --git a/yt/yt/core/bus/tcp/connection.cpp b/yt/yt/core/bus/tcp/connection.cpp index 2ab85dcca0..6807a9d8d4 100644 --- a/yt/yt/core/bus/tcp/connection.cpp +++ b/yt/yt/core/bus/tcp/connection.cpp @@ -1,7 +1,7 @@ #include "connection.h" #include "config.h" -#include "server.h" +#include "local_bypass.h" #include "dispatcher_impl.h" #include "ssl_context.h" #include "ssl_helpers.h" @@ -125,12 +125,9 @@ TTcpConnection::TTcpConnection( , UnixDomainSocketPath_(unixDomainSocketPath) , Handler_(std::move(handler)) , Poller_(std::move(poller)) - , LoggingTag_(Format("ConnectionId: %v, ConnectionType: %v, RemoteAddress: %v, EncryptionMode: %v, VerificationMode: %v", + , LoggingTag_(Format("ConnectionId: %v, Endpoint: %v", Id_, - ConnectionType_, - EndpointDescription_, - Config_->EncryptionMode, - Config_->VerificationMode)) + EndpointDescription_)) , Logger(BusLogger().WithRawTag(LoggingTag_)) , GenerateChecksums_(Config_->GenerateChecksums) , Socket_(socket) @@ -196,15 +193,24 @@ void TTcpConnection::Close() EncodedFragments_.clear(); + ILocalMessageHandlerPtr localBypassHandler; { auto guard = Guard(Lock_); FlushStatistics(); + localBypassHandler = LocalBypassHandler_; + } + + if (localBypassHandler) { + localBypassHandler->UnsubscribeTerminated(LocalBypassTerminatedCallback_); } } void TTcpConnection::Start() { - YT_LOG_DEBUG("Starting TCP connection"); + YT_LOG_DEBUG("Starting TCP connection (ConnectionType: %v, EncryptionMode: %v, VerificationMode: %v)", + ConnectionType_, + Config_->EncryptionMode, + Config_->VerificationMode); // Offline in PendingControl_ prevents retrying events until end of Open(). YT_VERIFY(Any(static_cast<EPollControl>(PendingControl_.load()) & EPollControl::Offline)); @@ -480,17 +486,26 @@ void TTcpConnection::OnAddressResolveFinished(const TErrorOr<TNetworkAddress>& r } TNetworkAddress address(result.Value(), Port_); - OnAddressResolved(address); - YT_LOG_DEBUG("Connection network address resolved (Address: %v, NetworkName: %v)", address, NetworkName_); + + OnAddressResolved(address); } void TTcpConnection::OnAddressResolved(const TNetworkAddress& address) { State_ = EState::Opening; + SetupNetwork(address); + + if (Config_->EnableLocalBypass) { + if (auto localHandler = TTcpDispatcher::TImpl::Get()->FindLocalBypassMessageHandler(address)) { + InitLocalBypass(std::move(localHandler), address); + return; + } + } + ConnectSocket(address); } @@ -609,6 +624,61 @@ int TTcpConnection::GetSocketPort() } } +void TTcpConnection::InitLocalBypass( + ILocalMessageHandlerPtr localBypassHandler, + const NNet::TNetworkAddress& address) +{ + { + auto guard = Guard(Lock_); + + if (State_ != EState::Opening) { + return; + } + + State_ = EState::Open; + + LocalBypassHandler_ = std::move(localBypassHandler); + + LocalBypassReplyBus_ = CreateLocalBypassReplyBus( + address, + LocalBypassHandler_, + Handler_); + + LocalBypassTerminatedCallback_ = BIND(&TTcpConnection::OnLocalBypassHandlerTerminated, MakeWeak(this)); + LocalBypassHandler_->SubscribeTerminated(LocalBypassTerminatedCallback_); + + UpdateConnectionCount(+1); + + LocalBypassActive_.store(true, std::memory_order::release); + + YT_LOG_DEBUG("Local bypass activated"); + } + + ReadyPromise_.TrySet(); + + FlushQueuedMessagesToLocalBypass(); +} + +void TTcpConnection::OnLocalBypassHandlerTerminated(const TError& error) +{ + Abort(error); +} + +void TTcpConnection::FlushQueuedMessagesToLocalBypass() +{ + QueuedMessages_.DequeueAll( + /*reverse*/ true, + [&] (auto& queuedMessage) { + // Log first to avoid producing weird traces. + YT_LOG_DEBUG("Queued message sent via local bypass (PacketId: %v)", queuedMessage.PacketId); + LocalBypassHandler_->HandleMessage(std::move(queuedMessage.Message), LocalBypassReplyBus_); + + if (queuedMessage.Promise) { + queuedMessage.Promise.TrySet(); + } + }); +} + void TTcpConnection::ConnectSocket(const TNetworkAddress& address) { auto dialer = CreateAsyncDialer( @@ -734,37 +804,13 @@ TFuture<void> TTcpConnection::Send(TSharedRefArray message, const TSendOptions& } } - TQueuedMessage queuedMessage(std::move(message), options); - auto promise = queuedMessage.Promise; - auto pendingOutPayloadBytes = PendingOutPayloadBytes_.fetch_add(queuedMessage.PayloadSize); - - // Log first to avoid producing weird traces. - YT_LOG_DEBUG("Outcoming message enqueued (PacketId: %v, PendingOutPayloadBytes: %v)", - queuedMessage.PacketId, - pendingOutPayloadBytes); - - if (LastIncompleteWriteTime_ == std::numeric_limits<NProfiling::TCpuInstant>::max()) { - // Arm stall detection. - LastIncompleteWriteTime_ = NProfiling::GetCpuInstant(); - } - - QueuedMessages_.Enqueue(std::move(queuedMessage)); - - // Wake up the event processing if needed. - { - auto previousPendingControl = static_cast<EPollControl>(PendingControl_.fetch_or(static_cast<ui64>(EPollControl::Write))); - if (None(previousPendingControl)) { - YT_LOG_TRACE("Retrying event processing for Send"); - Poller_->Retry(this); - } - } - - // Double-check the state not to leave any dangling outcoming messages. - if (State_.load() == EState::Closed) { - DiscardOutcomingMessages(); + // Fast path for bypass. + if (LocalBypassActive_.load(std::memory_order::acquire)) { + return SendViaLocalBypass(std::move(message), options); } - return promise; + // Slow path. + return SendViaSocket(std::move(message), options); } void TTcpConnection::SetTosLevel(TTosLevel tosLevel) @@ -1276,6 +1322,57 @@ bool TTcpConnection::OnHandshakePacketReceived() return true; } +TFuture<void> TTcpConnection::SendViaSocket(TSharedRefArray message, const TSendOptions& options) +{ + TQueuedMessage queuedMessage(std::move(message), options); + auto promise = queuedMessage.Promise; + auto pendingOutPayloadBytes = PendingOutPayloadBytes_.fetch_add(queuedMessage.PayloadSize); + + // Log first to avoid producing weird traces. + YT_LOG_DEBUG("Outcoming message enqueued (PacketId: %v, PendingOutPayloadBytes: %v)", + queuedMessage.PacketId, + pendingOutPayloadBytes); + + if (LastIncompleteWriteTime_ == std::numeric_limits<NProfiling::TCpuInstant>::max()) { + // Arm stall detection. + LastIncompleteWriteTime_ = NProfiling::GetCpuInstant(); + } + + QueuedMessages_.Enqueue(std::move(queuedMessage)); + + // Wake up the event processing if needed. + { + auto previousPendingControl = static_cast<EPollControl>(PendingControl_.fetch_or(static_cast<ui64>(EPollControl::Write))); + if (None(previousPendingControl)) { + YT_LOG_TRACE("Retrying event processing for Send"); + Poller_->Retry(this); + } + } + + // Double-check the state not to leave any dangling outcoming messages. + if (State_.load() == EState::Closed) { + DiscardOutcomingMessages(); + } + + // Another double-check to prevent a race with bypass activation. + if (LocalBypassActive_.load(std::memory_order::acquire)) { + FlushQueuedMessagesToLocalBypass(); + } + + return promise; +} + +TFuture<void> TTcpConnection::SendViaLocalBypass(TSharedRefArray message, const TSendOptions& /*options*/) +{ + // Log first to avoid producing weird traces. + YT_LOG_DEBUG("Outcoming message sent via local bypass"); + + LocalBypassHandler_->HandleMessage(std::move(message), LocalBypassReplyBus_); + + // No delivery tracking for local bypass. + return VoidFuture; +} + TTcpConnection::TPacket* TTcpConnection::EnqueuePacket( EPacketType type, EPacketFlags flags, @@ -1657,38 +1754,36 @@ void TTcpConnection::OnTerminate() void TTcpConnection::ProcessQueuedMessages() { - auto messages = QueuedMessages_.DequeueAll(); - - for (auto it = messages.rbegin(); it != messages.rend(); ++it) { - auto& queuedMessage = *it; - - auto packetId = queuedMessage.PacketId; - auto flags = queuedMessage.Options.TrackingLevel == EDeliveryTrackingLevel::Full - ? EPacketFlags::RequestAcknowledgement - : EPacketFlags::None; - - auto* packet = EnqueuePacket( - EPacketType::Message, - flags, - GenerateChecksums_ ? queuedMessage.Options.ChecksummedPartCount : 0, - packetId, - std::move(queuedMessage.Message), - queuedMessage.PayloadSize); - - packet->Promise = queuedMessage.Promise; - if (queuedMessage.Options.EnableSendCancelation) { - packet->EnableCancel(MakeStrong(this)); - } + QueuedMessages_.DequeueAll( + /*reverse*/ true, + [&] (auto& queuedMessage) { + auto packetId = queuedMessage.PacketId; + auto flags = queuedMessage.Options.TrackingLevel == EDeliveryTrackingLevel::Full + ? EPacketFlags::RequestAcknowledgement + : EPacketFlags::None; + + auto* packet = EnqueuePacket( + EPacketType::Message, + flags, + GenerateChecksums_ ? queuedMessage.Options.ChecksummedPartCount : 0, + packetId, + std::move(queuedMessage.Message), + queuedMessage.PayloadSize); + + packet->Promise = queuedMessage.Promise; + if (queuedMessage.Options.EnableSendCancelation) { + packet->EnableCancel(MakeStrong(this)); + } - YT_LOG_DEBUG("Outcoming message dequeued (PacketId: %v, PacketSize: %v, Flags: %v)", - packetId, - packet->PacketSize, - flags); + YT_LOG_DEBUG("Outcoming message dequeued (PacketId: %v, PacketSize: %v, Flags: %v)", + packetId, + packet->PacketSize, + flags); - if (queuedMessage.Promise && !queuedMessage.Options.EnableSendCancelation && !Any(flags & EPacketFlags::RequestAcknowledgement)) { - queuedMessage.Promise.TrySet(); - } - } + if (queuedMessage.Promise && !queuedMessage.Options.EnableSendCancelation && None(flags & EPacketFlags::RequestAcknowledgement)) { + queuedMessage.Promise.TrySet(); + } + }); } void TTcpConnection::DiscardOutcomingMessages() diff --git a/yt/yt/core/bus/tcp/connection.h b/yt/yt/core/bus/tcp/connection.h index a591566428..3d5447263f 100644 --- a/yt/yt/core/bus/tcp/connection.h +++ b/yt/yt/core/bus/tcp/connection.h @@ -284,6 +284,11 @@ private: size_t MaxFragmentsPerWrite_ = 256; + ILocalMessageHandlerPtr LocalBypassHandler_; + IBusPtr LocalBypassReplyBus_; + TCallback<void(const TError&)> LocalBypassTerminatedCallback_; + std::atomic<bool> LocalBypassActive_ = false; + void Open(TGuard<NThreading::TSpinLock>& guard); void Close(); void CloseSslSession(ESslState newSslState); @@ -296,6 +301,10 @@ private: int GetSocketPort(); + void InitLocalBypass(ILocalMessageHandlerPtr localBypassHandler, const NNet::TNetworkAddress& address); + void OnLocalBypassHandlerTerminated(const TError& error); + void FlushQueuedMessagesToLocalBypass(); + void ConnectSocket(const NNet::TNetworkAddress& address); void OnDialerFinished(const TErrorOr<TFileDescriptor>& fdOrError); @@ -323,6 +332,9 @@ private: bool OnHandshakePacketReceived(); bool OnSslAckPacketReceived(); + TFuture<void> SendViaSocket(TSharedRefArray message, const TSendOptions& options); + TFuture<void> SendViaLocalBypass(TSharedRefArray message, const TSendOptions& options); + TPacket* EnqueuePacket( EPacketType type, EPacketFlags flags, diff --git a/yt/yt/core/bus/tcp/dispatcher.h b/yt/yt/core/bus/tcp/dispatcher.h index b39b015c88..1e6ba8fd97 100644 --- a/yt/yt/core/bus/tcp/dispatcher.h +++ b/yt/yt/core/bus/tcp/dispatcher.h @@ -76,6 +76,7 @@ private: friend class TTcpBusServerBase; template <class TServer> friend class TTcpBusServerProxy; + friend class TCompositeBusServer; class TImpl; const TIntrusivePtr<TImpl> Impl_; diff --git a/yt/yt/core/bus/tcp/dispatcher_impl.cpp b/yt/yt/core/bus/tcp/dispatcher_impl.cpp index 5a6334ddbc..c5cb88624b 100644 --- a/yt/yt/core/bus/tcp/dispatcher_impl.cpp +++ b/yt/yt/core/bus/tcp/dispatcher_impl.cpp @@ -79,7 +79,7 @@ IPollerPtr TTcpDispatcher::TImpl::GetOrCreatePoller( const TString& threadNamePrefix) { { - auto guard = ReaderGuard(PollerLock_); + auto guard = ReaderGuard(PollersLock_); if (*pollerPtr) { return *pollerPtr; } @@ -87,13 +87,14 @@ IPollerPtr TTcpDispatcher::TImpl::GetOrCreatePoller( IPollerPtr poller; { - auto guard = WriterGuard(PollerLock_); + auto guard = WriterGuard(PollersLock_); + auto config = Config_.Acquire(); if (!*pollerPtr) { if (isXfer) { *pollerPtr = CreateThreadPoolPoller( - Config_->ThreadPoolSize, + config->ThreadPoolSize, threadNamePrefix, - Config_->ThreadPoolPollingPeriod); + config->ThreadPoolPollingPeriod); } else { *pollerPtr = CreateThreadPoolPoller(/*threadCount*/ 1, threadNamePrefix); } @@ -181,13 +182,13 @@ IPollerPtr TTcpDispatcher::TImpl::GetXferPoller() void TTcpDispatcher::TImpl::Configure(const TTcpDispatcherConfigPtr& config) { { - auto guard = WriterGuard(PollerLock_); + auto guard = WriterGuard(PollersLock_); - Config_ = config; + Config_.Store(config); if (XferPoller_) { - XferPoller_->SetThreadCount(Config_->ThreadPoolSize); - XferPoller_->SetPollingPeriod(Config_->ThreadPoolPollingPeriod); + XferPoller_->SetThreadCount(config->ThreadPoolSize); + XferPoller_->SetPollingPeriod(config->ThreadPoolPollingPeriod); } } @@ -259,12 +260,7 @@ void TTcpDispatcher::TImpl::CollectSensors(ISensorWriter* writer) } }); - TTcpDispatcherConfigPtr config; - { - auto guard = ReaderGuard(PollerLock_); - config = Config_; - } - + auto config = Config_.Acquire(); if (config->NetworkBandwidth) { writer->AddGauge("/network_bandwidth_limit", *config->NetworkBandwidth); } @@ -348,8 +344,47 @@ void TTcpDispatcher::TImpl::OnPeriodicCheck() std::optional<TString> TTcpDispatcher::TImpl::GetBusCertsDirectoryPath() const { - auto guard = ReaderGuard(PollerLock_); - return Config_->BusCertsDirectoryPath; + return Config_.Acquire()->BusCertsDirectoryPath; +} + +void TTcpDispatcher::TImpl::RegisterLocalMessageHandler(int port, const ILocalMessageHandlerPtr& handler) +{ + { + auto guard = WriterGuard(LocalMessageHandlersLock_); + if (!LocalMessageHandlers_.emplace(port, handler).second) { + THROW_ERROR_EXCEPTION("Local message handler is already registered for port %v", + port); + } + } + + YT_LOG_INFO("Local message handler registered (Port: %v)", + port); +} + +void TTcpDispatcher::TImpl::UnregisterLocalMessageHandler(int port) +{ + { + auto guard = WriterGuard(LocalMessageHandlersLock_); + LocalMessageHandlers_.erase(port); + } + + YT_LOG_INFO("Local message handler unregistered (Port: %v)", + port); +} + +ILocalMessageHandlerPtr TTcpDispatcher::TImpl::FindLocalBypassMessageHandler(const TNetworkAddress& address) +{ + if (!TAddressResolver::Get()->IsLocalAddress(address)) { + return nullptr; + } + + if (!Config_.Acquire()->EnableLocalBypass) { + YT_LOG_INFO("XXX disabled %v", address); + return nullptr; + } + + auto guard = ReaderGuard(LocalMessageHandlersLock_); + return GetOrDefault(LocalMessageHandlers_, address.GetPort()); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/dispatcher_impl.h b/yt/yt/core/bus/tcp/dispatcher_impl.h index 09ed61d50d..054607bc3e 100644 --- a/yt/yt/core/bus/tcp/dispatcher_impl.h +++ b/yt/yt/core/bus/tcp/dispatcher_impl.h @@ -16,6 +16,8 @@ #include <library/cpp/yt/threading/rw_spin_lock.h> #include <library/cpp/yt/threading/fork_aware_rw_spin_lock.h> +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + #include <atomic> namespace NYT::NBus { @@ -57,6 +59,10 @@ public: std::optional<TString> GetBusCertsDirectoryPath() const; + void RegisterLocalMessageHandler(int port, const ILocalMessageHandlerPtr& handler); + void UnregisterLocalMessageHandler(int port); + ILocalMessageHandlerPtr FindLocalBypassMessageHandler(const NNet::TNetworkAddress& address); + private: friend class TTcpDispatcher; @@ -73,8 +79,9 @@ private: std::vector<TTcpConnectionPtr> GetConnections(); void BuildOrchid(NYson::IYsonConsumer* consumer); - YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, PollerLock_); - TTcpDispatcherConfigPtr Config_ = New<TTcpDispatcherConfig>(); + TAtomicIntrusivePtr<TTcpDispatcherConfig> Config_{New<TTcpDispatcherConfig>()}; + + YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, PollersLock_); NConcurrency::IThreadPoolPollerPtr AcceptorPoller_; NConcurrency::IThreadPoolPollerPtr XferPoller_; @@ -106,6 +113,9 @@ private: }; TEnumIndexedArray<EMultiplexingBand, TBandDescriptor> BandToDescriptor_; + + YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, LocalMessageHandlersLock_); + THashMap<int, ILocalMessageHandlerPtr> LocalMessageHandlers_; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/local_bypass.cpp b/yt/yt/core/bus/tcp/local_bypass.cpp new file mode 100644 index 0000000000..911cbd805b --- /dev/null +++ b/yt/yt/core/bus/tcp/local_bypass.cpp @@ -0,0 +1,139 @@ +#include "local_bypass.h" + +#include <yt/yt/core/bus/bus.h> + +#include <yt/yt/core/ytree/fluent.h> + +#include <yt/yt/core/net/address.h> + +namespace NYT::NBus { + +using namespace NYTree; + +//////////////////////////////////////////////////////////////////////////////// + +class TLocalBypassReplyBus + : public IBus +{ +public: + TLocalBypassReplyBus( + const NNet::TNetworkAddress& localAddress, + ILocalMessageHandlerPtr serverHandler, + IMessageHandlerPtr clientHandler) + : LocalAddress_(localAddress) + , ServerHandler_(std::move(serverHandler)) + , ClientHandler_(std::move(clientHandler)) + , EndpointDescription_(Format("local-bypass:%v", LocalAddress_)) + , EndpointAttributes_(ConvertToAttributes(BuildYsonStringFluently() + .BeginMap() + .Item("local_bypass").Value(true) + .Item("address").Value(EndpointDescription_) + .EndMap())) + , ServerHandlerTerminatedCallback_(BIND(&TLocalBypassReplyBus::OnServerHandlerTerminated, MakeWeak(this))) + { + ServerHandler_->SubscribeTerminated(ServerHandlerTerminatedCallback_); + } + + ~TLocalBypassReplyBus() + { + ServerHandler_->UnsubscribeTerminated(ServerHandlerTerminatedCallback_); + } + + // IBus overrides. + const std::string& GetEndpointDescription() const override + { + return EndpointDescription_; + } + + const NYTree::IAttributeDictionary& GetEndpointAttributes() const override + { + return *EndpointAttributes_; + } + + TBusNetworkStatistics GetNetworkStatistics() const override + { + return {}; + } + + const std::string& GetEndpointAddress() const override + { + return EndpointDescription_; + } + + const NNet::TNetworkAddress& GetEndpointNetworkAddress() const override + { + return LocalAddress_; + } + + bool IsEndpointLocal() const override + { + return true; + } + + bool IsEncrypted() const override + { + return false; + } + + TFuture<void> GetReadyFuture() const override + { + return VoidFuture; + } + + TFuture<void> Send(TSharedRefArray message, const NBus::TSendOptions& /*options*/) override + { + ClientHandler_->HandleMessage(std::move(message), /*replyBus*/ nullptr); + return VoidFuture; + } + + void SetTosLevel(TTosLevel /*tosLevel*/) override + { } + + void Terminate(const TError& error) override + { + TerminatedList_.Fire(error); + } + + void SubscribeTerminated(const TCallback<void(const TError&)>& callback) override + { + TerminatedList_.Subscribe(callback); + } + + void UnsubscribeTerminated(const TCallback<void(const TError&)>& callback) override + { + TerminatedList_.Unsubscribe(callback); + } + +private: + const NNet::TNetworkAddress LocalAddress_; + const ILocalMessageHandlerPtr ServerHandler_; + const IMessageHandlerPtr ClientHandler_; + const std::string EndpointDescription_; + const IAttributeDictionaryPtr EndpointAttributes_; + const TCallback<void(const TError&)> ServerHandlerTerminatedCallback_; + + TSingleShotCallbackList<void(const TError&)> TerminatedList_; + + void OnServerHandlerTerminated(const TError& error) + { + TerminatedList_.Fire(error); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +IBusPtr CreateLocalBypassReplyBus( + const NNet::TNetworkAddress& localAddress, + ILocalMessageHandlerPtr serverHandler, + IMessageHandlerPtr clientHandler) +{ + return New<TLocalBypassReplyBus>( + localAddress, + std::move(serverHandler), + std::move(clientHandler)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBus + diff --git a/yt/yt/core/bus/tcp/local_bypass.h b/yt/yt/core/bus/tcp/local_bypass.h new file mode 100644 index 0000000000..221e7045e4 --- /dev/null +++ b/yt/yt/core/bus/tcp/local_bypass.h @@ -0,0 +1,32 @@ +#pragma once + +#include "private.h" + +#include <yt/yt/core/bus/bus.h> + +#include <yt/yt/core/net/public.h> + +namespace NYT::NBus { + +//////////////////////////////////////////////////////////////////////////////// + +//! An in-process variant of IMessageHandler that is used implement local bypass. +struct ILocalMessageHandler + : public IMessageHandler +{ + DECLARE_INTERFACE_SIGNAL(void(const TError&), Terminated); +}; + +DEFINE_REFCOUNTED_TYPE(ILocalMessageHandler) + +//////////////////////////////////////////////////////////////////////////////// + +IBusPtr CreateLocalBypassReplyBus( + const NNet::TNetworkAddress& localAddress, + ILocalMessageHandlerPtr serverHandler, + IMessageHandlerPtr clientHandler); + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NBus + diff --git a/yt/yt/core/bus/tcp/private.h b/yt/yt/core/bus/tcp/private.h index a7114a203c..57e627eea2 100644 --- a/yt/yt/core/bus/tcp/private.h +++ b/yt/yt/core/bus/tcp/private.h @@ -6,6 +6,8 @@ namespace NYT::NBus { //////////////////////////////////////////////////////////////////////////////// +DECLARE_REFCOUNTED_STRUCT(ILocalMessageHandler) + constexpr ui32 HandshakeMessageSignature = 0x68737562; //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/bus/tcp/server.cpp b/yt/yt/core/bus/tcp/server.cpp index a447b0a35f..a9fcf0c8a0 100644 --- a/yt/yt/core/bus/tcp/server.cpp +++ b/yt/yt/core/bus/tcp/server.cpp @@ -1,5 +1,6 @@ #include "server.h" #include "config.h" +#include "local_bypass.h" #include "server.h" #include "connection.h" #include "dispatcher_impl.h" @@ -50,18 +51,12 @@ public: , Handler_(std::move(handler)) , PacketTranscoderFactory_(std::move(packetTranscoderFactory)) , MemoryUsageTracker_(std::move(memoryUsageTracker)) + , Logger(MakeLogger(Config_)) { YT_VERIFY(Config_); YT_VERIFY(Poller_); YT_VERIFY(Handler_); YT_VERIFY(MemoryUsageTracker_); - - if (Config_->Port) { - Logger.AddTag("ServerPort: %v", *Config_->Port); - } - if (Config_->UnixDomainSocketPath) { - Logger.AddTag("UnixDomainSocketPath: %v", *Config_->UnixDomainSocketPath); - } } ~TTcpBusServerBase() @@ -72,17 +67,23 @@ public: void Start() { OpenServerSocket(); + if (!Poller_->TryRegister(this)) { CloseServerSocket(); THROW_ERROR_EXCEPTION("Cannot register server pollable"); } + ArmPoller(); + + YT_LOG_INFO("Bus server started"); } TFuture<void> Stop() { YT_LOG_INFO("Stopping Bus server"); + UnarmPoller(); + return Poller_->Unregister(this).Apply(BIND([this, this_ = MakeStrong(this)] { YT_LOG_INFO("Bus server stopped"); })); @@ -94,12 +95,12 @@ public: return Logger.GetTag(); } - void OnEvent(EPollControl /*control*/) override + void OnEvent(EPollControl /*control*/) final { OnAccept(); } - void OnShutdown() override + void OnShutdown() final { { auto guard = Guard(ControlSpinLock_); @@ -123,23 +124,32 @@ protected: const TBusServerConfigPtr Config_; const IPollerPtr Poller_; const IMessageHandlerPtr Handler_; - IPacketTranscoderFactory* const PacketTranscoderFactory_; - const IMemoryUsageTrackerPtr MemoryUsageTracker_; + const NLogging::TLogger Logger; + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, ControlSpinLock_); SOCKET ServerSocket_ = INVALID_SOCKET; YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, ConnectionsSpinLock_); THashSet<TTcpConnectionPtr> Connections_; - NLogging::TLogger Logger = BusLogger(); - virtual void CreateServerSocket() = 0; - virtual void InitClientSocket(SOCKET clientSocket) = 0; + static NLogging::TLogger MakeLogger(const TBusServerConfigPtr& config) + { + auto logger = BusLogger(); + if (config->Port) { + logger.AddTag("ServerPort: %v", *config->Port); + } + if (config->UnixDomainSocketPath) { + logger.AddTag("UnixDomainSocketPath: %v", *config->UnixDomainSocketPath); + } + return logger; + } + void OnConnectionTerminated(const TTcpConnectionPtr& connection, const TError& /*error*/) { auto guard = WriterGuard(ConnectionsSpinLock_); @@ -263,7 +273,7 @@ protected: { auto guard = WriterGuard(ConnectionsSpinLock_); - YT_VERIFY(Connections_.insert(connection).second); + EmplaceOrCrash(Connections_, connection); } connection->SubscribeTerminated(BIND_NO_PROPAGATE( @@ -321,7 +331,7 @@ public: using TTcpBusServerBase::TTcpBusServerBase; private: - void CreateServerSocket() override + void CreateServerSocket() final { ServerSocket_ = CreateTcpServerSocket(); @@ -329,7 +339,7 @@ private: BindSocket(serverAddress, Format("Failed to bind a server socket to port %v", Config_->Port)); } - void InitClientSocket(SOCKET clientSocket) override + void InitClientSocket(SOCKET clientSocket) final { if (Config_->EnableNoDelay) { if (!TrySetSocketNoDelay(clientSocket)) { @@ -364,7 +374,7 @@ public: { } private: - void CreateServerSocket() override + void CreateServerSocket() final { ServerSocket_ = CreateUnixServerSocket(); @@ -381,7 +391,7 @@ private: } } - void InitClientSocket(SOCKET /*clientSocket*/) override + void InitClientSocket(SOCKET /*clientSocket*/) final { } }; @@ -414,7 +424,7 @@ public: YT_UNUSED_FUTURE(Stop()); } - void Start(IMessageHandlerPtr handler) override + void Start(IMessageHandlerPtr handler) final { auto server = New<TServer>( Config_, @@ -427,7 +437,7 @@ public: server->Start(); } - TFuture<void> Stop() override + TFuture<void> Stop() final { if (auto server = Server_.Exchange(nullptr)) { return server->Stop(); @@ -438,9 +448,7 @@ public: private: const TBusServerConfigPtr Config_; - IPacketTranscoderFactory* const PacketTranscoderFactory_; - const IMemoryUsageTrackerPtr MemoryUsageTracker_; TAtomicIntrusivePtr<TServer> Server_; @@ -448,34 +456,91 @@ private: //////////////////////////////////////////////////////////////////////////////// +DECLARE_REFCOUNTED_CLASS(TLocalMessageHandler) + +class TLocalMessageHandler + : public ILocalMessageHandler +{ +public: + explicit TLocalMessageHandler(IMessageHandlerPtr underlying) + : Underlying_(std::move(underlying)) + { } + + void HandleMessage(TSharedRefArray message, IBusPtr replyBus) noexcept final + { + Underlying_->HandleMessage(std::move(message), std::move(replyBus)); + } + + void SubscribeTerminated(const TCallback<void(const TError&)>& callback) final + { + Terminated_.Subscribe(callback); + } + + void UnsubscribeTerminated(const TCallback<void(const TError&)>& callback) final + { + Terminated_.Unsubscribe(callback); + } + + void Terminate(const TError& error) + { + Terminated_.Fire(error); + } + +private: + const IMessageHandlerPtr Underlying_; + + TSingleShotCallbackList<void(const TError&)> Terminated_; +}; + +DEFINE_REFCOUNTED_TYPE(TLocalMessageHandler) + +//////////////////////////////////////////////////////////////////////////////// + class TCompositeBusServer : public IBusServer { public: - explicit TCompositeBusServer(std::vector<IBusServerPtr> servers) - : Servers_(std::move(servers)) + TCompositeBusServer( + TBusServerConfigPtr config, + std::vector<IBusServerPtr> servers) + : Config_(std::move(config)) + , Servers_(std::move(servers)) { } // IBusServer implementation. - void Start(IMessageHandlerPtr handler) override + void Start(IMessageHandlerPtr handler) final { for (const auto& server : Servers_) { server->Start(handler); } + + if (Config_->EnableLocalBypass && Config_->Port) { + LocalHandler_ = New<TLocalMessageHandler>(std::move(handler)); + TTcpDispatcher::TImpl::Get()->RegisterLocalMessageHandler(*Config_->Port, LocalHandler_); + } } - TFuture<void> Stop() override + TFuture<void> Stop() final { - std::vector<TFuture<void>> asyncResults; + if (Config_->EnableLocalBypass && Config_->Port) { + LocalHandler_->Terminate(TError(NRpc::EErrorCode::TransportError, "Local server stopped")); + LocalHandler_.Reset(); + TTcpDispatcher::TImpl::Get()->UnregisterLocalMessageHandler(*Config_->Port); + } + + std::vector<TFuture<void>> futures; for (const auto& server : Servers_) { - asyncResults.push_back(server->Stop()); + futures.push_back(server->Stop()); } - return AllSucceeded(asyncResults); + return AllSucceeded(futures); } private: + const TBusServerConfigPtr Config_; const std::vector<IBusServerPtr> Servers_; + + TLocalMessageHandlerPtr LocalHandler_; }; //////////////////////////////////////////////////////////////////////////////// @@ -494,6 +559,7 @@ IBusServerPtr CreateBusServer( packetTranscoderFactory, memoryUsageTracker)); } + #ifdef _linux_ // Abstract unix sockets are supported only on Linux. servers.push_back( @@ -503,7 +569,9 @@ IBusServerPtr CreateBusServer( memoryUsageTracker)); #endif - return New<TCompositeBusServer>(std::move(servers)); + return New<TCompositeBusServer>( + config, + std::move(servers)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/yt/yt/core/net/address.cpp b/yt/yt/core/net/address.cpp index d048167af3..70a5d78bf5 100644 --- a/yt/yt/core/net/address.cpp +++ b/yt/yt/core/net/address.cpp @@ -1029,10 +1029,12 @@ public: bool IsLocalAddress(const TNetworkAddress& address) { - TNetworkAddress localIP{address, 0}; - + if (!address.IsIP()) { + return false; + } + TNetworkAddress candidateAddress(address, /*port*/ 0); const auto& localAddresses = GetLocalAddresses(); - return std::find(localAddresses.begin(), localAddresses.end(), localIP) != localAddresses.end(); + return std::find(localAddresses.begin(), localAddresses.end(), candidateAddress) != localAddresses.end(); } void PurgeCache() diff --git a/yt/yt/core/rpc/bus/channel.cpp b/yt/yt/core/rpc/bus/channel.cpp index 3a0214e4de..d90dc774d9 100644 --- a/yt/yt/core/rpc/bus/channel.cpp +++ b/yt/yt/core/rpc/bus/channel.cpp @@ -284,8 +284,7 @@ private: void HandleMessage(TSharedRefArray message, IBusPtr replyBus) noexcept override { - auto session_ = Session_.Lock(); - if (session_) { + if (auto session_ = Session_.Lock()) { session_->HandleMessage(std::move(message), std::move(replyBus)); } } diff --git a/yt/yt/core/rpc/unittests/lib/common.h b/yt/yt/core/rpc/unittests/lib/common.h index 9d2f5fca81..5326f017cc 100644 --- a/yt/yt/core/rpc/unittests/lib/common.h +++ b/yt/yt/core/rpc/unittests/lib/common.h @@ -9,6 +9,7 @@ #include <yt/yt/core/bus/tcp/config.h> #include <yt/yt/core/bus/tcp/client.h> +#include <yt/yt/core/bus/tcp/dispatcher.h> #include <yt/yt/core/bus/tcp/server.h> #include <yt/yt/core/crypto/config.h> @@ -75,21 +76,25 @@ class TRpcTestBase public: void SetUp() final { - bool secure = TImpl::Secure; - WorkerPool_ = NConcurrency::CreateThreadPool(4, "Worker"); MemoryUsageTracker_ = New<TTestNodeMemoryTracker>(32_MB); - TestService_ = CreateTestService(WorkerPool_->GetInvoker(), secure, {}, MemoryUsageTracker_); + TestService_ = CreateTestService(WorkerPool_->GetInvoker(), TImpl::Secure, {}, MemoryUsageTracker_); auto services = std::vector<IServicePtr>{ TestService_, - CreateNoBaggageService(WorkerPool_->GetInvoker()) + CreateNoBaggageService(WorkerPool_->GetInvoker()), }; Host_ = TImpl::CreateTestServerHost( NTesting::GetFreePort(), std::move(services), MemoryUsageTracker_); + + // Make sure local bypass is globally enabled. + // Individual tests will toggle per-connection flag to actually enable this feature. + auto config = New<NYT::NBus::TTcpDispatcherConfig>(); + config->EnableLocalBypass = true; + NYT::NBus::TTcpDispatcher::Get()->Configure(config); } void TearDown() final @@ -98,14 +103,13 @@ public: } IChannelPtr CreateChannel( - const std::optional<TString>& address = std::nullopt, + const std::optional<TString>& address = {}, THashMap<TString, NYTree::INodePtr> grpcArguments = {}) { - if (address) { - return TImpl::CreateChannel(*address, Host_->GetAddress(), std::move(grpcArguments)); - } else { - return TImpl::CreateChannel(Host_->GetAddress(), Host_->GetAddress(), std::move(grpcArguments)); - } + return TImpl::CreateChannel( + address.value_or(Host_->GetAddress()), + Host_->GetAddress(), + std::move(grpcArguments)); } TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() const @@ -166,13 +170,14 @@ public: static constexpr bool AllowTransportErrors = false; static constexpr bool Secure = false; static constexpr int MaxSimultaneousRequestCount = 1000; + static constexpr bool MemoryUsageTrackingEnabled = TImpl::MemoryUsageTrackingEnabled; static TTestServerHostPtr CreateTestServerHost( NTesting::TPortHolder port, std::vector<IServicePtr> services, TTestNodeMemoryTrackerPtr memoryUsageTracker) { - auto busServer = MakeBusServer(port, memoryUsageTracker); + auto busServer = CreateBusServer(port, memoryUsageTracker); auto server = NRpc::NBus::CreateBusServer(busServer); return New<TTestServerHost>( @@ -190,34 +195,39 @@ public: return TImpl::CreateChannel(address, serverAddress, std::move(grpcArguments)); } - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) + static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { - return TImpl::MakeBusServer(port, memoryUsageTracker); + return TImpl::CreateBusServer(port, memoryUsageTracker); } }; //////////////////////////////////////////////////////////////////////////////// -template <bool ForceTcp> +template <bool EnableLocalBypass> class TRpcOverBusImpl { public: + static constexpr bool MemoryUsageTrackingEnabled = !EnableLocalBypass; + static IChannelPtr CreateChannel( const std::string& address, const std::string& /*serverAddress*/, THashMap<TString, NYTree::INodePtr> /*grpcArguments*/) { - auto client = CreateBusClient(NYT::NBus::TBusClientConfig::CreateTcp(address)); - return NRpc::NBus::CreateBusChannel(client); + auto config = NYT::NBus::TBusClientConfig::CreateTcp(address); + config->EnableLocalBypass = EnableLocalBypass; + auto client = CreateBusClient(std::move(config)); + return NRpc::NBus::CreateBusChannel(std::move(client)); } - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) + static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { - auto busConfig = NYT::NBus::TBusServerConfig::CreateTcp(port); - return CreateBusServer( - busConfig, + auto config = NYT::NBus::TBusServerConfig::CreateTcp(port); + config->EnableLocalBypass = EnableLocalBypass; + return NYT::NBus::CreateBusServer( + std::move(config), NYT::NBus::GetYTPacketTranscoderFactory(), - memoryUsageTracker); + std::move(memoryUsageTracker)); } }; @@ -439,12 +449,14 @@ public: class TRpcOverUdsImpl { public: - static NYT::NBus::IBusServerPtr MakeBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) + static constexpr bool MemoryUsageTrackingEnabled = true; + + static NYT::NBus::IBusServerPtr CreateBusServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) { SocketPath_ = GetWorkPath() + "/socket_" + ToString(port); - auto busConfig = NYT::NBus::TBusServerConfig::CreateUds(SocketPath_); - return CreateBusServer( - busConfig, + auto config = NYT::NBus::TBusServerConfig::CreateUds(SocketPath_); + return NYT::NBus::CreateBusServer( + config, NYT::NBus::GetYTPacketTranscoderFactory(), memoryUsageTracker); } @@ -454,9 +466,9 @@ public: const std::string& serverAddress, THashMap<TString, NYTree::INodePtr> /*grpcArguments*/) { - auto clientConfig = NYT::NBus::TBusClientConfig::CreateUds( + auto config = NYT::NBus::TBusClientConfig::CreateUds( address == serverAddress ? SocketPath_ : address); - auto client = CreateBusClient(clientConfig); + auto client = CreateBusClient(config); return NRpc::NBus::CreateBusChannel(client); } @@ -530,9 +542,9 @@ public: using TAllTransports = ::testing::Types< #ifdef _linux_ TRpcOverBus<TRpcOverUdsImpl>, - TRpcOverBus<TRpcOverBusImpl<true>>, #endif TRpcOverBus<TRpcOverBusImpl<false>>, + TRpcOverBus<TRpcOverBusImpl<true>>, TRpcOverGrpcImpl<false, false>, TRpcOverGrpcImpl<false, true>, TRpcOverGrpcImpl<true, false>, @@ -544,9 +556,9 @@ using TAllTransports = ::testing::Types< using TWithAttachments = ::testing::Types< #ifdef _linux_ TRpcOverBus<TRpcOverUdsImpl>, - TRpcOverBus<TRpcOverBusImpl<true>>, #endif TRpcOverBus<TRpcOverBusImpl<false>>, + TRpcOverBus<TRpcOverBusImpl<true>>, TRpcOverGrpcImpl<false, false>, TRpcOverGrpcImpl<false, true>, TRpcOverGrpcImpl<true, false>, @@ -554,10 +566,8 @@ using TWithAttachments = ::testing::Types< >; using TWithoutUds = ::testing::Types< -#ifdef _linux_ - TRpcOverBus<TRpcOverBusImpl<true>>, -#endif TRpcOverBus<TRpcOverBusImpl<false>>, + TRpcOverBus<TRpcOverBusImpl<true>>, TRpcOverGrpcImpl<false, false>, TRpcOverGrpcImpl<true, false>, TRpcOverHttpImpl<false>, @@ -567,9 +577,9 @@ using TWithoutUds = ::testing::Types< using TWithoutGrpc = ::testing::Types< #ifdef _linux_ TRpcOverBus<TRpcOverUdsImpl>, - TRpcOverBus<TRpcOverBusImpl<true>>, #endif - TRpcOverBus<TRpcOverBusImpl<false>> + TRpcOverBus<TRpcOverBusImpl<false>>, + TRpcOverBus<TRpcOverBusImpl<true>> >; using TGrpcOnly = ::testing::Types< diff --git a/yt/yt/core/rpc/unittests/rpc_ut.cpp b/yt/yt/core/rpc/unittests/rpc_ut.cpp index e3ce7376ee..94e8d13e7a 100644 --- a/yt/yt/core/rpc/unittests/rpc_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_ut.cpp @@ -508,9 +508,9 @@ TYPED_TEST(TAttachmentsTest, RegularAttachments) const auto& attachments = rsp->Attachments(); EXPECT_EQ(3u, attachments.size()); - EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0])); - EXPECT_EQ("from_", StringFromSharedRef(attachments[1])); - EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2])); + EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0])); + EXPECT_EQ("from_", StringFromSharedRef(attachments[1])); + EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2])); } TYPED_TEST(TNotGrpcTest, TrackedRegularAttachments) @@ -536,11 +536,13 @@ TYPED_TEST(TNotGrpcTest, TrackedRegularAttachments) // header + body = 79 bytes. // attachments = 22 bytes. // sum is 4219 bytes. - EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4197 + 32768 + std::ssize(GetRpcUserAgent())); + if (TypeParam::MemoryUsageTrackingEnabled) { + EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4197 + 32768 + std::ssize(GetRpcUserAgent())); + } EXPECT_EQ(3u, attachments.size()); - EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0])); - EXPECT_EQ("from_", StringFromSharedRef(attachments[1])); - EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2])); + EXPECT_EQ("Hello_", StringFromSharedRef(attachments[0])); + EXPECT_EQ("from_", StringFromSharedRef(attachments[1])); + EXPECT_EQ("TTestProxy_", StringFromSharedRef(attachments[2])); } TYPED_TEST(TAttachmentsTest, NullAndEmptyAttachments) @@ -602,7 +604,9 @@ TYPED_TEST(TNotGrpcTest, Compression) // attachmentStrings[1].size() = 36 * 2 bytes from decoder. // attachmentStrings[2].size() = 90 * 2 bytes from decoder. // sum is 4584 bytes. - EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4562 + 32768 + std::ssize(GetRpcUserAgent())); + if (TypeParam::MemoryUsageTrackingEnabled) { + EXPECT_GE(memoryUsageTracker->GetTotalUsage(), 4562 + 32768 + std::ssize(GetRpcUserAgent())); + } EXPECT_TRUE(rsp->message() == message); EXPECT_GE(rsp->GetResponseMessage().Size(), static_cast<size_t>(2)); const auto& serializedResponseBody = SerializeProtoToRefWithCompression(*rsp, responseCodecId); @@ -828,7 +832,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTracking) Sleep(TDuration::MilliSeconds(200)); - { + if (TypeParam::MemoryUsageTrackingEnabled) { auto rpcUsage = memoryUsageTracker->GetTotalUsage(); // 1285268 = 32768 + 1252500 = 32768 + 4096 * 300 + 300 * 79 (header + body). @@ -849,7 +853,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConnections) WaitFor(req->Invoke().AsVoid()).ThrowOnError(); } - { + if (TypeParam::MemoryUsageTrackingEnabled) { // 11082900 / 300 = 36974 = 32768 + 4096 + 79 (header + body). // 4 KB - stub for request. // See NYT::NBus::TPacketDecoder::TChunkedMemoryTrackingAllocator::Allocate. @@ -877,7 +881,7 @@ TYPED_TEST(TNotGrpcTest, MemoryTrackingMultipleConcurrent) Sleep(TDuration::MilliSeconds(100)); - { + if (TypeParam::MemoryUsageTrackingEnabled) { auto rpcUsage = memoryUsageTracker->GetUsed(); // connections count - per connection size. @@ -901,7 +905,8 @@ TYPED_TEST(TNotGrpcTest, MemoryOvercommit) req->set_request_codec(ToProto(requestCodecId)); req->Attachments().push_back(TSharedRef::FromString(TString(6_KB, 'x'))); WaitFor(req->Invoke()).ThrowOnError(); - { + + if (TypeParam::MemoryUsageTrackingEnabled) { auto rpcUsage = memoryReferenceUsageTracker->GetTotalUsage(); // Attachment allocator proactively allocate slice of 4 KB. diff --git a/yt/yt/core/ya.make b/yt/yt/core/ya.make index a47b5ddfbf..857e56920b 100644 --- a/yt/yt/core/ya.make +++ b/yt/yt/core/ya.make @@ -30,6 +30,7 @@ SRCS( GLOBAL bus/tcp/configure_dispatcher.cpp bus/tcp/packet.cpp bus/tcp/client.cpp + bus/tcp/local_bypass.cpp bus/tcp/server.cpp bus/tcp/ssl_context.cpp bus/tcp/ssl_helpers.cpp |