aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/neh/tcp2.cpp
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/neh/tcp2.cpp
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/neh/tcp2.cpp')
-rw-r--r--library/cpp/neh/tcp2.cpp1656
1 files changed, 1656 insertions, 0 deletions
diff --git a/library/cpp/neh/tcp2.cpp b/library/cpp/neh/tcp2.cpp
new file mode 100644
index 00000000000..3dad055af1f
--- /dev/null
+++ b/library/cpp/neh/tcp2.cpp
@@ -0,0 +1,1656 @@
+#include "tcp2.h"
+
+#include "details.h"
+#include "factory.h"
+#include "http_common.h"
+#include "neh.h"
+#include "utils.h"
+
+#include <library/cpp/dns/cache.h>
+#include <library/cpp/neh/asio/executor.h>
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/generic/buffer.h>
+#include <util/generic/hash.h>
+#include <util/generic/singleton.h>
+#include <util/network/endpoint.h>
+#include <util/network/init.h>
+#include <util/network/iovec.h>
+#include <util/network/socket.h>
+#include <util/string/cast.h>
+
+#include <atomic>
+
+//#define DEBUG_TCP2 1
+#ifdef DEBUG_TCP2
+TSpinLock OUT_LOCK;
+#define DBGOUT(args) \
+ { \
+ TGuard<TSpinLock> m(OUT_LOCK); \
+ Cout << TInstant::Now().GetValue() << " " << args << Endl; \
+ }
+#else
+#define DBGOUT(args)
+#endif
+
+using namespace std::placeholders;
+
+namespace NNeh {
+ TDuration TTcp2Options::ConnectTimeout = TDuration::MilliSeconds(300);
+ size_t TTcp2Options::InputBufferSize = 16000;
+ size_t TTcp2Options::AsioClientThreads = 4;
+ size_t TTcp2Options::AsioServerThreads = 4;
+ int TTcp2Options::Backlog = 100;
+ bool TTcp2Options::ClientUseDirectWrite = true;
+ bool TTcp2Options::ServerUseDirectWrite = true;
+ TDuration TTcp2Options::ServerInputDeadline = TDuration::Seconds(3600);
+ TDuration TTcp2Options::ServerOutputDeadline = TDuration::Seconds(10);
+
+ bool TTcp2Options::Set(TStringBuf name, TStringBuf value) {
+#define TCP2_TRY_SET(optType, optName) \
+ if (name == TStringBuf(#optName)) { \
+ optName = FromString<optType>(value); \
+ }
+
+ TCP2_TRY_SET(TDuration, ConnectTimeout)
+ else TCP2_TRY_SET(size_t, InputBufferSize) else TCP2_TRY_SET(size_t, AsioClientThreads) else TCP2_TRY_SET(size_t, AsioServerThreads) else TCP2_TRY_SET(int, Backlog) else TCP2_TRY_SET(bool, ClientUseDirectWrite) else TCP2_TRY_SET(bool, ServerUseDirectWrite) else TCP2_TRY_SET(TDuration, ServerInputDeadline) else TCP2_TRY_SET(TDuration, ServerOutputDeadline) else {
+ return false;
+ }
+ return true;
+ }
+}
+
+namespace {
+ namespace NNehTcp2 {
+ using namespace NAsio;
+ using namespace NDns;
+ using namespace NNeh;
+
+ const TString canceled = "canceled";
+ const TString emptyReply = "empty reply";
+
+ inline void PrepareSocket(SOCKET s) {
+ SetNoDelay(s, true);
+ }
+
+ typedef ui64 TRequestId;
+
+#pragma pack(push, 1) //disable align struct members (structs mapped to data transmitted other network)
+ struct TBaseHeader {
+ enum TMessageType {
+ Request = 1,
+ Response = 2,
+ Cancel = 3,
+ MaxMessageType
+ };
+
+ TBaseHeader(TRequestId id, ui32 headerLength, ui8 version, ui8 mType)
+ : Id(id)
+ , HeaderLength(headerLength)
+ , Version(version)
+ , Type(mType)
+ {
+ }
+
+ TRequestId Id; //message id, - monotonic inc. sequence (skip nil value)
+ ui32 HeaderLength;
+ ui8 Version; //current version: 1
+ ui8 Type; //<- TMessageType (+ in future possible ForceResponse,etc)
+ };
+
+ struct TRequestHeader: public TBaseHeader {
+ TRequestHeader(TRequestId reqId, size_t servicePathLength, size_t dataSize)
+ : TBaseHeader(reqId, sizeof(TRequestHeader) + servicePathLength, 1, (ui8)Request)
+ , ContentLength(dataSize)
+ {
+ }
+
+ ui32 ContentLength;
+ };
+
+ struct TResponseHeader: public TBaseHeader {
+ enum TErrorCode {
+ Success = 0,
+ EmptyReply = 1 //not found such service or service not sent response
+ ,
+ MaxErrorCode
+ };
+
+ TResponseHeader(TRequestId reqId, TErrorCode code, size_t dataSize)
+ : TBaseHeader(reqId, sizeof(TResponseHeader), 1, (ui8)Response)
+ , ErrorCode((ui16)code)
+ , ContentLength(dataSize)
+ {
+ }
+
+ TString ErrorDescription() const {
+ if (ErrorCode == (ui16)EmptyReply) {
+ return emptyReply;
+ }
+
+ TStringStream ss;
+ ss << TStringBuf("tcp2 err_code=") << ErrorCode;
+ return ss.Str();
+ }
+
+ ui16 ErrorCode;
+ ui32 ContentLength;
+ };
+
+ struct TCancelHeader: public TBaseHeader {
+ TCancelHeader(TRequestId reqId)
+ : TBaseHeader(reqId, sizeof(TCancelHeader), 1, (ui8)Cancel)
+ {
+ }
+ };
+#pragma pack(pop)
+
+ static const size_t maxHeaderSize = sizeof(TResponseHeader);
+
+ //buffer for read input data, - header + message data
+ struct TTcp2Message {
+ TTcp2Message()
+ : Loader_(&TTcp2Message::LoadBaseHeader)
+ , RequireBytesForComplete_(sizeof(TBaseHeader))
+ , Header_(sizeof(TBaseHeader))
+ {
+ }
+
+ void Clear() {
+ Loader_ = &TTcp2Message::LoadBaseHeader;
+ RequireBytesForComplete_ = sizeof(TBaseHeader);
+ Header_.Clear();
+ Content_.clear();
+ }
+
+ TBuffer& Header() noexcept {
+ return Header_;
+ }
+
+ const TString& Content() const noexcept {
+ return Content_;
+ }
+
+ bool IsComplete() const noexcept {
+ return RequireBytesForComplete_ == 0;
+ }
+
+ size_t LoadFrom(const char* buf, size_t len) {
+ return (this->*Loader_)(buf, len);
+ }
+
+ const TBaseHeader& BaseHeader() const {
+ return *reinterpret_cast<const TBaseHeader*>(Header_.Data());
+ }
+
+ const TRequestHeader& RequestHeader() const {
+ return *reinterpret_cast<const TRequestHeader*>(Header_.Data());
+ }
+
+ const TResponseHeader& ResponseHeader() const {
+ return *reinterpret_cast<const TResponseHeader*>(Header_.Data());
+ }
+
+ private:
+ size_t LoadBaseHeader(const char* buf, size_t len) {
+ size_t useBytes = Min<size_t>(sizeof(TBaseHeader) - Header_.Size(), len);
+ Header_.Append(buf, useBytes);
+ if (Y_UNLIKELY(sizeof(TBaseHeader) > Header_.Size())) {
+ //base header yet not complete
+ return useBytes;
+ }
+ {
+ const TBaseHeader& hdr = BaseHeader();
+ if (BaseHeader().HeaderLength > 32000) { //some heuristic header size limit
+ throw yexception() << TStringBuf("to large neh/tcp2 header size: ") << BaseHeader().HeaderLength;
+ }
+ //header completed
+ Header_.Reserve(hdr.HeaderLength);
+ }
+ const TBaseHeader& hdr = BaseHeader(); //reallocation can move Header_ data to another place, so use fresh 'hdr'
+ if (Y_UNLIKELY(hdr.Version != 1)) {
+ throw yexception() << TStringBuf("unsupported protocol version: ") << static_cast<unsigned>(hdr.Version);
+ }
+ RequireBytesForComplete_ = hdr.HeaderLength - sizeof(TBaseHeader);
+ return useBytes + LoadHeader(buf + useBytes, len - useBytes);
+ }
+
+ size_t LoadHeader(const char* buf, size_t len) {
+ size_t useBytes = Min<size_t>(RequireBytesForComplete_, len);
+ Header_.Append(buf, useBytes);
+ RequireBytesForComplete_ -= useBytes;
+ if (RequireBytesForComplete_) {
+ //continue load header
+ Loader_ = &TTcp2Message::LoadHeader;
+ return useBytes;
+ }
+
+ const TBaseHeader& hdr = *reinterpret_cast<const TBaseHeader*>(Header_.Data());
+
+ if (hdr.Type == TBaseHeader::Request) {
+ if (Header_.Size() < sizeof(TRequestHeader)) {
+ throw yexception() << TStringBuf("invalid request header size");
+ }
+ InitContentLoading(RequestHeader().ContentLength);
+ } else if (hdr.Type == TBaseHeader::Response) {
+ if (Header_.Size() < sizeof(TResponseHeader)) {
+ throw yexception() << TStringBuf("invalid response header size");
+ }
+ InitContentLoading(ResponseHeader().ContentLength);
+ } else if (hdr.Type == TBaseHeader::Cancel) {
+ if (Header_.Size() < sizeof(TCancelHeader)) {
+ throw yexception() << TStringBuf("invalid cancel header size");
+ }
+ return useBytes;
+ } else {
+ throw yexception() << TStringBuf("unsupported request type: ") << static_cast<unsigned>(hdr.Type);
+ }
+ return useBytes + (this->*Loader_)(buf + useBytes, len - useBytes);
+ }
+
+ void InitContentLoading(size_t contentLength) {
+ RequireBytesForComplete_ = contentLength;
+ Content_.ReserveAndResize(contentLength);
+ Loader_ = &TTcp2Message::LoadContent;
+ }
+
+ size_t LoadContent(const char* buf, size_t len) {
+ size_t curContentSize = Content_.size() - RequireBytesForComplete_;
+ size_t useBytes = Min<size_t>(RequireBytesForComplete_, len);
+ memcpy(Content_.begin() + curContentSize, buf, useBytes);
+ RequireBytesForComplete_ -= useBytes;
+ return useBytes;
+ }
+
+ private:
+ typedef size_t (TTcp2Message::*TLoader)(const char*, size_t);
+
+ TLoader Loader_; //current loader (stages - base-header/header/content)
+ size_t RequireBytesForComplete_;
+ TBuffer Header_;
+ TString Content_;
+ };
+
+ //base storage for output data
+ class TMultiBuffers {
+ public:
+ TMultiBuffers()
+ : IOVec_(nullptr, 0)
+ , DataSize_(0)
+ , PoolBytes_(0)
+ {
+ }
+
+ void Clear() noexcept {
+ Parts_.clear();
+ DataSize_ = 0;
+ PoolBytes_ = 0;
+ }
+
+ bool HasFreeSpace() const noexcept {
+ return DataSize_ < 64000 && (PoolBytes_ < (MemPoolSize_ - maxHeaderSize));
+ }
+
+ bool HasData() const noexcept {
+ return Parts_.size();
+ }
+
+ TContIOVector* GetIOvec() noexcept {
+ return &IOVec_;
+ }
+
+ protected:
+ void AddPart(const void* buf, size_t len) {
+ Parts_.push_back(IOutputStream::TPart(buf, len));
+ DataSize_ += len;
+ }
+
+ //used for allocate header (MUST be POD type)
+ template <typename T>
+ inline T* Allocate() noexcept {
+ size_t poolBytes = PoolBytes_;
+ PoolBytes_ += sizeof(T);
+ return (T*)(MemPool_ + poolBytes);
+ }
+
+ //used for allocate header (MUST be POD type) + some tail
+ template <typename T>
+ inline T* AllocatePlus(size_t tailSize) noexcept {
+ Y_ASSERT(tailSize <= MemPoolReserve_);
+ size_t poolBytes = PoolBytes_;
+ PoolBytes_ += sizeof(T) + tailSize;
+ return (T*)(MemPool_ + poolBytes);
+ }
+
+ protected:
+ TContIOVector IOVec_;
+ TVector<IOutputStream::TPart> Parts_;
+ static const size_t MemPoolSize_ = maxHeaderSize * 100;
+ static const size_t MemPoolReserve_ = 32;
+ size_t DataSize_;
+ size_t PoolBytes_;
+ char MemPool_[MemPoolSize_ + MemPoolReserve_];
+ };
+
+ //protector for limit usage tcp connection output (and used data) only from one thread at same time
+ class TOutputLock {
+ public:
+ TOutputLock() noexcept
+ : Lock_(0)
+ {
+ }
+
+ bool TryAquire() noexcept {
+ do {
+ if (AtomicTryLock(&Lock_)) {
+ return true;
+ }
+ } while (!AtomicGet(Lock_)); //without magic loop atomic lock some unreliable
+ return false;
+ }
+
+ void Release() noexcept {
+ AtomicUnlock(&Lock_);
+ }
+
+ bool IsFree() const noexcept {
+ return !AtomicGet(Lock_);
+ }
+
+ private:
+ TAtomic Lock_;
+ };
+
+ class TClient {
+ class TRequest;
+ class TConnection;
+ typedef TIntrusivePtr<TRequest> TRequestRef;
+ typedef TIntrusivePtr<TConnection> TConnectionRef;
+
+ class TRequest: public TThrRefBase, public TNonCopyable {
+ public:
+ class THandle: public TSimpleHandle {
+ public:
+ THandle(IOnRecv* f, const TMessage& msg, TStatCollector* s) noexcept
+ : TSimpleHandle(f, msg, s)
+ {
+ }
+
+ bool MessageSendedCompletely() const noexcept override {
+ if (TSimpleHandle::MessageSendedCompletely()) {
+ return true;
+ }
+
+ TRequestRef req = GetRequest();
+ if (!!req && req->RequestSendedCompletely()) {
+ const_cast<THandle*>(this)->SetSendComplete();
+ }
+
+ return TSimpleHandle::MessageSendedCompletely();
+ }
+
+ void Cancel() noexcept override {
+ if (TSimpleHandle::Canceled()) {
+ return;
+ }
+
+ TRequestRef req = GetRequest();
+ if (!!req) {
+ req->Cancel();
+ TSimpleHandle::Cancel();
+ }
+ }
+
+ void NotifyResponse(const TString& resp) {
+ TNotifyHandle::NotifyResponse(resp);
+
+ ReleaseRequest();
+ }
+
+ void NotifyError(const TString& error) {
+ TNotifyHandle::NotifyError(error);
+
+ ReleaseRequest();
+ }
+
+ void NotifyError(TErrorRef error) {
+ TNotifyHandle::NotifyError(error);
+
+ ReleaseRequest();
+ }
+
+ //not thread safe!
+ void SetRequest(const TRequestRef& r) noexcept {
+ Req_ = r;
+ }
+
+ void ReleaseRequest() noexcept {
+ TRequestRef tmp;
+ TGuard<TSpinLock> g(SP_);
+ tmp.Swap(Req_);
+ }
+
+ private:
+ TRequestRef GetRequest() const noexcept {
+ TGuard<TSpinLock> g(SP_);
+ return Req_;
+ }
+
+ mutable TSpinLock SP_;
+ TRequestRef Req_;
+ };
+
+ typedef TIntrusivePtr<THandle> THandleRef;
+
+ static void Run(THandleRef& h, const TMessage& msg, TClient& clnt) {
+ TRequestRef req(new TRequest(h, msg, clnt));
+ h->SetRequest(req);
+ req->Run(req);
+ }
+
+ ~TRequest() override {
+ DBGOUT("TClient::~TRequest()");
+ }
+
+ private:
+ TRequest(THandleRef& h, TMessage msg, TClient& clnt)
+ : Hndl_(h)
+ , Clnt_(clnt)
+ , Msg_(std::move(msg))
+ , Loc_(Msg_.Addr)
+ , Addr_(CachedResolve(TResolveInfo(Loc_.Host, Loc_.GetPort())))
+ , Canceled_(false)
+ , Id_(0)
+ {
+ DBGOUT("TClient::TRequest()");
+ }
+
+ void Run(TRequestRef& req) {
+ TDestination& dest = Clnt_.Dest_.Get(Addr_->Id);
+ dest.Run(req);
+ }
+
+ public:
+ void OnResponse(TTcp2Message& msg) {
+ DBGOUT("TRequest::OnResponse: " << msg.ResponseHeader().Id);
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ const TResponseHeader& respHdr = msg.ResponseHeader();
+ if (Y_LIKELY(!respHdr.ErrorCode)) {
+ h->NotifyResponse(msg.Content());
+ } else {
+ h->NotifyError(new TError(respHdr.ErrorDescription(), TError::ProtocolSpecific, respHdr.ErrorCode));
+ }
+ ReleaseConn();
+ }
+
+ void OnError(const TString& err, const i32 systemCode = 0) {
+ DBGOUT("TRequest::OnError: " << Id_.load(std::memory_order_acquire));
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ h->NotifyError(new TError(err, TError::UnknownType, 0, systemCode));
+ ReleaseConn();
+ }
+
+ void SetConnection(TConnection* conn) noexcept {
+ auto g = Guard(AL_);
+ Conn_ = conn;
+ }
+
+ bool Canceled() const noexcept {
+ return Canceled_;
+ }
+
+ const TResolvedHost* Addr() const noexcept {
+ return Addr_;
+ }
+
+ TStringBuf Service() const noexcept {
+ return Loc_.Service;
+ }
+
+ const TString& Data() const noexcept {
+ return Msg_.Data;
+ }
+
+ TClient& Client() noexcept {
+ return Clnt_;
+ }
+
+ bool RequestSendedCompletely() const noexcept {
+ if (Id_.load(std::memory_order_acquire) == 0) {
+ return false;
+ }
+
+ TConnectionRef conn = GetConn();
+ if (!conn) {
+ return false;
+ }
+
+ TRequestId lastSendedReqId = conn->LastSendedRequestId();
+ if (lastSendedReqId >= Id_.load(std::memory_order_acquire)) {
+ return true;
+ } else if (Y_UNLIKELY((Id_.load(std::memory_order_acquire) - lastSendedReqId) > (Max<TRequestId>() - Max<ui32>()))) {
+ //overflow req-id value
+ return true;
+ }
+ return false;
+ }
+
+ void Cancel() noexcept {
+ Canceled_ = true;
+ THandleRef h = ReleaseHandler();
+ if (!h) {
+ return;
+ }
+
+ TConnectionRef conn = ReleaseConn();
+ if (!!conn && Id_.load(std::memory_order_acquire)) {
+ conn->Cancel(Id_.load(std::memory_order_acquire));
+ }
+ h->NotifyError(new TError(canceled, TError::Cancelled));
+ }
+
+ void SetReqId(TRequestId reqId) noexcept {
+ auto guard = Guard(IdLock_);
+ Id_.store(reqId, std::memory_order_release);
+ }
+
+ TRequestId ReqId() const noexcept {
+ return Id_.load(std::memory_order_acquire);
+ }
+
+ private:
+ inline THandleRef ReleaseHandler() noexcept {
+ THandleRef h;
+ {
+ auto g = Guard(AL_);
+ h.Swap(Hndl_);
+ }
+ return h;
+ }
+
+ inline TConnectionRef GetConn() const noexcept {
+ auto g = Guard(AL_);
+ return Conn_;
+ }
+
+ inline TConnectionRef ReleaseConn() noexcept {
+ TConnectionRef c;
+ {
+ auto g = Guard(AL_);
+ c.Swap(Conn_);
+ }
+ return c;
+ }
+
+ mutable TAdaptiveLock AL_; //guaranted calling notify() only once (prevent race between asio thread and current)
+ THandleRef Hndl_;
+ TClient& Clnt_;
+ const TMessage Msg_;
+ const TParsedLocation Loc_;
+ const TResolvedHost* Addr_;
+ TConnectionRef Conn_;
+ NAtomic::TBool Canceled_;
+ TSpinLock IdLock_;
+ std::atomic<TRequestId> Id_;
+ };
+
+ class TConnection: public TThrRefBase {
+ enum TState {
+ Init,
+ Connecting,
+ Connected,
+ Closed,
+ MaxState
+ };
+ typedef THashMap<TRequestId, TRequestRef> TReqsInFly;
+
+ public:
+ class TOutputBuffers: public TMultiBuffers {
+ public:
+ void AddRequest(const TRequestRef& req) {
+ Requests_.push_back(req);
+ if (req->Service().size() > MemPoolReserve_) {
+ TRequestHeader* hdr = new (Allocate<TRequestHeader>()) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size());
+ AddPart(hdr, sizeof(TRequestHeader));
+ AddPart(req->Service().data(), req->Service().size());
+ } else {
+ TRequestHeader* hdr = new (AllocatePlus<TRequestHeader>(req->Service().size())) TRequestHeader(req->ReqId(), req->Service().size(), req->Data().size());
+ AddPart(hdr, sizeof(TRequestHeader) + req->Service().size());
+ memmove(++hdr, req->Service().data(), req->Service().size());
+ }
+ AddPart(req->Data().data(), req->Data().size());
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void AddCancelRequest(TRequestId reqId) {
+ TCancelHeader* hdr = new (Allocate<TCancelHeader>()) TCancelHeader(reqId);
+ AddPart(hdr, sizeof(TCancelHeader));
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void Clear() {
+ TMultiBuffers::Clear();
+ Requests_.clear();
+ }
+
+ private:
+ TVector<TRequestRef> Requests_;
+ };
+
+ TConnection(TIOService& srv)
+ : AS_(srv)
+ , State_(Init)
+ , BuffSize_(TTcp2Options::InputBufferSize)
+ , Buff_(new char[BuffSize_])
+ , NeedCheckReqsQueue_(0)
+ , NeedCheckCancelsQueue_(0)
+ , GenReqId_(0)
+ , LastSendedReqId_(0)
+ {
+ }
+
+ ~TConnection() override {
+ try {
+ DBGOUT("TClient::~TConnection()");
+ OnError("~");
+ } catch (...) {
+ Cdbg << "tcp2::~cln_conn: " << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ //called from client thread
+ bool Run(TRequestRef& req) {
+ if (Y_UNLIKELY(AtomicGet(State_) == Closed)) {
+ return false;
+ }
+
+ req->Ref();
+ try {
+ Reqs_.Enqueue(req.Get());
+ } catch (...) {
+ req->UnRef();
+ throw;
+ }
+
+ AtomicSet(NeedCheckReqsQueue_, 1);
+ req->SetConnection(this);
+ TAtomicBase state = AtomicGet(State_);
+ if (Y_LIKELY(state == Connected)) {
+ ProcessOutputReqsQueue();
+ return true;
+ }
+
+ if (state == Init) {
+ if (AtomicCas(&State_, Connecting, Init)) {
+ try {
+ TEndpoint addr(new NAddr::TAddrInfo(&*req->Addr()->Addr.Begin()));
+ AS_.AsyncConnect(addr, std::bind(&TConnection::OnConnect, TConnectionRef(this), _1, _2), TTcp2Options::ConnectTimeout);
+ } catch (...) {
+ AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCallback, TConnectionRef(this), CurrentExceptionMessage()));
+ }
+ return true;
+ }
+ }
+ state = AtomicGet(State_);
+ if (state == Connected) {
+ ProcessOutputReqsQueue();
+ } else if (state == Closed) {
+ SafeOnError();
+ }
+ return true;
+ }
+
+ //called from client thread
+ void Cancel(TRequestId id) {
+ Cancels_.Enqueue(id);
+ AtomicSet(NeedCheckCancelsQueue_, 1);
+ if (Y_LIKELY(AtomicGet(State_) == Connected)) {
+ ProcessOutputCancelsQueue();
+ }
+ }
+
+ void ProcessOutputReqsQueue() {
+ if (OutputLock_.TryAquire()) {
+ SendMessages(false);
+ }
+ }
+
+ void ProcessOutputCancelsQueue() {
+ if (OutputLock_.TryAquire()) {
+ AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true));
+ return;
+ }
+ }
+
+ //must be called only from asio thread
+ void ProcessReqsInFlyQueue() {
+ if (AtomicGet(State_) == Closed) {
+ return;
+ }
+
+ TRequest* reqPtr;
+
+ while (ReqsInFlyQueue_.Dequeue(&reqPtr)) {
+ TRequestRef reqTmp(reqPtr);
+ reqPtr->UnRef();
+ ReqsInFly_[reqPtr->ReqId()].Swap(reqTmp);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnConnect(const TErrorCode& ec, IHandlingContext&) {
+ DBGOUT("TConnect::OnConnect: " << ec.Value());
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == EIO) {
+ //try get more detail error info
+ char buf[1];
+ TErrorCode errConnect;
+ AS_.ReadSome(buf, 1, errConnect);
+ OnErrorCode(errConnect.Value() ? errConnect : ec);
+ } else {
+ OnErrorCode(ec);
+ }
+ } else {
+ try {
+ PrepareSocket(AS_.Native());
+ AtomicSet(State_, Connected);
+ AS_.AsyncPollRead(std::bind(&TConnection::OnCanRead, TConnectionRef(this), _1, _2));
+ if (OutputLock_.TryAquire()) {
+ SendMessages(true);
+ return;
+ }
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+ }
+ }
+
+ //must be called only after succes aquiring output
+ void SendMessages(bool asioThread) {
+ //DBGOUT("SendMessages");
+ if (Y_UNLIKELY(AtomicGet(State_) == Closed)) {
+ if (asioThread) {
+ OnError(Error_);
+ } else {
+ SafeOnError();
+ }
+ return;
+ }
+
+ do {
+ if (asioThread) {
+ AtomicSet(NeedCheckCancelsQueue_, 0);
+ TRequestId reqId;
+
+ ProcessReqsInFlyQueue();
+ while (Cancels_.Dequeue(&reqId)) {
+ TReqsInFly::iterator it = ReqsInFly_.find(reqId);
+ if (it == ReqsInFly_.end()) {
+ continue;
+ }
+
+ ReqsInFly_.erase(it);
+ OutputBuffers_.AddCancelRequest(reqId);
+ if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) {
+ if (!FlushOutputBuffers(asioThread, 0)) {
+ return;
+ }
+ }
+ }
+ } else if (AtomicGet(NeedCheckCancelsQueue_)) {
+ AS_.GetIOService().Post(std::bind(&TConnection::SendMessages, TConnectionRef(this), true));
+ return;
+ }
+
+ TRequestId lastReqId = 0;
+ {
+ AtomicSet(NeedCheckReqsQueue_, 0);
+ TRequest* reqPtr;
+
+ while (Reqs_.Dequeue(&reqPtr)) {
+ TRequestRef reqTmp(reqPtr);
+ reqPtr->UnRef();
+ reqPtr->SetReqId(GenerateReqId());
+ if (reqPtr->Canceled()) {
+ continue;
+ }
+ lastReqId = reqPtr->ReqId();
+ if (asioThread) {
+ TRequestRef& req = ReqsInFly_[(TRequestId)reqPtr->ReqId()];
+ req.Swap(reqTmp);
+ OutputBuffers_.AddRequest(req);
+ } else { //can access to ReqsInFly_ only from asio thread, so enqueue req to update ReqsInFly_ queue
+ try {
+ reqTmp->Ref();
+ ReqsInFlyQueue_.Enqueue(reqPtr);
+ } catch (...) {
+ reqTmp->UnRef();
+ throw;
+ }
+ OutputBuffers_.AddRequest(reqTmp);
+ }
+ if (Y_UNLIKELY(!OutputBuffers_.HasFreeSpace())) {
+ if (!FlushOutputBuffers(asioThread, lastReqId)) {
+ return;
+ }
+ }
+ }
+ }
+
+ if (OutputBuffers_.HasData()) {
+ if (!FlushOutputBuffers(asioThread, lastReqId)) {
+ return;
+ }
+ }
+
+ OutputLock_.Release();
+
+ if (!AtomicGet(NeedCheckReqsQueue_) && !AtomicGet(NeedCheckCancelsQueue_)) {
+ DBGOUT("TClient::SendMessages(exit2)");
+ return;
+ }
+ } while (OutputLock_.TryAquire());
+ DBGOUT("TClient::SendMessages(exit1)");
+ }
+
+ TRequestId GenerateReqId() noexcept {
+ TRequestId reqId;
+ {
+ auto guard = Guard(GenReqIdLock_);
+ reqId = ++GenReqId_;
+ }
+ return Y_LIKELY(reqId) ? reqId : GenerateReqId();
+ }
+
+ //called non thread-safe (from outside thread)
+ bool FlushOutputBuffers(bool asioThread, TRequestId reqId) {
+ if (asioThread || TTcp2Options::ClientUseDirectWrite) {
+ TContIOVector& vec = *OutputBuffers_.GetIOvec();
+ TErrorCode err;
+ vec.Proceed(AS_.WriteSome(vec, err));
+
+ if (Y_UNLIKELY(err)) {
+ if (asioThread) {
+ OnErrorCode(err);
+ } else {
+ AS_.GetIOService().Post(std::bind(&TConnection::OnErrorCode, TConnectionRef(this), err));
+ }
+ return false;
+ }
+
+ if (vec.Complete()) {
+ LastSendedReqId_.store(reqId, std::memory_order_release);
+ DBGOUT("Client::FlushOutputBuffers(" << reqId << ")");
+ OutputBuffers_.Clear();
+ return true;
+ }
+ }
+
+ DBGOUT("Client::AsyncWrite(" << reqId << ")");
+ AS_.AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), reqId, _1, _2, _3), TTcp2Options::ServerOutputDeadline);
+ return false;
+ }
+
+ //must be called only from asio thread
+ void OnSend(TRequestId reqId, const TErrorCode& ec, size_t amount, IHandlingContext&) {
+ Y_UNUSED(amount);
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+ } else {
+ if (Y_LIKELY(reqId)) {
+ DBGOUT("Client::OnSend(" << reqId << ")");
+ LastSendedReqId_.store(reqId, std::memory_order_release);
+ }
+ //output already aquired, used asio thread
+ OutputBuffers_.Clear();
+ SendMessages(true);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) {
+ //DBGOUT("OnCanRead(" << ec.Value() << ")");
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+ } else {
+ TErrorCode ec2;
+ OnReadSome(ec2, AS_.ReadSome(Buff_.Get(), BuffSize_, ec2), ctx);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) {
+ //DBGOUT("OnReadSome(" << ec.Value() << ", " << amount << ")");
+ if (Y_UNLIKELY(ec)) {
+ OnErrorCode(ec);
+
+ return;
+ }
+
+ while (1) {
+ if (Y_UNLIKELY(!amount)) {
+ OnError("tcp conn. closed");
+
+ return;
+ }
+
+ try {
+ const char* buff = Buff_.Get();
+ size_t leftBytes = amount;
+ do {
+ size_t useBytes = Msg_.LoadFrom(buff, leftBytes);
+ leftBytes -= useBytes;
+ buff += useBytes;
+ if (Msg_.IsComplete()) {
+ //DBGOUT("OnReceiveMessage(" << Msg_.BaseHeader().Id << "): " << leftBytes);
+ OnReceiveMessage();
+ Msg_.Clear();
+ }
+ } while (leftBytes);
+
+ if (amount == BuffSize_) {
+ //try decrease system calls, - re-run ReadSome if has full filled buffer
+ TErrorCode ecR;
+ amount = AS_.ReadSome(Buff_.Get(), BuffSize_, ecR);
+ if (!ecR) {
+ continue; //process next input data
+ }
+ if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) {
+ ctx.ContinueUseHandler();
+ } else {
+ OnErrorCode(ec);
+ }
+ } else {
+ ctx.ContinueUseHandler();
+ }
+ } catch (...) {
+ OnError(CurrentExceptionMessage());
+ }
+
+ return;
+ }
+ }
+
+ //must be called only from asio thread
+ void OnErrorCode(TErrorCode ec) {
+ OnError(ec.Text(), ec.Value());
+ }
+
+ //must be called only from asio thread
+ void OnErrorCallback(TString err) {
+ OnError(err);
+ }
+
+ //must be called only from asio thread
+ void OnError(const TString& err, const i32 systemCode = 0) {
+ if (AtomicGet(State_) != Closed) {
+ Error_ = err;
+ SystemCode_ = systemCode;
+ AtomicSet(State_, Closed);
+ AS_.AsyncCancel();
+ }
+ SafeOnError();
+ for (auto& it : ReqsInFly_) {
+ it.second->OnError(err);
+ }
+ ReqsInFly_.clear();
+ }
+
+ void SafeOnError() {
+ TRequest* reqPtr;
+
+ while (Reqs_.Dequeue(&reqPtr)) {
+ TRequestRef req(reqPtr);
+ reqPtr->UnRef();
+ //DBGOUT("err queue(" << AS_.Native() << "):" << size_t(reqPtr));
+ req->OnError(Error_, SystemCode_);
+ }
+
+ while (ReqsInFlyQueue_.Dequeue(&reqPtr)) {
+ TRequestRef req(reqPtr);
+ reqPtr->UnRef();
+ //DBGOUT("err fly queue(" << AS_.Native() << "):" << size_t(reqPtr));
+ req->OnError(Error_, SystemCode_);
+ }
+ }
+
+ //must be called only from asio thread
+ void OnReceiveMessage() {
+ //DBGOUT("OnReceiveMessage");
+ const TBaseHeader& hdr = Msg_.BaseHeader();
+
+ if (hdr.Type == TBaseHeader::Response) {
+ ProcessReqsInFlyQueue();
+ TReqsInFly::iterator it = ReqsInFly_.find(hdr.Id);
+ if (it == ReqsInFly_.end()) {
+ DBGOUT("ignore response: " << hdr.Id);
+ return;
+ }
+
+ it->second->OnResponse(Msg_);
+ ReqsInFly_.erase(it);
+ } else {
+ throw yexception() << TStringBuf("unsupported message type: ") << hdr.Type;
+ }
+ }
+
+ TRequestId LastSendedRequestId() const noexcept {
+ return LastSendedReqId_.load(std::memory_order_acquire);
+ }
+
+ private:
+ NAsio::TTcpSocket AS_;
+ TAtomic State_; //state machine status (TState)
+ TString Error_;
+ i32 SystemCode_ = 0;
+
+ //input
+ size_t BuffSize_;
+ TArrayHolder<char> Buff_;
+ TTcp2Message Msg_;
+
+ //output
+ TOutputLock OutputLock_;
+ TAtomic NeedCheckReqsQueue_;
+ TLockFreeQueue<TRequest*> Reqs_;
+ TAtomic NeedCheckCancelsQueue_;
+ TLockFreeQueue<TRequestId> Cancels_;
+ TAdaptiveLock GenReqIdLock_;
+ std::atomic<TRequestId> GenReqId_;
+ std::atomic<TRequestId> LastSendedReqId_;
+ TLockFreeQueue<TRequest*> ReqsInFlyQueue_;
+ TReqsInFly ReqsInFly_;
+ TOutputBuffers OutputBuffers_;
+ };
+
+ class TDestination {
+ public:
+ void Run(TRequestRef& req) {
+ while (1) {
+ TConnectionRef conn = GetConnection();
+ if (!!conn && conn->Run(req)) {
+ return;
+ }
+
+ DBGOUT("TDestination CreateConnection");
+ CreateConnection(conn, req->Client().ExecutorsPool().GetExecutor().GetIOService());
+ }
+ }
+
+ private:
+ TConnectionRef GetConnection() {
+ TGuard<TSpinLock> g(L_);
+ return Conn_;
+ }
+
+ void CreateConnection(TConnectionRef& oldConn, TIOService& srv) {
+ TConnectionRef conn(new TConnection(srv));
+ TGuard<TSpinLock> g(L_);
+ if (Conn_ == oldConn) {
+ Conn_.Swap(conn);
+ }
+ }
+
+ TSpinLock L_;
+ TConnectionRef Conn_;
+ };
+
+ //////////// TClient /////////
+
+ public:
+ TClient()
+ : EP_(TTcp2Options::AsioClientThreads)
+ {
+ }
+
+ ~TClient() {
+ EP_.SyncShutdown();
+ }
+
+ THandleRef Schedule(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
+ //find exist connection or create new
+ TRequest::THandleRef hndl(new TRequest::THandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
+ try {
+ TRequest::Run(hndl, msg, *this);
+ } catch (...) {
+ hndl->ResetOnRecv();
+ hndl->ReleaseRequest();
+ throw;
+ }
+ return hndl.Get();
+ }
+
+ TExecutorsPool& ExecutorsPool() {
+ return EP_;
+ }
+
+ private:
+ NNeh::NHttp::TLockFreeSequence<TDestination> Dest_;
+ TExecutorsPool EP_;
+ };
+
+ ////////// server side ////////////////////////////////////////////////////////////////////////////////////////////
+
+ class TServer: public IRequester {
+ typedef TAutoPtr<TTcpAcceptor> TTcpAcceptorPtr;
+ typedef TAtomicSharedPtr<TTcpSocket> TTcpSocketRef;
+ class TConnection;
+ typedef TIntrusivePtr<TConnection> TConnectionRef;
+
+ struct TRequest: public IRequest {
+ struct TState: public TThrRefBase {
+ TState()
+ : Canceled(false)
+ {
+ }
+
+ TAtomicBool Canceled;
+ };
+ typedef TIntrusivePtr<TState> TStateRef;
+
+ TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content);
+ ~TRequest() override;
+
+ TStringBuf Scheme() const override {
+ return TStringBuf("tcp2");
+ }
+
+ TString RemoteHost() const override;
+
+ TStringBuf Service() const override {
+ return TStringBuf(Buf.Data() + sizeof(TRequestHeader), Buf.End());
+ }
+
+ TStringBuf Data() const override {
+ return TStringBuf(Content_);
+ }
+
+ TStringBuf RequestId() const override {
+ return TStringBuf();
+ }
+
+ bool Canceled() const override {
+ return State->Canceled;
+ }
+
+ void SendReply(TData& data) override;
+
+ void SendError(TResponseError, const TString&) override {
+ // TODO
+ }
+
+ const TRequestHeader& RequestHeader() const noexcept {
+ return *reinterpret_cast<const TRequestHeader*>(Buf.Data());
+ }
+
+ private:
+ TConnectionRef Conn;
+ TBuffer Buf; //service-name + message-data
+ TString Content_;
+ TAtomic Replied_;
+
+ public:
+ TIntrusivePtr<TState> State;
+ };
+
+ class TConnection: public TThrRefBase {
+ private:
+ TConnection(TServer& srv, const TTcpSocketRef& sock)
+ : Srv_(srv)
+ , AS_(sock)
+ , Canceled_(false)
+ , RemoteHost_(NNeh::PrintHostByRfc(*AS_->RemoteEndpoint().Addr()))
+ , BuffSize_(TTcp2Options::InputBufferSize)
+ , Buff_(new char[BuffSize_])
+ , NeedCheckOutputQueue_(0)
+ {
+ DBGOUT("TServer::TConnection()");
+ }
+
+ public:
+ class TOutputBuffers: public TMultiBuffers {
+ public:
+ void AddResponse(TRequestId reqId, TData& data) {
+ TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, TResponseHeader::Success, data.size());
+ ResponseData_.push_back(TAutoPtr<TData>(new TData()));
+ TData& movedData = *ResponseData_.back();
+ movedData.swap(data);
+ AddPart(hdr, sizeof(TResponseHeader));
+ AddPart(movedData.data(), movedData.size());
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void AddError(TRequestId reqId, TResponseHeader::TErrorCode errCode) {
+ TResponseHeader* hdr = new (Allocate<TResponseHeader>()) TResponseHeader(reqId, errCode, 0);
+ AddPart(hdr, sizeof(TResponseHeader));
+ IOVec_ = TContIOVector(Parts_.data(), Parts_.size());
+ }
+
+ void Clear() {
+ TMultiBuffers::Clear();
+ ResponseData_.clear();
+ }
+
+ private:
+ TVector<TAutoPtr<TData>> ResponseData_;
+ };
+
+ static void Create(TServer& srv, const TTcpSocketRef& sock) {
+ TConnectionRef conn(new TConnection(srv, sock));
+ conn->AS_->AsyncPollRead(std::bind(&TConnection::OnCanRead, conn, _1, _2), TTcp2Options::ServerInputDeadline);
+ }
+
+ ~TConnection() override {
+ DBGOUT("~TServer::TConnection(" << (!AS_ ? -666 : AS_->Native()) << ")");
+ }
+
+ private:
+ void OnCanRead(const TErrorCode& ec, IHandlingContext& ctx) {
+ if (ec) {
+ OnError();
+ } else {
+ TErrorCode ec2;
+ OnReadSome(ec2, AS_->ReadSome(Buff_.Get(), BuffSize_, ec2), ctx);
+ }
+ }
+
+ void OnError() {
+ DBGOUT("Srv OnError(" << (!AS_ ? -666 : AS_->Native()) << ")"
+ << " c=" << (size_t)this);
+ Canceled_ = true;
+ AS_->AsyncCancel();
+ }
+
+ void OnReadSome(const TErrorCode& ec, size_t amount, IHandlingContext& ctx) {
+ while (1) {
+ if (ec || !amount) {
+ OnError();
+ return;
+ }
+
+ try {
+ const char* buff = Buff_.Get();
+ size_t leftBytes = amount;
+ do {
+ size_t useBytes = Msg_.LoadFrom(buff, leftBytes);
+ leftBytes -= useBytes;
+ buff += useBytes;
+ if (Msg_.IsComplete()) {
+ OnReceiveMessage();
+ }
+ } while (leftBytes);
+
+ if (amount == BuffSize_) {
+ //try decrease system calls, - re-run ReadSome if has full filled buffer
+ TErrorCode ecR;
+ amount = AS_->ReadSome(Buff_.Get(), BuffSize_, ecR);
+ if (!ecR) {
+ continue;
+ }
+ if (ecR.Value() == EAGAIN || ecR.Value() == EWOULDBLOCK) {
+ ctx.ContinueUseHandler();
+ } else {
+ OnError();
+ }
+ } else {
+ ctx.ContinueUseHandler();
+ }
+ } catch (...) {
+ DBGOUT("exc. " << CurrentExceptionMessage());
+ OnError();
+ }
+ return;
+ }
+ }
+
+ void OnReceiveMessage() {
+ DBGOUT("OnReceiveMessage()");
+ const TBaseHeader& hdr = Msg_.BaseHeader();
+
+ if (hdr.Type == TBaseHeader::Request) {
+ TRequest* reqPtr = new TRequest(TConnectionRef(this), Msg_.Header(), Msg_.Content());
+ IRequestRef req(reqPtr);
+ ReqsState_[reqPtr->RequestHeader().Id] = reqPtr->State;
+ OnRequest(req);
+ } else if (hdr.Type == TBaseHeader::Cancel) {
+ OnCancelRequest(hdr.Id);
+ } else {
+ throw yexception() << "unsupported message type: " << (ui32)hdr.Type;
+ }
+ Msg_.Clear();
+ {
+ TRequestId reqId;
+ while (FinReqs_.Dequeue(&reqId)) {
+ ReqsState_.erase(reqId);
+ }
+ }
+ }
+
+ void OnRequest(IRequestRef& r) {
+ DBGOUT("OnRequest()");
+ Srv_.OnRequest(r);
+ }
+
+ void OnCancelRequest(TRequestId reqId) {
+ THashMap<TRequestId, TRequest::TStateRef>::iterator it = ReqsState_.find(reqId);
+ if (it == ReqsState_.end()) {
+ return;
+ }
+
+ it->second->Canceled = true;
+ }
+
+ public:
+ class TOutputData {
+ public:
+ TOutputData(TRequestId reqId)
+ : ReqId(reqId)
+ {
+ }
+
+ virtual ~TOutputData() {
+ }
+
+ virtual void MoveTo(TOutputBuffers& bufs) = 0;
+
+ TRequestId ReqId;
+ };
+
+ class TResponseData: public TOutputData {
+ public:
+ TResponseData(TRequestId reqId, TData& data)
+ : TOutputData(reqId)
+ {
+ Data.swap(data);
+ }
+
+ void MoveTo(TOutputBuffers& bufs) override {
+ bufs.AddResponse(ReqId, Data);
+ }
+
+ TData Data;
+ };
+
+ class TResponseErrorData: public TOutputData {
+ public:
+ TResponseErrorData(TRequestId reqId, TResponseHeader::TErrorCode errorCode)
+ : TOutputData(reqId)
+ , ErrorCode(errorCode)
+ {
+ }
+
+ void MoveTo(TOutputBuffers& bufs) override {
+ bufs.AddError(ReqId, ErrorCode);
+ }
+
+ TResponseHeader::TErrorCode ErrorCode;
+ };
+
+ //called non thread-safe (from client thread)
+ void SendResponse(TRequestId reqId, TData& data) {
+ DBGOUT("SendResponse: " << reqId << " " << (size_t)~data << " c=" << (size_t)this);
+ TAutoPtr<TOutputData> od(new TResponseData(reqId, data));
+ OutputData_.Enqueue(od);
+ ProcessOutputQueue();
+ }
+
+ //called non thread-safe (from outside thread)
+ void SendError(TRequestId reqId, TResponseHeader::TErrorCode err) {
+ DBGOUT("SendResponseError: " << reqId << " c=" << (size_t)this);
+ TAutoPtr<TOutputData> od(new TResponseErrorData(reqId, err));
+ OutputData_.Enqueue(od);
+ ProcessOutputQueue();
+ }
+
+ void ProcessOutputQueue() {
+ AtomicSet(NeedCheckOutputQueue_, 1);
+ if (OutputLock_.TryAquire()) {
+ SendMessages(false);
+ return;
+ }
+ DBGOUT("ProcessOutputQueue: !AquireOutputOwnership: " << (int)OutputLock_.IsFree());
+ }
+
+ //must be called only after success aquiring output
+ void SendMessages(bool asioThread) {
+ DBGOUT("TServer::SendMessages(enter)");
+ try {
+ do {
+ AtomicUnlock(&NeedCheckOutputQueue_);
+ TAutoPtr<TOutputData> d;
+ while (OutputData_.Dequeue(&d)) {
+ d->MoveTo(OutputBuffers_);
+ if (!OutputBuffers_.HasFreeSpace()) {
+ if (!FlushOutputBuffers(asioThread)) {
+ return;
+ }
+ }
+ }
+
+ if (OutputBuffers_.HasData()) {
+ if (!FlushOutputBuffers(asioThread)) {
+ return;
+ }
+ }
+
+ OutputLock_.Release();
+
+ if (!AtomicGet(NeedCheckOutputQueue_)) {
+ DBGOUT("Server::SendMessages(exit2): " << (int)OutputLock_.IsFree());
+ return;
+ }
+ } while (OutputLock_.TryAquire());
+ DBGOUT("Server::SendMessages(exit1)");
+ } catch (...) {
+ OnError();
+ }
+ }
+
+ bool FlushOutputBuffers(bool asioThread) {
+ DBGOUT("FlushOutputBuffers: cnt=" << OutputBuffers_.GetIOvec()->Count() << " c=" << (size_t)this);
+ //TODO:reseach direct write efficiency
+ if (asioThread || TTcp2Options::ServerUseDirectWrite) {
+ TContIOVector& vec = *OutputBuffers_.GetIOvec();
+
+ vec.Proceed(AS_->WriteSome(vec));
+
+ if (vec.Complete()) {
+ OutputBuffers_.Clear();
+ //DBGOUT("WriteResponse: " << " c=" << (size_t)this);
+ return true;
+ }
+ }
+
+ //socket buffer filled - use async write for sending left data
+ DBGOUT("AsyncWriteResponse: "
+ << " [" << OutputBuffers_.GetIOvec()->Bytes() << "]"
+ << " c=" << (size_t)this);
+ AS_->AsyncWrite(OutputBuffers_.GetIOvec(), std::bind(&TConnection::OnSend, TConnectionRef(this), _1, _2, _3), TTcp2Options::ServerOutputDeadline);
+ return false;
+ }
+
+ void OnFinishRequest(TRequestId reqId) {
+ if (Y_LIKELY(!Canceled_)) {
+ FinReqs_.Enqueue(reqId);
+ }
+ }
+
+ private:
+ void OnSend(const TErrorCode& ec, size_t amount, IHandlingContext&) {
+ Y_UNUSED(amount);
+ DBGOUT("TServer::OnSend(" << ec.Value() << ", " << amount << ")");
+ if (ec) {
+ OnError();
+ } else {
+ OutputBuffers_.Clear();
+ SendMessages(true);
+ }
+ }
+
+ public:
+ bool IsCanceled() const noexcept {
+ return Canceled_;
+ }
+
+ const TString& RemoteHost() const noexcept {
+ return RemoteHost_;
+ }
+
+ private:
+ TServer& Srv_;
+ TTcpSocketRef AS_;
+ NAtomic::TBool Canceled_;
+ TString RemoteHost_;
+
+ //input
+ size_t BuffSize_;
+ TArrayHolder<char> Buff_;
+ TTcp2Message Msg_;
+ THashMap<TRequestId, TRequest::TStateRef> ReqsState_;
+ TLockFreeQueue<TRequestId> FinReqs_;
+
+ //output
+ TOutputLock OutputLock_; //protect socket/buffers from simultaneous access from few threads
+ TAtomic NeedCheckOutputQueue_;
+ NNeh::TAutoLockFreeQueue<TOutputData> OutputData_;
+ TOutputBuffers OutputBuffers_;
+ };
+
+ //////////// TServer /////////
+ public:
+ TServer(IOnRequest* cb, ui16 port)
+ : EP_(TTcp2Options::AsioServerThreads)
+ , CB_(cb)
+ {
+ TNetworkAddress addr(port);
+
+ for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) {
+ TEndpoint ep(new NAddr::TAddrInfo(&*it));
+ TTcpAcceptorPtr a(new TTcpAcceptor(EA_.GetIOService()));
+ //DBGOUT("bind:" << ep.IpToString() << ":" << ep.Port());
+ a->Bind(ep);
+ a->Listen(TTcp2Options::Backlog);
+ StartAccept(a.Get());
+ A_.push_back(a);
+ }
+ }
+
+ ~TServer() override {
+ EA_.SyncShutdown(); //cancel accepting connections
+ A_.clear(); //stop listening
+ EP_.SyncShutdown(); //close all exist connections
+ }
+
+ void StartAccept(TTcpAcceptor* a) {
+ const auto s = MakeAtomicShared<TTcpSocket>(EP_.Size() ? EP_.GetExecutor().GetIOService() : EA_.GetIOService());
+ a->AsyncAccept(*s, std::bind(&TServer::OnAccept, this, a, s, _1, _2));
+ }
+
+ void OnAccept(TTcpAcceptor* a, TTcpSocketRef s, const TErrorCode& ec, IHandlingContext&) {
+ if (Y_UNLIKELY(ec)) {
+ if (ec.Value() == ECANCELED) {
+ return;
+ } else if (ec.Value() == EMFILE || ec.Value() == ENFILE || ec.Value() == ENOMEM || ec.Value() == ENOBUFS) {
+ //reach some os limit, suspend accepting for preventing busyloop (100% cpu usage)
+ TSimpleSharedPtr<TDeadlineTimer> dt(new TDeadlineTimer(a->GetIOService()));
+ dt->AsyncWaitExpireAt(TDuration::Seconds(30), std::bind(&TServer::OnTimeoutSuspendAccept, this, a, dt, _1, _2));
+ } else {
+ Cdbg << "acc: " << ec.Text() << Endl;
+ }
+ } else {
+ SetNonBlock(s->Native());
+ PrepareSocket(s->Native());
+ TConnection::Create(*this, s);
+ }
+ StartAccept(a); //continue accepting
+ }
+
+ void OnTimeoutSuspendAccept(TTcpAcceptor* a, TSimpleSharedPtr<TDeadlineTimer>, const TErrorCode& ec, IHandlingContext&) {
+ if (!ec) {
+ DBGOUT("resume acceptor");
+ StartAccept(a);
+ }
+ }
+
+ void OnRequest(IRequestRef& r) {
+ try {
+ CB_->OnRequest(r);
+ } catch (...) {
+ Cdbg << CurrentExceptionMessage() << Endl;
+ }
+ }
+
+ private:
+ TVector<TTcpAcceptorPtr> A_;
+ TIOServiceExecutor EA_; //thread, where accepted incoming tcp connections
+ TExecutorsPool EP_; //threads, for process write/read data to/from tcp connections (if empty, use EA_ for r/w)
+ IOnRequest* CB_;
+ };
+
+ TServer::TRequest::TRequest(const TConnectionRef& conn, TBuffer& buf, const TString& content)
+ : Conn(conn)
+ , Content_(content)
+ , Replied_(0)
+ , State(new TState())
+ {
+ DBGOUT("TServer::TRequest()");
+ Buf.Swap(buf);
+ }
+
+ TServer::TRequest::~TRequest() {
+ DBGOUT("TServer::~TRequest()");
+ if (!AtomicGet(Replied_)) {
+ Conn->SendError(RequestHeader().Id, TResponseHeader::EmptyReply);
+ }
+ Conn->OnFinishRequest(RequestHeader().Id);
+ }
+
+ TString TServer::TRequest::RemoteHost() const {
+ return Conn->RemoteHost();
+ }
+
+ void TServer::TRequest::SendReply(TData& data) {
+ do {
+ if (AtomicCas(&Replied_, 1, 0)) {
+ Conn->SendResponse(RequestHeader().Id, data);
+ return;
+ }
+ } while (AtomicGet(Replied_) == 0);
+ }
+
+ class TProtocol: public IProtocol {
+ public:
+ inline TProtocol() {
+ InitNetworkSubSystem();
+ }
+
+ IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
+ return new TServer(cb, loc.GetPort());
+ }
+
+ THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
+ return Singleton<TClient>()->Schedule(msg, fallback, ss);
+ }
+
+ TStringBuf Scheme() const noexcept override {
+ return TStringBuf("tcp2");
+ }
+
+ bool SetOption(TStringBuf name, TStringBuf value) override {
+ return TTcp2Options::Set(name, value);
+ }
+ };
+ }
+}
+
+NNeh::IProtocol* NNeh::Tcp2Protocol() {
+ return Singleton<NNehTcp2::TProtocol>();
+}