#include "tcp.h"
#include "details.h"
#include "factory.h"
#include "location.h"
#include "pipequeue.h"
#include "utils.h"
#include <library/cpp/coroutine/listener/listen.h>
#include <library/cpp/coroutine/engine/events.h>
#include <library/cpp/coroutine/engine/sockpool.h>
#include <library/cpp/dns/cache.h>
#include <util/ysaveload.h>
#include <util/generic/buffer.h>
#include <util/generic/guid.h>
#include <util/generic/hash.h>
#include <util/generic/intrlist.h>
#include <util/generic/ptr.h>
#include <util/generic/vector.h>
#include <util/system/yassert.h>
#include <util/system/unaligned_mem.h>
#include <util/stream/buffered.h>
#include <util/stream/mem.h>
using namespace NDns;
using namespace NNeh;
using TNehMessage = TMessage;
template <>
struct TSerializer<TGUID> {
static inline void Save(IOutputStream* out, const TGUID& g) {
out->Write(&g.dw, sizeof(g.dw));
}
static inline void Load(IInputStream* in, TGUID& g) {
in->Load(&g.dw, sizeof(g.dw));
}
};
namespace {
namespace NNehTCP {
typedef IOutputStream::TPart TPart;
static inline ui64 LocalGuid(const TGUID& g) {
return ReadUnaligned<ui64>(g.dw);
}
static inline TString LoadStroka(IInputStream& input, size_t len) {
TString tmp;
tmp.ReserveAndResize(len);
input.Load(tmp.begin(), tmp.size());
return tmp;
}
struct TParts: public TVector<TPart> {
template <class T>
inline void Push(const T& t) {
Push(TPart(t));
}
inline void Push(const TPart& part) {
if (part.len) {
push_back(part);
}
}
inline void Clear() noexcept {
clear();
}
};
template <class T>
struct TMessageQueue {
inline TMessageQueue(TContExecutor* e)
: Ev(e)
{
}
template <class TPtr>
inline void Enqueue(TPtr p) noexcept {
L.PushBack(p.Release());
Ev.Signal();
}
template <class TPtr>
inline bool Dequeue(TPtr& p) noexcept {
do {
if (TryDequeue(p)) {
return true;
}
} while (Ev.WaitI() != ECANCELED);
return false;
}
template <class TPtr>
inline bool TryDequeue(TPtr& p) noexcept {
if (L.Empty()) {
return false;
}
p.Reset(L.PopFront());
return true;
}
inline TContExecutor* Executor() const noexcept {
return Ev.Executor();
}
TIntrusiveListWithAutoDelete<T, TDelete> L;
TContSimpleEvent Ev;
};
template <class Q, class C>
inline bool Dequeue(Q& q, C& c, size_t len) {
typename C::value_type t;
size_t slen = 0;
if (q.Dequeue(t)) {
slen += t->Length();
c.push_back(t);
while (slen < len && q.TryDequeue(t)) {
slen += t->Length();
c.push_back(t);
}
return true;
}
return false;
}
struct TServer: public IRequester, public TContListener::ICallBack {
struct TLink;
typedef TIntrusivePtr<TLink> TLinkRef;
struct TResponce: public TIntrusiveListItem<TResponce> {
inline TResponce(const TLinkRef& link, TData& data, TStringBuf reqid)
: Link(link)
{
Data.swap(data);
TMemoryOutput out(Buf, sizeof(Buf));
::Save(&out, (ui32)(reqid.size() + Data.size()));
out.Write(reqid.data(), reqid.size());
Y_ASSERT(reqid.size() == 16);
Len = out.Buf() - Buf;
}
inline void Serialize(TParts& parts) {
parts.Push(TStringBuf(Buf, Len));
parts.Push(TStringBuf(Data.data(), Data.size()));
}
inline size_t Length() const noexcept {
return Len + Data.size();
}
TLinkRef Link;
TData Data;
char Buf[32];
size_t Len;
};
typedef TAutoPtr<TResponce> TResponcePtr;
struct TRequest: public IRequest {
inline TRequest(const TLinkRef& link, IInputStream& in, size_t len)
: Link(link)
{
Buf.Proceed(len);
in.Load(Buf.Data(), Buf.Size());
if ((ServiceBegin() - Buf.Data()) + ServiceLen() > Buf.Size()) {
throw yexception() << "invalid request (service len)";
}
}
TStringBuf Scheme() const override {
return TStringBuf("tcp");
}
TString RemoteHost() const override {
return Link->RemoteHost;
}
TStringBuf Service() const override {
return TStringBuf(ServiceBegin(), ServiceLen());
}
TStringBuf Data() const override {
return TStringBuf(Service().end(), Buf.End());
}
TStringBuf RequestId() const override {
return TStringBuf(Buf.Data(), 16);
}
bool Canceled() const override {
//TODO
return false;
}
void SendReply(TData& data) override {
Link->P->Schedule(new TResponce(Link, data, RequestId()));
}
void SendError(TResponseError, const TString&) override {
// TODO
}
size_t ServiceLen() const noexcept {
const char* ptr = RequestId().end();
return *(ui32*)ptr;
}
const char* ServiceBegin() const noexcept {
return RequestId().end() + sizeof(ui32);
}
TBuffer Buf;
TLinkRef Link;
};
struct TLink: public TAtomicRefCount<TLink> {
inline TLink(TServer* parent, const TAcceptFull& a)
: P(parent)
, MQ(Executor())
{
S.Swap(*a.S);
SetNoDelay(S, true);
RemoteHost = PrintHostByRfc(*GetPeerAddr(S));
TLinkRef self(this);
Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv");
Executor()->Create<TLink, &TLink::SendCycle>(this, "send");
Executor()->Running()->Yield();
}
inline void Enqueue(TResponcePtr res) {
MQ.Enqueue(res);
}
inline TContExecutor* Executor() const noexcept {
return P->E.Get();
}
void SendCycle(TCont* c) {
TLinkRef self(this);
try {
DoSendCycle(c);
} catch (...) {
Cdbg << "neh/tcp/1: " << CurrentExceptionMessage() << Endl;
}
}
inline void DoSendCycle(TCont* c) {
TVector<TResponcePtr> responses;
TParts parts;
while (Dequeue(MQ, responses, 7000)) {
for (size_t i = 0; i < responses.size(); ++i) {
responses[i]->Serialize(parts);
}
{
TContIOVector iovec(parts.data(), parts.size());
NCoro::WriteVectorI(c, S, &iovec);
}
parts.Clear();
responses.clear();
}
}
void RecvCycle(TCont* c) {
TLinkRef self(this);
try {
DoRecvCycle(c);
} catch (...) {
if (!c->Cancelled()) {
Cdbg << "neh/tcp/2: " << CurrentExceptionMessage() << Endl;
}
}
}
inline void DoRecvCycle(TCont* c) {
TContIO io(S, c);
TBufferedInput input(&io, 8192 * 4);
while (true) {
ui32 len;
try {
::Load(&input, len);
} catch (TLoadEOF&) {
return;
}
P->CB->OnRequest(new TRequest(this, input, len));
}
}
TServer* P;
TMessageQueue<TResponce> MQ;
TSocketHolder S;
TString RemoteHost;
};
inline TServer(IOnRequest* cb, ui16 port)
: CB(cb)
, Addr(port)
{
Thrs.push_back(Spawn<TServer, &TServer::Run>(this));
}
~TServer() override {
Schedule(nullptr);
for (size_t i = 0; i < Thrs.size(); ++i) {
Thrs[i]->Join();
}
}
void Run() {
E = MakeHolder<TContExecutor>(RealStackSize(32000));
THolder<TContListener> L(new TContListener(this, E.Get(), TContListener::TOptions().SetDeferAccept(true)));
//SetHighestThreadPriority();
L->Bind(Addr);
E->Create<TServer, &TServer::RunDispatcher>(this, "dispatcher");
L->Listen();
E->Execute();
}
void OnAcceptFull(const TAcceptFull& a) override {
//I love such code
new TLink(this, a);
}
void OnError() override {
Cerr << CurrentExceptionMessage() << Endl;
}
inline void Schedule(TResponcePtr res) {
PQ.EnqueueSafe(res);
}
void RunDispatcher(TCont* c) {
while (true) {
TResponcePtr res;
PQ.DequeueSafe(c, res);
if (!res) {
break;
}
TLinkRef link = res->Link;
link->Enqueue(res);
}
c->Executor()->Abort();
}
THolder<TContExecutor> E;
IOnRequest* CB;
TNetworkAddress Addr;
TOneConsumerPipeQueue<TResponce> PQ;
TVector<TThreadRef> Thrs;
};
struct TClient {
struct TRequest: public TIntrusiveListItem<TRequest> {
inline TRequest(const TSimpleHandleRef& hndl, const TNehMessage& msg)
: Hndl(hndl)
, Msg(msg)
, Loc(Msg.Addr)
, RI(CachedThrResolve(TResolveInfo(Loc.Host, Loc.GetPort())))
{
CreateGuid(&Guid);
}
inline void Serialize(TParts& parts) {
TMemoryOutput out(Buf, sizeof(Buf));
::Save(&out, (ui32)MsgLen());
::Save(&out, Guid);
::Save(&out, (ui32) Loc.Service.size());
if (Loc.Service.size() > out.Avail()) {
parts.Push(TStringBuf(Buf, out.Buf()));
parts.Push(Loc.Service);
} else {
out.Write(Loc.Service.data(), Loc.Service.size());
parts.Push(TStringBuf(Buf, out.Buf()));
}
parts.Push(Msg.Data);
}
inline size_t Length() const noexcept {
return sizeof(ui32) + MsgLen();
}
inline size_t MsgLen() const noexcept {
return sizeof(Guid.dw) + sizeof(ui32) + Loc.Service.size() + Msg.Data.size();
}
void OnError(const TString& errText) {
Hndl->NotifyError(errText);
}
TSimpleHandleRef Hndl;
TNehMessage Msg;
TGUID Guid;
const TParsedLocation Loc;
const TResolvedHost* RI;
char Buf[128];
};
typedef TAutoPtr<TRequest> TRequestPtr;
struct TChannel {
struct TLink: public TIntrusiveListItem<TLink>, public TSimpleRefCount<TLink> {
inline TLink(TChannel* parent)
: P(parent)
{
Executor()->Create<TLink, &TLink::SendCycle>(this, "send");
}
void SendCycle(TCont* c) {
TIntrusivePtr<TLink> self(this);
try {
DoSendCycle(c);
OnError("shutdown");
} catch (...) {
OnError(CurrentExceptionMessage());
}
Unlink();
}
inline void DoSendCycle(TCont* c) {
if (int ret = NCoro::ConnectI(c, S, P->RI->Addr)) {
ythrow TSystemError(ret) << "can't connect";
}
SetNoDelay(S, true);
Executor()->Create<TLink, &TLink::RecvCycle>(this, "recv");
TVector<TRequestPtr> reqs;
TParts parts;
while (Dequeue(P->Q, reqs, 7000)) {
for (size_t i = 0; i < reqs.size(); ++i) {
TRequestPtr& req = reqs[i];
req->Serialize(parts);
InFly[LocalGuid(req->Guid)] = req;
}
{
TContIOVector vec(parts.data(), parts.size());
NCoro::WriteVectorI(c, S, &vec);
}
reqs.clear();
parts.Clear();
}
}
void RecvCycle(TCont* c) {
TIntrusivePtr<TLink> self(this);
try {
DoRecvCycle(c);
OnError("service close connection");
} catch (...) {
OnError(CurrentExceptionMessage());
}
}
inline void DoRecvCycle(TCont* c) {
TContIO io(S, c);
TBufferedInput input(&io, 8192 * 4);
while (true) {
ui32 len;
TGUID g;
try {
::Load(&input, len);
} catch (TLoadEOF&) {
return;
}
::Load(&input, g);
const TString data(LoadStroka(input, len - sizeof(g.dw)));
TInFly::iterator it = InFly.find(LocalGuid(g));
if (it == InFly.end()) {
continue;
}
TRequestPtr req = it->second;
InFly.erase(it);
req->Hndl->NotifyResponse(data);
}
}
inline TContExecutor* Executor() const noexcept {
return P->Q.Executor();
}
void OnError(const TString& errText) {
for (auto& it : InFly) {
it.second->OnError(errText);
}
InFly.clear();
TRequestPtr req;
while (P->Q.TryDequeue(req)) {
req->OnError(errText);
}
}
TChannel* P;
TSocketHolder S;
typedef THashMap<ui64, TRequestPtr> TInFly;
TInFly InFly;
};
inline TChannel(TContExecutor* e, const TResolvedHost* ri)
: Q(e)
, RI(ri)
{
}
inline void Enqueue(TRequestPtr req) {
Q.Enqueue(req);
if (Links.Empty()) {
for (size_t i = 0; i < 1; ++i) {
SpawnLink();
}
}
}
inline void SpawnLink() {
Links.PushBack(new TLink(this));
}
TMessageQueue<TRequest> Q;
TIntrusiveList<TLink> Links;
const TResolvedHost* RI;
};
typedef TAutoPtr<TChannel> TChannelPtr;
inline TClient() {
Thr = Spawn<TClient, &TClient::RunExecutor>(this);
}
inline ~TClient() {
Reqs.Enqueue(nullptr);
Thr->Join();
}
inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
TSimpleHandleRef ret(new TSimpleHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));
Reqs.Enqueue(new TRequest(ret, msg));
return ret.Get();
}
void RunExecutor() {
//SetHighestThreadPriority();
TContExecutor e(RealStackSize(32000));
e.Create<TClient, &TClient::RunDispatcher>(this, "dispatcher");
e.Execute();
}
void RunDispatcher(TCont* c) {
TRequestPtr req;
while (true) {
Reqs.DequeueSafe(c, req);
if (!req) {
break;
}
TChannelPtr& ch = Channels.Get(req->RI->Id);
if (!ch) {
ch.Reset(new TChannel(c->Executor(), req->RI));
}
ch->Enqueue(req);
}
c->Executor()->Abort();
}
TThreadRef Thr;
TOneConsumerPipeQueue<TRequest> Reqs;
TSocketMap<TChannelPtr> Channels;
};
struct TMultiClient {
inline TMultiClient()
: Next(0)
{
for (size_t i = 0; i < 2; ++i) {
Clients.push_back(new TClient());
}
}
inline THandleRef Schedule(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) {
return Clients[AtomicIncrement(Next) % Clients.size()]->Schedule(msg, fallback, ss);
}
TVector<TAutoPtr<TClient>> Clients;
TAtomic Next;
};
#if 0
static inline TMultiClient* Client() {
return Singleton<NNehTCP::TMultiClient>();
}
#else
static inline TClient* Client() {
return Singleton<NNehTCP::TClient>();
}
#endif
class TTcpProtocol: public IProtocol {
public:
inline TTcpProtocol() {
InitNetworkSubSystem();
}
IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
return new TServer(cb, loc.GetPort());
}
THandleRef ScheduleRequest(const TNehMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
return Client()->Schedule(msg, fallback, ss);
}
TStringBuf Scheme() const noexcept override {
return TStringBuf("tcp");
}
};
}
}
IProtocol* NNeh::TcpProtocol() {
return Singleton<NNehTCP::TTcpProtocol>();
}