diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/neh/udp.cpp | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/neh/udp.cpp')
-rw-r--r-- | library/cpp/neh/udp.cpp | 691 |
1 files changed, 691 insertions, 0 deletions
diff --git a/library/cpp/neh/udp.cpp b/library/cpp/neh/udp.cpp new file mode 100644 index 0000000000..13250a2493 --- /dev/null +++ b/library/cpp/neh/udp.cpp @@ -0,0 +1,691 @@ +#include "udp.h" +#include "details.h" +#include "neh.h" +#include "location.h" +#include "utils.h" +#include "factory.h" + +#include <library/cpp/dns/cache.h> + +#include <util/network/socket.h> +#include <util/network/address.h> +#include <util/generic/deque.h> +#include <util/generic/hash.h> +#include <util/generic/string.h> +#include <util/generic/buffer.h> +#include <util/generic/singleton.h> +#include <util/digest/murmur.h> +#include <util/random/random.h> +#include <util/ysaveload.h> +#include <util/system/thread.h> +#include <util/system/pipe.h> +#include <util/system/error.h> +#include <util/stream/mem.h> +#include <util/stream/buffer.h> +#include <util/string/cast.h> + +using namespace NNeh; +using namespace NDns; +using namespace NAddr; + +namespace { + namespace NUdp { + enum EPacketType { + PT_REQUEST = 1, + PT_RESPONSE = 2, + PT_STOP = 3, + PT_TIMEOUT = 4 + }; + + struct TUdpHandle: public TNotifyHandle { + inline TUdpHandle(IOnRecv* r, const TMessage& msg, TStatCollector* sc) noexcept + : TNotifyHandle(r, msg, sc) + { + } + + void Cancel() noexcept override { + THandle::Cancel(); //inform stat collector + } + + bool MessageSendedCompletely() const noexcept override { + //TODO + return true; + } + }; + + static inline IRemoteAddrPtr GetSendAddr(SOCKET s) { + IRemoteAddrPtr local = GetSockAddr(s); + const sockaddr* addr = local->Addr(); + + switch (addr->sa_family) { + case AF_INET: { + const TIpAddress a = *(const sockaddr_in*)addr; + + return MakeHolder<TIPv4Addr>(TIpAddress(InetToHost(INADDR_LOOPBACK), a.Port())); + } + + case AF_INET6: { + sockaddr_in6 a = *(const sockaddr_in6*)addr; + + a.sin6_addr = in6addr_loopback; + + return MakeHolder<TIPv6Addr>(a); + } + } + + ythrow yexception() << "unsupported"; + } + + typedef ui32 TCheckSum; + + static inline TString GenerateGuid() { + const ui64 res[2] = { + RandomNumber<ui64>(), RandomNumber<ui64>()}; + + return TString((const char*)res, sizeof(res)); + } + + static inline TCheckSum Sum(const TStringBuf& s) noexcept { + return HostToInet(MurmurHash<TCheckSum>(s.data(), s.size())); + } + + struct TPacket; + + template <class T> + static inline void Serialize(TPacket& p, const T& t); + + struct TPacket { + inline TPacket(IRemoteAddrPtr addr) + : Addr(std::move(addr)) + { + } + + template <class T> + inline TPacket(const T& t, IRemoteAddrPtr addr) + : Addr(std::move(addr)) + { + NUdp::Serialize(*this, t); + } + + inline TPacket(TSocketHolder& s, TBuffer& tmp) { + TAutoPtr<TOpaqueAddr> addr(new TOpaqueAddr()); + + retry_on_intr : { + const int rv = recvfrom(s, tmp.Data(), tmp.size(), MSG_WAITALL, addr->MutableAddr(), addr->LenPtr()); + + if (rv < 0) { + int err = LastSystemError(); + if (err == EAGAIN || err == EWOULDBLOCK) { + Data.Resize(sizeof(TCheckSum) + 1); + *(Data.data() + sizeof(TCheckSum)) = static_cast<char>(PT_TIMEOUT); + } else if (err == EINTR) { + goto retry_on_intr; + } else { + ythrow TSystemError() << "recv failed"; + } + } else { + Data.Append(tmp.Data(), (size_t)rv); + Addr.Reset(addr.Release()); + CheckSign(); + } + } + } + + inline void SendTo(TSocketHolder& s) { + Sign(); + + if (sendto(s, Data.data(), Data.size(), 0, Addr->Addr(), Addr->Len()) < 0) { + Cdbg << LastSystemErrorText() << Endl; + } + } + + IRemoteAddrPtr Addr; + TBuffer Data; + + inline void Sign() { + const TCheckSum sum = CalcSign(); + + memcpy(Data.Data(), &sum, sizeof(sum)); + } + + inline char Type() const noexcept { + return *(Data.data() + sizeof(TCheckSum)); + } + + inline void CheckSign() const { + if (Data.size() < 16) { + ythrow yexception() << "small packet"; + } + + if (StoredSign() != CalcSign()) { + ythrow yexception() << "bad checksum"; + } + } + + inline TCheckSum StoredSign() const noexcept { + TCheckSum sum; + + memcpy(&sum, Data.Data(), sizeof(sum)); + + return sum; + } + + inline TCheckSum CalcSign() const noexcept { + return Sum(Body()); + } + + inline TStringBuf Body() const noexcept { + return TStringBuf(Data.data() + sizeof(TCheckSum), Data.End()); + } + }; + + typedef TAutoPtr<TPacket> TPacketRef; + + class TPacketInput: public TMemoryInput { + public: + inline TPacketInput(const TPacket& p) + : TMemoryInput(p.Body().data(), p.Body().size()) + { + } + }; + + class TPacketOutput: public TBufferOutput { + public: + inline TPacketOutput(TPacket& p) + : TBufferOutput(p.Data) + { + p.Data.Proceed(sizeof(TCheckSum)); + } + }; + + template <class T> + static inline void Serialize(TPacketOutput* out, const T& t) { + Save(out, t.Type()); + t.Serialize(out); + } + + template <class T> + static inline void Serialize(TPacket& p, const T& t) { + TPacketOutput out(p); + + NUdp::Serialize(&out, t); + } + + namespace NPrivate { + template <class T> + static inline void Deserialize(TPacketInput* in, T& t) { + char type; + Load(in, type); + + if (type != t.Type()) { + ythrow yexception() << "unsupported packet"; + } + + t.Deserialize(in); + } + + template <class T> + static inline void Deserialize(const TPacket& p, T& t) { + TPacketInput in(p); + + Deserialize(&in, t); + } + } + + struct TRequestPacket { + TString Guid; + TString Service; + TString Data; + + inline TRequestPacket(const TPacket& p) { + NPrivate::Deserialize(p, *this); + } + + inline TRequestPacket(const TString& srv, const TString& data) + : Guid(GenerateGuid()) + , Service(srv) + , Data(data) + { + } + + inline char Type() const noexcept { + return static_cast<char>(PT_REQUEST); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, Guid); + Save(out, Service); + Save(out, Data); + } + + inline void Deserialize(TPacketInput* in) { + Load(in, Guid); + Load(in, Service); + Load(in, Data); + } + }; + + template <class TStore> + struct TResponsePacket { + TString Guid; + TStore Data; + + inline TResponsePacket(const TString& guid, TStore& data) + : Guid(guid) + { + Data.swap(data); + } + + inline TResponsePacket(const TPacket& p) { + NPrivate::Deserialize(p, *this); + } + + inline char Type() const noexcept { + return static_cast<char>(PT_RESPONSE); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, Guid); + Save(out, Data); + } + + inline void Deserialize(TPacketInput* in) { + Load(in, Guid); + Load(in, Data); + } + }; + + struct TStopPacket { + inline char Type() const noexcept { + return static_cast<char>(PT_STOP); + } + + inline void Serialize(TPacketOutput* out) const { + Save(out, TString("stop packet")); + } + }; + + struct TBindError: public TSystemError { + }; + + struct TSocketDescr { + inline TSocketDescr(TSocketHolder& s, int family) + : S(s.Release()) + , Family(family) + { + } + + TSocketHolder S; + int Family; + }; + + typedef TAutoPtr<TSocketDescr> TSocketRef; + typedef TVector<TSocketRef> TSockets; + + static inline void CreateSocket(TSocketHolder& s, const IRemoteAddr& addr) { + TSocketHolder res(socket(addr.Addr()->sa_family, SOCK_DGRAM, IPPROTO_UDP)); + + if (!res) { + ythrow TSystemError() << "can not create socket"; + } + + FixIPv6ListenSocket(res); + + if (bind(res, addr.Addr(), addr.Len()) != 0) { + ythrow TBindError() << "can not bind " << PrintHostAndPort(addr); + } + + res.Swap(s); + } + + static inline void CreateSockets(TSockets& s, ui16 port) { + TNetworkAddress addr(port); + + for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) { + TSocketHolder res; + + CreateSocket(res, TAddrInfo(&*it)); + + s.push_back(new TSocketDescr(res, it->ai_family)); + } + } + + static inline void CreateSocketsOnRandomPort(TSockets& s) { + while (true) { + try { + TSockets tmp; + + CreateSockets(tmp, 5000 + (RandomNumber<ui16>() % 1000)); + tmp.swap(s); + + return; + } catch (const TBindError&) { + } + } + } + + typedef ui64 TTimeStamp; + + static inline TTimeStamp TimeStamp() noexcept { + return GetCycleCount() >> 31; + } + + struct TRequestDescr: public TIntrusiveListItem<TRequestDescr> { + inline TRequestDescr(const TString& guid, const TNotifyHandleRef& hndl, const TMessage& msg) + : Guid(guid) + , Hndl(hndl) + , Msg(msg) + , TS(TimeStamp()) + { + } + + TString Guid; + TNotifyHandleRef Hndl; + TMessage Msg; + TTimeStamp TS; + }; + + typedef TAutoPtr<TRequestDescr> TRequestDescrRef; + + class TProto { + class TRequest: public IRequest, public TRequestPacket { + public: + inline TRequest(TPacket& p, TProto* parent) + : TRequestPacket(p) + , Addr_(std::move(p.Addr)) + , H_(PrintHostByRfc(*Addr_)) + , P_(parent) + { + } + + TStringBuf Scheme() const override { + return TStringBuf("udp"); + } + + TString RemoteHost() const override { + return H_; + } + + TStringBuf Service() const override { + return ((TRequestPacket&)(*this)).Service; + } + + TStringBuf Data() const override { + return ((TRequestPacket&)(*this)).Data; + } + + TStringBuf RequestId() const override { + return ((TRequestPacket&)(*this)).Guid; + } + + bool Canceled() const override { + //TODO ? + return false; + } + + void SendReply(TData& data) override { + P_->Schedule(new TPacket(TResponsePacket<TData>(Guid, data), std::move(Addr_))); + } + + void SendError(TResponseError, const TString&) override { + // TODO + } + + private: + IRemoteAddrPtr Addr_; + TString H_; + TProto* P_; + }; + + public: + inline TProto(IOnRequest* cb, TSocketHolder& s) + : CB_(cb) + , ToSendEv_(TSystemEvent::rAuto) + , S_(s.Release()) + { + SetSocketTimeout(S_, 10); + Thrs_.push_back(Spawn<TProto, &TProto::ExecuteRecv>(this)); + Thrs_.push_back(Spawn<TProto, &TProto::ExecuteSend>(this)); + } + + inline ~TProto() { + Schedule(new TPacket(TStopPacket(), GetSendAddr(S_))); + + for (size_t i = 0; i < Thrs_.size(); ++i) { + Thrs_[i]->Join(); + } + } + + inline TPacketRef Recv() { + TBuffer tmp; + + tmp.Resize(128 * 1024); + + while (true) { + try { + return new TPacket(S_, tmp); + } catch (...) { + Cdbg << CurrentExceptionMessage() << Endl; + + continue; + } + } + } + + typedef THashMap<TString, TRequestDescrRef> TInFlyBase; + + struct TInFly: public TInFlyBase, public TIntrusiveList<TRequestDescr> { + typedef TInFlyBase::iterator TIter; + typedef TInFlyBase::const_iterator TContsIter; + + inline void Insert(TRequestDescrRef& d) { + PushBack(d.Get()); + (*this)[d->Guid] = d; + } + + inline void EraseStale() noexcept { + const TTimeStamp now = TimeStamp(); + + for (TIterator it = Begin(); (it != End()) && (it->TS < now) && ((now - it->TS) > 120);) { + it->Hndl->NotifyError("request timeout"); + TString safe_key = (it++)->Guid; + erase(safe_key); + } + } + }; + + inline void ExecuteRecv() { + SetHighestThreadPriority(); + + TInFly infly; + + while (true) { + TPacketRef p = Recv(); + + switch (static_cast<EPacketType>(p->Type())) { + case PT_REQUEST: + if (CB_) { + CB_->OnRequest(new TRequest(*p, this)); + } else { + //skip request in case of client + } + + break; + + case PT_RESPONSE: { + CancelStaleRequests(infly); + + TResponsePacket<TString> rp(*p); + + TInFly::TIter it = static_cast<TInFlyBase&>(infly).find(rp.Guid); + + if (it == static_cast<TInFlyBase&>(infly).end()) { + break; + } + + const TRequestDescrRef& d = it->second; + d->Hndl->NotifyResponse(rp.Data); + + infly.erase(it); + + break; + } + + case PT_STOP: + Schedule(nullptr); + + return; + + case PT_TIMEOUT: + CancelStaleRequests(infly); + + break; + } + } + } + + inline void ExecuteSend() { + SetHighestThreadPriority(); + + while (true) { + TPacketRef p; + + while (!ToSend_.Dequeue(&p)) { + ToSendEv_.Wait(); + } + + //shutdown + if (!p) { + return; + } + + p->SendTo(S_); + } + } + + inline void Schedule(TPacketRef p) { + ToSend_.Enqueue(p); + ToSendEv_.Signal(); + } + + inline void Schedule(TRequestDescrRef dsc, TPacketRef p) { + ScheduledReqs_.Enqueue(dsc); + Schedule(p); + } + + protected: + void CancelStaleRequests(TInFly& infly) { + TRequestDescrRef d; + + while (ScheduledReqs_.Dequeue(&d)) { + infly.Insert(d); + } + + infly.EraseStale(); + } + + IOnRequest* CB_; + NNeh::TAutoLockFreeQueue<TPacket> ToSend_; + NNeh::TAutoLockFreeQueue<TRequestDescr> ScheduledReqs_; + TSystemEvent ToSendEv_; + TSocketHolder S_; + TVector<TThreadRef> Thrs_; + }; + + class TProtos { + public: + inline TProtos() { + TSockets s; + + CreateSocketsOnRandomPort(s); + Init(nullptr, s); + } + + inline TProtos(IOnRequest* cb, ui16 port) { + TSockets s; + + CreateSockets(s, port); + Init(cb, s); + } + + static inline TProtos* Instance() { + return Singleton<TProtos>(); + } + + inline void Schedule(const TMessage& msg, const TNotifyHandleRef& hndl) { + TParsedLocation loc(msg.Addr); + const TNetworkAddress* addr = &CachedThrResolve(TResolveInfo(loc.Host, loc.GetPort()))->Addr; + + for (TNetworkAddress::TIterator ai = addr->Begin(); ai != addr->End(); ++ai) { + TProto* proto = Find(ai->ai_family); + + if (proto) { + TRequestPacket rp(ToString(loc.Service), msg.Data); + TRequestDescrRef rd(new TRequestDescr(rp.Guid, hndl, msg)); + IRemoteAddrPtr raddr(new TAddrInfo(&*ai)); + TPacketRef p(new TPacket(rp, std::move(raddr))); + + proto->Schedule(rd, p); + + return; + } + } + + ythrow yexception() << "unsupported protocol family"; + } + + private: + inline void Init(IOnRequest* cb, TSockets& s) { + for (auto& it : s) { + P_[it->Family] = new TProto(cb, it->S); + } + } + + inline TProto* Find(int family) const { + TProtoStorage::const_iterator it = P_.find(family); + + if (it == P_.end()) { + return nullptr; + } + + return it->second.Get(); + } + + private: + typedef TAutoPtr<TProto> TProtoRef; + typedef THashMap<int, TProtoRef> TProtoStorage; + TProtoStorage P_; + }; + + class TRequester: public IRequester, public TProtos { + public: + inline TRequester(IOnRequest* cb, ui16 port) + : TProtos(cb, port) + { + } + }; + + class TProtocol: public IProtocol { + public: + IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override { + return new TRequester(cb, loc.GetPort()); + } + + THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override { + TNotifyHandleRef ret(new TUdpHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss))); + + TProtos::Instance()->Schedule(msg, ret); + + return ret.Get(); + } + + TStringBuf Scheme() const noexcept override { + return TStringBuf("udp"); + } + }; + } +} + +IProtocol* NNeh::UdpProtocol() { + return Singleton<NUdp::TProtocol>(); +} |