diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/neh/tcp.cpp | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/neh/tcp.cpp')
-rw-r--r-- | library/cpp/neh/tcp.cpp | 676 |
1 files changed, 676 insertions, 0 deletions
diff --git a/library/cpp/neh/tcp.cpp b/library/cpp/neh/tcp.cpp new file mode 100644 index 00000000000..80f464dac25 --- /dev/null +++ b/library/cpp/neh/tcp.cpp @@ -0,0 +1,676 @@ +#include "tcp.h" + +#include "details.h" +#include "factory.h" +#include "location.h" +#include "pipequeue.h" +#include "utils.h" + +#include <library/cpp/coroutine/listener/listen.h> +#include <library/cpp/coroutine/engine/events.h> +#include <library/cpp/coroutine/engine/sockpool.h> +#include <library/cpp/dns/cache.h> + +#include <util/ysaveload.h> +#include <util/generic/buffer.h> +#include <util/generic/guid.h> +#include <util/generic/hash.h> +#include <util/generic/intrlist.h> +#include <util/generic/ptr.h> +#include <util/generic/vector.h> +#include <util/system/yassert.h> +#include <util/system/unaligned_mem.h> +#include <util/stream/buffered.h> +#include <util/stream/mem.h> + +using namespace NDns; +using namespace NNeh; + +using TNehMessage = TMessage; + +template <> +struct TSerializer<TGUID> { + static inline void Save(IOutputStream* out, const TGUID& g) { + out->Write(&g.dw, sizeof(g.dw)); + } + + static inline void Load(IInputStream* in, TGUID& g) { + in->Load(&g.dw, sizeof(g.dw)); + } +}; + +namespace { + namespace NNehTCP { + typedef IOutputStream::TPart TPart; + + static inline ui64 LocalGuid(const TGUID& g) { + return ReadUnaligned<ui64>(g.dw); + } + + static inline TString LoadStroka(IInputStream& input, size_t len) { + TString tmp; + + tmp.ReserveAndResize(len); + input.Load(tmp.begin(), tmp.size()); + + return tmp; + } + + struct TParts: public TVector<TPart> { + template <class T> + inline void Push(const T& t) { + Push(TPart(t)); + } + + inline void Push(const TPart& part) { + if (part.len) { + push_back(part); + } + } + + inline void Clear() noexcept { + clear(); + } + }; + + template <class T> + struct TMessageQueue { + inline TMessageQueue(TContExecutor* e) + : Ev(e) + { + } + + template <class TPtr> + inline void Enqueue(TPtr p) noexcept { + L.PushBack(p.Release()); + Ev.Signal(); + } + + template <class TPtr> + inline bool Dequeue(TPtr& p) noexcept { + do { + if (TryDequeue(p)) { + return true; + } + } while (Ev.WaitI() != ECANCELED); + + return false; + } + + template <class TPtr> + inline bool TryDequeue(TPtr& p) noexcept { + if (L.Empty()) { + return false; + } + + p.Reset(L.PopFront()); + + return true; + } + + inline TContExecutor* Executor() const noexcept { + return Ev.Executor(); + } + + TIntrusiveListWithAutoDelete<T, TDelete> L; + TContSimpleEvent Ev; + }; + + template <class Q, class C> + inline bool Dequeue(Q& q, C& c, size_t len) { + typename C::value_type t; + size_t slen = 0; + + if (q.Dequeue(t)) { + slen += t->Length(); + c.push_back(t); + + while (slen < len && q.TryDequeue(t)) { + slen += t->Length(); + c.push_back(t); + } + + return true; + } + + return false; + } + + struct TServer: public IRequester, public TContListener::ICallBack { + struct TLink; + typedef TIntrusivePtr<TLink> TLinkRef; + + struct TResponce: public TIntrusiveListItem<TResponce> { + inline TResponce(const TLinkRef& link, TData& data, TStringBuf reqid) + : Link(link) + { + Data.swap(data); + + TMemoryOutput out(Buf, sizeof(Buf)); + + ::Save(&out, (ui32)(reqid.size() + Data.size())); + out.Write(reqid.data(), reqid.size()); + + Y_ASSERT(reqid.size() == 16); + + Len = out.Buf() - Buf; + } + + inline void Serialize(TParts& parts) { + parts.Push(TStringBuf(Buf, Len)); + parts.Push(TStringBuf(Data.data(), Data.size())); + } + + inline size_t Length() const noexcept { + return Len + Data.size(); + } + + TLinkRef Link; + TData Data; + char Buf[32]; + size_t Len; + }; + + typedef TAutoPtr<TResponce> TResponcePtr; + + struct TRequest: public IRequest { + inline TRequest(const TLinkRef& link, IInputStream& in, size_t len) + : Link(link) + { + Buf.Proceed(len); + in.Load(Buf.Data(), Buf.Size()); + if ((ServiceBegin() - Buf.Data()) + ServiceLen() > Buf.Size()) { + throw yexception() << "invalid request (service len)"; + } + } + + TStringBuf Scheme() const override { + return TStringBuf("tcp"); + } + + TString RemoteHost() const override { + return Link->RemoteHost; + } + + TStringBuf Service() const override { + return TStringBuf(ServiceBegin(), ServiceLen()); + } + + TStringBuf Data() const override { + return TStringBuf(Service().end(), Buf.End()); + } + + TStringBuf RequestId() const override { + return TStringBuf(Buf.Data(), 16); + } + + bool Canceled() const override { + //TODO + return false; + } + + void SendReply(TData& data) override { + Link->P->Schedule(new TResponce(Link, data, RequestId())); + } + + void SendError(TResponseError, const TString&) override { + // TODO + } + + size_t ServiceLen() const noexcept { + const char* ptr = RequestId().end(); + return *(ui32*)ptr; + } + + const char* ServiceBegin() const noexcept { + return RequestId().end() + sizeof(ui32); + } + + TBuffer Buf; + TLinkRef Link; + }; + + struct TLink: public TAtomicRefCount<TLink> { + inline TLink(TServer* parent, const TAcceptFull& a) + : P(parent) + , MQ(Executor()) + { + S.Swap(*a.S); + SetNoDelay(S, true); + + RemoteHost = PrintHostByRfc(*GetPeerAddr(S)); + + TLinkRef self(this); + + Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv"); + Executor()->Create<TLink, &TLink::SendCycle>(this, "send"); + + Executor()->Running()->Yield(); + } + + inline void Enqueue(TResponcePtr res) { + MQ.Enqueue(res); + } + + inline TContExecutor* Executor() const noexcept { + return P->E.Get(); + } + + void SendCycle(TCont* c) { + TLinkRef self(this); + + try { + DoSendCycle(c); + } catch (...) { + Cdbg << "neh/tcp/1: " << CurrentExceptionMessage() << Endl; + } + } + + inline void DoSendCycle(TCont* c) { + TVector<TResponcePtr> responses; + TParts parts; + + while (Dequeue(MQ, responses, 7000)) { + for (size_t i = 0; i < responses.size(); ++i) { + responses[i]->Serialize(parts); + } + + { + TContIOVector iovec(parts.data(), parts.size()); + NCoro::WriteVectorI(c, S, &iovec); + } + + parts.Clear(); + responses.clear(); + } + } + + void RecvCycle(TCont* c) { + TLinkRef self(this); + + try { + DoRecvCycle(c); + } catch (...) { + if (!c->Cancelled()) { + Cdbg << "neh/tcp/2: " << CurrentExceptionMessage() << Endl; + } + } + } + + inline void DoRecvCycle(TCont* c) { + TContIO io(S, c); + TBufferedInput input(&io, 8192 * 4); + + while (true) { + ui32 len; + + try { + ::Load(&input, len); + } catch (TLoadEOF&) { + return; + } + + P->CB->OnRequest(new TRequest(this, input, len)); + } + } + + TServer* P; + TMessageQueue<TResponce> MQ; + TSocketHolder S; + TString RemoteHost; + }; + + inline TServer(IOnRequest* cb, ui16 port) + : CB(cb) + , Addr(port) + { + Thrs.push_back(Spawn<TServer, &TServer::Run>(this)); + } + + ~TServer() override { + Schedule(nullptr); + + for (size_t i = 0; i < Thrs.size(); ++i) { + Thrs[i]->Join(); + } + } + + void Run() { + E = MakeHolder<TContExecutor>(RealStackSize(32000)); + THolder<TContListener> L(new TContListener(this, E.Get(), TContListener::TOptions().SetDeferAccept(true))); + //SetHighestThreadPriority(); + L->Bind(Addr); + E->Create<TServer, &TServer::RunDispatcher>(this, "dispatcher"); + L->Listen(); + E->Execute(); + } + + void OnAcceptFull(const TAcceptFull& a) override { + //I love such code + new TLink(this, a); + } + + void OnError() override { + Cerr << CurrentExceptionMessage() << Endl; + } + + inline void Schedule(TResponcePtr res) { + PQ.EnqueueSafe(res); + } + + void RunDispatcher(TCont* c) { + while (true) { + TResponcePtr res; + + PQ.DequeueSafe(c, res); + + if (!res) { + break; + } + + TLinkRef link = res->Link; + + link->Enqueue(res); + } + + c->Executor()->Abort(); + } + THolder<TContExecutor> E; + IOnRequest* CB; + TNetworkAddress Addr; + TOneConsumerPipeQueue<TResponce> PQ; + TVector<TThreadRef> Thrs; + }; + + struct TClient { + struct TRequest: public TIntrusiveListItem<TRequest> { + inline TRequest(const TSimpleHandleRef& hndl, const TNehMessage& msg) + : Hndl(hndl) + , Msg(msg) + , Loc(Msg.Addr) + , RI(CachedThrResolve(TResolveInfo(Loc.Host, Loc.GetPort()))) + { + CreateGuid(&Guid); + } + + inline void Serialize(TParts& parts) { + TMemoryOutput out(Buf, sizeof(Buf)); + + ::Save(&out, (ui32)MsgLen()); + ::Save(&out, Guid); + ::Save(&out, (ui32) Loc.Service.size()); + + if (Loc.Service.size() > out.Avail()) { + parts.Push(TStringBuf(Buf, out.Buf())); + parts.Push(Loc.Service); + } else { + out.Write(Loc.Service.data(), Loc.Service.size()); + parts.Push(TStringBuf(Buf, out.Buf())); + } + + parts.Push(Msg.Data); + } + + inline size_t Length() const noexcept { + return sizeof(ui32) + MsgLen(); + } + + inline size_t MsgLen() const noexcept { + return sizeof(Guid.dw) + sizeof(ui32) + Loc.Service.size() + Msg.Data.size(); + } + + void OnError(const TString& errText) { + Hndl->NotifyError(errText); + } + + TSimpleHandleRef Hndl; + TNehMessage Msg; + TGUID Guid; + const TParsedLocation Loc; + const TResolvedHost* RI; + char Buf[128]; + }; + + typedef TAutoPtr<TRequest> TRequestPtr; + + struct TChannel { + struct TLink: public TIntrusiveListItem<TLink>, public TSimpleRefCount<TLink> { + inline TLink(TChannel* parent) + : P(parent) + { + Executor()->Create<TLink, &TLink::SendCycle>(this, "send"); + } + + void SendCycle(TCont* c) { + TIntrusivePtr<TLink> self(this); + + try { + DoSendCycle(c); + OnError("shutdown"); + } catch (...) { + OnError(CurrentExceptionMessage()); + } + + Unlink(); + } + + inline void DoSendCycle(TCont* c) { + if (int ret = NCoro::ConnectI(c, S, P->RI->Addr)) { + ythrow TSystemError(ret) << "can't connect"; + } + SetNoDelay(S, true); + Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv"); + + TVector<TRequestPtr> reqs; + TParts parts; + + while (Dequeue(P->Q, reqs, 7000)) { + for (size_t i = 0; i < reqs.size(); ++i) { + TRequestPtr& req = reqs[i]; + + req->Serialize(parts); + InFly[LocalGuid(req->Guid)] = req; + } + + { + TContIOVector vec(parts.data(), parts.size()); + NCoro::WriteVectorI(c, S, &vec); + } + + reqs.clear(); + parts.Clear(); + } + } + + void RecvCycle(TCont* c) { + TIntrusivePtr<TLink> self(this); + + try { + DoRecvCycle(c); + OnError("service close connection"); + } catch (...) { + OnError(CurrentExceptionMessage()); + } + } + + inline void DoRecvCycle(TCont* c) { + TContIO io(S, c); + TBufferedInput input(&io, 8192 * 4); + + while (true) { + ui32 len; + TGUID g; + + try { + ::Load(&input, len); + } catch (TLoadEOF&) { + return; + } + ::Load(&input, g); + const TString data(LoadStroka(input, len - sizeof(g.dw))); + + TInFly::iterator it = InFly.find(LocalGuid(g)); + + if (it == InFly.end()) { + continue; + } + + TRequestPtr req = it->second; + + InFly.erase(it); + req->Hndl->NotifyResponse(data); + } + } + + inline TContExecutor* Executor() const noexcept { + return P->Q.Executor(); + } + + void OnError(const TString& errText) { + for (auto& it : InFly) { + it.second->OnError(errText); + } + InFly.clear(); + + TRequestPtr req; + while (P->Q.TryDequeue(req)) { + req->OnError(errText); + } + } + + TChannel* P; + TSocketHolder S; + typedef THashMap<ui64, TRequestPtr> TInFly; + TInFly InFly; + }; + + inline TChannel(TContExecutor* e, const TResolvedHost* ri) + : Q(e) + , RI(ri) + { + } + + inline void Enqueue(TRequestPtr req) { + Q.Enqueue(req); + + if (Links.Empty()) { + for (size_t i = 0; i < 1; ++i) { + SpawnLink(); + } + } + } + + inline void SpawnLink() { + Links.PushBack(new TLink(this)); + } + + TMessageQueue<TRequest> Q; + TIntrusiveList<TLink> Links; + const TResolvedHost* RI; + }; + + typedef TAutoPtr<TChannel> TChannelPtr; + + inline TClient() { + Thr = Spawn<TClient, &TClient::RunExecutor>(this); + } + + inline ~TClient() { + Reqs.Enqueue(nullptr); + Thr->Join(); + } + + inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { + TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + + Reqs.Enqueue(new TRequest(ret, msg)); + + return ret.Get(); + } + + void RunExecutor() { + //SetHighestThreadPriority(); + TContExecutor e(RealStackSize(32000)); + + e.Create<TClient, &TClient::RunDispatcher>(this, "dispatcher"); + e.Execute(); + } + + void RunDispatcher(TCont* c) { + TRequestPtr req; + + while (true) { + Reqs.DequeueSafe(c, req); + + if (!req) { + break; + } + + TChannelPtr& ch = Channels.Get(req->RI->Id); + + if (!ch) { + ch.Reset(new TChannel(c->Executor(), req->RI)); + } + + ch->Enqueue(req); + } + + c->Executor()->Abort(); + } + + TThreadRef Thr; + TOneConsumerPipeQueue<TRequest> Reqs; + TSocketMap<TChannelPtr> Channels; + }; + + struct TMultiClient { + inline TMultiClient() + : Next(0) + { + for (size_t i = 0; i < 2; ++i) { + Clients.push_back(new TClient()); + } + } + + inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) { + return Clients[AtomicIncrement(Next) % Clients.size()]->Schedule(msg, fallback, ss); + } + + TVector<TAutoPtr<TClient>> Clients; + TAtomic Next; + }; + +#if 0 + static inline TMultiClient* Client() { + return Singleton<NNehTCP::TMultiClient>(); + } +#else + static inline TClient* Client() { + return Singleton<NNehTCP::TClient>(); + } +#endif + + class TTcpProtocol: public IProtocol { + public: + inline TTcpProtocol() { + InitNetworkSubSystem(); + } + + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TServer(cb, loc.GetPort()); + } + + THandleRef ScheduleRequest(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + return Client()->Schedule(msg, fallback, ss); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("tcp"); + } + }; + } +} + +IProtocol* NNeh::TcpProtocol() { + return Singleton<NNehTCP::TTcpProtocol>(); +} |