diff options
author | babenko <babenko@yandex-team.com> | 2024-11-06 00:23:24 +0300 |
---|---|---|
committer | babenko <babenko@yandex-team.com> | 2024-11-06 00:35:51 +0300 |
commit | 0dfaae1982275e6fa2dd300e462594c0c6559bd8 (patch) | |
tree | baef0a069394dd175618afcb92e566e6be6851d2 | |
parent | c3e0efb6adf3a6c9840911d45dc3c19935cb5a08 (diff) | |
download | ydb-0dfaae1982275e6fa2dd300e462594c0c6559bd8.tar.gz |
Fix incoming request profiling: increment counters even for non-accepted requests (2nd attempt)
commit_hash:202ccff8f8d6833ffea6c63601d15c37c56d802a
-rw-r--r-- | yt/yt/core/rpc/config.h | 2 | ||||
-rw-r--r-- | yt/yt/core/rpc/service_detail.cpp | 505 | ||||
-rw-r--r-- | yt/yt/core/rpc/service_detail.h | 80 |
3 files changed, 315 insertions, 272 deletions
diff --git a/yt/yt/core/rpc/config.h b/yt/yt/core/rpc/config.h index c3ce1d08d2..203d62b8bc 100644 --- a/yt/yt/core/rpc/config.h +++ b/yt/yt/core/rpc/config.h @@ -132,7 +132,7 @@ public: std::optional<bool> EnableErrorCodeCounter; std::optional<ERequestTracingMode> TracingMode; TTimeHistogramConfigPtr TimeHistogram; - THashMap<TString, TMethodConfigPtr> Methods; + THashMap<std::string, TMethodConfigPtr> Methods; std::optional<int> AuthenticationQueueSizeLimit; std::optional<TDuration> PendingPayloadsTimeout; std::optional<bool> Pooled; diff --git a/yt/yt/core/rpc/service_detail.cpp b/yt/yt/core/rpc/service_detail.cpp index 5fe6fef0ee..b51529f85d 100644 --- a/yt/yt/core/rpc/service_detail.cpp +++ b/yt/yt/core/rpc/service_detail.cpp @@ -49,6 +49,11 @@ using NYT::ToProto; static const auto DefaultLoggingSuppressionFailedRequestThrottlerConfig = TThroughputThrottlerConfig::Create(1'000); constexpr int MaxUserAgentLength = 200; +constexpr TStringBuf UnknownUserAgent = "<unknown>"; + +constexpr TStringBuf UnknownUserName = "<unknown>"; +constexpr TStringBuf UnknownMethodName = "<unknown>"; + constexpr auto ServiceLivenessCheckPeriod = TDuration::MilliSeconds(100); //////////////////////////////////////////////////////////////////////////////// @@ -294,15 +299,15 @@ auto TServiceBase::TMethodDescriptor::SetHandleMethodError(bool value) const -> //////////////////////////////////////////////////////////////////////////////// -TServiceBase::TErrorCodeCounter::TErrorCodeCounter(NProfiling::TProfiler profiler) +TServiceBase::TErrorCodeCounters::TErrorCodeCounters(NProfiling::TProfiler profiler) : Profiler_(std::move(profiler)) { } -void TServiceBase::TErrorCodeCounter::Increment(TErrorCode code) +NProfiling::TCounter* TServiceBase::TErrorCodeCounters::GetCounter(TErrorCode code) { - CodeToCounter_.FindOrInsert(code, [&] { + return CodeToCounter_.FindOrInsert(code, [&] { return Profiler_.WithTag("code", ToString(code)).Counter("/code_count"); - }).first->Increment(); + }).first; } //////////////////////////////////////////////////////////////////////////////// @@ -319,7 +324,7 @@ TServiceBase::TMethodPerformanceCounters::TMethodPerformanceCounters( , RequestMessageAttachmentSizeCounter(profiler.Counter("/request_message_attachment_bytes")) , ResponseMessageBodySizeCounter(profiler.Counter("/response_message_body_bytes")) , ResponseMessageAttachmentSizeCounter(profiler.Counter("/response_message_attachment_bytes")) - , ErrorCodeCounter(profiler) + , ErrorCodeCounters(profiler) { if (timeHistogramConfig && timeHistogramConfig->CustomBounds) { const auto& customBounds = *timeHistogramConfig->CustomBounds; @@ -355,7 +360,7 @@ TServiceBase::TRuntimeMethodInfo::TRuntimeMethodInfo( Format("%v.%v ->", ServiceId.ServiceName, Descriptor.Method))) , RequestQueueSizeLimitErrorCounter(Profiler.Counter("/request_queue_size_errors")) , RequestQueueByteSizeLimitErrorCounter(Profiler.Counter("/request_queue_byte_size_errors")) - , UnauthenticatedRequestsCounter(Profiler.Counter("/unauthenticated_requests")) + , UnauthenticatedRequestCounter(Profiler.Counter("/unauthenticated_request_count")) , LoggingSuppressionFailedRequestThrottler( CreateReconfigurableThroughputThrottler( DefaultLoggingSuppressionFailedRequestThrottlerConfig)) @@ -368,33 +373,44 @@ TRequestQueue* TServiceBase::TRuntimeMethodInfo::GetDefaultRequestQueue() //////////////////////////////////////////////////////////////////////////////// +TServiceBase::TPerformanceCounters::TPerformanceCounters(const NProfiling::TProfiler& profiler) + : Profiler_(profiler.WithHot().WithSparse()) +{ } + +NProfiling::TCounter* TServiceBase::TPerformanceCounters::GetRequestsPerUserAgentCounter(TStringBuf userAgent) +{ + return RequestsPerUserAgent_.FindOrInsert(userAgent, [&] { + return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent"); + }).first; +} + +//////////////////////////////////////////////////////////////////////////////// + class TServiceBase::TServiceContext : public TServiceContextBase { public: TServiceContext( TServiceBasePtr&& service, - TAcceptedRequest&& acceptedRequest, + TIncomingRequest&& incomingRequest, NLogging::TLogger logger) : TServiceContextBase( - std::move(acceptedRequest.Header), - std::move(acceptedRequest.Message), - std::move(acceptedRequest.MemoryGuard), - std::move(acceptedRequest.MemoryUsageTracker), + std::move(incomingRequest.Header), + std::move(incomingRequest.Message), + std::move(incomingRequest.MemoryGuard), + std::move(incomingRequest.MemoryUsageTracker), std::move(logger), - acceptedRequest.RuntimeInfo->LogLevel.load(std::memory_order::relaxed)) + incomingRequest.RuntimeInfo->LogLevel.load(std::memory_order::relaxed)) , Service_(std::move(service)) - , RequestId_(acceptedRequest.RequestId) - , ReplyBus_(std::move(acceptedRequest.ReplyBus)) - , RuntimeInfo_(acceptedRequest.RuntimeInfo) - , TraceContext_(std::move(acceptedRequest.TraceContext)) - , RequestQueue_(acceptedRequest.RequestQueue) - , ThrottledError_(std::move(acceptedRequest.ThrottledError)) - , MethodPerformanceCounters_(Service_->GetMethodPerformanceCounters( - RuntimeInfo_, - {GetAuthenticationIdentity().UserTag, RequestQueue_})) + , RequestId_(incomingRequest.RequestId) + , ReplyBus_(std::move(incomingRequest.ReplyBus)) + , RuntimeInfo_(incomingRequest.RuntimeInfo) + , TraceContext_(std::move(incomingRequest.TraceContext)) + , RequestQueue_(incomingRequest.RequestQueue) + , ThrottledError_(std::move(incomingRequest.ThrottledError)) + , MethodPerformanceCounters_(incomingRequest.MethodPerformanceCounters) , PerformanceCounters_(Service_->GetPerformanceCounters()) - , ArriveInstant_(NProfiling::GetInstant()) + , ArriveInstant_(incomingRequest.ArriveInstant) { YT_ASSERT(RequestMessage_); YT_ASSERT(ReplyBus_); @@ -782,24 +798,6 @@ private: void Initialize() { - constexpr TStringBuf UnknownUserAgent = "unknown"; - auto userAgent = RequestHeader_->has_user_agent() - ? TStringBuf(RequestHeader_->user_agent()) - : UnknownUserAgent; - PerformanceCounters_->IncrementRequestsPerUserAgent(userAgent.SubString(0, MaxUserAgentLength)); - - MethodPerformanceCounters_->RequestCounter.Increment(); - MethodPerformanceCounters_->RequestMessageBodySizeCounter.Increment( - GetMessageBodySize(RequestMessage_)); - MethodPerformanceCounters_->RequestMessageAttachmentSizeCounter.Increment( - GetTotalMessageAttachmentSize(RequestMessage_)); - - if (RequestHeader_->has_start_time()) { - auto retryStart = FromProto<TInstant>(RequestHeader_->start_time()); - auto now = NProfiling::GetInstant(); - MethodPerformanceCounters_->RemoteWaitTimeCounter.Record(now - retryStart); - } - // COMPAT(danilalexeev): legacy RPC codecs RequestCodec_ = RequestHeader_->has_request_codec() ? CheckedEnumCast<NCompression::ECodec>(RequestHeader_->request_codec()) @@ -1066,7 +1064,8 @@ private: MethodPerformanceCounters_->TotalTimeCounter.Record(*TotalTime_); if (!Error_.IsOK()) { if (Service_->EnableErrorCodeCounter_.load()) { - MethodPerformanceCounters_->ErrorCodeCounter.Increment(Error_.GetNonTrivialCode()); + const auto* counter = MethodPerformanceCounters_->ErrorCodeCounters.GetCounter(Error_.GetNonTrivialCode()); + counter->Increment(); } else { MethodPerformanceCounters_->FailedRequestCounter.Increment(); } @@ -1658,23 +1657,20 @@ void TRequestQueue::SubscribeToThrottlers() //////////////////////////////////////////////////////////////////////////////// -struct TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals +bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()( + const TNonowningPerformanceCountersKey& lhs, + const TNonowningPerformanceCountersKey& rhs) const { - bool operator()( - const TNonowningPerformanceCountersKey& lhs, - const TNonowningPerformanceCountersKey& rhs) const - { - return lhs == rhs; - } + return lhs == rhs; +} - bool operator()( - const TOwningPerformanceCountersKey& lhs, - const TNonowningPerformanceCountersKey& rhs) const - { - const auto& [lhsUserTag, lhsRequestQueue] = lhs; - return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs; - } -}; +bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()( + const TOwningPerformanceCountersKey& lhs, + const TNonowningPerformanceCountersKey& rhs) const +{ + const auto& [lhsUserTag, lhsRequestQueue] = lhs; + return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs; +} //////////////////////////////////////////////////////////////////////////////// @@ -1698,6 +1694,7 @@ TServiceBase::TServiceBase( BIND(&TServiceBase::OnServiceLivenessCheck, MakeWeak(this)), ServiceLivenessCheckPeriod)) , PerformanceCounters_(New<TServiceBase::TPerformanceCounters>(RpcServerProfiler)) + , UnknownMethodPerformanceCounters_(CreateUnknownMethodPerformanceCounters()) { RegisterMethod(RPC_SERVICE_METHOD_DESC(Discover) .SetInvoker(TDispatcher::Get()->GetHeavyInvoker()) @@ -1723,94 +1720,103 @@ void TServiceBase::HandleRequest( { SetActive(); - auto method = FromProto<TString>(header->method()); + auto arriveInstant = NProfiling::GetInstant(); auto requestId = FromProto<TRequestId>(header->request_id()); + auto userAgent = header->has_user_agent() + ? TStringBuf(header->user_agent()).SubString(0, MaxUserAgentLength) + : UnknownUserAgent; + auto method = TStringBuf(header->method()); + auto user = header->has_user() + ? TStringBuf(header->user()) + : (Authenticator_ ? UnknownUserName : RootUserName); + auto userTag = header->has_user_tag() + ? TStringBuf(header->user_tag()) + : user; + + DoHandleRequest(TIncomingRequest{ + .ArriveInstant = arriveInstant, + .RequestId = requestId, + .ReplyBus = std::move(replyBus), + .Header = std::move(header), + .UserAgent = userAgent, + .Method = method, + .User = user, + .UserTag = userTag, + .Message = std::move(message), + .MemoryUsageTracker = MemoryUsageTracker_, + }); +} - auto replyError = [&] (TError error) { - ReplyError(std::move(error), *header, replyBus); - }; - +void TServiceBase::DoHandleRequest(TIncomingRequest&& incomingRequest) +{ if (IsStopped()) { - replyError(TError( - NRpc::EErrorCode::Unavailable, - "Service is stopped")); + ReplyError( + TError(NRpc::EErrorCode::Unavailable, "Service is stopped"), + std::move(incomingRequest)); return; } - if (auto error = DoCheckRequestCompatibility(*header); !error.IsOK()) { - replyError(std::move(error)); + incomingRequest.RuntimeInfo = FindMethodInfo(incomingRequest.Method); + if (!incomingRequest.RuntimeInfo) { + ReplyError( + TError(NRpc::EErrorCode::NoSuchMethod, "Unknown method"), + std::move(incomingRequest)); return; } - auto* runtimeInfo = FindMethodInfo(method); - if (!runtimeInfo) { - replyError(TError( - NRpc::EErrorCode::NoSuchMethod, - "Unknown method")); + incomingRequest.RequestQueue = GetRequestQueue(incomingRequest.RuntimeInfo, *incomingRequest.Header); + + if (auto error = DoCheckRequestCompatibility(*incomingRequest.Header); !error.IsOK()) { + ReplyError(std::move(error), std::move(incomingRequest)); return; } - auto memoryGuard = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, TypicalRequestSize); - message = TrackMemory(MemoryUsageTracker_, std::move(message)); + incomingRequest.MemoryGuard = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, TypicalRequestSize); + incomingRequest.Message = TrackMemory(MemoryUsageTracker_, std::move(incomingRequest.Message)); + if (MemoryUsageTracker_ && MemoryUsageTracker_->IsExceeded()) { - return replyError(TError( - NRpc::EErrorCode::MemoryPressure, - "Request is dropped due to high memory pressure")); + ReplyError( + TError(NRpc::EErrorCode::MemoryPressure, "Request is dropped due to high memory pressure"), + std::move(incomingRequest)); + return; } - auto tracingMode = runtimeInfo->TracingMode.load(std::memory_order::relaxed); - auto traceContext = tracingMode == ERequestTracingMode::Disable - ? NTracing::TTraceContextPtr() - : GetOrCreateHandlerTraceContext(*header, tracingMode == ERequestTracingMode::Force); - if (traceContext && traceContext->IsRecorded()) { - traceContext->AddTag(EndpointAnnotation, replyBus->GetEndpointDescription()); + if (auto tracingMode = incomingRequest.RuntimeInfo->TracingMode.load(std::memory_order::relaxed); tracingMode != ERequestTracingMode::Disable) { + incomingRequest.TraceContext = GetOrCreateHandlerTraceContext(*incomingRequest.Header, tracingMode == ERequestTracingMode::Force); } - auto* requestQueue = GetRequestQueue(runtimeInfo, *header); - RegisterRequestQueue(runtimeInfo, requestQueue); + if (incomingRequest.TraceContext && incomingRequest.TraceContext->IsRecorded()) { + incomingRequest.TraceContext->AddTag(EndpointAnnotation, incomingRequest.ReplyBus->GetEndpointDescription()); + } - auto maybeThrottled = GetThrottledError(*header); + incomingRequest.ThrottledError = GetThrottledError(*incomingRequest.Header); - if (requestQueue->IsQueueSizeLimitExceeded()) { - runtimeInfo->RequestQueueSizeLimitErrorCounter.Increment(); - replyError(TError( - NRpc::EErrorCode::RequestQueueSizeLimitExceeded, - "Request queue size limit exceeded") - << TErrorAttribute("limit", runtimeInfo->QueueSizeLimit.load()) - << TErrorAttribute("queue", requestQueue->GetName()) - << maybeThrottled); + if (incomingRequest.RequestQueue->IsQueueSizeLimitExceeded()) { + incomingRequest.RuntimeInfo->RequestQueueSizeLimitErrorCounter.Increment(); + ReplyError( + TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Request queue size limit exceeded") + << TErrorAttribute("limit", incomingRequest.RuntimeInfo->QueueSizeLimit.load()) + << TErrorAttribute("queue", incomingRequest.RequestQueue->GetName()) + << incomingRequest.ThrottledError, + std::move(incomingRequest)); return; } - if (requestQueue->IsQueueByteSizeLimitExceeded()) { - runtimeInfo->RequestQueueByteSizeLimitErrorCounter.Increment(); - replyError(TError( - NRpc::EErrorCode::RequestQueueSizeLimitExceeded, - "Request queue bytes size limit exceeded") - << TErrorAttribute("limit", runtimeInfo->QueueByteSizeLimit.load()) - << TErrorAttribute("queue", requestQueue->GetName()) - << maybeThrottled); + if (incomingRequest.RequestQueue->IsQueueByteSizeLimitExceeded()) { + incomingRequest.RuntimeInfo->RequestQueueByteSizeLimitErrorCounter.Increment(); + ReplyError( + TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Request queue bytes size limit exceeded") + << TErrorAttribute("limit", incomingRequest.RuntimeInfo->QueueByteSizeLimit.load()) + << TErrorAttribute("queue", incomingRequest.RequestQueue->GetName()) + << incomingRequest.ThrottledError, + std::move(incomingRequest)); return; } - TCurrentTraceContextGuard traceContextGuard(traceContext); - - // NOTE: Do not use replyError() after this line. - TAcceptedRequest acceptedRequest{ - .RequestId = requestId, - .ReplyBus = std::move(replyBus), - .RuntimeInfo = std::move(runtimeInfo), - .TraceContext = std::move(traceContext), - .Header = std::move(header), - .Message = std::move(message), - .RequestQueue = requestQueue, - .ThrottledError = maybeThrottled, - .MemoryGuard = std::move(memoryGuard), - .MemoryUsageTracker = MemoryUsageTracker_, - }; + TCurrentTraceContextGuard traceContextGuard(incomingRequest.TraceContext); - if (!IsAuthenticationNeeded(acceptedRequest)) { - HandleAuthenticatedRequest(std::move(acceptedRequest)); + if (!IsAuthenticationNeeded(incomingRequest)) { + HandleAuthenticatedRequest(std::move(incomingRequest)); return; } @@ -1818,11 +1824,10 @@ void TServiceBase::HandleRequest( auto authenticationQueueSizeLimit = AuthenticationQueueSizeLimit_.load(std::memory_order::relaxed); auto authenticationQueueSize = AuthenticationQueueSize_.load(std::memory_order::relaxed); if (authenticationQueueSize > authenticationQueueSizeLimit) { - auto error = TError( - NRpc::EErrorCode::RequestQueueSizeLimitExceeded, - "Authentication request queue size limit exceeded") - << TErrorAttribute("limit", authenticationQueueSizeLimit); - ReplyError(error, *acceptedRequest.Header, acceptedRequest.ReplyBus); + ReplyError( + TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Authentication request queue size limit exceeded") + << TErrorAttribute("limit", authenticationQueueSizeLimit), + std::move(incomingRequest)); return; } ++AuthenticationQueueSize_; @@ -1830,37 +1835,35 @@ void TServiceBase::HandleRequest( NProfiling::TWallTimer timer; TAuthenticationContext authenticationContext{ - .Header = acceptedRequest.Header.get(), - .UserIP = acceptedRequest.ReplyBus->GetEndpointNetworkAddress(), - .IsLocal = acceptedRequest.ReplyBus->IsEndpointLocal(), + .Header = incomingRequest.Header.get(), + .UserIP = incomingRequest.ReplyBus->GetEndpointNetworkAddress(), + .IsLocal = incomingRequest.ReplyBus->IsEndpointLocal(), }; if (Authenticator_->CanAuthenticate(authenticationContext)) { auto asyncAuthResult = Authenticator_->AsyncAuthenticate(authenticationContext); if (asyncAuthResult.IsSet()) { - OnRequestAuthenticated(timer, std::move(acceptedRequest), asyncAuthResult.Get()); + OnRequestAuthenticated(timer, std::move(incomingRequest), asyncAuthResult.Get()); } else { asyncAuthResult.Subscribe( - BIND(&TServiceBase::OnRequestAuthenticated, MakeStrong(this), timer, Passed(std::move(acceptedRequest)))); + BIND(&TServiceBase::OnRequestAuthenticated, MakeStrong(this), timer, Passed(std::move(incomingRequest)))); } } else { - OnRequestAuthenticated(timer, std::move(acceptedRequest), TError( + OnRequestAuthenticated(timer, std::move(incomingRequest), TError( NYT::NRpc::EErrorCode::AuthenticationError, "Request is missing credentials")); } } -void TServiceBase::ReplyError( - TError error, - const NProto::TRequestHeader& header, - const IBusPtr& replyBus) +void TServiceBase::ReplyError(TError error, TIncomingRequest&& incomingRequest) { - auto requestId = FromProto<TRequestId>(header.request_id()); + ProfileRequest(&incomingRequest); + auto richError = std::move(error) - << TErrorAttribute("request_id", requestId) + << TErrorAttribute("request_id", incomingRequest.RequestId) << TErrorAttribute("realm_id", ServiceId_.RealmId) << TErrorAttribute("service", ServiceId_.ServiceName) - << TErrorAttribute("method", header.method()) - << TErrorAttribute("endpoint", replyBus->GetEndpointDescription()); + << TErrorAttribute("method", incomingRequest.Method) + << TErrorAttribute("endpoint", incomingRequest.ReplyBus->GetEndpointDescription()); auto code = richError.GetCode(); auto logLevel = @@ -1869,8 +1872,8 @@ void TServiceBase::ReplyError( : NLogging::ELogLevel::Debug; YT_LOG_EVENT(Logger, logLevel, richError); - auto errorMessage = CreateErrorResponseMessage(requestId, richError); - YT_UNUSED_FUTURE(replyBus->Send(errorMessage)); + auto errorMessage = CreateErrorResponseMessage(incomingRequest.RequestId, richError); + YT_UNUSED_FUTURE(incomingRequest.ReplyBus->Send(errorMessage)); } void TServiceBase::OnMethodError(const TError& /*error*/, const TString& /*method*/) @@ -1878,76 +1881,70 @@ void TServiceBase::OnMethodError(const TError& /*error*/, const TString& /*metho void TServiceBase::OnRequestAuthenticated( const NProfiling::TWallTimer& timer, - TAcceptedRequest&& acceptedRequest, + TIncomingRequest&& incomingRequest, const TErrorOr<TAuthenticationResult>& authResultOrError) { AuthenticationTimer_.Record(timer.GetElapsedTime()); --AuthenticationQueueSize_; - auto& requestHeader = *acceptedRequest.Header; - - if (authResultOrError.IsOK()) { - const auto& authResult = authResultOrError.Value(); - const auto& Logger = RpcServerLogger; - YT_LOG_DEBUG("Request authenticated (RequestId: %v, User: %v, Realm: %v)", - acceptedRequest.RequestId, - authResult.User, - authResult.Realm); - const auto& authenticatedUser = authResult.User; - if (requestHeader.has_user()) { - const auto& user = requestHeader.user(); - if (user != authenticatedUser) { - ReplyError( - TError( - NRpc::EErrorCode::AuthenticationError, - "Manually specified and authenticated users mismatch") - << TErrorAttribute("user", user) - << TErrorAttribute("authenticated_user", authenticatedUser), - requestHeader, - acceptedRequest.ReplyBus); - return; - } - } - requestHeader.set_user(ToProto(authResult.User)); - - auto* credentialsExt = requestHeader.MutableExtension( - NRpc::NProto::TCredentialsExt::credentials_ext); - if (credentialsExt->user_ticket().empty()) { - credentialsExt->set_user_ticket(std::move(authResult.UserTicket)); - } - HandleAuthenticatedRequest(std::move(acceptedRequest)); - } else { + if (!authResultOrError.IsOK()) { ReplyError( - TError( - NRpc::EErrorCode::AuthenticationError, - "Request authentication failed") + TError(NRpc::EErrorCode::AuthenticationError, "Request authentication failed") << authResultOrError, - requestHeader, - acceptedRequest.ReplyBus); + std::move(incomingRequest)); + return; + } + + const auto& authResult = authResultOrError.Value(); + const auto& Logger = RpcServerLogger; + YT_LOG_DEBUG("Request authenticated (RequestId: %v, User: %v, Realm: %v)", + incomingRequest.RequestId, + authResult.User, + authResult.Realm); + const auto& authenticatedUser = authResult.User; + if (incomingRequest.Header->has_user()) { + const auto& user = incomingRequest.Header->user(); + if (user != authenticatedUser) { + ReplyError( + TError(NRpc::EErrorCode::AuthenticationError, "Manually specified and authenticated users mismatch") + << TErrorAttribute("user", user) + << TErrorAttribute("authenticated_user", authenticatedUser), + std::move(incomingRequest)); + return; + } } + + incomingRequest.Header->set_user(ToProto(authResult.User)); + incomingRequest.User = TStringBuf(incomingRequest.Header->user()); + incomingRequest.UserTag = incomingRequest.Header->has_user_tag() + ? incomingRequest.UserTag + : incomingRequest.User; + + auto* credentialsExt = incomingRequest.Header->MutableExtension( + NRpc::NProto::TCredentialsExt::credentials_ext); + if (credentialsExt->user_ticket().empty()) { + credentialsExt->set_user_ticket(std::move(authResult.UserTicket)); + } + + HandleAuthenticatedRequest(std::move(incomingRequest)); } -bool TServiceBase::IsAuthenticationNeeded(const TAcceptedRequest& acceptedRequest) +bool TServiceBase::IsAuthenticationNeeded(const TIncomingRequest& incomingRequest) { return Authenticator_.operator bool() && - !acceptedRequest.RuntimeInfo->Descriptor.System; + !incomingRequest.RuntimeInfo->Descriptor.System; } -void TServiceBase::HandleAuthenticatedRequest(TAcceptedRequest&& acceptedRequest) +void TServiceBase::HandleAuthenticatedRequest(TIncomingRequest&& incomingRequest) { - if (!acceptedRequest.ReplyBus->IsEndpointLocal()) { - bool authenticated = acceptedRequest.Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext) && - acceptedRequest.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext).has_service_ticket(); - if (!authenticated) { - acceptedRequest.RuntimeInfo->UnauthenticatedRequestsCounter.Increment(); - } - } + ProfileRequest(&incomingRequest); auto context = New<TServiceContext>( this, - std::move(acceptedRequest), + std::move(incomingRequest), Logger); + auto* requestQueue = context->GetRequestQueue(); requestQueue->OnRequestArrived(std::move(context)); } @@ -1957,55 +1954,49 @@ TRequestQueue* TServiceBase::GetRequestQueue( const NRpc::NProto::TRequestHeader& requestHeader) { TRequestQueue* requestQueue = nullptr; - if (auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) { + if (const auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) { requestQueue = provider->GetQueue(requestHeader); } if (!requestQueue) { requestQueue = runtimeInfo->DefaultRequestQueue.Get(); } - return requestQueue; -} -void TServiceBase::RegisterRequestQueue( - TRuntimeMethodInfo* runtimeInfo, - TRequestQueue* requestQueue) -{ - if (!requestQueue->Register(this, runtimeInfo)) { - return; - } + if (requestQueue->Register(this, runtimeInfo)) { + const auto& method = runtimeInfo->Descriptor.Method; + YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)", + method, + requestQueue->GetName()); - const auto& method = runtimeInfo->Descriptor.Method; - YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)", - method, - requestQueue->GetName()); + auto profiler = runtimeInfo->Profiler.WithSparse(); + if (runtimeInfo->Descriptor.RequestQueueProvider) { + profiler = profiler.WithTag("queue", requestQueue->GetName()); + } + profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] { + return requestQueue->GetQueueSize(); + }); + profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] { + return requestQueue->GetQueueByteSize(); + }); + profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] { + return requestQueue->GetConcurrency(); + }); + profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] { + return requestQueue->GetConcurrencyByte(); + }); - auto profiler = runtimeInfo->Profiler.WithSparse(); - if (runtimeInfo->Descriptor.RequestQueueProvider) { - profiler = profiler.WithTag("queue", requestQueue->GetName()); - } - profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] { - return requestQueue->GetQueueSize(); - }); - profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] { - return requestQueue->GetQueueByteSize(); - }); - profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] { - return requestQueue->GetConcurrency(); - }); - profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] { - return requestQueue->GetConcurrencyByte(); - }); + TMethodConfigPtr methodConfig; + if (auto config = Config_.Acquire()) { + methodConfig = GetOrDefault(config->Methods, method); + } + ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig); - TMethodConfigPtr methodConfig; - if (auto config = Config_.Acquire()) { - methodConfig = GetOrDefault(config->Methods, method); + { + auto guard = Guard(runtimeInfo->RequestQueuesLock); + runtimeInfo->RequestQueues.push_back(requestQueue); + } } - ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig); - { - auto guard = Guard(runtimeInfo->RequestQueuesLock); - runtimeInfo->RequestQueues.push_back(requestQueue); - } + return requestQueue; } void TServiceBase::ConfigureRequestQueue( @@ -2362,6 +2353,15 @@ void TServiceBase::OnPendingPayloadsLeaseExpired(TRequestId requestId) } } +TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateUnknownMethodPerformanceCounters() +{ + auto profiler = Profiler_ + .WithSparse() + .WithTag("method", std::string(UnknownMethodName), -1) + .WithTag("user", std::string(UnknownUserName)); + return New<TMethodPerformanceCounters>(profiler, TimeHistogramConfig_.Acquire()); +} + TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateMethodPerformanceCounters( TRuntimeMethodInfo* runtimeInfo, const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key) @@ -2380,12 +2380,18 @@ TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateMethodPerformanc } TServiceBase::TMethodPerformanceCounters* TServiceBase::GetMethodPerformanceCounters( - TRuntimeMethodInfo* runtimeInfo, - const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key) + const TIncomingRequest& incomingRequest) { - auto [userTag, requestQueue] = key; + auto* requestQueue = incomingRequest.RequestQueue; + const auto& runtimeInfo = incomingRequest.RuntimeInfo; + + // Handle a partially parsed request. + if (!requestQueue || !runtimeInfo) { + return UnknownMethodPerformanceCounters_.Get(); + } // Fast path. + auto userTag = incomingRequest.UserTag; if (userTag == RootUserName && requestQueue == runtimeInfo->DefaultRequestQueue.Get()) { if (EnablePerUserProfiling_.load(std::memory_order::relaxed)) { return runtimeInfo->RootPerformanceCounters.Get(); @@ -2394,12 +2400,14 @@ TServiceBase::TMethodPerformanceCounters* TServiceBase::GetMethodPerformanceCoun } } + // Slow path. if (!EnablePerUserProfiling_.load(std::memory_order::relaxed)) { userTag = {}; } - auto actualKey = TRuntimeMethodInfo::TNonowningPerformanceCountersKey{userTag, requestQueue}; - return runtimeInfo->PerformanceCountersMap.FindOrInsert(actualKey, [&] { - return CreateMethodPerformanceCounters(runtimeInfo, actualKey); + + auto key = TRuntimeMethodInfo::TNonowningPerformanceCountersKey{userTag, requestQueue}; + return runtimeInfo->PerformanceCountersMap.FindOrInsert(key, [&] { + return CreateMethodPerformanceCounters(runtimeInfo, key); }).first->Get(); } @@ -2408,6 +2416,33 @@ TServiceBase::TPerformanceCounters* TServiceBase::GetPerformanceCounters() return PerformanceCounters_.Get(); } +void TServiceBase::ProfileRequest(TIncomingRequest* incomingRequest) +{ + const auto* requestsPerUserAgentCounter = PerformanceCounters_->GetRequestsPerUserAgentCounter(incomingRequest->UserAgent); + requestsPerUserAgentCounter->Increment(); + + auto* methodPerformanceCounters = GetMethodPerformanceCounters(*incomingRequest); + // NB: Save the counters for future use on response. + incomingRequest->MethodPerformanceCounters = methodPerformanceCounters; + + methodPerformanceCounters->RequestCounter.Increment(); + methodPerformanceCounters->RequestMessageBodySizeCounter.Increment(GetMessageBodySize(incomingRequest->Message)); + methodPerformanceCounters->RequestMessageAttachmentSizeCounter.Increment(GetTotalMessageAttachmentSize(incomingRequest->Message)); + + if (incomingRequest->Header->has_start_time()) { + auto retryStart = FromProto<TInstant>(incomingRequest->Header->start_time()); + methodPerformanceCounters->RemoteWaitTimeCounter.Record(incomingRequest->ArriveInstant - retryStart); + } + + if (!incomingRequest->ReplyBus->IsEndpointLocal()) { + bool authenticated = incomingRequest->Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext) && + incomingRequest->Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext).has_service_ticket(); + if (!authenticated && incomingRequest->RuntimeInfo) { + incomingRequest->RuntimeInfo->UnauthenticatedRequestCounter.Increment(); + } + } +} + void TServiceBase::SetActive() { // Fast path. @@ -2726,13 +2761,13 @@ TFuture<void> TServiceBase::Stop() return StopResult_.ToFuture(); } -TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(const TString& method) +TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(TStringBuf method) { auto it = MethodMap_.find(method); return it == MethodMap_.end() ? nullptr : it->second.Get(); } -TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(const TString& method) +TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(TStringBuf method) { auto* runtimeInfo = FindMethodInfo(method); if (!runtimeInfo) { diff --git a/yt/yt/core/rpc/service_detail.h b/yt/yt/core/rpc/service_detail.h index 76f26d8224..b12bc45ebe 100644 --- a/yt/yt/core/rpc/service_detail.h +++ b/yt/yt/core/rpc/service_detail.h @@ -664,11 +664,12 @@ protected: TMethodDescriptor SetHandleMethodError(bool value) const; }; - struct TErrorCodeCounter + class TErrorCodeCounters { - explicit TErrorCodeCounter(NProfiling::TProfiler profiler); + public: + explicit TErrorCodeCounters(NProfiling::TProfiler profiler); - void Increment(TErrorCode code); + NProfiling::TCounter* GetCounter(TErrorCode code); private: const NProfiling::TProfiler Profiler_; @@ -724,7 +725,7 @@ protected: NProfiling::TCounter ResponseMessageAttachmentSizeCounter; //! Counts the number of errors, per error code. - TErrorCodeCounter ErrorCodeCounter; + TErrorCodeCounters ErrorCodeCounters; }; using TMethodPerformanceCountersPtr = TIntrusivePtr<TMethodPerformanceCounters>; @@ -762,7 +763,7 @@ protected: NProfiling::TCounter RequestQueueSizeLimitErrorCounter; NProfiling::TCounter RequestQueueByteSizeLimitErrorCounter; - NProfiling::TCounter UnauthenticatedRequestsCounter; + NProfiling::TCounter UnauthenticatedRequestCounter; std::atomic<NLogging::ELogLevel> LogLevel = {}; std::atomic<TDuration> LoggingSuppressionTimeout = {}; @@ -770,13 +771,24 @@ protected: using TNonowningPerformanceCountersKey = std::tuple<TStringBuf, TRequestQueue*>; using TOwningPerformanceCountersKey = std::tuple<TString, TRequestQueue*>; using TPerformanceCountersKeyHash = THash<TNonowningPerformanceCountersKey>; - struct TPerformanceCountersKeyEquals; + + struct TPerformanceCountersKeyEquals + { + bool operator()( + const TNonowningPerformanceCountersKey& lhs, + const TNonowningPerformanceCountersKey& rhs) const; + bool operator()( + const TOwningPerformanceCountersKey& lhs, + const TNonowningPerformanceCountersKey& rhs) const; + }; + using TPerformanceCountersMap = NConcurrency::TSyncMap< TOwningPerformanceCountersKey, TMethodPerformanceCountersPtr, TPerformanceCountersKeyHash, TPerformanceCountersKeyEquals >; + TPerformanceCountersMap PerformanceCountersMap; TMethodPerformanceCountersPtr BasePerformanceCounters; TMethodPerformanceCountersPtr RootPerformanceCounters; @@ -796,16 +808,9 @@ protected: : public TRefCounted { public: - explicit TPerformanceCounters(const NProfiling::TProfiler& profiler) - : Profiler_(profiler.WithHot().WithSparse()) - { } + explicit TPerformanceCounters(const NProfiling::TProfiler& profiler); - void IncrementRequestsPerUserAgent(TStringBuf userAgent) - { - RequestsPerUserAgent_.FindOrInsert(userAgent, [&] { - return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent"); - }).first->Increment(); - } + NProfiling::TCounter* GetRequestsPerUserAgentCounter(TStringBuf userAgent); private: const NProfiling::TProfiler Profiler_; @@ -853,10 +858,10 @@ protected: //! Returns a (non-owning!) pointer to TRuntimeMethodInfo for a given method's name //! or |nullptr| if no such method is registered. - TRuntimeMethodInfo* FindMethodInfo(const TString& method); + TRuntimeMethodInfo* FindMethodInfo(TStringBuf method); //! Similar to #FindMethodInfo but throws if no method is found. - TRuntimeMethodInfo* GetMethodInfoOrThrow(const TString& method); + TRuntimeMethodInfo* GetMethodInfoOrThrow(TStringBuf method); //! Returns the default invoker passed during construction. const IInvokerPtr& GetDefaultInvoker() const; @@ -903,12 +908,9 @@ protected: virtual std::optional<TError> GetThrottledError(const NProto::TRequestHeader& requestHeader); protected: - void ReplyError( - TError error, - const NProto::TRequestHeader& header, - const NYT::NBus::IBusPtr& replyBus); - - virtual void OnMethodError(const TError& error, const TString& method); + virtual void OnMethodError( + const TError& error, + const TString& method); private: friend class TRequestQueue; @@ -994,17 +996,24 @@ private: THashMap<TString, TDiscoverRequestSet> DiscoverRequestsByPayload_; YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, DiscoverRequestsByPayloadLock_); - TPerformanceCountersPtr PerformanceCounters_; + const TPerformanceCountersPtr PerformanceCounters_; + const TMethodPerformanceCountersPtr UnknownMethodPerformanceCounters_; - struct TAcceptedRequest + struct TIncomingRequest { + TInstant ArriveInstant; TRequestId RequestId; NYT::NBus::IBusPtr ReplyBus; - TRuntimeMethodInfo* RuntimeInfo; + TRuntimeMethodInfo* RuntimeInfo = nullptr; NTracing::TTraceContextPtr TraceContext; std::unique_ptr<NRpc::NProto::TRequestHeader> Header; + TStringBuf UserAgent; + TStringBuf Method; + TStringBuf User; + TStringBuf UserTag; TSharedRefArray Message; - TRequestQueue* RequestQueue; + TRequestQueue* RequestQueue = nullptr; + TMethodPerformanceCounters* MethodPerformanceCounters = nullptr; std::optional<TError> ThrottledError; TMemoryUsageTrackerGuard MemoryGuard; IMemoryUsageTrackerPtr MemoryUsageTracker; @@ -1019,19 +1028,18 @@ private: void OnRequestTimeout(TRequestId requestId, ERequestProcessingStage stage, bool aborted); void OnReplyBusTerminated(const NYT::TWeakPtr<NYT::NBus::IBus>& busWeak, const TError& error); + void DoHandleRequest(TIncomingRequest&& incomingRequest); + void ReplyError(TError error, TIncomingRequest&& incomingRequest); void OnRequestAuthenticated( const NProfiling::TWallTimer& timer, - TAcceptedRequest&& acceptedRequest, + TIncomingRequest&& incomingRequest, const TErrorOr<TAuthenticationResult>& authResultOrError); - bool IsAuthenticationNeeded(const TAcceptedRequest& acceptedRequest); - void HandleAuthenticatedRequest(TAcceptedRequest&& acceptedRequest); + bool IsAuthenticationNeeded(const TIncomingRequest& incomingRequest); + void HandleAuthenticatedRequest(TIncomingRequest&& incomingRequest); TRequestQueue* GetRequestQueue( TRuntimeMethodInfo* runtimeInfo, const NRpc::NProto::TRequestHeader& requestHeader); - void RegisterRequestQueue( - TRuntimeMethodInfo* runtimeInfo, - TRequestQueue* requestQueue); void ConfigureRequestQueue( TRuntimeMethodInfo* runtimeInfo, TRequestQueue* requestQueue, @@ -1054,14 +1062,14 @@ private: std::vector<TStreamingPayload> GetAndErasePendingPayloads(TRequestId requestId); void OnPendingPayloadsLeaseExpired(TRequestId requestId); + TMethodPerformanceCountersPtr CreateUnknownMethodPerformanceCounters(); TMethodPerformanceCountersPtr CreateMethodPerformanceCounters( TRuntimeMethodInfo* runtimeInfo, const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key); TMethodPerformanceCounters* GetMethodPerformanceCounters( - TRuntimeMethodInfo* runtimeInfo, - const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key); - + const TIncomingRequest& incomingRequest); TPerformanceCounters* GetPerformanceCounters(); + void ProfileRequest(TIncomingRequest* incomingRequest); void SetActive(); void ValidateInactive(); |