aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/grpc/server/grpc_request.h
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/grpc/server/grpc_request.h
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/grpc/server/grpc_request.h')
-rw-r--r--library/cpp/grpc/server/grpc_request.h543
1 files changed, 543 insertions, 0 deletions
diff --git a/library/cpp/grpc/server/grpc_request.h b/library/cpp/grpc/server/grpc_request.h
new file mode 100644
index 0000000000..5bd8d3902b
--- /dev/null
+++ b/library/cpp/grpc/server/grpc_request.h
@@ -0,0 +1,543 @@
+#pragma once
+
+#include <google/protobuf/text_format.h>
+#include <google/protobuf/arena.h>
+#include <google/protobuf/message.h>
+
+#include <library/cpp/monlib/dynamic_counters/counters.h>
+#include <library/cpp/logger/priority.h>
+
+#include "grpc_response.h"
+#include "event_callback.h"
+#include "grpc_async_ctx_base.h"
+#include "grpc_counters.h"
+#include "grpc_request_base.h"
+#include "grpc_server.h"
+#include "logger.h"
+
+#include <util/system/hp_timer.h>
+
+#include <grpc++/server.h>
+#include <grpc++/server_context.h>
+#include <grpc++/support/async_stream.h>
+#include <grpc++/support/async_unary_call.h>
+#include <grpc++/support/byte_buffer.h>
+#include <grpc++/impl/codegen/async_stream.h>
+
+namespace NGrpc {
+
+class IStreamAdaptor {
+public:
+ using TPtr = std::unique_ptr<IStreamAdaptor>;
+ virtual void Enqueue(std::function<void()>&& fn, bool urgent) = 0;
+ virtual size_t ProcessNext() = 0;
+ virtual ~IStreamAdaptor() = default;
+};
+
+IStreamAdaptor::TPtr CreateStreamAdaptor();
+
+///////////////////////////////////////////////////////////////////////////////
+template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter, typename TOutProtoPrinter>
+class TGRpcRequestImpl
+ : public TBaseAsyncContext<TService>
+ , public IQueueEvent
+ , public IRequestContextBase
+{
+ using TThis = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>;
+
+public:
+ using TOnRequest = std::function<void (IRequestContextBase* ctx)>;
+ using TRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*,
+ grpc::ServerAsyncResponseWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*);
+ using TStreamRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*,
+ grpc::ServerAsyncWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*);
+
+ TGRpcRequestImpl(TService* server,
+ typename TService::TCurrentGRpcService::AsyncService* service,
+ grpc::ServerCompletionQueue* cq,
+ TOnRequest cb,
+ TRequestCallback requestCallback,
+ const char* name,
+ TLoggerPtr logger,
+ ICounterBlockPtr counters,
+ IGRpcRequestLimiterPtr limiter)
+ : TBaseAsyncContext<TService>(service, cq)
+ , Server_(server)
+ , Cb_(cb)
+ , RequestCallback_(requestCallback)
+ , StreamRequestCallback_(nullptr)
+ , Name_(name)
+ , Logger_(std::move(logger))
+ , Counters_(std::move(counters))
+ , RequestLimiter_(std::move(limiter))
+ , Writer_(new grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>(&this->Context))
+ , StateFunc_(&TThis::SetRequestDone)
+ {
+ AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false);
+ Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_);
+ Y_VERIFY(Request_);
+ GRPC_LOG_DEBUG(Logger_, "[%p] created request Name# %s", this, Name_);
+ FinishPromise_ = NThreading::NewPromise<EFinishStatus>();
+ }
+
+ TGRpcRequestImpl(TService* server,
+ typename TService::TCurrentGRpcService::AsyncService* service,
+ grpc::ServerCompletionQueue* cq,
+ TOnRequest cb,
+ TStreamRequestCallback requestCallback,
+ const char* name,
+ TLoggerPtr logger,
+ ICounterBlockPtr counters,
+ IGRpcRequestLimiterPtr limiter)
+ : TBaseAsyncContext<TService>(service, cq)
+ , Server_(server)
+ , Cb_(cb)
+ , RequestCallback_(nullptr)
+ , StreamRequestCallback_(requestCallback)
+ , Name_(name)
+ , Logger_(std::move(logger))
+ , Counters_(std::move(counters))
+ , RequestLimiter_(std::move(limiter))
+ , StreamWriter_(new grpc::ServerAsyncWriter<TUniversalResponse<TOut>>(&this->Context))
+ , StateFunc_(&TThis::SetRequestDone)
+ {
+ AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false);
+ Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_);
+ Y_VERIFY(Request_);
+ GRPC_LOG_DEBUG(Logger_, "[%p] created streaming request Name# %s", this, Name_);
+ FinishPromise_ = NThreading::NewPromise<EFinishStatus>();
+ StreamAdaptor_ = CreateStreamAdaptor();
+ }
+
+ TAsyncFinishResult GetFinishFuture() override {
+ return FinishPromise_.GetFuture();
+ }
+
+ TString GetPeer() const override {
+ return TString(this->Context.peer());
+ }
+
+ bool SslServer() const override {
+ return Server_->SslServer();
+ }
+
+ void Run() {
+ // Start request unless server is shutting down
+ if (auto guard = Server_->ProtectShutdown()) {
+ Ref(); //For grpc c runtime
+ this->Context.AsyncNotifyWhenDone(OnFinishTag.Prepare());
+ if (RequestCallback_) {
+ (this->Service->*RequestCallback_)
+ (&this->Context, Request_,
+ reinterpret_cast<grpc::ServerAsyncResponseWriter<TOut>*>(Writer_.Get()), this->CQ, this->CQ, GetGRpcTag());
+ } else {
+ (this->Service->*StreamRequestCallback_)
+ (&this->Context, Request_,
+ reinterpret_cast<grpc::ServerAsyncWriter<TOut>*>(StreamWriter_.Get()), this->CQ, this->CQ, GetGRpcTag());
+ }
+ }
+ }
+
+ ~TGRpcRequestImpl() {
+ // No direct dtor call allowed
+ Y_ASSERT(RefCount() == 0);
+ }
+
+ bool Execute(bool ok) override {
+ return (this->*StateFunc_)(ok);
+ }
+
+ void DestroyRequest() override {
+ if (RequestRegistered_) {
+ Server_->DeregisterRequestCtx(this);
+ RequestRegistered_ = false;
+ }
+ UnRef();
+ }
+
+ TInstant Deadline() const override {
+ return TBaseAsyncContext<TService>::Deadline();
+ }
+
+ TSet<TStringBuf> GetPeerMetaKeys() const override {
+ return TBaseAsyncContext<TService>::GetPeerMetaKeys();
+ }
+
+ TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const override {
+ return TBaseAsyncContext<TService>::GetPeerMetaValues(key);
+ }
+
+ grpc_compression_level GetCompressionLevel() const override {
+ return TBaseAsyncContext<TService>::GetCompressionLevel();
+ }
+
+ //! Get pointer to the request's message.
+ const NProtoBuf::Message* GetRequest() const override {
+ return Request_;
+ }
+
+ TAuthState& GetAuthState() override {
+ return AuthState_;
+ }
+
+ void Reply(NProtoBuf::Message* resp, ui32 status) override {
+ ResponseStatus = status;
+ WriteDataOk(resp);
+ }
+
+ void Reply(grpc::ByteBuffer* resp, ui32 status) override {
+ ResponseStatus = status;
+ WriteByteDataOk(resp);
+ }
+
+ void ReplyError(grpc::StatusCode code, const TString& msg) override {
+ FinishGrpcStatus(code, msg, false);
+ }
+
+ void ReplyUnauthenticated(const TString& in) override {
+ const TString message = in.empty() ? TString("unauthenticated") : TString("unauthenticated, ") + in;
+ FinishGrpcStatus(grpc::StatusCode::UNAUTHENTICATED, message, false);
+ }
+
+ void SetNextReplyCallback(TOnNextReply&& cb) override {
+ NextReplyCb_ = cb;
+ }
+
+ void AddTrailingMetadata(const TString& key, const TString& value) override {
+ this->Context.AddTrailingMetadata(key, value);
+ }
+
+ void FinishStreamingOk() override {
+ GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (enqueued)", this, Name_,
+ this->Context.peer().c_str());
+ auto cb = [this]() {
+ StateFunc_ = &TThis::SetFinishDone;
+ GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (pushed to grpc)", this, Name_,
+ this->Context.peer().c_str());
+
+ StreamWriter_->Finish(grpc::Status::OK, GetGRpcTag());
+ };
+ StreamAdaptor_->Enqueue(std::move(cb), false);
+ }
+
+ google::protobuf::Arena* GetArena() override {
+ return &Arena_;
+ }
+
+ void UseDatabase(const TString& database) override {
+ Counters_->UseDatabase(database);
+ }
+
+private:
+ void Clone() {
+ if (!Server_->IsShuttingDown()) {
+ if (RequestCallback_) {
+ MakeIntrusive<TThis>(
+ Server_, this->Service, this->CQ, Cb_, RequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run();
+ } else {
+ MakeIntrusive<TThis>(
+ Server_, this->Service, this->CQ, Cb_, StreamRequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run();
+ }
+ }
+ }
+
+ void WriteDataOk(NProtoBuf::Message* resp) {
+ auto makeResponseString = [&] {
+ TString x;
+ TOutProtoPrinter printer;
+ printer.SetSingleLineMode(true);
+ printer.PrintToString(*resp, &x);
+ return x;
+ };
+
+ auto sz = (size_t)resp->ByteSize();
+ if (Writer_) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s", this, Name_,
+ makeResponseString().data(), this->Context.peer().c_str());
+ StateFunc_ = &TThis::SetFinishDone;
+ ResponseSize = sz;
+ Y_VERIFY(this->Context.c_call());
+ Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag());
+ } else {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (enqueued)",
+ this, Name_, makeResponseString().data(), this->Context.peer().c_str());
+
+ // because of std::function cannot hold move-only captured object
+ // we allocate shared object on heap to avoid message copy
+ auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp);
+ auto cb = [this, uResp = std::move(uResp), sz, &makeResponseString]() {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (pushed to grpc)",
+ this, Name_, makeResponseString().data(), this->Context.peer().c_str());
+ StateFunc_ = &TThis::NextReply;
+ ResponseSize += sz;
+ StreamWriter_->Write(*uResp, GetGRpcTag());
+ };
+ StreamAdaptor_->Enqueue(std::move(cb), false);
+ }
+ }
+
+ void WriteByteDataOk(grpc::ByteBuffer* resp) {
+ auto sz = resp->Length();
+ if (Writer_) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s", this, Name_,
+ this->Context.peer().c_str());
+ StateFunc_ = &TThis::SetFinishDone;
+ ResponseSize = sz;
+ Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag());
+ } else {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (enqueued)", this, Name_,
+ this->Context.peer().c_str());
+
+ // because of std::function cannot hold move-only captured object
+ // we allocate shared object on heap to avoid buffer copy
+ auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp);
+ auto cb = [this, uResp = std::move(uResp), sz]() {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (pushed to grpc)",
+ this, Name_, this->Context.peer().c_str());
+ StateFunc_ = &TThis::NextReply;
+ ResponseSize += sz;
+ StreamWriter_->Write(*uResp, GetGRpcTag());
+ };
+ StreamAdaptor_->Enqueue(std::move(cb), false);
+ }
+ }
+
+ void FinishGrpcStatus(grpc::StatusCode code, const TString& msg, bool urgent) {
+ Y_VERIFY(code != grpc::OK);
+ if (code == grpc::StatusCode::UNAUTHENTICATED) {
+ Counters_->CountNotAuthenticated();
+ } else if (code == grpc::StatusCode::RESOURCE_EXHAUSTED) {
+ Counters_->CountResourceExhausted();
+ }
+
+ if (Writer_) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)", this,
+ Name_, msg.c_str(), this->Context.peer().c_str(), (int)code);
+ StateFunc_ = &TThis::SetFinishError;
+ TOut resp;
+ Writer_->Finish(TUniversalResponseRef<TOut>(&resp), grpc::Status(code, msg), GetGRpcTag());
+ } else {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)"
+ " (enqueued)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code);
+ auto cb = [this, code, msg]() {
+ GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)"
+ " (pushed to grpc)", this, Name_, msg.c_str(),
+ this->Context.peer().c_str(), (int)code);
+ StateFunc_ = &TThis::SetFinishError;
+ StreamWriter_->Finish(grpc::Status(code, msg), GetGRpcTag());
+ };
+ StreamAdaptor_->Enqueue(std::move(cb), urgent);
+ }
+ }
+
+ bool SetRequestDone(bool ok) {
+ auto makeRequestString = [&] {
+ TString resp;
+ if (ok) {
+ TInProtoPrinter printer;
+ printer.SetSingleLineMode(true);
+ printer.PrintToString(*Request_, &resp);
+ } else {
+ resp = "<not ok>";
+ }
+ return resp;
+ };
+ GRPC_LOG_DEBUG(Logger_, "[%p] received request Name# %s ok# %s data# %s peer# %s", this, Name_,
+ ok ? "true" : "false", makeRequestString().data(), this->Context.peer().c_str());
+
+ if (this->Context.c_call() == nullptr) {
+ Y_VERIFY(!ok);
+ // One ref by OnFinishTag, grpc will not call this tag if no request received
+ UnRef();
+ } else if (!(RequestRegistered_ = Server_->RegisterRequestCtx(this))) {
+ // Request cannot be registered due to shutdown
+ // It's unsafe to continue, so drop this request without processing
+ GRPC_LOG_DEBUG(Logger_, "[%p] dropping request Name# %s due to shutdown", this, Name_);
+ this->Context.TryCancel();
+ return false;
+ }
+
+ Clone(); // TODO: Request pool?
+ if (!ok) {
+ Counters_->CountNotOkRequest();
+ return false;
+ }
+
+ if (IncRequest()) {
+ // Adjust counters.
+ RequestSize = Request_->ByteSize();
+ Counters_->StartProcessing(RequestSize);
+ RequestTimer.Reset();
+
+ if (!SslServer()) {
+ Counters_->CountRequestWithoutTls();
+ }
+
+ //TODO: Move this in to grpc_request_proxy
+ auto maybeDatabase = GetPeerMetaValues(TStringBuf("x-ydb-database"));
+ if (maybeDatabase.empty()) {
+ Counters_->CountRequestsWithoutDatabase();
+ }
+ auto maybeToken = GetPeerMetaValues(TStringBuf("x-ydb-auth-ticket"));
+ if (maybeToken.empty() || maybeToken[0].empty()) {
+ TString db{maybeDatabase ? maybeDatabase[0] : TStringBuf{}};
+ Counters_->CountRequestsWithoutToken();
+ GRPC_LOG_DEBUG(Logger_, "[%p] received request without user token "
+ "Name# %s data# %s peer# %s database# %s", this, Name_,
+ makeRequestString().data(), this->Context.peer().c_str(), db.c_str());
+ }
+
+ // Handle current request.
+ Cb_(this);
+ } else {
+ //This request has not been counted
+ SkipUpdateCountersOnError = true;
+ FinishGrpcStatus(grpc::StatusCode::RESOURCE_EXHAUSTED, "no resource", true);
+ }
+ return true;
+ }
+
+ bool NextReply(bool ok) {
+ auto logCb = [this, ok](int left) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] ready for next reply Name# %s ok# %s peer# %s left# %d", this, Name_,
+ ok ? "true" : "false", this->Context.peer().c_str(), left);
+ };
+
+ if (!ok) {
+ logCb(-1);
+ DecRequest();
+ Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus,
+ TDuration::Seconds(RequestTimer.Passed()));
+ return false;
+ }
+
+ Ref(); // To prevent destroy during this call in case of execution Finish
+ size_t left = StreamAdaptor_->ProcessNext();
+ logCb(left);
+ if (NextReplyCb_) {
+ NextReplyCb_(left);
+ }
+ // Now it is safe to destroy even if Finish was called
+ UnRef();
+ return true;
+ }
+
+ bool SetFinishDone(bool ok) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] finished request Name# %s ok# %s peer# %s", this, Name_,
+ ok ? "true" : "false", this->Context.peer().c_str());
+ //PrintBackTrace();
+ DecRequest();
+ Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus,
+ TDuration::Seconds(RequestTimer.Passed()));
+ return false;
+ }
+
+ bool SetFinishError(bool ok) {
+ GRPC_LOG_DEBUG(Logger_, "[%p] finished request with error Name# %s ok# %s peer# %s", this, Name_,
+ ok ? "true" : "false", this->Context.peer().c_str());
+ if (!SkipUpdateCountersOnError) {
+ DecRequest();
+ Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus,
+ TDuration::Seconds(RequestTimer.Passed()));
+ }
+ return false;
+ }
+
+ // Returns pointer to IQueueEvent to pass into grpc c runtime
+ // Implicit C style cast from this to void* is wrong due to multiple inheritance
+ void* GetGRpcTag() {
+ return static_cast<IQueueEvent*>(this);
+ }
+
+ void OnFinish(EQueueEventStatus evStatus) {
+ if (this->Context.IsCancelled()) {
+ FinishPromise_.SetValue(EFinishStatus::CANCEL);
+ } else {
+ FinishPromise_.SetValue(evStatus == EQueueEventStatus::OK ? EFinishStatus::OK : EFinishStatus::ERROR);
+ }
+ }
+
+ bool IncRequest() {
+ if (!Server_->IncRequest())
+ return false;
+
+ if (!RequestLimiter_)
+ return true;
+
+ if (!RequestLimiter_->IncRequest()) {
+ Server_->DecRequest();
+ return false;
+ }
+
+ return true;
+ }
+
+ void DecRequest() {
+ if (RequestLimiter_) {
+ RequestLimiter_->DecRequest();
+ }
+ Server_->DecRequest();
+ }
+
+ using TStateFunc = bool (TThis::*)(bool);
+ TService* Server_;
+ TOnRequest Cb_;
+ TRequestCallback RequestCallback_;
+ TStreamRequestCallback StreamRequestCallback_;
+ const char* const Name_;
+ TLoggerPtr Logger_;
+ ICounterBlockPtr Counters_;
+ IGRpcRequestLimiterPtr RequestLimiter_;
+
+ THolder<grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>> Writer_;
+ THolder<grpc::ServerAsyncWriterInterface<TUniversalResponse<TOut>>> StreamWriter_;
+ TStateFunc StateFunc_;
+ TIn* Request_;
+
+ google::protobuf::Arena Arena_;
+ TOnNextReply NextReplyCb_;
+ ui32 RequestSize = 0;
+ ui32 ResponseSize = 0;
+ ui32 ResponseStatus = 0;
+ THPTimer RequestTimer;
+ TAuthState AuthState_ = 0;
+ bool RequestRegistered_ = false;
+
+ using TFixedEvent = TQueueFixedEvent<TGRpcRequestImpl>;
+ TFixedEvent OnFinishTag = { this, &TGRpcRequestImpl::OnFinish };
+ NThreading::TPromise<EFinishStatus> FinishPromise_;
+ bool SkipUpdateCountersOnError = false;
+ IStreamAdaptor::TPtr StreamAdaptor_;
+};
+
+template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter=google::protobuf::TextFormat::Printer, typename TOutProtoPrinter=google::protobuf::TextFormat::Printer>
+class TGRpcRequest: public TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter> {
+ using TBase = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>;
+public:
+ TGRpcRequest(TService* server,
+ typename TService::TCurrentGRpcService::AsyncService* service,
+ grpc::ServerCompletionQueue* cq,
+ typename TBase::TOnRequest cb,
+ typename TBase::TRequestCallback requestCallback,
+ const char* name,
+ TLoggerPtr logger,
+ ICounterBlockPtr counters,
+ IGRpcRequestLimiterPtr limiter = nullptr)
+ : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), std::move(limiter)}
+ {
+ }
+
+ TGRpcRequest(TService* server,
+ typename TService::TCurrentGRpcService::AsyncService* service,
+ grpc::ServerCompletionQueue* cq,
+ typename TBase::TOnRequest cb,
+ typename TBase::TStreamRequestCallback requestCallback,
+ const char* name,
+ TLoggerPtr logger,
+ ICounterBlockPtr counters)
+ : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), nullptr}
+ {
+ }
+};
+
+} // namespace NGrpc