diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/grpc/server/grpc_request.h | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/grpc/server/grpc_request.h')
-rw-r--r-- | library/cpp/grpc/server/grpc_request.h | 543 |
1 files changed, 543 insertions, 0 deletions
diff --git a/library/cpp/grpc/server/grpc_request.h b/library/cpp/grpc/server/grpc_request.h new file mode 100644 index 0000000000..5bd8d3902b --- /dev/null +++ b/library/cpp/grpc/server/grpc_request.h @@ -0,0 +1,543 @@ +#pragma once + +#include <google/protobuf/text_format.h> +#include <google/protobuf/arena.h> +#include <google/protobuf/message.h> + +#include <library/cpp/monlib/dynamic_counters/counters.h> +#include <library/cpp/logger/priority.h> + +#include "grpc_response.h" +#include "event_callback.h" +#include "grpc_async_ctx_base.h" +#include "grpc_counters.h" +#include "grpc_request_base.h" +#include "grpc_server.h" +#include "logger.h" + +#include <util/system/hp_timer.h> + +#include <grpc++/server.h> +#include <grpc++/server_context.h> +#include <grpc++/support/async_stream.h> +#include <grpc++/support/async_unary_call.h> +#include <grpc++/support/byte_buffer.h> +#include <grpc++/impl/codegen/async_stream.h> + +namespace NGrpc { + +class IStreamAdaptor { +public: + using TPtr = std::unique_ptr<IStreamAdaptor>; + virtual void Enqueue(std::function<void()>&& fn, bool urgent) = 0; + virtual size_t ProcessNext() = 0; + virtual ~IStreamAdaptor() = default; +}; + +IStreamAdaptor::TPtr CreateStreamAdaptor(); + +/////////////////////////////////////////////////////////////////////////////// +template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter, typename TOutProtoPrinter> +class TGRpcRequestImpl + : public TBaseAsyncContext<TService> + , public IQueueEvent + , public IRequestContextBase +{ + using TThis = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; + +public: + using TOnRequest = std::function<void (IRequestContextBase* ctx)>; + using TRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, + grpc::ServerAsyncResponseWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); + using TStreamRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, + grpc::ServerAsyncWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); + + TGRpcRequestImpl(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + TOnRequest cb, + TRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter) + : TBaseAsyncContext<TService>(service, cq) + , Server_(server) + , Cb_(cb) + , RequestCallback_(requestCallback) + , StreamRequestCallback_(nullptr) + , Name_(name) + , Logger_(std::move(logger)) + , Counters_(std::move(counters)) + , RequestLimiter_(std::move(limiter)) + , Writer_(new grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>(&this->Context)) + , StateFunc_(&TThis::SetRequestDone) + { + AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); + Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); + Y_VERIFY(Request_); + GRPC_LOG_DEBUG(Logger_, "[%p] created request Name# %s", this, Name_); + FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); + } + + TGRpcRequestImpl(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + TOnRequest cb, + TStreamRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter) + : TBaseAsyncContext<TService>(service, cq) + , Server_(server) + , Cb_(cb) + , RequestCallback_(nullptr) + , StreamRequestCallback_(requestCallback) + , Name_(name) + , Logger_(std::move(logger)) + , Counters_(std::move(counters)) + , RequestLimiter_(std::move(limiter)) + , StreamWriter_(new grpc::ServerAsyncWriter<TUniversalResponse<TOut>>(&this->Context)) + , StateFunc_(&TThis::SetRequestDone) + { + AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); + Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); + Y_VERIFY(Request_); + GRPC_LOG_DEBUG(Logger_, "[%p] created streaming request Name# %s", this, Name_); + FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); + StreamAdaptor_ = CreateStreamAdaptor(); + } + + TAsyncFinishResult GetFinishFuture() override { + return FinishPromise_.GetFuture(); + } + + TString GetPeer() const override { + return TString(this->Context.peer()); + } + + bool SslServer() const override { + return Server_->SslServer(); + } + + void Run() { + // Start request unless server is shutting down + if (auto guard = Server_->ProtectShutdown()) { + Ref(); //For grpc c runtime + this->Context.AsyncNotifyWhenDone(OnFinishTag.Prepare()); + if (RequestCallback_) { + (this->Service->*RequestCallback_) + (&this->Context, Request_, + reinterpret_cast<grpc::ServerAsyncResponseWriter<TOut>*>(Writer_.Get()), this->CQ, this->CQ, GetGRpcTag()); + } else { + (this->Service->*StreamRequestCallback_) + (&this->Context, Request_, + reinterpret_cast<grpc::ServerAsyncWriter<TOut>*>(StreamWriter_.Get()), this->CQ, this->CQ, GetGRpcTag()); + } + } + } + + ~TGRpcRequestImpl() { + // No direct dtor call allowed + Y_ASSERT(RefCount() == 0); + } + + bool Execute(bool ok) override { + return (this->*StateFunc_)(ok); + } + + void DestroyRequest() override { + if (RequestRegistered_) { + Server_->DeregisterRequestCtx(this); + RequestRegistered_ = false; + } + UnRef(); + } + + TInstant Deadline() const override { + return TBaseAsyncContext<TService>::Deadline(); + } + + TSet<TStringBuf> GetPeerMetaKeys() const override { + return TBaseAsyncContext<TService>::GetPeerMetaKeys(); + } + + TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const override { + return TBaseAsyncContext<TService>::GetPeerMetaValues(key); + } + + grpc_compression_level GetCompressionLevel() const override { + return TBaseAsyncContext<TService>::GetCompressionLevel(); + } + + //! Get pointer to the request's message. + const NProtoBuf::Message* GetRequest() const override { + return Request_; + } + + TAuthState& GetAuthState() override { + return AuthState_; + } + + void Reply(NProtoBuf::Message* resp, ui32 status) override { + ResponseStatus = status; + WriteDataOk(resp); + } + + void Reply(grpc::ByteBuffer* resp, ui32 status) override { + ResponseStatus = status; + WriteByteDataOk(resp); + } + + void ReplyError(grpc::StatusCode code, const TString& msg) override { + FinishGrpcStatus(code, msg, false); + } + + void ReplyUnauthenticated(const TString& in) override { + const TString message = in.empty() ? TString("unauthenticated") : TString("unauthenticated, ") + in; + FinishGrpcStatus(grpc::StatusCode::UNAUTHENTICATED, message, false); + } + + void SetNextReplyCallback(TOnNextReply&& cb) override { + NextReplyCb_ = cb; + } + + void AddTrailingMetadata(const TString& key, const TString& value) override { + this->Context.AddTrailingMetadata(key, value); + } + + void FinishStreamingOk() override { + GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (enqueued)", this, Name_, + this->Context.peer().c_str()); + auto cb = [this]() { + StateFunc_ = &TThis::SetFinishDone; + GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (pushed to grpc)", this, Name_, + this->Context.peer().c_str()); + + StreamWriter_->Finish(grpc::Status::OK, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + + google::protobuf::Arena* GetArena() override { + return &Arena_; + } + + void UseDatabase(const TString& database) override { + Counters_->UseDatabase(database); + } + +private: + void Clone() { + if (!Server_->IsShuttingDown()) { + if (RequestCallback_) { + MakeIntrusive<TThis>( + Server_, this->Service, this->CQ, Cb_, RequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); + } else { + MakeIntrusive<TThis>( + Server_, this->Service, this->CQ, Cb_, StreamRequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); + } + } + } + + void WriteDataOk(NProtoBuf::Message* resp) { + auto makeResponseString = [&] { + TString x; + TOutProtoPrinter printer; + printer.SetSingleLineMode(true); + printer.PrintToString(*resp, &x); + return x; + }; + + auto sz = (size_t)resp->ByteSize(); + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s", this, Name_, + makeResponseString().data(), this->Context.peer().c_str()); + StateFunc_ = &TThis::SetFinishDone; + ResponseSize = sz; + Y_VERIFY(this->Context.c_call()); + Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (enqueued)", + this, Name_, makeResponseString().data(), this->Context.peer().c_str()); + + // because of std::function cannot hold move-only captured object + // we allocate shared object on heap to avoid message copy + auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); + auto cb = [this, uResp = std::move(uResp), sz, &makeResponseString]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (pushed to grpc)", + this, Name_, makeResponseString().data(), this->Context.peer().c_str()); + StateFunc_ = &TThis::NextReply; + ResponseSize += sz; + StreamWriter_->Write(*uResp, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + } + + void WriteByteDataOk(grpc::ByteBuffer* resp) { + auto sz = resp->Length(); + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s", this, Name_, + this->Context.peer().c_str()); + StateFunc_ = &TThis::SetFinishDone; + ResponseSize = sz; + Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (enqueued)", this, Name_, + this->Context.peer().c_str()); + + // because of std::function cannot hold move-only captured object + // we allocate shared object on heap to avoid buffer copy + auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); + auto cb = [this, uResp = std::move(uResp), sz]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (pushed to grpc)", + this, Name_, this->Context.peer().c_str()); + StateFunc_ = &TThis::NextReply; + ResponseSize += sz; + StreamWriter_->Write(*uResp, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + } + + void FinishGrpcStatus(grpc::StatusCode code, const TString& msg, bool urgent) { + Y_VERIFY(code != grpc::OK); + if (code == grpc::StatusCode::UNAUTHENTICATED) { + Counters_->CountNotAuthenticated(); + } else if (code == grpc::StatusCode::RESOURCE_EXHAUSTED) { + Counters_->CountResourceExhausted(); + } + + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)", this, + Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); + StateFunc_ = &TThis::SetFinishError; + TOut resp; + Writer_->Finish(TUniversalResponseRef<TOut>(&resp), grpc::Status(code, msg), GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" + " (enqueued)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); + auto cb = [this, code, msg]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" + " (pushed to grpc)", this, Name_, msg.c_str(), + this->Context.peer().c_str(), (int)code); + StateFunc_ = &TThis::SetFinishError; + StreamWriter_->Finish(grpc::Status(code, msg), GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), urgent); + } + } + + bool SetRequestDone(bool ok) { + auto makeRequestString = [&] { + TString resp; + if (ok) { + TInProtoPrinter printer; + printer.SetSingleLineMode(true); + printer.PrintToString(*Request_, &resp); + } else { + resp = "<not ok>"; + } + return resp; + }; + GRPC_LOG_DEBUG(Logger_, "[%p] received request Name# %s ok# %s data# %s peer# %s", this, Name_, + ok ? "true" : "false", makeRequestString().data(), this->Context.peer().c_str()); + + if (this->Context.c_call() == nullptr) { + Y_VERIFY(!ok); + // One ref by OnFinishTag, grpc will not call this tag if no request received + UnRef(); + } else if (!(RequestRegistered_ = Server_->RegisterRequestCtx(this))) { + // Request cannot be registered due to shutdown + // It's unsafe to continue, so drop this request without processing + GRPC_LOG_DEBUG(Logger_, "[%p] dropping request Name# %s due to shutdown", this, Name_); + this->Context.TryCancel(); + return false; + } + + Clone(); // TODO: Request pool? + if (!ok) { + Counters_->CountNotOkRequest(); + return false; + } + + if (IncRequest()) { + // Adjust counters. + RequestSize = Request_->ByteSize(); + Counters_->StartProcessing(RequestSize); + RequestTimer.Reset(); + + if (!SslServer()) { + Counters_->CountRequestWithoutTls(); + } + + //TODO: Move this in to grpc_request_proxy + auto maybeDatabase = GetPeerMetaValues(TStringBuf("x-ydb-database")); + if (maybeDatabase.empty()) { + Counters_->CountRequestsWithoutDatabase(); + } + auto maybeToken = GetPeerMetaValues(TStringBuf("x-ydb-auth-ticket")); + if (maybeToken.empty() || maybeToken[0].empty()) { + TString db{maybeDatabase ? maybeDatabase[0] : TStringBuf{}}; + Counters_->CountRequestsWithoutToken(); + GRPC_LOG_DEBUG(Logger_, "[%p] received request without user token " + "Name# %s data# %s peer# %s database# %s", this, Name_, + makeRequestString().data(), this->Context.peer().c_str(), db.c_str()); + } + + // Handle current request. + Cb_(this); + } else { + //This request has not been counted + SkipUpdateCountersOnError = true; + FinishGrpcStatus(grpc::StatusCode::RESOURCE_EXHAUSTED, "no resource", true); + } + return true; + } + + bool NextReply(bool ok) { + auto logCb = [this, ok](int left) { + GRPC_LOG_DEBUG(Logger_, "[%p] ready for next reply Name# %s ok# %s peer# %s left# %d", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str(), left); + }; + + if (!ok) { + logCb(-1); + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + return false; + } + + Ref(); // To prevent destroy during this call in case of execution Finish + size_t left = StreamAdaptor_->ProcessNext(); + logCb(left); + if (NextReplyCb_) { + NextReplyCb_(left); + } + // Now it is safe to destroy even if Finish was called + UnRef(); + return true; + } + + bool SetFinishDone(bool ok) { + GRPC_LOG_DEBUG(Logger_, "[%p] finished request Name# %s ok# %s peer# %s", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str()); + //PrintBackTrace(); + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + return false; + } + + bool SetFinishError(bool ok) { + GRPC_LOG_DEBUG(Logger_, "[%p] finished request with error Name# %s ok# %s peer# %s", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str()); + if (!SkipUpdateCountersOnError) { + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + } + return false; + } + + // Returns pointer to IQueueEvent to pass into grpc c runtime + // Implicit C style cast from this to void* is wrong due to multiple inheritance + void* GetGRpcTag() { + return static_cast<IQueueEvent*>(this); + } + + void OnFinish(EQueueEventStatus evStatus) { + if (this->Context.IsCancelled()) { + FinishPromise_.SetValue(EFinishStatus::CANCEL); + } else { + FinishPromise_.SetValue(evStatus == EQueueEventStatus::OK ? EFinishStatus::OK : EFinishStatus::ERROR); + } + } + + bool IncRequest() { + if (!Server_->IncRequest()) + return false; + + if (!RequestLimiter_) + return true; + + if (!RequestLimiter_->IncRequest()) { + Server_->DecRequest(); + return false; + } + + return true; + } + + void DecRequest() { + if (RequestLimiter_) { + RequestLimiter_->DecRequest(); + } + Server_->DecRequest(); + } + + using TStateFunc = bool (TThis::*)(bool); + TService* Server_; + TOnRequest Cb_; + TRequestCallback RequestCallback_; + TStreamRequestCallback StreamRequestCallback_; + const char* const Name_; + TLoggerPtr Logger_; + ICounterBlockPtr Counters_; + IGRpcRequestLimiterPtr RequestLimiter_; + + THolder<grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>> Writer_; + THolder<grpc::ServerAsyncWriterInterface<TUniversalResponse<TOut>>> StreamWriter_; + TStateFunc StateFunc_; + TIn* Request_; + + google::protobuf::Arena Arena_; + TOnNextReply NextReplyCb_; + ui32 RequestSize = 0; + ui32 ResponseSize = 0; + ui32 ResponseStatus = 0; + THPTimer RequestTimer; + TAuthState AuthState_ = 0; + bool RequestRegistered_ = false; + + using TFixedEvent = TQueueFixedEvent<TGRpcRequestImpl>; + TFixedEvent OnFinishTag = { this, &TGRpcRequestImpl::OnFinish }; + NThreading::TPromise<EFinishStatus> FinishPromise_; + bool SkipUpdateCountersOnError = false; + IStreamAdaptor::TPtr StreamAdaptor_; +}; + +template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter=google::protobuf::TextFormat::Printer, typename TOutProtoPrinter=google::protobuf::TextFormat::Printer> +class TGRpcRequest: public TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter> { + using TBase = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; +public: + TGRpcRequest(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + typename TBase::TOnRequest cb, + typename TBase::TRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter = nullptr) + : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), std::move(limiter)} + { + } + + TGRpcRequest(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + typename TBase::TOnRequest cb, + typename TBase::TStreamRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters) + : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), nullptr} + { + } +}; + +} // namespace NGrpc |