diff options
author | robot-piglet <robot-piglet@yandex-team.com> | 2024-05-14 15:28:27 +0300 |
---|---|---|
committer | robot-piglet <robot-piglet@yandex-team.com> | 2024-05-14 15:37:47 +0300 |
commit | a5eea9faefdbd1611f0efe300a49631ecab2abcf (patch) | |
tree | 62e7c333884815d354a52522ee0fa6f24adf3d60 | |
parent | afe084d86e0cab46922104f31fbc88625b54230f (diff) | |
download | ydb-a5eea9faefdbd1611f0efe300a49631ecab2abcf.tar.gz |
Intermediate changes
-rw-r--r-- | yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp | 85 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/lib/common.cpp | 126 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/lib/common.h | 211 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/roaming_channel_ut.cpp | 2 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp | 2 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/rpc_shutdown_ut.cpp | 2 | ||||
-rw-r--r-- | yt/yt/core/rpc/unittests/rpc_ut.cpp | 10 | ||||
-rw-r--r-- | yt/yt/core/test_framework/test_memory_tracker.cpp | 133 | ||||
-rw-r--r-- | yt/yt/core/test_framework/test_memory_tracker.h | 82 | ||||
-rw-r--r-- | yt/yt/core/test_framework/test_server_host.cpp | 58 | ||||
-rw-r--r-- | yt/yt/core/test_framework/test_server_host.h | 46 | ||||
-rw-r--r-- | yt/yt/core/test_framework/ya.make | 2 |
12 files changed, 429 insertions, 330 deletions
diff --git a/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp b/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp index fbd523cc63..f927969c24 100644 --- a/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp +++ b/yt/yt/core/rpc/unittests/handle_channel_failure_ut.cpp @@ -12,13 +12,15 @@ class THandleChannelFailureTestBase : public ::testing::Test { public: - IServerPtr CreateServer( - const TTestServerHost& serverHost, - IMemoryUsageTrackerPtr memoryUsageTracker) + TTestServerHostPtr CreateTestServerHost( + NTesting::TPortHolder port, + std::vector<IServicePtr> services, + TTestNodeMemoryTrackerPtr memoryUsageTracker) { - return TImpl::CreateServer( - serverHost.GetPort(), - memoryUsageTracker); + return TImpl::CreateTestServerHost( + std::move(port), + std::move(services), + std::move(memoryUsageTracker)); } IChannelPtr CreateChannel(const TString& address) @@ -36,39 +38,50 @@ TYPED_TEST_SUITE(THandleChannelFailureTest, TWithoutUds); TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) { - TTestServerHost outerServer; - TTestServerHost innerServer; + auto workerPool = NConcurrency::CreateThreadPool(4, "Worker"); - outerServer.InitilizeAddress(); - innerServer.InitilizeAddress(); + auto outerMemoryUsageTracker = New<TTestNodeMemoryTracker>(32_MB); + auto innerMemoryUsageTracker = New<TTestNodeMemoryTracker>(32_MB); + + auto outerServices = std::vector<IServicePtr>{ + CreateTestService( + workerPool->GetInvoker(), + false, + BIND([&] (const TString& address) { + return this->CreateChannel(address); + }), + outerMemoryUsageTracker), + CreateNoBaggageService(workerPool->GetInvoker()) + }; + + auto innerServices = std::vector<IServicePtr>{ + CreateTestService( + workerPool->GetInvoker(), + false, + BIND([&] (const TString& address) { + return this->CreateChannel(address); + }), + innerMemoryUsageTracker), + CreateNoBaggageService(workerPool->GetInvoker()) + }; + + TTestServerHostPtr outerHost = this->CreateTestServerHost( + NTesting::GetFreePort(), + outerServices, + outerMemoryUsageTracker); + + TTestServerHostPtr innerHost = this->CreateTestServerHost( + NTesting::GetFreePort(), + innerServices, + innerMemoryUsageTracker); auto finally = Finally([&] { - outerServer.TearDown(); - innerServer.TearDown(); + outerHost->TearDown(); + innerHost->TearDown(); }); - auto workerPool = NConcurrency::CreateThreadPool(4, "Worker"); - - outerServer.InitializeServer( - this->CreateServer( - outerServer, - outerServer.GetMemoryUsageTracker()), - workerPool->GetInvoker(), - /*secure*/ false, - BIND([&] (const TString& address) { - return this->CreateChannel(address); - })); - - innerServer.InitializeServer( - this->CreateServer( - innerServer, - innerServer.GetMemoryUsageTracker()), - workerPool->GetInvoker(), - /*secure*/ false, - /*createChannel*/ {}); - { - auto channel = this->CreateChannel(outerServer.GetAddress()); + auto channel = this->CreateChannel(outerHost->GetAddress()); TTestProxy proxy(channel); auto req = proxy.GetChannelFailureError(); auto error = req->Invoke().Get(); @@ -81,7 +94,7 @@ TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) int failCount = 0; auto channel = CreateFailureDetectingChannel( - this->CreateChannel(outerServer.GetAddress()), + this->CreateChannel(outerHost->GetAddress()), /*acknowledgementTimeout*/ std::nullopt, BIND([&] (const IChannelPtr& /*channel*/, const TError& error) { ++failCount; @@ -101,7 +114,7 @@ TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) int failCount = 0; auto channel = CreateFailureDetectingChannel( - this->CreateChannel(outerServer.GetAddress()), + this->CreateChannel(outerHost->GetAddress()), /*acknowledgementTimeout*/ std::nullopt, BIND([&] (const IChannelPtr& /*channel*/, const TError& error) { ++failCount; @@ -110,7 +123,7 @@ TYPED_TEST(THandleChannelFailureTest, HandleChannelFailureTest) TTestProxy proxy(channel); auto req = proxy.GetChannelFailureError(); - req->set_redirection_address(innerServer.GetAddress()); + req->set_redirection_address(innerHost->GetAddress()); auto error = req->Invoke().Get(); ASSERT_FALSE(error.IsOK()); ASSERT_TRUE(error.FindMatching(NRpc::EErrorCode::Unavailable)); diff --git a/yt/yt/core/rpc/unittests/lib/common.cpp b/yt/yt/core/rpc/unittests/lib/common.cpp index a1f72f5c07..d906d82fd4 100644 --- a/yt/yt/core/rpc/unittests/lib/common.cpp +++ b/yt/yt/core/rpc/unittests/lib/common.cpp @@ -8,130 +8,4 @@ TString TRpcOverUdsImpl::SocketPath_ = ""; //////////////////////////////////////////////////////////////////////////////// -TTestNodeMemoryTracker::TTestNodeMemoryTracker(size_t limit) - : Usage_(0) - , Limit_(limit) -{ } - -i64 TTestNodeMemoryTracker::GetLimit() const -{ - auto guard = Guard(Lock_); - return Limit_; -} - -i64 TTestNodeMemoryTracker::GetUsed() const -{ - auto guard = Guard(Lock_); - return Usage_; -} - -i64 TTestNodeMemoryTracker::GetFree() const -{ - auto guard = Guard(Lock_); - return Limit_ - Usage_; -} - -bool TTestNodeMemoryTracker::IsExceeded() const -{ - auto guard = Guard(Lock_); - return Limit_ - Usage_ <= 0; -} - -TError TTestNodeMemoryTracker::TryAcquire(i64 size) -{ - auto guard = Guard(Lock_); - return DoTryAcquire(size); -} - -TError TTestNodeMemoryTracker::DoTryAcquire(i64 size) -{ - if (Usage_ + size >= Limit_) { - return TError("Memory exceeded"); - } - - Usage_ += size; - TotalUsage_ += size; - - return {}; -} - -TError TTestNodeMemoryTracker::TryChange(i64 size) -{ - auto guard = Guard(Lock_); - - if (size > Usage_) { - return DoTryAcquire(size - Usage_); - } else if (size < Usage_) { - DoRelease(Usage_ - size); - } - - return {}; -} - -bool TTestNodeMemoryTracker::Acquire(i64 size) -{ - auto guard = Guard(Lock_); - DoAcquire(size); - return Usage_ >= Limit_; -} - -void TTestNodeMemoryTracker::Release(i64 size) -{ - auto guard = Guard(Lock_); - DoRelease(size); -} - -void TTestNodeMemoryTracker::SetLimit(i64 size) -{ - auto guard = Guard(Lock_); - Limit_ = size; -} - -void TTestNodeMemoryTracker::DoAcquire(i64 size) -{ - Usage_ += size; - TotalUsage_ += size; -} - -void TTestNodeMemoryTracker::DoRelease(i64 size) -{ - Usage_ -= size; -} - -void TTestNodeMemoryTracker::ClearTotalUsage() -{ - TotalUsage_ = 0; -} - -i64 TTestNodeMemoryTracker::GetTotalUsage() const -{ - return TotalUsage_; -} - -TSharedRef TTestNodeMemoryTracker::Track(TSharedRef reference, bool keepExistingTracking) -{ - if (!reference) { - return reference; - } - - auto rawReference = TRef(reference); - const auto& holder = reference.GetHolder(); - - // Reference could be without a holder, e.g. empty reference. - if (!holder) { - YT_VERIFY(reference.Begin() == TRef::MakeEmpty().Begin()); - return reference; - } - - auto guard = TMemoryUsageTrackerGuard::Acquire(this, reference.Size()); - - auto underlyingHolder = holder->Clone({.KeepMemoryReferenceTracking = keepExistingTracking}); - auto underlyingReference = TSharedRef(rawReference, std::move(underlyingHolder)); - return TSharedRef( - rawReference, - New<TTestTrackedReferenceHolder>(std::move(underlyingReference), std::move(guard))); -} - -//////////////////////////////////////////////////////////////////////////////// - } // namespace NYT::NRpc diff --git a/yt/yt/core/rpc/unittests/lib/common.h b/yt/yt/core/rpc/unittests/lib/common.h index f87c544c20..a10f442f26 100644 --- a/yt/yt/core/rpc/unittests/lib/common.h +++ b/yt/yt/core/rpc/unittests/lib/common.h @@ -1,6 +1,8 @@ #pragma once #include <yt/yt/core/test_framework/framework.h> +#include <yt/yt/core/test_framework/test_memory_tracker.h> +#include <yt/yt/core/test_framework/test_server_host.h> #include <yt/yt/core/bus/bus.h> #include <yt/yt/core/bus/server.h> @@ -53,168 +55,37 @@ #include <library/cpp/testing/common/env.h> #include <library/cpp/testing/common/network.h> -#include <random> - namespace NYT::NRpc { //////////////////////////////////////////////////////////////////////////////// -class TTestNodeMemoryTracker - : public IMemoryUsageTracker -{ -public: - explicit TTestNodeMemoryTracker(size_t limit); - - i64 GetLimit() const override; - i64 GetUsed() const override; - i64 GetFree() const override; - bool IsExceeded() const override; - - TError TryAcquire(i64 size) override; - TError TryChange(i64 size) override; - bool Acquire(i64 size) override; - void Release(i64 size) override; - void SetLimit(i64 size) override; - - void ClearTotalUsage(); - i64 GetTotalUsage() const; - - TSharedRef Track( - TSharedRef reference, - bool keepHolder = false) override; -private: - class TTestTrackedReferenceHolder - : public TSharedRangeHolder - { - public: - TTestTrackedReferenceHolder( - TSharedRef underlying, - TMemoryUsageTrackerGuard guard) - : Underlying_(std::move(underlying)) - , Guard_(std::move(guard)) - { } - - TSharedRangeHolderPtr Clone(const TSharedRangeHolderCloneOptions& options) override - { - if (options.KeepMemoryReferenceTracking) { - return this; - } - return Underlying_.GetHolder()->Clone(options); - } - - std::optional<size_t> GetTotalByteSize() const override - { - return Underlying_.GetHolder()->GetTotalByteSize(); - } - - private: - const TSharedRef Underlying_; - const TMemoryUsageTrackerGuard Guard_; - }; - - YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock_); - i64 Usage_; - i64 Limit_; - i64 TotalUsage_; - - TError DoTryAcquire(i64 size); - void DoAcquire(i64 size); - void DoRelease(i64 size); -}; - -DECLARE_REFCOUNTED_CLASS(TTestNodeMemoryTracker) -DEFINE_REFCOUNTED_TYPE(TTestNodeMemoryTracker) - -//////////////////////////////////////////////////////////////////////////////// - -class TTestServerHost -{ -public: - void InitilizeAddress() - { - Port_ = NTesting::GetFreePort(); - Address_ = Format("localhost:%v", Port_); - MemoryUsageTracker_ = New<TTestNodeMemoryTracker>(32_MB); - } - - void InitializeServer( - IServerPtr server, - const IInvokerPtr& invoker, - bool secure, - TTestCreateChannelCallback createChannel) - { - Server_ = server; - TestService_ = CreateTestService(invoker, secure, createChannel, MemoryUsageTracker_); - NoBaggageService_ = CreateNoBaggageService(invoker); - Server_->RegisterService(TestService_); - Server_->RegisterService(NoBaggageService_); - Server_->Start(); - } - - void TearDown() - { - Server_->Stop().Get().ThrowOnError(); - Server_.Reset(); - } - - const NTesting::TPortHolder& GetPort() const - { - return Port_; - } - - TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() - { - return MemoryUsageTracker_; - } - - TString GetAddress() const - { - return Address_; - } - - ITestServicePtr GetTestService() - { - return TestService_; - } - - IServerPtr GetServer() - { - return Server_; - } - -protected: - NTesting::TPortHolder Port_; - TString Address_; - - ITestServicePtr TestService_; - IServicePtr NoBaggageService_; - IServerPtr Server_; - TTestNodeMemoryTrackerPtr MemoryUsageTracker_; -}; - -//////////////////////////////////////////////////////////////////////////////// - template <class TImpl> -class TTestBase +class TRpcTestBase : public ::testing::Test { public: void SetUp() final { - Host_.InitilizeAddress(); + bool secure = TImpl::Secure; WorkerPool_ = NConcurrency::CreateThreadPool(4, "Worker"); - bool secure = TImpl::Secure; - Host_.InitializeServer( - TImpl::CreateServer(Host_.GetPort(), Host_.GetMemoryUsageTracker()), - WorkerPool_->GetInvoker(), - secure, - /*createChannel*/ {}); + MemoryUsageTracker_ = New<TTestNodeMemoryTracker>(32_MB); + TestService_ = CreateTestService(WorkerPool_->GetInvoker(), secure, {}, MemoryUsageTracker_); + + auto services = std::vector<IServicePtr>{ + TestService_, + CreateNoBaggageService(WorkerPool_->GetInvoker()) + }; + + Host_ = TImpl::CreateTestServerHost( + NTesting::GetFreePort(), + std::move(services), + MemoryUsageTracker_); } void TearDown() final { - Host_.TearDown(); + Host_->TearDown(); } IChannelPtr CreateChannel( @@ -222,25 +93,25 @@ public: THashMap<TString, NYTree::INodePtr> grpcArguments = {}) { if (address) { - return TImpl::CreateChannel(*address, Host_.GetAddress(), std::move(grpcArguments)); + return TImpl::CreateChannel(*address, Host_->GetAddress(), std::move(grpcArguments)); } else { - return TImpl::CreateChannel(Host_.GetAddress(), Host_.GetAddress(), std::move(grpcArguments)); + return TImpl::CreateChannel(Host_->GetAddress(), Host_->GetAddress(), std::move(grpcArguments)); } } - TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() + TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() const { - return Host_.GetMemoryUsageTracker(); + return Host_->GetMemoryUsageTracker(); } - ITestServicePtr GetTestService() + ITestServicePtr GetTestService() const { - return Host_.GetTestService(); + return TestService_; } - IServerPtr GetServer() + IServerPtr GetServer() const { - return Host_.GetServer(); + return Host_->GetServer(); } static bool CheckCancelCode(TErrorCode code) @@ -267,7 +138,9 @@ public: private: NConcurrency::IThreadPoolPtr WorkerPool_; - TTestServerHost Host_; + TTestNodeMemoryTrackerPtr MemoryUsageTracker_; + TTestServerHostPtr Host_; + ITestServicePtr TestService_; }; //////////////////////////////////////////////////////////////////////////////// @@ -279,10 +152,19 @@ public: static constexpr bool AllowTransportErrors = false; static constexpr bool Secure = false; - static IServerPtr CreateServer(ui16 port, IMemoryUsageTrackerPtr memoryUsageTracker) + static TTestServerHostPtr CreateTestServerHost( + NTesting::TPortHolder port, + std::vector<IServicePtr> services, + TTestNodeMemoryTrackerPtr memoryUsageTracker) { auto busServer = MakeBusServer(port, memoryUsageTracker); - return NRpc::NBus::CreateBusServer(busServer); + auto server = NRpc::NBus::CreateBusServer(busServer); + + return New<TTestServerHost>( + std::move(port), + server, + services, + memoryUsageTracker); } static IChannelPtr CreateChannel( @@ -500,9 +382,10 @@ public: return NGrpc::CreateGrpcChannel(channelConfig); } - static IServerPtr CreateServer( - ui16 port, - IMemoryUsageTrackerPtr /*memoryUsageTracker*/) + static TTestServerHostPtr CreateTestServerHost( + NTesting::TPortHolder port, + std::vector<IServicePtr> services, + TTestNodeMemoryTrackerPtr memoryUsageTracker) { auto serverAddressConfig = New<NGrpc::TServerAddressConfig>(); if (EnableSsl) { @@ -524,7 +407,13 @@ public: auto serverConfig = New<NGrpc::TServerConfig>(); serverConfig->Addresses.push_back(serverAddressConfig); - return NGrpc::CreateServer(serverConfig); + + auto server = NGrpc::CreateServer(serverConfig); + return New<TTestServerHost>( + std::move(port), + std::move(server), + std::move(services), + std::move(memoryUsageTracker)); } }; diff --git a/yt/yt/core/rpc/unittests/roaming_channel_ut.cpp b/yt/yt/core/rpc/unittests/roaming_channel_ut.cpp index de5fdfbfc5..3c1d214ab9 100644 --- a/yt/yt/core/rpc/unittests/roaming_channel_ut.cpp +++ b/yt/yt/core/rpc/unittests/roaming_channel_ut.cpp @@ -134,7 +134,7 @@ private: }; template <class TImpl> -using TRpcTest = TTestBase<TImpl>; +using TRpcTest = TRpcTestBase<TImpl>; TYPED_TEST_SUITE(TRpcTest, TAllTransports); TYPED_TEST(TRpcTest, RoamingChannelNever) diff --git a/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp b/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp index f8bc40e8ed..b2ad0c5ddf 100644 --- a/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_allocation_tags_ut.cpp @@ -24,7 +24,7 @@ constexpr auto MemoryAllocationTag = "memory_tag"; //////////////////////////////////////////////////////////////////////////////// template <class TImpl> -using TRpcTest = TTestBase<TImpl>; +using TRpcTest = TRpcTestBase<TImpl>; TYPED_TEST_SUITE(TRpcTest, TAllTransports); TYPED_TEST(TRpcTest, ResponseWithAllocationTags) diff --git a/yt/yt/core/rpc/unittests/rpc_shutdown_ut.cpp b/yt/yt/core/rpc/unittests/rpc_shutdown_ut.cpp index 5a2d043847..cbd48d41cb 100644 --- a/yt/yt/core/rpc/unittests/rpc_shutdown_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_shutdown_ut.cpp @@ -8,7 +8,7 @@ namespace NYT::NRpc { namespace { template <class TImpl> -using TRpcShutdownTest = TTestBase<TImpl>; +using TRpcShutdownTest = TRpcTestBase<TImpl>; TYPED_TEST_SUITE(TRpcShutdownTest, TAllTransports); diff --git a/yt/yt/core/rpc/unittests/rpc_ut.cpp b/yt/yt/core/rpc/unittests/rpc_ut.cpp index 7bd52df717..f4c7cf1338 100644 --- a/yt/yt/core/rpc/unittests/rpc_ut.cpp +++ b/yt/yt/core/rpc/unittests/rpc_ut.cpp @@ -1,5 +1,7 @@ #include <yt/yt/core/rpc/unittests/lib/common.h> +#include <random> + namespace NYT::NRpc { namespace { @@ -41,13 +43,13 @@ TString StringFromSharedRef(const TSharedRef& sharedRef) //////////////////////////////////////////////////////////////////////////////// template <class TImpl> -using TRpcTest = TTestBase<TImpl>; +using TRpcTest = TRpcTestBase<TImpl>; template <class TImpl> -using TNotUdsTest = TTestBase<TImpl>; +using TNotUdsTest = TRpcTestBase<TImpl>; template <class TImpl> -using TNotGrpcTest = TTestBase<TImpl>; +using TNotGrpcTest = TRpcTestBase<TImpl>; template <class TImpl> -using TGrpcTest = TTestBase<TImpl>; +using TGrpcTest = TRpcTestBase<TImpl>; TYPED_TEST_SUITE(TRpcTest, TAllTransports); TYPED_TEST_SUITE(TNotUdsTest, TWithoutUds); TYPED_TEST_SUITE(TNotGrpcTest, TWithoutGrpc); diff --git a/yt/yt/core/test_framework/test_memory_tracker.cpp b/yt/yt/core/test_framework/test_memory_tracker.cpp new file mode 100644 index 0000000000..e6ba6b8786 --- /dev/null +++ b/yt/yt/core/test_framework/test_memory_tracker.cpp @@ -0,0 +1,133 @@ +#include "test_memory_tracker.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TTestNodeMemoryTracker::TTestNodeMemoryTracker(size_t limit) + : Usage_(0) + , Limit_(limit) +{ } + +i64 TTestNodeMemoryTracker::GetLimit() const +{ + auto guard = Guard(Lock_); + return Limit_; +} + +i64 TTestNodeMemoryTracker::GetUsed() const +{ + auto guard = Guard(Lock_); + return Usage_; +} + +i64 TTestNodeMemoryTracker::GetFree() const +{ + auto guard = Guard(Lock_); + return Limit_ - Usage_; +} + +bool TTestNodeMemoryTracker::IsExceeded() const +{ + auto guard = Guard(Lock_); + return Limit_ - Usage_ <= 0; +} + +TError TTestNodeMemoryTracker::TryAcquire(i64 size) +{ + auto guard = Guard(Lock_); + return DoTryAcquire(size); +} + +TError TTestNodeMemoryTracker::DoTryAcquire(i64 size) +{ + if (Usage_ + size >= Limit_) { + return TError("Memory exceeded"); + } + + Usage_ += size; + TotalUsage_ += size; + + return {}; +} + +TError TTestNodeMemoryTracker::TryChange(i64 size) +{ + auto guard = Guard(Lock_); + + if (size > Usage_) { + return DoTryAcquire(size - Usage_); + } else if (size < Usage_) { + DoRelease(Usage_ - size); + } + + return {}; +} + +bool TTestNodeMemoryTracker::Acquire(i64 size) +{ + auto guard = Guard(Lock_); + DoAcquire(size); + return Usage_ >= Limit_; +} + +void TTestNodeMemoryTracker::Release(i64 size) +{ + auto guard = Guard(Lock_); + DoRelease(size); +} + +void TTestNodeMemoryTracker::SetLimit(i64 size) +{ + auto guard = Guard(Lock_); + Limit_ = size; +} + +void TTestNodeMemoryTracker::DoAcquire(i64 size) +{ + Usage_ += size; + TotalUsage_ += size; +} + +void TTestNodeMemoryTracker::DoRelease(i64 size) +{ + Usage_ -= size; +} + +void TTestNodeMemoryTracker::ClearTotalUsage() +{ + TotalUsage_ = 0; +} + +i64 TTestNodeMemoryTracker::GetTotalUsage() const +{ + return TotalUsage_; +} + +TSharedRef TTestNodeMemoryTracker::Track(TSharedRef reference, bool keepExistingTracking) +{ + if (!reference) { + return reference; + } + + auto rawReference = TRef(reference); + const auto& holder = reference.GetHolder(); + + // Reference could be without a holder, e.g. empty reference. + if (!holder) { + YT_VERIFY(reference.Begin() == TRef::MakeEmpty().Begin()); + return reference; + } + + auto guard = TMemoryUsageTrackerGuard::Acquire(this, reference.Size()); + + auto underlyingHolder = holder->Clone({.KeepMemoryReferenceTracking = keepExistingTracking}); + auto underlyingReference = TSharedRef(rawReference, std::move(underlyingHolder)); + return TSharedRef( + rawReference, + New<TTestTrackedReferenceHolder>(std::move(underlyingReference), std::move(guard))); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/core/test_framework/test_memory_tracker.h b/yt/yt/core/test_framework/test_memory_tracker.h new file mode 100644 index 0000000000..1ef5f635b1 --- /dev/null +++ b/yt/yt/core/test_framework/test_memory_tracker.h @@ -0,0 +1,82 @@ +#pragma once + +#include <yt/yt/core/test_framework/framework.h> + +#include <yt/yt/core/bus/public.h> + +#include <yt/yt/core/misc/fs.h> +#include <yt/yt/core/misc/memory_usage_tracker.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TTestNodeMemoryTracker + : public IMemoryUsageTracker +{ +public: + explicit TTestNodeMemoryTracker(size_t limit); + + i64 GetLimit() const override; + i64 GetUsed() const override; + i64 GetFree() const override; + bool IsExceeded() const override; + + TError TryAcquire(i64 size) override; + TError TryChange(i64 size) override; + bool Acquire(i64 size) override; + void Release(i64 size) override; + void SetLimit(i64 size) override; + + void ClearTotalUsage(); + i64 GetTotalUsage() const; + + TSharedRef Track( + TSharedRef reference, + bool keepHolder = false) override; +private: + class TTestTrackedReferenceHolder + : public TSharedRangeHolder + { + public: + TTestTrackedReferenceHolder( + TSharedRef underlying, + TMemoryUsageTrackerGuard guard) + : Underlying_(std::move(underlying)) + , Guard_(std::move(guard)) + { } + + TSharedRangeHolderPtr Clone(const TSharedRangeHolderCloneOptions& options) override + { + if (options.KeepMemoryReferenceTracking) { + return this; + } + return Underlying_.GetHolder()->Clone(options); + } + + std::optional<size_t> GetTotalByteSize() const override + { + return Underlying_.GetHolder()->GetTotalByteSize(); + } + + private: + const TSharedRef Underlying_; + const TMemoryUsageTrackerGuard Guard_; + }; + + YT_DECLARE_SPIN_LOCK(NThreading::TSpinLock, Lock_); + i64 Usage_; + i64 Limit_; + i64 TotalUsage_; + + TError DoTryAcquire(i64 size); + void DoAcquire(i64 size); + void DoRelease(i64 size); +}; + +DECLARE_REFCOUNTED_CLASS(TTestNodeMemoryTracker) +DEFINE_REFCOUNTED_TYPE(TTestNodeMemoryTracker) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/core/test_framework/test_server_host.cpp b/yt/yt/core/test_framework/test_server_host.cpp new file mode 100644 index 0000000000..98977b3f64 --- /dev/null +++ b/yt/yt/core/test_framework/test_server_host.cpp @@ -0,0 +1,58 @@ +#include "test_server_host.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +TTestServerHost::TTestServerHost( + NTesting::TPortHolder port, + NRpc::IServerPtr server, + std::vector<NRpc::IServicePtr> services, + TTestNodeMemoryTrackerPtr memoryUsageTracker) + : Port_(std::move(port)) + , Server_(std::move(server)) + , Services_(std::move(services)) + , MemoryUsageTracker_(std::move(memoryUsageTracker)) +{ + InitializeServer(); +} + +void TTestServerHost::InitializeServer() +{ + for (const auto& service : Services_) { + Server_->RegisterService(service); + } + + Server_->Start(); +} + +void TTestServerHost::TearDown() +{ + Server_->Stop().Get().ThrowOnError(); + Server_.Reset(); + Port_.Reset(); +} + +TTestNodeMemoryTrackerPtr TTestServerHost::GetMemoryUsageTracker() const +{ + return MemoryUsageTracker_; +} + +TString TTestServerHost::GetAddress() const +{ + return Format("localhost:%v", static_cast<ui16>(Port_)); +} + +std::vector<NRpc::IServicePtr> TTestServerHost::GetServices() const +{ + return Services_; +} + +NRpc::IServerPtr TTestServerHost::GetServer() const +{ + return Server_; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/yt/yt/core/test_framework/test_server_host.h b/yt/yt/core/test_framework/test_server_host.h new file mode 100644 index 0000000000..670e6328fb --- /dev/null +++ b/yt/yt/core/test_framework/test_server_host.h @@ -0,0 +1,46 @@ +#pragma once + +#include "test_memory_tracker.h" + +#include <yt/yt/core/bus/public.h> + +#include <yt/yt/core/rpc/service_detail.h> + +#include <library/cpp/testing/common/network.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +class TTestServerHost + : public TRefCounted +{ +public: + TTestServerHost( + NTesting::TPortHolder port, + NRpc::IServerPtr server, + std::vector<NRpc::IServicePtr> services, + TTestNodeMemoryTrackerPtr memoryUsageTracker); + + void TearDown(); + + TTestNodeMemoryTrackerPtr GetMemoryUsageTracker() const; + TString GetAddress() const; + std::vector<NRpc::IServicePtr> GetServices() const; + NRpc::IServerPtr GetServer() const; + +protected: + NTesting::TPortHolder Port_; + NRpc::IServerPtr Server_; + const std::vector<NRpc::IServicePtr> Services_; + const TTestNodeMemoryTrackerPtr MemoryUsageTracker_; + + void InitializeServer(); +}; + +DECLARE_REFCOUNTED_CLASS(TTestServerHost) +DEFINE_REFCOUNTED_TYPE(TTestServerHost) + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT::NRpc diff --git a/yt/yt/core/test_framework/ya.make b/yt/yt/core/test_framework/ya.make index 29e0366771..909b739a08 100644 --- a/yt/yt/core/test_framework/ya.make +++ b/yt/yt/core/test_framework/ya.make @@ -4,6 +4,8 @@ INCLUDE(${ARCADIA_ROOT}/yt/ya_cpp.make.inc) SRCS( fixed_growth_string_output.cpp + test_memory_tracker.cpp + test_server_host.cpp GLOBAL framework.cpp ) |