#pragma once #include <google/protobuf/text_format.h> #include <google/protobuf/arena.h> #include <google/protobuf/message.h> #include <library/cpp/monlib/dynamic_counters/counters.h> #include <library/cpp/logger/priority.h> #include "grpc_response.h" #include "event_callback.h" #include "grpc_async_ctx_base.h" #include "grpc_counters.h" #include "grpc_request_base.h" #include "grpc_server.h" #include "logger.h" #include <util/system/hp_timer.h> #include <grpc++/server.h> #include <grpc++/server_context.h> #include <grpc++/support/async_stream.h> #include <grpc++/support/async_unary_call.h> #include <grpc++/support/byte_buffer.h> #include <grpc++/impl/codegen/async_stream.h> namespace NGrpc { class IStreamAdaptor { public: using TPtr = std::unique_ptr<IStreamAdaptor>; virtual void Enqueue(std::function<void()>&& fn, bool urgent) = 0; virtual size_t ProcessNext() = 0; virtual ~IStreamAdaptor() = default; }; IStreamAdaptor::TPtr CreateStreamAdaptor(); /////////////////////////////////////////////////////////////////////////////// template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter, typename TOutProtoPrinter> class TGRpcRequestImpl : public TBaseAsyncContext<TService> , public IQueueEvent , public IRequestContextBase { using TThis = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; public: using TOnRequest = std::function<void (IRequestContextBase* ctx)>; using TRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, grpc::ServerAsyncResponseWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); using TStreamRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, grpc::ServerAsyncWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); TGRpcRequestImpl(TService* server, typename TService::TCurrentGRpcService::AsyncService* service, grpc::ServerCompletionQueue* cq, TOnRequest cb, TRequestCallback requestCallback, const char* name, TLoggerPtr logger, ICounterBlockPtr counters, IGRpcRequestLimiterPtr limiter) : TBaseAsyncContext<TService>(service, cq) , Server_(server) , Cb_(cb) , RequestCallback_(requestCallback) , StreamRequestCallback_(nullptr) , Name_(name) , Logger_(std::move(logger)) , Counters_(std::move(counters)) , RequestLimiter_(std::move(limiter)) , Writer_(new grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>(&this->Context)) , StateFunc_(&TThis::SetRequestDone) { AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); Y_VERIFY(Request_); GRPC_LOG_DEBUG(Logger_, "[%p] created request Name# %s", this, Name_); FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); } TGRpcRequestImpl(TService* server, typename TService::TCurrentGRpcService::AsyncService* service, grpc::ServerCompletionQueue* cq, TOnRequest cb, TStreamRequestCallback requestCallback, const char* name, TLoggerPtr logger, ICounterBlockPtr counters, IGRpcRequestLimiterPtr limiter) : TBaseAsyncContext<TService>(service, cq) , Server_(server) , Cb_(cb) , RequestCallback_(nullptr) , StreamRequestCallback_(requestCallback) , Name_(name) , Logger_(std::move(logger)) , Counters_(std::move(counters)) , RequestLimiter_(std::move(limiter)) , StreamWriter_(new grpc::ServerAsyncWriter<TUniversalResponse<TOut>>(&this->Context)) , StateFunc_(&TThis::SetRequestDone) { AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); Y_VERIFY(Request_); GRPC_LOG_DEBUG(Logger_, "[%p] created streaming request Name# %s", this, Name_); FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); StreamAdaptor_ = CreateStreamAdaptor(); } TAsyncFinishResult GetFinishFuture() override { return FinishPromise_.GetFuture(); } TString GetPeer() const override { return TString(this->Context.peer()); } bool SslServer() const override { return Server_->SslServer(); } void Run() { // Start request unless server is shutting down if (auto guard = Server_->ProtectShutdown()) { Ref(); //For grpc c runtime this->Context.AsyncNotifyWhenDone(OnFinishTag.Prepare()); if (RequestCallback_) { (this->Service->*RequestCallback_) (&this->Context, Request_, reinterpret_cast<grpc::ServerAsyncResponseWriter<TOut>*>(Writer_.Get()), this->CQ, this->CQ, GetGRpcTag()); } else { (this->Service->*StreamRequestCallback_) (&this->Context, Request_, reinterpret_cast<grpc::ServerAsyncWriter<TOut>*>(StreamWriter_.Get()), this->CQ, this->CQ, GetGRpcTag()); } } } ~TGRpcRequestImpl() { // No direct dtor call allowed Y_ASSERT(RefCount() == 0); } bool Execute(bool ok) override { return (this->*StateFunc_)(ok); } void DestroyRequest() override { if (RequestRegistered_) { Server_->DeregisterRequestCtx(this); RequestRegistered_ = false; } UnRef(); } TInstant Deadline() const override { return TBaseAsyncContext<TService>::Deadline(); } TSet<TStringBuf> GetPeerMetaKeys() const override { return TBaseAsyncContext<TService>::GetPeerMetaKeys(); } TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const override { return TBaseAsyncContext<TService>::GetPeerMetaValues(key); } grpc_compression_level GetCompressionLevel() const override { return TBaseAsyncContext<TService>::GetCompressionLevel(); } //! Get pointer to the request's message. const NProtoBuf::Message* GetRequest() const override { return Request_; } TAuthState& GetAuthState() override { return AuthState_; } void Reply(NProtoBuf::Message* resp, ui32 status) override { ResponseStatus = status; WriteDataOk(resp); } void Reply(grpc::ByteBuffer* resp, ui32 status) override { ResponseStatus = status; WriteByteDataOk(resp); } void ReplyError(grpc::StatusCode code, const TString& msg) override { FinishGrpcStatus(code, msg, false); } void ReplyUnauthenticated(const TString& in) override { const TString message = in.empty() ? TString("unauthenticated") : TString("unauthenticated, ") + in; FinishGrpcStatus(grpc::StatusCode::UNAUTHENTICATED, message, false); } void SetNextReplyCallback(TOnNextReply&& cb) override { NextReplyCb_ = cb; } void AddTrailingMetadata(const TString& key, const TString& value) override { this->Context.AddTrailingMetadata(key, value); } void FinishStreamingOk() override { GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (enqueued)", this, Name_, this->Context.peer().c_str()); auto cb = [this]() { StateFunc_ = &TThis::SetFinishDone; GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (pushed to grpc)", this, Name_, this->Context.peer().c_str()); StreamWriter_->Finish(grpc::Status::OK, GetGRpcTag()); }; StreamAdaptor_->Enqueue(std::move(cb), false); } google::protobuf::Arena* GetArena() override { return &Arena_; } void UseDatabase(const TString& database) override { Counters_->UseDatabase(database); } private: void Clone() { if (!Server_->IsShuttingDown()) { if (RequestCallback_) { MakeIntrusive<TThis>( Server_, this->Service, this->CQ, Cb_, RequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); } else { MakeIntrusive<TThis>( Server_, this->Service, this->CQ, Cb_, StreamRequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); } } } void WriteDataOk(NProtoBuf::Message* resp) { auto makeResponseString = [&] { TString x; TOutProtoPrinter printer; printer.SetSingleLineMode(true); printer.PrintToString(*resp, &x); return x; }; auto sz = (size_t)resp->ByteSize(); if (Writer_) { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s", this, Name_, makeResponseString().data(), this->Context.peer().c_str()); StateFunc_ = &TThis::SetFinishDone; ResponseSize = sz; Y_VERIFY(this->Context.c_call()); Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); } else { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (enqueued)", this, Name_, makeResponseString().data(), this->Context.peer().c_str()); // because of std::function cannot hold move-only captured object // we allocate shared object on heap to avoid message copy auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); auto cb = [this, uResp = std::move(uResp), sz, &makeResponseString]() { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (pushed to grpc)", this, Name_, makeResponseString().data(), this->Context.peer().c_str()); StateFunc_ = &TThis::NextReply; ResponseSize += sz; StreamWriter_->Write(*uResp, GetGRpcTag()); }; StreamAdaptor_->Enqueue(std::move(cb), false); } } void WriteByteDataOk(grpc::ByteBuffer* resp) { auto sz = resp->Length(); if (Writer_) { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s", this, Name_, this->Context.peer().c_str()); StateFunc_ = &TThis::SetFinishDone; ResponseSize = sz; Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); } else { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (enqueued)", this, Name_, this->Context.peer().c_str()); // because of std::function cannot hold move-only captured object // we allocate shared object on heap to avoid buffer copy auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); auto cb = [this, uResp = std::move(uResp), sz]() { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (pushed to grpc)", this, Name_, this->Context.peer().c_str()); StateFunc_ = &TThis::NextReply; ResponseSize += sz; StreamWriter_->Write(*uResp, GetGRpcTag()); }; StreamAdaptor_->Enqueue(std::move(cb), false); } } void FinishGrpcStatus(grpc::StatusCode code, const TString& msg, bool urgent) { Y_VERIFY(code != grpc::OK); if (code == grpc::StatusCode::UNAUTHENTICATED) { Counters_->CountNotAuthenticated(); } else if (code == grpc::StatusCode::RESOURCE_EXHAUSTED) { Counters_->CountResourceExhausted(); } if (Writer_) { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); StateFunc_ = &TThis::SetFinishError; TOut resp; Writer_->Finish(TUniversalResponseRef<TOut>(&resp), grpc::Status(code, msg), GetGRpcTag()); } else { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" " (enqueued)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); auto cb = [this, code, msg]() { GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" " (pushed to grpc)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); StateFunc_ = &TThis::SetFinishError; StreamWriter_->Finish(grpc::Status(code, msg), GetGRpcTag()); }; StreamAdaptor_->Enqueue(std::move(cb), urgent); } } bool SetRequestDone(bool ok) { auto makeRequestString = [&] { TString resp; if (ok) { TInProtoPrinter printer; printer.SetSingleLineMode(true); printer.PrintToString(*Request_, &resp); } else { resp = "<not ok>"; } return resp; }; GRPC_LOG_DEBUG(Logger_, "[%p] received request Name# %s ok# %s data# %s peer# %s", this, Name_, ok ? "true" : "false", makeRequestString().data(), this->Context.peer().c_str()); if (this->Context.c_call() == nullptr) { Y_VERIFY(!ok); // One ref by OnFinishTag, grpc will not call this tag if no request received UnRef(); } else if (!(RequestRegistered_ = Server_->RegisterRequestCtx(this))) { // Request cannot be registered due to shutdown // It's unsafe to continue, so drop this request without processing GRPC_LOG_DEBUG(Logger_, "[%p] dropping request Name# %s due to shutdown", this, Name_); this->Context.TryCancel(); return false; } Clone(); // TODO: Request pool? if (!ok) { Counters_->CountNotOkRequest(); return false; } if (IncRequest()) { // Adjust counters. RequestSize = Request_->ByteSize(); Counters_->StartProcessing(RequestSize); RequestTimer.Reset(); if (!SslServer()) { Counters_->CountRequestWithoutTls(); } //TODO: Move this in to grpc_request_proxy auto maybeDatabase = GetPeerMetaValues(TStringBuf("x-ydb-database")); if (maybeDatabase.empty()) { Counters_->CountRequestsWithoutDatabase(); } auto maybeToken = GetPeerMetaValues(TStringBuf("x-ydb-auth-ticket")); if (maybeToken.empty() || maybeToken[0].empty()) { TString db{maybeDatabase ? maybeDatabase[0] : TStringBuf{}}; Counters_->CountRequestsWithoutToken(); GRPC_LOG_DEBUG(Logger_, "[%p] received request without user token " "Name# %s data# %s peer# %s database# %s", this, Name_, makeRequestString().data(), this->Context.peer().c_str(), db.c_str()); } // Handle current request. Cb_(this); } else { //This request has not been counted SkipUpdateCountersOnError = true; FinishGrpcStatus(grpc::StatusCode::RESOURCE_EXHAUSTED, "no resource", true); } return true; } bool NextReply(bool ok) { auto logCb = [this, ok](int left) { GRPC_LOG_DEBUG(Logger_, "[%p] ready for next reply Name# %s ok# %s peer# %s left# %d", this, Name_, ok ? "true" : "false", this->Context.peer().c_str(), left); }; if (!ok) { logCb(-1); DecRequest(); Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, TDuration::Seconds(RequestTimer.Passed())); return false; } Ref(); // To prevent destroy during this call in case of execution Finish size_t left = StreamAdaptor_->ProcessNext(); logCb(left); if (NextReplyCb_) { NextReplyCb_(left); } // Now it is safe to destroy even if Finish was called UnRef(); return true; } bool SetFinishDone(bool ok) { GRPC_LOG_DEBUG(Logger_, "[%p] finished request Name# %s ok# %s peer# %s", this, Name_, ok ? "true" : "false", this->Context.peer().c_str()); //PrintBackTrace(); DecRequest(); Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, TDuration::Seconds(RequestTimer.Passed())); return false; } bool SetFinishError(bool ok) { GRPC_LOG_DEBUG(Logger_, "[%p] finished request with error Name# %s ok# %s peer# %s", this, Name_, ok ? "true" : "false", this->Context.peer().c_str()); if (!SkipUpdateCountersOnError) { DecRequest(); Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, TDuration::Seconds(RequestTimer.Passed())); } return false; } // Returns pointer to IQueueEvent to pass into grpc c runtime // Implicit C style cast from this to void* is wrong due to multiple inheritance void* GetGRpcTag() { return static_cast<IQueueEvent*>(this); } void OnFinish(EQueueEventStatus evStatus) { if (this->Context.IsCancelled()) { FinishPromise_.SetValue(EFinishStatus::CANCEL); } else { FinishPromise_.SetValue(evStatus == EQueueEventStatus::OK ? EFinishStatus::OK : EFinishStatus::ERROR); } } bool IncRequest() { if (!Server_->IncRequest()) return false; if (!RequestLimiter_) return true; if (!RequestLimiter_->IncRequest()) { Server_->DecRequest(); return false; } return true; } void DecRequest() { if (RequestLimiter_) { RequestLimiter_->DecRequest(); } Server_->DecRequest(); } using TStateFunc = bool (TThis::*)(bool); TService* Server_; TOnRequest Cb_; TRequestCallback RequestCallback_; TStreamRequestCallback StreamRequestCallback_; const char* const Name_; TLoggerPtr Logger_; ICounterBlockPtr Counters_; IGRpcRequestLimiterPtr RequestLimiter_; THolder<grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>> Writer_; THolder<grpc::ServerAsyncWriterInterface<TUniversalResponse<TOut>>> StreamWriter_; TStateFunc StateFunc_; TIn* Request_; google::protobuf::Arena Arena_; TOnNextReply NextReplyCb_; ui32 RequestSize = 0; ui32 ResponseSize = 0; ui32 ResponseStatus = 0; THPTimer RequestTimer; TAuthState AuthState_ = 0; bool RequestRegistered_ = false; using TFixedEvent = TQueueFixedEvent<TGRpcRequestImpl>; TFixedEvent OnFinishTag = { this, &TGRpcRequestImpl::OnFinish }; NThreading::TPromise<EFinishStatus> FinishPromise_; bool SkipUpdateCountersOnError = false; IStreamAdaptor::TPtr StreamAdaptor_; }; template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter=google::protobuf::TextFormat::Printer, typename TOutProtoPrinter=google::protobuf::TextFormat::Printer> class TGRpcRequest: public TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter> { using TBase = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; public: TGRpcRequest(TService* server, typename TService::TCurrentGRpcService::AsyncService* service, grpc::ServerCompletionQueue* cq, typename TBase::TOnRequest cb, typename TBase::TRequestCallback requestCallback, const char* name, TLoggerPtr logger, ICounterBlockPtr counters, IGRpcRequestLimiterPtr limiter = nullptr) : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), std::move(limiter)} { } TGRpcRequest(TService* server, typename TService::TCurrentGRpcService::AsyncService* service, grpc::ServerCompletionQueue* cq, typename TBase::TOnRequest cb, typename TBase::TStreamRequestCallback requestCallback, const char* name, TLoggerPtr logger, ICounterBlockPtr counters) : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), nullptr} { } }; } // namespace NGrpc