#include "stdafx.h" #include <util/datetime/cputimer.h> #include <util/draft/holder_vector.h> #include <util/generic/utility.h> #include <util/generic/vector.h> #include <util/network/init.h> #include <util/network/poller.h> #include <library/cpp/deprecated/atomic/atomic.h> #include <util/system/byteorder.h> #include <util/system/defaults.h> #include <util/system/error.h> #include <util/system/event.h> #include <util/system/thread.h> #include <util/system/yassert.h> #include <util/system/rwlock.h> #include <util/system/env.h> #include "socket.h" #include "packet_queue.h" #include "udp_recv_packet.h" #include <array> #include <stdlib.h> /////////////////////////////////////////////////////////////////////////////// #ifndef _win_ #include <netinet/in.h> #endif #ifdef _linux_ #include <dlfcn.h> // dlsym #endif template <class T> static T GetAddressOf(const char* name) { #ifdef _linux_ if (!GetEnv("DISABLE_MMSG")) { return (T)dlsym(RTLD_DEFAULT, name); } #endif Y_UNUSED(name); return nullptr; } /////////////////////////////////////////////////////////////////////////////// namespace NNetlibaSocket { /////////////////////////////////////////////////////////////////////////////// struct timespec; // we use it only as NULL pointer typedef int (*TSendMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int); typedef int (*TRecvMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int, timespec*); static const TSendMMsgFunc SendMMsgFunc = GetAddressOf<TSendMMsgFunc>("sendmmsg"); static const TRecvMMsgFunc RecvMMsgFunc = GetAddressOf<TRecvMMsgFunc>("recvmmsg"); /////////////////////////////////////////////////////////////////////////////// bool ReadTos(const TMsgHdr& msgHdr, ui8* tos) { #ifdef _win_ Y_UNUSED(msgHdr); Y_UNUSED(tos); return false; #else cmsghdr* cmsg = CMSG_FIRSTHDR(&msgHdr); if (!cmsg) return false; //Y_ASSERT(cmsg->cmsg_level == IPPROTO_IPV6); //Y_ASSERT(cmsg->cmsg_type == IPV6_TCLASS); if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) return false; *tos = *(ui8*)CMSG_DATA(cmsg); return true; #endif } bool ExtractDestinationAddress(TMsgHdr& msgHdr, sockaddr_in6* addrBuf) { Zero(*addrBuf); #ifdef _win_ Y_UNUSED(msgHdr); Y_UNUSED(addrBuf); return false; #else cmsghdr* cmsg; for (cmsg = CMSG_FIRSTHDR(&msgHdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msgHdr, cmsg)) { if ((cmsg->cmsg_level == IPPROTO_IPV6) && (cmsg->cmsg_type == IPV6_PKTINFO)) { in6_pktinfo* i = (in6_pktinfo*)CMSG_DATA(cmsg); addrBuf->sin6_addr = i->ipi6_addr; addrBuf->sin6_family = AF_INET6; return true; } } return false; #endif } // all send and recv methods are thread safe! class TAbstractSocket: public ISocket { private: SOCKET S; mutable TSocketPoller Poller; sockaddr_in6 SelfAddress; int SendSysSocketSize; int SendSysSocketSizePrev; int CreateSocket(int netPort); int DetectSelfAddress(); protected: int SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len); int OpenImpl(int port); void CloseImpl(); void WaitImpl(float timeoutSec) const; void CancelWaitImpl(const sockaddr_in6* address = nullptr); // NULL means "self" ssize_t RecvMsgImpl(TMsgHdr* hdr, int flags); TUdpRecvPacket* RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr); int RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout); bool IsFragmentationForbiden(); void ForbidFragmentation(); void EnableFragmentation(); //Shared state for setsockopt. Forbid simultaneous transfer while sender asking for specific options (i.e. DONOT_FRAG) TRWMutex Mutex; TAtomic RecvLag = 0; public: TAbstractSocket(); ~TAbstractSocket() override; #ifdef _unix_ void Reset(const TAbstractSocket& rhv); #endif bool IsValid() const override; const sockaddr_in6& GetSelfAddress() const override; int GetNetworkOrderPort() const override; int GetPort() const override; int GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) override; // send all packets to this and only this address by default int Connect(const struct sockaddr* address, socklen_t address_len) override; void CancelWaitHost(const sockaddr_in6 addr) override; bool IsSendMMsgSupported() const override; int SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) override; ssize_t SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) override; bool IncreaseSendBuff() override; int GetSendSysSocketSize() override; void SetRecvLagTime(NHPTimer::STime time) override; }; TAbstractSocket::TAbstractSocket() : S(INVALID_SOCKET) , SendSysSocketSize(0) , SendSysSocketSizePrev(0) { Zero(SelfAddress); } TAbstractSocket::~TAbstractSocket() { CloseImpl(); } #ifdef _unix_ void TAbstractSocket::Reset(const TAbstractSocket& rhv) { Close(); S = dup(rhv.S); SelfAddress = rhv.SelfAddress; } #endif int TAbstractSocket::CreateSocket(int netPort) { if (IsValid()) { Y_ASSERT(0); return 0; } S = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); if (S == INVALID_SOCKET) { return -1; } { int flag = 0; Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&flag, sizeof(flag)) == 0, "IPV6_V6ONLY failed"); } { int flag = 1; Y_VERIFY(SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (const char*)&flag, sizeof(flag)) == 0, "SO_REUSEADDR failed"); } #if defined(_win_) unsigned long dummy = 1; ioctlsocket(S, FIONBIO, &dummy); #else Y_VERIFY(fcntl(S, F_SETFL, O_NONBLOCK) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); Y_VERIFY(fcntl(S, F_SETFD, FD_CLOEXEC) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); { int flag = 1; #ifndef IPV6_RECVPKTINFO /* Darwin platforms require this */ Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_PKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_PKTINFO failed"); #else Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_RECVPKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_RECVPKTINFO failed"); #endif } #endif Poller.WaitRead(S, nullptr); { // bind socket sockaddr_in6 name; Zero(name); name.sin6_family = AF_INET6; name.sin6_addr = in6addr_any; name.sin6_port = netPort; if (bind(S, (sockaddr*)&name, sizeof(name)) != 0) { fprintf(stderr, "netliba_socket could not bind to port %d: %s (errno = %d)\n", InetToHost((ui16)netPort), LastSystemErrorText(), LastSystemError()); CloseImpl(); // we call this CloseImpl after Poller initialization return -1; } } //Default behavior is allowing fragmentation (according to netliba v6 behavior) //If we want to sent packet with DF flag we have to use SendMsg() EnableFragmentation(); { socklen_t sz = sizeof(SendSysSocketSize); if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &SendSysSocketSize, &sz)) { fprintf(stderr, "Can`t get SO_SNDBUF"); } } return 0; } bool TAbstractSocket::IsValid() const { return S != INVALID_SOCKET; } int TAbstractSocket::DetectSelfAddress() { socklen_t nameLen = sizeof(SelfAddress); if (getsockname(S, (sockaddr*)&SelfAddress, &nameLen) != 0) { // actually we use only sin6_port return -1; } Y_ASSERT(SelfAddress.sin6_family == AF_INET6); SelfAddress.sin6_addr = in6addr_loopback; return 0; } const sockaddr_in6& TAbstractSocket::GetSelfAddress() const { return SelfAddress; } int TAbstractSocket::GetNetworkOrderPort() const { return SelfAddress.sin6_port; } int TAbstractSocket::GetPort() const { return InetToHost((ui16)SelfAddress.sin6_port); } int TAbstractSocket::SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len) { const int rv = setsockopt(S, level, option_name, (const char*)option_value, option_len); Y_VERIFY_DEBUG(rv == 0, "SetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); return rv; } int TAbstractSocket::GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) { const int rv = getsockopt(S, level, option_name, (char*)option_value, option_len); Y_VERIFY_DEBUG(rv == 0, "GetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); return rv; } bool TAbstractSocket::IsFragmentationForbiden() { #if defined(_win_) DWORD flag = 0; socklen_t sz = sizeof(flag); Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (char*)&flag, &sz) == 0, ""); return flag; #elif defined(_linux_) int flag = 0; socklen_t sz = sizeof(flag); Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (char*)&flag, &sz) == 0, ""); return flag == IPV6_PMTUDISC_DO; #elif !defined(_darwin_) int flag = 0; socklen_t sz = sizeof(flag); Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (char*)&flag, &sz) == 0, ""); return flag; #endif return false; } void TAbstractSocket::ForbidFragmentation() { // do not fragment ping packets #if defined(_win_) DWORD flag = 1; SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag)); #elif defined(_linux_) int flag = IP_PMTUDISC_DO; SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); flag = IPV6_PMTUDISC_DO; SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); #elif !defined(_darwin_) int flag = 1; //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag)); SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag)); #endif } void TAbstractSocket::EnableFragmentation() { #if defined(_win_) DWORD flag = 0; SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag)); #elif defined(_linux_) int flag = IP_PMTUDISC_WANT; SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); flag = IPV6_PMTUDISC_WANT; SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); #elif !defined(_darwin_) int flag = 0; //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag)); SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag)); #endif } int TAbstractSocket::Connect(const sockaddr* address, socklen_t address_len) { Y_ASSERT(IsValid()); return connect(S, address, address_len); } void TAbstractSocket::CancelWaitHost(const sockaddr_in6 addr) { CancelWaitImpl(&addr); } bool TAbstractSocket::IsSendMMsgSupported() const { return SendMMsgFunc != nullptr; } int TAbstractSocket::SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) { Y_ASSERT(IsValid()); Y_VERIFY(SendMMsgFunc, "sendmmsg is not supported!"); TReadGuard rg(Mutex); static bool checked = 0; Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); return SendMMsgFunc(S, msgvec, vlen, flags); } ssize_t TAbstractSocket::SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) { Y_ASSERT(IsValid()); #ifdef _win32_ static bool checked = 0; Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows"); if (hdr->Tos || frag == FF_DONT_FRAG) { TWriteGuard wg(Mutex); if (frag == FF_DONT_FRAG) { ForbidFragmentation(); } else { Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); } int originalTos; if (hdr->Tos) { socklen_t sz = sizeof(originalTos); Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, &sz) == 0, ""); Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&hdr->Tos, sizeof(hdr->Tos)) == 0, ""); } const ssize_t rv = sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen); if (hdr->Tos) { Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, sizeof(originalTos)) == 0, ""); } if (frag == FF_DONT_FRAG) { EnableFragmentation(); } return rv; } TReadGuard rg(Mutex); Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); return sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen); #else if (frag == FF_DONT_FRAG) { TWriteGuard wg(Mutex); ForbidFragmentation(); const ssize_t rv = sendmsg(S, hdr, flags); EnableFragmentation(); return rv; } TReadGuard rg(Mutex); #ifndef _darwin_ static bool checked = 0; Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); #endif return sendmsg(S, hdr, flags); #endif } bool TAbstractSocket::IncreaseSendBuff() { int buffSize; socklen_t sz = sizeof(buffSize); if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, &sz)) { return false; } // worst case: 200000 pps * 8k * 0.01sec = 16Mb so 32Mb hard limit is reasonable value if (buffSize < 0 || buffSize > (1 << 25)) { fprintf(stderr, "GetSockOpt returns wrong or too big value for SO_SNDBUF: %d\n", buffSize); return false; } //linux returns the doubled value. man 7 socket: // // SO_SNDBUF // Sets or gets the maximum socket send buffer in bytes. The ker- // nel doubles this value (to allow space for bookkeeping overhead) // when it is set using setsockopt(), and this doubled value is // returned by getsockopt(). The default value is set by the // wmem_default sysctl and the maximum allowed value is set by the // wmem_max sysctl. The minimum (doubled) value for this option is // 2048. // #ifndef _linux_ buffSize += buffSize; #endif // false if previous value was less than current value. // It means setsockopt was not successful. (for example: system limits) // we will try to set it again but return false const bool rv = !(buffSize <= SendSysSocketSizePrev); if (SetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, sz) == 0) { SendSysSocketSize = buffSize; SendSysSocketSizePrev = buffSize; return rv; } return false; } int TAbstractSocket::GetSendSysSocketSize() { return SendSysSocketSize; } void TAbstractSocket::SetRecvLagTime(NHPTimer::STime time) { AtomicSet(RecvLag, time); } int TAbstractSocket::OpenImpl(int port) { Y_ASSERT(!IsValid()); const int netPort = port ? htons((u_short)port) : 0; #ifdef _freebsd_ // alternative OS if (netPort == 0) { static ui64 pp = GetCycleCount(); for (int attempt = 0; attempt < 100; ++attempt) { const int tryPort = htons((pp & 0x3fff) + 0xc000); ++pp; if (CreateSocket(tryPort) != 0) { Y_ASSERT(!IsValid()); continue; } if (DetectSelfAddress() != 0 || tryPort != SelfAddress.sin6_port) { // FreeBSD suck! CloseImpl(); Y_ASSERT(!IsValid()); continue; } break; } if (!IsValid()) { return -1; } } else { if (CreateSocket(netPort) != 0) { Y_ASSERT(!IsValid()); return -1; } } #else // regular OS if (CreateSocket(netPort) != 0) { Y_ASSERT(!IsValid()); return -1; } #endif if (IsValid() && DetectSelfAddress() != 0) { CloseImpl(); Y_ASSERT(!IsValid()); return -1; } Y_ASSERT(IsValid()); return 0; } void TAbstractSocket::CloseImpl() { if (IsValid()) { Poller.Unwait(S); Y_VERIFY(closesocket(S) == 0, "closesocket failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); } S = INVALID_SOCKET; } void TAbstractSocket::WaitImpl(float timeoutSec) const { Y_VERIFY(IsValid(), "something went wrong"); Poller.WaitT(TDuration::Seconds(timeoutSec)); } void TAbstractSocket::CancelWaitImpl(const sockaddr_in6* address) { Y_ASSERT(IsValid()); // darwin ignores packets with msg_iovlen == 0, also windows implementation uses sendto of first iovec. TIoVec v = CreateIoVec(nullptr, 0); TMsgHdr hdr = CreateSendMsgHdr((address ? *address : SelfAddress), v, nullptr); // send self fake packet TAbstractSocket::SendMsg(&hdr, 0, FF_ALLOW_FRAG); } ssize_t TAbstractSocket::RecvMsgImpl(TMsgHdr* hdr, int flags) { Y_ASSERT(IsValid()); #ifdef _win32_ Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows"); return recvfrom(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, &hdr->msg_namelen); #else return recvmsg(S, hdr, flags); #endif } TUdpRecvPacket* TAbstractSocket::RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) { Y_ASSERT(IsValid()); const TIoVec iov = CreateIoVec(buf->GetDataPtr(), buf->GetBufSize()); char controllBuffer[CTRL_BUFFER_SIZE]; //used to get dst address from socket TMsgHdr hdr = CreateRecvMsgHdr(srcAddr, iov, controllBuffer); const ssize_t rv = TAbstractSocket::RecvMsgImpl(&hdr, 0); if (rv < 0) { Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK); return nullptr; } if (dstAddr && !ExtractDestinationAddress(hdr, dstAddr)) { //fprintf(stderr, "can`t get destination ip\n"); } // we extract packet and allocate new buffer only if packet arrived TUdpRecvPacket* result = buf->ExtractPacket(); result->DataStart = 0; result->DataSize = (int)rv; return result; } // thread-safe int TAbstractSocket::RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout) { Y_ASSERT(IsValid()); Y_VERIFY(RecvMMsgFunc, "recvmmsg is not supported!"); return RecvMMsgFunc(S, msgvec, vlen, flags, timeout); } /////////////////////////////////////////////////////////////////////////////// class TSocket: public TAbstractSocket { public: int Open(int port) override; void Close() override; void Wait(float timeoutSec, int netlibaVersion) const override; void CancelWait(int netlibaVersion) override; bool IsRecvMsgSupported() const override; ssize_t RecvMsg(TMsgHdr* hdr, int flags) override; TUdpRecvPacket* Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) override; private: TUdpHostRecvBufAlloc RecvBuf; }; int TSocket::Open(int port) { return OpenImpl(port); } void TSocket::Close() { CloseImpl(); } void TSocket::Wait(float timeoutSec, int netlibaVersion) const { Y_UNUSED(netlibaVersion); WaitImpl(timeoutSec); } void TSocket::CancelWait(int netlibaVersion) { Y_UNUSED(netlibaVersion); CancelWaitImpl(); } bool TSocket::IsRecvMsgSupported() const { return true; } ssize_t TSocket::RecvMsg(TMsgHdr* hdr, int flags) { return RecvMsgImpl(hdr, flags); } TUdpRecvPacket* TSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) { Y_UNUSED(netlibaVersion); return RecvImpl(&RecvBuf, srcAddr, dstAddr); } /////////////////////////////////////////////////////////////////////////////// class TTryToRecvMMsgSocket: public TAbstractSocket { private: THolderVector<TUdpHostRecvBufAlloc> RecvPackets; TVector<sockaddr_in6> RecvPacketsSrcAddresses; TVector<TIoVec> RecvPacketsIoVecs; size_t RecvPacketsBegin; // first non returned to user size_t RecvPacketsHeadersEnd; // next after last one with data TVector<TMMsgHdr> RecvPacketsHeaders; TVector<std::array<char, CTRL_BUFFER_SIZE>> RecvPacketsCtrlBuffers; int FillRecvBuffers(); public: static bool IsRecvMMsgSupported(); // Tests showed best performance on queue size 128 (+7%). // If memory is limited you can use 12 - it gives +4%. // Do not use lower values - for example recvmmsg with 1 element is 3% slower that recvmsg! // (tested with junk/f0b0s/neTBasicSocket_queue_test). TTryToRecvMMsgSocket(const size_t recvQueueSize = 128); ~TTryToRecvMMsgSocket() override; int Open(int port) override; void Close() override; void Wait(float timeoutSec, int netlibaVersion) const override; void CancelWait(int netlibaVersion) override; bool IsRecvMsgSupported() const override { return false; } ssize_t RecvMsg(TMsgHdr* hdr, int flags) override { Y_UNUSED(hdr); Y_UNUSED(flags); Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TRecvMMsgSocket implementation must use memcpy which is suboptimal and thus forbidden!"); } TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override; }; TTryToRecvMMsgSocket::TTryToRecvMMsgSocket(const size_t recvQueueSize) : RecvPacketsBegin(0) , RecvPacketsHeadersEnd(0) { // recvmmsg is not supported - will act like TSocket, // we can't just VERIFY - TTryToRecvMMsgSocket is used as base class for TDualStackSocket. if (!IsRecvMMsgSupported()) { RecvPackets.reserve(1); RecvPackets.PushBack(new TUdpHostRecvBufAlloc); return; } RecvPackets.reserve(recvQueueSize); for (size_t i = 0; i != recvQueueSize; ++i) { RecvPackets.PushBack(new TUdpHostRecvBufAlloc); } RecvPacketsSrcAddresses.resize(recvQueueSize); RecvPacketsIoVecs.resize(recvQueueSize); RecvPacketsHeaders.resize(recvQueueSize); RecvPacketsCtrlBuffers.resize(recvQueueSize); for (size_t i = 0; i != recvQueueSize; ++i) { TMMsgHdr& mhdr = RecvPacketsHeaders[i]; Zero(mhdr); RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize()); char* buf = RecvPacketsCtrlBuffers[i].data(); memset(buf, 0, CTRL_BUFFER_SIZE); mhdr.msg_hdr = CreateRecvMsgHdr(&RecvPacketsSrcAddresses[i], RecvPacketsIoVecs[i], buf); } } TTryToRecvMMsgSocket::~TTryToRecvMMsgSocket() { Close(); } int TTryToRecvMMsgSocket::Open(int port) { return OpenImpl(port); } void TTryToRecvMMsgSocket::Close() { CloseImpl(); } void TTryToRecvMMsgSocket::Wait(float timeoutSec, int netlibaVersion) const { Y_UNUSED(netlibaVersion); Y_ASSERT(RecvPacketsBegin == RecvPacketsHeadersEnd || IsRecvMMsgSupported()); if (RecvPacketsBegin == RecvPacketsHeadersEnd) { WaitImpl(timeoutSec); } } void TTryToRecvMMsgSocket::CancelWait(int netlibaVersion) { Y_UNUSED(netlibaVersion); CancelWaitImpl(); } bool TTryToRecvMMsgSocket::IsRecvMMsgSupported() { return RecvMMsgFunc != nullptr; } int TTryToRecvMMsgSocket::FillRecvBuffers() { Y_ASSERT(IsRecvMMsgSupported()); Y_ASSERT(RecvPacketsBegin <= RecvPacketsHeadersEnd); if (RecvPacketsBegin < RecvPacketsHeadersEnd) { return RecvPacketsHeadersEnd - RecvPacketsBegin; } // no packets left from last recvmmsg call for (size_t i = 0; i != RecvPacketsHeadersEnd; ++i) { // reinit only used by last recvmmsg call headers RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize()); } RecvPacketsBegin = RecvPacketsHeadersEnd = 0; const int r = RecvMMsgImpl(&RecvPacketsHeaders[0], (unsigned int)RecvPacketsHeaders.size(), 0, nullptr); if (r >= 0) { RecvPacketsHeadersEnd = r; } else { Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK); } return r; } // not thread-safe TUdpRecvPacket* TTryToRecvMMsgSocket::Recv(sockaddr_in6* fromAddress, sockaddr_in6* dstAddr, int) { // act like TSocket if (!IsRecvMMsgSupported()) { return RecvImpl(RecvPackets[0], fromAddress, dstAddr); } if (FillRecvBuffers() <= 0) { return nullptr; } TUdpRecvPacket* result = RecvPackets[RecvPacketsBegin]->ExtractPacket(); TMMsgHdr& mmsgHdr = RecvPacketsHeaders[RecvPacketsBegin]; result->DataSize = (ssize_t)mmsgHdr.msg_len; if (dstAddr && !ExtractDestinationAddress(mmsgHdr.msg_hdr, dstAddr)) { // fprintf(stderr, "can`t get destination ip\n"); } *fromAddress = RecvPacketsSrcAddresses[RecvPacketsBegin]; //we must clean ctrlbuffer to be able to use it later #ifndef _win_ memset(mmsgHdr.msg_hdr.msg_control, 0, CTRL_BUFFER_SIZE); mmsgHdr.msg_hdr.msg_controllen = CTRL_BUFFER_SIZE; #endif RecvPacketsBegin++; return result; } /////////////////////////////////////////////////////////////////////////////// /* TODO: too slow, needs to be optimized template<size_t TTNumRecvThreads> class TMTRecvSocket: public TAbstractSocket { private: typedef TLockFreePacketQueue<TTNumRecvThreads> TPacketQueue; static void* RecvThreadFunc(void* that) { static_cast<TMTRecvSocket*>(that)->RecvLoop(); return NULL; } void RecvLoop() { TBestUnixRecvSocket impl; impl.Reset(*this); while (AtomicAdd(NumThreadsToDie, 0) == -1) { sockaddr_in6 addr; TUdpRecvPacket* packet = impl.Recv(&addr, NETLIBA_ANY_VERSION); if (!packet) { impl.Wait(0.0001, NETLIBA_ANY_VERSION); // so small tiomeout because we can't guarantee that 1 thread won't get all packets continue; } Queue.Push(packet, addr); } if (AtomicDecrement(NumThreadsToDie)) { impl.CancelWait(NETLIBA_ANY_VERSION); } else { AllThreadsAreDead.Signal(); } } THolderVector<TThread> RecvThreads; TAtomic NumThreadsToDie; TSystemEvent AllThreadsAreDead; TPacketQueue Queue; public: TMTRecvSocket() : NumThreadsToDie(-1) {} ~TMTRecvSocket() { Close(); } int Open(int port) { if (OpenImpl(port) != 0) { Y_ASSERT(!IsValid()); return -1; } NumThreadsToDie = -1; RecvThreads.reserve(TTNumRecvThreads); for (size_t i = 0; i != TTNumRecvThreads; ++i) { RecvThreads.PushBack(new TThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_recv_skt"))); RecvThreads.back()->Start(); RecvThreads.back()->Detach(); } return 0; } void Close() { if (!IsValid()) { return; } AtomicSwap(&NumThreadsToDie, (int)RecvThreads.size()); CancelWaitImpl(); Y_VERIFY(AllThreadsAreDead.WaitT(TDuration::Seconds(30)), "TMTRecvSocket destruction failed"); CloseImpl(); } void Wait(float timeoutSec, int netlibaVersion) const { Y_UNUSED(netlibaVersion); Queue.GetEvent().WaitT(TDuration::Seconds(timeoutSec)); } void CancelWait(int netlibaVersion) { Y_UNUSED(netlibaVersion); Queue.GetEvent().Signal(); } TUdpRecvPacket* Recv(sockaddr_in6 *addr, int netlibaVersion) { Y_UNUSED(netlibaVersion); TUdpRecvPacket* result; if (!Queue.Pop(&result, addr)) { return NULL; } return result; } bool IsRecvMsgSupported() const { return false; } ssize_t RecvMsg(TMsgHdr* hdr, int flags) { Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TMTRecvSocket implementation must use memcpy which is suboptimal and thus forbidden!"); } }; */ /////////////////////////////////////////////////////////////////////////////// // Send.*, Recv, Wait and CancelWait are thread-safe. class TDualStackSocket: public TTryToRecvMMsgSocket { private: typedef TTryToRecvMMsgSocket TBase; typedef TLockFreePacketQueue<1> TPacketQueue; static void* RecvThreadFunc(void* that); void RecvLoop(); struct TFilteredPacketQueue { enum EPushResult { PR_FULL = 0, PR_OK = 1, PR_FILTERED = 2 }; const ui8 F1; const ui8 F2; const ui8 CmdPos; TFilteredPacketQueue(ui8 f1, ui8 f2, ui8 cmdPos) : F1(f1) , F2(f2) , CmdPos(cmdPos) { } bool Pop(TUdpRecvPacket** packet, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) { return Queue.Pop(packet, srcAddr, dstAddr); } ui8 Push(TUdpRecvPacket* packet, const TPacketMeta& meta) { if (Queue.IsDataPartFull()) { const ui8 cmd = packet->Data.get()[CmdPos]; if (cmd == F1 || cmd == F2) return PR_FILTERED; } return Queue.Push(packet, meta); //false - PR_FULL, true - PR_OK } TPacketQueue Queue; }; TFilteredPacketQueue& GetRecvQueue(int netlibaVersion) const; TSystemEvent& GetQueueEvent(const TFilteredPacketQueue& queue) const; TThread RecvThread; TAtomic ShouldDie; TSystemEvent DieEvent; mutable TFilteredPacketQueue RecvQueue6; mutable TFilteredPacketQueue RecvQueue12; public: TDualStackSocket(); ~TDualStackSocket() override; int Open(int port) override; void Close() override; void Wait(float timeoutSec, int netlibaVersion) const override; void CancelWait(int netlibaVersion) override; bool IsRecvMsgSupported() const override { return false; } ssize_t RecvMsg(TMsgHdr* hdr, int flags) override { Y_UNUSED(hdr); Y_UNUSED(flags); Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TDualStackSocket implementation must use memcpy which is suboptimal and thus forbidden!"); } TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override; }; TDualStackSocket::TDualStackSocket() : RecvThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_dual_stack")) , ShouldDie(0) , RecvQueue6(NNetliba::DATA, NNetliba::DATA_SMALL, NNetliba::CMD_POS) , RecvQueue12(NNetliba_v12::DATA, NNetliba_v12::DATA_SMALL, NNetliba_v12::CMD_POS) { } // virtual functions don't work in dtors! TDualStackSocket::~TDualStackSocket() { Close(); sockaddr_in6 srcAdd; sockaddr_in6 dstAddr; TUdpRecvPacket* ptr = nullptr; while (GetRecvQueue(NETLIBA_ANY_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) { delete ptr; } while (GetRecvQueue(NETLIBA_V12_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) { delete ptr; } } int TDualStackSocket::Open(int port) { if (TBase::Open(port) != 0) { Y_ASSERT(!IsValid()); return -1; } AtomicSet(ShouldDie, 0); DieEvent.Reset(); RecvThread.Start(); RecvThread.Detach(); return 0; } void TDualStackSocket::Close() { if (!IsValid()) { return; } AtomicSwap(&ShouldDie, 1); CancelWaitImpl(); Y_VERIFY(DieEvent.WaitT(TDuration::Seconds(30)), "TDualStackSocket::Close failed"); TBase::Close(); } TDualStackSocket::TFilteredPacketQueue& TDualStackSocket::GetRecvQueue(int netlibaVersion) const { return netlibaVersion == NETLIBA_V12_VERSION ? RecvQueue12 : RecvQueue6; } TSystemEvent& TDualStackSocket::GetQueueEvent(const TFilteredPacketQueue& queue) const { return queue.Queue.GetEvent(); } void* TDualStackSocket::RecvThreadFunc(void* that) { SetHighestThreadPriority(); static_cast<TDualStackSocket*>(that)->RecvLoop(); return nullptr; } void TDualStackSocket::RecvLoop() { for (;;) { TUdpRecvPacket* p = nullptr; sockaddr_in6 srcAddr; sockaddr_in6 dstAddr; while (AtomicAdd(ShouldDie, 0) == 0 && (p = TBase::Recv(&srcAddr, &dstAddr, NETLIBA_ANY_VERSION))) { Y_ASSERT(p->DataStart == 0); if (p->DataSize < 12) { continue; } TFilteredPacketQueue& q = GetRecvQueue(p->Data.get()[8]); const ui8 res = q.Push(p, {srcAddr, dstAddr}); if (res == TFilteredPacketQueue::PR_OK) { GetQueueEvent(q).Signal(); } else { // simulate OS behavior on buffer overflow - drop packets. const NHPTimer::STime time = AtomicGet(RecvLag); const float sec = NHPTimer::GetSeconds(time); fprintf(stderr, "TDualStackSocket::RecvLoop netliba v%d queue overflow, recv lag: %f sec, dropping packet, res: %u\n", &q == &RecvQueue12 ? 12 : 6, sec, res); delete p; } } if (AtomicAdd(ShouldDie, 0)) { DieEvent.Signal(); return; } TBase::Wait(0.1f, NETLIBA_ANY_VERSION); } } void TDualStackSocket::Wait(float timeoutSec, int netlibaVersion) const { TFilteredPacketQueue& q = GetRecvQueue(netlibaVersion); if (q.Queue.IsEmpty()) { GetQueueEvent(q).Reset(); if (q.Queue.IsEmpty()) { GetQueueEvent(q).WaitT(TDuration::Seconds(timeoutSec)); } } } void TDualStackSocket::CancelWait(int netlibaVersion) { GetQueueEvent(GetRecvQueue(netlibaVersion)).Signal(); } // thread-safe TUdpRecvPacket* TDualStackSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) { TUdpRecvPacket* result = nullptr; if (!GetRecvQueue(netlibaVersion).Pop(&result, srcAddr, dstAddr)) { return nullptr; } return result; } /////////////////////////////////////////////////////////////////////////////// TIntrusivePtr<ISocket> CreateSocket() { return new TSocket(); } TIntrusivePtr<ISocket> CreateDualStackSocket() { return new TDualStackSocket(); } TIntrusivePtr<ISocket> CreateBestRecvSocket() { // TSocket is faster than TRecvMMsgFunc in case of unsupported recvmmsg if (!TTryToRecvMMsgSocket::IsRecvMMsgSupported()) { return new TSocket(); } return new TTryToRecvMMsgSocket(); } }