aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/grpc/client/grpc_client_low.h
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/grpc/client/grpc_client_low.h
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/grpc/client/grpc_client_low.h')
-rw-r--r--library/cpp/grpc/client/grpc_client_low.h1399
1 files changed, 1399 insertions, 0 deletions
diff --git a/library/cpp/grpc/client/grpc_client_low.h b/library/cpp/grpc/client/grpc_client_low.h
new file mode 100644
index 0000000000..ab0a0627be
--- /dev/null
+++ b/library/cpp/grpc/client/grpc_client_low.h
@@ -0,0 +1,1399 @@
+#pragma once
+
+#include "grpc_common.h"
+
+#include <util/thread/factory.h>
+#include <grpc++/grpc++.h>
+#include <grpc++/support/async_stream.h>
+#include <grpc++/support/async_unary_call.h>
+
+#include <deque>
+#include <typeindex>
+#include <typeinfo>
+#include <variant>
+#include <vector>
+#include <unordered_map>
+#include <unordered_set>
+#include <mutex>
+#include <shared_mutex>
+
+/*
+ * This file contains low level logic for grpc
+ * This file should not be used in high level code without special reason
+ */
+namespace NGrpc {
+
+const size_t DEFAULT_NUM_THREADS = 2;
+
+////////////////////////////////////////////////////////////////////////////////
+
+void EnableGRpcTracing();
+
+////////////////////////////////////////////////////////////////////////////////
+
+struct TTcpKeepAliveSettings {
+ bool Enabled;
+ size_t Idle;
+ size_t Count;
+ size_t Interval;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+
+// Common interface used to execute action from grpc cq routine
+class IQueueClientEvent {
+public:
+ virtual ~IQueueClientEvent() = default;
+
+ //! Execute an action defined by implementation
+ virtual bool Execute(bool ok) = 0;
+
+ //! Finish and destroy event
+ virtual void Destroy() = 0;
+};
+
+// Implementation of IQueueClientEvent that reduces allocations
+template<class TSelf>
+class TQueueClientFixedEvent : private IQueueClientEvent {
+ using TCallback = void (TSelf::*)(bool);
+
+public:
+ TQueueClientFixedEvent(TSelf* self, TCallback callback)
+ : Self(self)
+ , Callback(callback)
+ { }
+
+ IQueueClientEvent* Prepare() {
+ Self->Ref();
+ return this;
+ }
+
+private:
+ bool Execute(bool ok) override {
+ ((*Self).*Callback)(ok);
+ return false;
+ }
+
+ void Destroy() override {
+ Self->UnRef();
+ }
+
+private:
+ TSelf* const Self;
+ TCallback const Callback;
+};
+
+class IQueueClientContext;
+using IQueueClientContextPtr = std::shared_ptr<IQueueClientContext>;
+
+// Provider of IQueueClientContext instances
+class IQueueClientContextProvider {
+public:
+ virtual ~IQueueClientContextProvider() = default;
+
+ virtual IQueueClientContextPtr CreateContext() = 0;
+};
+
+// Activity context for a low-level client
+class IQueueClientContext : public IQueueClientContextProvider {
+public:
+ virtual ~IQueueClientContext() = default;
+
+ //! Returns CompletionQueue associated with the client
+ virtual grpc::CompletionQueue* CompletionQueue() = 0;
+
+ //! Returns true if context has been cancelled
+ virtual bool IsCancelled() const = 0;
+
+ //! Tries to cancel context, calling all registered callbacks
+ virtual bool Cancel() = 0;
+
+ //! Subscribes callback to cancellation
+ //
+ // Note there's no way to unsubscribe, if subscription is temporary
+ // make sure you create a new context with CreateContext and release
+ // it as soon as it's no longer needed.
+ virtual void SubscribeCancel(std::function<void()> callback) = 0;
+
+ //! Subscribes callback to cancellation
+ //
+ // This alias is for compatibility with older code.
+ void SubscribeStop(std::function<void()> callback) {
+ SubscribeCancel(std::move(callback));
+ }
+};
+
+// Represents grpc status and error message string
+struct TGrpcStatus {
+ TString Msg;
+ TString Details;
+ int GRpcStatusCode;
+ bool InternalError;
+
+ TGrpcStatus()
+ : GRpcStatusCode(grpc::StatusCode::OK)
+ , InternalError(false)
+ { }
+
+ TGrpcStatus(TString msg, int statusCode, bool internalError)
+ : Msg(std::move(msg))
+ , GRpcStatusCode(statusCode)
+ , InternalError(internalError)
+ { }
+
+ TGrpcStatus(grpc::StatusCode status, TString msg, TString details = {})
+ : Msg(std::move(msg))
+ , Details(std::move(details))
+ , GRpcStatusCode(status)
+ , InternalError(false)
+ { }
+
+ TGrpcStatus(const grpc::Status& status)
+ : TGrpcStatus(status.error_code(), TString(status.error_message()), TString(status.error_details()))
+ { }
+
+ TGrpcStatus& operator=(const grpc::Status& status) {
+ Msg = TString(status.error_message());
+ Details = TString(status.error_details());
+ GRpcStatusCode = status.error_code();
+ InternalError = false;
+ return *this;
+ }
+
+ static TGrpcStatus Internal(TString msg) {
+ return { std::move(msg), -1, true };
+ }
+
+ bool Ok() const {
+ return !InternalError && GRpcStatusCode == grpc::StatusCode::OK;
+ }
+};
+
+bool inline IsGRpcStatusGood(const TGrpcStatus& status) {
+ return status.Ok();
+}
+
+// Response callback type - this callback will be called when request is finished
+// (or after getting each chunk in case of streaming mode)
+template<typename TResponse>
+using TResponseCallback = std::function<void (TGrpcStatus&&, TResponse&&)>;
+
+template<typename TResponse>
+using TAdvancedResponseCallback = std::function<void (const grpc::ClientContext&, TGrpcStatus&&, TResponse&&)>;
+
+// Call associated metadata
+struct TCallMeta {
+ std::shared_ptr<grpc::CallCredentials> CallCredentials;
+ std::vector<std::pair<TString, TString>> Aux;
+ std::variant<TDuration, TInstant> Timeout; // timeout as duration from now or time point in future
+};
+
+class TGRpcRequestProcessorCommon {
+protected:
+ void ApplyMeta(const TCallMeta& meta) {
+ for (const auto& rec : meta.Aux) {
+ Context.AddMetadata(rec.first, rec.second);
+ }
+ if (meta.CallCredentials) {
+ Context.set_credentials(meta.CallCredentials);
+ }
+ if (const TDuration* timeout = std::get_if<TDuration>(&meta.Timeout)) {
+ if (*timeout) {
+ auto deadline = gpr_time_add(
+ gpr_now(GPR_CLOCK_MONOTONIC),
+ gpr_time_from_micros(timeout->MicroSeconds(), GPR_TIMESPAN));
+ Context.set_deadline(deadline);
+ }
+ } else if (const TInstant* deadline = std::get_if<TInstant>(&meta.Timeout)) {
+ if (*deadline) {
+ Context.set_deadline(gpr_time_from_micros(deadline->MicroSeconds(), GPR_CLOCK_MONOTONIC));
+ }
+ }
+ }
+
+ void GetInitialMetadata(std::unordered_multimap<TString, TString>* metadata) {
+ for (const auto& [key, value] : Context.GetServerInitialMetadata()) {
+ metadata->emplace(
+ TString(key.begin(), key.end()),
+ TString(value.begin(), value.end())
+ );
+ }
+ }
+
+ grpc::Status Status;
+ grpc::ClientContext Context;
+ std::shared_ptr<IQueueClientContext> LocalContext;
+};
+
+template<typename TStub, typename TRequest, typename TResponse>
+class TSimpleRequestProcessor
+ : public TThrRefBase
+ , public IQueueClientEvent
+ , public TGRpcRequestProcessorCommon {
+ using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>;
+ template<typename> friend class TServiceConnection;
+public:
+ using TPtr = TIntrusivePtr<TSimpleRequestProcessor>;
+ using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*);
+
+ explicit TSimpleRequestProcessor(TResponseCallback<TResponse>&& callback)
+ : Callback_(std::move(callback))
+ { }
+
+ ~TSimpleRequestProcessor() {
+ if (!Replied_ && Callback_) {
+ Callback_(TGrpcStatus::Internal("request left unhandled"), std::move(Reply_));
+ Callback_ = nullptr; // free resources as early as possible
+ }
+ }
+
+ bool Execute(bool ok) override {
+ {
+ std::unique_lock<std::mutex> guard(Mutex_);
+ LocalContext.reset();
+ }
+ TGrpcStatus status;
+ if (ok) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ Replied_ = true;
+ Callback_(std::move(status), std::move(Reply_));
+ Callback_ = nullptr; // free resources as early as possible
+ return false;
+ }
+
+ void Destroy() override {
+ UnRef();
+ }
+
+private:
+ IQueueClientEvent* FinishedEvent() {
+ Ref();
+ return this;
+ }
+
+ void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) {
+ auto context = provider->CreateContext();
+ if (!context) {
+ Replied_ = true;
+ Callback_(TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_));
+ Callback_ = nullptr;
+ return;
+ }
+ {
+ std::unique_lock<std::mutex> guard(Mutex_);
+ LocalContext = context;
+ Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue());
+ Reader_->Finish(&Reply_, &Status, FinishedEvent());
+ }
+ context->SubscribeStop([self = TPtr(this)] {
+ self->Stop();
+ });
+ }
+
+ void Stop() {
+ Context.TryCancel();
+ }
+
+ TResponseCallback<TResponse> Callback_;
+ TResponse Reply_;
+ std::mutex Mutex_;
+ TAsyncReaderPtr Reader_;
+
+ bool Replied_ = false;
+};
+
+template<typename TStub, typename TRequest, typename TResponse>
+class TAdvancedRequestProcessor
+ : public TThrRefBase
+ , public IQueueClientEvent
+ , public TGRpcRequestProcessorCommon {
+ using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>;
+ template<typename> friend class TServiceConnection;
+public:
+ using TPtr = TIntrusivePtr<TAdvancedRequestProcessor>;
+ using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*);
+
+ explicit TAdvancedRequestProcessor(TAdvancedResponseCallback<TResponse>&& callback)
+ : Callback_(std::move(callback))
+ { }
+
+ ~TAdvancedRequestProcessor() {
+ if (!Replied_ && Callback_) {
+ Callback_(Context, TGrpcStatus::Internal("request left unhandled"), std::move(Reply_));
+ Callback_ = nullptr; // free resources as early as possible
+ }
+ }
+
+ bool Execute(bool ok) override {
+ {
+ std::unique_lock<std::mutex> guard(Mutex_);
+ LocalContext.reset();
+ }
+ TGrpcStatus status;
+ if (ok) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ Replied_ = true;
+ Callback_(Context, std::move(status), std::move(Reply_));
+ Callback_ = nullptr; // free resources as early as possible
+ return false;
+ }
+
+ void Destroy() override {
+ UnRef();
+ }
+
+private:
+ IQueueClientEvent* FinishedEvent() {
+ Ref();
+ return this;
+ }
+
+ void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) {
+ auto context = provider->CreateContext();
+ if (!context) {
+ Replied_ = true;
+ Callback_(Context, TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_));
+ Callback_ = nullptr;
+ return;
+ }
+ {
+ std::unique_lock<std::mutex> guard(Mutex_);
+ LocalContext = context;
+ Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue());
+ Reader_->Finish(&Reply_, &Status, FinishedEvent());
+ }
+ context->SubscribeStop([self = TPtr(this)] {
+ self->Stop();
+ });
+ }
+
+ void Stop() {
+ Context.TryCancel();
+ }
+
+ TAdvancedResponseCallback<TResponse> Callback_;
+ TResponse Reply_;
+ std::mutex Mutex_;
+ TAsyncReaderPtr Reader_;
+
+ bool Replied_ = false;
+};
+
+template<class TResponse>
+class IStreamRequestReadProcessor : public TThrRefBase {
+public:
+ using TPtr = TIntrusivePtr<IStreamRequestReadProcessor>;
+ using TReadCallback = std::function<void(TGrpcStatus&&)>;
+
+ /**
+ * Asynchronously cancel the request
+ */
+ virtual void Cancel() = 0;
+
+ /**
+ * Scheduled initial server metadata read from the stream
+ */
+ virtual void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) = 0;
+
+ /**
+ * Scheduled response read from the stream
+ * Callback will be called with the status if it failed
+ * Only one Read or Finish call may be active at a time
+ */
+ virtual void Read(TResponse* response, TReadCallback callback) = 0;
+
+ /**
+ * Stop reading and gracefully finish the stream
+ * Only one Read or Finish call may be active at a time
+ */
+ virtual void Finish(TReadCallback callback) = 0;
+
+ /**
+ * Additional callback to be called when stream has finished
+ */
+ virtual void AddFinishedCallback(TReadCallback callback) = 0;
+};
+
+template<class TRequest, class TResponse>
+class IStreamRequestReadWriteProcessor : public IStreamRequestReadProcessor<TResponse> {
+public:
+ using TPtr = TIntrusivePtr<IStreamRequestReadWriteProcessor>;
+ using TWriteCallback = std::function<void(TGrpcStatus&&)>;
+
+ /**
+ * Scheduled request write to the stream
+ */
+ virtual void Write(TRequest&& request, TWriteCallback callback = { }) = 0;
+};
+
+class TGRpcKeepAliveSocketMutator;
+
+// Class to hold stubs allocated on channel.
+// It is poor documented part of grpc. See KIKIMR-6109 and comment to this commit
+
+// Stub holds shared_ptr<ChannelInterface>, so we can destroy this holder even if
+// request processor using stub
+class TStubsHolder : public TNonCopyable {
+ using TypeInfoRef = std::reference_wrapper<const std::type_info>;
+
+ struct THasher {
+ std::size_t operator()(TypeInfoRef code) const {
+ return code.get().hash_code();
+ }
+ };
+
+ struct TEqualTo {
+ bool operator()(TypeInfoRef lhs, TypeInfoRef rhs) const {
+ return lhs.get() == rhs.get();
+ }
+ };
+public:
+ TStubsHolder(std::shared_ptr<grpc::ChannelInterface> channel)
+ : ChannelInterface_(channel)
+ {}
+
+ // Returns true if channel can't be used to perform request now
+ bool IsChannelBroken() const {
+ auto state = ChannelInterface_->GetState(false);
+ return state == GRPC_CHANNEL_SHUTDOWN ||
+ state == GRPC_CHANNEL_TRANSIENT_FAILURE;
+ }
+
+ template<typename TStub>
+ std::shared_ptr<TStub> GetOrCreateStub() {
+ const auto& stubId = typeid(TStub);
+ {
+ std::shared_lock readGuard(RWMutex_);
+ const auto it = Stubs_.find(stubId);
+ if (it != Stubs_.end()) {
+ return std::static_pointer_cast<TStub>(it->second);
+ }
+ }
+ {
+ std::unique_lock writeGuard(RWMutex_);
+ auto it = Stubs_.emplace(stubId, nullptr);
+ if (!it.second) {
+ return std::static_pointer_cast<TStub>(it.first->second);
+ } else {
+ it.first->second = std::make_shared<TStub>(ChannelInterface_);
+ return std::static_pointer_cast<TStub>(it.first->second);
+ }
+ }
+ }
+
+ const TInstant& GetLastUseTime() const {
+ return LastUsed_;
+ }
+
+ void SetLastUseTime(const TInstant& time) {
+ LastUsed_ = time;
+ }
+private:
+ TInstant LastUsed_ = Now();
+ std::shared_mutex RWMutex_;
+ std::unordered_map<TypeInfoRef, std::shared_ptr<void>, THasher, TEqualTo> Stubs_;
+ std::shared_ptr<grpc::ChannelInterface> ChannelInterface_;
+};
+
+class TChannelPool {
+public:
+ TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime = TDuration::Minutes(6));
+ //Allows to CreateStub from TStubsHolder under lock
+ //The callback will be called just during GetStubsHolderLocked call
+ void GetStubsHolderLocked(const TString& channelId, const TGRpcClientConfig& config, std::function<void(TStubsHolder&)> cb);
+ void DeleteChannel(const TString& channelId);
+ void DeleteExpiredStubsHolders();
+private:
+ std::shared_mutex RWMutex_;
+ std::unordered_map<TString, TStubsHolder> Pool_;
+ std::multimap<TInstant, TString> LastUsedQueue_;
+ TTcpKeepAliveSettings TcpKeepAliveSettings_;
+ TDuration ExpireTime_;
+ TDuration UpdateReUseTime_;
+ void EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId);
+};
+
+template<class TResponse>
+using TStreamReaderCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadProcessor<TResponse>::TPtr)>;
+
+template<typename TStub, typename TRequest, typename TResponse>
+class TStreamRequestReadProcessor
+ : public IStreamRequestReadProcessor<TResponse>
+ , public TGRpcRequestProcessorCommon {
+ template<typename> friend class TServiceConnection;
+public:
+ using TSelf = TStreamRequestReadProcessor;
+ using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncReader<TResponse>>;
+ using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*, void*);
+ using TReaderCallback = TStreamReaderCallback<TResponse>;
+ using TPtr = TIntrusivePtr<TSelf>;
+ using TBase = IStreamRequestReadProcessor<TResponse>;
+ using TReadCallback = typename TBase::TReadCallback;
+
+ explicit TStreamRequestReadProcessor(TReaderCallback&& callback)
+ : Callback(std::move(callback))
+ {
+ Y_VERIFY(Callback, "Missing connected callback");
+ }
+
+ void Cancel() override {
+ Context.TryCancel();
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Cancelled = true;
+ if (Started && !ReadFinished) {
+ if (!ReadActive) {
+ ReadFinished = true;
+ }
+ if (ReadFinished) {
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ }
+ }
+ }
+
+ void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished && !HasInitialMetadata) {
+ ReadActive = true;
+ ReadCallback = std::move(callback);
+ InitialMetadata = metadata;
+ if (!ReadFinished) {
+ Stream->ReadInitialMetadata(OnReadDoneTag.Prepare());
+ }
+ return;
+ }
+ if (!HasInitialMetadata) {
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ } else {
+ GetInitialMetadata(metadata);
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+ void Read(TResponse* message, TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished) {
+ ReadActive = true;
+ ReadCallback = std::move(callback);
+ if (!ReadFinished) {
+ Stream->Read(message, OnReadDoneTag.Prepare());
+ }
+ return;
+ }
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF");
+ }
+
+ callback(std::move(status));
+ }
+
+ void Finish(TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished) {
+ ReadActive = true;
+ FinishCallback = std::move(callback);
+ if (!ReadFinished) {
+ ReadFinished = true;
+ }
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ return;
+ }
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+ void AddFinishedCallback(TReadCallback callback) override {
+ Y_VERIFY(callback, "Unexpected empty callback");
+
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ if (!Finished) {
+ FinishedCallbacks.emplace_back().swap(callback);
+ return;
+ }
+
+ if (FinishedOk) {
+ status = Status;
+ } else if (Cancelled) {
+ status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled");
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+private:
+ void Start(TStub& stub, const TRequest& request, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) {
+ auto context = provider->CreateContext();
+ if (!context) {
+ auto callback = std::move(Callback);
+ TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down");
+ callback(std::move(status), nullptr);
+ return;
+ }
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ LocalContext = context;
+ Stream = (stub.*asyncRequest)(&Context, request, context->CompletionQueue(), OnStartDoneTag.Prepare());
+ }
+
+ context->SubscribeStop([self = TPtr(this)] {
+ self->Cancel();
+ });
+ }
+
+ void OnReadDone(bool ok) {
+ TGrpcStatus status;
+ TReadCallback callback;
+ std::unordered_multimap<TString, TString>* initialMetadata = nullptr;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(ReadActive, "Unexpected Read done callback");
+ Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag");
+
+ if (!ok || Cancelled) {
+ ReadFinished = true;
+
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ if (!ok) {
+ // Keep ReadActive=true, so callback is called
+ // after the call is finished with an error
+ return;
+ }
+ }
+
+ callback = std::move(ReadCallback);
+ ReadCallback = nullptr;
+ ReadActive = false;
+ initialMetadata = InitialMetadata;
+ InitialMetadata = nullptr;
+ HasInitialMetadata = true;
+ }
+
+ if (initialMetadata) {
+ GetInitialMetadata(initialMetadata);
+ }
+
+ callback(std::move(status));
+ }
+
+ void OnStartDone(bool ok) {
+ TReaderCallback callback;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Started = true;
+ if (!ok || Cancelled) {
+ ReadFinished = true;
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ return;
+ }
+ callback = std::move(Callback);
+ Callback = nullptr;
+ }
+
+ callback({ }, typename TBase::TPtr(this));
+ }
+
+ void OnFinished(bool ok) {
+ TGrpcStatus status;
+ std::vector<TReadCallback> finishedCallbacks;
+ TReaderCallback startCallback;
+ TReadCallback readCallback;
+ TReadCallback finishCallback;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+
+ Finished = true;
+ FinishedOk = ok;
+ LocalContext.reset();
+
+ if (ok) {
+ status = Status;
+ } else if (Cancelled) {
+ status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled");
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+
+ finishedCallbacks.swap(FinishedCallbacks);
+
+ if (Callback) {
+ Y_VERIFY(!ReadActive);
+ startCallback = std::move(Callback);
+ Callback = nullptr;
+ } else if (ReadActive) {
+ if (ReadCallback) {
+ readCallback = std::move(ReadCallback);
+ ReadCallback = nullptr;
+ } else {
+ finishCallback = std::move(FinishCallback);
+ FinishCallback = nullptr;
+ }
+ ReadActive = false;
+ }
+ }
+
+ for (auto& finishedCallback : finishedCallbacks) {
+ auto statusCopy = status;
+ finishedCallback(std::move(statusCopy));
+ }
+
+ if (startCallback) {
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure");
+ }
+ startCallback(std::move(status), nullptr);
+ } else if (readCallback) {
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF");
+ }
+ readCallback(std::move(status));
+ } else if (finishCallback) {
+ finishCallback(std::move(status));
+ }
+ }
+
+ TReaderCallback Callback;
+ TAsyncReaderPtr Stream;
+ using TFixedEvent = TQueueClientFixedEvent<TSelf>;
+ std::mutex Mutex;
+ TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone };
+ TFixedEvent OnStartDoneTag = { this, &TSelf::OnStartDone };
+ TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished };
+
+ TReadCallback ReadCallback;
+ TReadCallback FinishCallback;
+ std::vector<TReadCallback> FinishedCallbacks;
+ std::unordered_multimap<TString, TString>* InitialMetadata = nullptr;
+ bool Started = false;
+ bool HasInitialMetadata = false;
+ bool ReadActive = false;
+ bool ReadFinished = false;
+ bool Finished = false;
+ bool Cancelled = false;
+ bool FinishedOk = false;
+};
+
+template<class TRequest, class TResponse>
+using TStreamConnectedCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadWriteProcessor<TRequest, TResponse>::TPtr)>;
+
+template<class TStub, class TRequest, class TResponse>
+class TStreamRequestReadWriteProcessor
+ : public IStreamRequestReadWriteProcessor<TRequest, TResponse>
+ , public TGRpcRequestProcessorCommon {
+public:
+ using TSelf = TStreamRequestReadWriteProcessor;
+ using TBase = IStreamRequestReadWriteProcessor<TRequest, TResponse>;
+ using TPtr = TIntrusivePtr<TSelf>;
+ using TConnectedCallback = TStreamConnectedCallback<TRequest, TResponse>;
+ using TReadCallback = typename TBase::TReadCallback;
+ using TWriteCallback = typename TBase::TWriteCallback;
+ using TAsyncReaderWriterPtr = std::unique_ptr<grpc::ClientAsyncReaderWriter<TRequest, TResponse>>;
+ using TAsyncRequest = TAsyncReaderWriterPtr (TStub::*)(grpc::ClientContext*, grpc::CompletionQueue*, void*);
+
+ explicit TStreamRequestReadWriteProcessor(TConnectedCallback&& callback)
+ : ConnectedCallback(std::move(callback))
+ {
+ Y_VERIFY(ConnectedCallback, "Missing connected callback");
+ }
+
+ void Cancel() override {
+ Context.TryCancel();
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Cancelled = true;
+ if (Started && !(ReadFinished && WriteFinished)) {
+ if (!ReadActive) {
+ ReadFinished = true;
+ }
+ if (!WriteActive) {
+ WriteFinished = true;
+ }
+ if (ReadFinished && WriteFinished) {
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ }
+ }
+ }
+
+ void Write(TRequest&& request, TWriteCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ if (Cancelled || ReadFinished || WriteFinished) {
+ status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped");
+ } else if (WriteActive) {
+ auto& item = WriteQueue.emplace_back();
+ item.Callback.swap(callback);
+ item.Request.Swap(&request);
+ } else {
+ WriteActive = true;
+ WriteCallback.swap(callback);
+ Stream->Write(request, OnWriteDoneTag.Prepare());
+ }
+ }
+
+ if (!status.Ok() && callback) {
+ callback(std::move(status));
+ }
+ }
+
+ void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished && !HasInitialMetadata) {
+ ReadActive = true;
+ ReadCallback = std::move(callback);
+ InitialMetadata = metadata;
+ if (!ReadFinished) {
+ Stream->ReadInitialMetadata(OnReadDoneTag.Prepare());
+ }
+ return;
+ }
+ if (!HasInitialMetadata) {
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ } else {
+ GetInitialMetadata(metadata);
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+ void Read(TResponse* message, TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished) {
+ ReadActive = true;
+ ReadCallback = std::move(callback);
+ if (!ReadFinished) {
+ Stream->Read(message, OnReadDoneTag.Prepare());
+ }
+ return;
+ }
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF");
+ }
+
+ callback(std::move(status));
+ }
+
+ void Finish(TReadCallback callback) override {
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected");
+ if (!Finished) {
+ ReadActive = true;
+ FinishCallback = std::move(callback);
+ if (!ReadFinished) {
+ ReadFinished = true;
+ if (!WriteActive) {
+ WriteFinished = true;
+ }
+ if (WriteFinished) {
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ }
+ return;
+ }
+ if (FinishedOk) {
+ status = Status;
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+ void AddFinishedCallback(TReadCallback callback) override {
+ Y_VERIFY(callback, "Unexpected empty callback");
+
+ TGrpcStatus status;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ if (!Finished) {
+ FinishedCallbacks.emplace_back().swap(callback);
+ return;
+ }
+
+ if (FinishedOk) {
+ status = Status;
+ } else if (Cancelled) {
+ status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled");
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+ }
+
+ callback(std::move(status));
+ }
+
+private:
+ template<typename> friend class TServiceConnection;
+
+ void Start(TStub& stub, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) {
+ auto context = provider->CreateContext();
+ if (!context) {
+ auto callback = std::move(ConnectedCallback);
+ TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down");
+ callback(std::move(status), nullptr);
+ return;
+ }
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ LocalContext = context;
+ Stream = (stub.*asyncRequest)(&Context, context->CompletionQueue(), OnConnectedTag.Prepare());
+ }
+
+ context->SubscribeStop([self = TPtr(this)] {
+ self->Cancel();
+ });
+ }
+
+private:
+ void OnConnected(bool ok) {
+ TConnectedCallback callback;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Started = true;
+ if (!ok || Cancelled) {
+ ReadFinished = true;
+ WriteFinished = true;
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ return;
+ }
+
+ callback = std::move(ConnectedCallback);
+ ConnectedCallback = nullptr;
+ }
+
+ callback({ }, typename TBase::TPtr(this));
+ }
+
+ void OnReadDone(bool ok) {
+ TGrpcStatus status;
+ TReadCallback callback;
+ std::unordered_multimap<TString, TString>* initialMetadata = nullptr;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(ReadActive, "Unexpected Read done callback");
+ Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag");
+
+ if (!ok || Cancelled || WriteFinished) {
+ ReadFinished = true;
+ if (!WriteActive) {
+ WriteFinished = true;
+ }
+ if (WriteFinished) {
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ if (!ok) {
+ // Keep ReadActive=true, so callback is called
+ // after the call is finished with an error
+ return;
+ }
+ }
+
+ callback = std::move(ReadCallback);
+ ReadCallback = nullptr;
+ ReadActive = false;
+ initialMetadata = InitialMetadata;
+ InitialMetadata = nullptr;
+ HasInitialMetadata = true;
+ }
+
+ if (initialMetadata) {
+ GetInitialMetadata(initialMetadata);
+ }
+
+ callback(std::move(status));
+ }
+
+ void OnWriteDone(bool ok) {
+ TWriteCallback okCallback;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Y_VERIFY(WriteActive, "Unexpected Write done callback");
+ Y_VERIFY(!WriteFinished, "Unexpected WriteFinished flag");
+
+ if (ok) {
+ okCallback.swap(WriteCallback);
+ } else if (WriteCallback) {
+ // Put callback back on the queue until OnFinished
+ auto& item = WriteQueue.emplace_front();
+ item.Callback.swap(WriteCallback);
+ }
+
+ if (!ok || Cancelled) {
+ WriteActive = false;
+ WriteFinished = true;
+ if (!ReadActive) {
+ ReadFinished = true;
+ }
+ if (ReadFinished) {
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ } else if (!WriteQueue.empty()) {
+ WriteCallback.swap(WriteQueue.front().Callback);
+ Stream->Write(WriteQueue.front().Request, OnWriteDoneTag.Prepare());
+ WriteQueue.pop_front();
+ } else {
+ WriteActive = false;
+ if (ReadFinished) {
+ WriteFinished = true;
+ Stream->Finish(&Status, OnFinishedTag.Prepare());
+ }
+ }
+ }
+
+ if (okCallback) {
+ okCallback(TGrpcStatus());
+ }
+ }
+
+ void OnFinished(bool ok) {
+ TGrpcStatus status;
+ std::deque<TWriteItem> writesDropped;
+ std::vector<TReadCallback> finishedCallbacks;
+ TConnectedCallback connectedCallback;
+ TReadCallback readCallback;
+ TReadCallback finishCallback;
+
+ {
+ std::unique_lock<std::mutex> guard(Mutex);
+ Finished = true;
+ FinishedOk = ok;
+ LocalContext.reset();
+
+ if (ok) {
+ status = Status;
+ } else if (Cancelled) {
+ status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled");
+ } else {
+ status = TGrpcStatus::Internal("Unexpected error");
+ }
+
+ writesDropped.swap(WriteQueue);
+ finishedCallbacks.swap(FinishedCallbacks);
+
+ if (ConnectedCallback) {
+ Y_VERIFY(!ReadActive);
+ connectedCallback = std::move(ConnectedCallback);
+ ConnectedCallback = nullptr;
+ } else if (ReadActive) {
+ if (ReadCallback) {
+ readCallback = std::move(ReadCallback);
+ ReadCallback = nullptr;
+ } else {
+ finishCallback = std::move(FinishCallback);
+ FinishCallback = nullptr;
+ }
+ ReadActive = false;
+ }
+ }
+
+ for (auto& item : writesDropped) {
+ if (item.Callback) {
+ TGrpcStatus writeStatus = status;
+ if (writeStatus.Ok()) {
+ writeStatus = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped");
+ }
+ item.Callback(std::move(writeStatus));
+ }
+ }
+
+ for (auto& finishedCallback : finishedCallbacks) {
+ TGrpcStatus statusCopy = status;
+ finishedCallback(std::move(statusCopy));
+ }
+
+ if (connectedCallback) {
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure");
+ }
+ connectedCallback(std::move(status), nullptr);
+ } else if (readCallback) {
+ if (status.Ok()) {
+ status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF");
+ }
+ readCallback(std::move(status));
+ } else if (finishCallback) {
+ finishCallback(std::move(status));
+ }
+ }
+
+private:
+ struct TWriteItem {
+ TWriteCallback Callback;
+ TRequest Request;
+ };
+
+private:
+ using TFixedEvent = TQueueClientFixedEvent<TSelf>;
+
+ TFixedEvent OnConnectedTag = { this, &TSelf::OnConnected };
+ TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone };
+ TFixedEvent OnWriteDoneTag = { this, &TSelf::OnWriteDone };
+ TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished };
+
+private:
+ std::mutex Mutex;
+ TAsyncReaderWriterPtr Stream;
+ TConnectedCallback ConnectedCallback;
+ TReadCallback ReadCallback;
+ TReadCallback FinishCallback;
+ std::vector<TReadCallback> FinishedCallbacks;
+ std::deque<TWriteItem> WriteQueue;
+ TWriteCallback WriteCallback;
+ std::unordered_multimap<TString, TString>* InitialMetadata = nullptr;
+ bool Started = false;
+ bool HasInitialMetadata = false;
+ bool ReadActive = false;
+ bool ReadFinished = false;
+ bool WriteActive = false;
+ bool WriteFinished = false;
+ bool Finished = false;
+ bool Cancelled = false;
+ bool FinishedOk = false;
+};
+
+class TGRpcClientLow;
+
+template<typename TGRpcService>
+class TServiceConnection {
+ using TStub = typename TGRpcService::Stub;
+ friend class TGRpcClientLow;
+
+public:
+ /*
+ * Start simple request
+ */
+ template<typename TRequest, typename TResponse>
+ void DoRequest(const TRequest& request,
+ TResponseCallback<TResponse> callback,
+ typename TSimpleRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest,
+ const TCallMeta& metas = { },
+ IQueueClientContextProvider* provider = nullptr)
+ {
+ auto processor = MakeIntrusive<TSimpleRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback));
+ processor->ApplyMeta(metas);
+ processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_);
+ }
+
+ /*
+ * Start simple request
+ */
+ template<typename TRequest, typename TResponse>
+ void DoAdvancedRequest(const TRequest& request,
+ TAdvancedResponseCallback<TResponse> callback,
+ typename TAdvancedRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest,
+ const TCallMeta& metas = { },
+ IQueueClientContextProvider* provider = nullptr)
+ {
+ auto processor = MakeIntrusive<TAdvancedRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback));
+ processor->ApplyMeta(metas);
+ processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_);
+ }
+
+ /*
+ * Start bidirectional streamming
+ */
+ template<typename TRequest, typename TResponse>
+ void DoStreamRequest(TStreamConnectedCallback<TRequest, TResponse> callback,
+ typename TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest,
+ const TCallMeta& metas = { },
+ IQueueClientContextProvider* provider = nullptr)
+ {
+ auto processor = MakeIntrusive<TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>>(std::move(callback));
+ processor->ApplyMeta(metas);
+ processor->Start(*Stub_, std::move(asyncRequest), provider ? provider : Provider_);
+ }
+
+ /*
+ * Start streaming response reading (one request, many responses)
+ */
+ template<typename TRequest, typename TResponse>
+ void DoStreamRequest(const TRequest& request,
+ TStreamReaderCallback<TResponse> callback,
+ typename TStreamRequestReadProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest,
+ const TCallMeta& metas = { },
+ IQueueClientContextProvider* provider = nullptr)
+ {
+ auto processor = MakeIntrusive<TStreamRequestReadProcessor<TStub, TRequest, TResponse>>(std::move(callback));
+ processor->ApplyMeta(metas);
+ processor->Start(*Stub_, request, std::move(asyncRequest), provider ? provider : Provider_);
+ }
+
+private:
+ TServiceConnection(std::shared_ptr<grpc::ChannelInterface> ci,
+ IQueueClientContextProvider* provider)
+ : Stub_(TGRpcService::NewStub(ci))
+ , Provider_(provider)
+ {
+ Y_VERIFY(Provider_, "Connection does not have a queue provider");
+ }
+
+ TServiceConnection(TStubsHolder& holder,
+ IQueueClientContextProvider* provider)
+ : Stub_(holder.GetOrCreateStub<TStub>())
+ , Provider_(provider)
+ {
+ Y_VERIFY(Provider_, "Connection does not have a queue provider");
+ }
+
+ std::shared_ptr<TStub> Stub_;
+ IQueueClientContextProvider* Provider_;
+};
+
+class TGRpcClientLow
+ : public IQueueClientContextProvider
+{
+ class TContextImpl;
+ friend class TContextImpl;
+
+ enum ECqState : TAtomicBase {
+ WORKING = 0,
+ STOP_SILENT = 1,
+ STOP_EXPLICIT = 2,
+ };
+
+public:
+ explicit TGRpcClientLow(size_t numWorkerThread = DEFAULT_NUM_THREADS, bool useCompletionQueuePerThread = false);
+ ~TGRpcClientLow();
+
+ // Tries to stop all currently running requests (via their stop callbacks)
+ // Will shutdown CQ and drain events once all requests have finished
+ // No new requests may be started after this call
+ void Stop(bool wait = false);
+
+ // Waits until all currently running requests finish execution
+ void WaitIdle();
+
+ inline bool IsStopping() const {
+ switch (GetCqState()) {
+ case WORKING:
+ return false;
+ case STOP_SILENT:
+ case STOP_EXPLICIT:
+ return true;
+ }
+
+ Y_UNREACHABLE();
+ }
+
+ IQueueClientContextPtr CreateContext() override;
+
+ template<typename TGRpcService>
+ std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(const TGRpcClientConfig& config) {
+ return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(CreateChannelInterface(config), this));
+ }
+
+ template<typename TGRpcService>
+ std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(TStubsHolder& holder) {
+ return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(holder, this));
+ }
+
+ // Tests only, not thread-safe
+ void AddWorkerThreadForTest();
+
+private:
+ using IThreadRef = std::unique_ptr<IThreadFactory::IThread>;
+ using CompletionQueueRef = std::unique_ptr<grpc::CompletionQueue>;
+ void Init(size_t numWorkerThread);
+
+ inline ECqState GetCqState() const { return (ECqState) AtomicGet(CqState_); }
+ inline void SetCqState(ECqState state) { AtomicSet(CqState_, state); }
+
+ void StopInternal(bool silent);
+ void WaitInternal();
+
+ void ForgetContext(TContextImpl* context);
+
+private:
+ bool UseCompletionQueuePerThread_;
+ std::vector<CompletionQueueRef> CQS_;
+ std::vector<IThreadRef> WorkerThreads_;
+ TAtomic CqState_ = -1;
+
+ std::mutex Mtx_;
+ std::condition_variable ContextsEmpty_;
+ std::unordered_set<TContextImpl*> Contexts_;
+
+ std::mutex JoinMutex_;
+};
+
+} // namespace NGRpc