diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/netliba/socket/socket.cpp | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/netliba/socket/socket.cpp')
-rw-r--r-- | library/cpp/netliba/socket/socket.cpp | 1086 |
1 files changed, 1086 insertions, 0 deletions
diff --git a/library/cpp/netliba/socket/socket.cpp b/library/cpp/netliba/socket/socket.cpp new file mode 100644 index 0000000000..c10236229b --- /dev/null +++ b/library/cpp/netliba/socket/socket.cpp @@ -0,0 +1,1086 @@ +#include "stdafx.h" +#include <util/datetime/cputimer.h> +#include <util/draft/holder_vector.h> +#include <util/generic/utility.h> +#include <util/generic/vector.h> +#include <util/network/init.h> +#include <util/network/poller.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <util/system/byteorder.h> +#include <util/system/defaults.h> +#include <util/system/error.h> +#include <util/system/event.h> +#include <util/system/thread.h> +#include <util/system/yassert.h> +#include <util/system/rwlock.h> +#include <util/system/env.h> + +#include "socket.h" +#include "packet_queue.h" +#include "udp_recv_packet.h" + +#include <array> +#include <stdlib.h> + +/////////////////////////////////////////////////////////////////////////////// + +#ifndef _win_ +#include <netinet/in.h> +#endif + +#ifdef _linux_ +#include <dlfcn.h> // dlsym +#endif + +template <class T> +static T GetAddressOf(const char* name) { +#ifdef _linux_ + if (!GetEnv("DISABLE_MMSG")) { + return (T)dlsym(RTLD_DEFAULT, name); + } +#endif + Y_UNUSED(name); + return nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// + +namespace NNetlibaSocket { + /////////////////////////////////////////////////////////////////////////////// + + struct timespec; // we use it only as NULL pointer + typedef int (*TSendMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int); + typedef int (*TRecvMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int, timespec*); + + static const TSendMMsgFunc SendMMsgFunc = GetAddressOf<TSendMMsgFunc>("sendmmsg"); + static const TRecvMMsgFunc RecvMMsgFunc = GetAddressOf<TRecvMMsgFunc>("recvmmsg"); + + /////////////////////////////////////////////////////////////////////////////// + + bool ReadTos(const TMsgHdr& msgHdr, ui8* tos) { +#ifdef _win_ + Y_UNUSED(msgHdr); + Y_UNUSED(tos); + return false; +#else + cmsghdr* cmsg = CMSG_FIRSTHDR(&msgHdr); + if (!cmsg) + return false; + //Y_ASSERT(cmsg->cmsg_level == IPPROTO_IPV6); + //Y_ASSERT(cmsg->cmsg_type == IPV6_TCLASS); + if (cmsg->cmsg_len != CMSG_LEN(sizeof(int))) + return false; + *tos = *(ui8*)CMSG_DATA(cmsg); + return true; +#endif + } + + bool ExtractDestinationAddress(TMsgHdr& msgHdr, sockaddr_in6* addrBuf) { + Zero(*addrBuf); +#ifdef _win_ + Y_UNUSED(msgHdr); + Y_UNUSED(addrBuf); + return false; +#else + cmsghdr* cmsg; + for (cmsg = CMSG_FIRSTHDR(&msgHdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msgHdr, cmsg)) { + if ((cmsg->cmsg_level == IPPROTO_IPV6) && (cmsg->cmsg_type == IPV6_PKTINFO)) { + in6_pktinfo* i = (in6_pktinfo*)CMSG_DATA(cmsg); + addrBuf->sin6_addr = i->ipi6_addr; + addrBuf->sin6_family = AF_INET6; + return true; + } + } + return false; +#endif + } + + // all send and recv methods are thread safe! + class TAbstractSocket: public ISocket { + private: + SOCKET S; + mutable TSocketPoller Poller; + sockaddr_in6 SelfAddress; + + int SendSysSocketSize; + int SendSysSocketSizePrev; + + int CreateSocket(int netPort); + int DetectSelfAddress(); + + protected: + int SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len); + + int OpenImpl(int port); + void CloseImpl(); + + void WaitImpl(float timeoutSec) const; + void CancelWaitImpl(const sockaddr_in6* address = nullptr); // NULL means "self" + + ssize_t RecvMsgImpl(TMsgHdr* hdr, int flags); + TUdpRecvPacket* RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr); + int RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout); + + bool IsFragmentationForbiden(); + void ForbidFragmentation(); + void EnableFragmentation(); + + //Shared state for setsockopt. Forbid simultaneous transfer while sender asking for specific options (i.e. DONOT_FRAG) + TRWMutex Mutex; + TAtomic RecvLag = 0; + + public: + TAbstractSocket(); + ~TAbstractSocket() override; +#ifdef _unix_ + void Reset(const TAbstractSocket& rhv); +#endif + + bool IsValid() const override; + + const sockaddr_in6& GetSelfAddress() const override; + int GetNetworkOrderPort() const override; + int GetPort() const override; + + int GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) override; + + // send all packets to this and only this address by default + int Connect(const struct sockaddr* address, socklen_t address_len) override; + + void CancelWaitHost(const sockaddr_in6 addr) override; + + bool IsSendMMsgSupported() const override; + int SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) override; + ssize_t SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) override; + bool IncreaseSendBuff() override; + int GetSendSysSocketSize() override; + void SetRecvLagTime(NHPTimer::STime time) override; + }; + + TAbstractSocket::TAbstractSocket() + : S(INVALID_SOCKET) + , SendSysSocketSize(0) + , SendSysSocketSizePrev(0) + { + Zero(SelfAddress); + } + + TAbstractSocket::~TAbstractSocket() { + CloseImpl(); + } + +#ifdef _unix_ + void TAbstractSocket::Reset(const TAbstractSocket& rhv) { + Close(); + S = dup(rhv.S); + SelfAddress = rhv.SelfAddress; + } +#endif + + int TAbstractSocket::CreateSocket(int netPort) { + if (IsValid()) { + Y_ASSERT(0); + return 0; + } + S = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP); + if (S == INVALID_SOCKET) { + return -1; + } + { + int flag = 0; + Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&flag, sizeof(flag)) == 0, "IPV6_V6ONLY failed"); + } + { + int flag = 1; + Y_VERIFY(SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (const char*)&flag, sizeof(flag)) == 0, "SO_REUSEADDR failed"); + } +#if defined(_win_) + unsigned long dummy = 1; + ioctlsocket(S, FIONBIO, &dummy); +#else + Y_VERIFY(fcntl(S, F_SETFL, O_NONBLOCK) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); + Y_VERIFY(fcntl(S, F_SETFD, FD_CLOEXEC) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); + { + int flag = 1; +#ifndef IPV6_RECVPKTINFO /* Darwin platforms require this */ + Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_PKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_PKTINFO failed"); +#else + Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_RECVPKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_RECVPKTINFO failed"); +#endif + } +#endif + + Poller.WaitRead(S, nullptr); + + { + // bind socket + sockaddr_in6 name; + Zero(name); + name.sin6_family = AF_INET6; + name.sin6_addr = in6addr_any; + name.sin6_port = netPort; + if (bind(S, (sockaddr*)&name, sizeof(name)) != 0) { + fprintf(stderr, "netliba_socket could not bind to port %d: %s (errno = %d)\n", InetToHost((ui16)netPort), LastSystemErrorText(), LastSystemError()); + CloseImpl(); // we call this CloseImpl after Poller initialization + return -1; + } + } + //Default behavior is allowing fragmentation (according to netliba v6 behavior) + //If we want to sent packet with DF flag we have to use SendMsg() + EnableFragmentation(); + + { + socklen_t sz = sizeof(SendSysSocketSize); + if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &SendSysSocketSize, &sz)) { + fprintf(stderr, "Can`t get SO_SNDBUF"); + } + } + return 0; + } + + bool TAbstractSocket::IsValid() const { + return S != INVALID_SOCKET; + } + + int TAbstractSocket::DetectSelfAddress() { + socklen_t nameLen = sizeof(SelfAddress); + if (getsockname(S, (sockaddr*)&SelfAddress, &nameLen) != 0) { // actually we use only sin6_port + return -1; + } + Y_ASSERT(SelfAddress.sin6_family == AF_INET6); + SelfAddress.sin6_addr = in6addr_loopback; + return 0; + } + + const sockaddr_in6& TAbstractSocket::GetSelfAddress() const { + return SelfAddress; + } + + int TAbstractSocket::GetNetworkOrderPort() const { + return SelfAddress.sin6_port; + } + + int TAbstractSocket::GetPort() const { + return InetToHost((ui16)SelfAddress.sin6_port); + } + + int TAbstractSocket::SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len) { + const int rv = setsockopt(S, level, option_name, (const char*)option_value, option_len); + Y_VERIFY_DEBUG(rv == 0, "SetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); + return rv; + } + + int TAbstractSocket::GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) { + const int rv = getsockopt(S, level, option_name, (char*)option_value, option_len); + Y_VERIFY_DEBUG(rv == 0, "GetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); + return rv; + } + + bool TAbstractSocket::IsFragmentationForbiden() { +#if defined(_win_) + DWORD flag = 0; + socklen_t sz = sizeof(flag); + Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (char*)&flag, &sz) == 0, ""); + return flag; +#elif defined(_linux_) + int flag = 0; + socklen_t sz = sizeof(flag); + Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (char*)&flag, &sz) == 0, ""); + return flag == IPV6_PMTUDISC_DO; +#elif !defined(_darwin_) + int flag = 0; + socklen_t sz = sizeof(flag); + Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (char*)&flag, &sz) == 0, ""); + return flag; +#endif + return false; + } + + void TAbstractSocket::ForbidFragmentation() { + // do not fragment ping packets +#if defined(_win_) + DWORD flag = 1; + SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag)); +#elif defined(_linux_) + int flag = IP_PMTUDISC_DO; + SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); + + flag = IPV6_PMTUDISC_DO; + SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); +#elif !defined(_darwin_) + int flag = 1; + //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag)); + SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag)); +#endif + } + + void TAbstractSocket::EnableFragmentation() { +#if defined(_win_) + DWORD flag = 0; + SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag)); +#elif defined(_linux_) + int flag = IP_PMTUDISC_WANT; + SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); + + flag = IPV6_PMTUDISC_WANT; + SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag)); +#elif !defined(_darwin_) + int flag = 0; + //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag)); + SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag)); +#endif + } + + int TAbstractSocket::Connect(const sockaddr* address, socklen_t address_len) { + Y_ASSERT(IsValid()); + return connect(S, address, address_len); + } + + void TAbstractSocket::CancelWaitHost(const sockaddr_in6 addr) { + CancelWaitImpl(&addr); + } + + bool TAbstractSocket::IsSendMMsgSupported() const { + return SendMMsgFunc != nullptr; + } + + int TAbstractSocket::SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) { + Y_ASSERT(IsValid()); + Y_VERIFY(SendMMsgFunc, "sendmmsg is not supported!"); + TReadGuard rg(Mutex); + static bool checked = 0; + Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); + return SendMMsgFunc(S, msgvec, vlen, flags); + } + + ssize_t TAbstractSocket::SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) { + Y_ASSERT(IsValid()); +#ifdef _win32_ + static bool checked = 0; + Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows"); + if (hdr->Tos || frag == FF_DONT_FRAG) { + TWriteGuard wg(Mutex); + if (frag == FF_DONT_FRAG) { + ForbidFragmentation(); + } else { + Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); + } + int originalTos; + if (hdr->Tos) { + socklen_t sz = sizeof(originalTos); + Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, &sz) == 0, ""); + Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&hdr->Tos, sizeof(hdr->Tos)) == 0, ""); + } + const ssize_t rv = sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen); + if (hdr->Tos) { + Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, sizeof(originalTos)) == 0, ""); + } + if (frag == FF_DONT_FRAG) { + EnableFragmentation(); + } + return rv; + } + TReadGuard rg(Mutex); + Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); + return sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen); +#else + if (frag == FF_DONT_FRAG) { + TWriteGuard wg(Mutex); + ForbidFragmentation(); + const ssize_t rv = sendmsg(S, hdr, flags); + EnableFragmentation(); + return rv; + } + + TReadGuard rg(Mutex); +#ifndef _darwin_ + static bool checked = 0; + Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior"); +#endif + return sendmsg(S, hdr, flags); +#endif + } + + bool TAbstractSocket::IncreaseSendBuff() { + int buffSize; + socklen_t sz = sizeof(buffSize); + if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, &sz)) { + return false; + } + // worst case: 200000 pps * 8k * 0.01sec = 16Mb so 32Mb hard limit is reasonable value + if (buffSize < 0 || buffSize > (1 << 25)) { + fprintf(stderr, "GetSockOpt returns wrong or too big value for SO_SNDBUF: %d\n", buffSize); + return false; + } + //linux returns the doubled value. man 7 socket: + // + // SO_SNDBUF + // Sets or gets the maximum socket send buffer in bytes. The ker- + // nel doubles this value (to allow space for bookkeeping overhead) + // when it is set using setsockopt(), and this doubled value is + // returned by getsockopt(). The default value is set by the + // wmem_default sysctl and the maximum allowed value is set by the + // wmem_max sysctl. The minimum (doubled) value for this option is + // 2048. + // +#ifndef _linux_ + buffSize += buffSize; +#endif + + // false if previous value was less than current value. + // It means setsockopt was not successful. (for example: system limits) + // we will try to set it again but return false + const bool rv = !(buffSize <= SendSysSocketSizePrev); + if (SetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, sz) == 0) { + SendSysSocketSize = buffSize; + SendSysSocketSizePrev = buffSize; + return rv; + } + return false; + } + + int TAbstractSocket::GetSendSysSocketSize() { + return SendSysSocketSize; + } + + void TAbstractSocket::SetRecvLagTime(NHPTimer::STime time) { + AtomicSet(RecvLag, time); + } + + int TAbstractSocket::OpenImpl(int port) { + Y_ASSERT(!IsValid()); + const int netPort = port ? htons((u_short)port) : 0; + +#ifdef _freebsd_ + // alternative OS + if (netPort == 0) { + static ui64 pp = GetCycleCount(); + for (int attempt = 0; attempt < 100; ++attempt) { + const int tryPort = htons((pp & 0x3fff) + 0xc000); + ++pp; + if (CreateSocket(tryPort) != 0) { + Y_ASSERT(!IsValid()); + continue; + } + + if (DetectSelfAddress() != 0 || tryPort != SelfAddress.sin6_port) { + // FreeBSD suck! + CloseImpl(); + Y_ASSERT(!IsValid()); + continue; + } + break; + } + if (!IsValid()) { + return -1; + } + } else { + if (CreateSocket(netPort) != 0) { + Y_ASSERT(!IsValid()); + return -1; + } + } +#else + // regular OS + if (CreateSocket(netPort) != 0) { + Y_ASSERT(!IsValid()); + return -1; + } +#endif + + if (IsValid() && DetectSelfAddress() != 0) { + CloseImpl(); + Y_ASSERT(!IsValid()); + return -1; + } + + Y_ASSERT(IsValid()); + return 0; + } + + void TAbstractSocket::CloseImpl() { + if (IsValid()) { + Poller.Unwait(S); + Y_VERIFY(closesocket(S) == 0, "closesocket failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError()); + } + S = INVALID_SOCKET; + } + + void TAbstractSocket::WaitImpl(float timeoutSec) const { + Y_VERIFY(IsValid(), "something went wrong"); + Poller.WaitT(TDuration::Seconds(timeoutSec)); + } + + void TAbstractSocket::CancelWaitImpl(const sockaddr_in6* address) { + Y_ASSERT(IsValid()); + + // darwin ignores packets with msg_iovlen == 0, also windows implementation uses sendto of first iovec. + TIoVec v = CreateIoVec(nullptr, 0); + TMsgHdr hdr = CreateSendMsgHdr((address ? *address : SelfAddress), v, nullptr); + + // send self fake packet + TAbstractSocket::SendMsg(&hdr, 0, FF_ALLOW_FRAG); + } + + ssize_t TAbstractSocket::RecvMsgImpl(TMsgHdr* hdr, int flags) { + Y_ASSERT(IsValid()); + +#ifdef _win32_ + Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows"); + return recvfrom(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, &hdr->msg_namelen); +#else + return recvmsg(S, hdr, flags); +#endif + } + + TUdpRecvPacket* TAbstractSocket::RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) { + Y_ASSERT(IsValid()); + + const TIoVec iov = CreateIoVec(buf->GetDataPtr(), buf->GetBufSize()); + char controllBuffer[CTRL_BUFFER_SIZE]; //used to get dst address from socket + TMsgHdr hdr = CreateRecvMsgHdr(srcAddr, iov, controllBuffer); + + const ssize_t rv = TAbstractSocket::RecvMsgImpl(&hdr, 0); + if (rv < 0) { + Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK); + return nullptr; + } + if (dstAddr && !ExtractDestinationAddress(hdr, dstAddr)) { + //fprintf(stderr, "can`t get destination ip\n"); + } + + // we extract packet and allocate new buffer only if packet arrived + TUdpRecvPacket* result = buf->ExtractPacket(); + result->DataStart = 0; + result->DataSize = (int)rv; + return result; + } + + // thread-safe + int TAbstractSocket::RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout) { + Y_ASSERT(IsValid()); + Y_VERIFY(RecvMMsgFunc, "recvmmsg is not supported!"); + return RecvMMsgFunc(S, msgvec, vlen, flags, timeout); + } + + /////////////////////////////////////////////////////////////////////////////// + + class TSocket: public TAbstractSocket { + public: + int Open(int port) override; + void Close() override; + + void Wait(float timeoutSec, int netlibaVersion) const override; + void CancelWait(int netlibaVersion) override; + + bool IsRecvMsgSupported() const override; + ssize_t RecvMsg(TMsgHdr* hdr, int flags) override; + TUdpRecvPacket* Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) override; + + private: + TUdpHostRecvBufAlloc RecvBuf; + }; + + int TSocket::Open(int port) { + return OpenImpl(port); + } + + void TSocket::Close() { + CloseImpl(); + } + + void TSocket::Wait(float timeoutSec, int netlibaVersion) const { + Y_UNUSED(netlibaVersion); + WaitImpl(timeoutSec); + } + + void TSocket::CancelWait(int netlibaVersion) { + Y_UNUSED(netlibaVersion); + CancelWaitImpl(); + } + + bool TSocket::IsRecvMsgSupported() const { + return true; + } + + ssize_t TSocket::RecvMsg(TMsgHdr* hdr, int flags) { + return RecvMsgImpl(hdr, flags); + } + + TUdpRecvPacket* TSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) { + Y_UNUSED(netlibaVersion); + return RecvImpl(&RecvBuf, srcAddr, dstAddr); + } + + /////////////////////////////////////////////////////////////////////////////// + + class TTryToRecvMMsgSocket: public TAbstractSocket { + private: + THolderVector<TUdpHostRecvBufAlloc> RecvPackets; + TVector<sockaddr_in6> RecvPacketsSrcAddresses; + TVector<TIoVec> RecvPacketsIoVecs; + size_t RecvPacketsBegin; // first non returned to user + size_t RecvPacketsHeadersEnd; // next after last one with data + TVector<TMMsgHdr> RecvPacketsHeaders; + TVector<std::array<char, CTRL_BUFFER_SIZE>> RecvPacketsCtrlBuffers; + + int FillRecvBuffers(); + + public: + static bool IsRecvMMsgSupported(); + + // Tests showed best performance on queue size 128 (+7%). + // If memory is limited you can use 12 - it gives +4%. + // Do not use lower values - for example recvmmsg with 1 element is 3% slower that recvmsg! + // (tested with junk/f0b0s/neTBasicSocket_queue_test). + TTryToRecvMMsgSocket(const size_t recvQueueSize = 128); + ~TTryToRecvMMsgSocket() override; + + int Open(int port) override; + void Close() override; + + void Wait(float timeoutSec, int netlibaVersion) const override; + void CancelWait(int netlibaVersion) override; + + bool IsRecvMsgSupported() const override { + return false; + } + ssize_t RecvMsg(TMsgHdr* hdr, int flags) override { + Y_UNUSED(hdr); + Y_UNUSED(flags); + Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TRecvMMsgSocket implementation must use memcpy which is suboptimal and thus forbidden!"); + } + TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override; + }; + + TTryToRecvMMsgSocket::TTryToRecvMMsgSocket(const size_t recvQueueSize) + : RecvPacketsBegin(0) + , RecvPacketsHeadersEnd(0) + { + // recvmmsg is not supported - will act like TSocket, + // we can't just VERIFY - TTryToRecvMMsgSocket is used as base class for TDualStackSocket. + if (!IsRecvMMsgSupported()) { + RecvPackets.reserve(1); + RecvPackets.PushBack(new TUdpHostRecvBufAlloc); + return; + } + + RecvPackets.reserve(recvQueueSize); + for (size_t i = 0; i != recvQueueSize; ++i) { + RecvPackets.PushBack(new TUdpHostRecvBufAlloc); + } + + RecvPacketsSrcAddresses.resize(recvQueueSize); + RecvPacketsIoVecs.resize(recvQueueSize); + RecvPacketsHeaders.resize(recvQueueSize); + RecvPacketsCtrlBuffers.resize(recvQueueSize); + + for (size_t i = 0; i != recvQueueSize; ++i) { + TMMsgHdr& mhdr = RecvPacketsHeaders[i]; + Zero(mhdr); + + RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize()); + char* buf = RecvPacketsCtrlBuffers[i].data(); + memset(buf, 0, CTRL_BUFFER_SIZE); + mhdr.msg_hdr = CreateRecvMsgHdr(&RecvPacketsSrcAddresses[i], RecvPacketsIoVecs[i], buf); + } + } + + TTryToRecvMMsgSocket::~TTryToRecvMMsgSocket() { + Close(); + } + + int TTryToRecvMMsgSocket::Open(int port) { + return OpenImpl(port); + } + + void TTryToRecvMMsgSocket::Close() { + CloseImpl(); + } + + void TTryToRecvMMsgSocket::Wait(float timeoutSec, int netlibaVersion) const { + Y_UNUSED(netlibaVersion); + Y_ASSERT(RecvPacketsBegin == RecvPacketsHeadersEnd || IsRecvMMsgSupported()); + if (RecvPacketsBegin == RecvPacketsHeadersEnd) { + WaitImpl(timeoutSec); + } + } + + void TTryToRecvMMsgSocket::CancelWait(int netlibaVersion) { + Y_UNUSED(netlibaVersion); + CancelWaitImpl(); + } + + bool TTryToRecvMMsgSocket::IsRecvMMsgSupported() { + return RecvMMsgFunc != nullptr; + } + + int TTryToRecvMMsgSocket::FillRecvBuffers() { + Y_ASSERT(IsRecvMMsgSupported()); + Y_ASSERT(RecvPacketsBegin <= RecvPacketsHeadersEnd); + if (RecvPacketsBegin < RecvPacketsHeadersEnd) { + return RecvPacketsHeadersEnd - RecvPacketsBegin; + } + + // no packets left from last recvmmsg call + for (size_t i = 0; i != RecvPacketsHeadersEnd; ++i) { // reinit only used by last recvmmsg call headers + RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize()); + } + RecvPacketsBegin = RecvPacketsHeadersEnd = 0; + + const int r = RecvMMsgImpl(&RecvPacketsHeaders[0], (unsigned int)RecvPacketsHeaders.size(), 0, nullptr); + if (r >= 0) { + RecvPacketsHeadersEnd = r; + } else { + Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK); + } + return r; + } + + // not thread-safe + TUdpRecvPacket* TTryToRecvMMsgSocket::Recv(sockaddr_in6* fromAddress, sockaddr_in6* dstAddr, int) { + // act like TSocket + if (!IsRecvMMsgSupported()) { + return RecvImpl(RecvPackets[0], fromAddress, dstAddr); + } + + if (FillRecvBuffers() <= 0) { + return nullptr; + } + + TUdpRecvPacket* result = RecvPackets[RecvPacketsBegin]->ExtractPacket(); + TMMsgHdr& mmsgHdr = RecvPacketsHeaders[RecvPacketsBegin]; + result->DataSize = (ssize_t)mmsgHdr.msg_len; + if (dstAddr && !ExtractDestinationAddress(mmsgHdr.msg_hdr, dstAddr)) { + // fprintf(stderr, "can`t get destination ip\n"); + } + *fromAddress = RecvPacketsSrcAddresses[RecvPacketsBegin]; + //we must clean ctrlbuffer to be able to use it later +#ifndef _win_ + memset(mmsgHdr.msg_hdr.msg_control, 0, CTRL_BUFFER_SIZE); + mmsgHdr.msg_hdr.msg_controllen = CTRL_BUFFER_SIZE; +#endif + RecvPacketsBegin++; + + return result; + } + + /////////////////////////////////////////////////////////////////////////////// + + /* TODO: too slow, needs to be optimized +template<size_t TTNumRecvThreads> +class TMTRecvSocket: public TAbstractSocket +{ +private: + typedef TLockFreePacketQueue<TTNumRecvThreads> TPacketQueue; + + static void* RecvThreadFunc(void* that) + { + static_cast<TMTRecvSocket*>(that)->RecvLoop(); + return NULL; + } + + void RecvLoop() + { + TBestUnixRecvSocket impl; + impl.Reset(*this); + + while (AtomicAdd(NumThreadsToDie, 0) == -1) { + sockaddr_in6 addr; + TUdpRecvPacket* packet = impl.Recv(&addr, NETLIBA_ANY_VERSION); + if (!packet) { + impl.Wait(0.0001, NETLIBA_ANY_VERSION); // so small tiomeout because we can't guarantee that 1 thread won't get all packets + continue; + } + Queue.Push(packet, addr); + } + + if (AtomicDecrement(NumThreadsToDie)) { + impl.CancelWait(NETLIBA_ANY_VERSION); + } else { + AllThreadsAreDead.Signal(); + } + } + + THolderVector<TThread> RecvThreads; + TAtomic NumThreadsToDie; + TSystemEvent AllThreadsAreDead; + + TPacketQueue Queue; + +public: + TMTRecvSocket() + : NumThreadsToDie(-1) {} + + ~TMTRecvSocket() + { + Close(); + } + + int Open(int port) + { + if (OpenImpl(port) != 0) { + Y_ASSERT(!IsValid()); + return -1; + } + + NumThreadsToDie = -1; + RecvThreads.reserve(TTNumRecvThreads); + for (size_t i = 0; i != TTNumRecvThreads; ++i) { + RecvThreads.PushBack(new TThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_recv_skt"))); + RecvThreads.back()->Start(); + RecvThreads.back()->Detach(); + } + return 0; + } + + void Close() + { + if (!IsValid()) { + return; + } + + AtomicSwap(&NumThreadsToDie, (int)RecvThreads.size()); + CancelWaitImpl(); + Y_VERIFY(AllThreadsAreDead.WaitT(TDuration::Seconds(30)), "TMTRecvSocket destruction failed"); + + CloseImpl(); + } + + void Wait(float timeoutSec, int netlibaVersion) const + { + Y_UNUSED(netlibaVersion); + Queue.GetEvent().WaitT(TDuration::Seconds(timeoutSec)); + } + void CancelWait(int netlibaVersion) + { + Y_UNUSED(netlibaVersion); + Queue.GetEvent().Signal(); + } + + TUdpRecvPacket* Recv(sockaddr_in6 *addr, int netlibaVersion) + { + Y_UNUSED(netlibaVersion); + TUdpRecvPacket* result; + if (!Queue.Pop(&result, addr)) { + return NULL; + } + return result; + } + + bool IsRecvMsgSupported() const { return false; } + ssize_t RecvMsg(TMsgHdr* hdr, int flags) { Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TMTRecvSocket implementation must use memcpy which is suboptimal and thus forbidden!"); } +}; +*/ + + /////////////////////////////////////////////////////////////////////////////// + + // Send.*, Recv, Wait and CancelWait are thread-safe. + class TDualStackSocket: public TTryToRecvMMsgSocket { + private: + typedef TTryToRecvMMsgSocket TBase; + typedef TLockFreePacketQueue<1> TPacketQueue; + + static void* RecvThreadFunc(void* that); + void RecvLoop(); + + struct TFilteredPacketQueue { + enum EPushResult { + PR_FULL = 0, + PR_OK = 1, + PR_FILTERED = 2 + }; + const ui8 F1; + const ui8 F2; + const ui8 CmdPos; + TFilteredPacketQueue(ui8 f1, ui8 f2, ui8 cmdPos) + : F1(f1) + , F2(f2) + , CmdPos(cmdPos) + { + } + bool Pop(TUdpRecvPacket** packet, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) { + return Queue.Pop(packet, srcAddr, dstAddr); + } + ui8 Push(TUdpRecvPacket* packet, const TPacketMeta& meta) { + if (Queue.IsDataPartFull()) { + const ui8 cmd = packet->Data.get()[CmdPos]; + if (cmd == F1 || cmd == F2) + return PR_FILTERED; + } + return Queue.Push(packet, meta); //false - PR_FULL, true - PR_OK + } + TPacketQueue Queue; + }; + + TFilteredPacketQueue& GetRecvQueue(int netlibaVersion) const; + TSystemEvent& GetQueueEvent(const TFilteredPacketQueue& queue) const; + + TThread RecvThread; + TAtomic ShouldDie; + TSystemEvent DieEvent; + + mutable TFilteredPacketQueue RecvQueue6; + mutable TFilteredPacketQueue RecvQueue12; + + public: + TDualStackSocket(); + ~TDualStackSocket() override; + + int Open(int port) override; + void Close() override; + + void Wait(float timeoutSec, int netlibaVersion) const override; + void CancelWait(int netlibaVersion) override; + + bool IsRecvMsgSupported() const override { + return false; + } + ssize_t RecvMsg(TMsgHdr* hdr, int flags) override { + Y_UNUSED(hdr); + Y_UNUSED(flags); + Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TDualStackSocket implementation must use memcpy which is suboptimal and thus forbidden!"); + } + + TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override; + }; + + TDualStackSocket::TDualStackSocket() + : RecvThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_dual_stack")) + , ShouldDie(0) + , RecvQueue6(NNetliba::DATA, NNetliba::DATA_SMALL, NNetliba::CMD_POS) + , RecvQueue12(NNetliba_v12::DATA, NNetliba_v12::DATA_SMALL, NNetliba_v12::CMD_POS) + { + } + + // virtual functions don't work in dtors! + TDualStackSocket::~TDualStackSocket() { + Close(); + + sockaddr_in6 srcAdd; + sockaddr_in6 dstAddr; + TUdpRecvPacket* ptr = nullptr; + + while (GetRecvQueue(NETLIBA_ANY_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) { + delete ptr; + } + while (GetRecvQueue(NETLIBA_V12_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) { + delete ptr; + } + } + + int TDualStackSocket::Open(int port) { + if (TBase::Open(port) != 0) { + Y_ASSERT(!IsValid()); + return -1; + } + + AtomicSet(ShouldDie, 0); + DieEvent.Reset(); + RecvThread.Start(); + RecvThread.Detach(); + return 0; + } + + void TDualStackSocket::Close() { + if (!IsValid()) { + return; + } + + AtomicSwap(&ShouldDie, 1); + CancelWaitImpl(); + Y_VERIFY(DieEvent.WaitT(TDuration::Seconds(30)), "TDualStackSocket::Close failed"); + + TBase::Close(); + } + + TDualStackSocket::TFilteredPacketQueue& TDualStackSocket::GetRecvQueue(int netlibaVersion) const { + return netlibaVersion == NETLIBA_V12_VERSION ? RecvQueue12 : RecvQueue6; + } + + TSystemEvent& TDualStackSocket::GetQueueEvent(const TFilteredPacketQueue& queue) const { + return queue.Queue.GetEvent(); + } + + void* TDualStackSocket::RecvThreadFunc(void* that) { + SetHighestThreadPriority(); + static_cast<TDualStackSocket*>(that)->RecvLoop(); + return nullptr; + } + + void TDualStackSocket::RecvLoop() { + for (;;) { + TUdpRecvPacket* p = nullptr; + sockaddr_in6 srcAddr; + sockaddr_in6 dstAddr; + while (AtomicAdd(ShouldDie, 0) == 0 && (p = TBase::Recv(&srcAddr, &dstAddr, NETLIBA_ANY_VERSION))) { + Y_ASSERT(p->DataStart == 0); + if (p->DataSize < 12) { + continue; + } + + TFilteredPacketQueue& q = GetRecvQueue(p->Data.get()[8]); + const ui8 res = q.Push(p, {srcAddr, dstAddr}); + if (res == TFilteredPacketQueue::PR_OK) { + GetQueueEvent(q).Signal(); + } else { + // simulate OS behavior on buffer overflow - drop packets. + const NHPTimer::STime time = AtomicGet(RecvLag); + const float sec = NHPTimer::GetSeconds(time); + fprintf(stderr, "TDualStackSocket::RecvLoop netliba v%d queue overflow, recv lag: %f sec, dropping packet, res: %u\n", + &q == &RecvQueue12 ? 12 : 6, sec, res); + delete p; + } + } + + if (AtomicAdd(ShouldDie, 0)) { + DieEvent.Signal(); + return; + } + + TBase::Wait(0.1f, NETLIBA_ANY_VERSION); + } + } + + void TDualStackSocket::Wait(float timeoutSec, int netlibaVersion) const { + TFilteredPacketQueue& q = GetRecvQueue(netlibaVersion); + if (q.Queue.IsEmpty()) { + GetQueueEvent(q).Reset(); + if (q.Queue.IsEmpty()) { + GetQueueEvent(q).WaitT(TDuration::Seconds(timeoutSec)); + } + } + } + + void TDualStackSocket::CancelWait(int netlibaVersion) { + GetQueueEvent(GetRecvQueue(netlibaVersion)).Signal(); + } + + // thread-safe + TUdpRecvPacket* TDualStackSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) { + TUdpRecvPacket* result = nullptr; + if (!GetRecvQueue(netlibaVersion).Pop(&result, srcAddr, dstAddr)) { + return nullptr; + } + return result; + } + + /////////////////////////////////////////////////////////////////////////////// + + TIntrusivePtr<ISocket> CreateSocket() { + return new TSocket(); + } + + TIntrusivePtr<ISocket> CreateDualStackSocket() { + return new TDualStackSocket(); + } + + TIntrusivePtr<ISocket> CreateBestRecvSocket() { + // TSocket is faster than TRecvMMsgFunc in case of unsupported recvmmsg + if (!TTryToRecvMMsgSocket::IsRecvMMsgSupported()) { + return new TSocket(); + } + return new TTryToRecvMMsgSocket(); + } + +} |