diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/grpc | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/grpc')
26 files changed, 4293 insertions, 0 deletions
diff --git a/library/cpp/grpc/client/grpc_client_low.cpp b/library/cpp/grpc/client/grpc_client_low.cpp new file mode 100644 index 0000000000..73cc908ef8 --- /dev/null +++ b/library/cpp/grpc/client/grpc_client_low.cpp @@ -0,0 +1,586 @@ +#include "grpc_client_low.h" +#include <contrib/libs/grpc/src/core/lib/iomgr/socket_mutator.h> +#include <contrib/libs/grpc/include/grpc/support/log.h> + +#include <library/cpp/containers/stack_vector/stack_vec.h> + +#include <util/string/printf.h> +#include <util/system/thread.h> +#include <util/random/random.h> + +#if !defined(_WIN32) && !defined(_WIN64) +#include <sys/types.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#endif + +namespace NGrpc { + +void EnableGRpcTracing() { + grpc_tracer_set_enabled("tcp", true); + grpc_tracer_set_enabled("client_channel", true); + grpc_tracer_set_enabled("channel", true); + grpc_tracer_set_enabled("api", true); + grpc_tracer_set_enabled("connectivity_state", true); + grpc_tracer_set_enabled("handshaker", true); + grpc_tracer_set_enabled("http", true); + grpc_tracer_set_enabled("http2_stream_state", true); + grpc_tracer_set_enabled("op_failure", true); + grpc_tracer_set_enabled("timer", true); + gpr_set_log_verbosity(GPR_LOG_SEVERITY_DEBUG); +} + +class TGRpcKeepAliveSocketMutator : public grpc_socket_mutator { +public: + TGRpcKeepAliveSocketMutator(int idle, int count, int interval) + : Idle_(idle) + , Count_(count) + , Interval_(interval) + { + grpc_socket_mutator_init(this, &VTable); + } +private: + static TGRpcKeepAliveSocketMutator* Cast(grpc_socket_mutator* mutator) { + return static_cast<TGRpcKeepAliveSocketMutator*>(mutator); + } + + template<typename TVal> + bool SetOption(int fd, int level, int optname, const TVal& value) { + return setsockopt(fd, level, optname, reinterpret_cast<const char*>(&value), sizeof(value)) == 0; + } + bool SetOption(int fd) { + if (!SetOption(fd, SOL_SOCKET, SO_KEEPALIVE, 1)) { + Cerr << Sprintf("Failed to set SO_KEEPALIVE option: %s", strerror(errno)) << Endl; + return false; + } +#ifdef _linux_ + if (Idle_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPIDLE, Idle_)) { + Cerr << Sprintf("Failed to set TCP_KEEPIDLE option: %s", strerror(errno)) << Endl; + return false; + } + if (Count_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPCNT, Count_)) { + Cerr << Sprintf("Failed to set TCP_KEEPCNT option: %s", strerror(errno)) << Endl; + return false; + } + if (Interval_ && !SetOption(fd, IPPROTO_TCP, TCP_KEEPINTVL, Interval_)) { + Cerr << Sprintf("Failed to set TCP_KEEPINTVL option: %s", strerror(errno)) << Endl; + return false; + } +#endif + return true; + } + static bool Mutate(int fd, grpc_socket_mutator* mutator) { + auto self = Cast(mutator); + return self->SetOption(fd); + } + static int Compare(grpc_socket_mutator* a, grpc_socket_mutator* b) { + const auto* selfA = Cast(a); + const auto* selfB = Cast(b); + auto tupleA = std::make_tuple(selfA->Idle_, selfA->Count_, selfA->Interval_); + auto tupleB = std::make_tuple(selfB->Idle_, selfB->Count_, selfB->Interval_); + return tupleA < tupleB ? -1 : tupleA > tupleB ? 1 : 0; + } + static void Destroy(grpc_socket_mutator* mutator) { + delete Cast(mutator); + } + + static grpc_socket_mutator_vtable VTable; + const int Idle_; + const int Count_; + const int Interval_; +}; + +grpc_socket_mutator_vtable TGRpcKeepAliveSocketMutator::VTable = + { + &TGRpcKeepAliveSocketMutator::Mutate, + &TGRpcKeepAliveSocketMutator::Compare, + &TGRpcKeepAliveSocketMutator::Destroy + }; + +TChannelPool::TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime) + : TcpKeepAliveSettings_(tcpKeepAliveSettings) + , ExpireTime_(expireTime) + , UpdateReUseTime_(ExpireTime_ * 0.3 < TDuration::Seconds(20) ? ExpireTime_ * 0.3 : TDuration::Seconds(20)) +{} + +void TChannelPool::GetStubsHolderLocked( + const TString& channelId, + const TGRpcClientConfig& config, + std::function<void(TStubsHolder&)> cb) +{ + { + std::shared_lock readGuard(RWMutex_); + const auto it = Pool_.find(channelId); + if (it != Pool_.end()) { + if (!it->second.IsChannelBroken() && !(Now() > it->second.GetLastUseTime() + UpdateReUseTime_)) { + return cb(it->second); + } + } + } + { + std::unique_lock writeGuard(RWMutex_); + { + auto it = Pool_.find(channelId); + if (it != Pool_.end()) { + if (!it->second.IsChannelBroken()) { + EraseFromQueueByTime(it->second.GetLastUseTime(), channelId); + auto now = Now(); + LastUsedQueue_.emplace(now, channelId); + it->second.SetLastUseTime(now); + return cb(it->second); + } else { + // This channel can't be used. Remove from pool to create new one + EraseFromQueueByTime(it->second.GetLastUseTime(), channelId); + Pool_.erase(it); + } + } + } + TGRpcKeepAliveSocketMutator* mutator = nullptr; + // will be destroyed inside grpc + if (TcpKeepAliveSettings_.Enabled) { + mutator = new TGRpcKeepAliveSocketMutator( + TcpKeepAliveSettings_.Idle, + TcpKeepAliveSettings_.Count, + TcpKeepAliveSettings_.Interval + ); + } + cb(Pool_.emplace(channelId, CreateChannelInterface(config, mutator)).first->second); + LastUsedQueue_.emplace(Pool_.at(channelId).GetLastUseTime(), channelId); + } +} + +void TChannelPool::DeleteChannel(const TString& channelId) { + std::unique_lock writeLock(RWMutex_); + auto poolIt = Pool_.find(channelId); + if (poolIt != Pool_.end()) { + EraseFromQueueByTime(poolIt->second.GetLastUseTime(), channelId); + Pool_.erase(poolIt); + } +} + +void TChannelPool::DeleteExpiredStubsHolders() { + std::unique_lock writeLock(RWMutex_); + auto lastExpired = LastUsedQueue_.lower_bound(Now() - ExpireTime_); + for (auto i = LastUsedQueue_.begin(); i != lastExpired; ++i){ + Pool_.erase(i->second); + } + LastUsedQueue_.erase(LastUsedQueue_.begin(), lastExpired); +} + +void TChannelPool::EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId) { + auto [begin, end] = LastUsedQueue_.equal_range(lastUseTime); + auto pos = std::find_if(begin, end, [&](auto a){return a.second == channelId;}); + Y_VERIFY(pos != LastUsedQueue_.end(), "data corruption at TChannelPool"); + LastUsedQueue_.erase(pos); +} + +static void PullEvents(grpc::CompletionQueue* cq) { + TThread::SetCurrentThreadName("grpc_client"); + while (true) { + void* tag; + bool ok; + + if (!cq->Next(&tag, &ok)) { + break; + } + + if (auto* ev = static_cast<IQueueClientEvent*>(tag)) { + if (!ev->Execute(ok)) { + ev->Destroy(); + } + } + } +} + +class TGRpcClientLow::TContextImpl final + : public std::enable_shared_from_this<TContextImpl> + , public IQueueClientContext +{ + friend class TGRpcClientLow; + + using TCallback = std::function<void()>; + using TContextPtr = std::shared_ptr<TContextImpl>; + +public: + ~TContextImpl() override { + Y_VERIFY(CountChildren() == 0, + "Destructor called with non-empty children"); + + if (Parent) { + Parent->ForgetContext(this); + } else if (Y_LIKELY(Owner)) { + Owner->ForgetContext(this); + } + } + + /** + * Helper for locking child pointer from a parent container + */ + static TContextPtr LockChildPtr(TContextImpl* ptr) { + if (ptr) { + // N.B. it is safe to do as long as it's done under a mutex and + // pointer is among valid children. When that's the case we + // know that TContextImpl destructor has not finished yet, so + // the object is valid. The lock() method may return nullptr + // though, if the object is being destructed right now. + return ptr->weak_from_this().lock(); + } else { + return nullptr; + } + } + + void ForgetContext(TContextImpl* child) { + std::unique_lock<std::mutex> guard(Mutex); + + auto removed = RemoveChild(child); + Y_VERIFY(removed, "Unexpected ForgetContext(%p)", child); + } + + IQueueClientContextPtr CreateContext() override { + auto self = shared_from_this(); + auto child = std::make_shared<TContextImpl>(); + + { + std::unique_lock<std::mutex> guard(Mutex); + + AddChild(child.get()); + + // It's now safe to initialize parent and owner + child->Parent = std::move(self); + child->Owner = Owner; + child->CQ = CQ; + + // Propagate cancellation to a child context + if (Cancelled.load(std::memory_order_relaxed)) { + child->Cancelled.store(true, std::memory_order_relaxed); + } + } + + return child; + } + + grpc::CompletionQueue* CompletionQueue() override { + Y_VERIFY(Owner, "Uninitialized context"); + return CQ; + } + + bool IsCancelled() const override { + return Cancelled.load(std::memory_order_acquire); + } + + bool Cancel() override { + TStackVec<TCallback, 1> callbacks; + TStackVec<TContextPtr, 2> children; + + { + std::unique_lock<std::mutex> guard(Mutex); + + if (Cancelled.load(std::memory_order_relaxed)) { + // Already cancelled in another thread + return false; + } + + callbacks.reserve(Callbacks.size()); + children.reserve(CountChildren()); + + for (auto& callback : Callbacks) { + callbacks.emplace_back().swap(callback); + } + Callbacks.clear(); + + // Collect all children we need to cancel + // N.B. we don't clear children links (cleared by destructors) + // N.B. some children may be stuck in destructors at the moment + for (TContextImpl* ptr : InlineChildren) { + if (auto child = LockChildPtr(ptr)) { + children.emplace_back(std::move(child)); + } + } + for (auto* ptr : Children) { + if (auto child = LockChildPtr(ptr)) { + children.emplace_back(std::move(child)); + } + } + + Cancelled.store(true, std::memory_order_release); + } + + // Call directly subscribed callbacks + if (callbacks) { + RunCallbacksNoExcept(callbacks); + } + + // Cancel all children + for (auto& child : children) { + child->Cancel(); + child.reset(); + } + + return true; + } + + void SubscribeCancel(TCallback callback) override { + Y_VERIFY(callback, "SubscribeCancel called with an empty callback"); + + { + std::unique_lock<std::mutex> guard(Mutex); + + if (!Cancelled.load(std::memory_order_relaxed)) { + Callbacks.emplace_back().swap(callback); + return; + } + } + + // Already cancelled, run immediately + callback(); + } + +private: + void AddChild(TContextImpl* child) { + for (TContextImpl*& slot : InlineChildren) { + if (!slot) { + slot = child; + return; + } + } + + Children.insert(child); + } + + bool RemoveChild(TContextImpl* child) { + for (TContextImpl*& slot : InlineChildren) { + if (slot == child) { + slot = nullptr; + return true; + } + } + + return Children.erase(child); + } + + size_t CountChildren() { + size_t count = 0; + + for (TContextImpl* ptr : InlineChildren) { + if (ptr) { + ++count; + } + } + + return count + Children.size(); + } + + template<class TCallbacks> + static void RunCallbacksNoExcept(TCallbacks& callbacks) noexcept { + for (auto& callback : callbacks) { + if (callback) { + callback(); + callback = nullptr; + } + } + } + +private: + // We want a simple lock here, without extra memory allocations + std::mutex Mutex; + + // These fields are initialized on successful registration + TContextPtr Parent; + TGRpcClientLow* Owner = nullptr; + grpc::CompletionQueue* CQ = nullptr; + + // Some children are stored inline, others are in a set + std::array<TContextImpl*, 2> InlineChildren{ { nullptr, nullptr } }; + std::unordered_set<TContextImpl*> Children; + + // Single callback is stored without extra allocations + TStackVec<TCallback, 1> Callbacks; + + // Atomic flag for a faster IsCancelled() implementation + std::atomic<bool> Cancelled; +}; + +TGRpcClientLow::TGRpcClientLow(size_t numWorkerThread, bool useCompletionQueuePerThread) + : UseCompletionQueuePerThread_(useCompletionQueuePerThread) +{ + Init(numWorkerThread); +} + +void TGRpcClientLow::Init(size_t numWorkerThread) { + SetCqState(WORKING); + if (UseCompletionQueuePerThread_) { + for (size_t i = 0; i < numWorkerThread; i++) { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } + } else { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + for (size_t i = 0; i < numWorkerThread; i++) { + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } + } +} + +void TGRpcClientLow::AddWorkerThreadForTest() { + if (UseCompletionQueuePerThread_) { + CQS_.push_back(std::make_unique<grpc::CompletionQueue>()); + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } else { + auto* cq = CQS_.back().get(); + WorkerThreads_.emplace_back(SystemThreadFactory()->Run([cq]() { + PullEvents(cq); + }).Release()); + } +} + +TGRpcClientLow::~TGRpcClientLow() { + StopInternal(true); + WaitInternal(); +} + +void TGRpcClientLow::Stop(bool wait) { + StopInternal(false); + + if (wait) { + WaitInternal(); + } +} + +void TGRpcClientLow::StopInternal(bool silent) { + bool shutdown; + + TVector<TContextImpl::TContextPtr> cancelQueue; + + { + std::unique_lock<std::mutex> guard(Mtx_); + + auto allowStateChange = [&]() { + switch (GetCqState()) { + case WORKING: + return true; + case STOP_SILENT: + return !silent; + case STOP_EXPLICIT: + return false; + } + + Y_UNREACHABLE(); + }; + + if (!allowStateChange()) { + // Completion queue is already stopping + return; + } + + SetCqState(silent ? STOP_SILENT : STOP_EXPLICIT); + + if (!silent && !Contexts_.empty()) { + cancelQueue.reserve(Contexts_.size()); + for (auto* ptr : Contexts_) { + // N.B. some contexts may be stuck in destructors + if (auto context = TContextImpl::LockChildPtr(ptr)) { + cancelQueue.emplace_back(std::move(context)); + } + } + } + + shutdown = Contexts_.empty(); + } + + for (auto& context : cancelQueue) { + context->Cancel(); + context.reset(); + } + + if (shutdown) { + for (auto& cq : CQS_) { + cq->Shutdown(); + } + } +} + +void TGRpcClientLow::WaitInternal() { + std::unique_lock<std::mutex> guard(JoinMutex_); + + for (auto& ti : WorkerThreads_) { + ti->Join(); + } +} + +void TGRpcClientLow::WaitIdle() { + std::unique_lock<std::mutex> guard(Mtx_); + + while (!Contexts_.empty()) { + ContextsEmpty_.wait(guard); + } +} + +std::shared_ptr<IQueueClientContext> TGRpcClientLow::CreateContext() { + std::unique_lock<std::mutex> guard(Mtx_); + + auto allowCreateContext = [&]() { + switch (GetCqState()) { + case WORKING: + return true; + case STOP_SILENT: + case STOP_EXPLICIT: + return false; + } + + Y_UNREACHABLE(); + }; + + if (!allowCreateContext()) { + // New context creation is forbidden + return nullptr; + } + + auto context = std::make_shared<TContextImpl>(); + Contexts_.insert(context.get()); + context->Owner = this; + if (UseCompletionQueuePerThread_) { + context->CQ = CQS_[RandomNumber(CQS_.size())].get(); + } else { + context->CQ = CQS_[0].get(); + } + return context; +} + +void TGRpcClientLow::ForgetContext(TContextImpl* context) { + bool shutdown = false; + + { + std::unique_lock<std::mutex> guard(Mtx_); + + if (!Contexts_.erase(context)) { + Y_FAIL("Unexpected ForgetContext(%p)", context); + } + + if (Contexts_.empty()) { + if (IsStopping()) { + shutdown = true; + } + + ContextsEmpty_.notify_all(); + } + } + + if (shutdown) { + // This was the last context, shutdown CQ + for (auto& cq : CQS_) { + cq->Shutdown(); + } + } +} + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/grpc_client_low.h b/library/cpp/grpc/client/grpc_client_low.h new file mode 100644 index 0000000000..ab0a0627be --- /dev/null +++ b/library/cpp/grpc/client/grpc_client_low.h @@ -0,0 +1,1399 @@ +#pragma once + +#include "grpc_common.h" + +#include <util/thread/factory.h> +#include <grpc++/grpc++.h> +#include <grpc++/support/async_stream.h> +#include <grpc++/support/async_unary_call.h> + +#include <deque> +#include <typeindex> +#include <typeinfo> +#include <variant> +#include <vector> +#include <unordered_map> +#include <unordered_set> +#include <mutex> +#include <shared_mutex> + +/* + * This file contains low level logic for grpc + * This file should not be used in high level code without special reason + */ +namespace NGrpc { + +const size_t DEFAULT_NUM_THREADS = 2; + +//////////////////////////////////////////////////////////////////////////////// + +void EnableGRpcTracing(); + +//////////////////////////////////////////////////////////////////////////////// + +struct TTcpKeepAliveSettings { + bool Enabled; + size_t Idle; + size_t Count; + size_t Interval; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Common interface used to execute action from grpc cq routine +class IQueueClientEvent { +public: + virtual ~IQueueClientEvent() = default; + + //! Execute an action defined by implementation + virtual bool Execute(bool ok) = 0; + + //! Finish and destroy event + virtual void Destroy() = 0; +}; + +// Implementation of IQueueClientEvent that reduces allocations +template<class TSelf> +class TQueueClientFixedEvent : private IQueueClientEvent { + using TCallback = void (TSelf::*)(bool); + +public: + TQueueClientFixedEvent(TSelf* self, TCallback callback) + : Self(self) + , Callback(callback) + { } + + IQueueClientEvent* Prepare() { + Self->Ref(); + return this; + } + +private: + bool Execute(bool ok) override { + ((*Self).*Callback)(ok); + return false; + } + + void Destroy() override { + Self->UnRef(); + } + +private: + TSelf* const Self; + TCallback const Callback; +}; + +class IQueueClientContext; +using IQueueClientContextPtr = std::shared_ptr<IQueueClientContext>; + +// Provider of IQueueClientContext instances +class IQueueClientContextProvider { +public: + virtual ~IQueueClientContextProvider() = default; + + virtual IQueueClientContextPtr CreateContext() = 0; +}; + +// Activity context for a low-level client +class IQueueClientContext : public IQueueClientContextProvider { +public: + virtual ~IQueueClientContext() = default; + + //! Returns CompletionQueue associated with the client + virtual grpc::CompletionQueue* CompletionQueue() = 0; + + //! Returns true if context has been cancelled + virtual bool IsCancelled() const = 0; + + //! Tries to cancel context, calling all registered callbacks + virtual bool Cancel() = 0; + + //! Subscribes callback to cancellation + // + // Note there's no way to unsubscribe, if subscription is temporary + // make sure you create a new context with CreateContext and release + // it as soon as it's no longer needed. + virtual void SubscribeCancel(std::function<void()> callback) = 0; + + //! Subscribes callback to cancellation + // + // This alias is for compatibility with older code. + void SubscribeStop(std::function<void()> callback) { + SubscribeCancel(std::move(callback)); + } +}; + +// Represents grpc status and error message string +struct TGrpcStatus { + TString Msg; + TString Details; + int GRpcStatusCode; + bool InternalError; + + TGrpcStatus() + : GRpcStatusCode(grpc::StatusCode::OK) + , InternalError(false) + { } + + TGrpcStatus(TString msg, int statusCode, bool internalError) + : Msg(std::move(msg)) + , GRpcStatusCode(statusCode) + , InternalError(internalError) + { } + + TGrpcStatus(grpc::StatusCode status, TString msg, TString details = {}) + : Msg(std::move(msg)) + , Details(std::move(details)) + , GRpcStatusCode(status) + , InternalError(false) + { } + + TGrpcStatus(const grpc::Status& status) + : TGrpcStatus(status.error_code(), TString(status.error_message()), TString(status.error_details())) + { } + + TGrpcStatus& operator=(const grpc::Status& status) { + Msg = TString(status.error_message()); + Details = TString(status.error_details()); + GRpcStatusCode = status.error_code(); + InternalError = false; + return *this; + } + + static TGrpcStatus Internal(TString msg) { + return { std::move(msg), -1, true }; + } + + bool Ok() const { + return !InternalError && GRpcStatusCode == grpc::StatusCode::OK; + } +}; + +bool inline IsGRpcStatusGood(const TGrpcStatus& status) { + return status.Ok(); +} + +// Response callback type - this callback will be called when request is finished +// (or after getting each chunk in case of streaming mode) +template<typename TResponse> +using TResponseCallback = std::function<void (TGrpcStatus&&, TResponse&&)>; + +template<typename TResponse> +using TAdvancedResponseCallback = std::function<void (const grpc::ClientContext&, TGrpcStatus&&, TResponse&&)>; + +// Call associated metadata +struct TCallMeta { + std::shared_ptr<grpc::CallCredentials> CallCredentials; + std::vector<std::pair<TString, TString>> Aux; + std::variant<TDuration, TInstant> Timeout; // timeout as duration from now or time point in future +}; + +class TGRpcRequestProcessorCommon { +protected: + void ApplyMeta(const TCallMeta& meta) { + for (const auto& rec : meta.Aux) { + Context.AddMetadata(rec.first, rec.second); + } + if (meta.CallCredentials) { + Context.set_credentials(meta.CallCredentials); + } + if (const TDuration* timeout = std::get_if<TDuration>(&meta.Timeout)) { + if (*timeout) { + auto deadline = gpr_time_add( + gpr_now(GPR_CLOCK_MONOTONIC), + gpr_time_from_micros(timeout->MicroSeconds(), GPR_TIMESPAN)); + Context.set_deadline(deadline); + } + } else if (const TInstant* deadline = std::get_if<TInstant>(&meta.Timeout)) { + if (*deadline) { + Context.set_deadline(gpr_time_from_micros(deadline->MicroSeconds(), GPR_CLOCK_MONOTONIC)); + } + } + } + + void GetInitialMetadata(std::unordered_multimap<TString, TString>* metadata) { + for (const auto& [key, value] : Context.GetServerInitialMetadata()) { + metadata->emplace( + TString(key.begin(), key.end()), + TString(value.begin(), value.end()) + ); + } + } + + grpc::Status Status; + grpc::ClientContext Context; + std::shared_ptr<IQueueClientContext> LocalContext; +}; + +template<typename TStub, typename TRequest, typename TResponse> +class TSimpleRequestProcessor + : public TThrRefBase + , public IQueueClientEvent + , public TGRpcRequestProcessorCommon { + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>; + template<typename> friend class TServiceConnection; +public: + using TPtr = TIntrusivePtr<TSimpleRequestProcessor>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); + + explicit TSimpleRequestProcessor(TResponseCallback<TResponse>&& callback) + : Callback_(std::move(callback)) + { } + + ~TSimpleRequestProcessor() { + if (!Replied_ && Callback_) { + Callback_(TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + } + } + + bool Execute(bool ok) override { + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext.reset(); + } + TGrpcStatus status; + if (ok) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + Replied_ = true; + Callback_(std::move(status), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + return false; + } + + void Destroy() override { + UnRef(); + } + +private: + IQueueClientEvent* FinishedEvent() { + Ref(); + return this; + } + + void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + Replied_ = true; + Callback_(TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); + Callback_ = nullptr; + return; + } + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext = context; + Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); + Reader_->Finish(&Reply_, &Status, FinishedEvent()); + } + context->SubscribeStop([self = TPtr(this)] { + self->Stop(); + }); + } + + void Stop() { + Context.TryCancel(); + } + + TResponseCallback<TResponse> Callback_; + TResponse Reply_; + std::mutex Mutex_; + TAsyncReaderPtr Reader_; + + bool Replied_ = false; +}; + +template<typename TStub, typename TRequest, typename TResponse> +class TAdvancedRequestProcessor + : public TThrRefBase + , public IQueueClientEvent + , public TGRpcRequestProcessorCommon { + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncResponseReader<TResponse>>; + template<typename> friend class TServiceConnection; +public: + using TPtr = TIntrusivePtr<TAdvancedRequestProcessor>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*); + + explicit TAdvancedRequestProcessor(TAdvancedResponseCallback<TResponse>&& callback) + : Callback_(std::move(callback)) + { } + + ~TAdvancedRequestProcessor() { + if (!Replied_ && Callback_) { + Callback_(Context, TGrpcStatus::Internal("request left unhandled"), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + } + } + + bool Execute(bool ok) override { + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext.reset(); + } + TGrpcStatus status; + if (ok) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + Replied_ = true; + Callback_(Context, std::move(status), std::move(Reply_)); + Callback_ = nullptr; // free resources as early as possible + return false; + } + + void Destroy() override { + UnRef(); + } + +private: + IQueueClientEvent* FinishedEvent() { + Ref(); + return this; + } + + void Start(TStub& stub, TAsyncRequest asyncRequest, const TRequest& request, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + Replied_ = true; + Callback_(Context, TGrpcStatus(grpc::StatusCode::CANCELLED, "Client is shutting down"), std::move(Reply_)); + Callback_ = nullptr; + return; + } + { + std::unique_lock<std::mutex> guard(Mutex_); + LocalContext = context; + Reader_ = (stub.*asyncRequest)(&Context, request, context->CompletionQueue()); + Reader_->Finish(&Reply_, &Status, FinishedEvent()); + } + context->SubscribeStop([self = TPtr(this)] { + self->Stop(); + }); + } + + void Stop() { + Context.TryCancel(); + } + + TAdvancedResponseCallback<TResponse> Callback_; + TResponse Reply_; + std::mutex Mutex_; + TAsyncReaderPtr Reader_; + + bool Replied_ = false; +}; + +template<class TResponse> +class IStreamRequestReadProcessor : public TThrRefBase { +public: + using TPtr = TIntrusivePtr<IStreamRequestReadProcessor>; + using TReadCallback = std::function<void(TGrpcStatus&&)>; + + /** + * Asynchronously cancel the request + */ + virtual void Cancel() = 0; + + /** + * Scheduled initial server metadata read from the stream + */ + virtual void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) = 0; + + /** + * Scheduled response read from the stream + * Callback will be called with the status if it failed + * Only one Read or Finish call may be active at a time + */ + virtual void Read(TResponse* response, TReadCallback callback) = 0; + + /** + * Stop reading and gracefully finish the stream + * Only one Read or Finish call may be active at a time + */ + virtual void Finish(TReadCallback callback) = 0; + + /** + * Additional callback to be called when stream has finished + */ + virtual void AddFinishedCallback(TReadCallback callback) = 0; +}; + +template<class TRequest, class TResponse> +class IStreamRequestReadWriteProcessor : public IStreamRequestReadProcessor<TResponse> { +public: + using TPtr = TIntrusivePtr<IStreamRequestReadWriteProcessor>; + using TWriteCallback = std::function<void(TGrpcStatus&&)>; + + /** + * Scheduled request write to the stream + */ + virtual void Write(TRequest&& request, TWriteCallback callback = { }) = 0; +}; + +class TGRpcKeepAliveSocketMutator; + +// Class to hold stubs allocated on channel. +// It is poor documented part of grpc. See KIKIMR-6109 and comment to this commit + +// Stub holds shared_ptr<ChannelInterface>, so we can destroy this holder even if +// request processor using stub +class TStubsHolder : public TNonCopyable { + using TypeInfoRef = std::reference_wrapper<const std::type_info>; + + struct THasher { + std::size_t operator()(TypeInfoRef code) const { + return code.get().hash_code(); + } + }; + + struct TEqualTo { + bool operator()(TypeInfoRef lhs, TypeInfoRef rhs) const { + return lhs.get() == rhs.get(); + } + }; +public: + TStubsHolder(std::shared_ptr<grpc::ChannelInterface> channel) + : ChannelInterface_(channel) + {} + + // Returns true if channel can't be used to perform request now + bool IsChannelBroken() const { + auto state = ChannelInterface_->GetState(false); + return state == GRPC_CHANNEL_SHUTDOWN || + state == GRPC_CHANNEL_TRANSIENT_FAILURE; + } + + template<typename TStub> + std::shared_ptr<TStub> GetOrCreateStub() { + const auto& stubId = typeid(TStub); + { + std::shared_lock readGuard(RWMutex_); + const auto it = Stubs_.find(stubId); + if (it != Stubs_.end()) { + return std::static_pointer_cast<TStub>(it->second); + } + } + { + std::unique_lock writeGuard(RWMutex_); + auto it = Stubs_.emplace(stubId, nullptr); + if (!it.second) { + return std::static_pointer_cast<TStub>(it.first->second); + } else { + it.first->second = std::make_shared<TStub>(ChannelInterface_); + return std::static_pointer_cast<TStub>(it.first->second); + } + } + } + + const TInstant& GetLastUseTime() const { + return LastUsed_; + } + + void SetLastUseTime(const TInstant& time) { + LastUsed_ = time; + } +private: + TInstant LastUsed_ = Now(); + std::shared_mutex RWMutex_; + std::unordered_map<TypeInfoRef, std::shared_ptr<void>, THasher, TEqualTo> Stubs_; + std::shared_ptr<grpc::ChannelInterface> ChannelInterface_; +}; + +class TChannelPool { +public: + TChannelPool(const TTcpKeepAliveSettings& tcpKeepAliveSettings, const TDuration& expireTime = TDuration::Minutes(6)); + //Allows to CreateStub from TStubsHolder under lock + //The callback will be called just during GetStubsHolderLocked call + void GetStubsHolderLocked(const TString& channelId, const TGRpcClientConfig& config, std::function<void(TStubsHolder&)> cb); + void DeleteChannel(const TString& channelId); + void DeleteExpiredStubsHolders(); +private: + std::shared_mutex RWMutex_; + std::unordered_map<TString, TStubsHolder> Pool_; + std::multimap<TInstant, TString> LastUsedQueue_; + TTcpKeepAliveSettings TcpKeepAliveSettings_; + TDuration ExpireTime_; + TDuration UpdateReUseTime_; + void EraseFromQueueByTime(const TInstant& lastUseTime, const TString& channelId); +}; + +template<class TResponse> +using TStreamReaderCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadProcessor<TResponse>::TPtr)>; + +template<typename TStub, typename TRequest, typename TResponse> +class TStreamRequestReadProcessor + : public IStreamRequestReadProcessor<TResponse> + , public TGRpcRequestProcessorCommon { + template<typename> friend class TServiceConnection; +public: + using TSelf = TStreamRequestReadProcessor; + using TAsyncReaderPtr = std::unique_ptr<grpc::ClientAsyncReader<TResponse>>; + using TAsyncRequest = TAsyncReaderPtr (TStub::*)(grpc::ClientContext*, const TRequest&, grpc::CompletionQueue*, void*); + using TReaderCallback = TStreamReaderCallback<TResponse>; + using TPtr = TIntrusivePtr<TSelf>; + using TBase = IStreamRequestReadProcessor<TResponse>; + using TReadCallback = typename TBase::TReadCallback; + + explicit TStreamRequestReadProcessor(TReaderCallback&& callback) + : Callback(std::move(callback)) + { + Y_VERIFY(Callback, "Missing connected callback"); + } + + void Cancel() override { + Context.TryCancel(); + + { + std::unique_lock<std::mutex> guard(Mutex); + Cancelled = true; + if (Started && !ReadFinished) { + if (!ReadActive) { + ReadFinished = true; + } + if (ReadFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + } + + void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished && !HasInitialMetadata) { + ReadActive = true; + ReadCallback = std::move(callback); + InitialMetadata = metadata; + if (!ReadFinished) { + Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); + } + return; + } + if (!HasInitialMetadata) { + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } else { + GetInitialMetadata(metadata); + } + } + + callback(std::move(status)); + } + + void Read(TResponse* message, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + ReadCallback = std::move(callback); + if (!ReadFinished) { + Stream->Read(message, OnReadDoneTag.Prepare()); + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + + callback(std::move(status)); + } + + void Finish(TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + FinishCallback = std::move(callback); + if (!ReadFinished) { + ReadFinished = true; + } + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + + void AddFinishedCallback(TReadCallback callback) override { + Y_VERIFY(callback, "Unexpected empty callback"); + + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (!Finished) { + FinishedCallbacks.emplace_back().swap(callback); + return; + } + + if (FinishedOk) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + +private: + void Start(TStub& stub, const TRequest& request, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + auto callback = std::move(Callback); + TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); + callback(std::move(status), nullptr); + return; + } + + { + std::unique_lock<std::mutex> guard(Mutex); + LocalContext = context; + Stream = (stub.*asyncRequest)(&Context, request, context->CompletionQueue(), OnStartDoneTag.Prepare()); + } + + context->SubscribeStop([self = TPtr(this)] { + self->Cancel(); + }); + } + + void OnReadDone(bool ok) { + TGrpcStatus status; + TReadCallback callback; + std::unordered_multimap<TString, TString>* initialMetadata = nullptr; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(ReadActive, "Unexpected Read done callback"); + Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); + + if (!ok || Cancelled) { + ReadFinished = true; + + Stream->Finish(&Status, OnFinishedTag.Prepare()); + if (!ok) { + // Keep ReadActive=true, so callback is called + // after the call is finished with an error + return; + } + } + + callback = std::move(ReadCallback); + ReadCallback = nullptr; + ReadActive = false; + initialMetadata = InitialMetadata; + InitialMetadata = nullptr; + HasInitialMetadata = true; + } + + if (initialMetadata) { + GetInitialMetadata(initialMetadata); + } + + callback(std::move(status)); + } + + void OnStartDone(bool ok) { + TReaderCallback callback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Started = true; + if (!ok || Cancelled) { + ReadFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + callback = std::move(Callback); + Callback = nullptr; + } + + callback({ }, typename TBase::TPtr(this)); + } + + void OnFinished(bool ok) { + TGrpcStatus status; + std::vector<TReadCallback> finishedCallbacks; + TReaderCallback startCallback; + TReadCallback readCallback; + TReadCallback finishCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + + Finished = true; + FinishedOk = ok; + LocalContext.reset(); + + if (ok) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + + finishedCallbacks.swap(FinishedCallbacks); + + if (Callback) { + Y_VERIFY(!ReadActive); + startCallback = std::move(Callback); + Callback = nullptr; + } else if (ReadActive) { + if (ReadCallback) { + readCallback = std::move(ReadCallback); + ReadCallback = nullptr; + } else { + finishCallback = std::move(FinishCallback); + FinishCallback = nullptr; + } + ReadActive = false; + } + } + + for (auto& finishedCallback : finishedCallbacks) { + auto statusCopy = status; + finishedCallback(std::move(statusCopy)); + } + + if (startCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); + } + startCallback(std::move(status), nullptr); + } else if (readCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + readCallback(std::move(status)); + } else if (finishCallback) { + finishCallback(std::move(status)); + } + } + + TReaderCallback Callback; + TAsyncReaderPtr Stream; + using TFixedEvent = TQueueClientFixedEvent<TSelf>; + std::mutex Mutex; + TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; + TFixedEvent OnStartDoneTag = { this, &TSelf::OnStartDone }; + TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; + + TReadCallback ReadCallback; + TReadCallback FinishCallback; + std::vector<TReadCallback> FinishedCallbacks; + std::unordered_multimap<TString, TString>* InitialMetadata = nullptr; + bool Started = false; + bool HasInitialMetadata = false; + bool ReadActive = false; + bool ReadFinished = false; + bool Finished = false; + bool Cancelled = false; + bool FinishedOk = false; +}; + +template<class TRequest, class TResponse> +using TStreamConnectedCallback = std::function<void(TGrpcStatus&&, typename IStreamRequestReadWriteProcessor<TRequest, TResponse>::TPtr)>; + +template<class TStub, class TRequest, class TResponse> +class TStreamRequestReadWriteProcessor + : public IStreamRequestReadWriteProcessor<TRequest, TResponse> + , public TGRpcRequestProcessorCommon { +public: + using TSelf = TStreamRequestReadWriteProcessor; + using TBase = IStreamRequestReadWriteProcessor<TRequest, TResponse>; + using TPtr = TIntrusivePtr<TSelf>; + using TConnectedCallback = TStreamConnectedCallback<TRequest, TResponse>; + using TReadCallback = typename TBase::TReadCallback; + using TWriteCallback = typename TBase::TWriteCallback; + using TAsyncReaderWriterPtr = std::unique_ptr<grpc::ClientAsyncReaderWriter<TRequest, TResponse>>; + using TAsyncRequest = TAsyncReaderWriterPtr (TStub::*)(grpc::ClientContext*, grpc::CompletionQueue*, void*); + + explicit TStreamRequestReadWriteProcessor(TConnectedCallback&& callback) + : ConnectedCallback(std::move(callback)) + { + Y_VERIFY(ConnectedCallback, "Missing connected callback"); + } + + void Cancel() override { + Context.TryCancel(); + + { + std::unique_lock<std::mutex> guard(Mutex); + Cancelled = true; + if (Started && !(ReadFinished && WriteFinished)) { + if (!ReadActive) { + ReadFinished = true; + } + if (!WriteActive) { + WriteFinished = true; + } + if (ReadFinished && WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + } + + void Write(TRequest&& request, TWriteCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (Cancelled || ReadFinished || WriteFinished) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); + } else if (WriteActive) { + auto& item = WriteQueue.emplace_back(); + item.Callback.swap(callback); + item.Request.Swap(&request); + } else { + WriteActive = true; + WriteCallback.swap(callback); + Stream->Write(request, OnWriteDoneTag.Prepare()); + } + } + + if (!status.Ok() && callback) { + callback(std::move(status)); + } + } + + void ReadInitialMetadata(std::unordered_multimap<TString, TString>* metadata, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished && !HasInitialMetadata) { + ReadActive = true; + ReadCallback = std::move(callback); + InitialMetadata = metadata; + if (!ReadFinished) { + Stream->ReadInitialMetadata(OnReadDoneTag.Prepare()); + } + return; + } + if (!HasInitialMetadata) { + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } else { + GetInitialMetadata(metadata); + } + } + + callback(std::move(status)); + } + + void Read(TResponse* message, TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + ReadCallback = std::move(callback); + if (!ReadFinished) { + Stream->Read(message, OnReadDoneTag.Prepare()); + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + + callback(std::move(status)); + } + + void Finish(TReadCallback callback) override { + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(!ReadActive, "Multiple Read/Finish calls detected"); + if (!Finished) { + ReadActive = true; + FinishCallback = std::move(callback); + if (!ReadFinished) { + ReadFinished = true; + if (!WriteActive) { + WriteFinished = true; + } + if (WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + return; + } + if (FinishedOk) { + status = Status; + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + + void AddFinishedCallback(TReadCallback callback) override { + Y_VERIFY(callback, "Unexpected empty callback"); + + TGrpcStatus status; + + { + std::unique_lock<std::mutex> guard(Mutex); + if (!Finished) { + FinishedCallbacks.emplace_back().swap(callback); + return; + } + + if (FinishedOk) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + } + + callback(std::move(status)); + } + +private: + template<typename> friend class TServiceConnection; + + void Start(TStub& stub, TAsyncRequest asyncRequest, IQueueClientContextProvider* provider) { + auto context = provider->CreateContext(); + if (!context) { + auto callback = std::move(ConnectedCallback); + TGrpcStatus status(grpc::StatusCode::CANCELLED, "Client is shutting down"); + callback(std::move(status), nullptr); + return; + } + + { + std::unique_lock<std::mutex> guard(Mutex); + LocalContext = context; + Stream = (stub.*asyncRequest)(&Context, context->CompletionQueue(), OnConnectedTag.Prepare()); + } + + context->SubscribeStop([self = TPtr(this)] { + self->Cancel(); + }); + } + +private: + void OnConnected(bool ok) { + TConnectedCallback callback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Started = true; + if (!ok || Cancelled) { + ReadFinished = true; + WriteFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + return; + } + + callback = std::move(ConnectedCallback); + ConnectedCallback = nullptr; + } + + callback({ }, typename TBase::TPtr(this)); + } + + void OnReadDone(bool ok) { + TGrpcStatus status; + TReadCallback callback; + std::unordered_multimap<TString, TString>* initialMetadata = nullptr; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(ReadActive, "Unexpected Read done callback"); + Y_VERIFY(!ReadFinished, "Unexpected ReadFinished flag"); + + if (!ok || Cancelled || WriteFinished) { + ReadFinished = true; + if (!WriteActive) { + WriteFinished = true; + } + if (WriteFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + if (!ok) { + // Keep ReadActive=true, so callback is called + // after the call is finished with an error + return; + } + } + + callback = std::move(ReadCallback); + ReadCallback = nullptr; + ReadActive = false; + initialMetadata = InitialMetadata; + InitialMetadata = nullptr; + HasInitialMetadata = true; + } + + if (initialMetadata) { + GetInitialMetadata(initialMetadata); + } + + callback(std::move(status)); + } + + void OnWriteDone(bool ok) { + TWriteCallback okCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Y_VERIFY(WriteActive, "Unexpected Write done callback"); + Y_VERIFY(!WriteFinished, "Unexpected WriteFinished flag"); + + if (ok) { + okCallback.swap(WriteCallback); + } else if (WriteCallback) { + // Put callback back on the queue until OnFinished + auto& item = WriteQueue.emplace_front(); + item.Callback.swap(WriteCallback); + } + + if (!ok || Cancelled) { + WriteActive = false; + WriteFinished = true; + if (!ReadActive) { + ReadFinished = true; + } + if (ReadFinished) { + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } else if (!WriteQueue.empty()) { + WriteCallback.swap(WriteQueue.front().Callback); + Stream->Write(WriteQueue.front().Request, OnWriteDoneTag.Prepare()); + WriteQueue.pop_front(); + } else { + WriteActive = false; + if (ReadFinished) { + WriteFinished = true; + Stream->Finish(&Status, OnFinishedTag.Prepare()); + } + } + } + + if (okCallback) { + okCallback(TGrpcStatus()); + } + } + + void OnFinished(bool ok) { + TGrpcStatus status; + std::deque<TWriteItem> writesDropped; + std::vector<TReadCallback> finishedCallbacks; + TConnectedCallback connectedCallback; + TReadCallback readCallback; + TReadCallback finishCallback; + + { + std::unique_lock<std::mutex> guard(Mutex); + Finished = true; + FinishedOk = ok; + LocalContext.reset(); + + if (ok) { + status = Status; + } else if (Cancelled) { + status = TGrpcStatus(grpc::StatusCode::CANCELLED, "Stream cancelled"); + } else { + status = TGrpcStatus::Internal("Unexpected error"); + } + + writesDropped.swap(WriteQueue); + finishedCallbacks.swap(FinishedCallbacks); + + if (ConnectedCallback) { + Y_VERIFY(!ReadActive); + connectedCallback = std::move(ConnectedCallback); + ConnectedCallback = nullptr; + } else if (ReadActive) { + if (ReadCallback) { + readCallback = std::move(ReadCallback); + ReadCallback = nullptr; + } else { + finishCallback = std::move(FinishCallback); + FinishCallback = nullptr; + } + ReadActive = false; + } + } + + for (auto& item : writesDropped) { + if (item.Callback) { + TGrpcStatus writeStatus = status; + if (writeStatus.Ok()) { + writeStatus = TGrpcStatus(grpc::StatusCode::CANCELLED, "Write request dropped"); + } + item.Callback(std::move(writeStatus)); + } + } + + for (auto& finishedCallback : finishedCallbacks) { + TGrpcStatus statusCopy = status; + finishedCallback(std::move(statusCopy)); + } + + if (connectedCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::UNKNOWN, "Unknown stream failure"); + } + connectedCallback(std::move(status), nullptr); + } else if (readCallback) { + if (status.Ok()) { + status = TGrpcStatus(grpc::StatusCode::OUT_OF_RANGE, "Read EOF"); + } + readCallback(std::move(status)); + } else if (finishCallback) { + finishCallback(std::move(status)); + } + } + +private: + struct TWriteItem { + TWriteCallback Callback; + TRequest Request; + }; + +private: + using TFixedEvent = TQueueClientFixedEvent<TSelf>; + + TFixedEvent OnConnectedTag = { this, &TSelf::OnConnected }; + TFixedEvent OnReadDoneTag = { this, &TSelf::OnReadDone }; + TFixedEvent OnWriteDoneTag = { this, &TSelf::OnWriteDone }; + TFixedEvent OnFinishedTag = { this, &TSelf::OnFinished }; + +private: + std::mutex Mutex; + TAsyncReaderWriterPtr Stream; + TConnectedCallback ConnectedCallback; + TReadCallback ReadCallback; + TReadCallback FinishCallback; + std::vector<TReadCallback> FinishedCallbacks; + std::deque<TWriteItem> WriteQueue; + TWriteCallback WriteCallback; + std::unordered_multimap<TString, TString>* InitialMetadata = nullptr; + bool Started = false; + bool HasInitialMetadata = false; + bool ReadActive = false; + bool ReadFinished = false; + bool WriteActive = false; + bool WriteFinished = false; + bool Finished = false; + bool Cancelled = false; + bool FinishedOk = false; +}; + +class TGRpcClientLow; + +template<typename TGRpcService> +class TServiceConnection { + using TStub = typename TGRpcService::Stub; + friend class TGRpcClientLow; + +public: + /* + * Start simple request + */ + template<typename TRequest, typename TResponse> + void DoRequest(const TRequest& request, + TResponseCallback<TResponse> callback, + typename TSimpleRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TSimpleRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); + } + + /* + * Start simple request + */ + template<typename TRequest, typename TResponse> + void DoAdvancedRequest(const TRequest& request, + TAdvancedResponseCallback<TResponse> callback, + typename TAdvancedRequestProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TAdvancedRequestProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, asyncRequest, request, provider ? provider : Provider_); + } + + /* + * Start bidirectional streamming + */ + template<typename TRequest, typename TResponse> + void DoStreamRequest(TStreamConnectedCallback<TRequest, TResponse> callback, + typename TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TStreamRequestReadWriteProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, std::move(asyncRequest), provider ? provider : Provider_); + } + + /* + * Start streaming response reading (one request, many responses) + */ + template<typename TRequest, typename TResponse> + void DoStreamRequest(const TRequest& request, + TStreamReaderCallback<TResponse> callback, + typename TStreamRequestReadProcessor<TStub, TRequest, TResponse>::TAsyncRequest asyncRequest, + const TCallMeta& metas = { }, + IQueueClientContextProvider* provider = nullptr) + { + auto processor = MakeIntrusive<TStreamRequestReadProcessor<TStub, TRequest, TResponse>>(std::move(callback)); + processor->ApplyMeta(metas); + processor->Start(*Stub_, request, std::move(asyncRequest), provider ? provider : Provider_); + } + +private: + TServiceConnection(std::shared_ptr<grpc::ChannelInterface> ci, + IQueueClientContextProvider* provider) + : Stub_(TGRpcService::NewStub(ci)) + , Provider_(provider) + { + Y_VERIFY(Provider_, "Connection does not have a queue provider"); + } + + TServiceConnection(TStubsHolder& holder, + IQueueClientContextProvider* provider) + : Stub_(holder.GetOrCreateStub<TStub>()) + , Provider_(provider) + { + Y_VERIFY(Provider_, "Connection does not have a queue provider"); + } + + std::shared_ptr<TStub> Stub_; + IQueueClientContextProvider* Provider_; +}; + +class TGRpcClientLow + : public IQueueClientContextProvider +{ + class TContextImpl; + friend class TContextImpl; + + enum ECqState : TAtomicBase { + WORKING = 0, + STOP_SILENT = 1, + STOP_EXPLICIT = 2, + }; + +public: + explicit TGRpcClientLow(size_t numWorkerThread = DEFAULT_NUM_THREADS, bool useCompletionQueuePerThread = false); + ~TGRpcClientLow(); + + // Tries to stop all currently running requests (via their stop callbacks) + // Will shutdown CQ and drain events once all requests have finished + // No new requests may be started after this call + void Stop(bool wait = false); + + // Waits until all currently running requests finish execution + void WaitIdle(); + + inline bool IsStopping() const { + switch (GetCqState()) { + case WORKING: + return false; + case STOP_SILENT: + case STOP_EXPLICIT: + return true; + } + + Y_UNREACHABLE(); + } + + IQueueClientContextPtr CreateContext() override; + + template<typename TGRpcService> + std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(const TGRpcClientConfig& config) { + return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(CreateChannelInterface(config), this)); + } + + template<typename TGRpcService> + std::unique_ptr<TServiceConnection<TGRpcService>> CreateGRpcServiceConnection(TStubsHolder& holder) { + return std::unique_ptr<TServiceConnection<TGRpcService>>(new TServiceConnection<TGRpcService>(holder, this)); + } + + // Tests only, not thread-safe + void AddWorkerThreadForTest(); + +private: + using IThreadRef = std::unique_ptr<IThreadFactory::IThread>; + using CompletionQueueRef = std::unique_ptr<grpc::CompletionQueue>; + void Init(size_t numWorkerThread); + + inline ECqState GetCqState() const { return (ECqState) AtomicGet(CqState_); } + inline void SetCqState(ECqState state) { AtomicSet(CqState_, state); } + + void StopInternal(bool silent); + void WaitInternal(); + + void ForgetContext(TContextImpl* context); + +private: + bool UseCompletionQueuePerThread_; + std::vector<CompletionQueueRef> CQS_; + std::vector<IThreadRef> WorkerThreads_; + TAtomic CqState_ = -1; + + std::mutex Mtx_; + std::condition_variable ContextsEmpty_; + std::unordered_set<TContextImpl*> Contexts_; + + std::mutex JoinMutex_; +}; + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/grpc_common.h b/library/cpp/grpc/client/grpc_common.h new file mode 100644 index 0000000000..ffcdafe045 --- /dev/null +++ b/library/cpp/grpc/client/grpc_common.h @@ -0,0 +1,84 @@ +#pragma once + +#include <grpc++/grpc++.h> +#include <grpc++/resource_quota.h> + +#include <util/datetime/base.h> +#include <unordered_map> +#include <util/generic/string.h> + +constexpr ui64 DEFAULT_GRPC_MESSAGE_SIZE_LIMIT = 64000000; + +namespace NGrpc { + +struct TGRpcClientConfig { + TString Locator; // format host:port + TDuration Timeout = TDuration::Max(); // request timeout + ui64 MaxMessageSize = DEFAULT_GRPC_MESSAGE_SIZE_LIMIT; // Max request and response size + ui64 MaxInboundMessageSize = 0; // overrides MaxMessageSize for incoming requests + ui64 MaxOutboundMessageSize = 0; // overrides MaxMessageSize for outgoing requests + ui32 MaxInFlight = 0; + bool EnableSsl = false; + TString SslCaCert; //Implicitly enables Ssl if not empty + grpc_compression_algorithm CompressionAlgoritm = GRPC_COMPRESS_NONE; + ui64 MemQuota = 0; + std::unordered_map<TString, TString> StringChannelParams; + std::unordered_map<TString, int> IntChannelParams; + TString LoadBalancingPolicy = { }; + TString SslTargetNameOverride = { }; + + TGRpcClientConfig() = default; + TGRpcClientConfig(const TGRpcClientConfig&) = default; + TGRpcClientConfig(TGRpcClientConfig&&) = default; + TGRpcClientConfig& operator=(const TGRpcClientConfig&) = default; + TGRpcClientConfig& operator=(TGRpcClientConfig&&) = default; + + TGRpcClientConfig(const TString& locator, TDuration timeout = TDuration::Max(), + ui64 maxMessageSize = DEFAULT_GRPC_MESSAGE_SIZE_LIMIT, ui32 maxInFlight = 0, TString caCert = "", + grpc_compression_algorithm compressionAlgorithm = GRPC_COMPRESS_NONE, bool enableSsl = false) + : Locator(locator) + , Timeout(timeout) + , MaxMessageSize(maxMessageSize) + , MaxInFlight(maxInFlight) + , EnableSsl(enableSsl) + , SslCaCert(caCert) + , CompressionAlgoritm(compressionAlgorithm) + {} +}; + +inline std::shared_ptr<grpc::ChannelInterface> CreateChannelInterface(const TGRpcClientConfig& config, grpc_socket_mutator* mutator = nullptr){ + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(config.MaxInboundMessageSize ? config.MaxInboundMessageSize : config.MaxMessageSize); + args.SetMaxSendMessageSize(config.MaxOutboundMessageSize ? config.MaxOutboundMessageSize : config.MaxMessageSize); + args.SetCompressionAlgorithm(config.CompressionAlgoritm); + + for (const auto& kvp: config.StringChannelParams) { + args.SetString(kvp.first, kvp.second); + } + + for (const auto& kvp: config.IntChannelParams) { + args.SetInt(kvp.first, kvp.second); + } + + if (config.MemQuota) { + grpc::ResourceQuota quota; + quota.Resize(config.MemQuota); + args.SetResourceQuota(quota); + } + if (mutator) { + args.SetSocketMutator(mutator); + } + if (!config.LoadBalancingPolicy.empty()) { + args.SetLoadBalancingPolicyName(config.LoadBalancingPolicy); + } + if (!config.SslTargetNameOverride.empty()) { + args.SetSslTargetNameOverride(config.SslTargetNameOverride); + } + if (config.EnableSsl || config.SslCaCert) { + return grpc::CreateCustomChannel(config.Locator, grpc::SslCredentials(grpc::SslCredentialsOptions{config.SslCaCert, "", ""}), args); + } else { + return grpc::CreateCustomChannel(config.Locator, grpc::InsecureChannelCredentials(), args); + } +} + +} // namespace NGRpc diff --git a/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp b/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp new file mode 100644 index 0000000000..b8af2a518f --- /dev/null +++ b/library/cpp/grpc/client/ut/grpc_client_low_ut.cpp @@ -0,0 +1,61 @@ +#include <library/cpp/grpc/client/grpc_client_low.h> + +#include <library/cpp/testing/unittest/registar.h> + +using namespace NGrpc; + +class TTestStub { +public: + std::shared_ptr<grpc::ChannelInterface> ChannelInterface; + TTestStub(std::shared_ptr<grpc::ChannelInterface> channelInterface) + : ChannelInterface(channelInterface) + {} +}; + +Y_UNIT_TEST_SUITE(ChannelPoolTests) { + Y_UNIT_TEST(UnusedStubsHoldersDeletion) { + TGRpcClientConfig clientConfig("invalid_host:invalid_port"); + TTcpKeepAliveSettings tcpKeepAliveSettings = + { + true, + 30, // NYdb::TCP_KEEPALIVE_IDLE, unused in UT, but is necessary in constructor + 5, // NYdb::TCP_KEEPALIVE_COUNT, unused in UT, but is necessary in constructor + 10 // NYdb::TCP_KEEPALIVE_INTERVAL, unused in UT, but is necessary in constructor + }; + auto channelPool = TChannelPool(tcpKeepAliveSettings, TDuration::MilliSeconds(250)); + std::vector<std::weak_ptr<grpc::ChannelInterface>> ChannelInterfacesWeak; + + { + std::vector<std::shared_ptr<TTestStub>> stubsHoldersShared; + auto storeStubsHolders = [&](TStubsHolder& stubsHolder) { + stubsHoldersShared.emplace_back(stubsHolder.GetOrCreateStub<TTestStub>()); + ChannelInterfacesWeak.emplace_back((*stubsHoldersShared.rbegin())->ChannelInterface); + return; + }; + for (int i = 0; i < 10; ++i) { + channelPool.GetStubsHolderLocked( + ToString(i), + clientConfig, + storeStubsHolders + ); + } + } + + auto now = Now(); + while (Now() < now + TDuration::MilliSeconds(500)){ + Sleep(TDuration::MilliSeconds(100)); + } + + channelPool.DeleteExpiredStubsHolders(); + + bool allDeleted = true; + for (auto i = ChannelInterfacesWeak.begin(); i != ChannelInterfacesWeak.end(); ++i) { + allDeleted = allDeleted && i->expired(); + } + + // assertion is made for channel interfaces instead of stubs, because after stub deletion + // TStubsHolder has the only shared_ptr for channel interface. + UNIT_ASSERT_C(allDeleted, "expired stubsHolders were not deleted after timeout"); + + } +} // ChannelPoolTests ut suite
\ No newline at end of file diff --git a/library/cpp/grpc/client/ut/ya.make b/library/cpp/grpc/client/ut/ya.make new file mode 100644 index 0000000000..eac779a99e --- /dev/null +++ b/library/cpp/grpc/client/ut/ya.make @@ -0,0 +1,11 @@ +UNITTEST_FOR(library/cpp/grpc/client) + +OWNER( + g:kikimr +) + +SRCS( + grpc_client_low_ut.cpp +) + +END() diff --git a/library/cpp/grpc/client/ya.make b/library/cpp/grpc/client/ya.make new file mode 100644 index 0000000000..a4e74b067c --- /dev/null +++ b/library/cpp/grpc/client/ya.make @@ -0,0 +1,20 @@ +LIBRARY() + +OWNER( + ddoarn + g:kikimr +) + +SRCS( + grpc_client_low.cpp +) + +PEERDIR( + contrib/libs/grpc +) + +END() + +RECURSE_FOR_TESTS( + ut +)
\ No newline at end of file diff --git a/library/cpp/grpc/server/actors/logger.cpp b/library/cpp/grpc/server/actors/logger.cpp new file mode 100644 index 0000000000..d8b2042576 --- /dev/null +++ b/library/cpp/grpc/server/actors/logger.cpp @@ -0,0 +1,45 @@ +#include "logger.h" + +namespace NGrpc { +namespace { + +static_assert( + ui16(TLOG_EMERG) == ui16(NActors::NLog::PRI_EMERG) && + ui16(TLOG_DEBUG) == ui16(NActors::NLog::PRI_DEBUG), + "log levels in the library/log and library/cpp/actors don't match"); + +class TActorSystemLogger final: public TLogger { +public: + TActorSystemLogger(NActors::TActorSystem& as, NActors::NLog::EComponent component) noexcept + : ActorSystem_{as} + , Component_{component} + { + } + +protected: + bool DoIsEnabled(ELogPriority p) const noexcept override { + const auto* settings = static_cast<::NActors::NLog::TSettings*>(ActorSystem_.LoggerSettings()); + const auto priority = static_cast<::NActors::NLog::EPriority>(p); + + return settings && settings->Satisfies(priority, Component_, 0); + } + + void DoWrite(ELogPriority p, const char* format, va_list args) noexcept override { + Y_VERIFY_DEBUG(DoIsEnabled(p)); + + const auto priority = static_cast<::NActors::NLog::EPriority>(p); + ::NActors::MemLogAdapter(ActorSystem_, priority, Component_, format, args); + } + +private: + NActors::TActorSystem& ActorSystem_; + NActors::NLog::EComponent Component_; +}; + +} // namespace + +TLoggerPtr CreateActorSystemLogger(NActors::TActorSystem& as, NActors::NLog::EComponent component) { + return MakeIntrusive<TActorSystemLogger>(as, component); +} + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/actors/logger.h b/library/cpp/grpc/server/actors/logger.h new file mode 100644 index 0000000000..abf9270f7b --- /dev/null +++ b/library/cpp/grpc/server/actors/logger.h @@ -0,0 +1,11 @@ +#pragma once + +#include <library/cpp/actors/core/actorsystem.h> +#include <library/cpp/actors/core/log.h> +#include <library/cpp/grpc/server/logger.h> + +namespace NGrpc { + +TLoggerPtr CreateActorSystemLogger(NActors::TActorSystem& as, NActors::NLog::EComponent component); + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/actors/ya.make b/library/cpp/grpc/server/actors/ya.make new file mode 100644 index 0000000000..6c9d80aa45 --- /dev/null +++ b/library/cpp/grpc/server/actors/ya.make @@ -0,0 +1,13 @@ +LIBRARY() + +OWNER(g:kikimr g:solomon) + +SRCS( + logger.cpp +) + +PEERDIR( + library/cpp/actors/core +) + +END() diff --git a/library/cpp/grpc/server/event_callback.cpp b/library/cpp/grpc/server/event_callback.cpp new file mode 100644 index 0000000000..f423836bd6 --- /dev/null +++ b/library/cpp/grpc/server/event_callback.cpp @@ -0,0 +1 @@ +#include "event_callback.h" diff --git a/library/cpp/grpc/server/event_callback.h b/library/cpp/grpc/server/event_callback.h new file mode 100644 index 0000000000..d0b700b3c9 --- /dev/null +++ b/library/cpp/grpc/server/event_callback.h @@ -0,0 +1,80 @@ +#pragma once + +#include "grpc_server.h" + +namespace NGrpc { + +enum class EQueueEventStatus { + OK, + ERROR +}; + +template<class TCallback> +class TQueueEventCallback: public IQueueEvent { +public: + TQueueEventCallback(const TCallback& callback) + : Callback(callback) + {} + + TQueueEventCallback(TCallback&& callback) + : Callback(std::move(callback)) + {} + + bool Execute(bool ok) override { + Callback(ok ? EQueueEventStatus::OK : EQueueEventStatus::ERROR); + return false; + } + + void DestroyRequest() override { + delete this; + } + +private: + TCallback Callback; +}; + +// Implementation of IQueueEvent that reduces allocations +template<class TSelf> +class TQueueFixedEvent: private IQueueEvent { + using TCallback = void (TSelf::*)(EQueueEventStatus); + +public: + TQueueFixedEvent(TSelf* self, TCallback callback) + : Self(self) + , Callback(callback) + { } + + IQueueEvent* Prepare() { + Self->Ref(); + return this; + } + +private: + bool Execute(bool ok) override { + ((*Self).*Callback)(ok ? EQueueEventStatus::OK : EQueueEventStatus::ERROR); + return false; + } + + void DestroyRequest() override { + Self->UnRef(); + } + +private: + TSelf* const Self; + TCallback const Callback; +}; + +template<class TCallback> +inline IQueueEvent* MakeQueueEventCallback(TCallback&& callback) { + return new TQueueEventCallback<TCallback>(std::forward<TCallback>(callback)); +} + +template<class T> +inline IQueueEvent* MakeQueueEventCallback(T* self, void (T::*method)(EQueueEventStatus)) { + using TPtr = TIntrusivePtr<T>; + return MakeQueueEventCallback([self = TPtr(self), method] (EQueueEventStatus status) { + ((*self).*method)(status); + }); +} + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_async_ctx_base.h b/library/cpp/grpc/server/grpc_async_ctx_base.h new file mode 100644 index 0000000000..51356d4ce5 --- /dev/null +++ b/library/cpp/grpc/server/grpc_async_ctx_base.h @@ -0,0 +1,94 @@ +#pragma once + +#include "grpc_server.h" + +#include <util/generic/vector.h> +#include <util/generic/string.h> +#include <util/system/yassert.h> +#include <util/generic/set.h> + +#include <grpc++/server.h> +#include <grpc++/server_context.h> + +#include <chrono> + +namespace NGrpc { + +template<typename TService> +class TBaseAsyncContext: public ICancelableContext { +public: + TBaseAsyncContext(typename TService::TCurrentGRpcService::AsyncService* service, grpc::ServerCompletionQueue* cq) + : Service(service) + , CQ(cq) + { + } + + TString GetPeerName() const { + return TString(Context.peer()); + } + + TInstant Deadline() const { + // The timeout transferred in "grpc-timeout" header [1] and calculated from the deadline + // right before the request is getting to be send. + // 1. https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + // + // After this timeout calculated back to the deadline on the server side + // using server grpc GPR_CLOCK_MONOTONIC time (raw_deadline() method). + // deadline() method convert this to epoch related deadline GPR_CLOCK_REALTIME + // + + std::chrono::system_clock::time_point t = Context.deadline(); + if (t == std::chrono::system_clock::time_point::max()) { + return TInstant::Max(); + } + auto us = std::chrono::time_point_cast<std::chrono::microseconds>(t); + return TInstant::MicroSeconds(us.time_since_epoch().count()); + } + + TSet<TStringBuf> GetPeerMetaKeys() const { + TSet<TStringBuf> keys; + for (const auto& [key, _]: Context.client_metadata()) { + keys.emplace(key.data(), key.size()); + } + return keys; + } + + TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const { + const auto& clientMetadata = Context.client_metadata(); + const auto range = clientMetadata.equal_range(grpc::string_ref{key.data(), key.size()}); + if (range.first == range.second) { + return {}; + } + + TVector<TStringBuf> values; + values.reserve(std::distance(range.first, range.second)); + + for (auto it = range.first; it != range.second; ++it) { + values.emplace_back(it->second.data(), it->second.size()); + } + return values; + } + + grpc_compression_level GetCompressionLevel() const { + return Context.compression_level(); + } + + void Shutdown() override { + // Shutdown may only be called after request has started successfully + if (Context.c_call()) + Context.TryCancel(); + } + +protected: + //! The means of communication with the gRPC runtime for an asynchronous + //! server. + typename TService::TCurrentGRpcService::AsyncService* const Service; + //! The producer-consumer queue where for asynchronous server notifications. + grpc::ServerCompletionQueue* const CQ; + //! Context for the rpc, allowing to tweak aspects of it such as the use + //! of compression, authentication, as well as to send metadata back to the + //! client. + grpc::ServerContext Context; +}; + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_counters.cpp b/library/cpp/grpc/server/grpc_counters.cpp new file mode 100644 index 0000000000..fa96e0100b --- /dev/null +++ b/library/cpp/grpc/server/grpc_counters.cpp @@ -0,0 +1,45 @@ +#include "grpc_counters.h" + +namespace NGrpc { +namespace { + +class TFakeCounterBlock final: public ICounterBlock { +private: + void CountNotOkRequest() override { + } + + void CountNotOkResponse() override { + } + + void CountNotAuthenticated() override { + } + + void CountResourceExhausted() override { + } + + void CountRequestBytes(ui32 /*requestSize*/) override { + } + + void CountResponseBytes(ui32 /*responseSize*/) override { + } + + void StartProcessing(ui32 /*requestSize*/) override { + } + + void FinishProcessing( + ui32 /*requestSize*/, + ui32 /*responseSize*/, + bool /*ok*/, + ui32 /*status*/, + TDuration /*requestDuration*/) override + { + } +}; + +} // namespace + +ICounterBlockPtr FakeCounterBlock() { + return MakeIntrusive<TFakeCounterBlock>(); +} + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_counters.h b/library/cpp/grpc/server/grpc_counters.h new file mode 100644 index 0000000000..0b6c36c84c --- /dev/null +++ b/library/cpp/grpc/server/grpc_counters.h @@ -0,0 +1,136 @@ +#pragma once + +#include <library/cpp/monlib/dynamic_counters/percentile/percentile.h> +#include <library/cpp/monlib/dynamic_counters/counters.h> +#include <util/generic/ptr.h> + +namespace NGrpc { + +struct ICounterBlock : public TThrRefBase { + virtual void CountNotOkRequest() = 0; + virtual void CountNotOkResponse() = 0; + virtual void CountNotAuthenticated() = 0; + virtual void CountResourceExhausted() = 0; + virtual void CountRequestBytes(ui32 requestSize) = 0; + virtual void CountResponseBytes(ui32 responseSize) = 0; + virtual void StartProcessing(ui32 requestSize) = 0; + virtual void FinishProcessing(ui32 requestSize, ui32 responseSize, bool ok, ui32 status, TDuration requestDuration) = 0; + virtual void CountRequestsWithoutDatabase() {} + virtual void CountRequestsWithoutToken() {} + virtual void CountRequestWithoutTls() {} + + virtual TIntrusivePtr<ICounterBlock> Clone() { return this; } + virtual void UseDatabase(const TString& database) { Y_UNUSED(database); } +}; + +using ICounterBlockPtr = TIntrusivePtr<ICounterBlock>; + +class TCounterBlock final : public ICounterBlock { + NMonitoring::TDynamicCounters::TCounterPtr TotalCounter; + NMonitoring::TDynamicCounters::TCounterPtr InflyCounter; + NMonitoring::TDynamicCounters::TCounterPtr NotOkRequestCounter; + NMonitoring::TDynamicCounters::TCounterPtr NotOkResponseCounter; + NMonitoring::TDynamicCounters::TCounterPtr RequestBytes; + NMonitoring::TDynamicCounters::TCounterPtr InflyRequestBytes; + NMonitoring::TDynamicCounters::TCounterPtr ResponseBytes; + NMonitoring::TDynamicCounters::TCounterPtr NotAuthenticated; + NMonitoring::TDynamicCounters::TCounterPtr ResourceExhausted; + bool Percentile = false; + NMonitoring::TPercentileTracker<4, 512, 15> RequestHistMs; + std::array<NMonitoring::TDynamicCounters::TCounterPtr, 2> GRpcStatusCounters; + +public: + TCounterBlock(NMonitoring::TDynamicCounters::TCounterPtr totalCounter, + NMonitoring::TDynamicCounters::TCounterPtr inflyCounter, + NMonitoring::TDynamicCounters::TCounterPtr notOkRequestCounter, + NMonitoring::TDynamicCounters::TCounterPtr notOkResponseCounter, + NMonitoring::TDynamicCounters::TCounterPtr requestBytes, + NMonitoring::TDynamicCounters::TCounterPtr inflyRequestBytes, + NMonitoring::TDynamicCounters::TCounterPtr responseBytes, + NMonitoring::TDynamicCounters::TCounterPtr notAuthenticated, + NMonitoring::TDynamicCounters::TCounterPtr resourceExhausted, + TIntrusivePtr<NMonitoring::TDynamicCounters> group) + : TotalCounter(std::move(totalCounter)) + , InflyCounter(std::move(inflyCounter)) + , NotOkRequestCounter(std::move(notOkRequestCounter)) + , NotOkResponseCounter(std::move(notOkResponseCounter)) + , RequestBytes(std::move(requestBytes)) + , InflyRequestBytes(std::move(inflyRequestBytes)) + , ResponseBytes(std::move(responseBytes)) + , NotAuthenticated(std::move(notAuthenticated)) + , ResourceExhausted(std::move(resourceExhausted)) + { + if (group) { + RequestHistMs.Initialize(group, "event", "request", "ms", {0.5f, 0.9f, 0.99f, 0.999f, 1.0f}); + Percentile = true; + } + } + + void CountNotOkRequest() override { + NotOkRequestCounter->Inc(); + } + + void CountNotOkResponse() override { + NotOkResponseCounter->Inc(); + } + + void CountNotAuthenticated() override { + NotAuthenticated->Inc(); + } + + void CountResourceExhausted() override { + ResourceExhausted->Inc(); + } + + void CountRequestBytes(ui32 requestSize) override { + *RequestBytes += requestSize; + } + + void CountResponseBytes(ui32 responseSize) override { + *ResponseBytes += responseSize; + } + + void StartProcessing(ui32 requestSize) override { + TotalCounter->Inc(); + InflyCounter->Inc(); + *RequestBytes += requestSize; + *InflyRequestBytes += requestSize; + } + + void FinishProcessing(ui32 requestSize, ui32 responseSize, bool ok, ui32 status, + TDuration requestDuration) override + { + Y_UNUSED(status); + + InflyCounter->Dec(); + *InflyRequestBytes -= requestSize; + *ResponseBytes += responseSize; + if (!ok) { + NotOkResponseCounter->Inc(); + } + if (Percentile) { + RequestHistMs.Increment(requestDuration.MilliSeconds()); + } + } + + ICounterBlockPtr Clone() override { + return this; + } + + void Update() { + if (Percentile) { + RequestHistMs.Update(); + } + } +}; + +using TCounterBlockPtr = TIntrusivePtr<TCounterBlock>; + +/** + * Creates new instance of ICounterBlock implementation which does nothing. + * + * @return new instance + */ +ICounterBlockPtr FakeCounterBlock(); + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_request.cpp b/library/cpp/grpc/server/grpc_request.cpp new file mode 100644 index 0000000000..d18a32776f --- /dev/null +++ b/library/cpp/grpc/server/grpc_request.cpp @@ -0,0 +1,59 @@ +#include "grpc_request.h" + +namespace NGrpc { + +const char* GRPC_USER_AGENT_HEADER = "user-agent"; + +class TStreamAdaptor: public IStreamAdaptor { +public: + TStreamAdaptor() + : StreamIsReady_(true) + {} + + void Enqueue(std::function<void()>&& fn, bool urgent) override { + with_lock(Mtx_) { + if (!UrgentQueue_.empty() || !NormalQueue_.empty()) { + Y_VERIFY(!StreamIsReady_); + } + auto& queue = urgent ? UrgentQueue_ : NormalQueue_; + if (StreamIsReady_ && queue.empty()) { + StreamIsReady_ = false; + } else { + queue.push_back(std::move(fn)); + return; + } + } + fn(); + } + + size_t ProcessNext() override { + size_t left = 0; + std::function<void()> fn; + with_lock(Mtx_) { + Y_VERIFY(!StreamIsReady_); + auto& queue = UrgentQueue_.empty() ? NormalQueue_ : UrgentQueue_; + if (queue.empty()) { + // Both queues are empty + StreamIsReady_ = true; + } else { + fn = std::move(queue.front()); + queue.pop_front(); + left = UrgentQueue_.size() + NormalQueue_.size(); + } + } + if (fn) + fn(); + return left; + } +private: + bool StreamIsReady_; + TList<std::function<void()>> NormalQueue_; + TList<std::function<void()>> UrgentQueue_; + TMutex Mtx_; +}; + +IStreamAdaptor::TPtr CreateStreamAdaptor() { + return std::make_unique<TStreamAdaptor>(); +} + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_request.h b/library/cpp/grpc/server/grpc_request.h new file mode 100644 index 0000000000..5bd8d3902b --- /dev/null +++ b/library/cpp/grpc/server/grpc_request.h @@ -0,0 +1,543 @@ +#pragma once + +#include <google/protobuf/text_format.h> +#include <google/protobuf/arena.h> +#include <google/protobuf/message.h> + +#include <library/cpp/monlib/dynamic_counters/counters.h> +#include <library/cpp/logger/priority.h> + +#include "grpc_response.h" +#include "event_callback.h" +#include "grpc_async_ctx_base.h" +#include "grpc_counters.h" +#include "grpc_request_base.h" +#include "grpc_server.h" +#include "logger.h" + +#include <util/system/hp_timer.h> + +#include <grpc++/server.h> +#include <grpc++/server_context.h> +#include <grpc++/support/async_stream.h> +#include <grpc++/support/async_unary_call.h> +#include <grpc++/support/byte_buffer.h> +#include <grpc++/impl/codegen/async_stream.h> + +namespace NGrpc { + +class IStreamAdaptor { +public: + using TPtr = std::unique_ptr<IStreamAdaptor>; + virtual void Enqueue(std::function<void()>&& fn, bool urgent) = 0; + virtual size_t ProcessNext() = 0; + virtual ~IStreamAdaptor() = default; +}; + +IStreamAdaptor::TPtr CreateStreamAdaptor(); + +/////////////////////////////////////////////////////////////////////////////// +template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter, typename TOutProtoPrinter> +class TGRpcRequestImpl + : public TBaseAsyncContext<TService> + , public IQueueEvent + , public IRequestContextBase +{ + using TThis = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; + +public: + using TOnRequest = std::function<void (IRequestContextBase* ctx)>; + using TRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, + grpc::ServerAsyncResponseWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); + using TStreamRequestCallback = void (TService::TCurrentGRpcService::AsyncService::*)(grpc::ServerContext*, TIn*, + grpc::ServerAsyncWriter<TOut>*, grpc::CompletionQueue*, grpc::ServerCompletionQueue*, void*); + + TGRpcRequestImpl(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + TOnRequest cb, + TRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter) + : TBaseAsyncContext<TService>(service, cq) + , Server_(server) + , Cb_(cb) + , RequestCallback_(requestCallback) + , StreamRequestCallback_(nullptr) + , Name_(name) + , Logger_(std::move(logger)) + , Counters_(std::move(counters)) + , RequestLimiter_(std::move(limiter)) + , Writer_(new grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>(&this->Context)) + , StateFunc_(&TThis::SetRequestDone) + { + AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); + Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); + Y_VERIFY(Request_); + GRPC_LOG_DEBUG(Logger_, "[%p] created request Name# %s", this, Name_); + FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); + } + + TGRpcRequestImpl(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + TOnRequest cb, + TStreamRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter) + : TBaseAsyncContext<TService>(service, cq) + , Server_(server) + , Cb_(cb) + , RequestCallback_(nullptr) + , StreamRequestCallback_(requestCallback) + , Name_(name) + , Logger_(std::move(logger)) + , Counters_(std::move(counters)) + , RequestLimiter_(std::move(limiter)) + , StreamWriter_(new grpc::ServerAsyncWriter<TUniversalResponse<TOut>>(&this->Context)) + , StateFunc_(&TThis::SetRequestDone) + { + AuthState_ = Server_->NeedAuth() ? TAuthState(true) : TAuthState(false); + Request_ = google::protobuf::Arena::CreateMessage<TIn>(&Arena_); + Y_VERIFY(Request_); + GRPC_LOG_DEBUG(Logger_, "[%p] created streaming request Name# %s", this, Name_); + FinishPromise_ = NThreading::NewPromise<EFinishStatus>(); + StreamAdaptor_ = CreateStreamAdaptor(); + } + + TAsyncFinishResult GetFinishFuture() override { + return FinishPromise_.GetFuture(); + } + + TString GetPeer() const override { + return TString(this->Context.peer()); + } + + bool SslServer() const override { + return Server_->SslServer(); + } + + void Run() { + // Start request unless server is shutting down + if (auto guard = Server_->ProtectShutdown()) { + Ref(); //For grpc c runtime + this->Context.AsyncNotifyWhenDone(OnFinishTag.Prepare()); + if (RequestCallback_) { + (this->Service->*RequestCallback_) + (&this->Context, Request_, + reinterpret_cast<grpc::ServerAsyncResponseWriter<TOut>*>(Writer_.Get()), this->CQ, this->CQ, GetGRpcTag()); + } else { + (this->Service->*StreamRequestCallback_) + (&this->Context, Request_, + reinterpret_cast<grpc::ServerAsyncWriter<TOut>*>(StreamWriter_.Get()), this->CQ, this->CQ, GetGRpcTag()); + } + } + } + + ~TGRpcRequestImpl() { + // No direct dtor call allowed + Y_ASSERT(RefCount() == 0); + } + + bool Execute(bool ok) override { + return (this->*StateFunc_)(ok); + } + + void DestroyRequest() override { + if (RequestRegistered_) { + Server_->DeregisterRequestCtx(this); + RequestRegistered_ = false; + } + UnRef(); + } + + TInstant Deadline() const override { + return TBaseAsyncContext<TService>::Deadline(); + } + + TSet<TStringBuf> GetPeerMetaKeys() const override { + return TBaseAsyncContext<TService>::GetPeerMetaKeys(); + } + + TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const override { + return TBaseAsyncContext<TService>::GetPeerMetaValues(key); + } + + grpc_compression_level GetCompressionLevel() const override { + return TBaseAsyncContext<TService>::GetCompressionLevel(); + } + + //! Get pointer to the request's message. + const NProtoBuf::Message* GetRequest() const override { + return Request_; + } + + TAuthState& GetAuthState() override { + return AuthState_; + } + + void Reply(NProtoBuf::Message* resp, ui32 status) override { + ResponseStatus = status; + WriteDataOk(resp); + } + + void Reply(grpc::ByteBuffer* resp, ui32 status) override { + ResponseStatus = status; + WriteByteDataOk(resp); + } + + void ReplyError(grpc::StatusCode code, const TString& msg) override { + FinishGrpcStatus(code, msg, false); + } + + void ReplyUnauthenticated(const TString& in) override { + const TString message = in.empty() ? TString("unauthenticated") : TString("unauthenticated, ") + in; + FinishGrpcStatus(grpc::StatusCode::UNAUTHENTICATED, message, false); + } + + void SetNextReplyCallback(TOnNextReply&& cb) override { + NextReplyCb_ = cb; + } + + void AddTrailingMetadata(const TString& key, const TString& value) override { + this->Context.AddTrailingMetadata(key, value); + } + + void FinishStreamingOk() override { + GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (enqueued)", this, Name_, + this->Context.peer().c_str()); + auto cb = [this]() { + StateFunc_ = &TThis::SetFinishDone; + GRPC_LOG_DEBUG(Logger_, "[%p] finished streaming Name# %s peer# %s (pushed to grpc)", this, Name_, + this->Context.peer().c_str()); + + StreamWriter_->Finish(grpc::Status::OK, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + + google::protobuf::Arena* GetArena() override { + return &Arena_; + } + + void UseDatabase(const TString& database) override { + Counters_->UseDatabase(database); + } + +private: + void Clone() { + if (!Server_->IsShuttingDown()) { + if (RequestCallback_) { + MakeIntrusive<TThis>( + Server_, this->Service, this->CQ, Cb_, RequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); + } else { + MakeIntrusive<TThis>( + Server_, this->Service, this->CQ, Cb_, StreamRequestCallback_, Name_, Logger_, Counters_->Clone(), RequestLimiter_)->Run(); + } + } + } + + void WriteDataOk(NProtoBuf::Message* resp) { + auto makeResponseString = [&] { + TString x; + TOutProtoPrinter printer; + printer.SetSingleLineMode(true); + printer.PrintToString(*resp, &x); + return x; + }; + + auto sz = (size_t)resp->ByteSize(); + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s", this, Name_, + makeResponseString().data(), this->Context.peer().c_str()); + StateFunc_ = &TThis::SetFinishDone; + ResponseSize = sz; + Y_VERIFY(this->Context.c_call()); + Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (enqueued)", + this, Name_, makeResponseString().data(), this->Context.peer().c_str()); + + // because of std::function cannot hold move-only captured object + // we allocate shared object on heap to avoid message copy + auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); + auto cb = [this, uResp = std::move(uResp), sz, &makeResponseString]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# %s peer# %s (pushed to grpc)", + this, Name_, makeResponseString().data(), this->Context.peer().c_str()); + StateFunc_ = &TThis::NextReply; + ResponseSize += sz; + StreamWriter_->Write(*uResp, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + } + + void WriteByteDataOk(grpc::ByteBuffer* resp) { + auto sz = resp->Length(); + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s", this, Name_, + this->Context.peer().c_str()); + StateFunc_ = &TThis::SetFinishDone; + ResponseSize = sz; + Writer_->Finish(TUniversalResponseRef<TOut>(resp), grpc::Status::OK, GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (enqueued)", this, Name_, + this->Context.peer().c_str()); + + // because of std::function cannot hold move-only captured object + // we allocate shared object on heap to avoid buffer copy + auto uResp = MakeIntrusive<TUniversalResponse<TOut>>(resp); + auto cb = [this, uResp = std::move(uResp), sz]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s data# byteString peer# %s (pushed to grpc)", + this, Name_, this->Context.peer().c_str()); + StateFunc_ = &TThis::NextReply; + ResponseSize += sz; + StreamWriter_->Write(*uResp, GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), false); + } + } + + void FinishGrpcStatus(grpc::StatusCode code, const TString& msg, bool urgent) { + Y_VERIFY(code != grpc::OK); + if (code == grpc::StatusCode::UNAUTHENTICATED) { + Counters_->CountNotAuthenticated(); + } else if (code == grpc::StatusCode::RESOURCE_EXHAUSTED) { + Counters_->CountResourceExhausted(); + } + + if (Writer_) { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)", this, + Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); + StateFunc_ = &TThis::SetFinishError; + TOut resp; + Writer_->Finish(TUniversalResponseRef<TOut>(&resp), grpc::Status(code, msg), GetGRpcTag()); + } else { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" + " (enqueued)", this, Name_, msg.c_str(), this->Context.peer().c_str(), (int)code); + auto cb = [this, code, msg]() { + GRPC_LOG_DEBUG(Logger_, "[%p] issuing response Name# %s nodata (%s) peer# %s, grpc status# (%d)" + " (pushed to grpc)", this, Name_, msg.c_str(), + this->Context.peer().c_str(), (int)code); + StateFunc_ = &TThis::SetFinishError; + StreamWriter_->Finish(grpc::Status(code, msg), GetGRpcTag()); + }; + StreamAdaptor_->Enqueue(std::move(cb), urgent); + } + } + + bool SetRequestDone(bool ok) { + auto makeRequestString = [&] { + TString resp; + if (ok) { + TInProtoPrinter printer; + printer.SetSingleLineMode(true); + printer.PrintToString(*Request_, &resp); + } else { + resp = "<not ok>"; + } + return resp; + }; + GRPC_LOG_DEBUG(Logger_, "[%p] received request Name# %s ok# %s data# %s peer# %s", this, Name_, + ok ? "true" : "false", makeRequestString().data(), this->Context.peer().c_str()); + + if (this->Context.c_call() == nullptr) { + Y_VERIFY(!ok); + // One ref by OnFinishTag, grpc will not call this tag if no request received + UnRef(); + } else if (!(RequestRegistered_ = Server_->RegisterRequestCtx(this))) { + // Request cannot be registered due to shutdown + // It's unsafe to continue, so drop this request without processing + GRPC_LOG_DEBUG(Logger_, "[%p] dropping request Name# %s due to shutdown", this, Name_); + this->Context.TryCancel(); + return false; + } + + Clone(); // TODO: Request pool? + if (!ok) { + Counters_->CountNotOkRequest(); + return false; + } + + if (IncRequest()) { + // Adjust counters. + RequestSize = Request_->ByteSize(); + Counters_->StartProcessing(RequestSize); + RequestTimer.Reset(); + + if (!SslServer()) { + Counters_->CountRequestWithoutTls(); + } + + //TODO: Move this in to grpc_request_proxy + auto maybeDatabase = GetPeerMetaValues(TStringBuf("x-ydb-database")); + if (maybeDatabase.empty()) { + Counters_->CountRequestsWithoutDatabase(); + } + auto maybeToken = GetPeerMetaValues(TStringBuf("x-ydb-auth-ticket")); + if (maybeToken.empty() || maybeToken[0].empty()) { + TString db{maybeDatabase ? maybeDatabase[0] : TStringBuf{}}; + Counters_->CountRequestsWithoutToken(); + GRPC_LOG_DEBUG(Logger_, "[%p] received request without user token " + "Name# %s data# %s peer# %s database# %s", this, Name_, + makeRequestString().data(), this->Context.peer().c_str(), db.c_str()); + } + + // Handle current request. + Cb_(this); + } else { + //This request has not been counted + SkipUpdateCountersOnError = true; + FinishGrpcStatus(grpc::StatusCode::RESOURCE_EXHAUSTED, "no resource", true); + } + return true; + } + + bool NextReply(bool ok) { + auto logCb = [this, ok](int left) { + GRPC_LOG_DEBUG(Logger_, "[%p] ready for next reply Name# %s ok# %s peer# %s left# %d", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str(), left); + }; + + if (!ok) { + logCb(-1); + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + return false; + } + + Ref(); // To prevent destroy during this call in case of execution Finish + size_t left = StreamAdaptor_->ProcessNext(); + logCb(left); + if (NextReplyCb_) { + NextReplyCb_(left); + } + // Now it is safe to destroy even if Finish was called + UnRef(); + return true; + } + + bool SetFinishDone(bool ok) { + GRPC_LOG_DEBUG(Logger_, "[%p] finished request Name# %s ok# %s peer# %s", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str()); + //PrintBackTrace(); + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + return false; + } + + bool SetFinishError(bool ok) { + GRPC_LOG_DEBUG(Logger_, "[%p] finished request with error Name# %s ok# %s peer# %s", this, Name_, + ok ? "true" : "false", this->Context.peer().c_str()); + if (!SkipUpdateCountersOnError) { + DecRequest(); + Counters_->FinishProcessing(RequestSize, ResponseSize, ok, ResponseStatus, + TDuration::Seconds(RequestTimer.Passed())); + } + return false; + } + + // Returns pointer to IQueueEvent to pass into grpc c runtime + // Implicit C style cast from this to void* is wrong due to multiple inheritance + void* GetGRpcTag() { + return static_cast<IQueueEvent*>(this); + } + + void OnFinish(EQueueEventStatus evStatus) { + if (this->Context.IsCancelled()) { + FinishPromise_.SetValue(EFinishStatus::CANCEL); + } else { + FinishPromise_.SetValue(evStatus == EQueueEventStatus::OK ? EFinishStatus::OK : EFinishStatus::ERROR); + } + } + + bool IncRequest() { + if (!Server_->IncRequest()) + return false; + + if (!RequestLimiter_) + return true; + + if (!RequestLimiter_->IncRequest()) { + Server_->DecRequest(); + return false; + } + + return true; + } + + void DecRequest() { + if (RequestLimiter_) { + RequestLimiter_->DecRequest(); + } + Server_->DecRequest(); + } + + using TStateFunc = bool (TThis::*)(bool); + TService* Server_; + TOnRequest Cb_; + TRequestCallback RequestCallback_; + TStreamRequestCallback StreamRequestCallback_; + const char* const Name_; + TLoggerPtr Logger_; + ICounterBlockPtr Counters_; + IGRpcRequestLimiterPtr RequestLimiter_; + + THolder<grpc::ServerAsyncResponseWriter<TUniversalResponseRef<TOut>>> Writer_; + THolder<grpc::ServerAsyncWriterInterface<TUniversalResponse<TOut>>> StreamWriter_; + TStateFunc StateFunc_; + TIn* Request_; + + google::protobuf::Arena Arena_; + TOnNextReply NextReplyCb_; + ui32 RequestSize = 0; + ui32 ResponseSize = 0; + ui32 ResponseStatus = 0; + THPTimer RequestTimer; + TAuthState AuthState_ = 0; + bool RequestRegistered_ = false; + + using TFixedEvent = TQueueFixedEvent<TGRpcRequestImpl>; + TFixedEvent OnFinishTag = { this, &TGRpcRequestImpl::OnFinish }; + NThreading::TPromise<EFinishStatus> FinishPromise_; + bool SkipUpdateCountersOnError = false; + IStreamAdaptor::TPtr StreamAdaptor_; +}; + +template<typename TIn, typename TOut, typename TService, typename TInProtoPrinter=google::protobuf::TextFormat::Printer, typename TOutProtoPrinter=google::protobuf::TextFormat::Printer> +class TGRpcRequest: public TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter> { + using TBase = TGRpcRequestImpl<TIn, TOut, TService, TInProtoPrinter, TOutProtoPrinter>; +public: + TGRpcRequest(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + typename TBase::TOnRequest cb, + typename TBase::TRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters, + IGRpcRequestLimiterPtr limiter = nullptr) + : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), std::move(limiter)} + { + } + + TGRpcRequest(TService* server, + typename TService::TCurrentGRpcService::AsyncService* service, + grpc::ServerCompletionQueue* cq, + typename TBase::TOnRequest cb, + typename TBase::TStreamRequestCallback requestCallback, + const char* name, + TLoggerPtr logger, + ICounterBlockPtr counters) + : TBase{server, service, cq, std::move(cb), std::move(requestCallback), name, std::move(logger), std::move(counters), nullptr} + { + } +}; + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_request_base.h b/library/cpp/grpc/server/grpc_request_base.h new file mode 100644 index 0000000000..fcfce1c181 --- /dev/null +++ b/library/cpp/grpc/server/grpc_request_base.h @@ -0,0 +1,116 @@ +#pragma once + +#include <google/protobuf/message.h> +#include <library/cpp/threading/future/future.h> + +#include <grpc++/server_context.h> + +namespace grpc { +class ByteBuffer; +} + +namespace NGrpc { + +extern const char* GRPC_USER_AGENT_HEADER; + +struct TAuthState { + enum EAuthState { + AS_NOT_PERFORMED, + AS_OK, + AS_FAIL, + AS_UNAVAILABLE + }; + TAuthState(bool needAuth) + : NeedAuth(needAuth) + , State(AS_NOT_PERFORMED) + {} + bool NeedAuth; + EAuthState State; +}; + + +//! An interface that may be used to limit concurrency of requests +class IGRpcRequestLimiter: public TThrRefBase { +public: + virtual bool IncRequest() = 0; + virtual void DecRequest() = 0; +}; + +using IGRpcRequestLimiterPtr = TIntrusivePtr<IGRpcRequestLimiter>; + +//! State of current request +class IRequestContextBase: public TThrRefBase { +public: + enum class EFinishStatus { + OK, + ERROR, + CANCEL + }; + using TAsyncFinishResult = NThreading::TFuture<EFinishStatus>; + + using TOnNextReply = std::function<void (size_t left)>; + + //! Get pointer to the request's message. + virtual const NProtoBuf::Message* GetRequest() const = 0; + + //! Get current auth state + virtual TAuthState& GetAuthState() = 0; + + //! Send common response (The request shoult be created for protobuf response type) + //! Implementation can swap protobuf message + virtual void Reply(NProtoBuf::Message* resp, ui32 status = 0) = 0; + + //! Send serialised response (The request shoult be created for bytes response type) + //! Implementation can swap ByteBuffer + virtual void Reply(grpc::ByteBuffer* resp, ui32 status = 0) = 0; + + //! Send grpc UNAUTHENTICATED status + virtual void ReplyUnauthenticated(const TString& in) = 0; + + //! Send grpc error + virtual void ReplyError(grpc::StatusCode code, const TString& msg) = 0; + + //! Returns deadline (server epoch related) if peer set it on its side, or Instanse::Max() otherwise + virtual TInstant Deadline() const = 0; + + //! Returns available peer metadata keys + virtual TSet<TStringBuf> GetPeerMetaKeys() const = 0; + + //! Returns peer optional metavalue + virtual TVector<TStringBuf> GetPeerMetaValues(TStringBuf key) const = 0; + + //! Returns request compression level + virtual grpc_compression_level GetCompressionLevel() const = 0; + + //! Returns protobuf arena allocator associated with current request + //! Lifetime of the arena is lifetime of the context + virtual google::protobuf::Arena* GetArena() = 0; + + //! Add trailing metadata in to grpc context + //! The metadata will be send at the time of rpc finish + virtual void AddTrailingMetadata(const TString& key, const TString& value) = 0; + + //! Use validated database name for counters + virtual void UseDatabase(const TString& database) = 0; + + // Streaming part + + //! Set callback. The callback will be called when response deliverid to the client + //! after that we can call Reply again in streaming mode. Yes, GRpc says there is only one + //! reply in flight + virtual void SetNextReplyCallback(TOnNextReply&& cb) = 0; + + //! Finish streaming reply + virtual void FinishStreamingOk() = 0; + + //! Returns future to get cancel of finish notification + virtual TAsyncFinishResult GetFinishFuture() = 0; + + //! Returns peer address + virtual TString GetPeer() const = 0; + + //! Returns true if server is using ssl + virtual bool SslServer() const = 0; +}; + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_response.h b/library/cpp/grpc/server/grpc_response.h new file mode 100644 index 0000000000..8e9afe44d5 --- /dev/null +++ b/library/cpp/grpc/server/grpc_response.h @@ -0,0 +1,90 @@ +#pragma once + +#include <grpc++/impl/codegen/byte_buffer.h> +#include <grpc++/impl/codegen/proto_utils.h> + +#include <variant> + +namespace NGrpc { + +/** + * Universal response that owns underlying message or buffer. + */ +template <typename TMsg> +class TUniversalResponse: public TAtomicRefCount<TUniversalResponse<TMsg>>, public TMoveOnly { + friend class grpc::SerializationTraits<TUniversalResponse<TMsg>>; + +public: + explicit TUniversalResponse(NProtoBuf::Message* msg) noexcept + : Data_{TMsg{}} + { + std::get<TMsg>(Data_).Swap(static_cast<TMsg*>(msg)); + } + + explicit TUniversalResponse(grpc::ByteBuffer* buffer) noexcept + : Data_{grpc::ByteBuffer{}} + { + std::get<grpc::ByteBuffer>(Data_).Swap(buffer); + } + +private: + std::variant<TMsg, grpc::ByteBuffer> Data_; +}; + +/** + * Universal response that only keeps reference to underlying message or buffer. + */ +template <typename TMsg> +class TUniversalResponseRef: private TMoveOnly { + friend class grpc::SerializationTraits<TUniversalResponseRef<TMsg>>; + +public: + explicit TUniversalResponseRef(const NProtoBuf::Message* msg) + : Data_{msg} + { + } + + explicit TUniversalResponseRef(const grpc::ByteBuffer* buffer) + : Data_{buffer} + { + } + +private: + std::variant<const NProtoBuf::Message*, const grpc::ByteBuffer*> Data_; +}; + +} // namespace NGrpc + +namespace grpc { + +template <typename TMsg> +class SerializationTraits<NGrpc::TUniversalResponse<TMsg>> { +public: + static Status Serialize( + const NGrpc::TUniversalResponse<TMsg>& resp, + ByteBuffer* buffer, + bool* ownBuffer) + { + return std::visit([&](const auto& data) { + using T = std::decay_t<decltype(data)>; + return SerializationTraits<T>::Serialize(data, buffer, ownBuffer); + }, resp.Data_); + } +}; + +template <typename TMsg> +class SerializationTraits<NGrpc::TUniversalResponseRef<TMsg>> { +public: + static Status Serialize( + const NGrpc::TUniversalResponseRef<TMsg>& resp, + ByteBuffer* buffer, + bool* ownBuffer) + { + return std::visit([&](const auto* data) { + using T = std::decay_t<std::remove_pointer_t<decltype(data)>>; + return SerializationTraits<T>::Serialize(*data, buffer, ownBuffer); + }, resp.Data_); + } +}; + +} // namespace grpc diff --git a/library/cpp/grpc/server/grpc_server.cpp b/library/cpp/grpc/server/grpc_server.cpp new file mode 100644 index 0000000000..7437b7a8f5 --- /dev/null +++ b/library/cpp/grpc/server/grpc_server.cpp @@ -0,0 +1,240 @@ +#include "grpc_server.h" + +#include <util/string/join.h> +#include <util/generic/yexception.h> +#include <util/system/thread.h> + +#include <grpc++/resource_quota.h> +#include <contrib/libs/grpc/src/core/lib/iomgr/socket_mutator.h> + +#if !defined(_WIN32) && !defined(_WIN64) + +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> + +#endif + +namespace NGrpc { + +using NThreading::TFuture; + +static void PullEvents(grpc::ServerCompletionQueue* cq) { + TThread::SetCurrentThreadName("grpc_server"); + while (true) { + void* tag; // uniquely identifies a request. + bool ok; + + if (cq->Next(&tag, &ok)) { + IQueueEvent* const ev(static_cast<IQueueEvent*>(tag)); + + if (!ev->Execute(ok)) { + ev->DestroyRequest(); + } + } else { + break; + } + } +} + +TGRpcServer::TGRpcServer(const TServerOptions& opts) + : Options_(opts) + , Limiter_(Options_.MaxGlobalRequestInFlight) + {} + +TGRpcServer::~TGRpcServer() { + Y_VERIFY(Ts.empty()); + Services_.clear(); +} + +void TGRpcServer::AddService(IGRpcServicePtr service) { + Services_.push_back(service); +} + +void TGRpcServer::Start() { + TString server_address(Join(":", Options_.Host, Options_.Port)); // https://st.yandex-team.ru/DTCC-695 + using grpc::ServerBuilder; + using grpc::ResourceQuota; + ServerBuilder builder; + auto credentials = grpc::InsecureServerCredentials(); + if (Options_.SslData) { + grpc::SslServerCredentialsOptions::PemKeyCertPair keycert; + keycert.cert_chain = std::move(Options_.SslData->Cert); + keycert.private_key = std::move(Options_.SslData->Key); + grpc::SslServerCredentialsOptions sslOps; + sslOps.pem_root_certs = std::move(Options_.SslData->Root); + sslOps.pem_key_cert_pairs.push_back(keycert); + credentials = grpc::SslServerCredentials(sslOps); + } + if (Options_.ExternalListener) { + Options_.ExternalListener->Init(builder.experimental().AddExternalConnectionAcceptor( + ServerBuilder::experimental_type::ExternalConnectionType::FROM_FD, + credentials + )); + } else { + builder.AddListeningPort(server_address, credentials); + } + builder.SetMaxReceiveMessageSize(Options_.MaxMessageSize); + builder.SetMaxSendMessageSize(Options_.MaxMessageSize); + for (IGRpcServicePtr service : Services_) { + service->SetServerOptions(Options_); + builder.RegisterService(service->GetService()); + service->SetGlobalLimiterHandle(&Limiter_); + } + + class TKeepAliveOption: public grpc::ServerBuilderOption { + public: + TKeepAliveOption(int idle, int interval) + : Idle(idle) + , Interval(interval) + , KeepAliveEnabled(true) + {} + + TKeepAliveOption() + : Idle(0) + , Interval(0) + , KeepAliveEnabled(false) + {} + + void UpdateArguments(grpc::ChannelArguments *args) override { + args->SetInt(GRPC_ARG_HTTP2_MAX_PING_STRIKES, 0); + args->SetInt(GRPC_ARG_HTTP2_MIN_RECV_PING_INTERVAL_WITHOUT_DATA_MS, 1000); + if (KeepAliveEnabled) { + args->SetInt(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0); + args->SetInt(GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS, 1); + args->SetInt(GRPC_ARG_HTTP2_MIN_SENT_PING_INTERVAL_WITHOUT_DATA_MS, Idle * 1000); + args->SetInt(GRPC_ARG_KEEPALIVE_TIME_MS, Idle * 1000); + args->SetInt(GRPC_ARG_KEEPALIVE_TIMEOUT_MS, Interval * 1000); + } + } + + void UpdatePlugins(std::vector<std::unique_ptr<grpc::ServerBuilderPlugin>>* /*plugins*/) override + {} + private: + const int Idle; + const int Interval; + const bool KeepAliveEnabled; + }; + + if (Options_.KeepAliveEnable) { + builder.SetOption(std::make_unique<TKeepAliveOption>( + Options_.KeepAliveIdleTimeoutTriggerSec, + Options_.KeepAliveProbeIntervalSec)); + } else { + builder.SetOption(std::make_unique<TKeepAliveOption>()); + } + + if (Options_.UseCompletionQueuePerThread) { + for (size_t i = 0; i < Options_.WorkerThreads; ++i) { + CQS_.push_back(builder.AddCompletionQueue()); + } + } else { + CQS_.push_back(builder.AddCompletionQueue()); + } + + if (Options_.GRpcMemoryQuotaBytes) { + // See details KIKIMR-6932 + /* + grpc::ResourceQuota quota("memory_bound"); + quota.Resize(Options_.GRpcMemoryQuotaBytes); + + builder.SetResourceQuota(quota); + */ + Cerr << "GRpc memory quota temporarily disabled due to issues with grpc quoter" << Endl; + } + Options_.ServerBuilderMutator(builder); + builder.SetDefaultCompressionLevel(Options_.DefaultCompressionLevel); + + Server_ = builder.BuildAndStart(); + if (!Server_) { + ythrow yexception() << "can't start grpc server on " << server_address; + } + + size_t index = 0; + for (IGRpcServicePtr service : Services_) { + // TODO: provide something else for services instead of ServerCompletionQueue + service->InitService(CQS_[index++ % CQS_.size()].get(), Options_.Logger); + } + + if (Options_.UseCompletionQueuePerThread) { + for (size_t i = 0; i < Options_.WorkerThreads; ++i) { + auto* cq = &CQS_[i]; + Ts.push_back(SystemThreadFactory()->Run([cq] { + PullEvents(cq->get()); + })); + } + } else { + for (size_t i = 0; i < Options_.WorkerThreads; ++i) { + auto* cq = &CQS_[0]; + Ts.push_back(SystemThreadFactory()->Run([cq] { + PullEvents(cq->get()); + })); + } + } + + if (Options_.ExternalListener) { + Options_.ExternalListener->Start(); + } +} + +void TGRpcServer::Stop() { + for (auto& service : Services_) { + service->StopService(); + } + + auto now = TInstant::Now(); + + if (Server_) { + i64 sec = Options_.GRpcShutdownDeadline.Seconds(); + Y_VERIFY(Options_.GRpcShutdownDeadline.NanoSecondsOfSecond() <= Max<i32>()); + i32 nanosecOfSec = Options_.GRpcShutdownDeadline.NanoSecondsOfSecond(); + Server_->Shutdown(gpr_timespec{sec, nanosecOfSec, GPR_TIMESPAN}); + } + + for (ui64 attempt = 0; ; ++attempt) { + bool unsafe = false; + size_t infly = 0; + for (auto& service : Services_) { + unsafe |= service->IsUnsafeToShutdown(); + infly += service->RequestsInProgress(); + } + + if (!unsafe && !infly) + break; + + auto spent = (TInstant::Now() - now).SecondsFloat(); + if (attempt % 300 == 0) { + // don't log too much + Cerr << "GRpc shutdown warning: left infly: " << infly << ", spent: " << spent << " sec" << Endl; + } + + if (!unsafe && spent > Options_.GRpcShutdownDeadline.SecondsFloat()) + break; + Sleep(TDuration::MilliSeconds(10)); + } + + // Always shutdown the completion queue after the server. + for (auto& cq : CQS_) { + cq->Shutdown(); + } + + for (auto ti = Ts.begin(); ti != Ts.end(); ++ti) { + (*ti)->Join(); + } + + Ts.clear(); + + if (Options_.ExternalListener) { + Options_.ExternalListener->Stop(); + } +} + +ui16 TGRpcServer::GetPort() const { + return Options_.Port; +} + +TString TGRpcServer::GetHost() const { + return Options_.Host; +} + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/grpc_server.h b/library/cpp/grpc/server/grpc_server.h new file mode 100644 index 0000000000..d6814a90a0 --- /dev/null +++ b/library/cpp/grpc/server/grpc_server.h @@ -0,0 +1,356 @@ +#pragma once + +#include "grpc_request_base.h" +#include "logger.h" + +#include <library/cpp/threading/future/future.h> + +#include <util/generic/ptr.h> +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/generic/maybe.h> +#include <util/generic/queue.h> +#include <util/generic/hash_set.h> +#include <util/system/types.h> +#include <util/system/mutex.h> +#include <util/thread/factory.h> + +#include <grpc++/grpc++.h> + +namespace NGrpc { + +constexpr ui64 DEFAULT_GRPC_MESSAGE_SIZE_LIMIT = 64000000; + +struct TSslData { + TString Cert; + TString Key; + TString Root; +}; + +struct IExternalListener + : public TThrRefBase +{ + using TPtr = TIntrusivePtr<IExternalListener>; + virtual void Init(std::unique_ptr<grpc::experimental::ExternalConnectionAcceptor> acceptor) = 0; + virtual void Start() = 0; + virtual void Stop() = 0; +}; + +//! Server's options. +struct TServerOptions { +#define DECLARE_FIELD(name, type, default) \ + type name{default}; \ + inline TServerOptions& Set##name(const type& value) { \ + name = value; \ + return *this; \ + } + + //! Hostname of server to bind to. + DECLARE_FIELD(Host, TString, "[::]"); + //! Service port. + DECLARE_FIELD(Port, ui16, 0); + + //! Number of worker threads. + DECLARE_FIELD(WorkerThreads, size_t, 2); + + //! Create one completion queue per thread + DECLARE_FIELD(UseCompletionQueuePerThread, bool, false); + + //! Memory quota size for grpc server in bytes. Zero means unlimited. + DECLARE_FIELD(GRpcMemoryQuotaBytes, size_t, 0); + + //! How long to wait until pending rpcs are forcefully terminated. + DECLARE_FIELD(GRpcShutdownDeadline, TDuration, TDuration::Seconds(30)); + + //! In/Out message size limit + DECLARE_FIELD(MaxMessageSize, size_t, DEFAULT_GRPC_MESSAGE_SIZE_LIMIT); + + //! Use GRpc keepalive + DECLARE_FIELD(KeepAliveEnable, TMaybe<bool>, TMaybe<bool>()); + + //! GRPC_ARG_KEEPALIVE_TIME_MS setting + DECLARE_FIELD(KeepAliveIdleTimeoutTriggerSec, int, 0); + + //! Deprecated, ths option ignored. Will be removed soon. + DECLARE_FIELD(KeepAliveMaxProbeCount, int, 0); + + //! GRPC_ARG_KEEPALIVE_TIMEOUT_MS setting + DECLARE_FIELD(KeepAliveProbeIntervalSec, int, 0); + + //! Max number of requests processing by services (global limit for grpc server) + DECLARE_FIELD(MaxGlobalRequestInFlight, size_t, 100000); + + //! SSL server data + DECLARE_FIELD(SslData, TMaybe<TSslData>, TMaybe<TSslData>()); + + //! GRPC auth + DECLARE_FIELD(UseAuth, bool, false); + + //! Default compression level. Used when no compression options provided by client. + // Mapping to particular compression algorithm depends on client. + DECLARE_FIELD(DefaultCompressionLevel, grpc_compression_level, GRPC_COMPRESS_LEVEL_NONE); + + //! Custom configurator for ServerBuilder. + DECLARE_FIELD(ServerBuilderMutator, std::function<void(grpc::ServerBuilder&)>, [](grpc::ServerBuilder&){}); + + DECLARE_FIELD(ExternalListener, IExternalListener::TPtr, nullptr); + + //! Logger which will be used to write logs about requests handling (iff appropriate log level is enabled). + DECLARE_FIELD(Logger, TLoggerPtr, nullptr); + +#undef DECLARE_FIELD +}; + +class IQueueEvent { +public: + virtual ~IQueueEvent() = default; + + //! Execute an action defined by implementation. + virtual bool Execute(bool ok) = 0; + + //! It is time to perform action requested by AcquireToken server method. It will be called under lock which is also + // used in ReturnToken/AcquireToken methods. Default implementation does nothing assuming that request processor does + // not implement in flight management. + virtual void Process() {} + + //! Finish and destroy request. + virtual void DestroyRequest() = 0; +}; + +class ICancelableContext { +public: + virtual void Shutdown() = 0; + virtual ~ICancelableContext() = default; +}; + +template <class TLimit> +class TInFlightLimiterImpl { +public: + explicit TInFlightLimiterImpl(const TLimit& limit) + : Limit_(limit) + {} + + bool Inc() { + i64 newVal; + i64 prev; + do { + prev = AtomicGet(CurInFlightReqs_); + Y_VERIFY(prev >= 0); + if (Limit_ && prev > Limit_) { + return false; + } + newVal = prev + 1; + } while (!AtomicCas(&CurInFlightReqs_, newVal, prev)); + return true; + } + + void Dec() { + i64 newVal = AtomicDecrement(CurInFlightReqs_); + Y_VERIFY(newVal >= 0); + } + + i64 GetCurrentInFlight() const { + return AtomicGet(CurInFlightReqs_); + } + +private: + const TLimit Limit_; + TAtomic CurInFlightReqs_ = 0; +}; + +using TGlobalLimiter = TInFlightLimiterImpl<i64>; + + +class IGRpcService: public TThrRefBase { +public: + virtual grpc::Service* GetService() = 0; + virtual void StopService() noexcept = 0; + virtual void InitService(grpc::ServerCompletionQueue* cq, TLoggerPtr logger) = 0; + virtual void SetGlobalLimiterHandle(TGlobalLimiter* limiter) = 0; + virtual bool IsUnsafeToShutdown() const = 0; + virtual size_t RequestsInProgress() const = 0; + + /** + * Called before service is added to the server builder. This allows + * service to inspect server options and initialize accordingly. + */ + virtual void SetServerOptions(const TServerOptions& options) = 0; +}; + +template<typename T> +class TGrpcServiceBase: public IGRpcService { +public: + class TShutdownGuard { + using TOwner = TGrpcServiceBase<T>; + friend class TGrpcServiceBase<T>; + + public: + TShutdownGuard() + : Owner(nullptr) + { } + + ~TShutdownGuard() { + Release(); + } + + TShutdownGuard(TShutdownGuard&& other) + : Owner(other.Owner) + { + other.Owner = nullptr; + } + + TShutdownGuard& operator=(TShutdownGuard&& other) { + if (Y_LIKELY(this != &other)) { + Release(); + Owner = other.Owner; + other.Owner = nullptr; + } + return *this; + } + + explicit operator bool() const { + return bool(Owner); + } + + void Release() { + if (Owner) { + AtomicDecrement(Owner->GuardCount_); + Owner = nullptr; + } + } + + TShutdownGuard(const TShutdownGuard&) = delete; + TShutdownGuard& operator=(const TShutdownGuard&) = delete; + + private: + explicit TShutdownGuard(TOwner* owner) + : Owner(owner) + { } + + private: + TOwner* Owner; + }; + +public: + using TCurrentGRpcService = T; + + void StopService() noexcept override { + with_lock(Lock_) { + AtomicSet(ShuttingDown_, 1); + + // Send TryCansel to event (can be send after finishing). + // Actual dtors will be called from grpc thread, so deadlock impossible + for (auto* request : Requests_) { + request->Shutdown(); + } + } + } + + TShutdownGuard ProtectShutdown() noexcept { + AtomicIncrement(GuardCount_); + if (IsShuttingDown()) { + AtomicDecrement(GuardCount_); + return { }; + } + + return TShutdownGuard(this); + }; + + bool IsUnsafeToShutdown() const override { + return AtomicGet(GuardCount_) > 0; + } + + size_t RequestsInProgress() const override { + size_t c = 0; + with_lock(Lock_) { + c = Requests_.size(); + } + return c; + } + + void SetServerOptions(const TServerOptions& options) override { + SslServer_ = bool(options.SslData); + NeedAuth_ = options.UseAuth; + } + + void SetGlobalLimiterHandle(TGlobalLimiter* /*limiter*/) override {} + + //! Check if the server is going to shut down. + bool IsShuttingDown() const { + return AtomicGet(ShuttingDown_); + } + + bool SslServer() const { + return SslServer_; + } + + bool NeedAuth() const { + return NeedAuth_; + } + + bool RegisterRequestCtx(ICancelableContext* req) { + with_lock(Lock_) { + auto r = Requests_.emplace(req); + Y_VERIFY(r.second, "Ctx already registered"); + + if (IsShuttingDown()) { + // Server is already shutting down + Requests_.erase(r.first); + return false; + } + } + + return true; + } + + void DeregisterRequestCtx(ICancelableContext* req) { + with_lock(Lock_) { + Y_VERIFY(Requests_.erase(req), "Ctx is not registered"); + } + } + +protected: + using TGrpcAsyncService = typename TCurrentGRpcService::AsyncService; + TGrpcAsyncService Service_; + + TGrpcAsyncService* GetService() override { + return &Service_; + } + +private: + TAtomic ShuttingDown_ = 0; + TAtomic GuardCount_ = 0; + + bool SslServer_ = false; + bool NeedAuth_ = false; + + THashSet<ICancelableContext*> Requests_; + TAdaptiveLock Lock_; +}; + +class TGRpcServer { +public: + using IGRpcServicePtr = TIntrusivePtr<IGRpcService>; + TGRpcServer(const TServerOptions& opts); + ~TGRpcServer(); + void AddService(IGRpcServicePtr service); + void Start(); + // Send stop to registred services and call Shutdown on grpc server + // This method MUST be called before destroying TGRpcServer + void Stop(); + ui16 GetPort() const; + TString GetHost() const; + +private: + using IThreadRef = TAutoPtr<IThreadFactory::IThread>; + + const TServerOptions Options_; + std::unique_ptr<grpc::Server> Server_; + std::vector<std::unique_ptr<grpc::ServerCompletionQueue>> CQS_; + TVector<IThreadRef> Ts; + + TVector<IGRpcServicePtr> Services_; + TGlobalLimiter Limiter_; +}; + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/logger.h b/library/cpp/grpc/server/logger.h new file mode 100644 index 0000000000..53af26be9c --- /dev/null +++ b/library/cpp/grpc/server/logger.h @@ -0,0 +1,43 @@ +#pragma once + +#include <library/cpp/logger/priority.h> + +#include <util/generic/ptr.h> + +namespace NGrpc { + +class TLogger: public TThrRefBase { +protected: + TLogger() = default; + +public: + [[nodiscard]] + bool IsEnabled(ELogPriority priority) const noexcept { + return DoIsEnabled(priority); + } + + void Y_PRINTF_FORMAT(3, 4) Write(ELogPriority priority, const char* format, ...) noexcept { + va_list args; + va_start(args, format); + DoWrite(priority, format, args); + va_end(args); + } + +protected: + virtual bool DoIsEnabled(ELogPriority priority) const noexcept = 0; + virtual void DoWrite(ELogPriority p, const char* format, va_list args) noexcept = 0; +}; + +using TLoggerPtr = TIntrusivePtr<TLogger>; + +#define GRPC_LOG_DEBUG(logger, format, ...) \ + if (logger && logger->IsEnabled(ELogPriority::TLOG_DEBUG)) { \ + logger->Write(ELogPriority::TLOG_DEBUG, format, __VA_ARGS__); \ + } else { } + +#define GRPC_LOG_INFO(logger, format, ...) \ + if (logger && logger->IsEnabled(ELogPriority::TLOG_INFO)) { \ + logger->Write(ELogPriority::TLOG_INFO, format, __VA_ARGS__); \ + } else { } + +} // namespace NGrpc diff --git a/library/cpp/grpc/server/ut/grpc_response_ut.cpp b/library/cpp/grpc/server/ut/grpc_response_ut.cpp new file mode 100644 index 0000000000..8abc4e4e0e --- /dev/null +++ b/library/cpp/grpc/server/ut/grpc_response_ut.cpp @@ -0,0 +1,88 @@ +#include <library/cpp/grpc/server/grpc_response.h> +#include <library/cpp/testing/unittest/registar.h> + +#include <google/protobuf/duration.pb.h> +#include <grpc++/impl/codegen/proto_utils.h> +#include <grpc++/impl/grpc_library.h> + +static ::grpc::internal::GrpcLibraryInitializer grpcInitializer; + +using namespace NGrpc; + +using google::protobuf::Duration; + +Y_UNIT_TEST_SUITE(ResponseTest) { + + template <typename T> + grpc::ByteBuffer Serialize(T resp) { + grpc::ByteBuffer buf; + bool ownBuf = false; + grpc::Status status = grpc::SerializationTraits<T>::Serialize(resp, &buf, &ownBuf); + UNIT_ASSERT(status.ok()); + return buf; + } + + template <typename T> + T Deserialize(grpc::ByteBuffer* buf) { + T message; + auto status = grpc::SerializationTraits<T>::Deserialize(buf, &message); + UNIT_ASSERT(status.ok()); + return message; + } + + Y_UNIT_TEST(UniversalResponseMsg) { + Duration d1; + d1.set_seconds(12345); + d1.set_nanos(67890); + + auto buf = Serialize(TUniversalResponse<Duration>(&d1)); + Duration d2 = Deserialize<Duration>(&buf); + + UNIT_ASSERT_VALUES_EQUAL(d2.seconds(), 12345); + UNIT_ASSERT_VALUES_EQUAL(d2.nanos(), 67890); + } + + Y_UNIT_TEST(UniversalResponseBuf) { + Duration d1; + d1.set_seconds(123); + d1.set_nanos(456); + + TString data = d1.SerializeAsString(); + grpc::Slice dataSlice{data.data(), data.size()}; + grpc::ByteBuffer dataBuf{&dataSlice, 1}; + + auto buf = Serialize(TUniversalResponse<Duration>(&dataBuf)); + Duration d2 = Deserialize<Duration>(&buf); + + UNIT_ASSERT_VALUES_EQUAL(d2.seconds(), 123); + UNIT_ASSERT_VALUES_EQUAL(d2.nanos(), 456); + } + + Y_UNIT_TEST(UniversalResponseRefMsg) { + Duration d1; + d1.set_seconds(12345); + d1.set_nanos(67890); + + auto buf = Serialize(TUniversalResponseRef<Duration>(&d1)); + Duration d2 = Deserialize<Duration>(&buf); + + UNIT_ASSERT_VALUES_EQUAL(d2.seconds(), 12345); + UNIT_ASSERT_VALUES_EQUAL(d2.nanos(), 67890); + } + + Y_UNIT_TEST(UniversalResponseRefBuf) { + Duration d1; + d1.set_seconds(123); + d1.set_nanos(456); + + TString data = d1.SerializeAsString(); + grpc::Slice dataSlice{data.data(), data.size()}; + grpc::ByteBuffer dataBuf{&dataSlice, 1}; + + auto buf = Serialize(TUniversalResponseRef<Duration>(&dataBuf)); + Duration d2 = Deserialize<Duration>(&buf); + + UNIT_ASSERT_VALUES_EQUAL(d2.seconds(), 123); + UNIT_ASSERT_VALUES_EQUAL(d2.nanos(), 456); + } +} diff --git a/library/cpp/grpc/server/ut/stream_adaptor_ut.cpp b/library/cpp/grpc/server/ut/stream_adaptor_ut.cpp new file mode 100644 index 0000000000..c34d3b8c2b --- /dev/null +++ b/library/cpp/grpc/server/ut/stream_adaptor_ut.cpp @@ -0,0 +1,121 @@ +#include <library/cpp/grpc/server/grpc_request.h> +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/testing/unittest/tests_data.h> + +#include <util/system/thread.h> +#include <util/thread/pool.h> + +using namespace NGrpc; + +// Here we emulate stream data producer +class TOrderedProducer: public TThread { +public: + TOrderedProducer(IStreamAdaptor* adaptor, ui64 max, bool withSleep, std::function<void(ui64)>&& consumerOp) + : TThread(&ThreadProc, this) + , Adaptor_(adaptor) + , Max_(max) + , WithSleep_(withSleep) + , ConsumerOp_(std::move(consumerOp)) + {} + + static void* ThreadProc(void* _this) { + SetCurrentThreadName("OrderedProducerThread"); + static_cast<TOrderedProducer*>(_this)->Exec(); + return nullptr; + } + + void Exec() { + for (ui64 i = 0; i < Max_; i++) { + auto cb = [i, this]() mutable { + ConsumerOp_(i); + }; + Adaptor_->Enqueue(std::move(cb), false); + if (WithSleep_ && (i % 256 == 0)) { + Sleep(TDuration::MilliSeconds(10)); + } + } + } + +private: + IStreamAdaptor* Adaptor_; + const ui64 Max_; + const bool WithSleep_; + std::function<void(ui64)> ConsumerOp_; +}; + +Y_UNIT_TEST_SUITE(StreamAdaptor) { + static void OrderingTest(size_t threads, bool withSleep) { + + auto adaptor = CreateStreamAdaptor(); + + const i64 max = 10000; + + // Here we will emulate grpc stream (NextReply call after writing) + std::unique_ptr<IThreadPool> consumerQueue(new TThreadPool(TThreadPool::TParams().SetBlocking(false).SetCatching(false))); + // And make sure only one request inflight (see UNIT_ASSERT on adding to the queue) + consumerQueue->Start(threads, 1); + + // Non atomic!!! Stream adaptor must protect us + ui64 curVal = 0; + + // Used just to wait in the main thread + TAtomic finished = false; + auto consumerOp = [&finished, &curVal, ptr{adaptor.get()}, queue{consumerQueue.get()}](ui64 i) { + // Check no reordering inside stream adaptor + // and no simultanious consumer Op call + UNIT_ASSERT_VALUES_EQUAL(curVal, i); + curVal++; + // We must set finished flag after last ProcessNext, but we can`t compare curVal and max after ProcessNext + // so compare here and set after + bool tmp = curVal == max; + bool res = queue->AddFunc([ptr, &finished, tmp, &curVal, i]() { + // Additional check the value still same + // run under tsan makes sure no consumer Op call before we call ProcessNext + UNIT_ASSERT_VALUES_EQUAL(curVal, i + 1); + ptr->ProcessNext(); + // Reordering after ProcessNext is possible, so check tmp and set finished to true + if (tmp) + AtomicSet(finished, true); + }); + UNIT_ASSERT(res); + }; + + TOrderedProducer producer(adaptor.get(), max, withSleep, std::move(consumerOp)); + + producer.Start(); + producer.Join(); + + while (!AtomicGet(finished)) + { + Sleep(TDuration::MilliSeconds(100)); + } + + consumerQueue->Stop(); + + UNIT_ASSERT_VALUES_EQUAL(curVal, max); + } + + Y_UNIT_TEST(OrderingOneThread) { + OrderingTest(1, false); + } + + Y_UNIT_TEST(OrderingTwoThreads) { + OrderingTest(2, false); + } + + Y_UNIT_TEST(OrderingManyThreads) { + OrderingTest(10, false); + } + + Y_UNIT_TEST(OrderingOneThreadWithSleep) { + OrderingTest(1, true); + } + + Y_UNIT_TEST(OrderingTwoThreadsWithSleep) { + OrderingTest(2, true); + } + + Y_UNIT_TEST(OrderingManyThreadsWithSleep) { + OrderingTest(10, true); + } +} diff --git a/library/cpp/grpc/server/ut/ya.make b/library/cpp/grpc/server/ut/ya.make new file mode 100644 index 0000000000..feb3291af9 --- /dev/null +++ b/library/cpp/grpc/server/ut/ya.make @@ -0,0 +1,21 @@ +UNITTEST_FOR(library/cpp/grpc/server) + +OWNER( + dcherednik + g:kikimr +) + +TIMEOUT(600) +SIZE(MEDIUM) + +PEERDIR( + library/cpp/grpc/server +) + +SRCS( + grpc_response_ut.cpp + stream_adaptor_ut.cpp +) + +END() + diff --git a/library/cpp/grpc/server/ya.make b/library/cpp/grpc/server/ya.make new file mode 100644 index 0000000000..356a1b6793 --- /dev/null +++ b/library/cpp/grpc/server/ya.make @@ -0,0 +1,25 @@ +LIBRARY() + +OWNER( + dcherednik + g:kikimr +) + +SRCS( + event_callback.cpp + grpc_request.cpp + grpc_server.cpp + grpc_counters.cpp +) + +GENERATE_ENUM_SERIALIZATION(grpc_request_base.h) + +PEERDIR( + contrib/libs/grpc + library/cpp/monlib/dynamic_counters/percentile +) + +END() + +RECURSE_FOR_TESTS(ut) + diff --git a/library/cpp/grpc/ya.make b/library/cpp/grpc/ya.make new file mode 100644 index 0000000000..3635124115 --- /dev/null +++ b/library/cpp/grpc/ya.make @@ -0,0 +1,5 @@ +RECURSE( + client + common + server +) |