aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbabenko <babenko@yandex-team.com>2024-11-06 00:23:24 +0300
committerbabenko <babenko@yandex-team.com>2024-11-06 00:35:51 +0300
commit0dfaae1982275e6fa2dd300e462594c0c6559bd8 (patch)
treebaef0a069394dd175618afcb92e566e6be6851d2
parentc3e0efb6adf3a6c9840911d45dc3c19935cb5a08 (diff)
downloadydb-0dfaae1982275e6fa2dd300e462594c0c6559bd8.tar.gz
Fix incoming request profiling: increment counters even for non-accepted requests (2nd attempt)
commit_hash:202ccff8f8d6833ffea6c63601d15c37c56d802a
-rw-r--r--yt/yt/core/rpc/config.h2
-rw-r--r--yt/yt/core/rpc/service_detail.cpp505
-rw-r--r--yt/yt/core/rpc/service_detail.h80
3 files changed, 315 insertions, 272 deletions
diff --git a/yt/yt/core/rpc/config.h b/yt/yt/core/rpc/config.h
index c3ce1d08d2..203d62b8bc 100644
--- a/yt/yt/core/rpc/config.h
+++ b/yt/yt/core/rpc/config.h
@@ -132,7 +132,7 @@ public:
std::optional<bool> EnableErrorCodeCounter;
std::optional<ERequestTracingMode> TracingMode;
TTimeHistogramConfigPtr TimeHistogram;
- THashMap<TString, TMethodConfigPtr> Methods;
+ THashMap<std::string, TMethodConfigPtr> Methods;
std::optional<int> AuthenticationQueueSizeLimit;
std::optional<TDuration> PendingPayloadsTimeout;
std::optional<bool> Pooled;
diff --git a/yt/yt/core/rpc/service_detail.cpp b/yt/yt/core/rpc/service_detail.cpp
index 5fe6fef0ee..b51529f85d 100644
--- a/yt/yt/core/rpc/service_detail.cpp
+++ b/yt/yt/core/rpc/service_detail.cpp
@@ -49,6 +49,11 @@ using NYT::ToProto;
static const auto DefaultLoggingSuppressionFailedRequestThrottlerConfig = TThroughputThrottlerConfig::Create(1'000);
constexpr int MaxUserAgentLength = 200;
+constexpr TStringBuf UnknownUserAgent = "<unknown>";
+
+constexpr TStringBuf UnknownUserName = "<unknown>";
+constexpr TStringBuf UnknownMethodName = "<unknown>";
+
constexpr auto ServiceLivenessCheckPeriod = TDuration::MilliSeconds(100);
////////////////////////////////////////////////////////////////////////////////
@@ -294,15 +299,15 @@ auto TServiceBase::TMethodDescriptor::SetHandleMethodError(bool value) const ->
////////////////////////////////////////////////////////////////////////////////
-TServiceBase::TErrorCodeCounter::TErrorCodeCounter(NProfiling::TProfiler profiler)
+TServiceBase::TErrorCodeCounters::TErrorCodeCounters(NProfiling::TProfiler profiler)
: Profiler_(std::move(profiler))
{ }
-void TServiceBase::TErrorCodeCounter::Increment(TErrorCode code)
+NProfiling::TCounter* TServiceBase::TErrorCodeCounters::GetCounter(TErrorCode code)
{
- CodeToCounter_.FindOrInsert(code, [&] {
+ return CodeToCounter_.FindOrInsert(code, [&] {
return Profiler_.WithTag("code", ToString(code)).Counter("/code_count");
- }).first->Increment();
+ }).first;
}
////////////////////////////////////////////////////////////////////////////////
@@ -319,7 +324,7 @@ TServiceBase::TMethodPerformanceCounters::TMethodPerformanceCounters(
, RequestMessageAttachmentSizeCounter(profiler.Counter("/request_message_attachment_bytes"))
, ResponseMessageBodySizeCounter(profiler.Counter("/response_message_body_bytes"))
, ResponseMessageAttachmentSizeCounter(profiler.Counter("/response_message_attachment_bytes"))
- , ErrorCodeCounter(profiler)
+ , ErrorCodeCounters(profiler)
{
if (timeHistogramConfig && timeHistogramConfig->CustomBounds) {
const auto& customBounds = *timeHistogramConfig->CustomBounds;
@@ -355,7 +360,7 @@ TServiceBase::TRuntimeMethodInfo::TRuntimeMethodInfo(
Format("%v.%v ->", ServiceId.ServiceName, Descriptor.Method)))
, RequestQueueSizeLimitErrorCounter(Profiler.Counter("/request_queue_size_errors"))
, RequestQueueByteSizeLimitErrorCounter(Profiler.Counter("/request_queue_byte_size_errors"))
- , UnauthenticatedRequestsCounter(Profiler.Counter("/unauthenticated_requests"))
+ , UnauthenticatedRequestCounter(Profiler.Counter("/unauthenticated_request_count"))
, LoggingSuppressionFailedRequestThrottler(
CreateReconfigurableThroughputThrottler(
DefaultLoggingSuppressionFailedRequestThrottlerConfig))
@@ -368,33 +373,44 @@ TRequestQueue* TServiceBase::TRuntimeMethodInfo::GetDefaultRequestQueue()
////////////////////////////////////////////////////////////////////////////////
+TServiceBase::TPerformanceCounters::TPerformanceCounters(const NProfiling::TProfiler& profiler)
+ : Profiler_(profiler.WithHot().WithSparse())
+{ }
+
+NProfiling::TCounter* TServiceBase::TPerformanceCounters::GetRequestsPerUserAgentCounter(TStringBuf userAgent)
+{
+ return RequestsPerUserAgent_.FindOrInsert(userAgent, [&] {
+ return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent");
+ }).first;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
class TServiceBase::TServiceContext
: public TServiceContextBase
{
public:
TServiceContext(
TServiceBasePtr&& service,
- TAcceptedRequest&& acceptedRequest,
+ TIncomingRequest&& incomingRequest,
NLogging::TLogger logger)
: TServiceContextBase(
- std::move(acceptedRequest.Header),
- std::move(acceptedRequest.Message),
- std::move(acceptedRequest.MemoryGuard),
- std::move(acceptedRequest.MemoryUsageTracker),
+ std::move(incomingRequest.Header),
+ std::move(incomingRequest.Message),
+ std::move(incomingRequest.MemoryGuard),
+ std::move(incomingRequest.MemoryUsageTracker),
std::move(logger),
- acceptedRequest.RuntimeInfo->LogLevel.load(std::memory_order::relaxed))
+ incomingRequest.RuntimeInfo->LogLevel.load(std::memory_order::relaxed))
, Service_(std::move(service))
- , RequestId_(acceptedRequest.RequestId)
- , ReplyBus_(std::move(acceptedRequest.ReplyBus))
- , RuntimeInfo_(acceptedRequest.RuntimeInfo)
- , TraceContext_(std::move(acceptedRequest.TraceContext))
- , RequestQueue_(acceptedRequest.RequestQueue)
- , ThrottledError_(std::move(acceptedRequest.ThrottledError))
- , MethodPerformanceCounters_(Service_->GetMethodPerformanceCounters(
- RuntimeInfo_,
- {GetAuthenticationIdentity().UserTag, RequestQueue_}))
+ , RequestId_(incomingRequest.RequestId)
+ , ReplyBus_(std::move(incomingRequest.ReplyBus))
+ , RuntimeInfo_(incomingRequest.RuntimeInfo)
+ , TraceContext_(std::move(incomingRequest.TraceContext))
+ , RequestQueue_(incomingRequest.RequestQueue)
+ , ThrottledError_(std::move(incomingRequest.ThrottledError))
+ , MethodPerformanceCounters_(incomingRequest.MethodPerformanceCounters)
, PerformanceCounters_(Service_->GetPerformanceCounters())
- , ArriveInstant_(NProfiling::GetInstant())
+ , ArriveInstant_(incomingRequest.ArriveInstant)
{
YT_ASSERT(RequestMessage_);
YT_ASSERT(ReplyBus_);
@@ -782,24 +798,6 @@ private:
void Initialize()
{
- constexpr TStringBuf UnknownUserAgent = "unknown";
- auto userAgent = RequestHeader_->has_user_agent()
- ? TStringBuf(RequestHeader_->user_agent())
- : UnknownUserAgent;
- PerformanceCounters_->IncrementRequestsPerUserAgent(userAgent.SubString(0, MaxUserAgentLength));
-
- MethodPerformanceCounters_->RequestCounter.Increment();
- MethodPerformanceCounters_->RequestMessageBodySizeCounter.Increment(
- GetMessageBodySize(RequestMessage_));
- MethodPerformanceCounters_->RequestMessageAttachmentSizeCounter.Increment(
- GetTotalMessageAttachmentSize(RequestMessage_));
-
- if (RequestHeader_->has_start_time()) {
- auto retryStart = FromProto<TInstant>(RequestHeader_->start_time());
- auto now = NProfiling::GetInstant();
- MethodPerformanceCounters_->RemoteWaitTimeCounter.Record(now - retryStart);
- }
-
// COMPAT(danilalexeev): legacy RPC codecs
RequestCodec_ = RequestHeader_->has_request_codec()
? CheckedEnumCast<NCompression::ECodec>(RequestHeader_->request_codec())
@@ -1066,7 +1064,8 @@ private:
MethodPerformanceCounters_->TotalTimeCounter.Record(*TotalTime_);
if (!Error_.IsOK()) {
if (Service_->EnableErrorCodeCounter_.load()) {
- MethodPerformanceCounters_->ErrorCodeCounter.Increment(Error_.GetNonTrivialCode());
+ const auto* counter = MethodPerformanceCounters_->ErrorCodeCounters.GetCounter(Error_.GetNonTrivialCode());
+ counter->Increment();
} else {
MethodPerformanceCounters_->FailedRequestCounter.Increment();
}
@@ -1658,23 +1657,20 @@ void TRequestQueue::SubscribeToThrottlers()
////////////////////////////////////////////////////////////////////////////////
-struct TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals
+bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()(
+ const TNonowningPerformanceCountersKey& lhs,
+ const TNonowningPerformanceCountersKey& rhs) const
{
- bool operator()(
- const TNonowningPerformanceCountersKey& lhs,
- const TNonowningPerformanceCountersKey& rhs) const
- {
- return lhs == rhs;
- }
+ return lhs == rhs;
+}
- bool operator()(
- const TOwningPerformanceCountersKey& lhs,
- const TNonowningPerformanceCountersKey& rhs) const
- {
- const auto& [lhsUserTag, lhsRequestQueue] = lhs;
- return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs;
- }
-};
+bool TServiceBase::TRuntimeMethodInfo::TPerformanceCountersKeyEquals::operator()(
+ const TOwningPerformanceCountersKey& lhs,
+ const TNonowningPerformanceCountersKey& rhs) const
+{
+ const auto& [lhsUserTag, lhsRequestQueue] = lhs;
+ return TNonowningPerformanceCountersKey{lhsUserTag, lhsRequestQueue} == rhs;
+}
////////////////////////////////////////////////////////////////////////////////
@@ -1698,6 +1694,7 @@ TServiceBase::TServiceBase(
BIND(&TServiceBase::OnServiceLivenessCheck, MakeWeak(this)),
ServiceLivenessCheckPeriod))
, PerformanceCounters_(New<TServiceBase::TPerformanceCounters>(RpcServerProfiler))
+ , UnknownMethodPerformanceCounters_(CreateUnknownMethodPerformanceCounters())
{
RegisterMethod(RPC_SERVICE_METHOD_DESC(Discover)
.SetInvoker(TDispatcher::Get()->GetHeavyInvoker())
@@ -1723,94 +1720,103 @@ void TServiceBase::HandleRequest(
{
SetActive();
- auto method = FromProto<TString>(header->method());
+ auto arriveInstant = NProfiling::GetInstant();
auto requestId = FromProto<TRequestId>(header->request_id());
+ auto userAgent = header->has_user_agent()
+ ? TStringBuf(header->user_agent()).SubString(0, MaxUserAgentLength)
+ : UnknownUserAgent;
+ auto method = TStringBuf(header->method());
+ auto user = header->has_user()
+ ? TStringBuf(header->user())
+ : (Authenticator_ ? UnknownUserName : RootUserName);
+ auto userTag = header->has_user_tag()
+ ? TStringBuf(header->user_tag())
+ : user;
+
+ DoHandleRequest(TIncomingRequest{
+ .ArriveInstant = arriveInstant,
+ .RequestId = requestId,
+ .ReplyBus = std::move(replyBus),
+ .Header = std::move(header),
+ .UserAgent = userAgent,
+ .Method = method,
+ .User = user,
+ .UserTag = userTag,
+ .Message = std::move(message),
+ .MemoryUsageTracker = MemoryUsageTracker_,
+ });
+}
- auto replyError = [&] (TError error) {
- ReplyError(std::move(error), *header, replyBus);
- };
-
+void TServiceBase::DoHandleRequest(TIncomingRequest&& incomingRequest)
+{
if (IsStopped()) {
- replyError(TError(
- NRpc::EErrorCode::Unavailable,
- "Service is stopped"));
+ ReplyError(
+ TError(NRpc::EErrorCode::Unavailable, "Service is stopped"),
+ std::move(incomingRequest));
return;
}
- if (auto error = DoCheckRequestCompatibility(*header); !error.IsOK()) {
- replyError(std::move(error));
+ incomingRequest.RuntimeInfo = FindMethodInfo(incomingRequest.Method);
+ if (!incomingRequest.RuntimeInfo) {
+ ReplyError(
+ TError(NRpc::EErrorCode::NoSuchMethod, "Unknown method"),
+ std::move(incomingRequest));
return;
}
- auto* runtimeInfo = FindMethodInfo(method);
- if (!runtimeInfo) {
- replyError(TError(
- NRpc::EErrorCode::NoSuchMethod,
- "Unknown method"));
+ incomingRequest.RequestQueue = GetRequestQueue(incomingRequest.RuntimeInfo, *incomingRequest.Header);
+
+ if (auto error = DoCheckRequestCompatibility(*incomingRequest.Header); !error.IsOK()) {
+ ReplyError(std::move(error), std::move(incomingRequest));
return;
}
- auto memoryGuard = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, TypicalRequestSize);
- message = TrackMemory(MemoryUsageTracker_, std::move(message));
+ incomingRequest.MemoryGuard = TMemoryUsageTrackerGuard::Acquire(MemoryUsageTracker_, TypicalRequestSize);
+ incomingRequest.Message = TrackMemory(MemoryUsageTracker_, std::move(incomingRequest.Message));
+
if (MemoryUsageTracker_ && MemoryUsageTracker_->IsExceeded()) {
- return replyError(TError(
- NRpc::EErrorCode::MemoryPressure,
- "Request is dropped due to high memory pressure"));
+ ReplyError(
+ TError(NRpc::EErrorCode::MemoryPressure, "Request is dropped due to high memory pressure"),
+ std::move(incomingRequest));
+ return;
}
- auto tracingMode = runtimeInfo->TracingMode.load(std::memory_order::relaxed);
- auto traceContext = tracingMode == ERequestTracingMode::Disable
- ? NTracing::TTraceContextPtr()
- : GetOrCreateHandlerTraceContext(*header, tracingMode == ERequestTracingMode::Force);
- if (traceContext && traceContext->IsRecorded()) {
- traceContext->AddTag(EndpointAnnotation, replyBus->GetEndpointDescription());
+ if (auto tracingMode = incomingRequest.RuntimeInfo->TracingMode.load(std::memory_order::relaxed); tracingMode != ERequestTracingMode::Disable) {
+ incomingRequest.TraceContext = GetOrCreateHandlerTraceContext(*incomingRequest.Header, tracingMode == ERequestTracingMode::Force);
}
- auto* requestQueue = GetRequestQueue(runtimeInfo, *header);
- RegisterRequestQueue(runtimeInfo, requestQueue);
+ if (incomingRequest.TraceContext && incomingRequest.TraceContext->IsRecorded()) {
+ incomingRequest.TraceContext->AddTag(EndpointAnnotation, incomingRequest.ReplyBus->GetEndpointDescription());
+ }
- auto maybeThrottled = GetThrottledError(*header);
+ incomingRequest.ThrottledError = GetThrottledError(*incomingRequest.Header);
- if (requestQueue->IsQueueSizeLimitExceeded()) {
- runtimeInfo->RequestQueueSizeLimitErrorCounter.Increment();
- replyError(TError(
- NRpc::EErrorCode::RequestQueueSizeLimitExceeded,
- "Request queue size limit exceeded")
- << TErrorAttribute("limit", runtimeInfo->QueueSizeLimit.load())
- << TErrorAttribute("queue", requestQueue->GetName())
- << maybeThrottled);
+ if (incomingRequest.RequestQueue->IsQueueSizeLimitExceeded()) {
+ incomingRequest.RuntimeInfo->RequestQueueSizeLimitErrorCounter.Increment();
+ ReplyError(
+ TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Request queue size limit exceeded")
+ << TErrorAttribute("limit", incomingRequest.RuntimeInfo->QueueSizeLimit.load())
+ << TErrorAttribute("queue", incomingRequest.RequestQueue->GetName())
+ << incomingRequest.ThrottledError,
+ std::move(incomingRequest));
return;
}
- if (requestQueue->IsQueueByteSizeLimitExceeded()) {
- runtimeInfo->RequestQueueByteSizeLimitErrorCounter.Increment();
- replyError(TError(
- NRpc::EErrorCode::RequestQueueSizeLimitExceeded,
- "Request queue bytes size limit exceeded")
- << TErrorAttribute("limit", runtimeInfo->QueueByteSizeLimit.load())
- << TErrorAttribute("queue", requestQueue->GetName())
- << maybeThrottled);
+ if (incomingRequest.RequestQueue->IsQueueByteSizeLimitExceeded()) {
+ incomingRequest.RuntimeInfo->RequestQueueByteSizeLimitErrorCounter.Increment();
+ ReplyError(
+ TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Request queue bytes size limit exceeded")
+ << TErrorAttribute("limit", incomingRequest.RuntimeInfo->QueueByteSizeLimit.load())
+ << TErrorAttribute("queue", incomingRequest.RequestQueue->GetName())
+ << incomingRequest.ThrottledError,
+ std::move(incomingRequest));
return;
}
- TCurrentTraceContextGuard traceContextGuard(traceContext);
-
- // NOTE: Do not use replyError() after this line.
- TAcceptedRequest acceptedRequest{
- .RequestId = requestId,
- .ReplyBus = std::move(replyBus),
- .RuntimeInfo = std::move(runtimeInfo),
- .TraceContext = std::move(traceContext),
- .Header = std::move(header),
- .Message = std::move(message),
- .RequestQueue = requestQueue,
- .ThrottledError = maybeThrottled,
- .MemoryGuard = std::move(memoryGuard),
- .MemoryUsageTracker = MemoryUsageTracker_,
- };
+ TCurrentTraceContextGuard traceContextGuard(incomingRequest.TraceContext);
- if (!IsAuthenticationNeeded(acceptedRequest)) {
- HandleAuthenticatedRequest(std::move(acceptedRequest));
+ if (!IsAuthenticationNeeded(incomingRequest)) {
+ HandleAuthenticatedRequest(std::move(incomingRequest));
return;
}
@@ -1818,11 +1824,10 @@ void TServiceBase::HandleRequest(
auto authenticationQueueSizeLimit = AuthenticationQueueSizeLimit_.load(std::memory_order::relaxed);
auto authenticationQueueSize = AuthenticationQueueSize_.load(std::memory_order::relaxed);
if (authenticationQueueSize > authenticationQueueSizeLimit) {
- auto error = TError(
- NRpc::EErrorCode::RequestQueueSizeLimitExceeded,
- "Authentication request queue size limit exceeded")
- << TErrorAttribute("limit", authenticationQueueSizeLimit);
- ReplyError(error, *acceptedRequest.Header, acceptedRequest.ReplyBus);
+ ReplyError(
+ TError(NRpc::EErrorCode::RequestQueueSizeLimitExceeded, "Authentication request queue size limit exceeded")
+ << TErrorAttribute("limit", authenticationQueueSizeLimit),
+ std::move(incomingRequest));
return;
}
++AuthenticationQueueSize_;
@@ -1830,37 +1835,35 @@ void TServiceBase::HandleRequest(
NProfiling::TWallTimer timer;
TAuthenticationContext authenticationContext{
- .Header = acceptedRequest.Header.get(),
- .UserIP = acceptedRequest.ReplyBus->GetEndpointNetworkAddress(),
- .IsLocal = acceptedRequest.ReplyBus->IsEndpointLocal(),
+ .Header = incomingRequest.Header.get(),
+ .UserIP = incomingRequest.ReplyBus->GetEndpointNetworkAddress(),
+ .IsLocal = incomingRequest.ReplyBus->IsEndpointLocal(),
};
if (Authenticator_->CanAuthenticate(authenticationContext)) {
auto asyncAuthResult = Authenticator_->AsyncAuthenticate(authenticationContext);
if (asyncAuthResult.IsSet()) {
- OnRequestAuthenticated(timer, std::move(acceptedRequest), asyncAuthResult.Get());
+ OnRequestAuthenticated(timer, std::move(incomingRequest), asyncAuthResult.Get());
} else {
asyncAuthResult.Subscribe(
- BIND(&TServiceBase::OnRequestAuthenticated, MakeStrong(this), timer, Passed(std::move(acceptedRequest))));
+ BIND(&TServiceBase::OnRequestAuthenticated, MakeStrong(this), timer, Passed(std::move(incomingRequest))));
}
} else {
- OnRequestAuthenticated(timer, std::move(acceptedRequest), TError(
+ OnRequestAuthenticated(timer, std::move(incomingRequest), TError(
NYT::NRpc::EErrorCode::AuthenticationError,
"Request is missing credentials"));
}
}
-void TServiceBase::ReplyError(
- TError error,
- const NProto::TRequestHeader& header,
- const IBusPtr& replyBus)
+void TServiceBase::ReplyError(TError error, TIncomingRequest&& incomingRequest)
{
- auto requestId = FromProto<TRequestId>(header.request_id());
+ ProfileRequest(&incomingRequest);
+
auto richError = std::move(error)
- << TErrorAttribute("request_id", requestId)
+ << TErrorAttribute("request_id", incomingRequest.RequestId)
<< TErrorAttribute("realm_id", ServiceId_.RealmId)
<< TErrorAttribute("service", ServiceId_.ServiceName)
- << TErrorAttribute("method", header.method())
- << TErrorAttribute("endpoint", replyBus->GetEndpointDescription());
+ << TErrorAttribute("method", incomingRequest.Method)
+ << TErrorAttribute("endpoint", incomingRequest.ReplyBus->GetEndpointDescription());
auto code = richError.GetCode();
auto logLevel =
@@ -1869,8 +1872,8 @@ void TServiceBase::ReplyError(
: NLogging::ELogLevel::Debug;
YT_LOG_EVENT(Logger, logLevel, richError);
- auto errorMessage = CreateErrorResponseMessage(requestId, richError);
- YT_UNUSED_FUTURE(replyBus->Send(errorMessage));
+ auto errorMessage = CreateErrorResponseMessage(incomingRequest.RequestId, richError);
+ YT_UNUSED_FUTURE(incomingRequest.ReplyBus->Send(errorMessage));
}
void TServiceBase::OnMethodError(const TError& /*error*/, const TString& /*method*/)
@@ -1878,76 +1881,70 @@ void TServiceBase::OnMethodError(const TError& /*error*/, const TString& /*metho
void TServiceBase::OnRequestAuthenticated(
const NProfiling::TWallTimer& timer,
- TAcceptedRequest&& acceptedRequest,
+ TIncomingRequest&& incomingRequest,
const TErrorOr<TAuthenticationResult>& authResultOrError)
{
AuthenticationTimer_.Record(timer.GetElapsedTime());
--AuthenticationQueueSize_;
- auto& requestHeader = *acceptedRequest.Header;
-
- if (authResultOrError.IsOK()) {
- const auto& authResult = authResultOrError.Value();
- const auto& Logger = RpcServerLogger;
- YT_LOG_DEBUG("Request authenticated (RequestId: %v, User: %v, Realm: %v)",
- acceptedRequest.RequestId,
- authResult.User,
- authResult.Realm);
- const auto& authenticatedUser = authResult.User;
- if (requestHeader.has_user()) {
- const auto& user = requestHeader.user();
- if (user != authenticatedUser) {
- ReplyError(
- TError(
- NRpc::EErrorCode::AuthenticationError,
- "Manually specified and authenticated users mismatch")
- << TErrorAttribute("user", user)
- << TErrorAttribute("authenticated_user", authenticatedUser),
- requestHeader,
- acceptedRequest.ReplyBus);
- return;
- }
- }
- requestHeader.set_user(ToProto(authResult.User));
-
- auto* credentialsExt = requestHeader.MutableExtension(
- NRpc::NProto::TCredentialsExt::credentials_ext);
- if (credentialsExt->user_ticket().empty()) {
- credentialsExt->set_user_ticket(std::move(authResult.UserTicket));
- }
- HandleAuthenticatedRequest(std::move(acceptedRequest));
- } else {
+ if (!authResultOrError.IsOK()) {
ReplyError(
- TError(
- NRpc::EErrorCode::AuthenticationError,
- "Request authentication failed")
+ TError(NRpc::EErrorCode::AuthenticationError, "Request authentication failed")
<< authResultOrError,
- requestHeader,
- acceptedRequest.ReplyBus);
+ std::move(incomingRequest));
+ return;
+ }
+
+ const auto& authResult = authResultOrError.Value();
+ const auto& Logger = RpcServerLogger;
+ YT_LOG_DEBUG("Request authenticated (RequestId: %v, User: %v, Realm: %v)",
+ incomingRequest.RequestId,
+ authResult.User,
+ authResult.Realm);
+ const auto& authenticatedUser = authResult.User;
+ if (incomingRequest.Header->has_user()) {
+ const auto& user = incomingRequest.Header->user();
+ if (user != authenticatedUser) {
+ ReplyError(
+ TError(NRpc::EErrorCode::AuthenticationError, "Manually specified and authenticated users mismatch")
+ << TErrorAttribute("user", user)
+ << TErrorAttribute("authenticated_user", authenticatedUser),
+ std::move(incomingRequest));
+ return;
+ }
}
+
+ incomingRequest.Header->set_user(ToProto(authResult.User));
+ incomingRequest.User = TStringBuf(incomingRequest.Header->user());
+ incomingRequest.UserTag = incomingRequest.Header->has_user_tag()
+ ? incomingRequest.UserTag
+ : incomingRequest.User;
+
+ auto* credentialsExt = incomingRequest.Header->MutableExtension(
+ NRpc::NProto::TCredentialsExt::credentials_ext);
+ if (credentialsExt->user_ticket().empty()) {
+ credentialsExt->set_user_ticket(std::move(authResult.UserTicket));
+ }
+
+ HandleAuthenticatedRequest(std::move(incomingRequest));
}
-bool TServiceBase::IsAuthenticationNeeded(const TAcceptedRequest& acceptedRequest)
+bool TServiceBase::IsAuthenticationNeeded(const TIncomingRequest& incomingRequest)
{
return
Authenticator_.operator bool() &&
- !acceptedRequest.RuntimeInfo->Descriptor.System;
+ !incomingRequest.RuntimeInfo->Descriptor.System;
}
-void TServiceBase::HandleAuthenticatedRequest(TAcceptedRequest&& acceptedRequest)
+void TServiceBase::HandleAuthenticatedRequest(TIncomingRequest&& incomingRequest)
{
- if (!acceptedRequest.ReplyBus->IsEndpointLocal()) {
- bool authenticated = acceptedRequest.Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext) &&
- acceptedRequest.Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext).has_service_ticket();
- if (!authenticated) {
- acceptedRequest.RuntimeInfo->UnauthenticatedRequestsCounter.Increment();
- }
- }
+ ProfileRequest(&incomingRequest);
auto context = New<TServiceContext>(
this,
- std::move(acceptedRequest),
+ std::move(incomingRequest),
Logger);
+
auto* requestQueue = context->GetRequestQueue();
requestQueue->OnRequestArrived(std::move(context));
}
@@ -1957,55 +1954,49 @@ TRequestQueue* TServiceBase::GetRequestQueue(
const NRpc::NProto::TRequestHeader& requestHeader)
{
TRequestQueue* requestQueue = nullptr;
- if (auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) {
+ if (const auto& provider = runtimeInfo->Descriptor.RequestQueueProvider) {
requestQueue = provider->GetQueue(requestHeader);
}
if (!requestQueue) {
requestQueue = runtimeInfo->DefaultRequestQueue.Get();
}
- return requestQueue;
-}
-void TServiceBase::RegisterRequestQueue(
- TRuntimeMethodInfo* runtimeInfo,
- TRequestQueue* requestQueue)
-{
- if (!requestQueue->Register(this, runtimeInfo)) {
- return;
- }
+ if (requestQueue->Register(this, runtimeInfo)) {
+ const auto& method = runtimeInfo->Descriptor.Method;
+ YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)",
+ method,
+ requestQueue->GetName());
- const auto& method = runtimeInfo->Descriptor.Method;
- YT_LOG_DEBUG("Request queue registered (Method: %v, Queue: %v)",
- method,
- requestQueue->GetName());
+ auto profiler = runtimeInfo->Profiler.WithSparse();
+ if (runtimeInfo->Descriptor.RequestQueueProvider) {
+ profiler = profiler.WithTag("queue", requestQueue->GetName());
+ }
+ profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] {
+ return requestQueue->GetQueueSize();
+ });
+ profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] {
+ return requestQueue->GetQueueByteSize();
+ });
+ profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] {
+ return requestQueue->GetConcurrency();
+ });
+ profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] {
+ return requestQueue->GetConcurrencyByte();
+ });
- auto profiler = runtimeInfo->Profiler.WithSparse();
- if (runtimeInfo->Descriptor.RequestQueueProvider) {
- profiler = profiler.WithTag("queue", requestQueue->GetName());
- }
- profiler.AddFuncGauge("/request_queue_size", MakeStrong(this), [=] {
- return requestQueue->GetQueueSize();
- });
- profiler.AddFuncGauge("/request_queue_byte_size", MakeStrong(this), [=] {
- return requestQueue->GetQueueByteSize();
- });
- profiler.AddFuncGauge("/concurrency", MakeStrong(this), [=] {
- return requestQueue->GetConcurrency();
- });
- profiler.AddFuncGauge("/concurrency_byte", MakeStrong(this), [=] {
- return requestQueue->GetConcurrencyByte();
- });
+ TMethodConfigPtr methodConfig;
+ if (auto config = Config_.Acquire()) {
+ methodConfig = GetOrDefault(config->Methods, method);
+ }
+ ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig);
- TMethodConfigPtr methodConfig;
- if (auto config = Config_.Acquire()) {
- methodConfig = GetOrDefault(config->Methods, method);
+ {
+ auto guard = Guard(runtimeInfo->RequestQueuesLock);
+ runtimeInfo->RequestQueues.push_back(requestQueue);
+ }
}
- ConfigureRequestQueue(runtimeInfo, requestQueue, methodConfig);
- {
- auto guard = Guard(runtimeInfo->RequestQueuesLock);
- runtimeInfo->RequestQueues.push_back(requestQueue);
- }
+ return requestQueue;
}
void TServiceBase::ConfigureRequestQueue(
@@ -2362,6 +2353,15 @@ void TServiceBase::OnPendingPayloadsLeaseExpired(TRequestId requestId)
}
}
+TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateUnknownMethodPerformanceCounters()
+{
+ auto profiler = Profiler_
+ .WithSparse()
+ .WithTag("method", std::string(UnknownMethodName), -1)
+ .WithTag("user", std::string(UnknownUserName));
+ return New<TMethodPerformanceCounters>(profiler, TimeHistogramConfig_.Acquire());
+}
+
TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateMethodPerformanceCounters(
TRuntimeMethodInfo* runtimeInfo,
const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key)
@@ -2380,12 +2380,18 @@ TServiceBase::TMethodPerformanceCountersPtr TServiceBase::CreateMethodPerformanc
}
TServiceBase::TMethodPerformanceCounters* TServiceBase::GetMethodPerformanceCounters(
- TRuntimeMethodInfo* runtimeInfo,
- const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key)
+ const TIncomingRequest& incomingRequest)
{
- auto [userTag, requestQueue] = key;
+ auto* requestQueue = incomingRequest.RequestQueue;
+ const auto& runtimeInfo = incomingRequest.RuntimeInfo;
+
+ // Handle a partially parsed request.
+ if (!requestQueue || !runtimeInfo) {
+ return UnknownMethodPerformanceCounters_.Get();
+ }
// Fast path.
+ auto userTag = incomingRequest.UserTag;
if (userTag == RootUserName && requestQueue == runtimeInfo->DefaultRequestQueue.Get()) {
if (EnablePerUserProfiling_.load(std::memory_order::relaxed)) {
return runtimeInfo->RootPerformanceCounters.Get();
@@ -2394,12 +2400,14 @@ TServiceBase::TMethodPerformanceCounters* TServiceBase::GetMethodPerformanceCoun
}
}
+ // Slow path.
if (!EnablePerUserProfiling_.load(std::memory_order::relaxed)) {
userTag = {};
}
- auto actualKey = TRuntimeMethodInfo::TNonowningPerformanceCountersKey{userTag, requestQueue};
- return runtimeInfo->PerformanceCountersMap.FindOrInsert(actualKey, [&] {
- return CreateMethodPerformanceCounters(runtimeInfo, actualKey);
+
+ auto key = TRuntimeMethodInfo::TNonowningPerformanceCountersKey{userTag, requestQueue};
+ return runtimeInfo->PerformanceCountersMap.FindOrInsert(key, [&] {
+ return CreateMethodPerformanceCounters(runtimeInfo, key);
}).first->Get();
}
@@ -2408,6 +2416,33 @@ TServiceBase::TPerformanceCounters* TServiceBase::GetPerformanceCounters()
return PerformanceCounters_.Get();
}
+void TServiceBase::ProfileRequest(TIncomingRequest* incomingRequest)
+{
+ const auto* requestsPerUserAgentCounter = PerformanceCounters_->GetRequestsPerUserAgentCounter(incomingRequest->UserAgent);
+ requestsPerUserAgentCounter->Increment();
+
+ auto* methodPerformanceCounters = GetMethodPerformanceCounters(*incomingRequest);
+ // NB: Save the counters for future use on response.
+ incomingRequest->MethodPerformanceCounters = methodPerformanceCounters;
+
+ methodPerformanceCounters->RequestCounter.Increment();
+ methodPerformanceCounters->RequestMessageBodySizeCounter.Increment(GetMessageBodySize(incomingRequest->Message));
+ methodPerformanceCounters->RequestMessageAttachmentSizeCounter.Increment(GetTotalMessageAttachmentSize(incomingRequest->Message));
+
+ if (incomingRequest->Header->has_start_time()) {
+ auto retryStart = FromProto<TInstant>(incomingRequest->Header->start_time());
+ methodPerformanceCounters->RemoteWaitTimeCounter.Record(incomingRequest->ArriveInstant - retryStart);
+ }
+
+ if (!incomingRequest->ReplyBus->IsEndpointLocal()) {
+ bool authenticated = incomingRequest->Header->HasExtension(NRpc::NProto::TCredentialsExt::credentials_ext) &&
+ incomingRequest->Header->GetExtension(NRpc::NProto::TCredentialsExt::credentials_ext).has_service_ticket();
+ if (!authenticated && incomingRequest->RuntimeInfo) {
+ incomingRequest->RuntimeInfo->UnauthenticatedRequestCounter.Increment();
+ }
+ }
+}
+
void TServiceBase::SetActive()
{
// Fast path.
@@ -2726,13 +2761,13 @@ TFuture<void> TServiceBase::Stop()
return StopResult_.ToFuture();
}
-TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(const TString& method)
+TServiceBase::TRuntimeMethodInfo* TServiceBase::FindMethodInfo(TStringBuf method)
{
auto it = MethodMap_.find(method);
return it == MethodMap_.end() ? nullptr : it->second.Get();
}
-TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(const TString& method)
+TServiceBase::TRuntimeMethodInfo* TServiceBase::GetMethodInfoOrThrow(TStringBuf method)
{
auto* runtimeInfo = FindMethodInfo(method);
if (!runtimeInfo) {
diff --git a/yt/yt/core/rpc/service_detail.h b/yt/yt/core/rpc/service_detail.h
index 76f26d8224..b12bc45ebe 100644
--- a/yt/yt/core/rpc/service_detail.h
+++ b/yt/yt/core/rpc/service_detail.h
@@ -664,11 +664,12 @@ protected:
TMethodDescriptor SetHandleMethodError(bool value) const;
};
- struct TErrorCodeCounter
+ class TErrorCodeCounters
{
- explicit TErrorCodeCounter(NProfiling::TProfiler profiler);
+ public:
+ explicit TErrorCodeCounters(NProfiling::TProfiler profiler);
- void Increment(TErrorCode code);
+ NProfiling::TCounter* GetCounter(TErrorCode code);
private:
const NProfiling::TProfiler Profiler_;
@@ -724,7 +725,7 @@ protected:
NProfiling::TCounter ResponseMessageAttachmentSizeCounter;
//! Counts the number of errors, per error code.
- TErrorCodeCounter ErrorCodeCounter;
+ TErrorCodeCounters ErrorCodeCounters;
};
using TMethodPerformanceCountersPtr = TIntrusivePtr<TMethodPerformanceCounters>;
@@ -762,7 +763,7 @@ protected:
NProfiling::TCounter RequestQueueSizeLimitErrorCounter;
NProfiling::TCounter RequestQueueByteSizeLimitErrorCounter;
- NProfiling::TCounter UnauthenticatedRequestsCounter;
+ NProfiling::TCounter UnauthenticatedRequestCounter;
std::atomic<NLogging::ELogLevel> LogLevel = {};
std::atomic<TDuration> LoggingSuppressionTimeout = {};
@@ -770,13 +771,24 @@ protected:
using TNonowningPerformanceCountersKey = std::tuple<TStringBuf, TRequestQueue*>;
using TOwningPerformanceCountersKey = std::tuple<TString, TRequestQueue*>;
using TPerformanceCountersKeyHash = THash<TNonowningPerformanceCountersKey>;
- struct TPerformanceCountersKeyEquals;
+
+ struct TPerformanceCountersKeyEquals
+ {
+ bool operator()(
+ const TNonowningPerformanceCountersKey& lhs,
+ const TNonowningPerformanceCountersKey& rhs) const;
+ bool operator()(
+ const TOwningPerformanceCountersKey& lhs,
+ const TNonowningPerformanceCountersKey& rhs) const;
+ };
+
using TPerformanceCountersMap = NConcurrency::TSyncMap<
TOwningPerformanceCountersKey,
TMethodPerformanceCountersPtr,
TPerformanceCountersKeyHash,
TPerformanceCountersKeyEquals
>;
+
TPerformanceCountersMap PerformanceCountersMap;
TMethodPerformanceCountersPtr BasePerformanceCounters;
TMethodPerformanceCountersPtr RootPerformanceCounters;
@@ -796,16 +808,9 @@ protected:
: public TRefCounted
{
public:
- explicit TPerformanceCounters(const NProfiling::TProfiler& profiler)
- : Profiler_(profiler.WithHot().WithSparse())
- { }
+ explicit TPerformanceCounters(const NProfiling::TProfiler& profiler);
- void IncrementRequestsPerUserAgent(TStringBuf userAgent)
- {
- RequestsPerUserAgent_.FindOrInsert(userAgent, [&] {
- return Profiler_.WithRequiredTag("user_agent", TString(userAgent)).Counter("/user_agent");
- }).first->Increment();
- }
+ NProfiling::TCounter* GetRequestsPerUserAgentCounter(TStringBuf userAgent);
private:
const NProfiling::TProfiler Profiler_;
@@ -853,10 +858,10 @@ protected:
//! Returns a (non-owning!) pointer to TRuntimeMethodInfo for a given method's name
//! or |nullptr| if no such method is registered.
- TRuntimeMethodInfo* FindMethodInfo(const TString& method);
+ TRuntimeMethodInfo* FindMethodInfo(TStringBuf method);
//! Similar to #FindMethodInfo but throws if no method is found.
- TRuntimeMethodInfo* GetMethodInfoOrThrow(const TString& method);
+ TRuntimeMethodInfo* GetMethodInfoOrThrow(TStringBuf method);
//! Returns the default invoker passed during construction.
const IInvokerPtr& GetDefaultInvoker() const;
@@ -903,12 +908,9 @@ protected:
virtual std::optional<TError> GetThrottledError(const NProto::TRequestHeader& requestHeader);
protected:
- void ReplyError(
- TError error,
- const NProto::TRequestHeader& header,
- const NYT::NBus::IBusPtr& replyBus);
-
- virtual void OnMethodError(const TError& error, const TString& method);
+ virtual void OnMethodError(
+ const TError& error,
+ const TString& method);
private:
friend class TRequestQueue;
@@ -994,17 +996,24 @@ private:
THashMap<TString, TDiscoverRequestSet> DiscoverRequestsByPayload_;
YT_DECLARE_SPIN_LOCK(NThreading::TReaderWriterSpinLock, DiscoverRequestsByPayloadLock_);
- TPerformanceCountersPtr PerformanceCounters_;
+ const TPerformanceCountersPtr PerformanceCounters_;
+ const TMethodPerformanceCountersPtr UnknownMethodPerformanceCounters_;
- struct TAcceptedRequest
+ struct TIncomingRequest
{
+ TInstant ArriveInstant;
TRequestId RequestId;
NYT::NBus::IBusPtr ReplyBus;
- TRuntimeMethodInfo* RuntimeInfo;
+ TRuntimeMethodInfo* RuntimeInfo = nullptr;
NTracing::TTraceContextPtr TraceContext;
std::unique_ptr<NRpc::NProto::TRequestHeader> Header;
+ TStringBuf UserAgent;
+ TStringBuf Method;
+ TStringBuf User;
+ TStringBuf UserTag;
TSharedRefArray Message;
- TRequestQueue* RequestQueue;
+ TRequestQueue* RequestQueue = nullptr;
+ TMethodPerformanceCounters* MethodPerformanceCounters = nullptr;
std::optional<TError> ThrottledError;
TMemoryUsageTrackerGuard MemoryGuard;
IMemoryUsageTrackerPtr MemoryUsageTracker;
@@ -1019,19 +1028,18 @@ private:
void OnRequestTimeout(TRequestId requestId, ERequestProcessingStage stage, bool aborted);
void OnReplyBusTerminated(const NYT::TWeakPtr<NYT::NBus::IBus>& busWeak, const TError& error);
+ void DoHandleRequest(TIncomingRequest&& incomingRequest);
+ void ReplyError(TError error, TIncomingRequest&& incomingRequest);
void OnRequestAuthenticated(
const NProfiling::TWallTimer& timer,
- TAcceptedRequest&& acceptedRequest,
+ TIncomingRequest&& incomingRequest,
const TErrorOr<TAuthenticationResult>& authResultOrError);
- bool IsAuthenticationNeeded(const TAcceptedRequest& acceptedRequest);
- void HandleAuthenticatedRequest(TAcceptedRequest&& acceptedRequest);
+ bool IsAuthenticationNeeded(const TIncomingRequest& incomingRequest);
+ void HandleAuthenticatedRequest(TIncomingRequest&& incomingRequest);
TRequestQueue* GetRequestQueue(
TRuntimeMethodInfo* runtimeInfo,
const NRpc::NProto::TRequestHeader& requestHeader);
- void RegisterRequestQueue(
- TRuntimeMethodInfo* runtimeInfo,
- TRequestQueue* requestQueue);
void ConfigureRequestQueue(
TRuntimeMethodInfo* runtimeInfo,
TRequestQueue* requestQueue,
@@ -1054,14 +1062,14 @@ private:
std::vector<TStreamingPayload> GetAndErasePendingPayloads(TRequestId requestId);
void OnPendingPayloadsLeaseExpired(TRequestId requestId);
+ TMethodPerformanceCountersPtr CreateUnknownMethodPerformanceCounters();
TMethodPerformanceCountersPtr CreateMethodPerformanceCounters(
TRuntimeMethodInfo* runtimeInfo,
const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key);
TMethodPerformanceCounters* GetMethodPerformanceCounters(
- TRuntimeMethodInfo* runtimeInfo,
- const TRuntimeMethodInfo::TNonowningPerformanceCountersKey& key);
-
+ const TIncomingRequest& incomingRequest);
TPerformanceCounters* GetPerformanceCounters();
+ void ProfileRequest(TIncomingRequest* incomingRequest);
void SetActive();
void ValidateInactive();