diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/grpc/client | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/grpc/client')
-rw-r--r-- | library/cpp/grpc/client/grpc_client_low.cpp | 586 | ||||
-rw-r--r-- | library/cpp/grpc/client/grpc_client_low.h | 1399 | ||||
-rw-r--r-- | library/cpp/grpc/client/grpc_common.h | 84 | ||||
-rw-r--r-- | library/cpp/grpc/client/ut/grpc_client_low_ut.cpp | 61 | ||||
-rw-r--r-- | library/cpp/grpc/client/ut/ya.make | 11 | ||||
-rw-r--r-- | library/cpp/grpc/client/ya.make | 20 |
6 files changed, 2161 insertions, 0 deletions
diff --git a/library/cpp/grpc/client/grpc_client_low.cpp b/library/cpp/grpc/client/grpc_client_low.cpp new file mode 100644 index 00000000000..73cc908ef82 --- /dev/null +++ b/library/cpp/grpc/client/grpc_client_low.cpp @@ -0,0 +1,586 @@ +#include "grpc_client_low.h" +#include <contrib/libs/grpc/src/core/lib/iomgr/socket_mutator.h> +#include <contrib/libs/grpc/include/grpc/support/log.h> + +#include <library/cpp/containers/stack_vector/stack_vec.h> + +#include <util/string/printf.h> +#include <util/system/thread.h> +#include <util/random/random.h> + +#if !defined(_WIN32) && !defined(_WIN64) +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#endif + +namespace NGrpc { + +void EnableGRpcTracing() { + grpc_tracer_set_enabled("tcp", true); + grpc_tracer_set_enabled("client_channel", true); + grpc_tracer_set_enabled("channel", true); + grpc_tracer_set_enabled("api", true); + grpc_tracer_set_enabled("connectivity_state", true); + grpc_tracer_set_enabled("handshaker", true); + grpc_tracer_set_enabled("http", true); + grpc_tracer_set_enabled("http2_stream_state", true); + grpc_tracer_set_enabled("op_failure", true); + grpc_tracer_set_enabled("timer", true); + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); +} + +class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { +public: + TGRpcKeepAliveSocketMutator(int idle, int count, int interval) + : Idle_(idle) + , Count_(count) + , Interval_(interval) + { + grpc_socket_mutator_init(this, &VTable); + } +private: + static TGRpcKeepAliveSocketMutator* Cast(grpc_socket_mutator* mutator) { + return static_cast<TGRpcKeepAliveSocketMutator*>(mutator); + } + + template<typename TVal> + bool SetOption(int fd, int level, int optname, const TVal& value) { + return setsockopt(fd, level, optname, reinterpret_cast<const char*>(&value), sizeof(value)) == 0; + } + bool SetOption(int fd) { + if (!SetOption(fd, SOL_SOCKET, SO_KEEPALIVE, 1)) { + Cerr << Sprintf("Failed to set SO_KEEPALIVE option: %s", strerror(errno)) << Endl; + return false; + } +#ifdef _linux_ + if (Idle_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPIDLE, Idle_)) { + Cerr << Sprintf("Failed to set TCP_KEEPIDLE option: %s", strerror(errno)) << Endl; + return false; + } + if (Count_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPCNT, Count_)) { + Cerr << Sprintf("Failed to set TCP_KEEPCNT option: %s", strerror(errno)) << Endl; + return false; + } + if (Interval_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPINTVL, Interval_)) { + Cerr << Sprintf("Failed to set TCP_KEEPINTVL option: %s", strerror(errno)) << Endl; + return false; + } +#endif + return true; + } + static bool Mutate(int fd, grpc_socket_mutator* mutator) { + auto self = Cast(mutator); + return self->SetOption(fd); + } + static int Compare(grpc_socket_mutator* a, grpc_socket_mutator* b) { + const auto* selfA = Cast(a); + const auto* selfB = Cast(b); + auto tupleA = std::make_tuple(selfA->Idle_, selfA->Count_, selfA->Interval_); + auto tupleB = std::make_tuple(selfB->Idle_, selfB->Count_, selfB->Interval_); + return tupleA < tupleB ? -1 : tupleA > tupleB ? 1 : 0; + } + static void Destroy(grpc_socket_mutator* mutator) { + delete Cast(mutator); + } + + static grpc_socket_mutator_vtable VTable; + const int Idle_; + const int Count_; + const int Interval_; +}; + +grpc_socket_mutator_vtable TGRpcKeepAliveSocketMutator::VTable = + { + &TGRpcKeepAliveSocketMutator::Mutate, + &TGRpcKeepAliveSocketMutator::Compare, + &TGRpcKeepAliveSocketMutator::Destroy + }; + +TChannelPool::TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime) + : TcpKeepAliveSettings_(tcpKeepAliveSettings) + , ExpireTime_(expireTime) + , UpdateReUseTime_(ExpireTime_ * 0.3 < TDuration::Seconds(20) ? ExpireTime_ * 0.3 : TDuration::Seconds(20)) +{} + +void TChannelPool::GetStubsHolderLocked( + const TString& channelId, + const TGRpcClientConfig& config, + std::function<void(TStubsHolder&)> cb) +{ + { + std::shared_lock readGuard(RWMutex_); + const auto it = Pool_.find(channelId); + if (it != Pool_.end()) { + if (!it->second.IsChannelBroken() && !(Now() > it->second.GetLastUseTime() + UpdateReUseTime_)) { + return cb(it->second); + } + } + } + { + std::unique_lock writeGuard(RWMutex_); + { + auto it = Pool_.find(channelId); + if (it != Pool_.end()) { + if (!it->second.IsChannelBroken()) { + EraseFromQueueByTime(it->second.GetLastUseTime(), channelId); + auto now = Now(); + LastUsedQueue_.emplace(now, channelId); + it->second.SetLastUseTime(now); + return cb(it->second); + } else { + // This channel can't be used. Remove from pool to create new one + EraseFromQueueByTime(it->second.GetLastUseTime(), channelId); + Pool_.erase(it); + } + } + } + TGRpcKeepAliveSocketMutator* mutator = nullptr; + // will be destroyed inside grpc + if (TcpKeepAliveSettings_.Enabled) { + mutator = new TGRpcKeepAliveSocketMutator( + TcpKeepAliveSettings_.Idle, + TcpKeepAliveSettings_.Count, + TcpKeepAliveSettings_.Interval + ); + } + cb(Pool_.emplace(channelId, CreateChannelInterface(config, mutator)).first->second); + LastUsedQueue_.emplace(Pool_.at(channelId).GetLastUseTime(), channelId); + } +} + +void TChannelPool::DeleteChannel(const TString& channelId) { + std::unique_lock writeLock(RWMutex_); + auto poolIt = Pool_.find(channelId); + if (poolIt != Pool_.end()) { + EraseFromQueueByTime(poolIt->second.GetLastUseTime(), channelId); + Pool_.erase(poolIt); + } +} + +void TChannelPool::DeleteExpiredStubsHolders() { + std::unique_lock writeLock(RWMutex_); + auto lastExpired = LastUsedQueue_.lower_bound(Now() - ExpireTime_); + for (auto i = LastUsedQueue_.begin(); i != lastExpired; ++i){ + Pool_.erase(i->second); + } + LastUsedQueue_.erase(LastUsedQueue_.begin(), lastExpired); +} + +void TChannelPool::EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId) { + auto [begin, end] = LastUsedQueue_.equal_range(lastUseTime); + auto pos = std::find_if(begin, end, [&](auto a){return a.second == channelId;}); + Y_VERIFY(pos != LastUsedQueue_.end(), "data corruption at TChannelPool"); + LastUsedQueue_.erase(pos); +} + +static void PullEvents(grpc::CompletionQueue* cq) { + TThread::SetCurrentThreadName("grpc_client"); + while (true) { + void* tag; + bool ok; + + if (!cq->Next(&tag, &ok)) { + break; + } + + if (auto* ev = static_cast<IQueueClientEvent*>(tag)) { + if (!ev->Execute(ok)) { + ev->Destroy(); + } + } + } +} + +class TGRpcClientLow::TContextImpl final + : public std::enable_shared_from_this<TContextImpl> + , public IQueueClientContext +{ + friend class TGRpcClientLow; + + using TCallback = std::function<void()>; + using TContextPtr = std::shared_ptr<TContextImpl>; + +public: + ~TContextImpl() override { + Y_VERIFY(CountChildren() == 0, + "Destructor called with non-empty children"); + + if (Parent) { + Parent->ForgetContext(this); + } else if (Y_LIKELY(Owner)) { + Owner->ForgetContext(this); + } + } + + /** + * Helper for locking child pointer from a parent container + */ + static TContextPtr LockChildPtr(TContextImpl* ptr) { + if (ptr) { + // N.B. it is safe to do as long as it's done under a mutex and + // pointer is among valid children. When that's the case we + // know that TContextImpl destructor has not finished yet, so + // the object is valid. The lock() method may return nullptr + // though, if the object is being destructed right now. + return ptr->weak_from_this().lock(); + } else { + return nullptr; + } + } + + void ForgetContext(TContextImpl* child) { + std::unique_lock<std::mutex> guard(Mutex); + + auto removed = RemoveChild(child); + Y_VERIFY(removed, "Unexpected ForgetContext(%p)", child); + } + + IQueueClientContextPtr CreateContext() override { + auto self = shared_from_this(); + auto child = std::make_shared<TContextImpl>(); + + { + std::unique_lock<std::mutex> guard(Mutex); + + AddChild(child.get()); + + // It's now safe to initialize parent and owner + child->Parent = std::move(self); + child->Owner = Owner; + child->CQ = CQ; + + // Propagate cancellation to a child context + if (Cancelled.load(std::memory_order_relaxed)) { + child->Cancelled.store(true, std::memory_order_relaxed); + } + } + + return child; + } + + grpc::CompletionQueue* CompletionQueue() override { + Y_VERIFY(Owner, "Uninitialized context"); + return CQ; + } + + bool IsCancelled() const override { + return Cancelled.load(std::memory_order_acquire); + } + + bool Cancel() override { + TStackVec<TCallback, 1> callbacks; + TStackVec<TContextPtr, 2> children; + + { + std::unique_lock<std::mutex> guard(Mutex); + + if (Cancelled.load(std::memory_order_relaxed)) { + // Already cancelled in another thread + return false; + } + + callbacks.reserve(Callbacks.size()); + children.reserve(CountChildren()); + + for (auto& callback : Callbacks) { + callbacks.emplace_back().swap(callback); + } + Callbacks.clear(); + + // Collect all children we need to cancel + // N.B. we don't clear children links (cleared by destructors) + // N.B. some children may be stuck in destructors at the moment + for (TContextImpl* ptr : InlineChildren) { + if (auto child = LockChildPtr(ptr)) { + children.emplace_back(std::move(child)); + } + } + for (auto* ptr : Children) { + if (auto child = LockChildPtr(ptr)) { + children.emplace_back(std::move(child)); + } + } + + Cancelled.store(true, std::memory_order_release); + } + + // Call directly subscribed callbacks + if (callbacks) { + RunCallbacksNoExcept(callbacks); + } + + // Cancel all children + for (auto& child : children) { + child->Cancel(); + child.reset(); + } + + return true; + } + + void SubscribeCancel(TCallback callback) override { + Y_VERIFY(callback, "SubscribeCancel called with an empty callback"); + + { + std::unique_lock<std::mutex> guard(Mutex); + + if (!Cancelled.load(std::memory_order_relaxed)) { + Callbacks.emplace_back().swap(callback); + return; + } + } + + // Already cancelled, run immediately + callback(); + } + +private: + void AddChild(TContextImpl* child) { + for (TContextImpl*& slot : InlineChildren) { + if (!slot) { + slot = child; + return; + } + } + + Children.insert(child); + } + + bool RemoveChild(TContextImpl* child) { + for (TContextImpl*& slot : InlineChildren) { + if (slot == child) { + slot = nullptr; + return true; + } + } + + return Children.erase(child); + } + + size_t CountChildren() { + size_t count = 0; + + for (TContextImpl* ptr : InlineChildren) { + if (ptr) { + ++count; + } + } + + return count + Children.size(); + } + + template<class TCallbacks> + static void RunCallbacksNoExcept(TCallbacks& callbacks) noexcept { + for (auto& callback : callbacks) { + if (callback) { + callback(); + callback = nullptr; + } + } + } + +private: + // We want a simple lock here, without extra memory allocations + std::mutex Mutex; + + // These fields are initialized on successful registration + TContextPtr Parent; + TGRpcClientLow* Owner = nullptr; + grpc::CompletionQueue* CQ = nullptr; + + // Some children are stored inline, others are in a set + std::array<TContextImpl*, 2> InlineChildren{ { nullptr, nullptr } }; + std::unordered_set<TContextImpl*> Children; + + // Single callback is stored without extra allocations + TStackVec<TCallback, 1> Callbacks; + + // Atomic flag for a faster IsCancelled() implementation + std::atomic<bool> Cancelled; +}; + +TGRpcClientLow::TGRpcClientLow(size_t numWorkerThread, bool useCompletionQueuePerThread) + : UseCompletionQueuePerThread_(useCompletionQueuePerThread) +{ + Init(numWorkerThread); +} + +void TGRpcClientLow::Init(size_t numWorkerThread) { + SetCqState(WORKING); + if (UseCompletionQueuePerThread_) { + for (size_t i = 0; i < numWorkerThread; i++) { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } + } else { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + for (size_t i = 0; i < numWorkerThread; i++) { + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } + } +} + +void TGRpcClientLow::AddWorkerThreadForTest() { + if (UseCompletionQueuePerThread_) { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } else { + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } +} + +TGRpcClientLow::~TGRpcClientLow() { + StopInternal(true); + WaitInternal(); +} + +void TGRpcClientLow::Stop(bool wait) { + StopInternal(false); + + if (wait) { + WaitInternal(); + } +} + +void TGRpcClientLow::StopInternal(bool silent) { + bool shutdown; + + TVector<TContextImpl::TContextPtr> cancelQueue; + + { + std::unique_lock<std::mutex> guard(Mtx_); + + auto allowStateChange = [&]() { + switch (GetCqState()) { + case WORKING: + return true; + case STOP_SILENT: + return !silent; + case STOP_EXPLICIT: + return false; + } + + Y_UNREACHABLE(); + }; + + if (!allowStateChange()) { + // Completion queue is already stopping + return; + } + + SetCqState(silent ? STOP_SILENT : STOP_EXPLICIT); + + if (!silent && !Contexts_.empty()) { + cancelQueue.reserve(Contexts_.size()); + for (auto* ptr : Contexts_) { + // N.B. some contexts may be stuck in destructors + if (auto context = TContextImpl::LockChildPtr(ptr)) { + cancelQueue.emplace_back(std::move(context)); + } + } + } + + shutdown = Contexts_.empty(); + } + + for (auto& context : cancelQueue) { + context->Cancel(); + context.reset(); + } + + if (shutdown) { + for (auto& cq : CQS_) { + cq->Shutdown(); + } + } +} + +void TGRpcClientLow::WaitInternal() { + std::unique_lock<std::mutex> guard(JoinMutex_); + + for (auto& ti : WorkerThreads_) { + ti->Join(); + } +} + +void TGRpcClientLow::WaitIdle() { + std::unique_lock<std::mutex> guard(Mtx_); + + while (!Contexts_.empty()) { + ContextsEmpty_.wait(guard); + } +} + +std::shared_ptr<IQueueClientContext> TGRpcClientLow::CreateContext() { + std::unique_lock<std::mutex> guard(Mtx_); + + auto allowCreateContext = [&]() { + switch (GetCqState()) { + case WORKING: + return true; + case STOP_SILENT: + case STOP_EXPLICIT: + return false; + } + + Y_UNREACHABLE(); + }; + + if (!allowCreateContext()) { + // New context creation is forbidden + return nullptr; + } + + auto context = std::make_shared<TContextImpl>(); + Contexts_.insert(context.get()); + context->Owner = this; + if (UseCompletionQueuePerThread_) { + context->CQ = CQS_[RandomNumber(CQS_.size())].get(); + } else { + context->CQ = CQS_[0].get(); + } + return context; +} + +void TGRpcClientLow::ForgetContext(TContextImpl* context) { + bool shutdown = false; + + { + std::unique_lock<std::mutex> guard(Mtx_); + + if (!Contexts_.erase(context)) { + Y_FAIL("Unexpected ForgetContext(%p)", context); + } + + if (Contexts_.empty()) { + if (IsStopping()) { + shutdown = true; + } + + ContextsEmpty_.notify_all(); + } + } + + if (shutdown) { + // This was the last context, shutdown CQ + for (auto& cq : CQS_) { + cq->Shutdown(); + } + } +} + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/grpc_client_low.h b/library/cpp/grpc/client/grpc_client_low.h new file mode 100644 index 00000000000..ab0a0627be0 --- /dev/null +++ b/library/cpp/grpc/client/grpc_client_low.h @@ -0,0 +1,1399 @@ +#pragma once + +#include "grpc_common.h" + +#include <util/thread/factory.h> +#include <grpc++/grpc++.h> +#include <grpc++/support/async_stream.h> +#include <grpc++/support/async_unary_call.h> + +#include <deque> +#include <typeindex> +#include <typeinfo> +#include <variant> +#include <vector> +#include <unordered_map> +#include <unordered_set> +#include <mutex> +#include <shared_mutex> + +/* + * This file contains low level logic for grpc + * This file should not be used in high level code without special reason + */ +namespace NGrpc { + +const size_t DEFAULT_NUM_THREADS = 2; + +//////////////////////////////////////////////////////////////////////////////// + +void EnableGRpcTracing(); + +//////////////////////////////////////////////////////////////////////////////// + +struct TTcpKeepAliveSettings { + bool Enabled; + size_t Idle; + size_t Count; + size_t Interval; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Common interface used to execute action from grpc cq routine +class IQueueClientEvent { +public: + virtual ~IQueueClientEvent() = default; + + //! Execute an action defined by implementation + virtual bool Execute(bool ok) = 0; + + //! Finish and destroy event + virtual void Destroy() = 0; +}; + +// Implementation of IQueueClientEvent that reduces allocations +template<class TSelf> +class TQueueClientFixedEvent : private IQueueClientEvent { + using TCallback = void (TSelf::*)(bool); + +public: + TQueueClientFixedEvent(TSelf* self, TCallback callback) + : Self(self) + , Callback(callback) + { } + + IQueueClientEvent* Prepare() { + Self->Ref(); + return this; + } + +private: + bool Execute(bool ok) override { + ((*Self).*Callback)(ok); + return false; + } + + void Destroy() override { + Self->UnRef(); + } + +private: + TSelf* const Self; + TCallback const Callback; +}; + +class IQueueClientContext; +using IQueueClientContextPtr = std::shared_ptr<IQueueClientContext>; + +// Provider of IQueueClientContext instances +class IQueueClientContextProvider { +public: + virtual ~IQueueClientContextProvider() = default; + + virtual IQueueClientContextPtr CreateContext() = 0; +}; + +// Activity context for a low-level client +class IQueueClientContext : public IQueueClientContextProvider { +public: + virtual ~IQueueClientContext() = default; + + //! Returns CompletionQueue associated with the client + virtual grpc::CompletionQueue* CompletionQueue() = 0; + + //! Returns true if context has been cancelled + virtual bool IsCancelled() const = 0; + + //! Tries to cancel context, calling all registered callbacks + virtual bool Cancel() = 0; + + //! Subscribes callback to cancellation + // + // Note there's no way to unsubscribe, if subscription is temporary + // make sure you create a new context with CreateContext and release + // it as soon as it's no longer needed. + virtual void SubscribeCancel(std::function<void()> callback) = 0; + + //! Subscribes callback to cancellation + // + // This alias is for compatibility with older code. + void SubscribeStop(std::function<void()> callback) { + SubscribeCancel(std::move(callback)); + } +}; + +// Represents grpc status and error message string +struct TGrpcStatus { + TString Msg; + TString Details; + int GRpcStatusCode; + bool InternalError; + + TGrpcStatus() + : GRpcStatusCode(grpc::StatusCode::OK) + , InternalError(false) + { } + + TGrpcStatus(TString msg, int statusCode, bool internalError) + : Msg(std::move(msg)) + , GRpcStatusCode(statusCode) + , InternalError(internalError) + { } + + TGrpcStatus(grpc::StatusCode status, TString msg, TString details = {}) + : Msg(std::move(msg)) + , Details(std::move(details)) + , GRpcStatusCode(status) + , InternalError(false) + { } + + TGrpcStatus(const grpc::Status& status) + : TGrpcStatus(status.error_code(), TString(status.error_message()), TString(status.error_details())) + { } + + TGrpcStatus& operator=(const grpc::Status& status) { + Msg = TString(status.error_message()); + Details = TString(status.error_details()); + GRpcStatusCode = status.error_code(); + InternalError = false; + return *this; + } + + static TGrpcStatus Internal(TString msg) { + return { std::move(msg), -1, true }; + } + + bool Ok() const { + return !InternalError && GRpcStatusCode == grpc::StatusCode::OK; + } +}; + +bool inline IsGRpcStatusGood(const TGrpcStatus& status) { + return status.Ok(); +} + +// Response callback type - this callback will be called when request is finished +// (or after getting each chunk in case of streaming mode) +template<typename TResponse> +using TResponseCallback = std::function<void (TGrpcStatus&&, TResponse&&)>; + +template<typename TResponse> +using TAdvancedResponseCallback = std::function<void (const grpc::ClientContext&, TGrpcStatus&&, TResponse&&)>; + +// Call associated metadata +struct TCallMeta { + std::shared_ptr<grpc::CallCredentials> CallCredentials; + std::vector<std::pair<TString, TString>> Aux; + std::variant<TDuration, TInstant> Timeout; // timeout as duration from now or time point in future +}; + +class TGRpcRequestProcessorCommon { +protected: + void ApplyMeta(const TCallMeta& meta) { + for (const auto& rec : meta.Aux) { + Context.AddMetadata(rec.first, rec.second); + } + if (meta.CallCredentials) { + Context.set_credentials(meta.CallCredentials); + } + if (const TDuration* timeout = std::get_if<TDuration>(&meta.Timeout)) { + if (*timeout) { + auto deadline = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(timeout->MicroSeconds(), GPR_TIMESPAN)); + Context.set_deadline(deadline); + } + } else if (const TInstant* deadline = std::get_if<TInstant>(&meta.Timeout)) { + if (*deadline) { + Context.set_deadline(gpr_time_from_micros(deadline->MicroSeconds(), GPR_CLOCK_MONOTONIC)); + } + } + } + + void GetInitialMetadata(std::unordered_multimap<TString, TString>* metadata) { + for (const auto& [key, value] : Context.GetServerInitialMetadata()) { + metadata->emplace( + TString(key.begin(), key.end()), + TString(value.begin(), value.end()) + ); + } + } + + grpc::Status Status; + grpc::ClientContext Context; + std::shared_ptr<IQueueClientContext> LocalContext; +}; + +template<typename TStub, typename TRequest, typename TResponse> +class TSimpleRequestProcessor + : public TThrRefBase + , public IQueueClientEvent + , public TGRpcRequestProcessorCommon { + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>; + template<typename> friend class TServiceConnection; +public: + using TPtr = TIntrusivePtr<TSimpleRequestProcessor>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); + + explicit TSimpleRequestProcessor(TResponseCallback<TResponse>&& callback) + : Callback_(std::move(callback)) + { } + + ~TSimpleRequestProcessor() { + if (!Replied_ && Callback_) { + Callback_(TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + } + } + + bool Execute(bool ok) override { + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext.reset(); + } + TGrpcStatus status; + if (ok) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + Replied_ = true; + Callback_(std::move(status), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + return false; + } + + void Destroy() override { + UnRef(); + } + +private: + IQueueClientEvent* FinishedEvent() { + Ref(); + return this; + } + + void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + Replied_ = true; + Callback_(TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); + Callback_ = nullptr; + return; + } + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext = context; + Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); + Reader_->Finish(&Reply_, &Status, FinishedEvent()); + } + context->SubscribeStop([self = TPtr(this)] { + self->Stop(); + }); + } + + void Stop() { + Context.TryCancel(); + } + + TResponseCallback<TResponse> Callback_; + TResponse Reply_; + std::mutex Mutex_; + TAsyncReaderPtr Reader_; + + bool Replied_ = false; +}; + +template<typename TStub, typename TRequest, typename TResponse> +class TAdvancedRequestProcessor + : public TThrRefBase + , public IQueueClientEvent + , public TGRpcRequestProcessorCommon { + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>; + template<typename> friend class TServiceConnection; +public: + using TPtr = TIntrusivePtr<TAdvancedRequestProcessor>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); + + explicit TAdvancedRequestProcessor(TAdvancedResponseCallback<TResponse>&& callback) + : Callback_(std::move(callback)) + { } + + ~TAdvancedRequestProcessor() { + if (!Replied_ && Callback_) { + Callback_(Context, TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + } + } + + bool Execute(bool ok) override { + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext.reset(); + } + TGrpcStatus status; + if (ok) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + Replied_ = true; + Callback_(Context, std::move(status), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + return false; + } + + void Destroy() override { + UnRef(); + } + +private: + IQueueClientEvent* FinishedEvent() { + Ref(); + return this; + } + + void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + Replied_ = true; + Callback_(Context, TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); + Callback_ = nullptr; + return; + } + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext = context; + Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); + Reader_->Finish(&Reply_, &Status, FinishedEvent()); + } + context->SubscribeStop([self = TPtr(this)] { + self->Stop(); + }); + } + + void Stop() { + Context.TryCancel(); + } + + TAdvancedResponseCallback<TResponse> Callback_; + TResponse Reply_; + std::mutex Mutex_; + TAsyncReaderPtr Reader_; + + bool Replied_ = false; +}; + +template<class TResponse> +class IStreamRequestReadProcessor : public TThrRefBase { +public: + using TPtr = TIntrusivePtr<IStreamRequestReadProcessor>; + using TReadCallback = std::function<void(TGrpcStatus&&)>; + + /** + * Asynchronously cancel the request + */ + virtual void Cancel() = 0; + + /** + * Scheduled initial server metadata read from the stream + */ + virtual void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) = 0; + + /** + * Scheduled response read from the stream + * Callback will be called with the status if it failed + * Only one Read or Finish call may be active at a time + */ + virtual void Read(TResponse* response, TReadCallback callback) = 0; + + /** + * Stop reading and gracefully finish the stream + * Only one Read or Finish call may be active at a time + */ + virtual void Finish(TReadCallback callback) = 0; + + /** + * Additional callback to be called when stream has finished + */ + virtual void AddFinishedCallback(TReadCallback callback) = 0; +}; + +template<class TRequest, class TResponse> +class IStreamRequestReadWriteProcessor : public IStreamRequestReadProcessor<TResponse> { +public: + using TPtr = TIntrusivePtr<IStreamRequestReadWriteProcessor>; + using TWriteCallback = std::function<void(TGrpcStatus&&)>; + + /** + * Scheduled request write to the stream + */ + virtual void Write(TRequest&& request, TWriteCallback callback = { }) = 0; +}; + +class TGRpcKeepAliveSocketMutator; + +// Class to hold stubs allocated on channel. +// It is poor documented part of grpc. See KIKIMR-6109 and comment to this commit + +// Stub holds shared_ptr<ChannelInterface>, so we can destroy this holder even if +// request processor using stub +class TStubsHolder : public TNonCopyable { + using TypeInfoRef = std::reference_wrapper<const std::type_info>; + + struct THasher { + std::size_t operator()(TypeInfoRef code) const { + return code.get().hash_code(); + } + }; + + struct TEqualTo { + bool operator()(TypeInfoRef lhs, TypeInfoRef rhs) const { + return lhs.get() == rhs.get(); + } + }; +public: + TStubsHolder(std::shared_ptr<grpc::ChannelInterface> channel) + : ChannelInterface_(channel) + {} + + // Returns true if channel can't be used to perform request now + bool IsChannelBroken() const { + auto state = ChannelInterface_->GetState(false); + return state == GRPC_CHANNEL_SHUTDOWN || + state == GRPC_CHANNEL_TRANSIENT_FAILURE; + } + + template<typename TStub> + std::shared_ptr<TStub> GetOrCreateStub() { + const auto& stubId = typeid(TStub); + { + std::shared_lock readGuard(RWMutex_); + const auto it = Stubs_.find(stubId); + if (it != Stubs_.end()) { + return std::static_pointer_cast<TStub>(it->second); + } + } + { + std::unique_lock writeGuard(RWMutex_); + auto it = Stubs_.emplace(stubId, nullptr); + if (!it.second) { + return std::static_pointer_cast<TStub>(it.first->second); + } else { + it.first->second = std::make_shared<TStub>(ChannelInterface_); + return std::static_pointer_cast<TStub>(it.first->second); + } + } + } + + const TInstant& GetLastUseTime() const { + return LastUsed_; + } + + void SetLastUseTime(const TInstant& time) { + LastUsed_ = time; + } +private: + TInstant LastUsed_ = Now(); + std::shared_mutex RWMutex_; + std::unordered_map<TypeInfoRef, std::shared_ptr<void>, THasher, TEqualTo> Stubs_; + std::shared_ptr<grpc::ChannelInterface> ChannelInterface_; +}; + +class TChannelPool { +public: + TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime = TDuration::Minutes(6)); + //Allows to CreateStub from TStubsHolder under lock + //The callback will be called just during GetStubsHolderLocked call + void GetStubsHolderLocked(const TString& channelId, const TGRpcClientConfig& config, std::function<void(TStubsHolder&)> cb); + void DeleteChannel(const TString& channelId); + void DeleteExpiredStubsHolders(); +private: + std::shared_mutex RWMutex_; + std::unordered_map<TString, TStubsHolder> Pool_; + std::multimap<TInstant, TString> LastUsedQueue_; + TTcpKeepAliveSettings TcpKeepAliveSettings_; + TDuration ExpireTime_; + TDuration UpdateReUseTime_; + void EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId); +}; + +template<class TResponse> +using TStreamReaderCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadProcessor<TResponse>::TPtr)>; + +template<typename TStub, typename TRequest, typename TResponse> +class TStreamRequestReadProcessor + : public IStreamRequestReadProcessor<TResponse> + , public TGRpcRequestProcessorCommon { + template<typename> friend class TServiceConnection; +public: + using TSelf = TStreamRequestReadProcessor; + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncReader<TResponse>>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*, void*); + using TReaderCallback = TStreamReaderCallback<TResponse>; + using TPtr = TIntrusivePtr<TSelf>; + using TBase = IStreamRequestReadProcessor<TResponse>; + using TReadCallback = typename TBase::TReadCallback; + + explicit TStreamRequestReadProcessor(TReaderCallback&& callback) + : Callback(std::move(callback)) + { + Y_VERIFY(Callback, "Missing connected callback"); + } + + void Cancel() override { + Context.TryCancel(); + + { + std::unique_lock<std::mutex> guard(Mutex); + Cancelled = true; + if (Started && !ReadFinished) { + if (!ReadActive) { + ReadFinished = true; + } + if (ReadFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + } + + void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished && !HasInitialMetadata) { + ReadActive = true; + ReadCallback = std::move(callback); + InitialMetadata = metadata; + if (!ReadFinished) { + Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); + } + return; + } + if (!HasInitialMetadata) { + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } else { + GetInitialMetadata(metadata); + } + } + + callback(std::move(status)); + } + + void Read(TResponse* message, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + ReadCallback = std::move(callback); + if (!ReadFinished) { + Stream->Read(message, OnReadDoneTag.Prepare()); + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + + callback(std::move(status)); + } + + void Finish(TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + FinishCallback = std::move(callback); + if (!ReadFinished) { + ReadFinished = true; + } + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + + void AddFinishedCallback(TReadCallback callback) override { + Y_VERIFY(callback, "Unexpected empty callback"); + + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (!Finished) { + FinishedCallbacks.emplace_back().swap(callback); + return; + } + + if (FinishedOk) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + +private: + void Start(TStub& stub, const TRequest& request, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + auto callback = std::move(Callback); + TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); + callback(std::move(status), nullptr); + return; + } + + { + std::unique_lock<std::mutex> guard(Mutex); + LocalContext = context; + Stream = (stub.*asyncRequest)(&Context, request, context->CompletionQueue(), OnStartDoneTag.Prepare()); + } + + context->SubscribeStop([self = TPtr(this)] { + self->Cancel(); + }); + } + + void OnReadDone(bool ok) { + TGrpcStatus status; + TReadCallback callback; + std::unordered_multimap<TString, TString>* initialMetadata = nullptr; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(ReadActive, "Unexpected Read done callback"); + Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); + + if (!ok || Cancelled) { + ReadFinished = true; + + Stream->Finish(&Status, OnFinishedTag.Prepare()); + if (!ok) { + // Keep ReadActive=true, so callback is called + // after the call is finished with an error + return; + } + } + + callback = std::move(ReadCallback); + ReadCallback = nullptr; + ReadActive = false; + initialMetadata = InitialMetadata; + InitialMetadata = nullptr; + HasInitialMetadata = true; + } + + if (initialMetadata) { + GetInitialMetadata(initialMetadata); + } + + callback(std::move(status)); + } + + void OnStartDone(bool ok) { + TReaderCallback callback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Started = true; + if (!ok || Cancelled) { + ReadFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + callback = std::move(Callback); + Callback = nullptr; + } + + callback({ }, typename TBase::TPtr(this)); + } + + void OnFinished(bool ok) { + TGrpcStatus status; + std::vector<TReadCallback> finishedCallbacks; + TReaderCallback startCallback; + TReadCallback readCallback; + TReadCallback finishCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + + Finished = true; + FinishedOk = ok; + LocalContext.reset(); + + if (ok) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + + finishedCallbacks.swap(FinishedCallbacks); + + if (Callback) { + Y_VERIFY(!ReadActive); + startCallback = std::move(Callback); + Callback = nullptr; + } else if (ReadActive) { + if (ReadCallback) { + readCallback = std::move(ReadCallback); + ReadCallback = nullptr; + } else { + finishCallback = std::move(FinishCallback); + FinishCallback = nullptr; + } + ReadActive = false; + } + } + + for (auto& finishedCallback : finishedCallbacks) { + auto statusCopy = status; + finishedCallback(std::move(statusCopy)); + } + + if (startCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); + } + startCallback(std::move(status), nullptr); + } else if (readCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + readCallback(std::move(status)); + } else if (finishCallback) { + finishCallback(std::move(status)); + } + } + + TReaderCallback Callback; + TAsyncReaderPtr Stream; + using TFixedEvent = TQueueClientFixedEvent<TSelf>; + std::mutex Mutex; + TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; + TFixedEvent OnStartDoneTag = { this, &TSelf::OnStartDone }; + TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; + + TReadCallback ReadCallback; + TReadCallback FinishCallback; + std::vector<TReadCallback> FinishedCallbacks; + std::unordered_multimap<TString, TString>* InitialMetadata = nullptr; + bool Started = false; + bool HasInitialMetadata = false; + bool ReadActive = false; + bool ReadFinished = false; + bool Finished = false; + bool Cancelled = false; + bool FinishedOk = false; +}; + +template<class TRequest, class TResponse> +using TStreamConnectedCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadWriteProcessor<TRequest, TResponse>::TPtr)>; + +template<class TStub, class TRequest, class TResponse> +class TStreamRequestReadWriteProcessor + : public IStreamRequestReadWriteProcessor<TRequest, TResponse> + , public TGRpcRequestProcessorCommon { +public: + using TSelf = TStreamRequestReadWriteProcessor; + using TBase = IStreamRequestReadWriteProcessor<TRequest, TResponse>; + using TPtr = TIntrusivePtr<TSelf>; + using TConnectedCallback = TStreamConnectedCallback<TRequest, TResponse>; + using TReadCallback = typename TBase::TReadCallback; + using TWriteCallback = typename TBase::TWriteCallback; + using TAsyncReaderWriterPtr = std::unique_ptr<grpc::ClientAsyncReaderWriter<TRequest, TResponse>>; + using TAsyncRequest = TAsyncReaderWriterPtr (TStub::*)(grpc::ClientContext*, grpc::CompletionQueue*, void*); + + explicit TStreamRequestReadWriteProcessor(TConnectedCallback&& callback) + : ConnectedCallback(std::move(callback)) + { + Y_VERIFY(ConnectedCallback, "Missing connected callback"); + } + + void Cancel() override { + Context.TryCancel(); + + { + std::unique_lock<std::mutex> guard(Mutex); + Cancelled = true; + if (Started && !(ReadFinished && WriteFinished)) { + if (!ReadActive) { + ReadFinished = true; + } + if (!WriteActive) { + WriteFinished = true; + } + if (ReadFinished && WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + } + + void Write(TRequest&& request, TWriteCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (Cancelled || ReadFinished || WriteFinished) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); + } else if (WriteActive) { + auto& item = WriteQueue.emplace_back(); + item.Callback.swap(callback); + item.Request.Swap(&request); + } else { + WriteActive = true; + WriteCallback.swap(callback); + Stream->Write(request, OnWriteDoneTag.Prepare()); + } + } + + if (!status.Ok() && callback) { + callback(std::move(status)); + } + } + + void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished && !HasInitialMetadata) { + ReadActive = true; + ReadCallback = std::move(callback); + InitialMetadata = metadata; + if (!ReadFinished) { + Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); + } + return; + } + if (!HasInitialMetadata) { + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } else { + GetInitialMetadata(metadata); + } + } + + callback(std::move(status)); + } + + void Read(TResponse* message, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + ReadCallback = std::move(callback); + if (!ReadFinished) { + Stream->Read(message, OnReadDoneTag.Prepare()); + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + + callback(std::move(status)); + } + + void Finish(TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + FinishCallback = std::move(callback); + if (!ReadFinished) { + ReadFinished = true; + if (!WriteActive) { + WriteFinished = true; + } + if (WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + + void AddFinishedCallback(TReadCallback callback) override { + Y_VERIFY(callback, "Unexpected empty callback"); + + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (!Finished) { + FinishedCallbacks.emplace_back().swap(callback); + return; + } + + if (FinishedOk) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + +private: + template<typename> friend class TServiceConnection; + + void Start(TStub& stub, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + auto callback = std::move(ConnectedCallback); + TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); + callback(std::move(status), nullptr); + return; + } + + { + std::unique_lock<std::mutex> guard(Mutex); + LocalContext = context; + Stream = (stub.*asyncRequest)(&Context, context->CompletionQueue(), OnConnectedTag.Prepare()); + } + + context->SubscribeStop([self = TPtr(this)] { + self->Cancel(); + }); + } + +private: + void OnConnected(bool ok) { + TConnectedCallback callback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Started = true; + if (!ok || Cancelled) { + ReadFinished = true; + WriteFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + + callback = std::move(ConnectedCallback); + ConnectedCallback = nullptr; + } + + callback({ }, typename TBase::TPtr(this)); + } + + void OnReadDone(bool ok) { + TGrpcStatus status; + TReadCallback callback; + std::unordered_multimap<TString, TString>* initialMetadata = nullptr; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(ReadActive, "Unexpected Read done callback"); + Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); + + if (!ok || Cancelled || WriteFinished) { + ReadFinished = true; + if (!WriteActive) { + WriteFinished = true; + } + if (WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + if (!ok) { + // Keep ReadActive=true, so callback is called + // after the call is finished with an error + return; + } + } + + callback = std::move(ReadCallback); + ReadCallback = nullptr; + ReadActive = false; + initialMetadata = InitialMetadata; + InitialMetadata = nullptr; + HasInitialMetadata = true; + } + + if (initialMetadata) { + GetInitialMetadata(initialMetadata); + } + + callback(std::move(status)); + } + + void OnWriteDone(bool ok) { + TWriteCallback okCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(WriteActive, "Unexpected Write done callback"); + Y_VERIFY(!WriteFinished, "Unexpected WriteFinished flag"); + + if (ok) { + okCallback.swap(WriteCallback); + } else if (WriteCallback) { + // Put callback back on the queue until OnFinished + auto& item = WriteQueue.emplace_front(); + item.Callback.swap(WriteCallback); + } + + if (!ok || Cancelled) { + WriteActive = false; + WriteFinished = true; + if (!ReadActive) { + ReadFinished = true; + } + if (ReadFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } else if (!WriteQueue.empty()) { + WriteCallback.swap(WriteQueue.front().Callback); + Stream->Write(WriteQueue.front().Request, OnWriteDoneTag.Prepare()); + WriteQueue.pop_front(); + } else { + WriteActive = false; + if (ReadFinished) { + WriteFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + + if (okCallback) { + okCallback(TGrpcStatus()); + } + } + + void OnFinished(bool ok) { + TGrpcStatus status; + std::deque<TWriteItem> writesDropped; + std::vector<TReadCallback> finishedCallbacks; + TConnectedCallback connectedCallback; + TReadCallback readCallback; + TReadCallback finishCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Finished = true; + FinishedOk = ok; + LocalContext.reset(); + + if (ok) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + + writesDropped.swap(WriteQueue); + finishedCallbacks.swap(FinishedCallbacks); + + if (ConnectedCallback) { + Y_VERIFY(!ReadActive); + connectedCallback = std::move(ConnectedCallback); + ConnectedCallback = nullptr; + } else if (ReadActive) { + if (ReadCallback) { + readCallback = std::move(ReadCallback); + ReadCallback = nullptr; + } else { + finishCallback = std::move(FinishCallback); + FinishCallback = nullptr; + } + ReadActive = false; + } + } + + for (auto& item : writesDropped) { + if (item.Callback) { + TGrpcStatus writeStatus = status; + if (writeStatus.Ok()) { + writeStatus = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); + } + item.Callback(std::move(writeStatus)); + } + } + + for (auto& finishedCallback : finishedCallbacks) { + TGrpcStatus statusCopy = status; + finishedCallback(std::move(statusCopy)); + } + + if (connectedCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); + } + connectedCallback(std::move(status), nullptr); + } else if (readCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + readCallback(std::move(status)); + } else if (finishCallback) { + finishCallback(std::move(status)); + } + } + +private: + struct TWriteItem { + TWriteCallback Callback; + TRequest Request; + }; + +private: + using TFixedEvent = TQueueClientFixedEvent<TSelf>; + + TFixedEvent OnConnectedTag = { this, &TSelf::OnConnected }; + TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; + TFixedEvent OnWriteDoneTag = { this, &TSelf::OnWriteDone }; + TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; + +private: + std::mutex Mutex; + TAsyncReaderWriterPtr Stream; + TConnectedCallback ConnectedCallback; + TReadCallback ReadCallback; + TReadCallback FinishCallback; + std::vector<TReadCallback> FinishedCallbacks; + std::deque<TWriteItem> WriteQueue; + TWriteCallback WriteCallback; + std::unordered_multimap<TString, TString>* InitialMetadata = nullptr; + bool Started = false; + bool HasInitialMetadata = false; + bool ReadActive = false; + bool ReadFinished = false; + bool WriteActive = false; + bool WriteFinished = false; + bool Finished = false; + bool Cancelled = false; + bool FinishedOk = false; +}; + +class TGRpcClientLow; + +template<typename TGRpcService> +class TServiceConnection { + using TStub = typename TGRpcService::Stub; + friend class TGRpcClientLow; + +public: + /* + * Start simple request + */ + template<typename TRequest, typename TResponse> + void DoRequest(const TRequest& request, + TResponseCallback<TResponse> callback, + typename TSimpleRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TSimpleRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); + } + + /* + * Start simple request + */ + template<typename TRequest, typename TResponse> + void DoAdvancedRequest(const TRequest& request, + TAdvancedResponseCallback<TResponse> callback, + typename TAdvancedRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TAdvancedRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); + } + + /* + * Start bidirectional streamming + */ + template<typename TRequest, typename TResponse> + void DoStreamRequest(TStreamConnectedCallback<TRequest, TResponse> callback, + typename TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, std::move(asyncRequest), provider ? provider : Provider_); + } + + /* + * Start streaming response reading (one request, many responses) + */ + template<typename TRequest, typename TResponse> + void DoStreamRequest(const TRequest& request, + TStreamReaderCallback<TResponse> callback, + typename TStreamRequestReadProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TStreamRequestReadProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, request, std::move(asyncRequest), provider ? provider : Provider_); + } + +private: + TServiceConnection(std::shared_ptr<grpc::ChannelInterface> ci, + IQueueClientContextProvider* provider) + : Stub_(TGRpcService::NewStub(ci)) + , Provider_(provider) + { + Y_VERIFY(Provider_, "Connection does not have a queue provider"); + } + + TServiceConnection(TStubsHolder& holder, + IQueueClientContextProvider* provider) + : Stub_(holder.GetOrCreateStub<TStub>()) + , Provider_(provider) + { + Y_VERIFY(Provider_, "Connection does not have a queue provider"); + } + + std::shared_ptr<TStub> Stub_; + IQueueClientContextProvider* Provider_; +}; + +class TGRpcClientLow + : public IQueueClientContextProvider +{ + class TContextImpl; + friend class TContextImpl; + + enum ECqState : TAtomicBase { + WORKING = 0, + STOP_SILENT = 1, + STOP_EXPLICIT = 2, + }; + +public: + explicit TGRpcClientLow(size_t numWorkerThread = DEFAULT_NUM_THREADS, bool useCompletionQueuePerThread = false); + ~TGRpcClientLow(); + + // Tries to stop all currently running requests (via their stop callbacks) + // Will shutdown CQ and drain events once all requests have finished + // No new requests may be started after this call + void Stop(bool wait = false); + + // Waits until all currently running requests finish execution + void WaitIdle(); + + inline bool IsStopping() const { + switch (GetCqState()) { + case WORKING: + return false; + case STOP_SILENT: + case STOP_EXPLICIT: + return true; + } + + Y_UNREACHABLE(); + } + + IQueueClientContextPtr CreateContext() override; + + template<typename TGRpcService> + std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(const TGRpcClientConfig& config) { + return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(CreateChannelInterface(config), this)); + } + + template<typename TGRpcService> + std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(TStubsHolder& holder) { + return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(holder, this)); + } + + // Tests only, not thread-safe + void AddWorkerThreadForTest(); + +private: + using IThreadRef = std::unique_ptr<IThreadFactory::IThread>; + using CompletionQueueRef = std::unique_ptr<grpc::CompletionQueue>; + void Init(size_t numWorkerThread); + + inline ECqState GetCqState() const { return (ECqState) AtomicGet(CqState_); } + inline void SetCqState(ECqState state) { AtomicSet(CqState_, state); } + + void StopInternal(bool silent); + void WaitInternal(); + + void ForgetContext(TContextImpl* context); + +private: + bool UseCompletionQueuePerThread_; + std::vector<CompletionQueueRef> CQS_; + std::vector<IThreadRef> WorkerThreads_; + TAtomic CqState_ = -1; + + std::mutex Mtx_; + std::condition_variable ContextsEmpty_; + std::unordered_set<TContextImpl*> Contexts_; + + std::mutex JoinMutex_; +}; + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/grpc_common.h b/library/cpp/grpc/client/grpc_common.h new file mode 100644 index 00000000000..ffcdafe0458 --- /dev/null +++ b/library/cpp/grpc/client/grpc_common.h @@ -0,0 +1,84 @@ +#pragma once + +#include <grpc++/grpc++.h> +#include <grpc++/resource_quota.h> + +#include <util/datetime/base.h> +#include <unordered_map> +#include <util/generic/string.h> + +constexpr ui64 DEFAULT_GRPC_MESSAGE_SIZE_LIMIT = 64000000; + +namespace NGrpc { + +struct TGRpcClientConfig { + TString Locator; // format host:port + TDuration Timeout = TDuration::Max(); // request timeout + ui64 MaxMessageSize = DEFAULT_GRPC_MESSAGE_SIZE_LIMIT; // Max request and response size + ui64 MaxInboundMessageSize = 0; // overrides MaxMessageSize for incoming requests + ui64 MaxOutboundMessageSize = 0; // overrides MaxMessageSize for outgoing requests + ui32 MaxInFlight = 0; + bool EnableSsl = false; + TString SslCaCert; //Implicitly enables Ssl if not empty + grpc_compression_algorithm CompressionAlgoritm = GRPC_COMPRESS_NONE; + ui64 MemQuota = 0; + std::unordered_map<TString, TString> StringChannelParams; + std::unordered_map<TString, int> IntChannelParams; + TString LoadBalancingPolicy = { }; + TString SslTargetNameOverride = { }; + + TGRpcClientConfig() = default; + TGRpcClientConfig(const TGRpcClientConfig&) = default; + TGRpcClientConfig(TGRpcClientConfig&&) = default; + TGRpcClientConfig& operator=(const TGRpcClientConfig&) = default; + TGRpcClientConfig& operator=(TGRpcClientConfig&&) = default; + + TGRpcClientConfig(const TString& locator, TDuration timeout = TDuration::Max(), + ui64 maxMessageSize = DEFAULT_GRPC_MESSAGE_SIZE_LIMIT, ui32 maxInFlight = 0, TString caCert = "", + grpc_compression_algorithm compressionAlgorithm = GRPC_COMPRESS_NONE, bool enableSsl = false) + : Locator(locator) + , Timeout(timeout) + , MaxMessageSize(maxMessageSize) + , MaxInFlight(maxInFlight) + , EnableSsl(enableSsl) + , SslCaCert(caCert) + , CompressionAlgoritm(compressionAlgorithm) + {} +}; + +inline std::shared_ptr<grpc::ChannelInterface> CreateChannelInterface(const TGRpcClientConfig& config, grpc_socket_mutator* mutator = nullptr){ + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(config.MaxInboundMessageSize ? config.MaxInboundMessageSize : config.MaxMessageSize); + args.SetMaxSendMessageSize(config.MaxOutboundMessageSize ? config.MaxOutboundMessageSize : config.MaxMessageSize); + args.SetCompressionAlgorithm(config.CompressionAlgoritm); + + for (const auto& kvp: config.StringChannelParams) { + args.SetString(kvp.first, kvp.second); + } + + for (const auto& kvp: config.IntChannelParams) { + args.SetInt(kvp.first, kvp.second); + } + + if (config.MemQuota) { + grpc::ResourceQuota quota; + quota.Resize(config.MemQuota); + args.SetResourceQuota(quota); + } + if (mutator) { + args.SetSocketMutator(mutator); + } + if (!config.LoadBalancingPolicy.empty()) { + args.SetLoadBalancingPolicyName(config.LoadBalancingPolicy); + } + if (!config.SslTargetNameOverride.empty()) { + args.SetSslTargetNameOverride(config.SslTargetNameOverride); + } + if (config.EnableSsl || config.SslCaCert) { + return grpc::CreateCustomChannel(config.Locator, grpc::SslCredentials(grpc::SslCredentialsOptions{config.SslCaCert, "", ""}), args); + } else { + return grpc::CreateCustomChannel(config.Locator, grpc::InsecureChannelCredentials(), args); + } +} + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp b/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp new file mode 100644 index 00000000000..b8af2a518fd --- /dev/null +++ b/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp @@ -0,0 +1,61 @@ +#include <library/cpp/grpc/client/grpc_client_low.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NGrpc; + +class TTestStub { +public: + std::shared_ptr<grpc::ChannelInterface> ChannelInterface; + TTestStub(std::shared_ptr<grpc::ChannelInterface> channelInterface) + : ChannelInterface(channelInterface) + {} +}; + +Y_UNIT_TEST_SUITE(ChannelPoolTests) { + Y_UNIT_TEST(UnusedStubsHoldersDeletion) { + TGRpcClientConfig clientConfig("invalid_host:invalid_port"); + TTcpKeepAliveSettings tcpKeepAliveSettings = + { + true, + 30, // NYdb::TCP_KEEPALIVE_IDLE, unused in UT, but is necessary in constructor + 5, // NYdb::TCP_KEEPALIVE_COUNT, unused in UT, but is necessary in constructor + 10 // NYdb::TCP_KEEPALIVE_INTERVAL, unused in UT, but is necessary in constructor + }; + auto channelPool = TChannelPool(tcpKeepAliveSettings, TDuration::MilliSeconds(250)); + std::vector<std::weak_ptr<grpc::ChannelInterface>> ChannelInterfacesWeak; + + { + std::vector<std::shared_ptr<TTestStub>> stubsHoldersShared; + auto storeStubsHolders = [&](TStubsHolder& stubsHolder) { + stubsHoldersShared.emplace_back(stubsHolder.GetOrCreateStub<TTestStub>()); + ChannelInterfacesWeak.emplace_back((*stubsHoldersShared.rbegin())->ChannelInterface); + return; + }; + for (int i = 0; i < 10; ++i) { + channelPool.GetStubsHolderLocked( + ToString(i), + clientConfig, + storeStubsHolders + ); + } + } + + auto now = Now(); + while (Now() < now + TDuration::MilliSeconds(500)){ + Sleep(TDuration::MilliSeconds(100)); + } + + channelPool.DeleteExpiredStubsHolders(); + + bool allDeleted = true; + for (auto i = ChannelInterfacesWeak.begin(); i != ChannelInterfacesWeak.end(); ++i) { + allDeleted = allDeleted && i->expired(); + } + + // assertion is made for channel interfaces instead of stubs, because after stub deletion + // TStubsHolder has the only shared_ptr for channel interface. + UNIT_ASSERT_C(allDeleted, "expired stubsHolders were not deleted after timeout"); + + } +} // ChannelPoolTests ut suite
\ No newline at end of file diff --git a/library/cpp/grpc/client/ut/ya.make b/library/cpp/grpc/client/ut/ya.make new file mode 100644 index 00000000000..eac779a99e4 --- /dev/null +++ b/library/cpp/grpc/client/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST_FOR(library/cpp/grpc/client) + +OWNER( + g:kikimr +) + +SRCS( + grpc_client_low_ut.cpp +) + +END() diff --git a/library/cpp/grpc/client/ya.make b/library/cpp/grpc/client/ya.make new file mode 100644 index 00000000000..a4e74b067cf --- /dev/null +++ b/library/cpp/grpc/client/ya.make @@ -0,0 +1,20 @@ +LIBRARY() + +OWNER( + ddoarn + g:kikimr +) + +SRCS( + grpc_client_low.cpp +) + +PEERDIR( + contrib/libs/grpc +) + +END() + +RECURSE_FOR_TESTS( + ut +)
\ No newline at end of file |