aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/netliba
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/netliba
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/netliba')
-rw-r--r--library/cpp/netliba/socket/allocator.h14
-rw-r--r--library/cpp/netliba/socket/creators.cpp141
-rw-r--r--library/cpp/netliba/socket/packet_queue.h97
-rw-r--r--library/cpp/netliba/socket/protocols.h48
-rw-r--r--library/cpp/netliba/socket/socket.cpp1086
-rw-r--r--library/cpp/netliba/socket/socket.h126
-rw-r--r--library/cpp/netliba/socket/stdafx.cpp1
-rw-r--r--library/cpp/netliba/socket/stdafx.h16
-rw-r--r--library/cpp/netliba/socket/udp_recv_packet.h79
-rw-r--r--library/cpp/netliba/v6/block_chain.cpp90
-rw-r--r--library/cpp/netliba/v6/block_chain.h319
-rw-r--r--library/cpp/netliba/v6/cpu_affinity.cpp138
-rw-r--r--library/cpp/netliba/v6/cpu_affinity.h5
-rw-r--r--library/cpp/netliba/v6/cstdafx.h41
-rw-r--r--library/cpp/netliba/v6/ib_buffers.cpp5
-rw-r--r--library/cpp/netliba/v6/ib_buffers.h181
-rw-r--r--library/cpp/netliba/v6/ib_collective.cpp1317
-rw-r--r--library/cpp/netliba/v6/ib_collective.h160
-rw-r--r--library/cpp/netliba/v6/ib_cs.cpp776
-rw-r--r--library/cpp/netliba/v6/ib_cs.h57
-rw-r--r--library/cpp/netliba/v6/ib_low.cpp114
-rw-r--r--library/cpp/netliba/v6/ib_low.h797
-rw-r--r--library/cpp/netliba/v6/ib_mem.cpp167
-rw-r--r--library/cpp/netliba/v6/ib_mem.h178
-rw-r--r--library/cpp/netliba/v6/ib_memstream.cpp122
-rw-r--r--library/cpp/netliba/v6/ib_memstream.h95
-rw-r--r--library/cpp/netliba/v6/ib_test.cpp232
-rw-r--r--library/cpp/netliba/v6/ib_test.h5
-rw-r--r--library/cpp/netliba/v6/net_acks.cpp194
-rw-r--r--library/cpp/netliba/v6/net_acks.h528
-rw-r--r--library/cpp/netliba/v6/net_queue_stat.h9
-rw-r--r--library/cpp/netliba/v6/net_request.cpp5
-rw-r--r--library/cpp/netliba/v6/net_request.h15
-rw-r--r--library/cpp/netliba/v6/net_test.cpp50
-rw-r--r--library/cpp/netliba/v6/net_test.h28
-rw-r--r--library/cpp/netliba/v6/stdafx.cpp1
-rw-r--r--library/cpp/netliba/v6/stdafx.h25
-rw-r--r--library/cpp/netliba/v6/udp_address.cpp300
-rw-r--r--library/cpp/netliba/v6/udp_address.h48
-rw-r--r--library/cpp/netliba/v6/udp_client_server.cpp1321
-rw-r--r--library/cpp/netliba/v6/udp_client_server.h62
-rw-r--r--library/cpp/netliba/v6/udp_debug.cpp2
-rw-r--r--library/cpp/netliba/v6/udp_debug.h21
-rw-r--r--library/cpp/netliba/v6/udp_http.cpp1354
-rw-r--r--library/cpp/netliba/v6/udp_http.h148
-rw-r--r--library/cpp/netliba/v6/udp_socket.cpp292
-rw-r--r--library/cpp/netliba/v6/udp_socket.h59
-rw-r--r--library/cpp/netliba/v6/udp_test.cpp161
-rw-r--r--library/cpp/netliba/v6/udp_test.h5
49 files changed, 11035 insertions, 0 deletions
diff --git a/library/cpp/netliba/socket/allocator.h b/library/cpp/netliba/socket/allocator.h
new file mode 100644
index 0000000000..f09b0dabcf
--- /dev/null
+++ b/library/cpp/netliba/socket/allocator.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#ifdef NETLIBA_WITH_NALF
+#include <library/cpp/malloc/nalf/alloc_helpers.h>
+using TWithCustomAllocator = TWithNalfIncrementalAlloc;
+template <typename T>
+using TCustomAllocator = TNalfIncrementalAllocator<T>;
+#else
+#include <memory>
+typedef struct {
+} TWithCustomAllocator;
+template <typename T>
+using TCustomAllocator = std::allocator<T>;
+#endif
diff --git a/library/cpp/netliba/socket/creators.cpp b/library/cpp/netliba/socket/creators.cpp
new file mode 100644
index 0000000000..3821bf55b9
--- /dev/null
+++ b/library/cpp/netliba/socket/creators.cpp
@@ -0,0 +1,141 @@
+#include "stdafx.h"
+#include <string.h>
+#include <util/generic/utility.h>
+#include <util/network/init.h>
+#include <util/system/defaults.h>
+#include <util/system/yassert.h>
+#include "socket.h"
+
+namespace NNetlibaSocket {
+ void* CreateTos(const ui8 tos, void* buffer) {
+#ifdef _win_
+ *(int*)buffer = (int)tos;
+#else
+ // glibc bug: http://sourceware.org/bugzilla/show_bug.cgi?id=13500
+ memset(buffer, 0, TOS_BUFFER_SIZE);
+
+ msghdr dummy;
+ Zero(dummy);
+ dummy.msg_control = buffer;
+ dummy.msg_controllen = TOS_BUFFER_SIZE;
+
+ // TODO: in FreeBSD setting TOS for dual stack sockets does not affect ipv4 frames
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&dummy);
+ cmsg->cmsg_level = IPPROTO_IPV6;
+ cmsg->cmsg_type = IPV6_TCLASS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+ memcpy(CMSG_DATA(cmsg), &tos, sizeof(tos)); // memcpy shut ups alias restrict warning
+
+ Y_ASSERT(CMSG_NXTHDR(&dummy, cmsg) == nullptr);
+#endif
+ return buffer;
+ }
+
+ TMsgHdr* AddSockAuxData(TMsgHdr* header, const ui8 tos, const sockaddr_in6& myAddr, void* buffer, size_t bufferSize) {
+#ifdef _win_
+ Y_UNUSED(header);
+ Y_UNUSED(tos);
+ Y_UNUSED(myAddr);
+ Y_UNUSED(buffer);
+ Y_UNUSED(bufferSize);
+ return nullptr;
+#else
+ header->msg_control = buffer;
+ header->msg_controllen = bufferSize;
+
+ size_t totalLen = 0;
+#ifdef _cygwin_
+ Y_UNUSED(tos);
+#else
+ // Cygwin does not support IPV6_TCLASS, so we ignore it
+ cmsghdr* cmsgTos = CMSG_FIRSTHDR(header);
+ if (cmsgTos == nullptr) {
+ header->msg_control = nullptr;
+ header->msg_controllen = 0;
+ return nullptr;
+ }
+ cmsgTos->cmsg_level = IPPROTO_IPV6;
+ cmsgTos->cmsg_type = IPV6_TCLASS;
+ cmsgTos->cmsg_len = CMSG_LEN(sizeof(int));
+ totalLen = CMSG_SPACE(sizeof(int));
+ *(ui8*)CMSG_DATA(cmsgTos) = tos;
+#endif
+
+ if (*(ui64*)myAddr.sin6_addr.s6_addr != 0u) {
+ in6_pktinfo* pktInfo;
+#ifdef _cygwin_
+ cmsghdr* cmsgAddr = CMSG_FIRSTHDR(header);
+#else
+ cmsghdr* cmsgAddr = CMSG_NXTHDR(header, cmsgTos);
+#endif
+ if (cmsgAddr == nullptr) {
+ // leave only previous record
+ header->msg_controllen = totalLen;
+ return nullptr;
+ }
+ cmsgAddr->cmsg_level = IPPROTO_IPV6;
+ cmsgAddr->cmsg_type = IPV6_PKTINFO;
+ cmsgAddr->cmsg_len = CMSG_LEN(sizeof(*pktInfo));
+ totalLen += CMSG_SPACE(sizeof(*pktInfo));
+ pktInfo = (in6_pktinfo*)CMSG_DATA(cmsgAddr);
+
+ pktInfo->ipi6_addr = myAddr.sin6_addr;
+ pktInfo->ipi6_ifindex = 0; /* 0 = use interface specified in routing table */
+ }
+ header->msg_controllen = totalLen; //write right len
+
+ return header;
+#endif
+ }
+
+ TIoVec CreateIoVec(char* data, const size_t dataSize) {
+ TIoVec result;
+ Zero(result);
+
+ result.iov_base = data;
+ result.iov_len = dataSize;
+
+ return result;
+ }
+
+ TMsgHdr CreateSendMsgHdr(const sockaddr_in6& addr, const TIoVec& iov, void* tosBuffer) {
+ TMsgHdr result;
+ Zero(result);
+
+ result.msg_name = (void*)&addr;
+ result.msg_namelen = sizeof(addr);
+ result.msg_iov = (TIoVec*)&iov;
+ result.msg_iovlen = 1;
+
+ if (tosBuffer) {
+#ifdef _win_
+ result.Tos = *(int*)tosBuffer;
+#else
+ result.msg_control = tosBuffer;
+ result.msg_controllen = TOS_BUFFER_SIZE;
+#endif
+ }
+
+ return result;
+ }
+
+ TMsgHdr CreateRecvMsgHdr(sockaddr_in6* addrBuf, const TIoVec& iov, void* controllBuffer) {
+ TMsgHdr result;
+ Zero(result);
+
+ Zero(*addrBuf);
+ result.msg_name = addrBuf;
+ result.msg_namelen = sizeof(*addrBuf);
+
+ result.msg_iov = (TIoVec*)&iov;
+ result.msg_iovlen = 1;
+#ifndef _win_
+ if (controllBuffer) {
+ memset(controllBuffer, 0, CTRL_BUFFER_SIZE);
+ result.msg_control = controllBuffer;
+ result.msg_controllen = CTRL_BUFFER_SIZE;
+ }
+#endif
+ return result;
+ }
+}
diff --git a/library/cpp/netliba/socket/packet_queue.h b/library/cpp/netliba/socket/packet_queue.h
new file mode 100644
index 0000000000..58a84709c2
--- /dev/null
+++ b/library/cpp/netliba/socket/packet_queue.h
@@ -0,0 +1,97 @@
+#pragma once
+
+#include "udp_recv_packet.h"
+
+#include <library/cpp/threading/chunk_queue/queue.h>
+
+#include <util/network/init.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <util/system/event.h>
+#include <util/system/yassert.h>
+#include <library/cpp/deprecated/atomic/atomic_ops.h>
+#include <utility>
+
+namespace NNetlibaSocket {
+ struct TPacketMeta {
+ sockaddr_in6 RemoteAddr;
+ sockaddr_in6 MyAddr;
+ };
+
+ template <size_t TTNumWriterThreads>
+ class TLockFreePacketQueue {
+ private:
+ enum { MAX_PACKETS_IN_QUEUE = INT_MAX,
+ CMD_QUEUE_RESERVE = 1 << 20,
+ MAX_DATA_IN_QUEUE = 32 << 20 };
+
+ typedef std::pair<TUdpRecvPacket*, TPacketMeta> TPacket;
+ typedef std::conditional_t<TTNumWriterThreads == 1, NThreading::TOneOneQueue<TPacket>, NThreading::TManyOneQueue<TPacket, TTNumWriterThreads>> TImpl;
+
+ mutable TImpl Queue;
+ mutable TSystemEvent QueueEvent;
+
+ mutable TAtomic NumPackets;
+ TAtomic DataSize;
+
+ public:
+ TLockFreePacketQueue()
+ : NumPackets(0)
+ , DataSize(0)
+ {
+ }
+
+ ~TLockFreePacketQueue() {
+ TPacket packet;
+ while (Queue.Dequeue(packet)) {
+ delete packet.first;
+ }
+ }
+
+ bool IsDataPartFull() const {
+ return (AtomicGet(NumPackets) >= MAX_PACKETS_IN_QUEUE || AtomicGet(DataSize) >= MAX_DATA_IN_QUEUE - CMD_QUEUE_RESERVE);
+ }
+
+ bool Push(TUdpRecvPacket* packet, const TPacketMeta& meta) {
+ // simulate OS behavior on buffer overflow - drop packets.
+ // yeah it contains small data race (we can add little bit more packets, but nobody cares)
+ if (AtomicGet(NumPackets) >= MAX_PACKETS_IN_QUEUE || AtomicGet(DataSize) >= MAX_DATA_IN_QUEUE) {
+ return false;
+ }
+ AtomicAdd(NumPackets, 1);
+ AtomicAdd(DataSize, packet->DataSize);
+ Y_ASSERT(packet->DataStart == 0);
+
+ Queue.Enqueue(TPacket(std::make_pair(packet, meta)));
+ QueueEvent.Signal();
+ return true;
+ }
+
+ bool Pop(TUdpRecvPacket** packet, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) {
+ TPacket p;
+ if (!Queue.Dequeue(p)) {
+ QueueEvent.Reset();
+ if (!Queue.Dequeue(p)) {
+ return false;
+ }
+ QueueEvent.Signal();
+ }
+ *packet = p.first;
+ *srcAddr = p.second.RemoteAddr;
+ *dstAddr = p.second.MyAddr;
+
+ AtomicSub(NumPackets, 1);
+ AtomicSub(DataSize, (*packet)->DataSize);
+ Y_ASSERT(AtomicGet(NumPackets) >= 0 && AtomicGet(DataSize) >= 0);
+
+ return true;
+ }
+
+ bool IsEmpty() const {
+ return AtomicAdd(NumPackets, 0) == 0;
+ }
+
+ TSystemEvent& GetEvent() const {
+ return QueueEvent;
+ }
+ };
+}
diff --git a/library/cpp/netliba/socket/protocols.h b/library/cpp/netliba/socket/protocols.h
new file mode 100644
index 0000000000..ec6896ab9b
--- /dev/null
+++ b/library/cpp/netliba/socket/protocols.h
@@ -0,0 +1,48 @@
+#pragma once
+
+namespace NNetlibaSocket {
+ namespace NNetliba_v12 {
+ const ui8 CMD_POS = 11;
+ enum EUdpCmd {
+ CMD_BEGIN = 1,
+
+ DATA = CMD_BEGIN,
+ DATA_SMALL, // no jumbo-packets
+ DO_NOT_USE_1, //just reserved
+ DO_NOT_USE_2, //just reserved
+
+ CANCEL_TRANSFER,
+
+ ACK,
+ ACK_COMPLETE,
+ ACK_CANCELED,
+ ACK_RESEND_NOSHMEM,
+
+ PING,
+ PONG,
+ PONG_IB,
+
+ KILL,
+
+ CMD_END,
+ };
+ }
+
+ namespace NNetliba {
+ const ui8 CMD_POS = 8;
+ enum EUdpCmd {
+ DATA,
+ ACK,
+ ACK_COMPLETE,
+ ACK_RESEND,
+ DATA_SMALL, // no jumbo-packets
+ PING,
+ PONG,
+ DATA_SHMEM,
+ DATA_SMALL_SHMEM,
+ KILL,
+ ACK_RESEND_NOSHMEM,
+ };
+ }
+
+}
diff --git a/library/cpp/netliba/socket/socket.cpp b/library/cpp/netliba/socket/socket.cpp
new file mode 100644
index 0000000000..c10236229b
--- /dev/null
+++ b/library/cpp/netliba/socket/socket.cpp
@@ -0,0 +1,1086 @@
+#include "stdafx.h"
+#include <util/datetime/cputimer.h>
+#include <util/draft/holder_vector.h>
+#include <util/generic/utility.h>
+#include <util/generic/vector.h>
+#include <util/network/init.h>
+#include <util/network/poller.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <util/system/byteorder.h>
+#include <util/system/defaults.h>
+#include <util/system/error.h>
+#include <util/system/event.h>
+#include <util/system/thread.h>
+#include <util/system/yassert.h>
+#include <util/system/rwlock.h>
+#include <util/system/env.h>
+
+#include "socket.h"
+#include "packet_queue.h"
+#include "udp_recv_packet.h"
+
+#include <array>
+#include <stdlib.h>
+
+///////////////////////////////////////////////////////////////////////////////
+
+#ifndef _win_
+#include <netinet/in.h>
+#endif
+
+#ifdef _linux_
+#include <dlfcn.h> // dlsym
+#endif
+
+template <class T>
+static T GetAddressOf(const char* name) {
+#ifdef _linux_
+ if (!GetEnv("DISABLE_MMSG")) {
+ return (T)dlsym(RTLD_DEFAULT, name);
+ }
+#endif
+ Y_UNUSED(name);
+ return nullptr;
+}
+
+///////////////////////////////////////////////////////////////////////////////
+
+namespace NNetlibaSocket {
+ ///////////////////////////////////////////////////////////////////////////////
+
+ struct timespec; // we use it only as NULL pointer
+ typedef int (*TSendMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int);
+ typedef int (*TRecvMMsgFunc)(SOCKET, TMMsgHdr*, unsigned int, unsigned int, timespec*);
+
+ static const TSendMMsgFunc SendMMsgFunc = GetAddressOf<TSendMMsgFunc>("sendmmsg");
+ static const TRecvMMsgFunc RecvMMsgFunc = GetAddressOf<TRecvMMsgFunc>("recvmmsg");
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ bool ReadTos(const TMsgHdr& msgHdr, ui8* tos) {
+#ifdef _win_
+ Y_UNUSED(msgHdr);
+ Y_UNUSED(tos);
+ return false;
+#else
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&msgHdr);
+ if (!cmsg)
+ return false;
+ //Y_ASSERT(cmsg->cmsg_level == IPPROTO_IPV6);
+ //Y_ASSERT(cmsg->cmsg_type == IPV6_TCLASS);
+ if (cmsg->cmsg_len != CMSG_LEN(sizeof(int)))
+ return false;
+ *tos = *(ui8*)CMSG_DATA(cmsg);
+ return true;
+#endif
+ }
+
+ bool ExtractDestinationAddress(TMsgHdr& msgHdr, sockaddr_in6* addrBuf) {
+ Zero(*addrBuf);
+#ifdef _win_
+ Y_UNUSED(msgHdr);
+ Y_UNUSED(addrBuf);
+ return false;
+#else
+ cmsghdr* cmsg;
+ for (cmsg = CMSG_FIRSTHDR(&msgHdr); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msgHdr, cmsg)) {
+ if ((cmsg->cmsg_level == IPPROTO_IPV6) && (cmsg->cmsg_type == IPV6_PKTINFO)) {
+ in6_pktinfo* i = (in6_pktinfo*)CMSG_DATA(cmsg);
+ addrBuf->sin6_addr = i->ipi6_addr;
+ addrBuf->sin6_family = AF_INET6;
+ return true;
+ }
+ }
+ return false;
+#endif
+ }
+
+ // all send and recv methods are thread safe!
+ class TAbstractSocket: public ISocket {
+ private:
+ SOCKET S;
+ mutable TSocketPoller Poller;
+ sockaddr_in6 SelfAddress;
+
+ int SendSysSocketSize;
+ int SendSysSocketSizePrev;
+
+ int CreateSocket(int netPort);
+ int DetectSelfAddress();
+
+ protected:
+ int SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len);
+
+ int OpenImpl(int port);
+ void CloseImpl();
+
+ void WaitImpl(float timeoutSec) const;
+ void CancelWaitImpl(const sockaddr_in6* address = nullptr); // NULL means "self"
+
+ ssize_t RecvMsgImpl(TMsgHdr* hdr, int flags);
+ TUdpRecvPacket* RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr);
+ int RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout);
+
+ bool IsFragmentationForbiden();
+ void ForbidFragmentation();
+ void EnableFragmentation();
+
+ //Shared state for setsockopt. Forbid simultaneous transfer while sender asking for specific options (i.e. DONOT_FRAG)
+ TRWMutex Mutex;
+ TAtomic RecvLag = 0;
+
+ public:
+ TAbstractSocket();
+ ~TAbstractSocket() override;
+#ifdef _unix_
+ void Reset(const TAbstractSocket& rhv);
+#endif
+
+ bool IsValid() const override;
+
+ const sockaddr_in6& GetSelfAddress() const override;
+ int GetNetworkOrderPort() const override;
+ int GetPort() const override;
+
+ int GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) override;
+
+ // send all packets to this and only this address by default
+ int Connect(const struct sockaddr* address, socklen_t address_len) override;
+
+ void CancelWaitHost(const sockaddr_in6 addr) override;
+
+ bool IsSendMMsgSupported() const override;
+ int SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) override;
+ ssize_t SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) override;
+ bool IncreaseSendBuff() override;
+ int GetSendSysSocketSize() override;
+ void SetRecvLagTime(NHPTimer::STime time) override;
+ };
+
+ TAbstractSocket::TAbstractSocket()
+ : S(INVALID_SOCKET)
+ , SendSysSocketSize(0)
+ , SendSysSocketSizePrev(0)
+ {
+ Zero(SelfAddress);
+ }
+
+ TAbstractSocket::~TAbstractSocket() {
+ CloseImpl();
+ }
+
+#ifdef _unix_
+ void TAbstractSocket::Reset(const TAbstractSocket& rhv) {
+ Close();
+ S = dup(rhv.S);
+ SelfAddress = rhv.SelfAddress;
+ }
+#endif
+
+ int TAbstractSocket::CreateSocket(int netPort) {
+ if (IsValid()) {
+ Y_ASSERT(0);
+ return 0;
+ }
+ S = socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
+ if (S == INVALID_SOCKET) {
+ return -1;
+ }
+ {
+ int flag = 0;
+ Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_V6ONLY, (const char*)&flag, sizeof(flag)) == 0, "IPV6_V6ONLY failed");
+ }
+ {
+ int flag = 1;
+ Y_VERIFY(SetSockOpt(SOL_SOCKET, SO_REUSEADDR, (const char*)&flag, sizeof(flag)) == 0, "SO_REUSEADDR failed");
+ }
+#if defined(_win_)
+ unsigned long dummy = 1;
+ ioctlsocket(S, FIONBIO, &dummy);
+#else
+ Y_VERIFY(fcntl(S, F_SETFL, O_NONBLOCK) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError());
+ Y_VERIFY(fcntl(S, F_SETFD, FD_CLOEXEC) == 0, "fnctl failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError());
+ {
+ int flag = 1;
+#ifndef IPV6_RECVPKTINFO /* Darwin platforms require this */
+ Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_PKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_PKTINFO failed");
+#else
+ Y_VERIFY(SetSockOpt(IPPROTO_IPV6, IPV6_RECVPKTINFO, (const char*)&flag, sizeof(flag)) == 0, "IPV6_RECVPKTINFO failed");
+#endif
+ }
+#endif
+
+ Poller.WaitRead(S, nullptr);
+
+ {
+ // bind socket
+ sockaddr_in6 name;
+ Zero(name);
+ name.sin6_family = AF_INET6;
+ name.sin6_addr = in6addr_any;
+ name.sin6_port = netPort;
+ if (bind(S, (sockaddr*)&name, sizeof(name)) != 0) {
+ fprintf(stderr, "netliba_socket could not bind to port %d: %s (errno = %d)\n", InetToHost((ui16)netPort), LastSystemErrorText(), LastSystemError());
+ CloseImpl(); // we call this CloseImpl after Poller initialization
+ return -1;
+ }
+ }
+ //Default behavior is allowing fragmentation (according to netliba v6 behavior)
+ //If we want to sent packet with DF flag we have to use SendMsg()
+ EnableFragmentation();
+
+ {
+ socklen_t sz = sizeof(SendSysSocketSize);
+ if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &SendSysSocketSize, &sz)) {
+ fprintf(stderr, "Can`t get SO_SNDBUF");
+ }
+ }
+ return 0;
+ }
+
+ bool TAbstractSocket::IsValid() const {
+ return S != INVALID_SOCKET;
+ }
+
+ int TAbstractSocket::DetectSelfAddress() {
+ socklen_t nameLen = sizeof(SelfAddress);
+ if (getsockname(S, (sockaddr*)&SelfAddress, &nameLen) != 0) { // actually we use only sin6_port
+ return -1;
+ }
+ Y_ASSERT(SelfAddress.sin6_family == AF_INET6);
+ SelfAddress.sin6_addr = in6addr_loopback;
+ return 0;
+ }
+
+ const sockaddr_in6& TAbstractSocket::GetSelfAddress() const {
+ return SelfAddress;
+ }
+
+ int TAbstractSocket::GetNetworkOrderPort() const {
+ return SelfAddress.sin6_port;
+ }
+
+ int TAbstractSocket::GetPort() const {
+ return InetToHost((ui16)SelfAddress.sin6_port);
+ }
+
+ int TAbstractSocket::SetSockOpt(int level, int option_name, const void* option_value, socklen_t option_len) {
+ const int rv = setsockopt(S, level, option_name, (const char*)option_value, option_len);
+ Y_VERIFY_DEBUG(rv == 0, "SetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError());
+ return rv;
+ }
+
+ int TAbstractSocket::GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) {
+ const int rv = getsockopt(S, level, option_name, (char*)option_value, option_len);
+ Y_VERIFY_DEBUG(rv == 0, "GetSockOpt failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError());
+ return rv;
+ }
+
+ bool TAbstractSocket::IsFragmentationForbiden() {
+#if defined(_win_)
+ DWORD flag = 0;
+ socklen_t sz = sizeof(flag);
+ Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (char*)&flag, &sz) == 0, "");
+ return flag;
+#elif defined(_linux_)
+ int flag = 0;
+ socklen_t sz = sizeof(flag);
+ Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (char*)&flag, &sz) == 0, "");
+ return flag == IPV6_PMTUDISC_DO;
+#elif !defined(_darwin_)
+ int flag = 0;
+ socklen_t sz = sizeof(flag);
+ Y_VERIFY(GetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (char*)&flag, &sz) == 0, "");
+ return flag;
+#endif
+ return false;
+ }
+
+ void TAbstractSocket::ForbidFragmentation() {
+ // do not fragment ping packets
+#if defined(_win_)
+ DWORD flag = 1;
+ SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag));
+#elif defined(_linux_)
+ int flag = IP_PMTUDISC_DO;
+ SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag));
+
+ flag = IPV6_PMTUDISC_DO;
+ SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag));
+#elif !defined(_darwin_)
+ int flag = 1;
+ //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag));
+ SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag));
+#endif
+ }
+
+ void TAbstractSocket::EnableFragmentation() {
+#if defined(_win_)
+ DWORD flag = 0;
+ SetSockOpt(IPPROTO_IP, IP_DONTFRAGMENT, (const char*)&flag, sizeof(flag));
+#elif defined(_linux_)
+ int flag = IP_PMTUDISC_WANT;
+ SetSockOpt(IPPROTO_IP, IP_MTU_DISCOVER, (const char*)&flag, sizeof(flag));
+
+ flag = IPV6_PMTUDISC_WANT;
+ SetSockOpt(IPPROTO_IPV6, IPV6_MTU_DISCOVER, (const char*)&flag, sizeof(flag));
+#elif !defined(_darwin_)
+ int flag = 0;
+ //SetSockOpt(IPPROTO_IP, IP_DONTFRAG, (const char*)&flag, sizeof(flag));
+ SetSockOpt(IPPROTO_IPV6, IPV6_DONTFRAG, (const char*)&flag, sizeof(flag));
+#endif
+ }
+
+ int TAbstractSocket::Connect(const sockaddr* address, socklen_t address_len) {
+ Y_ASSERT(IsValid());
+ return connect(S, address, address_len);
+ }
+
+ void TAbstractSocket::CancelWaitHost(const sockaddr_in6 addr) {
+ CancelWaitImpl(&addr);
+ }
+
+ bool TAbstractSocket::IsSendMMsgSupported() const {
+ return SendMMsgFunc != nullptr;
+ }
+
+ int TAbstractSocket::SendMMsg(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) {
+ Y_ASSERT(IsValid());
+ Y_VERIFY(SendMMsgFunc, "sendmmsg is not supported!");
+ TReadGuard rg(Mutex);
+ static bool checked = 0;
+ Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior");
+ return SendMMsgFunc(S, msgvec, vlen, flags);
+ }
+
+ ssize_t TAbstractSocket::SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) {
+ Y_ASSERT(IsValid());
+#ifdef _win32_
+ static bool checked = 0;
+ Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows");
+ if (hdr->Tos || frag == FF_DONT_FRAG) {
+ TWriteGuard wg(Mutex);
+ if (frag == FF_DONT_FRAG) {
+ ForbidFragmentation();
+ } else {
+ Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior");
+ }
+ int originalTos;
+ if (hdr->Tos) {
+ socklen_t sz = sizeof(originalTos);
+ Y_VERIFY(GetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, &sz) == 0, "");
+ Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&hdr->Tos, sizeof(hdr->Tos)) == 0, "");
+ }
+ const ssize_t rv = sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen);
+ if (hdr->Tos) {
+ Y_VERIFY(SetSockOpt(IPPROTO_IP, IP_TOS, (char*)&originalTos, sizeof(originalTos)) == 0, "");
+ }
+ if (frag == FF_DONT_FRAG) {
+ EnableFragmentation();
+ }
+ return rv;
+ }
+ TReadGuard rg(Mutex);
+ Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior");
+ return sendto(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, hdr->msg_namelen);
+#else
+ if (frag == FF_DONT_FRAG) {
+ TWriteGuard wg(Mutex);
+ ForbidFragmentation();
+ const ssize_t rv = sendmsg(S, hdr, flags);
+ EnableFragmentation();
+ return rv;
+ }
+
+ TReadGuard rg(Mutex);
+#ifndef _darwin_
+ static bool checked = 0;
+ Y_VERIFY(checked || (checked = !IsFragmentationForbiden()), "Send methods of this class expect default EnableFragmentation behavior");
+#endif
+ return sendmsg(S, hdr, flags);
+#endif
+ }
+
+ bool TAbstractSocket::IncreaseSendBuff() {
+ int buffSize;
+ socklen_t sz = sizeof(buffSize);
+ if (GetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, &sz)) {
+ return false;
+ }
+ // worst case: 200000 pps * 8k * 0.01sec = 16Mb so 32Mb hard limit is reasonable value
+ if (buffSize < 0 || buffSize > (1 << 25)) {
+ fprintf(stderr, "GetSockOpt returns wrong or too big value for SO_SNDBUF: %d\n", buffSize);
+ return false;
+ }
+ //linux returns the doubled value. man 7 socket:
+ //
+ // SO_SNDBUF
+ // Sets or gets the maximum socket send buffer in bytes. The ker-
+ // nel doubles this value (to allow space for bookkeeping overhead)
+ // when it is set using setsockopt(), and this doubled value is
+ // returned by getsockopt(). The default value is set by the
+ // wmem_default sysctl and the maximum allowed value is set by the
+ // wmem_max sysctl. The minimum (doubled) value for this option is
+ // 2048.
+ //
+#ifndef _linux_
+ buffSize += buffSize;
+#endif
+
+ // false if previous value was less than current value.
+ // It means setsockopt was not successful. (for example: system limits)
+ // we will try to set it again but return false
+ const bool rv = !(buffSize <= SendSysSocketSizePrev);
+ if (SetSockOpt(SOL_SOCKET, SO_SNDBUF, &buffSize, sz) == 0) {
+ SendSysSocketSize = buffSize;
+ SendSysSocketSizePrev = buffSize;
+ return rv;
+ }
+ return false;
+ }
+
+ int TAbstractSocket::GetSendSysSocketSize() {
+ return SendSysSocketSize;
+ }
+
+ void TAbstractSocket::SetRecvLagTime(NHPTimer::STime time) {
+ AtomicSet(RecvLag, time);
+ }
+
+ int TAbstractSocket::OpenImpl(int port) {
+ Y_ASSERT(!IsValid());
+ const int netPort = port ? htons((u_short)port) : 0;
+
+#ifdef _freebsd_
+ // alternative OS
+ if (netPort == 0) {
+ static ui64 pp = GetCycleCount();
+ for (int attempt = 0; attempt < 100; ++attempt) {
+ const int tryPort = htons((pp & 0x3fff) + 0xc000);
+ ++pp;
+ if (CreateSocket(tryPort) != 0) {
+ Y_ASSERT(!IsValid());
+ continue;
+ }
+
+ if (DetectSelfAddress() != 0 || tryPort != SelfAddress.sin6_port) {
+ // FreeBSD suck!
+ CloseImpl();
+ Y_ASSERT(!IsValid());
+ continue;
+ }
+ break;
+ }
+ if (!IsValid()) {
+ return -1;
+ }
+ } else {
+ if (CreateSocket(netPort) != 0) {
+ Y_ASSERT(!IsValid());
+ return -1;
+ }
+ }
+#else
+ // regular OS
+ if (CreateSocket(netPort) != 0) {
+ Y_ASSERT(!IsValid());
+ return -1;
+ }
+#endif
+
+ if (IsValid() && DetectSelfAddress() != 0) {
+ CloseImpl();
+ Y_ASSERT(!IsValid());
+ return -1;
+ }
+
+ Y_ASSERT(IsValid());
+ return 0;
+ }
+
+ void TAbstractSocket::CloseImpl() {
+ if (IsValid()) {
+ Poller.Unwait(S);
+ Y_VERIFY(closesocket(S) == 0, "closesocket failed: %s (errno = %d)", LastSystemErrorText(), LastSystemError());
+ }
+ S = INVALID_SOCKET;
+ }
+
+ void TAbstractSocket::WaitImpl(float timeoutSec) const {
+ Y_VERIFY(IsValid(), "something went wrong");
+ Poller.WaitT(TDuration::Seconds(timeoutSec));
+ }
+
+ void TAbstractSocket::CancelWaitImpl(const sockaddr_in6* address) {
+ Y_ASSERT(IsValid());
+
+ // darwin ignores packets with msg_iovlen == 0, also windows implementation uses sendto of first iovec.
+ TIoVec v = CreateIoVec(nullptr, 0);
+ TMsgHdr hdr = CreateSendMsgHdr((address ? *address : SelfAddress), v, nullptr);
+
+ // send self fake packet
+ TAbstractSocket::SendMsg(&hdr, 0, FF_ALLOW_FRAG);
+ }
+
+ ssize_t TAbstractSocket::RecvMsgImpl(TMsgHdr* hdr, int flags) {
+ Y_ASSERT(IsValid());
+
+#ifdef _win32_
+ Y_VERIFY(hdr->msg_iov->iov_len == 1, "Scatter/gather is currenly not supported on Windows");
+ return recvfrom(S, hdr->msg_iov->iov_base, hdr->msg_iov->iov_len, flags, (sockaddr*)hdr->msg_name, &hdr->msg_namelen);
+#else
+ return recvmsg(S, hdr, flags);
+#endif
+ }
+
+ TUdpRecvPacket* TAbstractSocket::RecvImpl(TUdpHostRecvBufAlloc* buf, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) {
+ Y_ASSERT(IsValid());
+
+ const TIoVec iov = CreateIoVec(buf->GetDataPtr(), buf->GetBufSize());
+ char controllBuffer[CTRL_BUFFER_SIZE]; //used to get dst address from socket
+ TMsgHdr hdr = CreateRecvMsgHdr(srcAddr, iov, controllBuffer);
+
+ const ssize_t rv = TAbstractSocket::RecvMsgImpl(&hdr, 0);
+ if (rv < 0) {
+ Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK);
+ return nullptr;
+ }
+ if (dstAddr && !ExtractDestinationAddress(hdr, dstAddr)) {
+ //fprintf(stderr, "can`t get destination ip\n");
+ }
+
+ // we extract packet and allocate new buffer only if packet arrived
+ TUdpRecvPacket* result = buf->ExtractPacket();
+ result->DataStart = 0;
+ result->DataSize = (int)rv;
+ return result;
+ }
+
+ // thread-safe
+ int TAbstractSocket::RecvMMsgImpl(TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags, timespec* timeout) {
+ Y_ASSERT(IsValid());
+ Y_VERIFY(RecvMMsgFunc, "recvmmsg is not supported!");
+ return RecvMMsgFunc(S, msgvec, vlen, flags, timeout);
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ class TSocket: public TAbstractSocket {
+ public:
+ int Open(int port) override;
+ void Close() override;
+
+ void Wait(float timeoutSec, int netlibaVersion) const override;
+ void CancelWait(int netlibaVersion) override;
+
+ bool IsRecvMsgSupported() const override;
+ ssize_t RecvMsg(TMsgHdr* hdr, int flags) override;
+ TUdpRecvPacket* Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) override;
+
+ private:
+ TUdpHostRecvBufAlloc RecvBuf;
+ };
+
+ int TSocket::Open(int port) {
+ return OpenImpl(port);
+ }
+
+ void TSocket::Close() {
+ CloseImpl();
+ }
+
+ void TSocket::Wait(float timeoutSec, int netlibaVersion) const {
+ Y_UNUSED(netlibaVersion);
+ WaitImpl(timeoutSec);
+ }
+
+ void TSocket::CancelWait(int netlibaVersion) {
+ Y_UNUSED(netlibaVersion);
+ CancelWaitImpl();
+ }
+
+ bool TSocket::IsRecvMsgSupported() const {
+ return true;
+ }
+
+ ssize_t TSocket::RecvMsg(TMsgHdr* hdr, int flags) {
+ return RecvMsgImpl(hdr, flags);
+ }
+
+ TUdpRecvPacket* TSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) {
+ Y_UNUSED(netlibaVersion);
+ return RecvImpl(&RecvBuf, srcAddr, dstAddr);
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ class TTryToRecvMMsgSocket: public TAbstractSocket {
+ private:
+ THolderVector<TUdpHostRecvBufAlloc> RecvPackets;
+ TVector<sockaddr_in6> RecvPacketsSrcAddresses;
+ TVector<TIoVec> RecvPacketsIoVecs;
+ size_t RecvPacketsBegin; // first non returned to user
+ size_t RecvPacketsHeadersEnd; // next after last one with data
+ TVector<TMMsgHdr> RecvPacketsHeaders;
+ TVector<std::array<char, CTRL_BUFFER_SIZE>> RecvPacketsCtrlBuffers;
+
+ int FillRecvBuffers();
+
+ public:
+ static bool IsRecvMMsgSupported();
+
+ // Tests showed best performance on queue size 128 (+7%).
+ // If memory is limited you can use 12 - it gives +4%.
+ // Do not use lower values - for example recvmmsg with 1 element is 3% slower that recvmsg!
+ // (tested with junk/f0b0s/neTBasicSocket_queue_test).
+ TTryToRecvMMsgSocket(const size_t recvQueueSize = 128);
+ ~TTryToRecvMMsgSocket() override;
+
+ int Open(int port) override;
+ void Close() override;
+
+ void Wait(float timeoutSec, int netlibaVersion) const override;
+ void CancelWait(int netlibaVersion) override;
+
+ bool IsRecvMsgSupported() const override {
+ return false;
+ }
+ ssize_t RecvMsg(TMsgHdr* hdr, int flags) override {
+ Y_UNUSED(hdr);
+ Y_UNUSED(flags);
+ Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TRecvMMsgSocket implementation must use memcpy which is suboptimal and thus forbidden!");
+ }
+ TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override;
+ };
+
+ TTryToRecvMMsgSocket::TTryToRecvMMsgSocket(const size_t recvQueueSize)
+ : RecvPacketsBegin(0)
+ , RecvPacketsHeadersEnd(0)
+ {
+ // recvmmsg is not supported - will act like TSocket,
+ // we can't just VERIFY - TTryToRecvMMsgSocket is used as base class for TDualStackSocket.
+ if (!IsRecvMMsgSupported()) {
+ RecvPackets.reserve(1);
+ RecvPackets.PushBack(new TUdpHostRecvBufAlloc);
+ return;
+ }
+
+ RecvPackets.reserve(recvQueueSize);
+ for (size_t i = 0; i != recvQueueSize; ++i) {
+ RecvPackets.PushBack(new TUdpHostRecvBufAlloc);
+ }
+
+ RecvPacketsSrcAddresses.resize(recvQueueSize);
+ RecvPacketsIoVecs.resize(recvQueueSize);
+ RecvPacketsHeaders.resize(recvQueueSize);
+ RecvPacketsCtrlBuffers.resize(recvQueueSize);
+
+ for (size_t i = 0; i != recvQueueSize; ++i) {
+ TMMsgHdr& mhdr = RecvPacketsHeaders[i];
+ Zero(mhdr);
+
+ RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize());
+ char* buf = RecvPacketsCtrlBuffers[i].data();
+ memset(buf, 0, CTRL_BUFFER_SIZE);
+ mhdr.msg_hdr = CreateRecvMsgHdr(&RecvPacketsSrcAddresses[i], RecvPacketsIoVecs[i], buf);
+ }
+ }
+
+ TTryToRecvMMsgSocket::~TTryToRecvMMsgSocket() {
+ Close();
+ }
+
+ int TTryToRecvMMsgSocket::Open(int port) {
+ return OpenImpl(port);
+ }
+
+ void TTryToRecvMMsgSocket::Close() {
+ CloseImpl();
+ }
+
+ void TTryToRecvMMsgSocket::Wait(float timeoutSec, int netlibaVersion) const {
+ Y_UNUSED(netlibaVersion);
+ Y_ASSERT(RecvPacketsBegin == RecvPacketsHeadersEnd || IsRecvMMsgSupported());
+ if (RecvPacketsBegin == RecvPacketsHeadersEnd) {
+ WaitImpl(timeoutSec);
+ }
+ }
+
+ void TTryToRecvMMsgSocket::CancelWait(int netlibaVersion) {
+ Y_UNUSED(netlibaVersion);
+ CancelWaitImpl();
+ }
+
+ bool TTryToRecvMMsgSocket::IsRecvMMsgSupported() {
+ return RecvMMsgFunc != nullptr;
+ }
+
+ int TTryToRecvMMsgSocket::FillRecvBuffers() {
+ Y_ASSERT(IsRecvMMsgSupported());
+ Y_ASSERT(RecvPacketsBegin <= RecvPacketsHeadersEnd);
+ if (RecvPacketsBegin < RecvPacketsHeadersEnd) {
+ return RecvPacketsHeadersEnd - RecvPacketsBegin;
+ }
+
+ // no packets left from last recvmmsg call
+ for (size_t i = 0; i != RecvPacketsHeadersEnd; ++i) { // reinit only used by last recvmmsg call headers
+ RecvPacketsIoVecs[i] = CreateIoVec(RecvPackets[i]->GetDataPtr(), RecvPackets[i]->GetBufSize());
+ }
+ RecvPacketsBegin = RecvPacketsHeadersEnd = 0;
+
+ const int r = RecvMMsgImpl(&RecvPacketsHeaders[0], (unsigned int)RecvPacketsHeaders.size(), 0, nullptr);
+ if (r >= 0) {
+ RecvPacketsHeadersEnd = r;
+ } else {
+ Y_ASSERT(LastSystemError() == EAGAIN || LastSystemError() == EWOULDBLOCK);
+ }
+ return r;
+ }
+
+ // not thread-safe
+ TUdpRecvPacket* TTryToRecvMMsgSocket::Recv(sockaddr_in6* fromAddress, sockaddr_in6* dstAddr, int) {
+ // act like TSocket
+ if (!IsRecvMMsgSupported()) {
+ return RecvImpl(RecvPackets[0], fromAddress, dstAddr);
+ }
+
+ if (FillRecvBuffers() <= 0) {
+ return nullptr;
+ }
+
+ TUdpRecvPacket* result = RecvPackets[RecvPacketsBegin]->ExtractPacket();
+ TMMsgHdr& mmsgHdr = RecvPacketsHeaders[RecvPacketsBegin];
+ result->DataSize = (ssize_t)mmsgHdr.msg_len;
+ if (dstAddr && !ExtractDestinationAddress(mmsgHdr.msg_hdr, dstAddr)) {
+ // fprintf(stderr, "can`t get destination ip\n");
+ }
+ *fromAddress = RecvPacketsSrcAddresses[RecvPacketsBegin];
+ //we must clean ctrlbuffer to be able to use it later
+#ifndef _win_
+ memset(mmsgHdr.msg_hdr.msg_control, 0, CTRL_BUFFER_SIZE);
+ mmsgHdr.msg_hdr.msg_controllen = CTRL_BUFFER_SIZE;
+#endif
+ RecvPacketsBegin++;
+
+ return result;
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ /* TODO: too slow, needs to be optimized
+template<size_t TTNumRecvThreads>
+class TMTRecvSocket: public TAbstractSocket
+{
+private:
+ typedef TLockFreePacketQueue<TTNumRecvThreads> TPacketQueue;
+
+ static void* RecvThreadFunc(void* that)
+ {
+ static_cast<TMTRecvSocket*>(that)->RecvLoop();
+ return NULL;
+ }
+
+ void RecvLoop()
+ {
+ TBestUnixRecvSocket impl;
+ impl.Reset(*this);
+
+ while (AtomicAdd(NumThreadsToDie, 0) == -1) {
+ sockaddr_in6 addr;
+ TUdpRecvPacket* packet = impl.Recv(&addr, NETLIBA_ANY_VERSION);
+ if (!packet) {
+ impl.Wait(0.0001, NETLIBA_ANY_VERSION); // so small tiomeout because we can't guarantee that 1 thread won't get all packets
+ continue;
+ }
+ Queue.Push(packet, addr);
+ }
+
+ if (AtomicDecrement(NumThreadsToDie)) {
+ impl.CancelWait(NETLIBA_ANY_VERSION);
+ } else {
+ AllThreadsAreDead.Signal();
+ }
+ }
+
+ THolderVector<TThread> RecvThreads;
+ TAtomic NumThreadsToDie;
+ TSystemEvent AllThreadsAreDead;
+
+ TPacketQueue Queue;
+
+public:
+ TMTRecvSocket()
+ : NumThreadsToDie(-1) {}
+
+ ~TMTRecvSocket()
+ {
+ Close();
+ }
+
+ int Open(int port)
+ {
+ if (OpenImpl(port) != 0) {
+ Y_ASSERT(!IsValid());
+ return -1;
+ }
+
+ NumThreadsToDie = -1;
+ RecvThreads.reserve(TTNumRecvThreads);
+ for (size_t i = 0; i != TTNumRecvThreads; ++i) {
+ RecvThreads.PushBack(new TThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_recv_skt")));
+ RecvThreads.back()->Start();
+ RecvThreads.back()->Detach();
+ }
+ return 0;
+ }
+
+ void Close()
+ {
+ if (!IsValid()) {
+ return;
+ }
+
+ AtomicSwap(&NumThreadsToDie, (int)RecvThreads.size());
+ CancelWaitImpl();
+ Y_VERIFY(AllThreadsAreDead.WaitT(TDuration::Seconds(30)), "TMTRecvSocket destruction failed");
+
+ CloseImpl();
+ }
+
+ void Wait(float timeoutSec, int netlibaVersion) const
+ {
+ Y_UNUSED(netlibaVersion);
+ Queue.GetEvent().WaitT(TDuration::Seconds(timeoutSec));
+ }
+ void CancelWait(int netlibaVersion)
+ {
+ Y_UNUSED(netlibaVersion);
+ Queue.GetEvent().Signal();
+ }
+
+ TUdpRecvPacket* Recv(sockaddr_in6 *addr, int netlibaVersion)
+ {
+ Y_UNUSED(netlibaVersion);
+ TUdpRecvPacket* result;
+ if (!Queue.Pop(&result, addr)) {
+ return NULL;
+ }
+ return result;
+ }
+
+ bool IsRecvMsgSupported() const { return false; }
+ ssize_t RecvMsg(TMsgHdr* hdr, int flags) { Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TMTRecvSocket implementation must use memcpy which is suboptimal and thus forbidden!"); }
+};
+*/
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ // Send.*, Recv, Wait and CancelWait are thread-safe.
+ class TDualStackSocket: public TTryToRecvMMsgSocket {
+ private:
+ typedef TTryToRecvMMsgSocket TBase;
+ typedef TLockFreePacketQueue<1> TPacketQueue;
+
+ static void* RecvThreadFunc(void* that);
+ void RecvLoop();
+
+ struct TFilteredPacketQueue {
+ enum EPushResult {
+ PR_FULL = 0,
+ PR_OK = 1,
+ PR_FILTERED = 2
+ };
+ const ui8 F1;
+ const ui8 F2;
+ const ui8 CmdPos;
+ TFilteredPacketQueue(ui8 f1, ui8 f2, ui8 cmdPos)
+ : F1(f1)
+ , F2(f2)
+ , CmdPos(cmdPos)
+ {
+ }
+ bool Pop(TUdpRecvPacket** packet, sockaddr_in6* srcAddr, sockaddr_in6* dstAddr) {
+ return Queue.Pop(packet, srcAddr, dstAddr);
+ }
+ ui8 Push(TUdpRecvPacket* packet, const TPacketMeta& meta) {
+ if (Queue.IsDataPartFull()) {
+ const ui8 cmd = packet->Data.get()[CmdPos];
+ if (cmd == F1 || cmd == F2)
+ return PR_FILTERED;
+ }
+ return Queue.Push(packet, meta); //false - PR_FULL, true - PR_OK
+ }
+ TPacketQueue Queue;
+ };
+
+ TFilteredPacketQueue& GetRecvQueue(int netlibaVersion) const;
+ TSystemEvent& GetQueueEvent(const TFilteredPacketQueue& queue) const;
+
+ TThread RecvThread;
+ TAtomic ShouldDie;
+ TSystemEvent DieEvent;
+
+ mutable TFilteredPacketQueue RecvQueue6;
+ mutable TFilteredPacketQueue RecvQueue12;
+
+ public:
+ TDualStackSocket();
+ ~TDualStackSocket() override;
+
+ int Open(int port) override;
+ void Close() override;
+
+ void Wait(float timeoutSec, int netlibaVersion) const override;
+ void CancelWait(int netlibaVersion) override;
+
+ bool IsRecvMsgSupported() const override {
+ return false;
+ }
+ ssize_t RecvMsg(TMsgHdr* hdr, int flags) override {
+ Y_UNUSED(hdr);
+ Y_UNUSED(flags);
+ Y_VERIFY(false, "Use TBasicSocket for RecvMsg call! TDualStackSocket implementation must use memcpy which is suboptimal and thus forbidden!");
+ }
+
+ TUdpRecvPacket* Recv(sockaddr_in6* addr, sockaddr_in6* dstAddr, int netlibaVersion) override;
+ };
+
+ TDualStackSocket::TDualStackSocket()
+ : RecvThread(TThread::TParams(RecvThreadFunc, this).SetName("nl12_dual_stack"))
+ , ShouldDie(0)
+ , RecvQueue6(NNetliba::DATA, NNetliba::DATA_SMALL, NNetliba::CMD_POS)
+ , RecvQueue12(NNetliba_v12::DATA, NNetliba_v12::DATA_SMALL, NNetliba_v12::CMD_POS)
+ {
+ }
+
+ // virtual functions don't work in dtors!
+ TDualStackSocket::~TDualStackSocket() {
+ Close();
+
+ sockaddr_in6 srcAdd;
+ sockaddr_in6 dstAddr;
+ TUdpRecvPacket* ptr = nullptr;
+
+ while (GetRecvQueue(NETLIBA_ANY_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) {
+ delete ptr;
+ }
+ while (GetRecvQueue(NETLIBA_V12_VERSION).Pop(&ptr, &srcAdd, &dstAddr)) {
+ delete ptr;
+ }
+ }
+
+ int TDualStackSocket::Open(int port) {
+ if (TBase::Open(port) != 0) {
+ Y_ASSERT(!IsValid());
+ return -1;
+ }
+
+ AtomicSet(ShouldDie, 0);
+ DieEvent.Reset();
+ RecvThread.Start();
+ RecvThread.Detach();
+ return 0;
+ }
+
+ void TDualStackSocket::Close() {
+ if (!IsValid()) {
+ return;
+ }
+
+ AtomicSwap(&ShouldDie, 1);
+ CancelWaitImpl();
+ Y_VERIFY(DieEvent.WaitT(TDuration::Seconds(30)), "TDualStackSocket::Close failed");
+
+ TBase::Close();
+ }
+
+ TDualStackSocket::TFilteredPacketQueue& TDualStackSocket::GetRecvQueue(int netlibaVersion) const {
+ return netlibaVersion == NETLIBA_V12_VERSION ? RecvQueue12 : RecvQueue6;
+ }
+
+ TSystemEvent& TDualStackSocket::GetQueueEvent(const TFilteredPacketQueue& queue) const {
+ return queue.Queue.GetEvent();
+ }
+
+ void* TDualStackSocket::RecvThreadFunc(void* that) {
+ SetHighestThreadPriority();
+ static_cast<TDualStackSocket*>(that)->RecvLoop();
+ return nullptr;
+ }
+
+ void TDualStackSocket::RecvLoop() {
+ for (;;) {
+ TUdpRecvPacket* p = nullptr;
+ sockaddr_in6 srcAddr;
+ sockaddr_in6 dstAddr;
+ while (AtomicAdd(ShouldDie, 0) == 0 && (p = TBase::Recv(&srcAddr, &dstAddr, NETLIBA_ANY_VERSION))) {
+ Y_ASSERT(p->DataStart == 0);
+ if (p->DataSize < 12) {
+ continue;
+ }
+
+ TFilteredPacketQueue& q = GetRecvQueue(p->Data.get()[8]);
+ const ui8 res = q.Push(p, {srcAddr, dstAddr});
+ if (res == TFilteredPacketQueue::PR_OK) {
+ GetQueueEvent(q).Signal();
+ } else {
+ // simulate OS behavior on buffer overflow - drop packets.
+ const NHPTimer::STime time = AtomicGet(RecvLag);
+ const float sec = NHPTimer::GetSeconds(time);
+ fprintf(stderr, "TDualStackSocket::RecvLoop netliba v%d queue overflow, recv lag: %f sec, dropping packet, res: %u\n",
+ &q == &RecvQueue12 ? 12 : 6, sec, res);
+ delete p;
+ }
+ }
+
+ if (AtomicAdd(ShouldDie, 0)) {
+ DieEvent.Signal();
+ return;
+ }
+
+ TBase::Wait(0.1f, NETLIBA_ANY_VERSION);
+ }
+ }
+
+ void TDualStackSocket::Wait(float timeoutSec, int netlibaVersion) const {
+ TFilteredPacketQueue& q = GetRecvQueue(netlibaVersion);
+ if (q.Queue.IsEmpty()) {
+ GetQueueEvent(q).Reset();
+ if (q.Queue.IsEmpty()) {
+ GetQueueEvent(q).WaitT(TDuration::Seconds(timeoutSec));
+ }
+ }
+ }
+
+ void TDualStackSocket::CancelWait(int netlibaVersion) {
+ GetQueueEvent(GetRecvQueue(netlibaVersion)).Signal();
+ }
+
+ // thread-safe
+ TUdpRecvPacket* TDualStackSocket::Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion) {
+ TUdpRecvPacket* result = nullptr;
+ if (!GetRecvQueue(netlibaVersion).Pop(&result, srcAddr, dstAddr)) {
+ return nullptr;
+ }
+ return result;
+ }
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ TIntrusivePtr<ISocket> CreateSocket() {
+ return new TSocket();
+ }
+
+ TIntrusivePtr<ISocket> CreateDualStackSocket() {
+ return new TDualStackSocket();
+ }
+
+ TIntrusivePtr<ISocket> CreateBestRecvSocket() {
+ // TSocket is faster than TRecvMMsgFunc in case of unsupported recvmmsg
+ if (!TTryToRecvMMsgSocket::IsRecvMMsgSupported()) {
+ return new TSocket();
+ }
+ return new TTryToRecvMMsgSocket();
+ }
+
+}
diff --git a/library/cpp/netliba/socket/socket.h b/library/cpp/netliba/socket/socket.h
new file mode 100644
index 0000000000..c1da3c145f
--- /dev/null
+++ b/library/cpp/netliba/socket/socket.h
@@ -0,0 +1,126 @@
+#pragma once
+
+#include <util/system/platform.h>
+#include <util/generic/noncopyable.h>
+#include <util/generic/ptr.h>
+#include <util/network/init.h>
+#include <util/system/defaults.h>
+#include <util/system/hp_timer.h>
+#include "udp_recv_packet.h"
+#include "protocols.h"
+
+#include <sys/uio.h>
+
+namespace NNetlibaSocket {
+ typedef iovec TIoVec;
+
+#ifdef _win32_
+ struct TMsgHdr {
+ void* msg_name; /* optional address */
+ int msg_namelen; /* size of address */
+ TIoVec* msg_iov; /* scatter/gather array */
+ int msg_iovlen; /* # elements in msg_iov */
+
+ int Tos; // netlib_socket extension
+ };
+#else
+#include <sys/socket.h>
+ typedef msghdr TMsgHdr;
+#endif
+
+ // equal to glibc 2.14 mmsghdr definition, defined for windows and darwin compatibility
+ struct TMMsgHdr {
+ TMsgHdr msg_hdr;
+ unsigned int msg_len;
+ };
+
+#if defined(_linux_)
+#include <linux/version.h>
+#include <features.h>
+// sendmmsg was added in glibc 2.14 and linux 3.0
+#if __GLIBC__ >= 2 && __GLIBC_MINOR__ >= 14 && LINUX_VERSION_CODE >= KERNEL_VERSION(3, 0, 0)
+#include <sys/socket.h> // sendmmsg
+ static_assert(sizeof(TMMsgHdr) == sizeof(mmsghdr), "expect sizeof(TMMsgHdr) == sizeof(mmsghdr)");
+#endif
+#endif
+
+#ifdef _win32_
+ const size_t TOS_BUFFER_SIZE = sizeof(int);
+ const size_t CTRL_BUFFER_SIZE = 32;
+#else
+#if defined(_darwin_)
+#define Y_DARWIN_ALIGN32(p) ((__darwin_size_t)((__darwin_size_t)(p) + __DARWIN_ALIGNBYTES32) & ~__DARWIN_ALIGNBYTES32)
+#define Y_CMSG_SPACE(l) (Y_DARWIN_ALIGN32(sizeof(struct cmsghdr)) + Y_DARWIN_ALIGN32(l))
+#else
+#define Y_CMSG_SPACE(l) CMSG_SPACE(l)
+#endif
+
+ constexpr size_t TOS_BUFFER_SIZE = Y_CMSG_SPACE(sizeof(int));
+ constexpr size_t CTRL_BUFFER_SIZE = Y_CMSG_SPACE(sizeof(int)) + Y_CMSG_SPACE(sizeof(struct in6_pktinfo));
+#endif
+
+ ///////////////////////////////////////////////////////////////////////////////
+ // Warning: every variable (tosBuffer, data, addr, iov) passed and returned from these functions must exist until actual send!!!
+ void* CreateTos(const ui8 tos, void* tosBuffer);
+ TIoVec CreateIoVec(char* data, const size_t dataSize);
+ TMsgHdr CreateSendMsgHdr(const sockaddr_in6& addr, const TIoVec& iov, void* tosBuffer);
+ TMsgHdr CreateRecvMsgHdr(sockaddr_in6* addrBuf, const TIoVec& iov, void* ctrlBuffer = nullptr);
+ TMsgHdr* AddSockAuxData(TMsgHdr* header, const ui8 tos, const sockaddr_in6& addr, void* buffer, size_t bufferSize);
+ ///////////////////////////////////////////////////////////////////////////////
+ //returns false if TOS wasn't readed and do not touch *tos
+ bool ReadTos(const TMsgHdr& msgHdr, ui8* tos);
+ bool ExtractDestinationAddress(TMsgHdr& msgHdr, sockaddr_in6* addrBuf);
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ // currently netliba v6 version id is any number which's not equal to NETLIBA_V12_VERSION
+ enum { NETLIBA_ANY_VERSION = -1,
+ NETLIBA_V12_VERSION = 112 };
+
+ enum EFragFlag {
+ FF_ALLOW_FRAG,
+ FF_DONT_FRAG
+ };
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ class ISocket: public TNonCopyable, public TThrRefBase {
+ public:
+ ~ISocket() override {
+ }
+
+ virtual int Open(int port) = 0;
+ virtual void Close() = 0;
+ virtual bool IsValid() const = 0;
+
+ virtual const sockaddr_in6& GetSelfAddress() const = 0;
+ virtual int GetNetworkOrderPort() const = 0;
+ virtual int GetPort() const = 0;
+
+ virtual int GetSockOpt(int level, int option_name, void* option_value, socklen_t* option_len) = 0;
+
+ // send all packets to this and only this address by default
+ virtual int Connect(const struct sockaddr* address, socklen_t address_len) = 0;
+
+ virtual void Wait(float timeoutSec, int netlibaVersion = NETLIBA_ANY_VERSION) const = 0;
+ virtual void CancelWait(int netlibaVersion = NETLIBA_ANY_VERSION) = 0;
+ virtual void CancelWaitHost(const sockaddr_in6 address) = 0;
+
+ virtual bool IsSendMMsgSupported() const = 0;
+ virtual int SendMMsg(struct TMMsgHdr* msgvec, unsigned int vlen, unsigned int flags) = 0;
+ virtual ssize_t SendMsg(const TMsgHdr* hdr, int flags, const EFragFlag frag) = 0;
+
+ virtual bool IsRecvMsgSupported() const = 0;
+ virtual ssize_t RecvMsg(TMsgHdr* hdr, int flags) = 0;
+ virtual TUdpRecvPacket* Recv(sockaddr_in6* srcAddr, sockaddr_in6* dstAddr, int netlibaVersion = NETLIBA_ANY_VERSION) = 0;
+ virtual bool IncreaseSendBuff() = 0;
+ virtual int GetSendSysSocketSize() = 0;
+ virtual void SetRecvLagTime(NHPTimer::STime time) = 0;
+ };
+
+ TIntrusivePtr<ISocket> CreateSocket(); // not thread safe!
+ TIntrusivePtr<ISocket> CreateDualStackSocket(); // has thread safe send/recv methods
+
+ // this function was added mostly for testing
+ TIntrusivePtr<ISocket> CreateBestRecvSocket();
+}
diff --git a/library/cpp/netliba/socket/stdafx.cpp b/library/cpp/netliba/socket/stdafx.cpp
new file mode 100644
index 0000000000..fd4f341c7b
--- /dev/null
+++ b/library/cpp/netliba/socket/stdafx.cpp
@@ -0,0 +1 @@
+#include "stdafx.h"
diff --git a/library/cpp/netliba/socket/stdafx.h b/library/cpp/netliba/socket/stdafx.h
new file mode 100644
index 0000000000..7d99e5dc10
--- /dev/null
+++ b/library/cpp/netliba/socket/stdafx.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <util/system/platform.h>
+#if defined(_darwin_)
+#define __APPLE_USE_RFC_2292
+#endif
+
+#include <util/system/compat.h>
+#include <util/network/init.h>
+#if defined(_unix_)
+#include <netdb.h>
+#include <fcntl.h>
+#elif defined(_win_)
+#include <winsock2.h>
+using socklen_t = int;
+#endif
diff --git a/library/cpp/netliba/socket/udp_recv_packet.h b/library/cpp/netliba/socket/udp_recv_packet.h
new file mode 100644
index 0000000000..a2777fbcbf
--- /dev/null
+++ b/library/cpp/netliba/socket/udp_recv_packet.h
@@ -0,0 +1,79 @@
+#pragma once
+
+#include <util/generic/noncopyable.h>
+#include <util/system/defaults.h>
+
+#include <memory>
+#include "allocator.h"
+
+namespace NNetlibaSocket {
+ enum { UDP_MAX_PACKET_SIZE = 8900 };
+
+ class TUdpHostRecvBufAlloc;
+ struct TUdpRecvPacket: public TWithCustomAllocator {
+ friend class TUdpHostRecvBufAlloc;
+ int DataStart = 0, DataSize = 0;
+ std::shared_ptr<char> Data;
+
+ private:
+ int ArraySize_ = 0;
+ };
+
+ ///////////////////////////////////////////////////////////////////////////////
+
+ class TUdpHostRecvBufAlloc: public TNonCopyable {
+ private:
+ mutable TUdpRecvPacket* RecvPktBuf;
+
+ static TUdpRecvPacket* Alloc() {
+ return new TUdpRecvPacket();
+ }
+
+ public:
+ static TUdpRecvPacket* Create(const int dataSize) {
+ TUdpRecvPacket* result = Alloc();
+ result->Data.reset(TCustomAllocator<char>().allocate(dataSize), [=](char* p) { TCustomAllocator<char>().deallocate(p, dataSize); }, TCustomAllocator<char>());
+ result->ArraySize_ = dataSize;
+ return result;
+ }
+ void SetNewPacket() const {
+ RecvPktBuf = CreateNewPacket();
+ }
+
+ public:
+ static TUdpRecvPacket* CreateNewSmallPacket(int dataSize) {
+ return Create(dataSize);
+ }
+ static TUdpRecvPacket* CreateNewPacket() {
+ return Create(UDP_MAX_PACKET_SIZE);
+ }
+ static TUdpRecvPacket* Clone(const TUdpRecvPacket* pkt) {
+ TUdpRecvPacket* result = Alloc();
+ result->DataStart = pkt->DataStart;
+ result->DataSize = pkt->DataSize;
+ result->Data = pkt->Data;
+ result->ArraySize_ = pkt->ArraySize_;
+ return result;
+ }
+
+ TUdpHostRecvBufAlloc() {
+ SetNewPacket();
+ }
+ ~TUdpHostRecvBufAlloc() {
+ delete RecvPktBuf;
+ }
+
+ TUdpRecvPacket* ExtractPacket() {
+ TUdpRecvPacket* res = RecvPktBuf;
+ SetNewPacket();
+ return res;
+ }
+
+ int GetBufSize() const {
+ return RecvPktBuf->ArraySize_;
+ }
+ char* GetDataPtr() const {
+ return RecvPktBuf->Data.get();
+ }
+ };
+}
diff --git a/library/cpp/netliba/v6/block_chain.cpp b/library/cpp/netliba/v6/block_chain.cpp
new file mode 100644
index 0000000000..cbb56d9a5e
--- /dev/null
+++ b/library/cpp/netliba/v6/block_chain.cpp
@@ -0,0 +1,90 @@
+#include "stdafx.h"
+#include "block_chain.h"
+
+#include <util/system/unaligned_mem.h>
+
+namespace NNetliba {
+ ui32 CalcChecksum(const void* p, int size) {
+ //return 0;
+ //return CalcCrc32(p, size);
+ i64 sum = 0;
+ const unsigned char *pp = (const unsigned char*)p, *pend = pp + size;
+ for (const unsigned char* pend4 = pend - 3; pp < pend4; pp += 4)
+ sum += *(const ui32*)pp;
+
+ ui32 left = 0, pos = 0;
+ for (; pp < pend; ++pp) {
+ pos += ((ui32)*pp) << left;
+ left += 8;
+ }
+
+ sum += pos;
+ sum = (sum & 0xffffffff) + (sum >> 32);
+ sum += sum >> 32;
+ return (ui32)~sum;
+ }
+
+ ui32 CalcChecksum(const TBlockChain& chain) {
+ TIncrementalChecksumCalcer ics;
+ AddChain(&ics, chain);
+ return ics.CalcChecksum();
+ }
+
+ void TIncrementalChecksumCalcer::AddBlock(const void* p, int size) {
+ ui32 sum = CalcBlockSum(p, size);
+ AddBlockSum(sum, size);
+ }
+
+ void TIncrementalChecksumCalcer::AddBlockSum(ui32 sum, int size) {
+ for (int k = 0; k < Offset; ++k)
+ sum = (sum >> 24) + ((sum & 0xffffff) << 8);
+ TotalSum += sum;
+
+ Offset = (Offset + size) & 3;
+ }
+
+ ui32 TIncrementalChecksumCalcer::CalcBlockSum(const void* p, int size) {
+ i64 sum = 0;
+ const unsigned char *pp = (const unsigned char*)p, *pend = pp + size;
+ for (const unsigned char* pend4 = pend - 3; pp < pend4; pp += 4)
+ sum += ReadUnaligned<ui32>(pp);
+
+ ui32 left = 0, pos = 0;
+ for (; pp < pend; ++pp) {
+ pos += ((ui32)*pp) << left;
+ left += 8;
+ }
+ sum += pos;
+ sum = (sum & 0xffffffff) + (sum >> 32);
+ sum += sum >> 32;
+ return (ui32)sum;
+ }
+
+ ui32 TIncrementalChecksumCalcer::CalcChecksum() {
+ TotalSum = (TotalSum & 0xffffffff) + (TotalSum >> 32);
+ TotalSum += TotalSum >> 32;
+ return (ui32)~TotalSum;
+ }
+
+ //void TestChainChecksum()
+ //{
+ // TVector<char> data;
+ // data.resize(10);
+ // for (int i = 0; i < data.ysize(); ++i)
+ // data[i] = rand();
+ // int crc1 = CalcChecksum(&data[0], data.size());
+ //
+ // TBlockChain chain;
+ // TIncrementalChecksumCalcer incCS;
+ // for (int offset = 0; offset < data.ysize();) {
+ // int sz = Min(rand() % 10, data.ysize() - offset);
+ // chain.AddBlock(&data[offset], sz);
+ // incCS.AddBlock(&data[offset], sz);
+ // offset += sz;
+ // }
+ // int crc2 = CalcChecksum(chain);
+ // Y_ASSERT(crc1 == crc2);
+ // int crc3 = incCS.CalcChecksum();
+ // Y_ASSERT(crc1 == crc3);
+ //}
+}
diff --git a/library/cpp/netliba/v6/block_chain.h b/library/cpp/netliba/v6/block_chain.h
new file mode 100644
index 0000000000..25247ec05f
--- /dev/null
+++ b/library/cpp/netliba/v6/block_chain.h
@@ -0,0 +1,319 @@
+#pragma once
+
+#include <util/generic/algorithm.h>
+#include <util/generic/list.h>
+#include <util/system/shmat.h>
+#include <util/generic/noncopyable.h>
+
+namespace NNetliba {
+ class TBlockChain {
+ public:
+ struct TBlock {
+ const char* Data;
+ int Offset, Size; // Offset in whole chain
+
+ TBlock()
+ : Data(nullptr)
+ , Offset(0)
+ , Size(0)
+ {
+ }
+ TBlock(const char* data, int offset, int sz)
+ : Data(data)
+ , Offset(offset)
+ , Size(sz)
+ {
+ }
+ };
+
+ private:
+ typedef TVector<TBlock> TBlockVector;
+ TBlockVector Blocks;
+ int Size;
+ struct TBlockLess {
+ bool operator()(const TBlock& b, int offset) const {
+ return b.Offset < offset;
+ }
+ };
+
+ public:
+ TBlockChain()
+ : Size(0)
+ {
+ }
+ void AddBlock(const void* data, int sz) {
+ Blocks.push_back(TBlock((const char*)data, Size, sz));
+ Size += sz;
+ }
+ int GetSize() const {
+ return Size;
+ }
+ const TBlock& GetBlock(int i) const {
+ return Blocks[i];
+ }
+ int GetBlockCount() const {
+ return Blocks.ysize();
+ }
+ int GetBlockIdByOffset(int offset) const {
+ TBlockVector::const_iterator i = LowerBound(Blocks.begin(), Blocks.end(), offset, TBlockLess());
+ if (i == Blocks.end())
+ return Blocks.ysize() - 1;
+ if (i->Offset == offset)
+ return (int)(i - Blocks.begin());
+ return (int)(i - Blocks.begin() - 1);
+ }
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ class TBlockChainIterator {
+ const TBlockChain& Chain;
+ int Pos, BlockPos, BlockId;
+ bool Failed;
+
+ public:
+ TBlockChainIterator(const TBlockChain& chain)
+ : Chain(chain)
+ , Pos(0)
+ , BlockPos(0)
+ , BlockId(0)
+ , Failed(false)
+ {
+ }
+ void Read(void* dst, int sz) {
+ char* dstBuf = (char*)dst;
+ while (sz > 0) {
+ if (BlockId >= Chain.GetBlockCount()) {
+ // JACKPOT!
+ fprintf(stderr, "reading beyond chain end: BlockId %d, Chain.GetBlockCount() %d, Pos %d, BlockPos %d\n",
+ BlockId, Chain.GetBlockCount(), Pos, BlockPos);
+ Y_ASSERT(0 && "reading beyond chain end");
+ memset(dstBuf, 0, sz);
+ Failed = true;
+ return;
+ }
+ const TBlockChain::TBlock& blk = Chain.GetBlock(BlockId);
+ int copySize = Min(blk.Size - BlockPos, sz);
+ memcpy(dstBuf, blk.Data + BlockPos, copySize);
+ dstBuf += copySize;
+ Pos += copySize;
+ BlockPos += copySize;
+ sz -= copySize;
+ if (BlockPos == blk.Size) {
+ BlockPos = 0;
+ ++BlockId;
+ }
+ }
+ }
+ void Seek(int pos) {
+ if (pos < 0 || pos > Chain.GetSize()) {
+ Y_ASSERT(0);
+ Pos = 0;
+ BlockPos = 0;
+ BlockId = 0;
+ return;
+ }
+ BlockId = Chain.GetBlockIdByOffset(pos);
+ const TBlockChain::TBlock& blk = Chain.GetBlock(BlockId);
+ Pos = pos;
+ BlockPos = Pos - blk.Offset;
+ }
+ int GetPos() const {
+ return Pos;
+ }
+ int GetSize() const {
+ return Chain.GetSize();
+ }
+ bool HasFailed() const {
+ return Failed;
+ }
+ void Fail() {
+ Failed = true;
+ }
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ class TRopeDataPacket: public TNonCopyable {
+ TBlockChain Chain;
+ TVector<char*> Buf;
+ char *Block, *BlockEnd;
+ TList<TVector<char>> DataVectors;
+ TIntrusivePtr<TSharedMemory> SharedData;
+ TVector<TIntrusivePtr<TThrRefBase>> AttachedStorage;
+ char DefaultBuf[128]; // prevent allocs in most cases
+ enum {
+ N_DEFAULT_BLOCK_SIZE = 1024
+ };
+
+ char* Alloc(int sz) {
+ char* res = nullptr;
+ if (BlockEnd - Block < sz) {
+ int bufSize = Max((int)N_DEFAULT_BLOCK_SIZE, sz);
+ char* newBlock = AllocBuf(bufSize);
+ Block = newBlock;
+ BlockEnd = Block + bufSize;
+ Buf.push_back(newBlock);
+ }
+ res = Block;
+ Block += sz;
+ Y_ASSERT(Block <= BlockEnd);
+ return res;
+ }
+
+ public:
+ TRopeDataPacket()
+ : Block(DefaultBuf)
+ , BlockEnd(DefaultBuf + Y_ARRAY_SIZE(DefaultBuf))
+ {
+ }
+ ~TRopeDataPacket() {
+ for (size_t i = 0; i < Buf.size(); ++i)
+ FreeBuf(Buf[i]);
+ }
+ static char* AllocBuf(int sz) {
+ return new char[sz];
+ }
+ static void FreeBuf(char* buf) {
+ delete[] buf;
+ }
+
+ // buf - pointer to buffer which will be freed with FreeBuf()
+ // data - pointer to data start within buf
+ // sz - size of useful data
+ void AddBlock(char* buf, const char* data, int sz) {
+ Buf.push_back(buf);
+ Chain.AddBlock(data, sz);
+ }
+ void AddBlock(TThrRefBase* buf, const char* data, int sz) {
+ AttachedStorage.push_back(buf);
+ Chain.AddBlock(data, sz);
+ }
+ //
+ void Write(const void* data, int sz) {
+ char* buf = Alloc(sz);
+ memcpy(buf, data, sz);
+ Chain.AddBlock(buf, sz);
+ }
+ template <class T>
+ void Write(const T& data) {
+ Write(&data, sizeof(T));
+ }
+ //// caller guarantees that data will persist all *this lifetime
+ //// int this case so we don`t have to copy data to locally held buffer
+ //template<class T>
+ //void WriteNoCopy(const T *data)
+ //{
+ // Chain.AddBlock(data, sizeof(T));
+ //}
+ // write some array like TVector<>
+ //template<class T>
+ //void WriteArr(const T &sz)
+ //{
+ // int n = (int)sz.size();
+ // Write(n);
+ // if (n > 0)
+ // Write(&sz[0], n * sizeof(sz[0]));
+ //}
+ void WriteStroka(const TString& sz) {
+ int n = (int)sz.size();
+ Write(n);
+ if (n > 0)
+ Write(sz.c_str(), n * sizeof(sz[0]));
+ }
+ // will take *data ownership, saves copy
+ void WriteDestructive(TVector<char>* data) {
+ int n = data ? data->ysize() : 0;
+ Write(n);
+ if (n > 0) {
+ TVector<char>& local = DataVectors.emplace_back(std::move(*data));
+ Chain.AddBlock(&local[0], local.ysize());
+ }
+ }
+ void AttachSharedData(TIntrusivePtr<TSharedMemory> shm) {
+ SharedData = shm;
+ }
+ TSharedMemory* GetSharedData() const {
+ return SharedData.Get();
+ }
+ const TBlockChain& GetChain() {
+ return Chain;
+ }
+ int GetSize() {
+ return Chain.GetSize();
+ }
+ };
+
+ template <class T>
+ inline void ReadArr(TBlockChainIterator* res, T* dst) {
+ int n;
+ res->Read(&n, sizeof(n));
+ if (n >= 0) {
+ dst->resize(n);
+ if (n > 0)
+ res->Read(&(*dst)[0], n * sizeof((*dst)[0]));
+ } else {
+ dst->resize(0);
+ res->Fail();
+ }
+ }
+
+ template <>
+ inline void ReadArr<TString>(TBlockChainIterator* res, TString* dst) {
+ int n;
+ res->Read(&n, sizeof(n));
+ if (n >= 0) {
+ dst->resize(n);
+ if (n > 0)
+ res->Read(dst->begin(), n * sizeof(TString::value_type));
+ } else {
+ dst->resize(0);
+ res->Fail();
+ }
+ }
+
+ // saves on zeroing *dst with yresize()
+ template <class T>
+ static void ReadYArr(TBlockChainIterator* res, TVector<T>* dst) {
+ int n;
+ res->Read(&n, sizeof(n));
+ if (n >= 0) {
+ dst->yresize(n);
+ if (n > 0)
+ res->Read(&(*dst)[0], n * sizeof((*dst)[0]));
+ } else {
+ dst->yresize(0);
+ res->Fail();
+ }
+ }
+
+ template <class T>
+ static void Read(TBlockChainIterator* res, T* dst) {
+ res->Read(dst, sizeof(T));
+ }
+
+ ui32 CalcChecksum(const void* p, int size);
+ ui32 CalcChecksum(const TBlockChain& chain);
+
+ class TIncrementalChecksumCalcer {
+ i64 TotalSum;
+ int Offset;
+
+ public:
+ TIncrementalChecksumCalcer()
+ : TotalSum(0)
+ , Offset(0)
+ {
+ }
+ void AddBlock(const void* p, int size);
+ void AddBlockSum(ui32 sum, int size);
+ ui32 CalcChecksum();
+
+ static ui32 CalcBlockSum(const void* p, int size);
+ };
+
+ inline void AddChain(TIncrementalChecksumCalcer* ics, const TBlockChain& chain) {
+ for (int k = 0; k < chain.GetBlockCount(); ++k) {
+ const TBlockChain::TBlock& blk = chain.GetBlock(k);
+ ics->AddBlock(blk.Data, blk.Size);
+ }
+ }
+}
diff --git a/library/cpp/netliba/v6/cpu_affinity.cpp b/library/cpp/netliba/v6/cpu_affinity.cpp
new file mode 100644
index 0000000000..7808092a72
--- /dev/null
+++ b/library/cpp/netliba/v6/cpu_affinity.cpp
@@ -0,0 +1,138 @@
+#include "stdafx.h"
+#include "cpu_affinity.h"
+
+#if defined(__FreeBSD__) && (__FreeBSD__ >= 7)
+#include <sys/param.h>
+#include <sys/cpuset.h>
+#elif defined(_linux_)
+#include <pthread.h>
+#include <util/stream/file.h>
+#include <util/string/printf.h>
+#endif
+
+namespace NNetliba {
+ class TCPUSet {
+ public:
+ enum { MAX_SIZE = 128 };
+
+ private:
+#if defined(__FreeBSD__) && (__FreeBSD__ >= 7)
+#define NUMCPU ((CPU_MAXSIZE > MAX_SIZE) ? 1 : (MAX_SIZE / CPU_MAXSIZE))
+ cpuset_t CpuInfo[NUMCPU];
+
+ public:
+ bool GetAffinity() {
+ int error = cpuset_getaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuInfo), CpuInfo);
+ return error == 0;
+ }
+ bool SetAffinity() {
+ int error = cpuset_setaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuInfo), CpuInfo);
+ return error == 0;
+ }
+ bool IsSet(size_t i) {
+ return CPU_ISSET(i, CpuInfo);
+ }
+ void Set(size_t i) {
+ CPU_SET(i, CpuInfo);
+ }
+#elif defined(_linux_)
+ public:
+#define NUMCPU ((CPU_SETSIZE > MAX_SIZE) ? 1 : (MAX_SIZE / CPU_SETSIZE))
+ cpu_set_t CpuInfo[NUMCPU];
+
+ public:
+ bool GetAffinity() {
+ int error = pthread_getaffinity_np(pthread_self(), sizeof(CpuInfo), CpuInfo);
+ return error == 0;
+ }
+ bool SetAffinity() {
+ int error = pthread_setaffinity_np(pthread_self(), sizeof(CpuInfo), CpuInfo);
+ return error == 0;
+ }
+ bool IsSet(size_t i) {
+ return CPU_ISSET(i, CpuInfo);
+ }
+ void Set(size_t i) {
+ CPU_SET(i, CpuInfo);
+ }
+#else
+ public:
+ bool GetAffinity() {
+ return true;
+ }
+ bool SetAffinity() {
+ return true;
+ }
+ bool IsSet(size_t i) {
+ Y_UNUSED(i);
+ return true;
+ }
+ void Set(size_t i) {
+ Y_UNUSED(i);
+ }
+#endif
+
+ TCPUSet() {
+ Clear();
+ }
+ void Clear() {
+ memset(this, 0, sizeof(*this));
+ }
+ };
+
+ static TMutex CPUSetsLock;
+ struct TCPUSetInfo {
+ TCPUSet CPUSet;
+ bool IsOk;
+
+ TCPUSetInfo()
+ : IsOk(false)
+ {
+ }
+ };
+ static THashMap<int, TCPUSetInfo> CPUSets;
+
+ void BindToSocket(int n) {
+ TGuard<TMutex> gg(CPUSetsLock);
+ if (CPUSets.find(n) == CPUSets.end()) {
+ TCPUSetInfo& res = CPUSets[n];
+
+ bool foundCPU = false;
+#ifdef _linux_
+ for (int cpuId = 0; cpuId < TCPUSet::MAX_SIZE; ++cpuId) {
+ try { // I just wanna check if file exists, I don't want your stinking exceptions :/
+ TIFStream f(Sprintf("/sys/devices/system/cpu/cpu%d/topology/physical_package_id", cpuId).c_str());
+ TString s;
+ if (f.ReadLine(s) && !s.empty()) {
+ //printf("cpu%d - %s\n", cpuId, s.c_str());
+ int physCPU = atoi(s.c_str());
+ if (physCPU == 0) {
+ res.IsOk = true;
+ res.CPUSet.Set(cpuId);
+ foundCPU = true;
+ }
+ } else {
+ break;
+ }
+ } catch (const TFileError&) {
+ break;
+ }
+ }
+#endif
+ if (!foundCPU && n == 0) {
+ for (int i = 0; i < 6; ++i) {
+ res.CPUSet.Set(i);
+ }
+ res.IsOk = true;
+ foundCPU = true;
+ }
+ }
+ {
+ TCPUSetInfo& cc = CPUSets[n];
+ if (cc.IsOk) {
+ cc.CPUSet.SetAffinity();
+ }
+ }
+ }
+
+}
diff --git a/library/cpp/netliba/v6/cpu_affinity.h b/library/cpp/netliba/v6/cpu_affinity.h
new file mode 100644
index 0000000000..a580edd829
--- /dev/null
+++ b/library/cpp/netliba/v6/cpu_affinity.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace NNetliba {
+ void BindToSocket(int n);
+}
diff --git a/library/cpp/netliba/v6/cstdafx.h b/library/cpp/netliba/v6/cstdafx.h
new file mode 100644
index 0000000000..e11fe54cc4
--- /dev/null
+++ b/library/cpp/netliba/v6/cstdafx.h
@@ -0,0 +1,41 @@
+#pragma once
+
+#ifdef _WIN32
+#pragma warning(disable : 4530 4244 4996)
+#include <malloc.h>
+#include <util/system/winint.h>
+#endif
+
+#include <util/system/defaults.h>
+#include <util/system/mutex.h>
+#include <util/system/event.h>
+#include <library/cpp/deprecated/atomic/atomic.h>
+#include <util/system/yassert.h>
+#include <util/system/compat.h>
+
+#include <util/ysafeptr.h>
+
+#include <util/stream/output.h>
+
+#include <library/cpp/string_utils/url/url.h>
+
+#include <library/cpp/charset/codepage.h>
+#include <library/cpp/charset/recyr.hh>
+
+#include <util/generic/vector.h>
+#include <util/generic/hash.h>
+#include <util/generic/list.h>
+#include <util/generic/hash_set.h>
+#include <util/generic/ptr.h>
+#include <util/generic/ymath.h>
+#include <util/generic/utility.h>
+#include <util/generic/algorithm.h>
+
+#include <array>
+#include <cstdlib>
+#include <cstdio>
+
+namespace NNetliba {
+ typedef unsigned char byte;
+ typedef ssize_t yint;
+}
diff --git a/library/cpp/netliba/v6/ib_buffers.cpp b/library/cpp/netliba/v6/ib_buffers.cpp
new file mode 100644
index 0000000000..323846b633
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_buffers.cpp
@@ -0,0 +1,5 @@
+#include "stdafx.h"
+#include "ib_buffers.h"
+
+namespace NNetliba {
+}
diff --git a/library/cpp/netliba/v6/ib_buffers.h b/library/cpp/netliba/v6/ib_buffers.h
new file mode 100644
index 0000000000..96d83ff654
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_buffers.h
@@ -0,0 +1,181 @@
+#pragma once
+
+#include "ib_low.h"
+
+namespace NNetliba {
+ // buffer id 0 is special, it is used when data is sent inline and should not be returned
+ const size_t SMALL_PKT_SIZE = 4096;
+
+ const ui64 BP_AH_USED_FLAG = 0x1000000000000000ul;
+ const ui64 BP_BUF_ID_MASK = 0x00000000fffffffful;
+
+ // single thread version
+ class TIBBufferPool: public TThrRefBase, TNonCopyable {
+ enum {
+ BLOCK_SIZE_LN = 11,
+ BLOCK_SIZE = 1 << BLOCK_SIZE_LN,
+ BLOCK_COUNT = 1024
+ };
+ struct TSingleBlock {
+ TIntrusivePtr<TMemoryRegion> Mem;
+ TVector<ui8> BlkRefCounts;
+ TVector<TIntrusivePtr<TAddressHandle>> AHHolder;
+
+ void Alloc(TPtrArg<TIBContext> ctx) {
+ size_t dataSize = SMALL_PKT_SIZE * BLOCK_SIZE;
+ Mem = new TMemoryRegion(ctx, dataSize);
+ BlkRefCounts.resize(BLOCK_SIZE, 0);
+ AHHolder.resize(BLOCK_SIZE);
+ }
+ char* GetBufData(ui64 idArg) {
+ char* data = Mem->GetData();
+ return data + (idArg & (BLOCK_SIZE - 1)) * SMALL_PKT_SIZE;
+ }
+ };
+
+ TIntrusivePtr<TIBContext> IBCtx;
+ TVector<int> FreeList;
+ TVector<TSingleBlock> Blocks;
+ size_t FirstFreeBlock;
+ int PostRecvDeficit;
+ TIntrusivePtr<TSharedReceiveQueue> SRQ;
+
+ void AddBlock() {
+ if (FirstFreeBlock == Blocks.size()) {
+ Y_VERIFY(0, "run out of buffers");
+ }
+ Blocks[FirstFreeBlock].Alloc(IBCtx);
+ size_t start = (FirstFreeBlock == 0) ? 1 : FirstFreeBlock * BLOCK_SIZE;
+ size_t finish = FirstFreeBlock * BLOCK_SIZE + BLOCK_SIZE;
+ for (size_t i = start; i < finish; ++i) {
+ FreeList.push_back(i);
+ }
+ ++FirstFreeBlock;
+ }
+
+ public:
+ TIBBufferPool(TPtrArg<TIBContext> ctx, int maxSRQWorkRequests)
+ : IBCtx(ctx)
+ , FirstFreeBlock(0)
+ , PostRecvDeficit(maxSRQWorkRequests)
+ {
+ Blocks.resize(BLOCK_COUNT);
+ AddBlock();
+ SRQ = new TSharedReceiveQueue(ctx, maxSRQWorkRequests);
+
+ PostRecv();
+ }
+ TSharedReceiveQueue* GetSRQ() const {
+ return SRQ.Get();
+ }
+ int AllocBuf() {
+ if (FreeList.empty()) {
+ AddBlock();
+ }
+ int id = FreeList.back();
+ FreeList.pop_back();
+ Y_ASSERT(++Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 1);
+ return id;
+ }
+ void FreeBuf(ui64 idArg) {
+ ui64 id = idArg & BP_BUF_ID_MASK;
+ if (id == 0) {
+ return;
+ }
+ Y_ASSERT(id > 0 && id < (ui64)(FirstFreeBlock * BLOCK_SIZE));
+ FreeList.push_back(id);
+ Y_ASSERT(--Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 0);
+ if (idArg & BP_AH_USED_FLAG) {
+ Blocks[id >> BLOCK_SIZE_LN].AHHolder[id & (BLOCK_SIZE - 1)] = nullptr;
+ }
+ }
+ char* GetBufData(ui64 idArg) {
+ ui64 id = idArg & BP_BUF_ID_MASK;
+ return Blocks[id >> BLOCK_SIZE_LN].GetBufData(id);
+ }
+ int PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
+ if (len > SMALL_PKT_SIZE) {
+ Y_VERIFY(0, "buffer overrun");
+ }
+ if (len <= MAX_INLINE_DATA_SIZE) {
+ qp->PostSend(nullptr, 0, data, len);
+ return 0;
+ } else {
+ int id = AllocBuf();
+ TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
+ char* buf = blk.GetBufData(id);
+ memcpy(buf, data, len);
+ qp->PostSend(blk.Mem, id, buf, len);
+ return id;
+ }
+ }
+ void PostSend(TPtrArg<TUDQueuePair> qp, TPtrArg<TAddressHandle> ah, int remoteQPN, int remoteQKey,
+ const void* data, size_t len) {
+ if (len > SMALL_PKT_SIZE - 40) {
+ Y_VERIFY(0, "buffer overrun");
+ }
+ ui64 id = AllocBuf();
+ TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
+ int ptr = id & (BLOCK_SIZE - 1);
+ blk.AHHolder[ptr] = ah.Get();
+ id |= BP_AH_USED_FLAG;
+ if (len <= MAX_INLINE_DATA_SIZE) {
+ qp->PostSend(ah, remoteQPN, remoteQKey, nullptr, id, data, len);
+ } else {
+ char* buf = blk.GetBufData(id);
+ memcpy(buf, data, len);
+ qp->PostSend(ah, remoteQPN, remoteQKey, blk.Mem, id, buf, len);
+ }
+ }
+ void RequestPostRecv() {
+ ++PostRecvDeficit;
+ }
+ void PostRecv() {
+ for (int i = 0; i < PostRecvDeficit; ++i) {
+ int id = AllocBuf();
+ TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
+ char* buf = blk.GetBufData(id);
+ SRQ->PostReceive(blk.Mem, id, buf, SMALL_PKT_SIZE);
+ }
+ PostRecvDeficit = 0;
+ }
+ };
+
+ class TIBRecvPacketProcess: public TNonCopyable {
+ TIBBufferPool& BP;
+ ui64 Id;
+ char* Data;
+
+ public:
+ TIBRecvPacketProcess(TIBBufferPool& bp, const ibv_wc& wc)
+ : BP(bp)
+ , Id(wc.wr_id)
+ {
+ Y_ASSERT(wc.opcode & IBV_WC_RECV);
+ BP.RequestPostRecv();
+ Data = BP.GetBufData(Id);
+ }
+ // intended for postponed packet processing
+ // with this call RequestPostRecv() should be called outside (and PostRecv() too in order to avoid rnr situation)
+ TIBRecvPacketProcess(TIBBufferPool& bp, const ui64 wr_id)
+ : BP(bp)
+ , Id(wr_id)
+ {
+ Data = BP.GetBufData(Id);
+ }
+ ~TIBRecvPacketProcess() {
+ BP.FreeBuf(Id);
+ BP.PostRecv();
+ }
+ char* GetData() const {
+ return Data;
+ }
+ char* GetUDData() const {
+ return Data + 40;
+ }
+ ibv_grh* GetGRH() const {
+ return (ibv_grh*)Data;
+ }
+ };
+
+}
diff --git a/library/cpp/netliba/v6/ib_collective.cpp b/library/cpp/netliba/v6/ib_collective.cpp
new file mode 100644
index 0000000000..87ea166366
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_collective.cpp
@@ -0,0 +1,1317 @@
+#include "stdafx.h"
+#include "ib_collective.h"
+#include "ib_mem.h"
+#include "ib_buffers.h"
+#include "ib_low.h"
+#include "udp_http.h"
+#include "udp_address.h"
+#include <util/generic/deque.h>
+#include <util/system/hp_timer.h>
+
+namespace NNetliba {
+ const int COL_SERVICE_LEVEL = 2;
+ const int COL_DATA_SERVICE_LEVEL = 2; // base level
+ const int COL_DATA_SERVICE_LEVEL_COUNT = 6; // level count
+ const int MAX_REQS_PER_PEER = 32;
+ const int MAX_TOTAL_RDMA = 20;
+ const int SEND_COUNT_TABLE_SIZE = 1 << 12; // must be power of 2
+
+ struct TMergeRecord {
+ struct TTransfer {
+ int DstRank;
+ int SL;
+ int RangeBeg, RangeFin;
+ int Id;
+
+ TTransfer()
+ : DstRank(-1)
+ , SL(0)
+ , RangeBeg(0)
+ , RangeFin(0)
+ , Id(0)
+ {
+ }
+ TTransfer(int dstRank, int sl, int rangeBeg, int rangeFin, int id)
+ : DstRank(dstRank)
+ , SL(sl)
+ , RangeBeg(rangeBeg)
+ , RangeFin(rangeFin)
+ , Id(id)
+ {
+ }
+ };
+ struct TInTransfer {
+ int SrcRank;
+ int SL;
+
+ TInTransfer()
+ : SrcRank(-1)
+ , SL(0)
+ {
+ }
+ TInTransfer(int srcRank, int sl)
+ : SrcRank(srcRank)
+ , SL(sl)
+ {
+ }
+ };
+
+ TVector<TTransfer> OutList;
+ TVector<TInTransfer> InList;
+ ui64 RecvMask;
+
+ TMergeRecord()
+ : RecvMask(0)
+ {
+ }
+ };
+
+ struct TMergeIteration {
+ TVector<TMergeRecord> Ops;
+
+ void Init(int colSize) {
+ Ops.resize(colSize);
+ }
+ void Transfer(int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin, int id) {
+ Y_VERIFY(id < 64, "recv mask overflow");
+ Ops[srcRank].OutList.push_back(TMergeRecord::TTransfer(dstRank, sl, rangeBeg, rangeFin, id));
+ Ops[dstRank].InList.push_back(TMergeRecord::TInTransfer(srcRank, sl));
+ Ops[dstRank].RecvMask |= ui64(1) << id;
+ }
+ };
+
+ struct TMergePlan {
+ TVector<TMergeIteration> Iterations;
+ TVector<int> RankReceiveCount;
+ int ColSize;
+ int MaxRankReceiveCount;
+
+ TMergePlan()
+ : ColSize(0)
+ , MaxRankReceiveCount(0)
+ {
+ }
+ void Init(int colSize) {
+ Iterations.resize(0);
+ RankReceiveCount.resize(0);
+ RankReceiveCount.resize(colSize, 0);
+ ColSize = colSize;
+ }
+ void Transfer(int iter, int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin) {
+ while (iter >= Iterations.ysize()) {
+ TMergeIteration& res = Iterations.emplace_back();
+ res.Init(ColSize);
+ }
+ int id = RankReceiveCount[dstRank]++;
+ MaxRankReceiveCount = Max(MaxRankReceiveCount, id + 1);
+ Y_ASSERT(id < 64);
+ Iterations[iter].Transfer(srcRank, dstRank, sl, rangeBeg, rangeFin, id);
+ }
+ };
+
+ struct TSRTransfer {
+ int SrcRank, DstRank;
+ int RangeBeg, RangeFin;
+
+ TSRTransfer() {
+ Zero(*this);
+ }
+ TSRTransfer(int srcRank, int dstRank, int rangeBeg, int rangeFin)
+ : SrcRank(srcRank)
+ , DstRank(dstRank)
+ , RangeBeg(rangeBeg)
+ , RangeFin(rangeFin)
+ {
+ }
+ };
+
+ static int SplitRange(THashMap<int, TVector<TSRTransfer>>* res, int iter, int beg, int fin) {
+ int mid = (beg + fin + 1) / 2;
+ if (mid == fin) {
+ return iter;
+ }
+ for (int i = 0; i < fin - mid; ++i) {
+ (*res)[iter].push_back(TSRTransfer(beg + i, mid + i, beg, mid));
+ (*res)[iter].push_back(TSRTransfer(mid + i, beg + i, mid, fin));
+ }
+ if (fin - mid < mid - beg) {
+ // [mid - 1] did not receive [mid;fin)
+ (*res)[iter].push_back(TSRTransfer(mid, mid - 1, mid, fin));
+ }
+ int rv1 = SplitRange(res, iter + 1, beg, mid);
+ int rv2 = SplitRange(res, iter + 1, mid, fin);
+ return Max(rv1, rv2);
+ }
+
+ static void CreatePow2Merge(TMergePlan* plan, int colSize) {
+ // finally everybody has full range [0;ColSize)
+ // construct plan recursively, on each iteration split some range
+ plan->Init(colSize);
+
+ THashMap<int, TVector<TSRTransfer>> allTransfers;
+ int maxIter = SplitRange(&allTransfers, 0, 0, colSize);
+
+ for (int iter = 0; iter < maxIter; ++iter) {
+ const TVector<TSRTransfer>& arr = allTransfers[maxIter - iter - 1]; // reverse order
+ for (int i = 0; i < arr.ysize(); ++i) {
+ const TSRTransfer& sr = arr[i];
+ plan->Transfer(iter, sr.SrcRank, sr.DstRank, 0, sr.RangeBeg, sr.RangeFin);
+ }
+ }
+ }
+
+ struct TCoverInterval {
+ int Beg, Fin; // [Beg;Fin)
+
+ TCoverInterval()
+ : Beg(0)
+ , Fin(0)
+ {
+ }
+ TCoverInterval(int b, int f)
+ : Beg(b)
+ , Fin(f)
+ {
+ }
+ };
+
+ enum EAllToAllMode {
+ AA_POW2,
+ AA_CIRCLE,
+ AA_STAR,
+ AA_POW2_MERGE,
+ };
+ static int AllToAll(TMergePlan* plan, int iter, int sl, EAllToAllMode mode, const TVector<int>& myGroup, TVector<TCoverInterval>* cover) {
+ TVector<TCoverInterval>& hostCoverage = *cover;
+ int groupSize = myGroup.ysize();
+
+ for (int k = 1; k < groupSize; ++k) {
+ int h1 = myGroup[k - 1];
+ int h2 = myGroup[k];
+ Y_VERIFY(hostCoverage[h1].Fin == hostCoverage[h2].Beg, "Invalid host order in CreateGroupMerge()");
+ }
+
+ switch (mode) {
+ case AA_POW2: {
+ for (int delta = 1; delta < groupSize; delta *= 2) {
+ int sz = Min(delta, groupSize - delta);
+ for (int offset = 0; offset < groupSize; ++offset) {
+ int srcRank = myGroup[offset];
+ int dstRank = myGroup[(offset + delta) % groupSize];
+
+ int start = offset + 1 - sz;
+ int finish = offset + 1;
+ if (start < 0) {
+ // [start; myGroup.size())
+ int dataBeg = hostCoverage[myGroup[start + groupSize]].Beg;
+ int dataFin = hostCoverage[myGroup.back()].Fin;
+ plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
+ // [0; finish)
+ dataBeg = hostCoverage[myGroup[0]].Beg;
+ dataFin = hostCoverage[myGroup[finish - 1]].Fin;
+ plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
+ } else {
+ // [start;finish)
+ int dataBeg = hostCoverage[myGroup[start]].Beg;
+ int dataFin = hostCoverage[myGroup[finish - 1]].Fin;
+ plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
+ }
+ }
+ ++iter;
+ }
+ } break;
+ case AA_CIRCLE: {
+ for (int dataDelta = 1; dataDelta < groupSize; ++dataDelta) {
+ for (int offset = 0; offset < groupSize; ++offset) {
+ int srcRank = myGroup[offset];
+ int dstRank = myGroup[(offset + 1) % groupSize];
+
+ int dataRank = myGroup[(offset + 1 - dataDelta + groupSize) % groupSize];
+ int dataBeg = hostCoverage[dataRank].Beg;
+ int dataFin = hostCoverage[dataRank].Fin;
+
+ plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
+ }
+ ++iter;
+ }
+ } break;
+ case AA_STAR: {
+ for (int offset = 0; offset < groupSize; ++offset) {
+ for (int delta = 1; delta < groupSize; ++delta) {
+ int srcRank = myGroup[offset];
+ int dstRank = myGroup[(offset + delta) % groupSize];
+
+ int dataRank = myGroup[offset];
+ int dataBeg = hostCoverage[dataRank].Beg;
+ int dataFin = hostCoverage[dataRank].Fin;
+
+ plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin);
+ }
+ }
+ ++iter;
+ } break;
+ case AA_POW2_MERGE: {
+ TMergePlan pp;
+ CreatePow2Merge(&pp, groupSize);
+ for (int z = 0; z < pp.Iterations.ysize(); ++z) {
+ const TMergeIteration& mm = pp.Iterations[z];
+ for (int src = 0; src < mm.Ops.ysize(); ++src) {
+ const TMergeRecord& mr = mm.Ops[src];
+ int srcRank = myGroup[src];
+ for (int i = 0; i < mr.OutList.ysize(); ++i) {
+ int dstRank = myGroup[mr.OutList[i].DstRank];
+ plan->Transfer(iter, srcRank, dstRank, sl, 0, 1);
+ }
+ }
+ ++iter;
+ }
+ } break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ {
+ TCoverInterval cc(hostCoverage[myGroup[0]].Beg, hostCoverage[myGroup.back()].Fin);
+ for (int k = 0; k < groupSize; ++k) {
+ hostCoverage[myGroup[k]] = cc;
+ }
+ }
+ return iter;
+ }
+
+ // fully populated matrix
+ static void CreateGroupMerge(TMergePlan* plan, EAllToAllMode mode, const TVector<TVector<int>>& hostGroup) {
+ int hostCount = hostGroup[0].ysize();
+ int groupTypeCount = hostGroup.ysize();
+
+ plan->Init(hostCount);
+
+ TVector<int> gcount;
+ gcount.resize(groupTypeCount, 0);
+ for (int hostId = 0; hostId < hostCount; ++hostId) {
+ for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
+ int val = hostGroup[groupType][hostId];
+ gcount[groupType] = Max(gcount[groupType], val + 1);
+ }
+ }
+ for (int hostId = 1; hostId < hostCount; ++hostId) {
+ bool isIncrement = true;
+ for (int groupType = 0; groupType < groupTypeCount; ++groupType) {
+ int prev = hostGroup[groupType][hostId - 1];
+ int cur = hostGroup[groupType][hostId];
+ if (isIncrement) {
+ if (cur == prev + 1) {
+ isIncrement = false;
+ } else {
+ Y_VERIFY(cur == 0, "ib_hosts, wrapped to non-zero");
+ Y_VERIFY(prev == gcount[groupType] - 1, "ib_hosts, structure is irregular");
+ isIncrement = true;
+ }
+ } else {
+ Y_VERIFY(prev == cur, "ib_hosts, structure is irregular");
+ }
+ }
+ }
+
+ TVector<TCoverInterval> hostCoverage;
+ for (int i = 0; i < hostCount; ++i) {
+ hostCoverage.push_back(TCoverInterval(i, i + 1));
+ }
+
+ int baseIter = 0;
+ for (int groupType = hostGroup.ysize() - 1; groupType >= 0; --groupType) {
+ Y_ASSERT(hostGroup[groupType].ysize() == hostCount);
+ TVector<TVector<int>> hh;
+ hh.resize(gcount[groupType]);
+ for (int rank = 0; rank < hostGroup[groupType].ysize(); ++rank) {
+ int groupId = hostGroup[groupType][rank];
+ hh[groupId].push_back(rank);
+ }
+ int newIter = 0;
+ for (int groupId = 0; groupId < hh.ysize(); ++groupId) {
+ int nn = AllToAll(plan, baseIter, 0, mode, hh[groupId], &hostCoverage); // seems to be fastest
+ if (newIter == 0) {
+ newIter = nn;
+ } else {
+ Y_VERIFY(newIter == nn, "groups should be of same size");
+ }
+ }
+ baseIter = newIter;
+ }
+ //printf("%d iterations symmetrical plan\n", baseIter);
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ struct TAllDataSync {
+ enum {
+ WR_COUNT = 64 * 2
+ };
+
+ int CurrentBuffer;
+ TIntrusivePtr<TIBMemBlock> MemBlock[2];
+ TIntrusivePtr<TComplectionQueue> CQ;
+ TIntrusivePtr<TSharedReceiveQueue> SRQ;
+ TIntrusivePtr<TIBMemBlock> FakeRecvMem;
+ size_t DataSize, BufSize;
+ size_t CurrentOffset, ReadyOffset;
+ bool WasFlushed;
+ int ActiveRDMACount;
+ ui64 FutureRecvMask;
+ TIntrusivePtr<IReduceOp> ReduceOp;
+
+ struct TBlockInfo {
+ ui64 Addr;
+ ui32 Key;
+ };
+ struct TSend {
+ TBlockInfo RemoteBlocks[2];
+ TIntrusivePtr<TRCQueuePair> QP;
+ size_t SrcOffset;
+ size_t DstOffset;
+ size_t Length;
+ ui32 ImmData;
+ int DstRank;
+ union {
+ struct {
+ int RangeBeg, RangeFin;
+ } Gather;
+ struct {
+ int SrcIndex, DstIndex;
+ } Reduce;
+ };
+ };
+ struct TRecv {
+ TIntrusivePtr<TRCQueuePair> QP;
+ int SrcRank;
+ };
+ struct TReduce {
+ size_t DstOffset, SrcOffset;
+ int DstIndex, SrcIndex;
+ };
+ struct TIteration {
+ TVector<TSend> OutList;
+ TVector<TRecv> InList;
+ TVector<TReduce> ReduceList;
+ ui64 RecvMask;
+ };
+ TVector<TIteration> Iterations;
+
+ public:
+ void* GetRawData() {
+ char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
+ return myData + CurrentOffset;
+ }
+ size_t GetRawDataSize() {
+ return DataSize;
+ }
+ void PostRecv() {
+ SRQ->PostReceive(FakeRecvMem->GetMemRegion(), 0, FakeRecvMem->GetData(), FakeRecvMem->GetSize());
+ }
+ void Sync() {
+ Y_ASSERT(WasFlushed && "Have to call Flush() before data fill & Sync()");
+ char* myData = (char*)MemBlock[CurrentBuffer]->GetData();
+
+ ui64 recvMask = FutureRecvMask;
+ FutureRecvMask = 0;
+ int recvDebt = 0;
+ for (int z = 0; z < Iterations.ysize(); ++z) {
+ const TIteration& iter = Iterations[z];
+ for (int k = 0; k < iter.OutList.ysize(); ++k) {
+ const TSend& ss = iter.OutList[k];
+ const TBlockInfo& remoteBlk = ss.RemoteBlocks[CurrentBuffer];
+ ss.QP->PostRDMAWriteImm(remoteBlk.Addr + ss.DstOffset, remoteBlk.Key, ss.ImmData,
+ MemBlock[CurrentBuffer]->GetMemRegion(), 0, myData + ss.SrcOffset, ss.Length);
+ ++ActiveRDMACount;
+ //printf("-> %d, imm %d (%" PRId64 " bytes)\n", ss.DstRank, ss.ImmData, ss.Length);
+ //printf("send %d\n", ss.SrcOffset);
+ }
+ ibv_wc wc;
+ while ((recvMask & iter.RecvMask) != iter.RecvMask) {
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0) {
+ Y_VERIFY(wc.status == IBV_WC_SUCCESS, "AllGather::Sync fail, status %d", (int)wc.status);
+ if (wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
+ //printf("Got %d\n", wc.imm_data);
+ ++recvDebt;
+ ui64 newBit = ui64(1) << wc.imm_data;
+ if (recvMask & newBit) {
+ Y_VERIFY((FutureRecvMask & newBit) == 0, "data from 2 Sync() ahead is impossible");
+ FutureRecvMask |= newBit;
+ } else {
+ recvMask |= newBit;
+ }
+ } else if (wc.opcode == IBV_WC_RDMA_WRITE) {
+ --ActiveRDMACount;
+ } else {
+ Y_ASSERT(0);
+ }
+ } else {
+ if (recvDebt > 0) {
+ PostRecv();
+ --recvDebt;
+ }
+ }
+ }
+ for (int k = 0; k < iter.ReduceList.ysize(); ++k) {
+ const TReduce& rr = iter.ReduceList[k];
+ ReduceOp->Reduce(myData + rr.DstOffset, myData + rr.SrcOffset, DataSize);
+ //printf("Merge %d -> %d (%d bytes)\n", rr.SrcOffset, rr.DstOffset, DataSize);
+ }
+ //printf("Iteration %d done\n", z);
+ }
+ while (recvDebt > 0) {
+ PostRecv();
+ --recvDebt;
+ }
+ CurrentOffset = ReadyOffset;
+ WasFlushed = false;
+ //printf("new cur offset %g\n", (double)CurrentOffset);
+ //printf("Sync complete\n");
+ }
+ void Flush() {
+ Y_ASSERT(!WasFlushed);
+ CurrentBuffer = 1 - CurrentBuffer;
+ CurrentOffset = 0;
+ WasFlushed = true;
+ }
+
+ public:
+ TAllDataSync(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
+ : CurrentBuffer(0)
+ , DataSize(0)
+ , BufSize(bufSize)
+ , CurrentOffset(0)
+ , ReadyOffset(0)
+ , WasFlushed(false)
+ , ActiveRDMACount(0)
+ , FutureRecvMask(0)
+ , ReduceOp(reduceOp)
+ {
+ if (memPool) {
+ MemBlock[0] = memPool->Alloc(BufSize);
+ MemBlock[1] = memPool->Alloc(BufSize);
+ CQ = new TComplectionQueue(memPool->GetIBContext(), WR_COUNT);
+ SRQ = new TSharedReceiveQueue(memPool->GetIBContext(), WR_COUNT);
+ FakeRecvMem = memPool->Alloc(4096);
+ } else {
+ MemBlock[0] = new TIBMemBlock(BufSize);
+ MemBlock[1] = new TIBMemBlock(BufSize);
+ CQ = new TComplectionQueue(nullptr, WR_COUNT);
+ SRQ = new TSharedReceiveQueue(nullptr, WR_COUNT);
+ FakeRecvMem = new TIBMemBlock(4096);
+ }
+ for (int i = 0; i < WR_COUNT; ++i) {
+ PostRecv();
+ }
+ }
+ ~TAllDataSync() {
+ while (ActiveRDMACount > 0) {
+ ibv_wc wc;
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0) {
+ if (wc.opcode == IBV_WC_RDMA_WRITE) {
+ --ActiveRDMACount;
+ } else {
+ Y_ASSERT(0);
+ }
+ }
+ }
+ }
+ };
+
+ class TAllReduce: public IAllReduce {
+ TAllDataSync DataSync;
+ size_t BufSizeMult;
+ size_t ReadyOffsetMult;
+
+ public:
+ TAllReduce(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp)
+ : DataSync(bufSize, memPool, reduceOp)
+ , BufSizeMult(0)
+ , ReadyOffsetMult(0)
+ {
+ }
+ TAllDataSync& GetDataSync() {
+ return DataSync;
+ }
+ void* GetRawData() override {
+ return DataSync.GetRawData();
+ }
+ size_t GetRawDataSize() override {
+ return DataSync.GetRawDataSize();
+ }
+ void Sync() override {
+ DataSync.Sync();
+ }
+ void Flush() override {
+ DataSync.Flush();
+ }
+
+ bool Resize(size_t dataSize) override {
+ size_t repSize = (dataSize + 63) & (~63ull);
+ size_t bufSize = repSize * BufSizeMult;
+
+ if (bufSize > DataSync.BufSize) {
+ return false;
+ }
+
+ for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
+ TAllDataSync::TIteration& iter = DataSync.Iterations[z];
+ for (int i = 0; i < iter.OutList.ysize(); ++i) {
+ TAllDataSync::TSend& snd = iter.OutList[i];
+ snd.Length = dataSize;
+ snd.SrcOffset = snd.Reduce.SrcIndex * repSize;
+ snd.DstOffset = snd.Reduce.DstIndex * repSize;
+ }
+
+ for (int i = 0; i < iter.ReduceList.ysize(); ++i) {
+ TAllDataSync::TReduce& red = iter.ReduceList[i];
+ red.SrcOffset = red.SrcIndex * repSize;
+ red.DstOffset = red.DstIndex * repSize;
+ }
+ }
+ DataSync.ReadyOffset = ReadyOffsetMult * repSize;
+ DataSync.DataSize = dataSize;
+ return true;
+ }
+ friend class TIBCollective;
+ };
+
+ class TAllGather: public IAllGather {
+ TAllDataSync DataSync;
+ int ColSize;
+
+ public:
+ TAllGather(int colSize, size_t bufSize, TPtrArg<TIBMemPool> memPool)
+ : DataSync(bufSize, memPool, nullptr)
+ , ColSize(colSize)
+ {
+ }
+ TAllDataSync& GetDataSync() {
+ return DataSync;
+ }
+ void* GetRawData() override {
+ return DataSync.GetRawData();
+ }
+ size_t GetRawDataSize() override {
+ return DataSync.GetRawDataSize();
+ }
+ void Sync() override {
+ DataSync.Sync();
+ }
+ void Flush() override {
+ DataSync.Flush();
+ }
+
+ bool Resize(const TVector<size_t>& szPerRank) override {
+ Y_VERIFY(szPerRank.ysize() == ColSize, "Invalid size array");
+
+ TVector<size_t> offsets;
+ offsets.push_back(0);
+ for (int rank = 0; rank < ColSize; ++rank) {
+ offsets.push_back(offsets.back() + szPerRank[rank]);
+ }
+
+ size_t dataSize = offsets.back();
+ if (dataSize > DataSync.BufSize) {
+ return false;
+ }
+
+ for (int z = 0; z < DataSync.Iterations.ysize(); ++z) {
+ TAllDataSync::TIteration& iter = DataSync.Iterations[z];
+ for (int i = 0; i < iter.OutList.ysize(); ++i) {
+ TAllDataSync::TSend& snd = iter.OutList[i];
+ int rangeBeg = snd.Gather.RangeBeg;
+ int rangeFin = snd.Gather.RangeFin;
+ snd.Length = offsets[rangeFin] - offsets[rangeBeg];
+ snd.SrcOffset = offsets[rangeBeg];
+ snd.DstOffset = snd.SrcOffset;
+ }
+ }
+ DataSync.DataSize = dataSize;
+ return true;
+ }
+ };
+
+ struct TIBAddr {
+ int LID, SL;
+
+ TIBAddr()
+ : LID(0)
+ , SL(0)
+ {
+ }
+ TIBAddr(int lid, int sl)
+ : LID(lid)
+ , SL(sl)
+ {
+ }
+ };
+ inline bool operator==(const TIBAddr& a, const TIBAddr& b) {
+ return a.LID == b.LID && a.SL == b.SL;
+ }
+ inline bool operator<(const TIBAddr& a, const TIBAddr& b) {
+ if (a.LID == b.LID) {
+ return a.SL < b.SL;
+ }
+ return a.LID < b.LID;
+ }
+
+ struct TIBAddrHash {
+ int operator()(const TIBAddr& a) const {
+ return a.LID + a.SL * 4254515;
+ }
+ };
+
+ class TIBCollective: public IIBCollective {
+ struct TPendingMessage {
+ int QPN;
+ ui64 WorkId;
+
+ TPendingMessage() {
+ Zero(*this);
+ }
+ TPendingMessage(int qpn, ui64 wid)
+ : QPN(qpn)
+ , WorkId(wid)
+ {
+ }
+ };
+
+ struct TBlockInform {
+ TAllDataSync::TBlockInfo RemoteBlocks[2];
+ int PSN, QPN;
+ };
+
+ struct TPeerConnection {
+ TAllDataSync::TBlockInfo RemoteBlocks[2];
+ TIntrusivePtr<TRCQueuePair> QP;
+ };
+
+ struct TBWTest {
+ ui64 Addr;
+ ui32 RKey;
+ };
+
+ TIntrusivePtr<TIBPort> Port;
+ TIntrusivePtr<TIBMemPool> MemPool;
+ int ColSize, ColRank;
+ TVector<int> Hosts; // host LIDs
+ TVector<TVector<int>> HostGroup;
+ TVector<TIntrusivePtr<TRCQueuePair>> Peers;
+ TIntrusivePtr<TComplectionQueue> CQ;
+ TIntrusivePtr<TIBBufferPool> BP;
+ ui8 SendCountTable[SEND_COUNT_TABLE_SIZE];
+ ui8 RDMACountTable[SEND_COUNT_TABLE_SIZE];
+ TDeque<TPendingMessage> Pending;
+ TMergePlan MergePlan, ReducePlan;
+ int QPNTableSizeLog;
+
+ void WriteCompleted(const ibv_wc& wc) {
+ --SendCountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
+ if (wc.opcode == IBV_WC_RDMA_WRITE) {
+ --RDMACountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)];
+ }
+ BP->FreeBuf(wc.wr_id);
+ }
+ bool GetMsg(ui64* resWorkId, int* resQPN, TIBMicroPeerTable* tbl) {
+ if (tbl->NeedParsePending()) {
+ for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
+ if (!tbl->NeedQPN(z->QPN)) {
+ continue;
+ }
+ *resWorkId = z->WorkId;
+ *resQPN = z->QPN;
+ Pending.erase(z);
+ return true;
+ }
+ //printf("Stop parse pending\n");
+ tbl->StopParsePending();
+ }
+ for (;;) {
+ ibv_wc wc;
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0) {
+ Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
+ if (wc.opcode & IBV_WC_RECV) {
+ BP->RequestPostRecv();
+ if (tbl->NeedQPN(wc.qp_num)) {
+ *resWorkId = wc.wr_id;
+ *resQPN = wc.qp_num;
+ return true;
+ } else {
+ Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
+ BP->PostRecv();
+ }
+ } else {
+ WriteCompleted(wc);
+ }
+ } else {
+ return false;
+ }
+ }
+ }
+
+ bool ProcessSendCompletion(const ibv_wc& wc) {
+ Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
+ if (wc.opcode & IBV_WC_RECV) {
+ BP->RequestPostRecv();
+ Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
+ BP->PostRecv();
+ } else {
+ WriteCompleted(wc);
+ return true;
+ }
+ return false;
+ }
+
+ void WaitCompletion(ibv_wc* res) {
+ ibv_wc& wc = *res;
+ for (;;) {
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0 && ProcessSendCompletion(wc)) {
+ break;
+ }
+ }
+ }
+
+ bool TryWaitCompletion() override {
+ ibv_wc wc;
+ for (;;) {
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0) {
+ if (ProcessSendCompletion(wc)) {
+ return true;
+ }
+ } else {
+ return false;
+ }
+ }
+ }
+
+ void WaitCompletion() override {
+ ibv_wc wc;
+ WaitCompletion(&wc);
+ }
+
+ ui64 WaitForMsg(int qpn) {
+ for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) {
+ if (z->QPN == qpn) {
+ ui64 workId = z->WorkId;
+ Pending.erase(z);
+ return workId;
+ }
+ }
+ ibv_wc wc;
+ for (;;) {
+ int rv = CQ->Poll(&wc, 1);
+ if (rv > 0) {
+ Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status);
+ if (wc.opcode & IBV_WC_RECV) {
+ BP->RequestPostRecv();
+ if ((int)wc.qp_num == qpn) {
+ return wc.wr_id;
+ } else {
+ Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id));
+ BP->PostRecv();
+ }
+ } else {
+ WriteCompleted(wc);
+ }
+ }
+ }
+ }
+
+ bool AllocOperationSlot(TPtrArg<TRCQueuePair> qp) {
+ int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
+ if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
+ return false;
+ }
+ ++SendCountTable[way];
+ return true;
+ }
+ bool AllocRDMAWriteSlot(TPtrArg<TRCQueuePair> qp) {
+ int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1);
+ if (SendCountTable[way] >= MAX_REQS_PER_PEER) {
+ return false;
+ }
+ if (RDMACountTable[way] >= MAX_OUTSTANDING_RDMA) {
+ return false;
+ }
+ ++SendCountTable[way];
+ ++RDMACountTable[way];
+ return true;
+ }
+ bool TryPostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
+ if (AllocOperationSlot(qp)) {
+ BP->PostSend(qp, data, len);
+ return true;
+ }
+ return false;
+ }
+ void PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
+ while (!TryPostSend(qp, data, len)) {
+ WaitCompletion();
+ }
+ }
+ int GetRank() override {
+ return ColRank;
+ }
+ int GetSize() override {
+ return ColSize;
+ }
+ int GetGroupTypeCount() override {
+ return HostGroup.ysize();
+ }
+ int GetQPN(int rank) override {
+ if (rank == ColRank) {
+ Y_ASSERT(0 && "there is no qpn connected to localhost");
+ return 0;
+ }
+ return Peers[rank]->GetQPN();
+ }
+
+ void Start(const TCollectiveLinkSet& links) override {
+ Hosts = links.Hosts;
+ HostGroup = links.HostGroup;
+ for (int k = 0; k < ColSize; ++k) {
+ if (k == ColRank) {
+ continue;
+ }
+ const TCollectiveLinkSet::TLinkInfo& lnk = links.Links[k];
+ ibv_ah_attr peerAddr;
+ MakeAH(&peerAddr, Port, Hosts[k], COL_SERVICE_LEVEL);
+ Peers[k]->Init(peerAddr, lnk.QPN, lnk.PSN);
+ }
+
+ //CreatePow2Merge(&MergePlan, ColSize);
+ //CreatePow2Merge(&ReducePlan, ColSize);
+ CreateGroupMerge(&MergePlan, AA_STAR, HostGroup);
+ CreateGroupMerge(&ReducePlan, AA_POW2_MERGE, HostGroup);
+ }
+
+ void CreateDataSyncQPs(
+ TPtrArg<TComplectionQueue> cq,
+ TPtrArg<TSharedReceiveQueue> srq,
+ TPtrArg<TIBMemBlock> memBlock0,
+ TPtrArg<TIBMemBlock> memBlock1,
+ const TMergePlan& plan,
+ THashMap<TIBAddr, TPeerConnection, TIBAddrHash>* res) {
+ THashMap<TIBAddr, TPeerConnection, TIBAddrHash>& connections = *res;
+
+ TIBMemBlock* memBlock[2] = {memBlock0, memBlock1};
+
+ // make full peer list
+ TVector<TIBAddr> peerList;
+ for (int z = 0; z < plan.Iterations.ysize(); ++z) {
+ const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
+ for (int i = 0; i < rr.OutList.ysize(); ++i) {
+ const TMergeRecord::TTransfer& tr = rr.OutList[i];
+ peerList.push_back(TIBAddr(tr.DstRank, tr.SL));
+ }
+ for (int i = 0; i < rr.InList.ysize(); ++i) {
+ const TMergeRecord::TInTransfer& tr = rr.InList[i];
+ peerList.push_back(TIBAddr(tr.SrcRank, tr.SL));
+ }
+ }
+ Sort(peerList.begin(), peerList.end());
+ peerList.erase(Unique(peerList.begin(), peerList.end()), peerList.end());
+
+ // establish QPs and exchange mem block handlers
+ for (int z = 0; z < peerList.ysize(); ++z) {
+ const TIBAddr& ibAddr = peerList[z];
+ int dstRank = ibAddr.LID;
+ TPeerConnection& dst = connections[ibAddr];
+
+ dst.QP = new TRCQueuePair(Port->GetCtx(), cq, srq, TAllDataSync::WR_COUNT);
+
+ TBlockInform myBlock;
+ for (int k = 0; k < 2; ++k) {
+ myBlock.RemoteBlocks[k].Addr = memBlock[k]->GetAddr();
+ myBlock.RemoteBlocks[k].Key = memBlock[k]->GetMemRegion()->GetRKey();
+ }
+ myBlock.PSN = dst.QP->GetPSN();
+ myBlock.QPN = dst.QP->GetQPN();
+ PostSend(Peers[dstRank], &myBlock, sizeof(myBlock));
+ }
+
+ for (int z = 0; z < peerList.ysize(); ++z) {
+ const TIBAddr& ibAddr = peerList[z];
+ int dstRank = ibAddr.LID;
+ int sl = COL_DATA_SERVICE_LEVEL + ClampVal(ibAddr.SL, 0, COL_DATA_SERVICE_LEVEL_COUNT);
+
+ TPeerConnection& dst = connections[ibAddr];
+
+ ui64 wr_id = WaitForMsg(Peers[dstRank]->GetQPN());
+ TIBRecvPacketProcess pkt(*BP, wr_id);
+ const TBlockInform& info = *(TBlockInform*)pkt.GetData();
+ ibv_ah_attr peerAddr;
+ MakeAH(&peerAddr, Port, Hosts[dstRank], COL_DATA_SERVICE_LEVEL + sl);
+ dst.QP->Init(peerAddr, info.QPN, info.PSN);
+ dst.RemoteBlocks[0] = info.RemoteBlocks[0];
+ dst.RemoteBlocks[1] = info.RemoteBlocks[1];
+ }
+ Fence();
+ }
+
+ IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) override {
+ const TMergePlan& plan = MergePlan;
+
+ Y_VERIFY(szPerRank.ysize() == ColSize, "Invalid size array");
+
+ size_t totalSize = 0;
+ for (int i = 0; i < szPerRank.ysize(); ++i) {
+ totalSize += szPerRank[i];
+ }
+ size_t bufSize = 4096;
+ while (totalSize >= bufSize) {
+ bufSize *= 2;
+ }
+
+ TAllGather* res = new TAllGather(ColSize, bufSize, MemPool);
+ TAllDataSync& ds = res->GetDataSync();
+
+ THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
+ CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
+
+ // build plan
+ for (int z = 0; z < plan.Iterations.ysize(); ++z) {
+ const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
+ if (rr.OutList.empty() && rr.InList.empty()) {
+ continue;
+ }
+ TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
+ for (int i = 0; i < rr.OutList.ysize(); ++i) {
+ const TMergeRecord::TTransfer& tr = rr.OutList[i];
+ TAllDataSync::TSend& snd = iter.OutList.emplace_back();
+ TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
+
+ snd.ImmData = tr.Id;
+ snd.Gather.RangeBeg = tr.RangeBeg;
+ snd.Gather.RangeFin = tr.RangeFin;
+ snd.QP = pc.QP;
+ snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
+ snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
+ snd.DstRank = tr.DstRank;
+ }
+ for (int i = 0; i < rr.InList.ysize(); ++i) {
+ const TMergeRecord::TInTransfer& tr = rr.InList[i];
+ TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
+ TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
+ rcv.QP = pc.QP;
+ rcv.SrcRank = tr.SrcRank;
+ }
+ iter.RecvMask = rr.RecvMask;
+ }
+ bool rv = res->Resize(szPerRank);
+ Y_VERIFY(rv, "oops");
+
+ return res;
+ }
+ IAllGather* CreateAllGather(size_t szPerRank) override {
+ TVector<size_t> arr;
+ arr.resize(ColSize, szPerRank);
+ return CreateAllGather(arr);
+ }
+
+ IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) override {
+ const TMergePlan& plan = ReducePlan;
+
+ size_t bufSizeMult = plan.MaxRankReceiveCount + 1;
+ size_t bufSize = 4096;
+ {
+ size_t sz = (dataSize + 64) * bufSizeMult;
+ while (sz > bufSize) {
+ bufSize *= 2;
+ }
+ }
+
+ TAllReduce* res = new TAllReduce(bufSize, MemPool, reduceOp);
+ TAllDataSync& ds = res->GetDataSync();
+
+ THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections;
+ CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections);
+
+ // build plan
+ int currentDataOffset = 0;
+ for (int z = 0; z < plan.Iterations.ysize(); ++z) {
+ const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
+ if (rr.OutList.empty() && rr.InList.empty()) {
+ continue;
+ }
+ TAllDataSync::TIteration& iter = ds.Iterations.emplace_back();
+ for (int i = 0; i < rr.OutList.ysize(); ++i) {
+ const TMergeRecord::TTransfer& tr = rr.OutList[i];
+ TAllDataSync::TSend& snd = iter.OutList.emplace_back();
+ TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)];
+
+ snd.ImmData = tr.Id;
+ snd.Reduce.SrcIndex = currentDataOffset;
+ snd.Reduce.DstIndex = tr.Id + 1;
+ snd.QP = pc.QP;
+ snd.RemoteBlocks[0] = pc.RemoteBlocks[0];
+ snd.RemoteBlocks[1] = pc.RemoteBlocks[1];
+ snd.DstRank = tr.DstRank;
+ }
+
+ for (int i = 0; i < rr.InList.ysize(); ++i) {
+ const TMergeRecord::TInTransfer& tr = rr.InList[i];
+ TAllDataSync::TRecv& rcv = iter.InList.emplace_back();
+ TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)];
+ rcv.QP = pc.QP;
+ rcv.SrcRank = tr.SrcRank;
+ }
+ iter.RecvMask = rr.RecvMask;
+
+ TVector<int> inputOffset;
+ inputOffset.push_back(currentDataOffset);
+ int newDataOffset = currentDataOffset;
+ for (int i = 0; i < 64; ++i) {
+ if (rr.RecvMask & (1ull << i)) {
+ int offset = i + 1;
+ inputOffset.push_back(offset);
+ newDataOffset = Max(offset, newDataOffset);
+ }
+ }
+ for (int i = 0; i < inputOffset.ysize(); ++i) {
+ if (inputOffset[i] == newDataOffset) {
+ continue;
+ }
+ TAllDataSync::TReduce& red = iter.ReduceList.emplace_back();
+ red.SrcIndex = inputOffset[i];
+ red.DstIndex = newDataOffset;
+ }
+ currentDataOffset = newDataOffset;
+ }
+ res->BufSizeMult = bufSizeMult;
+ res->ReadyOffsetMult = currentDataOffset;
+
+ bool rv = res->Resize(dataSize);
+ Y_VERIFY(rv, "oops");
+
+ return res;
+ }
+
+ void Fence() override {
+ const TMergePlan& plan = ReducePlan;
+
+ for (int z = 0; z < plan.Iterations.ysize(); ++z) {
+ const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank];
+ for (int i = 0; i < rr.OutList.ysize(); ++i) {
+ const TMergeRecord::TTransfer& tr = rr.OutList[i];
+ char c;
+ PostSend(Peers[tr.DstRank], &c, sizeof(c));
+ }
+
+ for (int i = 0; i < rr.InList.ysize(); ++i) {
+ const TMergeRecord::TInTransfer& tr = rr.InList[i];
+ ui64 wr_id = WaitForMsg(Peers[tr.SrcRank]->GetQPN());
+ TIBRecvPacketProcess pkt(*BP, wr_id);
+ }
+ }
+ }
+ void RunBWTest(int groupType, int delta, int* targetRank, float* res) override {
+ const int BUF_SIZE = 8 * 1024 * 1024;
+ TIntrusivePtr<TIBMemBlock> sendMem, recvMem;
+ sendMem = MemPool->Alloc(BUF_SIZE);
+ recvMem = MemPool->Alloc(BUF_SIZE);
+
+ int myGroup = HostGroup[groupType][ColRank];
+ int myGroupPos = 0;
+ TVector<int> gg;
+ Y_ASSERT(HostGroup[groupType].ysize() == ColSize);
+ for (int rank = 0; rank < ColSize; ++rank) {
+ if (HostGroup[groupType][rank] == myGroup) {
+ if (rank == ColRank) {
+ myGroupPos = gg.ysize();
+ }
+ gg.push_back(rank);
+ }
+ }
+ if (delta >= gg.ysize()) {
+ *targetRank = -1;
+ *res = 0;
+ return;
+ }
+
+ int sendRank = gg[(myGroupPos + delta) % gg.ysize()];
+ int recvRank = gg[(myGroupPos + gg.ysize() - delta) % gg.ysize()];
+ *targetRank = sendRank;
+ TIntrusivePtr<TRCQueuePair> sendRC = Peers[sendRank];
+ TIntrusivePtr<TRCQueuePair> recvRC = Peers[recvRank];
+ {
+ TBWTest bw;
+ bw.Addr = recvMem->GetAddr();
+ bw.RKey = recvMem->GetMemRegion()->GetRKey();
+ PostSend(recvRC, &bw, sizeof(bw));
+ }
+ TBWTest dstMem;
+ {
+ ui64 wr_id = WaitForMsg(sendRC->GetQPN());
+ TIBRecvPacketProcess pkt(*BP, wr_id);
+ dstMem = *(TBWTest*)pkt.GetData();
+ }
+ // run
+ TVector<double> score;
+ for (int iter = 0; iter < 5; ++iter) {
+ while (!AllocRDMAWriteSlot(sendRC)) {
+ WaitCompletion();
+ Y_ASSERT(0 && "measurements are imprecise");
+ }
+ NHPTimer::STime t;
+ NHPTimer::GetTime(&t);
+ sendRC->PostRDMAWrite(dstMem.Addr, dstMem.RKey, sendMem->GetMemRegion(), 0, sendMem->GetData(), BUF_SIZE);
+ for (;;) {
+ ibv_wc wc;
+ WaitCompletion(&wc);
+ if (wc.opcode == IBV_WC_RDMA_WRITE) {
+ if (wc.qp_num != (ui32)sendRC->GetQPN()) {
+ abort();
+ }
+ break;
+ }
+ }
+ double tPassed = NHPTimer::GetTimePassed(&t);
+ double speed = BUF_SIZE / tPassed / 1000000000.0; // G/sec
+ score.push_back(speed);
+ }
+ Sort(score.begin(), score.end());
+ // signal completion & wait for signal
+ *res = score[score.size() / 2];
+ {
+ char bb;
+ PostSend(sendRC, &bb, sizeof(bb));
+ ui64 wr_id = WaitForMsg(recvRC->GetQPN());
+ TIBRecvPacketProcess pkt(*BP, wr_id);
+ }
+ }
+ bool TrySendMicro(int dstRank, const void* data, int dataSize) override {
+ return TryPostSend(Peers[dstRank], data, dataSize);
+ }
+ void InitPeerTable(TIBMicroPeerTable* res) override {
+ res->Init(QPNTableSizeLog);
+ }
+ void RdmaWrite(const TVector<TRdmaRequest>& reqs) override {
+ TVector<TVector<int>> reqPerRank;
+ reqPerRank.resize(ColSize);
+ int reqCount = reqs.ysize();
+ for (int i = 0; i < reqCount; ++i) {
+ reqPerRank[reqs[i].DstRank].push_back(i);
+ }
+ int inFlight = 0; // IB congestion control sucks :/ so we limit number of simultaneous rdmas
+ int startRank = ColRank;
+ while (reqCount > 0) {
+ if (inFlight < MAX_TOTAL_RDMA) {
+ for (int z = 0; z < ColSize; ++z) {
+ int dstRank = (startRank + 1 + z) % ColSize;
+ if (reqPerRank[dstRank].empty()) {
+ continue;
+ }
+ Y_ASSERT(dstRank != ColRank && "sending self is meaningless");
+ TRCQueuePair* qp = Peers[dstRank].Get();
+ if (AllocRDMAWriteSlot(qp)) {
+ const TRdmaRequest& rr = reqs[reqPerRank[dstRank].back()];
+ qp->PostRDMAWrite(rr.RemoteAddr, rr.RemoteKey, rr.LocalAddr, rr.LocalKey, 0, rr.Size);
+ reqPerRank[dstRank].pop_back();
+ if (++inFlight >= MAX_TOTAL_RDMA) {
+ startRank = dstRank;
+ break;
+ }
+ }
+ }
+ }
+ {
+ ibv_wc wc;
+ WaitCompletion(&wc);
+ if (wc.opcode == IBV_WC_RDMA_WRITE) {
+ --inFlight;
+ --reqCount;
+ }
+ }
+ }
+ }
+
+ public:
+ TIBCollective(TPtrArg<TIBPort> port, TPtrArg<TIBMemPool> memPool,
+ const TCollectiveInit& params,
+ TCollectiveLinkSet* resLinks)
+ : Port(port)
+ , MemPool(memPool)
+ , QPNTableSizeLog(0)
+ {
+ ColSize = params.Size;
+ ColRank = params.Rank;
+
+ int maxOutstandingQueries = MAX_REQS_PER_PEER * ColSize + 10;
+ CQ = new TComplectionQueue(Port->GetCtx(), maxOutstandingQueries * 2);
+ BP = new TIBBufferPool(Port->GetCtx(), maxOutstandingQueries);
+
+ Peers.resize(ColSize);
+ resLinks->Links.resize(ColSize);
+ TVector<int> qpnArr;
+ for (int k = 0; k < ColSize; ++k) {
+ if (k == ColRank) {
+ continue;
+ }
+ TRCQueuePair* rc = new TRCQueuePair(Port->GetCtx(), CQ, BP->GetSRQ(), MAX_REQS_PER_PEER);
+ Peers[k] = rc;
+ TCollectiveLinkSet::TLinkInfo& lnk = resLinks->Links[k];
+ lnk.PSN = rc->GetPSN();
+ lnk.QPN = rc->GetQPN();
+
+ qpnArr.push_back(lnk.QPN);
+ }
+ resLinks->Hosts.resize(ColSize);
+ resLinks->Hosts[ColRank] = Port->GetLID();
+
+ static_assert(MAX_REQS_PER_PEER < 256, "expect MAX_REQS_PER_PEER < 256"); // sent count will fit into SendCountTable[]
+ Zero(SendCountTable);
+ Zero(RDMACountTable);
+
+ if (!qpnArr.empty()) {
+ for (;;) {
+ TVector<ui8> qpnTable;
+ int qpnTableSize = 1 << QPNTableSizeLog;
+ qpnTable.resize(qpnTableSize, 0);
+ bool ok = true;
+ for (int i = 0; i < qpnArr.ysize(); ++i) {
+ int idx = qpnArr[i] & (qpnTableSize - 1);
+ if (++qpnTable[idx] == 2) {
+ ok = false;
+ break;
+ }
+ }
+ if (ok) {
+ break;
+ }
+ ++QPNTableSizeLog;
+ }
+ //printf("QPN table, size_log %d\n", QPNTableSizeLog);
+ }
+ }
+
+ friend class TIBRecvMicro;
+ };
+
+ TIBRecvMicro::TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable)
+ : IB(*(TIBCollective*)col)
+ {
+ Y_ASSERT(typeid(IB) == typeid(TIBCollective));
+ if (IB.GetMsg(&Id, &QPN, peerTable)) {
+ Data = IB.BP->GetBufData(Id);
+ } else {
+ Data = nullptr;
+ }
+ }
+
+ TIBRecvMicro::~TIBRecvMicro() {
+ if (Data) {
+ IB.BP->FreeBuf(Id);
+ IB.BP->PostRecv();
+ }
+ }
+
+ IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks) {
+ return new TIBCollective(GetIBDevice(), GetIBMemPool(), params, resLinks);
+ }
+}
diff --git a/library/cpp/netliba/v6/ib_collective.h b/library/cpp/netliba/v6/ib_collective.h
new file mode 100644
index 0000000000..48ffd29b34
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_collective.h
@@ -0,0 +1,160 @@
+#pragma once
+
+#include <library/cpp/binsaver/bin_saver.h>
+
+namespace NNetliba {
+ struct TCollectiveInit {
+ int Size, Rank;
+
+ SAVELOAD(Size, Rank)
+ };
+
+ struct TCollectiveLinkSet {
+ struct TLinkInfo {
+ int QPN, PSN;
+ };
+ TVector<int> Hosts; // host LIDs
+ TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
+ TVector<TLinkInfo> Links;
+
+ SAVELOAD(Hosts, HostGroup, Links)
+ };
+
+ struct IAllDataSync: public TThrRefBase {
+ virtual void* GetRawData() = 0;
+ virtual size_t GetRawDataSize() = 0;
+ virtual void Sync() = 0;
+ virtual void Flush() = 0;
+
+ template <class T>
+ T* GetData() {
+ return static_cast<T*>(GetRawData());
+ }
+ template <class T>
+ size_t GetSize() {
+ return GetRawDataSize() / sizeof(T);
+ }
+ };
+
+ struct IAllReduce: public IAllDataSync {
+ virtual bool Resize(size_t dataSize) = 0;
+ };
+
+ struct IAllGather: public IAllDataSync {
+ virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
+ };
+
+ struct IReduceOp: public TThrRefBase {
+ virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
+ };
+
+ template <class T, class TElem = typename T::TElem>
+ class TAllReduceOp: public IReduceOp {
+ T Op;
+
+ public:
+ TAllReduceOp() {
+ }
+ TAllReduceOp(T op)
+ : Op(op)
+ {
+ }
+ void Reduce(void* dst, const void* add, size_t dataSize) const override {
+ TElem* dstPtr = (TElem*)(dst);
+ const TElem* addPtr = (const TElem*)(add);
+ TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
+ while (dstPtr < finPtr) {
+ Op(dstPtr, *addPtr);
+ ++dstPtr;
+ ++addPtr;
+ }
+ }
+ };
+
+ // table of active peers for micro send/recv
+ class TIBMicroPeerTable {
+ TVector<ui8> Table; // == 0 means accept mesages from this qpn
+ int TableSize;
+ bool ParsePending;
+
+ public:
+ TIBMicroPeerTable()
+ : ParsePending(true)
+ {
+ Init(0);
+ }
+ void Init(int tableSizeLog) {
+ TableSize = 1 << tableSizeLog;
+ ParsePending = true;
+ Table.resize(0);
+ Table.resize(TableSize, 0);
+ }
+ bool NeedParsePending() const {
+ return ParsePending;
+ }
+ void StopParsePending() {
+ ParsePending = false;
+ }
+ void StopQPN(int qpn, ui8 mask) {
+ Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
+ Table[qpn & (TableSize - 1)] |= mask;
+ }
+ void StopQPN(int qpn) {
+ Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
+ Table[qpn & (TableSize - 1)] = 0xff;
+ }
+ bool NeedQPN(int qpn) const {
+ return Table[qpn & (TableSize - 1)] != 0xff;
+ }
+ };
+
+ struct IIBCollective;
+ class TIBCollective;
+ class TIBRecvMicro: public TNonCopyable {
+ TIBCollective& IB;
+ ui64 Id;
+ int QPN;
+ void* Data;
+
+ public:
+ TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
+ ~TIBRecvMicro();
+ void* GetRawData() const {
+ return Data;
+ }
+ template <class T>
+ T* GetData() {
+ return static_cast<T*>(GetRawData());
+ }
+ int GetQPN() const {
+ return QPN;
+ }
+ };
+
+ struct IIBCollective: public TThrRefBase {
+ struct TRdmaRequest {
+ int DstRank;
+ ui64 RemoteAddr, LocalAddr;
+ ui32 RemoteKey, LocalKey;
+ ui64 Size;
+ };
+
+ virtual int GetRank() = 0;
+ virtual int GetSize() = 0;
+ virtual int GetGroupTypeCount() = 0;
+ virtual int GetQPN(int rank) = 0;
+ virtual bool TryWaitCompletion() = 0;
+ virtual void WaitCompletion() = 0;
+ virtual void Start(const TCollectiveLinkSet& links) = 0;
+ virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
+ virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
+ virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
+ virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
+ virtual void Fence() = 0;
+ virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
+ virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
+ virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
+ };
+
+ IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
+}
diff --git a/library/cpp/netliba/v6/ib_cs.cpp b/library/cpp/netliba/v6/ib_cs.cpp
new file mode 100644
index 0000000000..6dbe7bb0e5
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_cs.cpp
@@ -0,0 +1,776 @@
+#include "stdafx.h"
+#include "ib_cs.h"
+#include "ib_buffers.h"
+#include "ib_mem.h"
+#include <util/generic/deque.h>
+#include <util/digest/murmur.h>
+
+/*
+Questions
+ does rdma work?
+ what is RC latency?
+ 3us if measured by completion event arrival
+ 2.3us if bind to socket 0 & use inline send
+ memory region - can we use memory from some offset?
+ yes
+ is send_inplace supported and is it faster?
+ yes, supported, 1024 bytes limit, inline is faster (2.4 vs 2.9)
+ is srq a penalty compared to regular rq?
+ rdma is faster anyway, so why bother
+
+collective ops
+ support asymmetric configurations by additional transfers (overlap 1 or 2 hosts is allowed)
+
+remove commented stuff all around
+
+next gen
+ shared+registered large mem blocks for easy transfer
+ no crc calcs
+ direct channel exposure
+ make ui64 packet id? otherwise we could get duplicate id (highly improbable but possible)
+ lock free allocation in ib_mem
+*/
+
+namespace NNetliba {
+ const int WELCOME_QKEY = 0x13081976;
+
+ const int MAX_SEND_COUNT = (128 - 10) / 4;
+ const int QP_SEND_QUEUE_SIZE = (MAX_SEND_COUNT * 2 + 10) + 10;
+ const int WELCOME_QP_SEND_SIZE = 10000;
+
+ const int MAX_SRQ_WORK_REQUESTS = 10000;
+ const int MAX_CQ_EVENTS = MAX_SRQ_WORK_REQUESTS; //1000;
+
+ const double CHANNEL_CHECK_INTERVAL = 1.;
+
+ const int TRAFFIC_SL = 4; // 4 is mandatory for RoCE to work, it's the only lossless priority(?)
+ const int CONNECT_SL = 1;
+
+ class TIBClientServer: public IIBClientServer {
+ enum ECmd {
+ CMD_HANDSHAKE,
+ CMD_HANDSHAKE_ACK,
+ CMD_CONFIRM,
+ CMD_DATA_TINY,
+ CMD_DATA_INIT,
+ CMD_BUFFER_READY,
+ CMD_DATA_COMPLETE,
+ CMD_KEEP_ALIVE,
+ };
+#pragma pack(1)
+ struct TCmdHandshake {
+ char Command;
+ int QPN, PSN;
+ TGUID SocketId;
+ TUdpAddress MyAddress; // address of the handshake sender as viewed from receiver
+ };
+ struct TCmdHandshakeAck {
+ char Command;
+ int QPN, PSN;
+ int YourQPN;
+ };
+ struct TCmdConfirm {
+ char Command;
+ };
+ struct TCmdDataTiny {
+ struct THeader {
+ char Command;
+ ui16 Size;
+ TGUID PacketGuid;
+ } Header;
+ typedef char TDataVec[SMALL_PKT_SIZE - sizeof(THeader)];
+ TDataVec Data;
+
+ static int GetMaxDataSize() {
+ return sizeof(TDataVec);
+ }
+ };
+ struct TCmdDataInit {
+ char Command;
+ size_t Size;
+ TGUID PacketGuid;
+ };
+ struct TCmdBufferReady {
+ char Command;
+ TGUID PacketGuid;
+ ui64 RemoteAddr;
+ ui32 RemoteKey;
+ };
+ struct TCmdDataComplete {
+ char Command;
+ TGUID PacketGuid;
+ ui64 DataHash;
+ };
+ struct TCmdKeepAlive {
+ char Command;
+ };
+#pragma pack()
+
+ struct TCompleteInfo {
+ enum {
+ CI_DATA_TINY,
+ CI_RDMA_COMPLETE,
+ CI_DATA_SENT,
+ CI_KEEP_ALIVE,
+ CI_IGNORE,
+ };
+ int Type;
+ int BufId;
+ TIBMsgHandle MsgHandle;
+
+ TCompleteInfo(int t, int bufId, TIBMsgHandle msg)
+ : Type(t)
+ , BufId(bufId)
+ , MsgHandle(msg)
+ {
+ }
+ };
+ struct TPendingQueuedSend {
+ TGUID PacketGuid;
+ TIBMsgHandle MsgHandle;
+ TRopeDataPacket* Data;
+
+ TPendingQueuedSend()
+ : MsgHandle(0)
+ {
+ }
+ TPendingQueuedSend(const TGUID& packetGuid, TIBMsgHandle msgHandle, TRopeDataPacket* data)
+ : PacketGuid(packetGuid)
+ , MsgHandle(msgHandle)
+ , Data(data)
+ {
+ }
+ };
+ struct TQueuedSend {
+ TGUID PacketGuid;
+ TIBMsgHandle MsgHandle;
+ TIntrusivePtr<TIBMemBlock> MemBlock;
+ ui64 RemoteAddr;
+ ui32 RemoteKey;
+
+ TQueuedSend() = default;
+ TQueuedSend(const TGUID& packetGuid, TIBMsgHandle msgHandle)
+ : PacketGuid(packetGuid)
+ , MsgHandle(msgHandle)
+ , RemoteAddr(0)
+ , RemoteKey(0)
+ {
+ }
+ };
+ struct TQueuedRecv {
+ TGUID PacketGuid;
+ TIntrusivePtr<TIBMemBlock> Data;
+
+ TQueuedRecv() = default;
+ TQueuedRecv(const TGUID& packetGuid, TPtrArg<TIBMemBlock> data)
+ : PacketGuid(packetGuid)
+ , Data(data)
+ {
+ }
+ };
+ struct TIBPeer: public IIBPeer {
+ TUdpAddress PeerAddress;
+ TIntrusivePtr<TRCQueuePair> QP;
+ EState State;
+ int SendCount;
+ NHPTimer::STime LastRecv;
+ TDeque<TPendingQueuedSend> PendingSendQueue;
+ // these lists have limited size and potentially just circle buffers
+ TDeque<TQueuedSend> SendQueue;
+ TDeque<TQueuedRecv> RecvQueue;
+ TDeque<TCompleteInfo> OutMsgs;
+
+ TIBPeer(const TUdpAddress& peerAddress, TPtrArg<TRCQueuePair> qp)
+ : PeerAddress(peerAddress)
+ , QP(qp)
+ , State(CONNECTING)
+ , SendCount(0)
+ {
+ NHPTimer::GetTime(&LastRecv);
+ }
+ ~TIBPeer() override {
+ //printf("IBPeer destroyed\n");
+ }
+ EState GetState() override {
+ return State;
+ }
+ TDeque<TQueuedSend>::iterator GetSend(const TGUID& packetGuid) {
+ for (TDeque<TQueuedSend>::iterator z = SendQueue.begin(); z != SendQueue.end(); ++z) {
+ if (z->PacketGuid == packetGuid) {
+ return z;
+ }
+ }
+ Y_VERIFY(0, "no send by guid");
+ return SendQueue.begin();
+ }
+ TDeque<TQueuedSend>::iterator GetSend(TIBMsgHandle msgHandle) {
+ for (TDeque<TQueuedSend>::iterator z = SendQueue.begin(); z != SendQueue.end(); ++z) {
+ if (z->MsgHandle == msgHandle) {
+ return z;
+ }
+ }
+ Y_VERIFY(0, "no send by handle");
+ return SendQueue.begin();
+ }
+ TDeque<TQueuedRecv>::iterator GetRecv(const TGUID& packetGuid) {
+ for (TDeque<TQueuedRecv>::iterator z = RecvQueue.begin(); z != RecvQueue.end(); ++z) {
+ if (z->PacketGuid == packetGuid) {
+ return z;
+ }
+ }
+ Y_VERIFY(0, "no recv by guid");
+ return RecvQueue.begin();
+ }
+ void PostRDMA(TQueuedSend& qs) {
+ Y_ASSERT(qs.RemoteAddr != 0 && qs.MemBlock.Get() != nullptr);
+ QP->PostRDMAWrite(qs.RemoteAddr, qs.RemoteKey,
+ qs.MemBlock->GetMemRegion(), 0, qs.MemBlock->GetData(), qs.MemBlock->GetSize());
+ OutMsgs.push_back(TCompleteInfo(TCompleteInfo::CI_RDMA_COMPLETE, 0, qs.MsgHandle));
+ //printf("Post rdma write, size %d\n", qs.Data->GetSize());
+ }
+ void PostSend(TIBBufferPool& bp, const void* data, size_t len, int t, TIBMsgHandle msgHandle) {
+ int bufId = bp.PostSend(QP, data, len);
+ OutMsgs.push_back(TCompleteInfo(t, bufId, msgHandle));
+ }
+ };
+
+ TIntrusivePtr<TIBPort> Port;
+ TIntrusivePtr<TIBMemPool> MemPool;
+ TIntrusivePtr<TIBMemPool::TCopyResultStorage> CopyResults;
+ TIntrusivePtr<TComplectionQueue> CQ;
+ TIBBufferPool BP;
+ TIntrusivePtr<TUDQueuePair> WelcomeQP;
+ int WelcomeQPN;
+ TIBConnectInfo ConnectInfo;
+ TDeque<TIBSendResult> SendResults;
+ TDeque<TRequest*> ReceivedList;
+ typedef THashMap<int, TIntrusivePtr<TIBPeer>> TPeerChannelHash;
+ TPeerChannelHash Channels;
+ TIBMsgHandle MsgCounter;
+ NHPTimer::STime LastCheckTime;
+
+ ~TIBClientServer() override {
+ for (auto& z : ReceivedList) {
+ delete z;
+ }
+ }
+ TIBPeer* GetChannelByQPN(int qpn) {
+ TPeerChannelHash::iterator z = Channels.find(qpn);
+ if (z == Channels.end()) {
+ return nullptr;
+ }
+ return z->second.Get();
+ }
+
+ // IIBClientServer
+ TRequest* GetRequest() override {
+ if (ReceivedList.empty()) {
+ return nullptr;
+ }
+ TRequest* res = ReceivedList.front();
+ ReceivedList.pop_front();
+ return res;
+ }
+ bool GetSendResult(TIBSendResult* res) override {
+ if (SendResults.empty()) {
+ return false;
+ }
+ *res = SendResults.front();
+ SendResults.pop_front();
+ return true;
+ }
+ void StartSend(TPtrArg<TIBPeer> peer, const TGUID& packetGuid, TIBMsgHandle msgHandle, TRopeDataPacket* data) {
+ int sz = data->GetSize();
+ if (sz <= TCmdDataTiny::GetMaxDataSize()) {
+ TCmdDataTiny dataTiny;
+ dataTiny.Header.Command = CMD_DATA_TINY;
+ dataTiny.Header.Size = (ui16)sz;
+ dataTiny.Header.PacketGuid = packetGuid;
+ TBlockChainIterator bc(data->GetChain());
+ bc.Read(dataTiny.Data, sz);
+
+ peer->PostSend(BP, &dataTiny, sizeof(dataTiny.Header) + sz, TCompleteInfo::CI_DATA_TINY, msgHandle);
+ //printf("Send CMD_DATA_TINY\n");
+ } else {
+ MemPool->CopyData(data, msgHandle, peer, CopyResults);
+ peer->SendQueue.push_back(TQueuedSend(packetGuid, msgHandle));
+ {
+ TQueuedSend& msg = peer->SendQueue.back();
+ TCmdDataInit dataInit;
+ dataInit.Command = CMD_DATA_INIT;
+ dataInit.PacketGuid = msg.PacketGuid;
+ dataInit.Size = data->GetSize();
+ peer->PostSend(BP, &dataInit, sizeof(dataInit), TCompleteInfo::CI_IGNORE, 0);
+ //printf("Send CMD_DATA_INIT\n");
+ }
+ }
+ ++peer->SendCount;
+ }
+ void SendCompleted(TPtrArg<TIBPeer> peer, TIBMsgHandle msgHandle) {
+ SendResults.push_back(TIBSendResult(msgHandle, true));
+ if (--peer->SendCount < MAX_SEND_COUNT) {
+ if (!peer->PendingSendQueue.empty()) {
+ TPendingQueuedSend& qs = peer->PendingSendQueue.front();
+ StartSend(peer, qs.PacketGuid, qs.MsgHandle, qs.Data);
+ //printf("Sending pending %d\n", qs.MsgHandle);
+ peer->PendingSendQueue.pop_front();
+ }
+ }
+ }
+ void SendFailed(TPtrArg<TIBPeer> peer, TIBMsgHandle msgHandle) {
+ //printf("IB SendFailed()\n");
+ SendResults.push_back(TIBSendResult(msgHandle, false));
+ --peer->SendCount;
+ }
+ void PeerFailed(TPtrArg<TIBPeer> peer) {
+ //printf("PeerFailed(), peer %p, state %d (%d pending, %d queued, %d out, %d sendcount)\n",
+ // peer.Get(), peer->State,
+ // (int)peer->PendingSendQueue.size(),
+ // (int)peer->SendQueue.size(),
+ // (int)peer->OutMsgs.size(),
+ // peer->SendCount);
+ peer->State = IIBPeer::FAILED;
+ while (!peer->PendingSendQueue.empty()) {
+ TPendingQueuedSend& qs = peer->PendingSendQueue.front();
+ SendResults.push_back(TIBSendResult(qs.MsgHandle, false));
+ peer->PendingSendQueue.pop_front();
+ }
+ while (!peer->SendQueue.empty()) {
+ TQueuedSend& qs = peer->SendQueue.front();
+ SendFailed(peer, qs.MsgHandle);
+ peer->SendQueue.pop_front();
+ }
+ while (!peer->OutMsgs.empty()) {
+ TCompleteInfo& cc = peer->OutMsgs.front();
+ //printf("Don't wait completion for sent packet (QPN %d), bufId %d\n", peer->QP->GetQPN(), cc.BufId);
+ if (cc.Type == TCompleteInfo::CI_DATA_TINY) {
+ SendFailed(peer, cc.MsgHandle);
+ }
+ BP.FreeBuf(cc.BufId);
+ peer->OutMsgs.pop_front();
+ }
+ {
+ Y_ASSERT(peer->SendCount == 0);
+ //printf("Remove peer %p from hash (QPN %d)\n", peer.Get(), peer->QP->GetQPN());
+ TPeerChannelHash::iterator z = Channels.find(peer->QP->GetQPN());
+ if (z == Channels.end()) {
+ Y_VERIFY(0, "peer failed for unregistered peer");
+ }
+ Channels.erase(z);
+ }
+ }
+ TIBMsgHandle Send(TPtrArg<IIBPeer> peerArg, TRopeDataPacket* data, const TGUID& packetGuid) override {
+ TIBPeer* peer = static_cast<TIBPeer*>(peerArg.Get()); // trust me, I'm professional
+ if (peer == nullptr || peer->State != IIBPeer::OK) {
+ return -1;
+ }
+ Y_ASSERT(Channels.find(peer->QP->GetQPN())->second == peer);
+ TIBMsgHandle msgHandle = ++MsgCounter;
+ if (peer->SendCount >= MAX_SEND_COUNT) {
+ peer->PendingSendQueue.push_back(TPendingQueuedSend(packetGuid, msgHandle, data));
+ } else {
+ //printf("Sending direct %d\n", msgHandle);
+ StartSend(peer, packetGuid, msgHandle, data);
+ }
+ return msgHandle;
+ }
+ void ParsePacket(ibv_wc* wc, NHPTimer::STime tCurrent) {
+ if (wc->status != IBV_WC_SUCCESS) {
+ TIBPeer* peer = GetChannelByQPN(wc->qp_num);
+ if (peer) {
+ //printf("failed recv packet (status %d)\n", wc->status);
+ PeerFailed(peer);
+ } else {
+ //printf("Ignoring recv error for closed/non existing QPN %d\n", wc->qp_num);
+ }
+ return;
+ }
+
+ TIBRecvPacketProcess pkt(BP, *wc);
+
+ TIBPeer* peer = GetChannelByQPN(wc->qp_num);
+ if (peer) {
+ Y_ASSERT(peer->State != IIBPeer::FAILED);
+ peer->LastRecv = tCurrent;
+ char cmdId = *(const char*)pkt.GetData();
+ switch (cmdId) {
+ case CMD_CONFIRM:
+ //printf("got confirm\n");
+ Y_ASSERT(peer->State == IIBPeer::CONNECTING);
+ peer->State = IIBPeer::OK;
+ break;
+ case CMD_DATA_TINY:
+ //printf("Recv CMD_DATA_TINY\n");
+ {
+ const TCmdDataTiny& dataTiny = *(TCmdDataTiny*)pkt.GetData();
+ TRequest* req = new TRequest;
+ req->Address = peer->PeerAddress;
+ req->Guid = dataTiny.Header.PacketGuid;
+ req->Data = new TRopeDataPacket;
+ req->Data->Write(dataTiny.Data, dataTiny.Header.Size);
+ ReceivedList.push_back(req);
+ }
+ break;
+ case CMD_DATA_INIT:
+ //printf("Recv CMD_DATA_INIT\n");
+ {
+ const TCmdDataInit& data = *(TCmdDataInit*)pkt.GetData();
+ TIntrusivePtr<TIBMemBlock> blk = MemPool->Alloc(data.Size);
+ peer->RecvQueue.push_back(TQueuedRecv(data.PacketGuid, blk));
+ TCmdBufferReady ready;
+ ready.Command = CMD_BUFFER_READY;
+ ready.PacketGuid = data.PacketGuid;
+ ready.RemoteAddr = blk->GetData() - (char*)nullptr;
+ ready.RemoteKey = blk->GetMemRegion()->GetRKey();
+
+ peer->PostSend(BP, &ready, sizeof(ready), TCompleteInfo::CI_IGNORE, 0);
+ //printf("Send CMD_BUFFER_READY\n");
+ }
+ break;
+ case CMD_BUFFER_READY:
+ //printf("Recv CMD_BUFFER_READY\n");
+ {
+ const TCmdBufferReady& ready = *(TCmdBufferReady*)pkt.GetData();
+ TDeque<TQueuedSend>::iterator z = peer->GetSend(ready.PacketGuid);
+ TQueuedSend& qs = *z;
+ qs.RemoteAddr = ready.RemoteAddr;
+ qs.RemoteKey = ready.RemoteKey;
+ if (qs.MemBlock.Get()) {
+ peer->PostRDMA(qs);
+ }
+ }
+ break;
+ case CMD_DATA_COMPLETE:
+ //printf("Recv CMD_DATA_COMPLETE\n");
+ {
+ const TCmdDataComplete& cmd = *(TCmdDataComplete*)pkt.GetData();
+ TDeque<TQueuedRecv>::iterator z = peer->GetRecv(cmd.PacketGuid);
+ TQueuedRecv& qr = *z;
+#ifdef _DEBUG
+ Y_VERIFY(MurmurHash<ui64>(qr.Data->GetData(), qr.Data->GetSize()) == cmd.DataHash || cmd.DataHash == 0, "RDMA data hash mismatch");
+#endif
+ TRequest* req = new TRequest;
+ req->Address = peer->PeerAddress;
+ req->Guid = qr.PacketGuid;
+ req->Data = new TRopeDataPacket;
+ req->Data->AddBlock(qr.Data.Get(), qr.Data->GetData(), qr.Data->GetSize());
+ ReceivedList.push_back(req);
+ peer->RecvQueue.erase(z);
+ }
+ break;
+ case CMD_KEEP_ALIVE:
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ } else {
+ // can get here
+ //printf("Ignoring packet for closed/non existing QPN %d\n", wc->qp_num);
+ }
+ }
+ void OnComplete(ibv_wc* wc, NHPTimer::STime tCurrent) {
+ TIBPeer* peer = GetChannelByQPN(wc->qp_num);
+ if (peer) {
+ if (!peer->OutMsgs.empty()) {
+ peer->LastRecv = tCurrent;
+ if (wc->status != IBV_WC_SUCCESS) {
+ //printf("completed with status %d\n", wc->status);
+ PeerFailed(peer);
+ } else {
+ const TCompleteInfo& cc = peer->OutMsgs.front();
+ switch (cc.Type) {
+ case TCompleteInfo::CI_DATA_TINY:
+ //printf("Completed data_tiny\n");
+ SendCompleted(peer, cc.MsgHandle);
+ break;
+ case TCompleteInfo::CI_RDMA_COMPLETE:
+ //printf("Completed rdma_complete\n");
+ {
+ TDeque<TQueuedSend>::iterator z = peer->GetSend(cc.MsgHandle);
+ TQueuedSend& qs = *z;
+
+ TCmdDataComplete complete;
+ complete.Command = CMD_DATA_COMPLETE;
+ complete.PacketGuid = qs.PacketGuid;
+#ifdef _DEBUG
+ complete.DataHash = MurmurHash<ui64>(qs.MemBlock->GetData(), qs.MemBlock->GetSize());
+#else
+ complete.DataHash = 0;
+#endif
+
+ peer->PostSend(BP, &complete, sizeof(complete), TCompleteInfo::CI_DATA_SENT, qs.MsgHandle);
+ //printf("Send CMD_DATA_COMPLETE\n");
+ }
+ break;
+ case TCompleteInfo::CI_DATA_SENT:
+ //printf("Completed data_sent\n");
+ {
+ TDeque<TQueuedSend>::iterator z = peer->GetSend(cc.MsgHandle);
+ TIBMsgHandle msgHandle = z->MsgHandle;
+ peer->SendQueue.erase(z);
+ SendCompleted(peer, msgHandle);
+ }
+ break;
+ case TCompleteInfo::CI_KEEP_ALIVE:
+ break;
+ case TCompleteInfo::CI_IGNORE:
+ //printf("Completed ignored\n");
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ peer->OutMsgs.pop_front();
+ BP.FreeBuf(wc->wr_id);
+ }
+ } else {
+ Y_VERIFY(0, "got completion without outstanding messages");
+ }
+ } else {
+ //printf("Got completion for non existing qpn %d, bufId %d (status %d)\n", wc->qp_num, (int)wc->wr_id, (int)wc->status);
+ if (wc->status == IBV_WC_SUCCESS) {
+ Y_VERIFY(0, "only errors should go unmatched");
+ }
+ // no need to free buf since it has to be freed in PeerFailed()
+ }
+ }
+ void ParseWelcomePacket(ibv_wc* wc) {
+ TIBRecvPacketProcess pkt(BP, *wc);
+
+ char cmdId = *(const char*)pkt.GetUDData();
+ switch (cmdId) {
+ case CMD_HANDSHAKE: {
+ //printf("got handshake\n");
+ const TCmdHandshake& handshake = *(TCmdHandshake*)pkt.GetUDData();
+ if (handshake.SocketId != ConnectInfo.SocketId) {
+ // connection attempt from wrong IB subnet
+ break;
+ }
+ TIntrusivePtr<TRCQueuePair> rcQP;
+ rcQP = new TRCQueuePair(Port->GetCtx(), CQ, BP.GetSRQ(), QP_SEND_QUEUE_SIZE);
+
+ int qpn = rcQP->GetQPN();
+ Y_ASSERT(Channels.find(qpn) == Channels.end());
+ TIntrusivePtr<TIBPeer>& peer = Channels[qpn];
+ peer = new TIBPeer(handshake.MyAddress, rcQP);
+
+ ibv_ah_attr peerAddr;
+ TIntrusivePtr<TAddressHandle> ahPeer;
+ Port->GetAHAttr(wc, pkt.GetGRH(), &peerAddr);
+ ahPeer = new TAddressHandle(Port->GetCtx(), &peerAddr);
+
+ peerAddr.sl = TRAFFIC_SL;
+ rcQP->Init(peerAddr, handshake.QPN, handshake.PSN);
+
+ TCmdHandshakeAck handshakeAck;
+ handshakeAck.Command = CMD_HANDSHAKE_ACK;
+ handshakeAck.PSN = rcQP->GetPSN();
+ handshakeAck.QPN = rcQP->GetQPN();
+ handshakeAck.YourQPN = handshake.QPN;
+ // if ack gets lost we'll create new Peer Channel
+ // and this one will be erased in Step() by timeout counted from LastRecv
+ BP.PostSend(WelcomeQP, ahPeer, wc->src_qp, WELCOME_QKEY, &handshakeAck, sizeof(handshakeAck));
+ //printf("send handshake_ack\n");
+ } break;
+ case CMD_HANDSHAKE_ACK: {
+ //printf("got handshake_ack\n");
+ const TCmdHandshakeAck& handshakeAck = *(TCmdHandshakeAck*)pkt.GetUDData();
+ TIBPeer* peer = GetChannelByQPN(handshakeAck.YourQPN);
+ if (peer) {
+ ibv_ah_attr peerAddr;
+ Port->GetAHAttr(wc, pkt.GetGRH(), &peerAddr);
+
+ peerAddr.sl = TRAFFIC_SL;
+ peer->QP->Init(peerAddr, handshakeAck.QPN, handshakeAck.PSN);
+
+ peer->State = IIBPeer::OK;
+
+ TCmdConfirm confirm;
+ confirm.Command = CMD_CONFIRM;
+ peer->PostSend(BP, &confirm, sizeof(confirm), TCompleteInfo::CI_IGNORE, 0);
+ //printf("send confirm\n");
+ } else {
+ // respective QPN was deleted or never existed
+ // silently ignore and peer channel on remote side
+ // will not get into confirmed state and will be deleted
+ }
+ } break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ bool Step(NHPTimer::STime tCurrent) override {
+ bool rv = false;
+ // only have to process completions, everything is done on completion of something
+ ibv_wc wcArr[10];
+ for (;;) {
+ int wcCount = CQ->Poll(wcArr, Y_ARRAY_SIZE(wcArr));
+ if (wcCount == 0) {
+ break;
+ }
+ rv = true;
+ for (int z = 0; z < wcCount; ++z) {
+ ibv_wc& wc = wcArr[z];
+ if (wc.opcode & IBV_WC_RECV) {
+ // received msg
+ if ((int)wc.qp_num == WelcomeQPN) {
+ if (wc.status != IBV_WC_SUCCESS) {
+ Y_VERIFY(0, "ud recv op completed with error %d\n", (int)wc.status);
+ }
+ Y_ASSERT(wc.opcode == IBV_WC_RECV | IBV_WC_SEND);
+ ParseWelcomePacket(&wc);
+ } else {
+ ParsePacket(&wc, tCurrent);
+ }
+ } else {
+ // send completion
+ if ((int)wc.qp_num == WelcomeQPN) {
+ // ok
+ BP.FreeBuf(wc.wr_id);
+ } else {
+ OnComplete(&wc, tCurrent);
+ }
+ }
+ }
+ }
+ {
+ TIntrusivePtr<TIBMemBlock> memBlock;
+ i64 msgHandle;
+ TIntrusivePtr<TIBPeer> peer;
+ while (CopyResults->GetCopyResult(&memBlock, &msgHandle, &peer)) {
+ if (peer->GetState() != IIBPeer::OK) {
+ continue;
+ }
+ TDeque<TQueuedSend>::iterator z = peer->GetSend(msgHandle);
+ if (z == peer->SendQueue.end()) {
+ Y_VERIFY(0, "peer %p, copy completed, msg %d not found?\n", peer.Get(), (int)msgHandle);
+ continue;
+ }
+ TQueuedSend& qs = *z;
+ qs.MemBlock = memBlock;
+ if (qs.RemoteAddr != 0) {
+ peer->PostRDMA(qs);
+ }
+ rv = true;
+ }
+ }
+ {
+ NHPTimer::STime t1 = LastCheckTime;
+ if (NHPTimer::GetTimePassed(&t1) > CHANNEL_CHECK_INTERVAL) {
+ for (TPeerChannelHash::iterator z = Channels.begin(); z != Channels.end();) {
+ TIntrusivePtr<TIBPeer> peer = z->second;
+ ++z; // peer can be removed from Channels
+ Y_ASSERT(peer->State != IIBPeer::FAILED);
+ NHPTimer::STime t2 = peer->LastRecv;
+ double timeSinceLastRecv = NHPTimer::GetTimePassed(&t2);
+ if (timeSinceLastRecv > CHANNEL_CHECK_INTERVAL) {
+ if (peer->State == IIBPeer::CONNECTING) {
+ Y_ASSERT(peer->OutMsgs.empty() && peer->SendCount == 0);
+ // if handshake does not seem to work out - close connection
+ //printf("IB connecting timed out\n");
+ PeerFailed(peer);
+ } else {
+ // if we have outmsg we hope that IB will report us if there are any problems
+ // with connectivity
+ if (peer->OutMsgs.empty()) {
+ //printf("Sending keep alive\n");
+ TCmdKeepAlive keep;
+ keep.Command = CMD_KEEP_ALIVE;
+ peer->PostSend(BP, &keep, sizeof(keep), TCompleteInfo::CI_KEEP_ALIVE, 0);
+ }
+ }
+ }
+ }
+ LastCheckTime = t1;
+ }
+ }
+ return rv;
+ }
+ IIBPeer* ConnectPeer(const TIBConnectInfo& info, const TUdpAddress& peerAddr, const TUdpAddress& myAddr) override {
+ for (auto& channel : Channels) {
+ TIntrusivePtr<TIBPeer> peer = channel.second;
+ if (peer->PeerAddress == peerAddr) {
+ return peer.Get();
+ }
+ }
+ TIntrusivePtr<TRCQueuePair> rcQP;
+ rcQP = new TRCQueuePair(Port->GetCtx(), CQ, BP.GetSRQ(), QP_SEND_QUEUE_SIZE);
+
+ int qpn = rcQP->GetQPN();
+ Y_ASSERT(Channels.find(qpn) == Channels.end());
+ TIntrusivePtr<TIBPeer>& peer = Channels[qpn];
+ peer = new TIBPeer(peerAddr, rcQP);
+
+ TCmdHandshake handshake;
+ handshake.Command = CMD_HANDSHAKE;
+ handshake.PSN = rcQP->GetPSN();
+ handshake.QPN = rcQP->GetQPN();
+ handshake.SocketId = info.SocketId;
+ handshake.MyAddress = myAddr;
+
+ TIntrusivePtr<TAddressHandle> serverAH;
+ if (info.LID != 0) {
+ serverAH = new TAddressHandle(Port, info.LID, CONNECT_SL);
+ } else {
+ //ibv_gid addr;
+ //addr.global.subnet_prefix = info.Subnet;
+ //addr.global.interface_id = info.Interface;
+ //serverAH = new TAddressHandle(Port, addr, CONNECT_SL);
+
+ TUdpAddress local = myAddr;
+ local.Port = 0;
+ TUdpAddress remote = peerAddr;
+ remote.Port = 0;
+ //printf("local Addr %s\n", GetAddressAsString(local).c_str());
+ //printf("remote Addr %s\n", GetAddressAsString(remote).c_str());
+ // CRAP - somehow prevent connecting machines from different RoCE isles
+ serverAH = new TAddressHandle(Port, remote, local, CONNECT_SL);
+ if (!serverAH->IsValid()) {
+ return nullptr;
+ }
+ }
+ BP.PostSend(WelcomeQP, serverAH, info.QPN, WELCOME_QKEY, &handshake, sizeof(handshake));
+ //printf("send handshake\n");
+
+ return peer.Get();
+ }
+ const TIBConnectInfo& GetConnectInfo() override {
+ return ConnectInfo;
+ }
+
+ public:
+ TIBClientServer(TPtrArg<TIBPort> port)
+ : Port(port)
+ , MemPool(GetIBMemPool())
+ , CQ(new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS))
+ , BP(port->GetCtx(), MAX_SRQ_WORK_REQUESTS)
+ , WelcomeQP(new TUDQueuePair(port, CQ, BP.GetSRQ(), WELCOME_QP_SEND_SIZE))
+ , WelcomeQPN(WelcomeQP->GetQPN())
+ , MsgCounter(1)
+ {
+ CopyResults = new TIBMemPool::TCopyResultStorage;
+ CreateGuid(&ConnectInfo.SocketId);
+ ibv_gid addr;
+ port->GetGID(&addr);
+ ConnectInfo.Interface = addr.global.interface_id;
+ ConnectInfo.Subnet = addr.global.subnet_prefix;
+ //printf("connect addr subnet %lx, iface %lx\n", addr.global.subnet_prefix, addr.global.interface_id);
+ ConnectInfo.LID = port->GetLID();
+ ConnectInfo.QPN = WelcomeQPN;
+
+ WelcomeQP->Init(WELCOME_QKEY);
+
+ NHPTimer::GetTime(&LastCheckTime);
+ }
+ };
+
+ IIBClientServer* CreateIBClientServer() {
+ TIntrusivePtr<TIBPort> port = GetIBDevice();
+ if (port.Get() == nullptr) {
+ return nullptr;
+ }
+ return new TIBClientServer(port);
+ }
+}
diff --git a/library/cpp/netliba/v6/ib_cs.h b/library/cpp/netliba/v6/ib_cs.h
new file mode 100644
index 0000000000..932f2880dc
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_cs.h
@@ -0,0 +1,57 @@
+#pragma once
+
+#include "udp_address.h"
+#include "block_chain.h"
+#include "net_request.h"
+#include <util/generic/guid.h>
+#include <util/system/hp_timer.h>
+
+namespace NNetliba {
+ struct TIBConnectInfo {
+ TGUID SocketId;
+ ui64 Subnet, Interface;
+ int LID;
+ int QPN;
+ };
+
+ struct TRCQueuePairHandshake {
+ int QPN, PSN;
+ };
+
+ using TIBMsgHandle = i64;
+
+ struct TIBSendResult {
+ TIBMsgHandle Handle;
+ bool Success;
+ TIBSendResult()
+ : Handle(0)
+ , Success(false)
+ {
+ }
+ TIBSendResult(TIBMsgHandle handle, bool success)
+ : Handle(handle)
+ , Success(success)
+ {
+ }
+ };
+
+ struct IIBPeer: public TThrRefBase {
+ enum EState {
+ CONNECTING,
+ OK,
+ FAILED,
+ };
+ virtual EState GetState() = 0;
+ };
+
+ struct IIBClientServer: public TThrRefBase {
+ virtual TRequest* GetRequest() = 0;
+ virtual TIBMsgHandle Send(TPtrArg<IIBPeer> peer, TRopeDataPacket* data, const TGUID& packetGuid) = 0;
+ virtual bool GetSendResult(TIBSendResult* res) = 0;
+ virtual bool Step(NHPTimer::STime tCurrent) = 0;
+ virtual IIBPeer* ConnectPeer(const TIBConnectInfo& info, const TUdpAddress& peerAddr, const TUdpAddress& myAddr) = 0;
+ virtual const TIBConnectInfo& GetConnectInfo() = 0;
+ };
+
+ IIBClientServer* CreateIBClientServer();
+}
diff --git a/library/cpp/netliba/v6/ib_low.cpp b/library/cpp/netliba/v6/ib_low.cpp
new file mode 100644
index 0000000000..455a5f3512
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_low.cpp
@@ -0,0 +1,114 @@
+#include "stdafx.h"
+#include "ib_low.h"
+
+namespace NNetliba {
+ static bool EnableROCEFlag = false;
+
+ void EnableROCE(bool f) {
+ EnableROCEFlag = f;
+ }
+
+#if defined _linux_ && !defined CATBOOST_OPENSOURCE
+ static TMutex IBPortMutex;
+ static TIntrusivePtr<TIBPort> IBPort;
+ static bool IBWasInitialized;
+
+ TIntrusivePtr<TIBPort> GetIBDevice() {
+ TGuard<TMutex> gg(IBPortMutex);
+ if (IBWasInitialized) {
+ return IBPort;
+ }
+ IBWasInitialized = true;
+
+ try {
+ int rv = ibv_fork_init();
+ if (rv != 0) {
+ //printf("ibv_fork_init() failed");
+ return nullptr;
+ }
+ } catch (...) {
+ //we can not load ib interface, so no ib
+ return nullptr;
+ }
+
+ TIntrusivePtr<TIBContext> ctx;
+ TIntrusivePtr<TIBPort> resPort;
+ int numDevices;
+ ibv_device** deviceList = ibv_get_device_list(&numDevices);
+ //for (int i = 0; i < numDevices; ++i) {
+ // ibv_device *dev = deviceList[i];
+
+ // printf("Dev %d\n", i);
+ // printf("name:%s\ndev_name:%s\ndev_path:%s\nibdev_path:%s\n",
+ // dev->name,
+ // dev->dev_name,
+ // dev->dev_path,
+ // dev->ibdev_path);
+ // printf("get_device_name(): %s\n", ibv_get_device_name(dev));
+ // ui64 devGuid = ibv_get_device_guid(dev);
+ // printf("ibv_get_device_guid: %" PRIx64 "\n", devGuid);
+ // printf("node type: %s\n", ibv_node_type_str(dev->node_type));
+ // printf("\n");
+ //}
+ if (numDevices == 1) {
+ ctx = new TIBContext(deviceList[0]);
+ TIBContext::TLock ibContext(ctx);
+ ibv_device_attr devAttrs;
+ CHECK_Z(ibv_query_device(ibContext.GetContext(), &devAttrs));
+
+ for (int port = 1; port <= devAttrs.phys_port_cnt; ++port) {
+ ibv_port_attr portAttrs;
+ CHECK_Z(ibv_query_port(ibContext.GetContext(), port, &portAttrs));
+ //ibv_gid myAddress; // ipv6 address of this port;
+ //CHECK_Z(ibv_query_gid(ibContext.GetContext(), port, 0, &myAddress));
+ //{
+ // ibv_gid p = myAddress;
+ // for (int k = 0; k < 4; ++k) {
+ // DoSwap(p.raw[k], p.raw[7 - k]);
+ // DoSwap(p.raw[8 + k], p.raw[15 - k]);
+ // }
+ // printf("Port %d, address %" PRIx64 ":%" PRIx64 "\n",
+ // port,
+ // p.global.subnet_prefix,
+ // p.global.interface_id);
+ //}
+
+ // skip ROCE if flag is not set
+ if (portAttrs.lid == 0 && EnableROCEFlag == false) {
+ continue;
+ }
+ // bind to first active port
+ if (portAttrs.state == IBV_PORT_ACTIVE) {
+ resPort = new TIBPort(ctx, port);
+ break;
+ }
+ }
+ } else {
+ //printf("%d IB devices found, fail\n", numDevices);
+ ctx = nullptr;
+ }
+ ibv_free_device_list(deviceList);
+ IBPort = resPort;
+ return IBPort;
+ }
+
+ void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel) {
+ ibv_gid localGid, remoteGid;
+ localGid.global.subnet_prefix = localAddr.Network;
+ localGid.global.interface_id = localAddr.Interface;
+ remoteGid.global.subnet_prefix = remoteAddr.Network;
+ remoteGid.global.interface_id = remoteAddr.Interface;
+
+ Zero(*res);
+ res->is_global = 1;
+ res->port_num = port->GetPort();
+ res->sl = serviceLevel;
+ res->grh.dgid = remoteGid;
+ //res->grh.flow_label = 0;
+ res->grh.sgid_index = port->GetGIDIndex(localGid);
+ res->grh.hop_limit = 7;
+ //res->grh.traffic_class = 0;
+ }
+
+#endif
+}
diff --git a/library/cpp/netliba/v6/ib_low.h b/library/cpp/netliba/v6/ib_low.h
new file mode 100644
index 0000000000..b2a3e341d2
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_low.h
@@ -0,0 +1,797 @@
+#pragma once
+
+#include "udp_address.h"
+
+#if defined(_linux_) && !defined(CATBOOST_OPENSOURCE)
+#include <contrib/libs/ibdrv/include/infiniband/verbs.h>
+#include <contrib/libs/ibdrv/include/rdma/rdma_cma.h>
+#endif
+
+namespace NNetliba {
+#define CHECK_Z(x) \
+ { \
+ int rv = (x); \
+ if (rv != 0) { \
+ fprintf(stderr, "check_z failed, errno = %d\n", errno); \
+ Y_VERIFY(0, "check_z"); \
+ } \
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ const int MAX_SGE = 1;
+ const size_t MAX_INLINE_DATA_SIZE = 16;
+ const int MAX_OUTSTANDING_RDMA = 10;
+
+#if defined(_linux_) && !defined(CATBOOST_OPENSOURCE)
+ class TIBContext: public TThrRefBase, TNonCopyable {
+ ibv_context* Context;
+ ibv_pd* ProtDomain;
+ TMutex Lock;
+
+ ~TIBContext() override {
+ if (Context) {
+ CHECK_Z(ibv_dealloc_pd(ProtDomain));
+ CHECK_Z(ibv_close_device(Context));
+ }
+ }
+
+ public:
+ TIBContext(ibv_device* device) {
+ Context = ibv_open_device(device);
+ if (Context) {
+ ProtDomain = ibv_alloc_pd(Context);
+ }
+ }
+ bool IsValid() const {
+ return Context != nullptr && ProtDomain != nullptr;
+ }
+
+ class TLock {
+ TIntrusivePtr<TIBContext> Ptr;
+ TGuard<TMutex> Guard;
+
+ public:
+ TLock(TPtrArg<TIBContext> ctx)
+ : Ptr(ctx)
+ , Guard(ctx->Lock)
+ {
+ }
+ ibv_context* GetContext() {
+ return Ptr->Context;
+ }
+ ibv_pd* GetProtDomain() {
+ return Ptr->ProtDomain;
+ }
+ };
+ };
+
+ class TIBPort: public TThrRefBase, TNonCopyable {
+ int Port;
+ int LID;
+ TIntrusivePtr<TIBContext> IBCtx;
+ enum {
+ MAX_GID = 16
+ };
+ ibv_gid MyGidArr[MAX_GID];
+
+ public:
+ TIBPort(TPtrArg<TIBContext> ctx, int port)
+ : IBCtx(ctx)
+ {
+ ibv_port_attr portAttrs;
+ TIBContext::TLock ibContext(IBCtx);
+ CHECK_Z(ibv_query_port(ibContext.GetContext(), port, &portAttrs));
+ Port = port;
+ LID = portAttrs.lid;
+ for (int i = 0; i < MAX_GID; ++i) {
+ ibv_gid& dst = MyGidArr[i];
+ Zero(dst);
+ ibv_query_gid(ibContext.GetContext(), Port, i, &dst);
+ }
+ }
+ int GetPort() const {
+ return Port;
+ }
+ int GetLID() const {
+ return LID;
+ }
+ TIBContext* GetCtx() {
+ return IBCtx.Get();
+ }
+ void GetGID(ibv_gid* res) const {
+ *res = MyGidArr[0];
+ }
+ int GetGIDIndex(const ibv_gid& arg) const {
+ for (int i = 0; i < MAX_GID; ++i) {
+ const ibv_gid& chk = MyGidArr[i];
+ if (memcmp(&chk, &arg, sizeof(chk)) == 0) {
+ return i;
+ }
+ }
+ return 0;
+ }
+ void GetAHAttr(ibv_wc* wc, ibv_grh* grh, ibv_ah_attr* res) {
+ TIBContext::TLock ibContext(IBCtx);
+ CHECK_Z(ibv_init_ah_from_wc(ibContext.GetContext(), Port, wc, grh, res));
+ }
+ };
+
+ class TComplectionQueue: public TThrRefBase, TNonCopyable {
+ ibv_cq* CQ;
+ TIntrusivePtr<TIBContext> IBCtx;
+
+ ~TComplectionQueue() override {
+ if (CQ) {
+ CHECK_Z(ibv_destroy_cq(CQ));
+ }
+ }
+
+ public:
+ TComplectionQueue(TPtrArg<TIBContext> ctx, int maxCQEcount)
+ : IBCtx(ctx)
+ {
+ TIBContext::TLock ibContext(IBCtx);
+ /* ibv_cq_init_attr_ex attr;
+ Zero(attr);
+ attr.cqe = maxCQEcount;
+ attr.cq_create_flags = 0;
+ ibv_cq_ex *vcq = ibv_create_cq_ex(ibContext.GetContext(), &attr);
+ if (vcq) {
+ CQ = (ibv_cq*)vcq; // doubtful trick but that's life
+ } else {*/
+ // no completion channel
+ // no completion vector
+ CQ = ibv_create_cq(ibContext.GetContext(), maxCQEcount, nullptr, nullptr, 0);
+ // }
+ }
+ ibv_cq* GetCQ() {
+ return CQ;
+ }
+ int Poll(ibv_wc* res, int bufSize) {
+ Y_ASSERT(bufSize >= 1);
+ //struct ibv_wc
+ //{
+ // ui64 wr_id;
+ // enum ibv_wc_status status;
+ // enum ibv_wc_opcode opcode;
+ // ui32 vendor_err;
+ // ui32 byte_len;
+ // ui32 imm_data;/* network byte order */
+ // ui32 qp_num;
+ // ui32 src_qp;
+ // enum ibv_wc_flags wc_flags;
+ // ui16 pkey_index;
+ // ui16 slid;
+ // ui8 sl;
+ // ui8 dlid_path_bits;
+ //};
+ int rv = ibv_poll_cq(CQ, bufSize, res);
+ if (rv < 0) {
+ Y_VERIFY(0, "ibv_poll_cq failed");
+ }
+ if (rv > 0) {
+ //printf("Completed wr\n");
+ //printf(" wr_id = %" PRIx64 "\n", wc.wr_id);
+ //printf(" status = %d\n", wc.status);
+ //printf(" opcode = %d\n", wc.opcode);
+ //printf(" byte_len = %d\n", wc.byte_len);
+ //printf(" imm_data = %d\n", wc.imm_data);
+ //printf(" qp_num = %d\n", wc.qp_num);
+ //printf(" src_qp = %d\n", wc.src_qp);
+ //printf(" wc_flags = %x\n", wc.wc_flags);
+ //printf(" slid = %d\n", wc.slid);
+ }
+ //rv = number_of_toggled_wc;
+ return rv;
+ }
+ };
+
+ //struct ibv_mr
+ //{
+ // struct ibv_context *context;
+ // struct ibv_pd *pd;
+ // void *addr;
+ // size_t length;
+ // ui32 handle;
+ // ui32 lkey;
+ // ui32 rkey;
+ //};
+ class TMemoryRegion: public TThrRefBase, TNonCopyable {
+ ibv_mr* MR;
+ TIntrusivePtr<TIBContext> IBCtx;
+
+ ~TMemoryRegion() override {
+ if (MR) {
+ CHECK_Z(ibv_dereg_mr(MR));
+ }
+ }
+
+ public:
+ TMemoryRegion(TPtrArg<TIBContext> ctx, size_t len)
+ : IBCtx(ctx)
+ {
+ TIBContext::TLock ibContext(IBCtx);
+ int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; // TODO: IBV_ACCESS_ALLOCATE_MR
+ MR = ibv_reg_mr(ibContext.GetProtDomain(), 0, len, access);
+ Y_ASSERT(MR);
+ }
+ ui32 GetLKey() const {
+ static_assert(sizeof(ui32) == sizeof(MR->lkey), "expect sizeof(ui32) == sizeof(MR->lkey)");
+ return MR->lkey;
+ }
+ ui32 GetRKey() const {
+ static_assert(sizeof(ui32) == sizeof(MR->lkey), "expect sizeof(ui32) == sizeof(MR->lkey)");
+ return MR->lkey;
+ }
+ char* GetData() {
+ return MR ? (char*)MR->addr : nullptr;
+ }
+ bool IsCovered(const void* data, size_t len) const {
+ size_t dataAddr = (const char*)data - (const char*)nullptr;
+ size_t bufAddr = (const char*)MR->addr - (const char*)nullptr;
+ return (dataAddr >= bufAddr) && (dataAddr + len <= bufAddr + MR->length);
+ }
+ };
+
+ class TSharedReceiveQueue: public TThrRefBase, TNonCopyable {
+ ibv_srq* SRQ;
+ TIntrusivePtr<TIBContext> IBCtx;
+
+ ~TSharedReceiveQueue() override {
+ if (SRQ) {
+ ibv_destroy_srq(SRQ);
+ }
+ }
+
+ public:
+ TSharedReceiveQueue(TPtrArg<TIBContext> ctx, int maxWR)
+ : IBCtx(ctx)
+ {
+ ibv_srq_init_attr attr;
+ Zero(attr);
+ attr.srq_context = this;
+ attr.attr.max_sge = MAX_SGE;
+ attr.attr.max_wr = maxWR;
+
+ TIBContext::TLock ibContext(IBCtx);
+ SRQ = ibv_create_srq(ibContext.GetProtDomain(), &attr);
+ Y_ASSERT(SRQ);
+ }
+ ibv_srq* GetSRQ() {
+ return SRQ;
+ }
+ void PostReceive(TPtrArg<TMemoryRegion> mem, ui64 id, const void* buf, size_t len) {
+ Y_ASSERT(mem->IsCovered(buf, len));
+ ibv_recv_wr wr, *bad;
+ ibv_sge sg;
+ sg.addr = (const char*)buf - (const char*)nullptr;
+ sg.length = len;
+ sg.lkey = mem->GetLKey();
+ Zero(wr);
+ wr.wr_id = id;
+ wr.sg_list = &sg;
+ wr.num_sge = 1;
+ CHECK_Z(ibv_post_srq_recv(SRQ, &wr, &bad));
+ }
+ };
+
+ inline void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, int lid, int serviceLevel) {
+ Zero(*res);
+ res->dlid = lid;
+ res->port_num = port->GetPort();
+ res->sl = serviceLevel;
+ }
+
+ void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel);
+
+ class TAddressHandle: public TThrRefBase, TNonCopyable {
+ ibv_ah* AH;
+ TIntrusivePtr<TIBContext> IBCtx;
+
+ ~TAddressHandle() override {
+ if (AH) {
+ CHECK_Z(ibv_destroy_ah(AH));
+ }
+ AH = nullptr;
+ IBCtx = nullptr;
+ }
+
+ public:
+ TAddressHandle(TPtrArg<TIBContext> ctx, ibv_ah_attr* attr)
+ : IBCtx(ctx)
+ {
+ TIBContext::TLock ibContext(IBCtx);
+ AH = ibv_create_ah(ibContext.GetProtDomain(), attr);
+ Y_ASSERT(AH != nullptr);
+ }
+ TAddressHandle(TPtrArg<TIBPort> port, int lid, int serviceLevel)
+ : IBCtx(port->GetCtx())
+ {
+ ibv_ah_attr attr;
+ MakeAH(&attr, port, lid, serviceLevel);
+ TIBContext::TLock ibContext(IBCtx);
+ AH = ibv_create_ah(ibContext.GetProtDomain(), &attr);
+ Y_ASSERT(AH != nullptr);
+ }
+ TAddressHandle(TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel)
+ : IBCtx(port->GetCtx())
+ {
+ ibv_ah_attr attr;
+ MakeAH(&attr, port, remoteAddr, localAddr, serviceLevel);
+ TIBContext::TLock ibContext(IBCtx);
+ AH = ibv_create_ah(ibContext.GetProtDomain(), &attr);
+ Y_ASSERT(AH != nullptr);
+ }
+ ibv_ah* GetAH() {
+ return AH;
+ }
+ bool IsValid() const {
+ return AH != nullptr;
+ }
+ };
+
+ // GRH + wc -> address_handle_attr
+ //int ibv_init_ah_from_wc(struct ibv_context *context, ui8 port_num,
+ //struct ibv_wc *wc, struct ibv_grh *grh,
+ //struct ibv_ah_attr *ah_attr)
+ //ibv_create_ah_from_wc(struct ibv_pd *pd, struct ibv_wc *wc, struct ibv_grh
+ // *grh, ui8 port_num)
+
+ class TQueuePair: public TThrRefBase, TNonCopyable {
+ protected:
+ ibv_qp* QP;
+ int MyPSN; // start packet sequence number
+ TIntrusivePtr<TIBContext> IBCtx;
+ TIntrusivePtr<TComplectionQueue> CQ;
+ TIntrusivePtr<TSharedReceiveQueue> SRQ;
+
+ TQueuePair(TPtrArg<TIBContext> ctx, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq,
+ int sendQueueSize,
+ ibv_qp_type qp_type)
+ : IBCtx(ctx)
+ , CQ(cq)
+ , SRQ(srq)
+ {
+ MyPSN = GetCycleCount() & 0xffffff; // should be random and different on different runs, 24bit
+
+ ibv_qp_init_attr attr;
+ Zero(attr);
+ attr.qp_context = this; // not really useful
+ attr.send_cq = cq->GetCQ();
+ attr.recv_cq = cq->GetCQ();
+ attr.srq = srq->GetSRQ();
+ attr.cap.max_send_wr = sendQueueSize;
+ attr.cap.max_recv_wr = 0; // we are using srq, no need for qp's rq
+ attr.cap.max_send_sge = MAX_SGE;
+ attr.cap.max_recv_sge = MAX_SGE;
+ attr.cap.max_inline_data = MAX_INLINE_DATA_SIZE;
+ attr.qp_type = qp_type;
+ attr.sq_sig_all = 1; // inline sends need not be signaled, but if they are not work queue overflows
+
+ TIBContext::TLock ibContext(IBCtx);
+ QP = ibv_create_qp(ibContext.GetProtDomain(), &attr);
+ Y_ASSERT(QP);
+
+ //struct ibv_qp {
+ // struct ibv_context *context;
+ // void *qp_context;
+ // struct ibv_pd *pd;
+ // struct ibv_cq *send_cq;
+ // struct ibv_cq *recv_cq;
+ // struct ibv_srq *srq;
+ // ui32 handle;
+ // ui32 qp_num;
+ // enum ibv_qp_state state;
+ // enum ibv_qp_type qp_type;
+
+ // pthread_mutex_t mutex;
+ // pthread_cond_t cond;
+ // ui32 events_completed;
+ //};
+ //qp_context The value qp_context that was provided to ibv_create_qp()
+ //qp_num The number of this Queue Pair
+ //state The last known state of this Queue Pair. The actual state may be different from this state (in the RDMA device transitioned the state into other state)
+ //qp_type The Transport Service Type of this Queue Pair
+ }
+ ~TQueuePair() override {
+ if (QP) {
+ CHECK_Z(ibv_destroy_qp(QP));
+ }
+ }
+ void FillSendAttrs(ibv_send_wr* wr, ibv_sge* sg,
+ ui64 localAddr, ui32 lKey, ui64 id, size_t len) {
+ sg->addr = localAddr;
+ sg->length = len;
+ sg->lkey = lKey;
+ Zero(*wr);
+ wr->wr_id = id;
+ wr->sg_list = sg;
+ wr->num_sge = 1;
+ if (len <= MAX_INLINE_DATA_SIZE) {
+ wr->send_flags = IBV_SEND_INLINE;
+ }
+ }
+ void FillSendAttrs(ibv_send_wr* wr, ibv_sge* sg,
+ TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) {
+ ui64 localAddr = (const char*)data - (const char*)nullptr;
+ ui32 lKey = 0;
+ if (mem) {
+ Y_ASSERT(mem->IsCovered(data, len));
+ lKey = mem->GetLKey();
+ } else {
+ Y_ASSERT(len <= MAX_INLINE_DATA_SIZE);
+ }
+ FillSendAttrs(wr, sg, localAddr, lKey, id, len);
+ }
+
+ public:
+ int GetQPN() const {
+ if (QP)
+ return QP->qp_num;
+ return 0;
+ }
+ int GetPSN() const {
+ return MyPSN;
+ }
+ // we are using srq
+ //void PostReceive(const TMemoryRegion &mem)
+ //{
+ // ibv_recv_wr wr, *bad;
+ // ibv_sge sg;
+ // sg.addr = mem.Addr;
+ // sg.length = mem.Length;
+ // sg.lkey = mem.lkey;
+ // Zero(wr);
+ // wr.wr_id = 13;
+ // wr.sg_list = sg;
+ // wr.num_sge = 1;
+ // CHECK_Z(ibv_post_recv(QP, &wr, &bad));
+ //}
+ };
+
+ class TRCQueuePair: public TQueuePair {
+ public:
+ TRCQueuePair(TPtrArg<TIBContext> ctx, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq, int sendQueueSize)
+ : TQueuePair(ctx, cq, srq, sendQueueSize, IBV_QPT_RC)
+ {
+ }
+ // SRQ should have receive posted
+ void Init(const ibv_ah_attr& peerAddr, int peerQPN, int peerPSN) {
+ Y_ASSERT(QP->qp_type == IBV_QPT_RC);
+ ibv_qp_attr attr;
+ //{
+ // enum ibv_qp_state qp_state;
+ // enum ibv_qp_state cur_qp_state;
+ // enum ibv_mtu path_mtu;
+ // enum ibv_mig_state path_mig_state;
+ // ui32 qkey;
+ // ui32 rq_psn;
+ // ui32 sq_psn;
+ // ui32 dest_qp_num;
+ // int qp_access_flags;
+ // struct ibv_qp_cap cap;
+ // struct ibv_ah_attr ah_attr;
+ // struct ibv_ah_attr alt_ah_attr;
+ // ui16 pkey_index;
+ // ui16 alt_pkey_index;
+ // ui8 en_sqd_async_notify;
+ // ui8 sq_draining;
+ // ui8 max_rd_atomic;
+ // ui8 max_dest_rd_atomic;
+ // ui8 min_rnr_timer;
+ // ui8 port_num;
+ // ui8 timeout;
+ // ui8 retry_cnt;
+ // ui8 rnr_retry;
+ // ui8 alt_port_num;
+ // ui8 alt_timeout;
+ //};
+ // RESET -> INIT
+ Zero(attr);
+ attr.qp_state = IBV_QPS_INIT;
+ attr.pkey_index = 0;
+ attr.port_num = peerAddr.port_num;
+ // for connected QP
+ attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC;
+ CHECK_Z(ibv_modify_qp(QP, &attr,
+ IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS));
+
+ // INIT -> ReadyToReceive
+ //PostReceive(mem);
+ attr.qp_state = IBV_QPS_RTR;
+ attr.path_mtu = IBV_MTU_512; // allows more fine grained VL arbitration
+ // for connected QP
+ attr.ah_attr = peerAddr;
+ attr.dest_qp_num = peerQPN;
+ attr.rq_psn = peerPSN;
+ attr.max_dest_rd_atomic = MAX_OUTSTANDING_RDMA; // number of outstanding RDMA requests
+ attr.min_rnr_timer = 12; // recommended
+ CHECK_Z(ibv_modify_qp(QP, &attr,
+ IBV_QP_STATE | IBV_QP_PATH_MTU |
+ IBV_QP_AV | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER));
+
+ // ReadyToReceive -> ReadyToTransmit
+ attr.qp_state = IBV_QPS_RTS;
+ // for connected QP
+ attr.timeout = 14; // increased to 18 for sometime, 14 recommended
+ //attr.retry_cnt = 0; // for debug purposes
+ //attr.rnr_retry = 0; // for debug purposes
+ attr.retry_cnt = 7; // release configuration
+ attr.rnr_retry = 7; // release configuration (try forever)
+ attr.sq_psn = MyPSN;
+ attr.max_rd_atomic = MAX_OUTSTANDING_RDMA; // number of outstanding RDMA requests
+ CHECK_Z(ibv_modify_qp(QP, &attr,
+ IBV_QP_STATE |
+ IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC));
+ }
+ void PostSend(TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) {
+ ibv_send_wr wr, *bad;
+ ibv_sge sg;
+ FillSendAttrs(&wr, &sg, mem, id, data, len);
+ wr.opcode = IBV_WR_SEND;
+ //IBV_WR_RDMA_WRITE
+ //IBV_WR_RDMA_WRITE_WITH_IMM
+ //IBV_WR_SEND
+ //IBV_WR_SEND_WITH_IMM
+ //IBV_WR_RDMA_READ
+ //wr.imm_data = xz;
+
+ CHECK_Z(ibv_post_send(QP, &wr, &bad));
+ }
+ void PostRDMAWrite(ui64 remoteAddr, ui32 remoteKey,
+ TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) {
+ ibv_send_wr wr, *bad;
+ ibv_sge sg;
+ FillSendAttrs(&wr, &sg, mem, id, data, len);
+ wr.opcode = IBV_WR_RDMA_WRITE;
+ wr.wr.rdma.remote_addr = remoteAddr;
+ wr.wr.rdma.rkey = remoteKey;
+
+ CHECK_Z(ibv_post_send(QP, &wr, &bad));
+ }
+ void PostRDMAWrite(ui64 remoteAddr, ui32 remoteKey,
+ ui64 localAddr, ui32 localKey, ui64 id, size_t len) {
+ ibv_send_wr wr, *bad;
+ ibv_sge sg;
+ FillSendAttrs(&wr, &sg, localAddr, localKey, id, len);
+ wr.opcode = IBV_WR_RDMA_WRITE;
+ wr.wr.rdma.remote_addr = remoteAddr;
+ wr.wr.rdma.rkey = remoteKey;
+
+ CHECK_Z(ibv_post_send(QP, &wr, &bad));
+ }
+ void PostRDMAWriteImm(ui64 remoteAddr, ui32 remoteKey, ui32 immData,
+ TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) {
+ ibv_send_wr wr, *bad;
+ ibv_sge sg;
+ FillSendAttrs(&wr, &sg, mem, id, data, len);
+ wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
+ wr.imm_data = immData;
+ wr.wr.rdma.remote_addr = remoteAddr;
+ wr.wr.rdma.rkey = remoteKey;
+
+ CHECK_Z(ibv_post_send(QP, &wr, &bad));
+ }
+ };
+
+ class TUDQueuePair: public TQueuePair {
+ TIntrusivePtr<TIBPort> Port;
+
+ public:
+ TUDQueuePair(TPtrArg<TIBPort> port, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq, int sendQueueSize)
+ : TQueuePair(port->GetCtx(), cq, srq, sendQueueSize, IBV_QPT_UD)
+ , Port(port)
+ {
+ }
+ // SRQ should have receive posted
+ void Init(int qkey) {
+ Y_ASSERT(QP->qp_type == IBV_QPT_UD);
+ ibv_qp_attr attr;
+ // RESET -> INIT
+ Zero(attr);
+ attr.qp_state = IBV_QPS_INIT;
+ attr.pkey_index = 0;
+ attr.port_num = Port->GetPort();
+ // for unconnected qp
+ attr.qkey = qkey;
+ CHECK_Z(ibv_modify_qp(QP, &attr,
+ IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY));
+
+ // INIT -> ReadyToReceive
+ //PostReceive(mem);
+ attr.qp_state = IBV_QPS_RTR;
+ CHECK_Z(ibv_modify_qp(QP, &attr, IBV_QP_STATE));
+
+ // ReadyToReceive -> ReadyToTransmit
+ attr.qp_state = IBV_QPS_RTS;
+ attr.sq_psn = 0;
+ CHECK_Z(ibv_modify_qp(QP, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN));
+ }
+ void PostSend(TPtrArg<TAddressHandle> ah, int remoteQPN, int remoteQKey,
+ TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) {
+ ibv_send_wr wr, *bad;
+ ibv_sge sg;
+ FillSendAttrs(&wr, &sg, mem, id, data, len);
+ wr.opcode = IBV_WR_SEND;
+ wr.wr.ud.ah = ah->GetAH();
+ wr.wr.ud.remote_qpn = remoteQPN;
+ wr.wr.ud.remote_qkey = remoteQKey;
+ //IBV_WR_SEND_WITH_IMM
+ //wr.imm_data = xz;
+
+ CHECK_Z(ibv_post_send(QP, &wr, &bad));
+ }
+ };
+
+ TIntrusivePtr<TIBPort> GetIBDevice();
+
+#else
+ //////////////////////////////////////////////////////////////////////////
+ // stub for OS without IB support
+ //////////////////////////////////////////////////////////////////////////
+ enum ibv_wc_opcode {
+ IBV_WC_SEND,
+ IBV_WC_RDMA_WRITE,
+ IBV_WC_RDMA_READ,
+ IBV_WC_COMP_SWAP,
+ IBV_WC_FETCH_ADD,
+ IBV_WC_BIND_MW,
+ IBV_WC_RECV = 1 << 7,
+ IBV_WC_RECV_RDMA_WITH_IMM
+ };
+
+ enum ibv_wc_status {
+ IBV_WC_SUCCESS,
+ // lots of errors follow
+ };
+ //struct ibv_device;
+ //struct ibv_pd;
+ union ibv_gid {
+ ui8 raw[16];
+ struct {
+ ui64 subnet_prefix;
+ ui64 interface_id;
+ } global;
+ };
+
+ struct ibv_wc {
+ ui64 wr_id;
+ enum ibv_wc_status status;
+ enum ibv_wc_opcode opcode;
+ ui32 imm_data; /* in network byte order */
+ ui32 qp_num;
+ ui32 src_qp;
+ };
+ struct ibv_grh {};
+ struct ibv_ah_attr {
+ ui8 sl;
+ };
+ //struct ibv_cq;
+ class TIBContext: public TThrRefBase, TNonCopyable {
+ public:
+ bool IsValid() const {
+ return false;
+ }
+ //ibv_context *GetContext() { return 0; }
+ //ibv_pd *GetProtDomain() { return 0; }
+ };
+
+ class TIBPort: public TThrRefBase, TNonCopyable {
+ public:
+ TIBPort(TPtrArg<TIBContext>, int) {
+ }
+ int GetPort() const {
+ return 1;
+ }
+ int GetLID() const {
+ return 1;
+ }
+ TIBContext* GetCtx() {
+ return 0;
+ }
+ void GetGID(ibv_gid* res) {
+ Zero(*res);
+ }
+ void GetAHAttr(ibv_wc*, ibv_grh*, ibv_ah_attr*) {
+ }
+ };
+
+ class TComplectionQueue: public TThrRefBase, TNonCopyable {
+ public:
+ TComplectionQueue(TPtrArg<TIBContext>, int) {
+ }
+ //ibv_cq *GetCQ() { return 0; }
+ int Poll(ibv_wc*, int) {
+ return 0;
+ }
+ };
+
+ class TMemoryRegion: public TThrRefBase, TNonCopyable {
+ public:
+ TMemoryRegion(TPtrArg<TIBContext>, size_t) {
+ }
+ ui32 GetLKey() const {
+ return 0;
+ }
+ ui32 GetRKey() const {
+ return 0;
+ }
+ char* GetData() {
+ return 0;
+ }
+ bool IsCovered(const void*, size_t) const {
+ return false;
+ }
+ };
+
+ class TSharedReceiveQueue: public TThrRefBase, TNonCopyable {
+ public:
+ TSharedReceiveQueue(TPtrArg<TIBContext>, int) {
+ }
+ //ibv_srq *GetSRQ() { return SRQ; }
+ void PostReceive(TPtrArg<TMemoryRegion>, ui64, const void*, size_t) {
+ }
+ };
+
+ inline void MakeAH(ibv_ah_attr*, TPtrArg<TIBPort>, int, int) {
+ }
+
+ class TAddressHandle: public TThrRefBase, TNonCopyable {
+ public:
+ TAddressHandle(TPtrArg<TIBContext>, ibv_ah_attr*) {
+ }
+ TAddressHandle(TPtrArg<TIBPort>, int, int) {
+ }
+ TAddressHandle(TPtrArg<TIBPort>, const TUdpAddress&, const TUdpAddress&, int) {
+ }
+ //ibv_ah *GetAH() { return AH; }
+ bool IsValid() {
+ return true;
+ }
+ };
+
+ class TQueuePair: public TThrRefBase, TNonCopyable {
+ public:
+ int GetQPN() const {
+ return 0;
+ }
+ int GetPSN() const {
+ return 0;
+ }
+ };
+
+ class TRCQueuePair: public TQueuePair {
+ public:
+ TRCQueuePair(TPtrArg<TIBContext>, TPtrArg<TComplectionQueue>, TPtrArg<TSharedReceiveQueue>, int) {
+ }
+ // SRQ should have receive posted
+ void Init(const ibv_ah_attr&, int, int) {
+ }
+ void PostSend(TPtrArg<TMemoryRegion>, ui64, const void*, size_t) {
+ }
+ void PostRDMAWrite(ui64, ui32, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) {
+ }
+ void PostRDMAWrite(ui64, ui32, ui64, ui32, ui64, size_t) {
+ }
+ void PostRDMAWriteImm(ui64, ui32, ui32, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) {
+ }
+ };
+
+ class TUDQueuePair: public TQueuePair {
+ TIntrusivePtr<TIBPort> Port;
+
+ public:
+ TUDQueuePair(TPtrArg<TIBPort>, TPtrArg<TComplectionQueue>, TPtrArg<TSharedReceiveQueue>, int) {
+ }
+ // SRQ should have receive posted
+ void Init(int) {
+ }
+ void PostSend(TPtrArg<TAddressHandle>, int, int, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) {
+ }
+ };
+
+ inline TIntrusivePtr<TIBPort> GetIBDevice() {
+ return 0;
+ }
+#endif
+}
diff --git a/library/cpp/netliba/v6/ib_mem.cpp b/library/cpp/netliba/v6/ib_mem.cpp
new file mode 100644
index 0000000000..1e7f55ac57
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_mem.cpp
@@ -0,0 +1,167 @@
+#include "stdafx.h"
+#include "ib_mem.h"
+#include "ib_low.h"
+#include "cpu_affinity.h"
+
+#if defined(_unix_)
+#include <pthread.h>
+#endif
+
+namespace NNetliba {
+ TIBMemSuperBlock::TIBMemSuperBlock(TIBMemPool* pool, size_t szLog)
+ : Pool(pool)
+ , SzLog(szLog)
+ , UseCount(0)
+ {
+ size_t sz = GetSize();
+ MemRegion = new TMemoryRegion(pool->GetIBContext(), sz);
+ //printf("Alloc super block, size %" PRId64 "\n", sz);
+ }
+
+ TIBMemSuperBlock::~TIBMemSuperBlock() {
+ Y_ASSERT(AtomicGet(UseCount) == 0);
+ }
+
+ char* TIBMemSuperBlock::GetData() {
+ return MemRegion->GetData();
+ }
+
+ void TIBMemSuperBlock::DecRef() {
+ if (AtomicAdd(UseCount, -1) == 0) {
+ Pool->Return(this);
+ }
+ }
+
+ TIBMemBlock::~TIBMemBlock() {
+ if (Super.Get()) {
+ Super->DecRef();
+ } else {
+ delete[] Data;
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ TIBMemPool::TIBMemPool(TPtrArg<TIBContext> ctx)
+ : IBCtx(ctx)
+ , AllocCacheSize(0)
+ , CurrentOffset(IB_MEM_LARGE_BLOCK)
+ , WorkThread(TThread::TParams(ThreadFunc, (void*)this).SetName("nl6_ib_mem_pool"))
+ , KeepRunning(true)
+ {
+ WorkThread.Start();
+ HasStarted.Wait();
+ }
+
+ TIBMemPool::~TIBMemPool() {
+ Y_ASSERT(WorkThread.Running());
+ KeepRunning = false;
+ HasWork.Signal();
+ WorkThread.Join();
+ {
+ TJobItem* work = nullptr;
+ while (Requests.Dequeue(&work)) {
+ delete work;
+ }
+ }
+ }
+
+ TIntrusivePtr<TIBMemSuperBlock> TIBMemPool::AllocSuper(size_t szArg) {
+ // assume CacheLock is taken
+ size_t szLog = 12;
+ while ((((size_t)1) << szLog) < szArg) {
+ ++szLog;
+ }
+ TIntrusivePtr<TIBMemSuperBlock> super;
+ {
+ TVector<TIntrusivePtr<TIBMemSuperBlock>>& cc = AllocCache[szLog];
+ if (!cc.empty()) {
+ super = cc.back();
+ cc.resize(cc.size() - 1);
+ AllocCacheSize -= 1ll << super->SzLog;
+ }
+ }
+ if (super.Get() == nullptr) {
+ super = new TIBMemSuperBlock(this, szLog);
+ }
+ return super;
+ }
+
+ TIBMemBlock* TIBMemPool::Alloc(size_t sz) {
+ TGuard<TMutex> gg(CacheLock);
+ if (sz > IB_MEM_LARGE_BLOCK) {
+ TIntrusivePtr<TIBMemSuperBlock> super = AllocSuper(sz);
+ return new TIBMemBlock(super, super->GetData(), sz);
+ } else {
+ if (CurrentOffset + sz > IB_MEM_LARGE_BLOCK) {
+ CurrentBlk.Assign(AllocSuper(IB_MEM_LARGE_BLOCK));
+ CurrentOffset = 0;
+ }
+ CurrentOffset += sz;
+ return new TIBMemBlock(CurrentBlk.Get(), CurrentBlk.Get()->GetData() + CurrentOffset - sz, sz);
+ }
+ }
+
+ void TIBMemPool::Return(TPtrArg<TIBMemSuperBlock> blk) {
+ TGuard<TMutex> gg(CacheLock);
+ Y_ASSERT(AtomicGet(blk->UseCount) == 0);
+ size_t sz = 1ull << blk->SzLog;
+ if (sz + AllocCacheSize > IB_MEM_POOL_SIZE) {
+ AllocCache.clear();
+ AllocCacheSize = 0;
+ }
+ {
+ TVector<TIntrusivePtr<TIBMemSuperBlock>>& cc = AllocCache[blk->SzLog];
+ cc.push_back(blk.Get());
+ AllocCacheSize += sz;
+ }
+ }
+
+ void* TIBMemPool::ThreadFunc(void* param) {
+ BindToSocket(0);
+ SetHighestThreadPriority();
+ TIBMemPool* pThis = (TIBMemPool*)param;
+ pThis->HasStarted.Signal();
+
+ while (pThis->KeepRunning) {
+ TJobItem* work = nullptr;
+ if (!pThis->Requests.Dequeue(&work)) {
+ pThis->HasWork.Reset();
+ if (!pThis->Requests.Dequeue(&work)) {
+ pThis->HasWork.Wait();
+ }
+ }
+ if (work) {
+ //printf("mem copy got work\n");
+ int sz = work->Data->GetSize();
+ work->Block = pThis->Alloc(sz);
+ TBlockChainIterator bc(work->Data->GetChain());
+ bc.Read(work->Block->GetData(), sz);
+ TIntrusivePtr<TCopyResultStorage> dst = work->ResultStorage;
+ work->ResultStorage = nullptr;
+ dst->Results.Enqueue(work);
+ //printf("mem copy completed\n");
+ }
+ }
+ return nullptr;
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ static TMutex IBMemMutex;
+ static TIntrusivePtr<TIBMemPool> IBMemPool;
+ static bool IBWasInitialized;
+
+ TIntrusivePtr<TIBMemPool> GetIBMemPool() {
+ TGuard<TMutex> gg(IBMemMutex);
+ if (IBWasInitialized) {
+ return IBMemPool;
+ }
+ IBWasInitialized = true;
+
+ TIntrusivePtr<TIBPort> ibPort = GetIBDevice();
+ if (ibPort.Get() == nullptr) {
+ return nullptr;
+ }
+ IBMemPool = new TIBMemPool(ibPort->GetCtx());
+ return IBMemPool;
+ }
+}
diff --git a/library/cpp/netliba/v6/ib_mem.h b/library/cpp/netliba/v6/ib_mem.h
new file mode 100644
index 0000000000..dfa5b9cd5f
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_mem.h
@@ -0,0 +1,178 @@
+#pragma once
+
+#include "block_chain.h"
+#include <util/thread/lfqueue.h>
+#include <util/system/thread.h>
+
+namespace NNetliba {
+ // registered memory blocks
+ class TMemoryRegion;
+ class TIBContext;
+
+ class TIBMemPool;
+ struct TIBMemSuperBlock: public TThrRefBase, TNonCopyable {
+ TIntrusivePtr<TIBMemPool> Pool;
+ size_t SzLog;
+ TAtomic UseCount;
+ TIntrusivePtr<TMemoryRegion> MemRegion;
+
+ TIBMemSuperBlock(TIBMemPool* pool, size_t szLog);
+ ~TIBMemSuperBlock() override;
+ char* GetData();
+ size_t GetSize() {
+ return ((ui64)1) << SzLog;
+ }
+ void IncRef() {
+ AtomicAdd(UseCount, 1);
+ }
+ void DecRef();
+ };
+
+ class TIBMemBlock: public TThrRefBase, TNonCopyable {
+ TIntrusivePtr<TIBMemSuperBlock> Super;
+ char* Data;
+ size_t Size;
+
+ ~TIBMemBlock() override;
+
+ public:
+ TIBMemBlock(TPtrArg<TIBMemSuperBlock> super, char* data, size_t sz)
+ : Super(super)
+ , Data(data)
+ , Size(sz)
+ {
+ Super->IncRef();
+ }
+ TIBMemBlock(size_t sz)
+ : Super(nullptr)
+ , Size(sz)
+ {
+ // not really IB mem block, but useful IB code debug without IB
+ Data = new char[sz];
+ }
+ char* GetData() {
+ return Data;
+ }
+ ui64 GetAddr() {
+ return Data - (char*)nullptr;
+ }
+ size_t GetSize() {
+ return Size;
+ }
+ TMemoryRegion* GetMemRegion() {
+ return Super.Get() ? Super->MemRegion.Get() : nullptr;
+ }
+ };
+
+ const size_t IB_MEM_LARGE_BLOCK_LN = 20;
+ const size_t IB_MEM_LARGE_BLOCK = 1ul << IB_MEM_LARGE_BLOCK_LN;
+ const size_t IB_MEM_POOL_SIZE = 1024 * 1024 * 1024;
+
+ class TIBMemPool: public TThrRefBase, TNonCopyable {
+ public:
+ struct TCopyResultStorage;
+
+ private:
+ class TIBMemSuperBlockPtr {
+ TIntrusivePtr<TIBMemSuperBlock> Blk;
+
+ public:
+ ~TIBMemSuperBlockPtr() {
+ Detach();
+ }
+ void Assign(TIntrusivePtr<TIBMemSuperBlock> p) {
+ Detach();
+ Blk = p;
+ if (p.Get()) {
+ AtomicAdd(p->UseCount, 1);
+ }
+ }
+ void Detach() {
+ if (Blk.Get()) {
+ Blk->DecRef();
+ Blk = nullptr;
+ }
+ }
+ TIBMemSuperBlock* Get() {
+ return Blk.Get();
+ }
+ };
+
+ TIntrusivePtr<TIBContext> IBCtx;
+ THashMap<size_t, TVector<TIntrusivePtr<TIBMemSuperBlock>>> AllocCache;
+ size_t AllocCacheSize;
+ TIBMemSuperBlockPtr CurrentBlk;
+ int CurrentOffset;
+ TMutex CacheLock;
+ TThread WorkThread;
+ TSystemEvent HasStarted;
+ bool KeepRunning;
+
+ struct TJobItem {
+ TRopeDataPacket* Data;
+ i64 MsgHandle;
+ TIntrusivePtr<TThrRefBase> Context;
+ TIntrusivePtr<TIBMemBlock> Block;
+ TIntrusivePtr<TCopyResultStorage> ResultStorage;
+
+ TJobItem(TRopeDataPacket* data, i64 msgHandle, TThrRefBase* context, TPtrArg<TCopyResultStorage> dst)
+ : Data(data)
+ , MsgHandle(msgHandle)
+ , Context(context)
+ , ResultStorage(dst)
+ {
+ }
+ };
+
+ TLockFreeQueue<TJobItem*> Requests;
+ TSystemEvent HasWork;
+
+ static void* ThreadFunc(void* param);
+
+ void Return(TPtrArg<TIBMemSuperBlock> blk);
+ TIntrusivePtr<TIBMemSuperBlock> AllocSuper(size_t sz);
+ ~TIBMemPool() override;
+
+ public:
+ struct TCopyResultStorage: public TThrRefBase {
+ TLockFreeStack<TJobItem*> Results;
+
+ ~TCopyResultStorage() override {
+ TJobItem* work;
+ while (Results.Dequeue(&work)) {
+ delete work;
+ }
+ }
+ template <class T>
+ bool GetCopyResult(TIntrusivePtr<TIBMemBlock>* resBlock, i64* resMsgHandle, TIntrusivePtr<T>* context) {
+ TJobItem* work;
+ if (Results.Dequeue(&work)) {
+ *resBlock = work->Block;
+ *resMsgHandle = work->MsgHandle;
+ *context = static_cast<T*>(work->Context.Get()); // caller responsibility to make sure this makes sense
+ delete work;
+ return true;
+ } else {
+ return false;
+ }
+ }
+ };
+
+ public:
+ TIBMemPool(TPtrArg<TIBContext> ctx);
+ TIBContext* GetIBContext() {
+ return IBCtx.Get();
+ }
+ TIBMemBlock* Alloc(size_t sz);
+
+ void CopyData(TRopeDataPacket* data, i64 msgHandle, TThrRefBase* context, TPtrArg<TCopyResultStorage> dst) {
+ Requests.Enqueue(new TJobItem(data, msgHandle, context, dst));
+ HasWork.Signal();
+ }
+
+ friend class TIBMemBlock;
+ friend struct TIBMemSuperBlock;
+ };
+
+ extern TIntrusivePtr<TIBMemPool> GetIBMemPool();
+}
diff --git a/library/cpp/netliba/v6/ib_memstream.cpp b/library/cpp/netliba/v6/ib_memstream.cpp
new file mode 100644
index 0000000000..ffed8f1dba
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_memstream.cpp
@@ -0,0 +1,122 @@
+#include "stdafx.h"
+#include "ib_mem.h"
+#include "ib_memstream.h"
+#include "ib_low.h"
+
+namespace NNetliba {
+ int TIBMemStream::WriteImpl(const void* userBuffer, int sizeArg) {
+ const char* srcData = (const char*)userBuffer;
+ int size = sizeArg;
+ for (;;) {
+ if (size == 0)
+ return sizeArg;
+ if (CurBlock == Blocks.ysize()) {
+ // add new block
+ TBlock& blk = Blocks.emplace_back();
+ blk.StartOffset = GetLength();
+ int szLog = 17 + Min(Blocks.ysize() / 2, 13);
+ blk.BufSize = 1 << szLog;
+ blk.DataSize = 0;
+ blk.Mem = MemPool->Alloc(blk.BufSize);
+ Y_ASSERT(CurBlockOffset == 0);
+ }
+ TBlock& curBlk = Blocks[CurBlock];
+ int leftSpace = curBlk.BufSize - CurBlockOffset;
+ int copySize = Min(size, leftSpace);
+ memcpy(curBlk.Mem->GetData() + CurBlockOffset, srcData, copySize);
+ size -= copySize;
+ CurBlockOffset += copySize;
+ srcData += copySize;
+ curBlk.DataSize = Max(curBlk.DataSize, CurBlockOffset);
+ if (CurBlockOffset == curBlk.BufSize) {
+ ++CurBlock;
+ CurBlockOffset = 0;
+ }
+ }
+ }
+
+ int TIBMemStream::ReadImpl(void* userBuffer, int sizeArg) {
+ char* dstData = (char*)userBuffer;
+ int size = sizeArg;
+ for (;;) {
+ if (size == 0)
+ return sizeArg;
+ if (CurBlock == Blocks.ysize()) {
+ //memset(dstData, 0, size);
+ size = 0;
+ continue;
+ }
+ TBlock& curBlk = Blocks[CurBlock];
+ int leftSpace = curBlk.DataSize - CurBlockOffset;
+ int copySize = Min(size, leftSpace);
+ memcpy(dstData, curBlk.Mem->GetData() + CurBlockOffset, copySize);
+ size -= copySize;
+ CurBlockOffset += copySize;
+ dstData += copySize;
+ if (CurBlockOffset == curBlk.DataSize) {
+ ++CurBlock;
+ CurBlockOffset = 0;
+ }
+ }
+ }
+
+ i64 TIBMemStream::GetLength() {
+ i64 res = 0;
+ for (int i = 0; i < Blocks.ysize(); ++i) {
+ res += Blocks[i].DataSize;
+ }
+ return res;
+ }
+
+ i64 TIBMemStream::Seek(i64 pos) {
+ for (int resBlockId = 0; resBlockId < Blocks.ysize(); ++resBlockId) {
+ const TBlock& blk = Blocks[resBlockId];
+ if (pos < blk.StartOffset + blk.DataSize) {
+ CurBlock = resBlockId;
+ CurBlockOffset = pos - blk.StartOffset;
+ return pos;
+ }
+ }
+ CurBlock = Blocks.ysize();
+ CurBlockOffset = 0;
+ return GetLength();
+ }
+
+ void TIBMemStream::GetBlocks(TVector<TBlockDescr>* res) const {
+ int blockCount = Blocks.ysize();
+ res->resize(blockCount);
+ for (int i = 0; i < blockCount; ++i) {
+ const TBlock& blk = Blocks[i];
+ TBlockDescr& dst = (*res)[i];
+ dst.Addr = blk.Mem->GetAddr();
+ dst.BufSize = blk.BufSize;
+ dst.DataSize = blk.DataSize;
+ TMemoryRegion* mem = blk.Mem->GetMemRegion();
+ dst.LocalKey = mem->GetLKey();
+ dst.RemoteKey = mem->GetRKey();
+ }
+ }
+
+ void TIBMemStream::CreateBlocks(const TVector<TBlockSizes>& arr) {
+ int blockCount = arr.ysize();
+ Blocks.resize(blockCount);
+ i64 offset = 0;
+ for (int i = 0; i < blockCount; ++i) {
+ const TBlockSizes& src = arr[i];
+ TBlock& blk = Blocks[i];
+ blk.BufSize = src.BufSize;
+ blk.DataSize = src.DataSize;
+ blk.Mem = MemPool->Alloc(blk.BufSize);
+ blk.StartOffset = offset;
+ offset += blk.DataSize;
+ }
+ CurBlock = 0;
+ CurBlockOffset = 0;
+ }
+
+ void TIBMemStream::Clear() {
+ Blocks.resize(0);
+ CurBlock = 0;
+ CurBlockOffset = 0;
+ }
+}
diff --git a/library/cpp/netliba/v6/ib_memstream.h b/library/cpp/netliba/v6/ib_memstream.h
new file mode 100644
index 0000000000..67eb2386de
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_memstream.h
@@ -0,0 +1,95 @@
+#pragma once
+
+#include "ib_mem.h"
+#include <library/cpp/binsaver/bin_saver.h>
+#include <library/cpp/binsaver/buffered_io.h>
+
+namespace NNetliba {
+ class TIBMemStream: public IBinaryStream {
+ struct TBlock {
+ TIntrusivePtr<TIBMemBlock> Mem;
+ i64 StartOffset;
+ int BufSize, DataSize;
+
+ TBlock()
+ : StartOffset(0)
+ , BufSize(0)
+ , DataSize(0)
+ {
+ }
+ TBlock(const TBlock& x) {
+ Copy(x);
+ }
+ void operator=(const TBlock& x) {
+ Copy(x);
+ }
+ void Copy(const TBlock& x) {
+ if (x.BufSize > 0) {
+ Mem = GetIBMemPool()->Alloc(x.BufSize);
+ memcpy(Mem->GetData(), x.Mem->GetData(), x.DataSize);
+ StartOffset = x.StartOffset;
+ BufSize = x.BufSize;
+ DataSize = x.DataSize;
+ } else {
+ Mem = nullptr;
+ StartOffset = 0;
+ BufSize = 0;
+ DataSize = 0;
+ }
+ }
+ };
+
+ TIntrusivePtr<TIBMemPool> MemPool;
+ TVector<TBlock> Blocks;
+ int CurBlock;
+ int CurBlockOffset;
+
+ public:
+ struct TBlockDescr {
+ ui64 Addr;
+ int BufSize, DataSize;
+ ui32 RemoteKey, LocalKey;
+ };
+ struct TBlockSizes {
+ int BufSize, DataSize;
+ };
+
+ public:
+ TIBMemStream()
+ : MemPool(GetIBMemPool())
+ , CurBlock(0)
+ , CurBlockOffset(0)
+ {
+ }
+ ~TIBMemStream() override {
+ } // keep gcc happy
+
+ bool IsValid() const override {
+ return true;
+ }
+ bool IsFailed() const override {
+ return false;
+ }
+ void Flush() {
+ }
+
+ i64 GetLength();
+ i64 Seek(i64 pos);
+
+ void GetBlocks(TVector<TBlockDescr>* res) const;
+ void CreateBlocks(const TVector<TBlockSizes>& arr);
+
+ void Clear();
+
+ private:
+ int WriteImpl(const void* userBuffer, int size) override;
+ int ReadImpl(void* userBuffer, int size) override;
+ };
+
+ template <class T>
+ inline void Serialize(bool bRead, TIBMemStream& ms, T& c) {
+ IBinSaver bs(ms, bRead);
+ bs.Add(1, &c);
+ }
+
+}
diff --git a/library/cpp/netliba/v6/ib_test.cpp b/library/cpp/netliba/v6/ib_test.cpp
new file mode 100644
index 0000000000..178691e837
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_test.cpp
@@ -0,0 +1,232 @@
+#include "stdafx.h"
+#include "ib_test.h"
+#include "ib_buffers.h"
+#include "udp_socket.h"
+#include "udp_address.h"
+#include <util/system/hp_timer.h>
+
+namespace NNetliba {
+ struct TWelcomeSocketAddr {
+ int LID;
+ int QPN;
+ };
+
+ struct TRCQueuePairHandshake {
+ int QPN, PSN;
+ };
+
+ class TIPSocket {
+ TNetSocket s;
+ enum {
+ HDR_SIZE = UDP_LOW_LEVEL_HEADER_SIZE
+ };
+
+ public:
+ void Init(int port) {
+ if (!InitLocalIPList()) {
+ Y_ASSERT(0 && "Can not determine self IP address");
+ return;
+ }
+ s.Open(port);
+ }
+ bool IsValid() {
+ return s.IsValid();
+ }
+ void Respond(const TWelcomeSocketAddr& info) {
+ char buf[10000];
+ int sz = sizeof(buf);
+ sockaddr_in6 fromAddress;
+ bool rv = s.RecvFrom(buf, &sz, &fromAddress);
+ if (rv && strcmp(buf + HDR_SIZE, "Hello_IB") == 0) {
+ printf("send welcome info\n");
+ memcpy(buf + HDR_SIZE, &info, sizeof(info));
+ TNetSocket::ESendError err = s.SendTo(buf, sizeof(info) + HDR_SIZE, fromAddress, FF_ALLOW_FRAG);
+ if (err != TNetSocket::SEND_OK) {
+ printf("SendTo() fail %d\n", err);
+ }
+ }
+ }
+ void Request(const char* hostName, int port, TWelcomeSocketAddr* res) {
+ TUdpAddress addr = CreateAddress(hostName, port);
+ printf("addr = %s\n", GetAddressAsString(addr).c_str());
+
+ sockaddr_in6 sockAddr;
+ GetWinsockAddr(&sockAddr, addr);
+
+ for (;;) {
+ char buf[10000];
+ int sz = sizeof(buf);
+ sockaddr_in6 fromAddress;
+ bool rv = s.RecvFrom(buf, &sz, &fromAddress);
+ if (rv) {
+ if (sz == sizeof(TWelcomeSocketAddr) + HDR_SIZE) {
+ *res = *(TWelcomeSocketAddr*)(buf + HDR_SIZE);
+ break;
+ }
+ printf("Get unexpected %d bytes from somewhere?\n", sz);
+ }
+
+ strcpy(buf + HDR_SIZE, "Hello_IB");
+ TNetSocket::ESendError err = s.SendTo(buf, strlen(buf) + 1 + HDR_SIZE, sockAddr, FF_ALLOW_FRAG);
+ if (err != TNetSocket::SEND_OK) {
+ printf("SendTo() fail %d\n", err);
+ }
+
+ Sleep(TDuration::MilliSeconds(100));
+ }
+ }
+ };
+
+ // can hang if opposite side exits, but it's only basic test
+ static void WaitForRecv(TIBBufferPool* bp, TPtrArg<TComplectionQueue> cq, ibv_wc* wc) {
+ for (;;) {
+ if (cq->Poll(wc, 1) == 1) {
+ if (wc->opcode & IBV_WC_RECV) {
+ break;
+ }
+ bp->FreeBuf(wc->wr_id);
+ }
+ }
+ }
+
+ void RunIBTest(bool isClient, const char* serverName) {
+ TIntrusivePtr<TIBPort> port = GetIBDevice();
+ if (port.Get() == nullptr) {
+ printf("No IB device found\n");
+ return;
+ }
+
+ const int IP_PORT = 13666;
+ const int WELCOME_QKEY = 0x1113013;
+ const int MAX_SRQ_WORK_REQUESTS = 100;
+ const int MAX_CQ_EVENTS = 1000;
+ const int QP_SEND_QUEUE_SIZE = 3;
+
+ TIntrusivePtr<TComplectionQueue> cq = new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS);
+
+ TIBBufferPool bp(port->GetCtx(), MAX_SRQ_WORK_REQUESTS);
+
+ if (!isClient) {
+ // server
+ TIPSocket ipSocket;
+ ipSocket.Init(IP_PORT);
+ if (!ipSocket.IsValid()) {
+ printf("UDP port %d is not available\n", IP_PORT);
+ return;
+ }
+
+ TIntrusivePtr<TComplectionQueue> cqRC = new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS);
+
+ TIntrusivePtr<TUDQueuePair> welcomeQP = new TUDQueuePair(port, cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE);
+ welcomeQP->Init(WELCOME_QKEY);
+
+ TWelcomeSocketAddr info;
+ info.LID = port->GetLID();
+ info.QPN = welcomeQP->GetQPN();
+
+ TIntrusivePtr<TAddressHandle> ahPeer1;
+ for (;;) {
+ ipSocket.Respond(info);
+ // poll srq
+ ibv_wc wc;
+ if (cq->Poll(&wc, 1) == 1 && (wc.opcode & IBV_WC_RECV)) {
+ printf("Got IB handshake\n");
+
+ TRCQueuePairHandshake remoteHandshake;
+ ibv_ah_attr clientAddr;
+ {
+ TIBRecvPacketProcess pkt(bp, wc);
+ remoteHandshake = *(TRCQueuePairHandshake*)pkt.GetUDData();
+ port->GetAHAttr(&wc, pkt.GetGRH(), &clientAddr);
+ }
+
+ TIntrusivePtr<TAddressHandle> ahPeer2;
+ ahPeer2 = new TAddressHandle(port->GetCtx(), &clientAddr);
+
+ TIntrusivePtr<TRCQueuePair> rcTest = new TRCQueuePair(port->GetCtx(), cqRC, bp.GetSRQ(), QP_SEND_QUEUE_SIZE);
+ rcTest->Init(clientAddr, remoteHandshake.QPN, remoteHandshake.PSN);
+
+ TRCQueuePairHandshake handshake;
+ handshake.PSN = rcTest->GetPSN();
+ handshake.QPN = rcTest->GetQPN();
+ bp.PostSend(welcomeQP, ahPeer2, wc.src_qp, WELCOME_QKEY, &handshake, sizeof(handshake));
+
+ WaitForRecv(&bp, cqRC, &wc);
+
+ {
+ TIBRecvPacketProcess pkt(bp, wc);
+ printf("Got RC ping: %s\n", pkt.GetData());
+ const char* ret = "Puk";
+ bp.PostSend(rcTest, ret, strlen(ret) + 1);
+ }
+
+ for (int i = 0; i < 5; ++i) {
+ WaitForRecv(&bp, cqRC, &wc);
+ TIBRecvPacketProcess pkt(bp, wc);
+ printf("Got RC ping: %s\n", pkt.GetData());
+ const char* ret = "Fine";
+ bp.PostSend(rcTest, ret, strlen(ret) + 1);
+ }
+ }
+ }
+ } else {
+ // client
+ ibv_wc wc;
+
+ TIPSocket ipSocket;
+ ipSocket.Init(0);
+ if (!ipSocket.IsValid()) {
+ printf("Failed to create UDP socket\n");
+ return;
+ }
+
+ printf("Connecting to %s\n", serverName);
+ TWelcomeSocketAddr info;
+ ipSocket.Request(serverName, IP_PORT, &info);
+ printf("Got welcome info, lid %d, qpn %d\n", info.LID, info.QPN);
+
+ TIntrusivePtr<TUDQueuePair> welcomeQP = new TUDQueuePair(port, cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE);
+ welcomeQP->Init(WELCOME_QKEY);
+
+ TIntrusivePtr<TRCQueuePair> rcTest = new TRCQueuePair(port->GetCtx(), cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE);
+
+ TRCQueuePairHandshake handshake;
+ handshake.PSN = rcTest->GetPSN();
+ handshake.QPN = rcTest->GetQPN();
+ TIntrusivePtr<TAddressHandle> serverAH = new TAddressHandle(port, info.LID, 0);
+ bp.PostSend(welcomeQP, serverAH, info.QPN, WELCOME_QKEY, &handshake, sizeof(handshake));
+
+ WaitForRecv(&bp, cq, &wc);
+
+ ibv_ah_attr serverAddr;
+ TRCQueuePairHandshake remoteHandshake;
+ {
+ TIBRecvPacketProcess pkt(bp, wc);
+ printf("Got handshake response\n");
+ remoteHandshake = *(TRCQueuePairHandshake*)pkt.GetUDData();
+ port->GetAHAttr(&wc, pkt.GetGRH(), &serverAddr);
+ }
+
+ rcTest->Init(serverAddr, remoteHandshake.QPN, remoteHandshake.PSN);
+
+ char hiAndy[] = "Hi, Andy";
+ bp.PostSend(rcTest, hiAndy, sizeof(hiAndy));
+ WaitForRecv(&bp, cq, &wc);
+ {
+ TIBRecvPacketProcess pkt(bp, wc);
+ printf("Got RC pong: %s\n", pkt.GetData());
+ }
+
+ for (int i = 0; i < 5; ++i) {
+ char howAreYou[] = "How are you?";
+ bp.PostSend(rcTest, howAreYou, sizeof(howAreYou));
+
+ WaitForRecv(&bp, cq, &wc);
+ {
+ TIBRecvPacketProcess pkt(bp, wc);
+ printf("Got RC pong: %s\n", pkt.GetData());
+ }
+ }
+ }
+ }
+}
diff --git a/library/cpp/netliba/v6/ib_test.h b/library/cpp/netliba/v6/ib_test.h
new file mode 100644
index 0000000000..b72c95d7a9
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_test.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace NNetliba {
+ void RunIBTest(bool isClient, const char* serverName);
+}
diff --git a/library/cpp/netliba/v6/net_acks.cpp b/library/cpp/netliba/v6/net_acks.cpp
new file mode 100644
index 0000000000..5f4690c264
--- /dev/null
+++ b/library/cpp/netliba/v6/net_acks.cpp
@@ -0,0 +1,194 @@
+#include "stdafx.h"
+#include "net_acks.h"
+#include <util/datetime/cputimer.h>
+
+#include <atomic>
+
+namespace NNetliba {
+ const float RTT_AVERAGE_OVER = 15;
+
+ float TCongestionControl::StartWindowSize = 3;
+ float TCongestionControl::MaxPacketRate = 0; // unlimited
+
+ bool UseTOSforAcks = false; //true;//
+
+ void EnableUseTOSforAcks(bool enable) {
+ UseTOSforAcks = enable;
+ }
+
+ float CONG_CTRL_CHANNEL_INFLATE = 1;
+
+ void SetCongCtrlChannelInflate(float inflate) {
+ CONG_CTRL_CHANNEL_INFLATE = inflate;
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ TPingTracker::TPingTracker()
+ : AvrgRTT(CONG_CTRL_INITIAL_RTT)
+ , AvrgRTT2(CONG_CTRL_INITIAL_RTT * CONG_CTRL_INITIAL_RTT)
+ , RTTCount(0)
+ {
+ }
+
+ void TPingTracker::RegisterRTT(float rtt) {
+ Y_ASSERT(rtt > 0);
+ float keep = RTTCount / (RTTCount + 1);
+ AvrgRTT *= keep;
+ AvrgRTT += (1 - keep) * rtt;
+ AvrgRTT2 *= keep;
+ AvrgRTT2 += (1 - keep) * Sqr(rtt);
+ RTTCount = Min(RTTCount + 1, RTT_AVERAGE_OVER);
+ //static int n;
+ //if ((++n % 1024) == 0)
+ // printf("Average RTT = %g (sko = %g)\n", GetRTT() * 1000, GetRTTSKO() * 1000);
+ }
+
+ void TPingTracker::IncreaseRTT() {
+ const float F_RTT_DECAY_RATE = 1.1f;
+ AvrgRTT *= F_RTT_DECAY_RATE;
+ AvrgRTT2 *= Sqr(F_RTT_DECAY_RATE);
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ void TAckTracker::Resend() {
+ CurrentPacket = 0;
+ for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i)
+ Congestion->Failure(); // not actually correct but simplifies logic a lot
+ PacketsInFly.clear();
+ DroppedPackets.clear();
+ ResendQueue.clear();
+ for (size_t i = 0; i < AckReceived.size(); ++i)
+ AckReceived[i] = false;
+ }
+
+ int TAckTracker::SelectPacket() {
+ if (!ResendQueue.empty()) {
+ int res = ResendQueue.back();
+ ResendQueue.pop_back();
+ //printf("resending packet %d\n", res);
+ return res;
+ }
+ if (CurrentPacket == PacketCount) {
+ return -1;
+ }
+ return CurrentPacket++;
+ }
+
+ TAckTracker::~TAckTracker() {
+ for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i)
+ Congestion->Failure();
+ // object will be incorrect state after this (failed packets are not added to resend queue), but who cares
+ }
+
+ int TAckTracker::GetPacketToSend(float deltaT) {
+ int res = SelectPacket();
+ if (res == -1) {
+ // needed to count time even if we don't have anything to send
+ Congestion->HasTriedToSend();
+ return res;
+ }
+ Congestion->LaunchPacket();
+ PacketsInFly[res] = -deltaT; // deltaT is time since last Step(), so for the timing to be correct we should subtract it
+ return res;
+ }
+
+ // called on SendTo() failure
+ void TAckTracker::AddToResend(int pkt) {
+ //printf("AddToResend(%d)\n", pkt);
+ TPacketHash::iterator i = PacketsInFly.find(pkt);
+ if (i != PacketsInFly.end()) {
+ PacketsInFly.erase(i);
+ Congestion->FailureOnSend();
+ ResendQueue.push_back(pkt);
+ } else
+ Y_ASSERT(0);
+ }
+
+ void TAckTracker::Ack(int pkt, float deltaT, bool updateRTT) {
+ Y_ASSERT(pkt >= 0 && pkt < PacketCount);
+ if (AckReceived[pkt])
+ return;
+ AckReceived[pkt] = true;
+ //printf("Ack received for %d\n", pkt);
+ TPacketHash::iterator i = PacketsInFly.find(pkt);
+ if (i == PacketsInFly.end()) {
+ for (size_t k = 0; k < ResendQueue.size(); ++k) {
+ if (ResendQueue[k] == pkt) {
+ ResendQueue[k] = ResendQueue.back();
+ ResendQueue.pop_back();
+ break;
+ }
+ }
+ TPacketHash::iterator z = DroppedPackets.find(pkt);
+ if (z != DroppedPackets.end()) {
+ // late packet arrived
+ if (updateRTT) {
+ float ping = z->second + deltaT;
+ Congestion->RegisterRTT(ping);
+ }
+ DroppedPackets.erase(z);
+ } else {
+ // Y_ASSERT(0); // ack on nonsent packet, possible in resend scenario
+ }
+ return;
+ }
+ if (updateRTT) {
+ float ping = i->second + deltaT;
+ //printf("Register RTT %g\n", ping * 1000);
+ Congestion->RegisterRTT(ping);
+ }
+ PacketsInFly.erase(i);
+ Congestion->Success();
+ }
+
+ void TAckTracker::AckAll() {
+ for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i) {
+ int pkt = i->first;
+ AckReceived[pkt] = true;
+ Congestion->Success();
+ }
+ PacketsInFly.clear();
+ }
+
+ void TAckTracker::Step(float deltaT) {
+ float timeoutVal = Congestion->GetTimeout();
+
+ //static int n;
+ //if ((++n % 1024) == 0)
+ // printf("timeout = %g, window = %g, fail_rate %g, pkt_rate = %g\n", timeoutVal * 1000, Congestion->GetWindow(), Congestion->GetFailRate(), (1 - Congestion->GetFailRate()) * Congestion->GetWindow() / Congestion->GetRTT());
+
+ TimeToNextPacketTimeout = 1000;
+ // для окон меньше единицы мы кидаем рандом один раз за RTT на то, можно ли пускать пакет
+ // поэтому можно ждать максимум RTT, после этого надо кинуть новый random
+ if (Congestion->GetWindow() < 1)
+ TimeToNextPacketTimeout = Congestion->GetRTT();
+
+ for (auto& droppedPacket : DroppedPackets) {
+ float& t = droppedPacket.second;
+ t += deltaT;
+ }
+
+ for (TPacketHash::iterator i = PacketsInFly.begin(); i != PacketsInFly.end();) {
+ float& t = i->second;
+ t += deltaT;
+ if (t > timeoutVal) {
+ //printf("packet %d timed out (timeout = %g)\n", i->first, timeoutVal);
+ ResendQueue.push_back(i->first);
+ DroppedPackets[i->first] = i->second;
+ TPacketHash::iterator k = i++;
+ PacketsInFly.erase(k);
+ Congestion->Failure();
+ } else {
+ TimeToNextPacketTimeout = Min(TimeToNextPacketTimeout, timeoutVal - t);
+ ++i;
+ }
+ }
+ }
+
+ static std::atomic<ui32> netAckRndVal = (ui32)GetCycleCount();
+ ui32 NetAckRnd() {
+ const auto nextNetAckRndVal = static_cast<ui32>(((ui64)netAckRndVal.load(std::memory_order_acquire) * 279470273) % 4294967291);
+ netAckRndVal.store(nextNetAckRndVal, std::memory_order_release);
+ return nextNetAckRndVal;
+ }
+}
diff --git a/library/cpp/netliba/v6/net_acks.h b/library/cpp/netliba/v6/net_acks.h
new file mode 100644
index 0000000000..a497eaf95c
--- /dev/null
+++ b/library/cpp/netliba/v6/net_acks.h
@@ -0,0 +1,528 @@
+#pragma once
+
+#include "net_test.h"
+#include "net_queue_stat.h"
+
+#include <util/system/spinlock.h>
+
+namespace NNetliba {
+ const float MIN_PACKET_RTT_SKO = 0.001f; // avoid drops due to small hiccups
+
+ const float CONG_CTRL_INITIAL_RTT = 0.24f; //0.01f; // taking into account Las Vegas 10ms estimate is too optimistic
+
+ const float CONG_CTRL_WINDOW_GROW = 0.005f;
+ const float CONG_CTRL_WINDOW_SHRINK = 0.9f;
+ const float CONG_CTRL_WINDOW_SHRINK_RTT = 0.95f;
+ const float CONG_CTRL_RTT_MIX_RATE = 0.9f;
+ const int CONG_CTRL_RTT_SEQ_COUNT = 8;
+ const float CONG_CTRL_MIN_WINDOW = 0.01f;
+ const float CONG_CTRL_LARGE_TIME_WINDOW = 10000.0f;
+ const float CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD = 0.4f; // in seconds
+ const float CONG_CTRL_MINIMAL_SEND_INTERVAL = 1;
+ const float CONG_CTRL_MIN_FAIL_INTERVAL = 0.001f;
+ const float CONG_CTRL_ALLOWED_BURST_SIZE = 3;
+ const float CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION = 0.002f;
+
+ const float LAME_MTU_TIMEOUT = 0.3f;
+ const float LAME_MTU_INTERVAL = 0.05f;
+
+ const float START_CHECK_PORT_DELAY = 0.5;
+ const float FINISH_CHECK_PORT_DELAY = 10;
+ const int N_PORT_TEST_COUNT_LIMIT = 256; // or 512
+
+ // if enabled all acks are sent with different TOS, so they end up in different queue
+ // this allows us to limit window based on minimal RTT observed and 1G link assumption
+ extern bool UseTOSforAcks;
+
+ class TPingTracker {
+ float AvrgRTT, AvrgRTT2; // RTT statistics
+ float RTTCount;
+
+ public:
+ TPingTracker();
+ float GetRTT() const {
+ return AvrgRTT;
+ }
+ float GetRTTSKO() const {
+ float sko = sqrt(fabs(Sqr(AvrgRTT) - AvrgRTT2));
+ float minSKO = Max(MIN_PACKET_RTT_SKO, AvrgRTT * 0.05f);
+ return Max(minSKO, sko);
+ }
+ float GetTimeout() const {
+ return GetRTT() + GetRTTSKO() * 3;
+ }
+ void RegisterRTT(float rtt);
+ void IncreaseRTT();
+ };
+
+ ui32 NetAckRnd();
+
+ class TLameMTUDiscovery: public TThrRefBase {
+ enum EState {
+ NEED_PING,
+ WAIT,
+ };
+
+ float TimePassed, TimeSinceLastPing;
+ EState State;
+
+ public:
+ TLameMTUDiscovery()
+ : TimePassed(0)
+ , TimeSinceLastPing(0)
+ , State(NEED_PING)
+ {
+ }
+ bool CanSend() {
+ return State == NEED_PING;
+ }
+ void PingSent() {
+ State = WAIT;
+ TimeSinceLastPing = 0;
+ }
+ bool IsTimedOut() const {
+ return TimePassed > LAME_MTU_TIMEOUT;
+ }
+ void Step(float deltaT) {
+ TimePassed += deltaT;
+ TimeSinceLastPing += deltaT;
+ if (TimeSinceLastPing > LAME_MTU_INTERVAL)
+ State = NEED_PING;
+ }
+ };
+
+ struct TPeerQueueStats: public IPeerQueueStats {
+ int Count;
+
+ TPeerQueueStats()
+ : Count(0)
+ {
+ }
+ int GetPacketCount() override {
+ return Count;
+ }
+ };
+
+ // pretend we have multiple channels in parallel
+ // not exact approximation since N channels should have N distinct windows
+ extern float CONG_CTRL_CHANNEL_INFLATE;
+
+ class TCongestionControl: public TThrRefBase {
+ float Window, PacketsInFly, FailRate;
+ float MinRTT, MaxWindow;
+ bool FullSpeed, DoCountTime;
+ TPingTracker PingTracker;
+ double TimeSinceLastRecv;
+ TAdaptiveLock PortTesterLock;
+ TIntrusivePtr<TPortUnreachableTester> PortTester;
+ int ActiveTransferCount;
+ float AvrgRTT;
+ int HighRTTCounter;
+ float WindowFraction, FractionRecalc;
+ float TimeWindow;
+ double TimeSinceLastFail;
+ float VirtualPackets;
+ int MTU;
+ TIntrusivePtr<TLameMTUDiscovery> MTUDiscovery;
+ TIntrusivePtr<TPeerQueueStats> QueueStats;
+
+ void CalcMaxWindow() {
+ if (MTU == 0)
+ return;
+ MaxWindow = 125000000 / MTU * Max(0.001f, MinRTT);
+ }
+
+ public:
+ static float StartWindowSize, MaxPacketRate;
+
+ public:
+ TCongestionControl()
+ : Window(StartWindowSize * CONG_CTRL_CHANNEL_INFLATE)
+ , PacketsInFly(0)
+ , FailRate(0)
+ , MinRTT(10)
+ , MaxWindow(10000)
+ , FullSpeed(false)
+ , DoCountTime(false)
+ , TimeSinceLastRecv(0)
+ , ActiveTransferCount(0)
+ , AvrgRTT(0)
+ , HighRTTCounter(0)
+ , WindowFraction(0)
+ , FractionRecalc(0)
+ , TimeWindow(CONG_CTRL_LARGE_TIME_WINDOW)
+ , TimeSinceLastFail(0)
+ , MTU(0)
+ {
+ VirtualPackets = Max(Window - CONG_CTRL_ALLOWED_BURST_SIZE, 0.f);
+ }
+ bool CanSend() {
+ bool res = VirtualPackets + PacketsInFly + WindowFraction <= Window;
+ FullSpeed |= !res;
+ res &= TimeWindow > 0;
+ return res;
+ }
+ void LaunchPacket() {
+ PacketsInFly += 1.0f;
+ TimeWindow -= 1.0f;
+ }
+ void RegisterRTT(float RTT) {
+ if (RTT < 0)
+ return;
+ RTT = ClampVal(RTT, 0.0001f, 1.0f);
+ if (RTT < MinRTT && MTU != 0) {
+ MinRTT = RTT;
+ CalcMaxWindow();
+ }
+ MinRTT = Min(MinRTT, RTT);
+
+ PingTracker.RegisterRTT(RTT);
+ if (AvrgRTT == 0)
+ AvrgRTT = RTT;
+ if (RTT > AvrgRTT) {
+ ++HighRTTCounter;
+ if (HighRTTCounter >= CONG_CTRL_RTT_SEQ_COUNT) {
+ //printf("Too many high RTT in a row\n");
+ if (FullSpeed) {
+ float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK_RTT) / CONG_CTRL_CHANNEL_INFLATE);
+ Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
+ VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
+ //printf("reducing window by RTT , new window %g\n", Window);
+ }
+ // reduce no more then twice per RTT
+ HighRTTCounter = Min(0, CONG_CTRL_RTT_SEQ_COUNT - (int)(Window * 0.5));
+ }
+ } else {
+ HighRTTCounter = Min(0, HighRTTCounter);
+ }
+
+ float rttMixRate = CONG_CTRL_RTT_MIX_RATE;
+ AvrgRTT = AvrgRTT * rttMixRate + RTT * (1 - rttMixRate);
+ }
+ void Success() {
+ PacketsInFly -= 1;
+ Y_ASSERT(PacketsInFly >= 0);
+ // FullSpeed should be correct at this point
+ // we assume that after UpdateAlive() we send all packets first then we listen for acks and call Success()
+ // FullSpeed is set in CanSend() during send if we are using full window
+ // do not increaese window while send rate is limited by virtual packets (ie start of transfer)
+ if (FullSpeed && VirtualPackets == 0) {
+ // there are 2 requirements for window growth
+ // 1) growth should be proportional to window size to ensure constant FailRate
+ // 2) growth should be constant to ensure fairness among different flows
+ // so lets make it square root :)
+ Window += sqrt(Window / CONG_CTRL_CHANNEL_INFLATE) * CONG_CTRL_WINDOW_GROW;
+ if (UseTOSforAcks) {
+ Window = Min(Window, MaxWindow);
+ }
+ }
+ FailRate *= 0.99f;
+ }
+ void FailureOnSend() {
+ //printf("Failure on send\n");
+ PacketsInFly -= 1;
+ Y_ASSERT(PacketsInFly >= 0);
+ // not a congestion event, do not modify Window
+ // do not set FullSpeed since we are not using full Window
+ }
+ void Failure() {
+ //printf("Congestion failure\n");
+ PacketsInFly -= 1;
+ Y_ASSERT(PacketsInFly >= 0);
+ // account limited number of fails per segment
+ if (TimeSinceLastFail > CONG_CTRL_MIN_FAIL_INTERVAL) {
+ TimeSinceLastFail = 0;
+ if (Window <= CONG_CTRL_MIN_WINDOW) {
+ // ping dead hosts less frequently
+ if (PingTracker.GetRTT() / CONG_CTRL_MIN_WINDOW < CONG_CTRL_MINIMAL_SEND_INTERVAL)
+ PingTracker.IncreaseRTT();
+ Window = CONG_CTRL_MIN_WINDOW;
+ VirtualPackets = 0;
+ } else {
+ float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK) / CONG_CTRL_CHANNEL_INFLATE);
+ Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract);
+ VirtualPackets = Max(0.f, VirtualPackets - windowSubtract);
+ }
+ }
+ FailRate = FailRate * 0.99f + 0.01f;
+ }
+ bool HasPacketsInFly() const {
+ return PacketsInFly > 0;
+ }
+ float GetTimeout() const {
+ return PingTracker.GetTimeout();
+ }
+ float GetWindow() const {
+ return Window;
+ }
+ float GetRTT() const {
+ return PingTracker.GetRTT();
+ }
+ float GetFailRate() const {
+ return FailRate;
+ }
+ float GetTimeSinceLastRecv() const {
+ return TimeSinceLastRecv;
+ }
+ int GetTransferCount() const {
+ return ActiveTransferCount;
+ }
+ float GetMaxWindow() const {
+ return UseTOSforAcks ? MaxWindow : -1;
+ }
+ void MarkAlive() {
+ TimeSinceLastRecv = 0;
+
+ with_lock (PortTesterLock) {
+ PortTester = nullptr;
+ }
+
+ }
+ void HasTriedToSend() {
+ DoCountTime = true;
+ }
+ bool IsAlive() const {
+ return TimeSinceLastRecv < 1e6f;
+ }
+ void Kill() {
+ TimeSinceLastRecv = 1e6f;
+ }
+ bool UpdateAlive(const TUdpAddress& toAddress, float deltaT, float timeout, float* resMaxWaitTime) {
+ if (!FullSpeed) {
+ // create virtual packets during idle to avoid burst on transmit start
+ if (AvrgRTT > CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION) {
+ VirtualPackets = Max(VirtualPackets, Window - PacketsInFly - CONG_CTRL_ALLOWED_BURST_SIZE);
+ }
+ } else {
+ if (VirtualPackets > 0) {
+ if (Window <= CONG_CTRL_ALLOWED_BURST_SIZE) {
+ VirtualPackets = 0;
+ }
+ float xRTT = AvrgRTT == 0 ? CONG_CTRL_INITIAL_RTT : AvrgRTT;
+ float virtualPktsPerSecond = Window / xRTT;
+ VirtualPackets = Max(0.f, VirtualPackets - deltaT * virtualPktsPerSecond);
+ *resMaxWaitTime = Min(*resMaxWaitTime, 0.001f); // need to update virtual packets counter regularly
+ }
+ }
+ float currentRTT = GetRTT();
+ FractionRecalc += deltaT;
+ if (FractionRecalc > currentRTT) {
+ int cycleCount = (int)(FractionRecalc / currentRTT);
+ FractionRecalc -= currentRTT * cycleCount;
+ WindowFraction = (NetAckRnd() & 1023) * (1 / 1023.0f) / cycleCount;
+ }
+
+ if (MaxPacketRate > 0 && AvrgRTT > 0) {
+ float maxTimeWindow = CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD * MaxPacketRate;
+ TimeWindow = Min(maxTimeWindow, TimeWindow + MaxPacketRate * deltaT);
+ } else
+ TimeWindow = CONG_CTRL_LARGE_TIME_WINDOW;
+
+ // guarantee minimal send rate
+ if (currentRTT > CONG_CTRL_MINIMAL_SEND_INTERVAL * Window) {
+ Window = Max(CONG_CTRL_MIN_WINDOW, currentRTT / CONG_CTRL_MINIMAL_SEND_INTERVAL);
+ VirtualPackets = 0;
+ }
+
+ TimeSinceLastFail += deltaT;
+
+ //static int n;
+ //if ((++n & 127) == 0)
+ // printf("window = %g, fly = %g, VirtualPkts = %g, deltaT = %g, FailRate = %g FullSpeed = %d AvrgRTT = %g\n",
+ // Window, PacketsInFly, VirtualPackets, deltaT * 1000, FailRate, (int)FullSpeed, AvrgRTT * 1000);
+
+ if (PacketsInFly > 0 || FullSpeed || DoCountTime) {
+ // считаем время только когда есть пакеты в полете
+ TimeSinceLastRecv += deltaT;
+ if (TimeSinceLastRecv > START_CHECK_PORT_DELAY) {
+ if (TimeSinceLastRecv < FINISH_CHECK_PORT_DELAY) {
+ TIntrusivePtr<TPortUnreachableTester> portTester;
+ with_lock (PortTesterLock) {
+ portTester = PortTester;
+ }
+
+ if (!portTester && AtomicGet(ActivePortTestersCount) < N_PORT_TEST_COUNT_LIMIT) {
+ portTester = new TPortUnreachableTester();
+ with_lock (PortTesterLock) {
+ PortTester = portTester;
+ }
+
+ if (portTester->IsValid()) {
+ portTester->Connect(toAddress);
+ } else {
+ with_lock (PortTesterLock) {
+ PortTester = nullptr;
+ }
+ }
+ }
+ if (portTester && !portTester->Test(deltaT)) {
+ Kill();
+ return false;
+ }
+ } else {
+ with_lock (PortTesterLock) {
+ PortTester = nullptr;
+ }
+ }
+ }
+ if (TimeSinceLastRecv > timeout) {
+ Kill();
+ return false;
+ }
+ }
+
+ FullSpeed = false;
+ DoCountTime = false;
+
+ if (MTUDiscovery.Get())
+ MTUDiscovery->Step(deltaT);
+
+ return true;
+ }
+ bool IsKnownMTU() const {
+ return MTU != 0;
+ }
+ int GetMTU() const {
+ return MTU;
+ }
+ TLameMTUDiscovery* GetMTUDiscovery() {
+ if (MTUDiscovery.Get() == nullptr)
+ MTUDiscovery = new TLameMTUDiscovery;
+ return MTUDiscovery.Get();
+ }
+ void SetMTU(int sz) {
+ MTU = sz;
+ MTUDiscovery = nullptr;
+ CalcMaxWindow();
+ }
+ void AttachQueueStats(TIntrusivePtr<TPeerQueueStats> s) {
+ if (s.Get()) {
+ s->Count = ActiveTransferCount;
+ }
+ Y_ASSERT(QueueStats.Get() == nullptr);
+ QueueStats = s;
+ }
+ friend class TCongestionControlPtr;
+ };
+
+ class TCongestionControlPtr {
+ TIntrusivePtr<TCongestionControl> Ptr;
+
+ void Inc() {
+ if (Ptr.Get()) {
+ ++Ptr->ActiveTransferCount;
+ if (Ptr->QueueStats.Get()) {
+ Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
+ }
+ }
+ }
+ void Dec() {
+ if (Ptr.Get()) {
+ --Ptr->ActiveTransferCount;
+ if (Ptr->QueueStats.Get()) {
+ Ptr->QueueStats->Count = Ptr->ActiveTransferCount;
+ }
+ }
+ }
+
+ public:
+ TCongestionControlPtr() {
+ }
+ ~TCongestionControlPtr() {
+ Dec();
+ }
+ TCongestionControlPtr(TCongestionControl* p)
+ : Ptr(p)
+ {
+ Inc();
+ }
+ TCongestionControlPtr& operator=(const TCongestionControlPtr& a) {
+ Dec();
+ Ptr = a.Ptr;
+ Inc();
+ return *this;
+ }
+ TCongestionControlPtr& operator=(TCongestionControl* a) {
+ Dec();
+ Ptr = a;
+ Inc();
+ return *this;
+ }
+ operator TCongestionControl*() const {
+ return Ptr.Get();
+ }
+ TCongestionControl* operator->() const {
+ return Ptr.Get();
+ }
+ TIntrusivePtr<TCongestionControl> Get() const {
+ return Ptr;
+ }
+ };
+
+ class TAckTracker {
+ struct TFlyingPacket {
+ float T;
+ int PktId;
+ TFlyingPacket()
+ : T(0)
+ , PktId(-1)
+ {
+ }
+ TFlyingPacket(float t, int pktId)
+ : T(t)
+ , PktId(pktId)
+ {
+ }
+ };
+ int PacketCount, CurrentPacket;
+ typedef THashMap<int, float> TPacketHash;
+ TPacketHash PacketsInFly, DroppedPackets;
+ TVector<int> ResendQueue;
+ TCongestionControlPtr Congestion;
+ TVector<bool> AckReceived;
+ float TimeToNextPacketTimeout;
+
+ int SelectPacket();
+
+ public:
+ TAckTracker()
+ : PacketCount(0)
+ , CurrentPacket(0)
+ , TimeToNextPacketTimeout(1000)
+ {
+ }
+ ~TAckTracker();
+ void AttachCongestionControl(TCongestionControl* p) {
+ Congestion = p;
+ }
+ TIntrusivePtr<TCongestionControl> GetCongestionControl() const {
+ return Congestion.Get();
+ }
+ void SetPacketCount(int n) {
+ Y_ASSERT(PacketCount == 0);
+ PacketCount = n;
+ AckReceived.resize(n, false);
+ }
+ void Resend();
+ bool IsInitialized() {
+ return PacketCount != 0;
+ }
+ int GetPacketToSend(float deltaT);
+ void AddToResend(int pkt); // called when failed to send packet
+ void Ack(int pkt, float deltaT, bool updateRTT);
+ void AckAll();
+ void MarkAlive() {
+ Congestion->MarkAlive();
+ }
+ bool IsAlive() const {
+ return Congestion->IsAlive();
+ }
+ void Step(float deltaT);
+ bool CanSend() const {
+ return Congestion->CanSend();
+ }
+ float GetTimeToNextPacketTimeout() const {
+ return TimeToNextPacketTimeout;
+ }
+ };
+}
diff --git a/library/cpp/netliba/v6/net_queue_stat.h b/library/cpp/netliba/v6/net_queue_stat.h
new file mode 100644
index 0000000000..75e4b10e35
--- /dev/null
+++ b/library/cpp/netliba/v6/net_queue_stat.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+
+namespace NNetliba {
+ struct IPeerQueueStats: public TThrRefBase {
+ virtual int GetPacketCount() = 0;
+ };
+}
diff --git a/library/cpp/netliba/v6/net_request.cpp b/library/cpp/netliba/v6/net_request.cpp
new file mode 100644
index 0000000000..e3688f2835
--- /dev/null
+++ b/library/cpp/netliba/v6/net_request.cpp
@@ -0,0 +1,5 @@
+#include "stdafx.h"
+#include "net_request.h"
+
+namespace NNetliba {
+}
diff --git a/library/cpp/netliba/v6/net_request.h b/library/cpp/netliba/v6/net_request.h
new file mode 100644
index 0000000000..15bc47ae21
--- /dev/null
+++ b/library/cpp/netliba/v6/net_request.h
@@ -0,0 +1,15 @@
+#pragma once
+
+#include <util/generic/guid.h>
+#include "udp_address.h"
+
+namespace NNetliba {
+ class TRopeDataPacket;
+
+ struct TRequest {
+ TUdpAddress Address;
+ TGUID Guid;
+ TAutoPtr<TRopeDataPacket> Data;
+ };
+
+}
diff --git a/library/cpp/netliba/v6/net_test.cpp b/library/cpp/netliba/v6/net_test.cpp
new file mode 100644
index 0000000000..63fd7c4067
--- /dev/null
+++ b/library/cpp/netliba/v6/net_test.cpp
@@ -0,0 +1,50 @@
+#include "stdafx.h"
+#include "net_test.h"
+#include "udp_address.h"
+
+#ifndef _win_
+#include <errno.h>
+#endif
+
+namespace NNetliba {
+ TAtomic ActivePortTestersCount;
+
+ const float PING_DELAY = 0.5f;
+
+ TPortUnreachableTester::TPortUnreachableTester()
+ : TimePassed(0)
+ , ConnectOk(false)
+
+ {
+ s.Open(0);
+ if (s.IsValid()) {
+ AtomicAdd(ActivePortTestersCount, 1);
+ }
+ }
+
+ void TPortUnreachableTester::Connect(const TUdpAddress& addr) {
+ Y_ASSERT(IsValid());
+ sockaddr_in6 toAddress;
+ GetWinsockAddr(&toAddress, addr);
+ ConnectOk = s.Connect(toAddress);
+ TimePassed = 0;
+ }
+
+ TPortUnreachableTester::~TPortUnreachableTester() {
+ if (s.IsValid())
+ AtomicAdd(ActivePortTestersCount, -1);
+ }
+
+ bool TPortUnreachableTester::Test(float deltaT) {
+ if (!ConnectOk)
+ return false;
+ if (s.IsHostUnreachable())
+ return false;
+ TimePassed += deltaT;
+ if (TimePassed > PING_DELAY) {
+ TimePassed = 0;
+ s.SendEmptyPacket();
+ }
+ return true;
+ }
+}
diff --git a/library/cpp/netliba/v6/net_test.h b/library/cpp/netliba/v6/net_test.h
new file mode 100644
index 0000000000..cfff7409f5
--- /dev/null
+++ b/library/cpp/netliba/v6/net_test.h
@@ -0,0 +1,28 @@
+#pragma once
+
+#include "udp_socket.h"
+
+namespace NNetliba {
+ struct TUdpAddress;
+
+ // needed to limit simultaneous port testers to avoid limit on open handles count
+ extern TAtomic ActivePortTestersCount;
+
+ // need separate socket for each destination
+ // FreeBSD can not return port unreachable error for unconnected socket
+ class TPortUnreachableTester: public TThrRefBase {
+ TNetSocket s;
+ float TimePassed;
+ bool ConnectOk;
+
+ ~TPortUnreachableTester() override;
+
+ public:
+ TPortUnreachableTester();
+ bool IsValid() const {
+ return s.IsValid();
+ }
+ void Connect(const TUdpAddress& addr);
+ bool Test(float deltaT);
+ };
+}
diff --git a/library/cpp/netliba/v6/stdafx.cpp b/library/cpp/netliba/v6/stdafx.cpp
new file mode 100644
index 0000000000..fd4f341c7b
--- /dev/null
+++ b/library/cpp/netliba/v6/stdafx.cpp
@@ -0,0 +1 @@
+#include "stdafx.h"
diff --git a/library/cpp/netliba/v6/stdafx.h b/library/cpp/netliba/v6/stdafx.h
new file mode 100644
index 0000000000..42ea5c0894
--- /dev/null
+++ b/library/cpp/netliba/v6/stdafx.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include "cstdafx.h"
+
+#include <util/system/compat.h>
+#include <util/network/init.h>
+#if defined(_unix_)
+#include <netdb.h>
+#include <fcntl.h>
+#elif defined(_win_)
+#include <winsock2.h>
+using socklen_t = int;
+#endif
+
+#include <util/generic/ptr.h>
+
+template <class T>
+static const T* BreakAliasing(const void* f) {
+ return (const T*)f;
+}
+
+template <class T>
+static T* BreakAliasing(void* f) {
+ return (T*)f;
+}
diff --git a/library/cpp/netliba/v6/udp_address.cpp b/library/cpp/netliba/v6/udp_address.cpp
new file mode 100644
index 0000000000..17540602e9
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_address.cpp
@@ -0,0 +1,300 @@
+#include "stdafx.h"
+#include "udp_address.h"
+
+#include <util/system/mutex.h>
+#include <util/system/spinlock.h>
+
+#ifdef _win_
+#include <iphlpapi.h>
+#pragma comment(lib, "Iphlpapi.lib")
+#else
+#include <errno.h>
+#include <ifaddrs.h>
+#endif
+
+namespace NNetliba {
+ static bool IsValidIPv6(const char* sz) {
+ enum {
+ S1,
+ SEMICOLON,
+ SCOPE
+ };
+ int state = S1, scCount = 0, digitCount = 0, hasDoubleSemicolon = false;
+ while (*sz) {
+ if (state == S1) {
+ switch (*sz) {
+ case '0':
+ case '1':
+ case '2':
+ case '3':
+ case '4':
+ case '5':
+ case '6':
+ case '7':
+ case '8':
+ case '9':
+ case 'A':
+ case 'B':
+ case 'C':
+ case 'D':
+ case 'E':
+ case 'F':
+ case 'a':
+ case 'b':
+ case 'c':
+ case 'd':
+ case 'e':
+ case 'f':
+ ++digitCount;
+ if (digitCount > 4)
+ return false;
+ break;
+ case ':':
+ state = SEMICOLON;
+ ++scCount;
+ break;
+ case '%':
+ state = SCOPE;
+ break;
+ default:
+ return false;
+ }
+ ++sz;
+ } else if (state == SEMICOLON) {
+ if (*sz == ':') {
+ if (hasDoubleSemicolon)
+ return false;
+ hasDoubleSemicolon = true;
+ ++scCount;
+ digitCount = 0;
+ state = S1;
+ ++sz;
+ } else {
+ digitCount = 0;
+ state = S1;
+ }
+ } else if (state == SCOPE) {
+ // arbitrary string is allowed as scope id
+ ++sz;
+ }
+ }
+ if (!hasDoubleSemicolon && scCount != 7)
+ return false;
+ return scCount <= 7;
+ }
+
+ static bool ParseInetName(TUdpAddress* pRes, const char* name, int nDefaultPort, EUdpAddressType addressType) {
+ int nPort = nDefaultPort;
+
+ TString host;
+ if (name[0] == '[') {
+ ++name;
+ const char* nameFin = name;
+ for (; *nameFin; ++nameFin) {
+ if (nameFin[0] == ']')
+ break;
+ }
+ host.assign(name, nameFin);
+ Y_ASSERT(IsValidIPv6(host.c_str()));
+ name = *nameFin ? nameFin + 1 : nameFin;
+ if (name[0] == ':') {
+ char* endPtr = nullptr;
+ nPort = strtol(name + 1, &endPtr, 10);
+ if (!endPtr || *endPtr != '\0')
+ return false;
+ }
+ } else {
+ host = name;
+ if (!IsValidIPv6(name)) {
+ size_t nIdx = host.find(':');
+ if (nIdx != (size_t)TString::npos) {
+ const char* pszPort = host.c_str() + nIdx + 1;
+ char* endPtr = nullptr;
+ nPort = strtol(pszPort, &endPtr, 10);
+ if (!endPtr || *endPtr != '\0')
+ return false;
+ host.resize(nIdx);
+ }
+ }
+ }
+
+ addrinfo aiHints;
+ Zero(aiHints);
+ aiHints.ai_family = AF_UNSPEC;
+ aiHints.ai_socktype = SOCK_DGRAM;
+ aiHints.ai_protocol = IPPROTO_UDP;
+
+ // Do not use TMutex here: it has a non-trivial destructor which will be called before
+ // destruction of current thread, if its TThread declared as global/static variable.
+ static TAdaptiveLock cs;
+ TGuard lock(cs);
+
+ addrinfo* aiList = nullptr;
+ for (int attempt = 0; attempt < 1000; ++attempt) {
+ int rv = getaddrinfo(host.c_str(), "1313", &aiHints, &aiList);
+ if (rv == 0)
+ break;
+ if (aiList) {
+ freeaddrinfo(aiList);
+ }
+ if (rv != EAI_AGAIN) {
+ return false;
+ }
+ usleep(100 * 1000);
+ }
+ for (addrinfo* ptr = aiList; ptr; ptr = ptr->ai_next) {
+ sockaddr* addr = ptr->ai_addr;
+ if (addr == nullptr)
+ continue;
+ switch (addressType) {
+ case UAT_ANY: {
+ if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6)
+ continue;
+ break;
+ }
+ case UAT_IPV4: {
+ if (addr->sa_family != AF_INET)
+ continue;
+ break;
+ }
+ case UAT_IPV6: {
+ if (addr->sa_family != AF_INET6)
+ continue;
+ break;
+ }
+ }
+
+ GetUdpAddress(pRes, *(sockaddr_in6*)addr);
+ pRes->Port = nPort;
+ freeaddrinfo(aiList);
+ return true;
+ }
+ freeaddrinfo(aiList);
+ return false;
+ }
+
+ bool GetLocalAddresses(TVector<TUdpAddress>* addrs) {
+#ifdef _win_
+ TVector<char> buf;
+ buf.resize(1000000);
+ PIP_ADAPTER_ADDRESSES adapterBuf = (PIP_ADAPTER_ADDRESSES)&buf[0];
+ ULONG bufSize = buf.ysize();
+
+ ULONG rv = GetAdaptersAddresses(AF_UNSPEC, 0, NULL, adapterBuf, &bufSize);
+ if (rv != ERROR_SUCCESS)
+ return false;
+ for (PIP_ADAPTER_ADDRESSES ptr = adapterBuf; ptr; ptr = ptr->Next) {
+ if ((ptr->Flags & (IP_ADAPTER_IPV4_ENABLED | IP_ADAPTER_IPV6_ENABLED)) == 0) {
+ continue;
+ }
+ if (ptr->IfType == IF_TYPE_TUNNEL) {
+ // ignore tunnels
+ continue;
+ }
+ if (ptr->OperStatus != IfOperStatusUp) {
+ // ignore disable adapters
+ continue;
+ }
+ if (ptr->Mtu < 1280) {
+ fprintf(stderr, "WARNING: MTU %d is less then ipv6 minimum", ptr->Mtu);
+ }
+ for (IP_ADAPTER_UNICAST_ADDRESS* addr = ptr->FirstUnicastAddress; addr; addr = addr->Next) {
+ sockaddr* x = (sockaddr*)addr->Address.lpSockaddr;
+ if (x == 0)
+ continue;
+ if (x->sa_family == AF_INET || x->sa_family == AF_INET6) {
+ TUdpAddress address;
+ sockaddr_in6* xx = (sockaddr_in6*)x;
+ GetUdpAddress(&address, *xx);
+ addrs->push_back(address);
+ }
+ }
+ }
+ return true;
+#else
+ ifaddrs* ifap;
+ if (getifaddrs(&ifap) != -1) {
+ for (ifaddrs* ifa = ifap; ifa; ifa = ifa->ifa_next) {
+ sockaddr* sa = (sockaddr*)ifa->ifa_addr;
+ if (sa == nullptr)
+ continue;
+ if (sa->sa_family == AF_INET || sa->sa_family == AF_INET6) {
+ TUdpAddress address;
+ sockaddr_in6* xx = (sockaddr_in6*)sa;
+ GetUdpAddress(&address, *xx);
+ addrs->push_back(address);
+ }
+ }
+ freeifaddrs(ifap);
+ return true;
+ }
+ return false;
+#endif
+ }
+
+ void GetUdpAddress(TUdpAddress* res, const sockaddr_in6& addr) {
+ if (addr.sin6_family == AF_INET) {
+ const sockaddr_in& addr4 = *(const sockaddr_in*)&addr;
+ res->Network = 0;
+ res->Interface = 0xffff0000ll + (((ui64)(ui32)addr4.sin_addr.s_addr) << 32);
+ res->Scope = 0;
+ res->Port = ntohs(addr4.sin_port);
+ } else if (addr.sin6_family == AF_INET6) {
+ res->Network = *BreakAliasing<ui64>(addr.sin6_addr.s6_addr + 0);
+ res->Interface = *BreakAliasing<ui64>(addr.sin6_addr.s6_addr + 8);
+ res->Scope = addr.sin6_scope_id;
+ res->Port = ntohs(addr.sin6_port);
+ }
+ }
+
+ void GetWinsockAddr(sockaddr_in6* res, const TUdpAddress& addr) {
+ if (0) { //addr.IsIPv4()) {
+ // use ipv4 to ipv6 mapping
+ //// ipv4
+ //sockaddr_in &toAddress = *(sockaddr_in*)res;
+ //Zero(toAddress);
+ //toAddress.sin_family = AF_INET;
+ //toAddress.sin_addr.s_addr = addr.GetIPv4();
+ //toAddress.sin_port = htons((u_short)addr.Port);
+ } else {
+ // ipv6
+ sockaddr_in6& toAddress = *(sockaddr_in6*)res;
+ Zero(toAddress);
+ toAddress.sin6_family = AF_INET6;
+ *BreakAliasing<ui64>(toAddress.sin6_addr.s6_addr + 0) = addr.Network;
+ *BreakAliasing<ui64>(toAddress.sin6_addr.s6_addr + 8) = addr.Interface;
+ toAddress.sin6_scope_id = addr.Scope;
+ toAddress.sin6_port = htons((u_short)addr.Port);
+ }
+ }
+
+ TUdpAddress CreateAddress(const TString& server, int defaultPort, EUdpAddressType addressType) {
+ TUdpAddress res;
+ ParseInetName(&res, server.c_str(), defaultPort, addressType);
+ return res;
+ }
+
+ TString GetAddressAsString(const TUdpAddress& addr) {
+ char buf[1000];
+ if (addr.IsIPv4()) {
+ int ip = addr.GetIPv4();
+ sprintf(buf, "%d.%d.%d.%d:%d",
+ (ip >> 0) & 0xff, (ip >> 8) & 0xff,
+ (ip >> 16) & 0xff, (ip >> 24) & 0xff,
+ addr.Port);
+ } else {
+ ui16 ipv6[8];
+ *BreakAliasing<ui64>(ipv6) = addr.Network;
+ *BreakAliasing<ui64>(ipv6 + 4) = addr.Interface;
+ char suffix[100] = "";
+ if (addr.Scope != 0) {
+ sprintf(suffix, "%%%d", addr.Scope);
+ }
+ sprintf(buf, "[%x:%x:%x:%x:%x:%x:%x:%x%s]:%d",
+ ntohs(ipv6[0]), ntohs(ipv6[1]), ntohs(ipv6[2]), ntohs(ipv6[3]),
+ ntohs(ipv6[4]), ntohs(ipv6[5]), ntohs(ipv6[6]), ntohs(ipv6[7]),
+ suffix, addr.Port);
+ }
+ return buf;
+ }
+}
diff --git a/library/cpp/netliba/v6/udp_address.h b/library/cpp/netliba/v6/udp_address.h
new file mode 100644
index 0000000000..3e283fe545
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_address.h
@@ -0,0 +1,48 @@
+#pragma once
+
+#include <util/generic/string.h>
+#include <util/generic/vector.h>
+#include <util/system/defaults.h>
+
+struct sockaddr_in6;
+
+namespace NNetliba {
+ struct TUdpAddress {
+ ui64 Network, Interface;
+ int Scope, Port;
+
+ TUdpAddress()
+ : Network(0)
+ , Interface(0)
+ , Scope(0)
+ , Port(0)
+ {
+ }
+ bool IsIPv4() const {
+ return (Network == 0 && (Interface & 0xffffffffll) == 0xffff0000ll);
+ }
+ ui32 GetIPv4() const {
+ return Interface >> 32;
+ }
+ };
+
+ inline bool operator==(const TUdpAddress& a, const TUdpAddress& b) {
+ return a.Network == b.Network && a.Interface == b.Interface && a.Scope == b.Scope && a.Port == b.Port;
+ }
+
+ enum EUdpAddressType {
+ UAT_ANY,
+ UAT_IPV4,
+ UAT_IPV6,
+ };
+
+ // accepts sockaddr_in & sockaddr_in6
+ void GetUdpAddress(TUdpAddress* res, const sockaddr_in6& addr);
+ // generates sockaddr_in6 for both ipv4 & ipv6
+ void GetWinsockAddr(sockaddr_in6* res, const TUdpAddress& addr);
+ // supports formats like hostname, hostname:124, 127.0.0.1, 127.0.0.1:80, fe34::12, [fe34::12]:80
+ TUdpAddress CreateAddress(const TString& server, int defaultPort, EUdpAddressType type = UAT_ANY);
+ TString GetAddressAsString(const TUdpAddress& addr);
+
+ bool GetLocalAddresses(TVector<TUdpAddress>* addrs);
+}
diff --git a/library/cpp/netliba/v6/udp_client_server.cpp b/library/cpp/netliba/v6/udp_client_server.cpp
new file mode 100644
index 0000000000..3eaf6e5e96
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_client_server.cpp
@@ -0,0 +1,1321 @@
+#include "stdafx.h"
+#include "udp_client_server.h"
+#include "net_acks.h"
+#include <util/generic/guid.h>
+#include <util/system/hp_timer.h>
+#include <util/datetime/cputimer.h>
+#include <util/system/yield.h>
+#include <util/system/unaligned_mem.h>
+#include "block_chain.h"
+#include <util/system/shmat.h>
+#include "udp_debug.h"
+#include "udp_socket.h"
+#include "ib_cs.h"
+
+#include <library/cpp/netliba/socket/socket.h>
+
+#include <util/random/random.h>
+#include <util/system/sanitizers.h>
+
+#include <atomic>
+
+namespace NNetliba {
+ // rely on UDP checksum in packets, check crc only for complete packets
+ // UPDATE: looks like UDP checksum is not enough, network errors do happen, we saw 600+ retransmits of a ~1MB data packet
+
+ const float UDP_TRANSFER_TIMEOUT = 90.0f;
+ const float DEFAULT_MAX_WAIT_TIME = 1;
+ const float UDP_KEEP_PEER_INFO = 600;
+ // траффик может идти, а новых данных для конкретного пакета может не добавляться.
+ // это возможно когда мы прерываем процесс в момент передачи и перезапускаем его на том же порту,
+ // тогда на приемнике повиснет пакет. Этот пакет мы зашибем по этому таймауту
+ const float UDP_MAX_INPUT_DATA_WAIT = UDP_TRANSFER_TIMEOUT * 2;
+
+ enum {
+ UDP_PACKET_SIZE_FULL = 8900, // used for ping to detect jumbo-frame support
+ UDP_PACKET_SIZE = 8800, // max data in packet
+ UDP_PACKET_SIZE_SMALL = 1350, // 1180 would be better taking into account that 1280 is guaranteed ipv6 minimum MTU
+ UDP_PACKET_BUF_SIZE = UDP_PACKET_SIZE + 100,
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ struct TUdpCompleteInTransfer {
+ TGUID PacketGuid;
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ struct TUdpRecvPacket {
+ int DataStart, DataSize;
+ ui32 BlockSum;
+ // Data[] should be last member in struct, this fact is used to create truncated TUdpRecvPacket in CreateNewSmallPacket()
+ char Data[UDP_PACKET_BUF_SIZE];
+ };
+
+ struct TUdpInTransfer {
+ private:
+ TVector<TUdpRecvPacket*> Packets;
+
+ public:
+ sockaddr_in6 ToAddress;
+ int PacketSize, LastPacketSize;
+ bool HasLastPacket;
+ TVector<int> NewPacketsToAck;
+ TCongestionControlPtr Congestion;
+ float TimeSinceLastRecv;
+ int Attempt;
+ TGUID PacketGuid;
+ int Crc32;
+ TIntrusivePtr<TSharedMemory> SharedData;
+ TRequesterPendingDataStats* Stats;
+
+ TUdpInTransfer()
+ : PacketSize(0)
+ , LastPacketSize(0)
+ , HasLastPacket(false)
+ , TimeSinceLastRecv(0)
+ , Attempt(0)
+ , Crc32(0)
+ , Stats(nullptr)
+ {
+ Zero(ToAddress);
+ }
+ ~TUdpInTransfer() {
+ if (Stats) {
+ Stats->InpCount -= 1;
+ }
+ EraseAllPackets();
+ }
+ void EraseAllPackets() {
+ for (int i = 0; i < Packets.ysize(); ++i) {
+ ErasePacket(i);
+ }
+ Packets.clear();
+ HasLastPacket = false;
+ }
+ void AttachStats(TRequesterPendingDataStats* stats) {
+ Stats = stats;
+ Stats->InpCount += 1;
+ Y_ASSERT(Packets.empty());
+ }
+ void ErasePacket(int id) {
+ TUdpRecvPacket* pkt = Packets[id];
+ if (pkt) {
+ if (Stats) {
+ Stats->InpDataSize -= PacketSize;
+ }
+ TRopeDataPacket::FreeBuf((char*)pkt);
+ Packets[id] = nullptr;
+ }
+ }
+ void AssignPacket(int id, TUdpRecvPacket* pkt) {
+ ErasePacket(id);
+ if (pkt && Stats) {
+ Stats->InpDataSize += PacketSize;
+ }
+ Packets[id] = pkt;
+ }
+ int GetPacketCount() const {
+ return Packets.ysize();
+ }
+ void SetPacketCount(int n) {
+ Packets.resize(n, nullptr);
+ }
+ const TUdpRecvPacket* GetPacket(int id) const {
+ return Packets[id];
+ }
+ TUdpRecvPacket* ExtractPacket(int id) {
+ TUdpRecvPacket* res = Packets[id];
+ if (res) {
+ if (Stats) {
+ Stats->InpDataSize -= PacketSize;
+ }
+ Packets[id] = nullptr;
+ }
+ return res;
+ }
+ };
+
+ struct TUdpOutTransfer {
+ sockaddr_in6 ToAddress;
+ TAutoPtr<TRopeDataPacket> Data;
+ int PacketCount;
+ int PacketSize, LastPacketSize;
+ TAckTracker AckTracker;
+ int Attempt;
+ TGUID PacketGuid;
+ int Crc32;
+ EPacketPriority PacketPriority;
+ TRequesterPendingDataStats* Stats;
+
+ TUdpOutTransfer()
+ : PacketCount(0)
+ , PacketSize(0)
+ , LastPacketSize(0)
+ , Attempt(0)
+ , Crc32(0)
+ , PacketPriority(PP_LOW)
+ , Stats(nullptr)
+ {
+ Zero(ToAddress);
+ }
+ ~TUdpOutTransfer() {
+ if (Stats) {
+ Stats->OutCount -= 1;
+ Stats->OutDataSize -= Data->GetSize();
+ }
+ }
+ void AttachStats(TRequesterPendingDataStats* stats) {
+ Stats = stats;
+ Stats->OutCount += 1;
+ Stats->OutDataSize += Data->GetSize();
+ }
+ };
+
+ struct TTransferKey {
+ TUdpAddress Address;
+ int Id;
+ };
+ inline bool operator==(const TTransferKey& a, const TTransferKey& b) {
+ return a.Address == b.Address && a.Id == b.Id;
+ }
+ struct TTransferKeyHash {
+ int operator()(const TTransferKey& k) const {
+ return (ui32)k.Address.Interface + (ui32)k.Address.Port * (ui32)389461 + (ui32)k.Id;
+ }
+ };
+
+ struct TUdpAddressHash {
+ int operator()(const TUdpAddress& addr) const {
+ return (ui32)addr.Interface + (ui32)addr.Port * (ui32)389461;
+ }
+ };
+
+ class TUdpHostRevBufAlloc: public TNonCopyable {
+ TUdpRecvPacket* RecvPktBuf;
+
+ void AllocNewBuf() {
+ RecvPktBuf = (TUdpRecvPacket*)TRopeDataPacket::AllocBuf(sizeof(TUdpRecvPacket));
+ }
+
+ public:
+ TUdpHostRevBufAlloc() {
+ AllocNewBuf();
+ }
+ ~TUdpHostRevBufAlloc() {
+ FreeBuf(RecvPktBuf);
+ }
+ void FreeBuf(TUdpRecvPacket* pkt) {
+ TRopeDataPacket::FreeBuf((char*)pkt);
+ }
+ TUdpRecvPacket* ExtractPacket() {
+ TUdpRecvPacket* res = RecvPktBuf;
+ AllocNewBuf();
+ return res;
+ }
+ TUdpRecvPacket* CreateNewSmallPacket(int sz) {
+ int pktStructSz = sizeof(TUdpRecvPacket) - Y_ARRAY_SIZE(RecvPktBuf->Data) + sz;
+ TUdpRecvPacket* pkt = (TUdpRecvPacket*)TRopeDataPacket::AllocBuf(pktStructSz);
+ return pkt;
+ }
+ int GetBufSize() const {
+ return Y_ARRAY_SIZE(RecvPktBuf->Data);
+ }
+ char* GetDataPtr() const {
+ return RecvPktBuf->Data;
+ }
+ };
+
+ static TAtomic transferIdCounter = (long)(GetCycleCount() & 0x1fffffff);
+ inline int GetTransferId() {
+ int res = AtomicAdd(transferIdCounter, 1);
+ while (res < 0) {
+ // negative transfer ids are treated as errors, so wrap transfer id
+ AtomicCas(&transferIdCounter, 0, transferIdCounter);
+ res = AtomicAdd(transferIdCounter, 1);
+ }
+ return res;
+ }
+
+ static bool IBDetection = true;
+ class TUdpHost: public IUdpHost {
+ struct TPeerLink {
+ TIntrusivePtr<TCongestionControl> UdpCongestion;
+ TIntrusivePtr<IIBPeer> IBPeer;
+ double TimeNoActiveTransfers;
+
+ TPeerLink()
+ : TimeNoActiveTransfers(0)
+ {
+ }
+ bool Update(float deltaT, const TUdpAddress& toAddress, float* maxWaitTime) {
+ bool updateOk = UdpCongestion->UpdateAlive(toAddress, deltaT, UDP_TRANSFER_TIMEOUT, maxWaitTime);
+ return updateOk;
+ }
+ void StartSleep(const TUdpAddress& toAddress, float* maxWaitTime) {
+ //printf("peer_link start sleep, IBPeer = %p, refs = %d\n", IBPeer.Get(), (int)IBPeer.RefCount());
+ UdpCongestion->UpdateAlive(toAddress, 0, UDP_TRANSFER_TIMEOUT, maxWaitTime);
+ UdpCongestion->MarkAlive();
+ TimeNoActiveTransfers = 0;
+ }
+ bool UpdateSleep(float deltaT) {
+ TimeNoActiveTransfers += deltaT;
+ if (IBPeer.Get()) {
+ //printf("peer_link update sleep, IBPeer = %p, refs = %d\n", IBPeer.Get(), (int)IBPeer.RefCount());
+ if (IBPeer->GetState() == IIBPeer::OK) {
+ return true;
+ }
+ //printf("Drop broken IB connection\n");
+ IBPeer = nullptr;
+ }
+ return (TimeNoActiveTransfers < UDP_KEEP_PEER_INFO);
+ }
+ };
+
+ TNetSocket s;
+ typedef THashMap<TTransferKey, TUdpInTransfer, TTransferKeyHash> TUdpInXferHash;
+ typedef THashMap<TTransferKey, TUdpOutTransfer, TTransferKeyHash> TUdpOutXferHash;
+ // congestion control per peer
+ typedef THashMap<TUdpAddress, TPeerLink, TUdpAddressHash> TPeerLinkHash;
+ typedef THashMap<TTransferKey, TUdpCompleteInTransfer, TTransferKeyHash> TUdpCompleteInXferHash;
+ typedef THashMap<TUdpAddress, TIntrusivePtr<TPeerQueueStats>, TUdpAddressHash> TQueueStatsHash;
+ TUdpInXferHash RecvQueue;
+ TUdpCompleteInXferHash RecvCompleted;
+ TUdpOutXferHash SendQueue;
+ TPeerLinkHash CongestionTrack, CongestionTrackHistory;
+ TList<TRequest*> ReceivedList;
+ NHPTimer::STime CurrentT;
+ TList<TSendResult> SendResults;
+ TList<TTransferKey> SendOrderLow, SendOrder, SendOrderHighPrior;
+ TAtomic IsWaiting;
+ float MaxWaitTime;
+ std::atomic<float> MaxWaitTime2;
+ float IBIdleTime;
+ TVector<TTransferKey> RecvCompletedQueue, KeepCompletedQueue;
+ float TimeSinceCompletedQueueClean, TimeSinceCongestionHistoryUpdate;
+ TRequesterPendingDataStats PendingDataStats;
+ TQueueStatsHash PeerQueueStats;
+ TIntrusivePtr<IIBClientServer> IB;
+ typedef THashMap<TIBMsgHandle, TTransferKey> TIBtoTransferKeyHash;
+ TIBtoTransferKeyHash IBKeyToTransferKey;
+
+ char PktBuf[UDP_PACKET_BUF_SIZE];
+ TUdpHostRevBufAlloc RecvBuf;
+
+ TPeerLink& GetPeerLink(const TUdpAddress& ip) {
+ TPeerLinkHash::iterator z = CongestionTrack.find(ip);
+ if (z == CongestionTrack.end()) {
+ z = CongestionTrackHistory.find(ip);
+ if (z == CongestionTrackHistory.end()) {
+ TPeerLink& res = CongestionTrack[ip];
+ Y_ASSERT(res.UdpCongestion.Get() == nullptr);
+ res.UdpCongestion = new TCongestionControl;
+ TQueueStatsHash::iterator zq = PeerQueueStats.find(ip);
+ if (zq != PeerQueueStats.end()) {
+ res.UdpCongestion->AttachQueueStats(zq->second);
+ }
+ return res;
+ } else {
+ TPeerLink& res = CongestionTrack[z->first];
+ res = z->second;
+ CongestionTrackHistory.erase(z);
+ return res;
+ }
+ } else {
+ Y_ASSERT(CongestionTrackHistory.find(ip) == CongestionTrackHistory.end());
+ return z->second;
+ }
+ }
+ void SucceededSend(int id) {
+ SendResults.push_back(TSendResult(id, true));
+ }
+ void FailedSend(int id) {
+ SendResults.push_back(TSendResult(id, false));
+ }
+ void SendData(TList<TTransferKey>* order, float deltaT, bool needCheckAlive);
+ void RecvCycle();
+
+ public:
+ TUdpHost()
+ : CurrentT(0)
+ , IsWaiting(0)
+ , MaxWaitTime(DEFAULT_MAX_WAIT_TIME)
+ , MaxWaitTime2(DEFAULT_MAX_WAIT_TIME)
+ , IBIdleTime(0)
+ , TimeSinceCompletedQueueClean(0)
+ , TimeSinceCongestionHistoryUpdate(0)
+ {
+ }
+ ~TUdpHost() override {
+ for (TList<TRequest*>::const_iterator i = ReceivedList.begin(); i != ReceivedList.end(); ++i)
+ delete *i;
+ }
+
+ bool Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket);
+
+ TRequest* GetRequest() override {
+ if (ReceivedList.empty()) {
+ if (IB.Get()) {
+ return IB->GetRequest();
+ }
+ return nullptr;
+ }
+ TRequest* res = ReceivedList.front();
+ ReceivedList.pop_front();
+ return res;
+ }
+
+ void AddToSendOrder(const TTransferKey& transferKey, EPacketPriority pp) {
+ if (pp == PP_LOW)
+ SendOrderLow.push_back(transferKey);
+ else if (pp == PP_NORMAL)
+ SendOrder.push_back(transferKey);
+ else if (pp == PP_HIGH)
+ SendOrderHighPrior.push_back(transferKey);
+ else
+ Y_ASSERT(0);
+
+ CancelWait();
+ }
+
+ int Send(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data, int crc32, TGUID* packetGuid, EPacketPriority pp) override {
+ if (addr.Port == 0) {
+ // shortcut for broken addresses
+ if (packetGuid && packetGuid->IsEmpty())
+ CreateGuid(packetGuid);
+ int reqId = GetTransferId();
+ FailedSend(reqId);
+ return reqId;
+ }
+ TTransferKey transferKey;
+ transferKey.Address = addr;
+ transferKey.Id = GetTransferId();
+ Y_ASSERT(SendQueue.find(transferKey) == SendQueue.end());
+
+ TPeerLink& peerInfo = GetPeerLink(transferKey.Address);
+
+ TUdpOutTransfer& xfer = SendQueue[transferKey];
+ GetWinsockAddr(&xfer.ToAddress, transferKey.Address);
+ xfer.Crc32 = crc32;
+ xfer.PacketPriority = pp;
+ if (!packetGuid || packetGuid->IsEmpty()) {
+ CreateGuid(&xfer.PacketGuid);
+ if (packetGuid)
+ *packetGuid = xfer.PacketGuid;
+ } else {
+ xfer.PacketGuid = *packetGuid;
+ }
+ xfer.Data.Reset(data.Release());
+ xfer.AttachStats(&PendingDataStats);
+ xfer.AckTracker.AttachCongestionControl(peerInfo.UdpCongestion.Get());
+
+ bool isSentOverIB = false;
+ // we don't support priorities (=service levels in IB terms) currently
+ // so send only PP_NORMAL traffic over IB
+ if (pp == PP_NORMAL && peerInfo.IBPeer.Get() && xfer.Data->GetSharedData() == nullptr) {
+ TIBMsgHandle hndl = IB->Send(peerInfo.IBPeer, xfer.Data.Get(), xfer.PacketGuid);
+ if (hndl >= 0) {
+ IBKeyToTransferKey[hndl] = transferKey;
+ isSentOverIB = true;
+ } else {
+ // so we failed to use IB, ibPeer is either not connected yet or failed
+ if (peerInfo.IBPeer->GetState() == IIBPeer::FAILED) {
+ //printf("Disconnect failed IB peer\n");
+ peerInfo.IBPeer = nullptr;
+ }
+ }
+ }
+ if (!isSentOverIB) {
+ AddToSendOrder(transferKey, pp);
+ }
+
+ return transferKey.Id;
+ }
+
+ bool GetSendResult(TSendResult* res) override {
+ if (SendResults.empty()) {
+ if (IB.Get()) {
+ TIBSendResult sr;
+ if (IB->GetSendResult(&sr)) {
+ TIBtoTransferKeyHash::iterator z = IBKeyToTransferKey.find(sr.Handle);
+ if (z == IBKeyToTransferKey.end()) {
+ Y_VERIFY(0, "unknown handle returned from IB");
+ }
+ TTransferKey transferKey = z->second;
+ IBKeyToTransferKey.erase(z);
+
+ TUdpOutXferHash::iterator i = SendQueue.find(transferKey);
+ if (i == SendQueue.end()) {
+ Y_VERIFY(0, "IBKeyToTransferKey refers nonexisting xfer");
+ }
+ if (sr.Success) {
+ TUdpOutTransfer& xfer = i->second;
+ xfer.AckTracker.MarkAlive(); // do we really need this?
+ *res = TSendResult(transferKey.Id, sr.Success);
+ SendQueue.erase(i);
+ return true;
+ } else {
+ //printf("IB send failed, fall back to regular network\n");
+ // Houston, we got a problem
+ // IB failed to send, try to use regular network
+ TUdpOutTransfer& xfer = i->second;
+ AddToSendOrder(transferKey, xfer.PacketPriority);
+ }
+ }
+ }
+ return false;
+ }
+ *res = SendResults.front();
+ SendResults.pop_front();
+ return true;
+ }
+
+ void Step() override;
+ void IBStep() override;
+
+ void Wait(float seconds) override {
+ if (seconds < 1e-3)
+ seconds = 0;
+ if (seconds > MaxWaitTime)
+ seconds = MaxWaitTime;
+ if (IBIdleTime < 0.010) {
+ seconds = 0;
+ }
+ if (seconds == 0) {
+ ThreadYield();
+ } else {
+ AtomicAdd(IsWaiting, 1);
+ if (seconds > MaxWaitTime2)
+ seconds = MaxWaitTime2;
+ MaxWaitTime2 = DEFAULT_MAX_WAIT_TIME;
+
+ if (seconds == 0) {
+ ThreadYield();
+ } else {
+ if (IB.Get()) {
+ for (float done = 0; done < seconds;) {
+ float deltaSleep = Min(seconds - done, 0.002f);
+ s.Wait(deltaSleep);
+ NHPTimer::STime tChk;
+ NHPTimer::GetTime(&tChk);
+ if (IB->Step(tChk)) {
+ IBIdleTime = 0;
+ break;
+ }
+ done += deltaSleep;
+ }
+ } else {
+ s.Wait(seconds);
+ }
+ }
+ AtomicAdd(IsWaiting, -1);
+ }
+ }
+
+ void CancelWait() override {
+ MaxWaitTime2 = 0;
+ if (AtomicAdd(IsWaiting, 0) == 1) {
+ s.SendSelfFakePacket();
+ }
+ }
+
+ void GetPendingDataSize(TRequesterPendingDataStats* res) override {
+ *res = PendingDataStats;
+#ifndef NDEBUG
+ TRequesterPendingDataStats chk;
+ for (TUdpOutXferHash::const_iterator i = SendQueue.begin(); i != SendQueue.end(); ++i) {
+ TRopeDataPacket* pckt = i->second.Data.Get();
+ if (pckt) {
+ chk.OutDataSize += pckt->GetSize();
+ ++chk.OutCount;
+ }
+ }
+ for (TUdpInXferHash::const_iterator i = RecvQueue.begin(); i != RecvQueue.end(); ++i) {
+ const TUdpInTransfer& tr = i->second;
+ for (int p = 0; p < tr.GetPacketCount(); ++p) {
+ if (tr.GetPacket(p)) {
+ chk.InpDataSize += tr.PacketSize;
+ }
+ }
+ ++chk.InpCount;
+ }
+ Y_ASSERT(memcmp(&chk, res, sizeof(chk)) == 0);
+#endif
+ }
+ TString GetDebugInfo() override;
+ TString GetPeerLinkDebug(const TPeerLinkHash& ch);
+ void Kill(const TUdpAddress& addr) override;
+ TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) override;
+ };
+
+ bool TUdpHost::Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) {
+ if (s.IsValid()) {
+ Y_ASSERT(0);
+ return false;
+ }
+ s.Open(socket);
+ if (!s.IsValid())
+ return false;
+
+ if (IBDetection)
+ IB = CreateIBClientServer();
+
+ NHPTimer::GetTime(&CurrentT);
+ return true;
+ }
+
+ static bool HasAllPackets(const TUdpInTransfer& res) {
+ if (!res.HasLastPacket)
+ return false;
+ for (int i = res.GetPacketCount() - 1; i >= 0; --i) {
+ if (!res.GetPacket(i))
+ return false;
+ }
+ return true;
+ }
+
+ // grouped acks, first int - packet_id, second int - bit mask for 32 packets preceding packet_id
+ const int SIZEOF_ACK = 8;
+ static int WriteAck(TUdpInTransfer* p, int* dst, int maxAcks) {
+ int ackCount = 0;
+ if (p->NewPacketsToAck.size() > 1)
+ Sort(p->NewPacketsToAck.begin(), p->NewPacketsToAck.end());
+ int lastAcked = 0;
+ for (size_t idx = 0; idx < p->NewPacketsToAck.size(); ++idx) {
+ int pkt = p->NewPacketsToAck[idx];
+ if (idx == p->NewPacketsToAck.size() - 1 || pkt > lastAcked + 30) {
+ *dst++ = pkt;
+ int bitMask = 0;
+ int backPackets = Min(pkt, 32);
+ for (int k = 0; k < backPackets; ++k) {
+ if (p->GetPacket(pkt - k - 1))
+ bitMask |= 1 << k;
+ }
+ *dst++ = bitMask;
+ if (++ackCount >= maxAcks)
+ break;
+ lastAcked = pkt;
+ //printf("sending ack %d (mask %x)\n", pkt, bitMask);
+ }
+ }
+ p->NewPacketsToAck.clear();
+ return ackCount;
+ }
+
+ static void AckPacket(TUdpOutTransfer* p, int pkt, float deltaT, bool updateRTT) {
+ if (pkt < 0 || pkt >= p->PacketCount) {
+ Y_ASSERT(0);
+ return;
+ }
+ p->AckTracker.Ack(pkt, deltaT, updateRTT);
+ }
+
+ static void ReadAcks(TUdpOutTransfer* p, const int* acks, int ackCount, float deltaT) {
+ for (int i = 0; i < ackCount; ++i) {
+ int pkt = *acks++;
+ int bitMask = *acks++;
+ bool updateRTT = i == ackCount - 1; // update RTT using only last packet in the pack
+ AckPacket(p, pkt, deltaT, updateRTT);
+ for (int k = 0; k < 32; ++k) {
+ if (bitMask & (1 << k))
+ AckPacket(p, pkt - k - 1, deltaT, false);
+ }
+ }
+ }
+
+ using namespace NNetlibaSocket::NNetliba;
+
+ const ui64 KILL_PASSPHRASE1 = 0x98ff9cefb11d9a4cul;
+ const ui64 KILL_PASSPHRASE2 = 0xf7754c29e0be95eaul;
+
+ template <class T>
+ inline T Read(char** data) {
+ T res = ReadUnaligned<T>(*data);
+ *data += sizeof(T);
+ return res;
+ }
+ template <class T>
+ inline void Write(char** data, T res) {
+ WriteUnaligned<T>(*data, res);
+ *data += sizeof(T);
+ }
+
+ static void RequireResend(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, int attempt) {
+ char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, transferId);
+ Write(&pktData, (char)ACK_RESEND);
+ Write(&pktData, attempt);
+ s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG);
+ }
+
+ static void RequireResendNoShmem(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, int attempt) {
+ char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, transferId);
+ Write(&pktData, (char)ACK_RESEND_NOSHMEM);
+ Write(&pktData, attempt);
+ s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG);
+ }
+
+ static void AckComplete(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, const TGUID& packetGuid, int packetId) {
+ char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, transferId);
+ Write(&pktData, (char)ACK_COMPLETE);
+ Write(&pktData, packetGuid);
+ Write(&pktData, packetId); // we need packetId to update RTT
+ s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG);
+ }
+
+ static void SendPing(TNetSocket& s, const sockaddr_in6& toAddress, int selfNetworkOrderPort) {
+ char pktBuf[UDP_PACKET_SIZE_FULL];
+ char* pktData = pktBuf + UDP_LOW_LEVEL_HEADER_SIZE;
+ if (NSan::MSanIsOn()) {
+ Zero(pktBuf);
+ }
+ Write(&pktData, (int)0);
+ Write(&pktData, (char)PING);
+ Write(&pktData, selfNetworkOrderPort);
+ s.SendTo(pktBuf, UDP_PACKET_SIZE_FULL, toAddress, FF_DONT_FRAG);
+ }
+
+ // not MTU discovery, just figure out IB address of the peer
+ static void SendFakePing(TNetSocket& s, const sockaddr_in6& toAddress, int selfNetworkOrderPort) {
+ char buf[100];
+ char* pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, (int)0);
+ Write(&pktData, (char)PING);
+ Write(&pktData, selfNetworkOrderPort);
+ s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG);
+ }
+
+ void TUdpHost::SendData(TList<TTransferKey>* order, float deltaT1, bool needCheckAlive) {
+ for (TList<TTransferKey>::iterator z = order->begin(); z != order->end();) {
+ // pick connection to send
+ const TTransferKey& transferKey = *z;
+ TUdpOutXferHash::iterator i = SendQueue.find(transferKey);
+ if (i == SendQueue.end()) {
+ z = order->erase(z);
+ continue;
+ }
+ ++z;
+
+ // perform sending
+ int transferId = transferKey.Id;
+ TUdpOutTransfer& xfer = i->second;
+
+ if (!xfer.AckTracker.IsInitialized()) {
+ TIntrusivePtr<TCongestionControl> congestion = xfer.AckTracker.GetCongestionControl();
+ Y_ASSERT(congestion.Get() != nullptr);
+ if (!congestion->IsKnownMTU()) {
+ TLameMTUDiscovery* md = congestion->GetMTUDiscovery();
+ if (md->IsTimedOut()) {
+ congestion->SetMTU(UDP_PACKET_SIZE_SMALL);
+ } else {
+ if (md->CanSend()) {
+ SendPing(s, xfer.ToAddress, s.GetNetworkOrderPort());
+ md->PingSent();
+ }
+ continue;
+ }
+ }
+ // try to use large mtu, we could have selected small mtu due to connectivity problems
+ if (congestion->GetMTU() == UDP_PACKET_SIZE_SMALL || IB.Get() != nullptr) {
+ // recheck every ~50mb
+ int chkDenom = (50000000 / xfer.Data->GetSize()) | 1;
+ if ((NetAckRnd() % chkDenom) == 0) {
+ //printf("send rechecking ping\n");
+ if (congestion->GetMTU() == UDP_PACKET_SIZE_SMALL) {
+ SendPing(s, xfer.ToAddress, s.GetNetworkOrderPort());
+ } else {
+ SendFakePing(s, xfer.ToAddress, s.GetNetworkOrderPort());
+ }
+ }
+ }
+ xfer.PacketSize = congestion->GetMTU();
+ xfer.LastPacketSize = xfer.Data->GetSize() % xfer.PacketSize;
+ xfer.PacketCount = xfer.Data->GetSize() / xfer.PacketSize + 1;
+ xfer.AckTracker.SetPacketCount(xfer.PacketCount);
+ }
+
+ xfer.AckTracker.Step(deltaT1);
+ MaxWaitTime = Min(MaxWaitTime, xfer.AckTracker.GetTimeToNextPacketTimeout());
+ if (needCheckAlive && !xfer.AckTracker.IsAlive()) {
+ FailedSend(transferId);
+ SendQueue.erase(i);
+ continue;
+ }
+ bool sendBufferOverflow = false;
+ while (xfer.AckTracker.CanSend()) {
+ NHPTimer::STime tCopy = CurrentT;
+ float deltaT2 = (float)NHPTimer::GetTimePassed(&tCopy);
+ deltaT2 = ClampVal(deltaT2, 0.0f, UDP_TRANSFER_TIMEOUT / 3);
+
+ int pkt = xfer.AckTracker.GetPacketToSend(deltaT2);
+ if (pkt == -1) {
+ break;
+ }
+
+ int dataSize = xfer.PacketSize;
+ if (pkt == xfer.PacketCount - 1)
+ dataSize = xfer.LastPacketSize;
+
+ char* pktData = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, transferId);
+ char pktType = xfer.PacketSize == UDP_PACKET_SIZE ? DATA : DATA_SMALL;
+ TSharedMemory* shm = xfer.Data->GetSharedData();
+ if (shm) {
+ if (pktType == DATA)
+ pktType = DATA_SHMEM;
+ else
+ pktType = DATA_SMALL_SHMEM;
+ }
+ Write(&pktData, pktType);
+ Write(&pktData, xfer.Attempt);
+ Write(&pktData, pkt);
+ if (pkt == 0) {
+ Write(&pktData, xfer.PacketGuid);
+ Write(&pktData, xfer.Crc32);
+ if (shm) {
+ Write(&pktData, shm->GetId());
+ Write(&pktData, shm->GetSize());
+ }
+ }
+ TBlockChainIterator dataReader(xfer.Data->GetChain());
+ dataReader.Seek(pkt * xfer.PacketSize);
+ dataReader.Read(pktData, dataSize);
+ pktData += dataSize;
+ int sendSize = (int)(pktData - PktBuf);
+ TNetSocket::ESendError sendErr = s.SendTo(PktBuf, sendSize, xfer.ToAddress, FF_ALLOW_FRAG);
+ if (sendErr != TNetSocket::SEND_OK) {
+ if (sendErr == TNetSocket::SEND_NO_ROUTE_TO_HOST) {
+ FailedSend(transferId);
+ SendQueue.erase(i);
+ break;
+ } else {
+ // most probably out of send buffer space (or something terrible has happened)
+ xfer.AckTracker.AddToResend(pkt);
+ sendBufferOverflow = true;
+ MaxWaitTime = 0;
+ //printf("failed send\n");
+ break;
+ }
+ }
+ }
+ if (sendBufferOverflow)
+ break;
+ }
+ }
+
+ void TUdpHost::RecvCycle() {
+ for (;;) {
+ sockaddr_in6 fromAddress;
+ int rv = RecvBuf.GetBufSize();
+ bool recvOk = s.RecvFrom(RecvBuf.GetDataPtr(), &rv, &fromAddress);
+ if (!recvOk)
+ break;
+
+ NHPTimer::STime tCopy = CurrentT;
+ float deltaT = (float)NHPTimer::GetTimePassed(&tCopy);
+ deltaT = ClampVal(deltaT, 0.0f, UDP_TRANSFER_TIMEOUT / 3);
+
+ //int fromIP = fromAddress.sin_addr.s_addr;
+
+ TTransferKey k;
+ char* pktData = RecvBuf.GetDataPtr() + UDP_LOW_LEVEL_HEADER_SIZE;
+ GetUdpAddress(&k.Address, fromAddress);
+ k.Id = Read<int>(&pktData);
+ int transferId = k.Id;
+ int cmd = Read<char>(&pktData);
+ Y_ASSERT(cmd == (int)*(RecvBuf.GetDataPtr() + CMD_POS));
+ switch (cmd) {
+ case DATA:
+ case DATA_SMALL:
+ case DATA_SHMEM:
+ case DATA_SMALL_SHMEM: {
+ int attempt = Read<int>(&pktData);
+ int packetId = Read<int>(&pktData);
+ //printf("data packet %d (trans ID = %d)\n", packetId, transferId);
+ TUdpCompleteInXferHash::iterator itCompl = RecvCompleted.find(k);
+ if (itCompl != RecvCompleted.end()) {
+ Y_ASSERT(RecvQueue.find(k) == RecvQueue.end());
+ const TUdpCompleteInTransfer& complete = itCompl->second;
+ bool sendAckComplete = true;
+ if (packetId == 0) {
+ // check packet GUID
+ char* tmpPktData = pktData;
+ TGUID packetGuid;
+ packetGuid = Read<TGUID>(&tmpPktData);
+ if (packetGuid != complete.PacketGuid) {
+ // we are receiving new data with the same transferId
+ // in this case we have to flush all the information about previous transfer
+ // and start over
+ //printf("same transferId for a different packet\n");
+ RecvCompleted.erase(itCompl);
+ sendAckComplete = false;
+ }
+ }
+ if (sendAckComplete) {
+ AckComplete(s, fromAddress, transferId, complete.PacketGuid, packetId);
+ break;
+ }
+ }
+ TUdpInXferHash::iterator rq = RecvQueue.find(k);
+ if (rq == RecvQueue.end()) {
+ //printf("new input transfer\n");
+ TUdpInTransfer& res = RecvQueue[k];
+ res.ToAddress = fromAddress;
+ res.Attempt = attempt;
+ res.Congestion = GetPeerLink(k.Address).UdpCongestion.Get();
+ res.PacketSize = 0;
+ res.HasLastPacket = false;
+ res.AttachStats(&PendingDataStats);
+ rq = RecvQueue.find(k);
+ Y_ASSERT(rq != RecvQueue.end());
+ }
+ TUdpInTransfer& res = rq->second;
+ res.Congestion->MarkAlive();
+ res.TimeSinceLastRecv = 0;
+
+ if (packetId == 0) {
+ TGUID packetGuid;
+ packetGuid = Read<TGUID>(&pktData);
+ int crc32 = Read<int>(&pktData);
+ res.Crc32 = crc32;
+ res.PacketGuid = packetGuid;
+ if (cmd == DATA_SHMEM || cmd == DATA_SMALL_SHMEM) {
+ // link to attached shared memory
+ TGUID shmemId = Read<TGUID>(&pktData);
+ int shmemSize = Read<int>(&pktData);
+ if (res.SharedData.Get() == nullptr) {
+ res.SharedData = new TSharedMemory;
+ if (!res.SharedData->Open(shmemId, shmemSize)) {
+ res.SharedData = nullptr;
+ RequireResendNoShmem(s, res.ToAddress, transferId, res.Attempt);
+ break;
+ }
+ }
+ }
+ }
+ if (attempt != res.Attempt) {
+ RequireResend(s, res.ToAddress, transferId, res.Attempt);
+ break;
+ } else {
+ if (res.PacketSize == 0) {
+ res.PacketSize = (cmd == DATA || cmd == DATA_SHMEM ? UDP_PACKET_SIZE : UDP_PACKET_SIZE_SMALL);
+ } else {
+ // check that all data is of same size
+ Y_ASSERT(cmd == DATA || cmd == DATA_SMALL);
+ Y_ASSERT(res.PacketSize == (cmd == DATA ? UDP_PACKET_SIZE : UDP_PACKET_SIZE_SMALL));
+ }
+
+ int dataSize = (int)(RecvBuf.GetDataPtr() + rv - pktData);
+
+ Y_ASSERT(dataSize <= res.PacketSize);
+ if (dataSize > res.PacketSize)
+ break; // mem overrun protection
+ if (packetId >= res.GetPacketCount())
+ res.SetPacketCount(packetId + 1);
+ {
+ TUdpRecvPacket* pkt = nullptr;
+ if (res.PacketSize == UDP_PACKET_SIZE_SMALL) {
+ // save memory by using smaller buffer at the cost of additional memcpy
+ pkt = RecvBuf.CreateNewSmallPacket(dataSize);
+ memcpy(pkt->Data, pktData, dataSize);
+ pkt->DataStart = 0;
+ pkt->DataSize = dataSize;
+ } else {
+ int dataStart = (int)(pktData - RecvBuf.GetDataPtr()); // data offset in the packet
+ pkt = RecvBuf.ExtractPacket();
+ pkt->DataStart = dataStart;
+ pkt->DataSize = dataSize;
+ }
+ // calc packet sum, will be used to calc whole message crc
+ pkt->BlockSum = TIncrementalChecksumCalcer::CalcBlockSum(pkt->Data + pkt->DataStart, pkt->DataSize);
+ res.AssignPacket(packetId, pkt);
+ }
+
+ if (dataSize != res.PacketSize) {
+ res.LastPacketSize = dataSize;
+ res.HasLastPacket = true;
+ }
+
+ if (HasAllPackets(res)) {
+ //printf("received\n");
+ TRequest* out = new TRequest;
+ out->Address = k.Address;
+ out->Guid = res.PacketGuid;
+ TIncrementalChecksumCalcer incCS;
+ int packetCount = res.GetPacketCount();
+ out->Data.Reset(new TRopeDataPacket);
+ for (int i = 0; i < packetCount; ++i) {
+ TUdpRecvPacket* pkt = res.ExtractPacket(i);
+ Y_ASSERT(pkt->DataSize == ((i == packetCount - 1) ? res.LastPacketSize : res.PacketSize));
+ out->Data->AddBlock((char*)pkt, pkt->Data + pkt->DataStart, pkt->DataSize);
+ incCS.AddBlockSum(pkt->BlockSum, pkt->DataSize);
+ }
+ out->Data->AttachSharedData(res.SharedData);
+ res.EraseAllPackets();
+
+ int crc32 = incCS.CalcChecksum(); // CalcChecksum(out->Data->GetChain());
+#ifdef SIMULATE_NETWORK_FAILURES
+ bool crcOk = crc32 == res.Crc32 ? (RandomNumber<size_t>() % 10) != 0 : false;
+#else
+ bool crcOk = crc32 == res.Crc32;
+#endif
+ if (crcOk) {
+ ReceivedList.push_back(out);
+ Y_ASSERT(RecvCompleted.find(k) == RecvCompleted.end());
+ TUdpCompleteInTransfer& complete = RecvCompleted[k];
+ RecvCompletedQueue.push_back(k);
+ complete.PacketGuid = res.PacketGuid;
+ AckComplete(s, res.ToAddress, transferId, complete.PacketGuid, packetId);
+ RecvQueue.erase(rq);
+ } else {
+ //printf("crc failed, require resend\n");
+ delete out;
+ ++res.Attempt;
+ res.NewPacketsToAck.clear();
+ RequireResend(s, res.ToAddress, transferId, res.Attempt);
+ }
+ } else {
+ res.NewPacketsToAck.push_back(packetId);
+ }
+ }
+ } break;
+ case ACK: {
+ TUdpOutXferHash::iterator i = SendQueue.find(k);
+ if (i == SendQueue.end())
+ break;
+ TUdpOutTransfer& xfer = i->second;
+ if (!xfer.AckTracker.IsInitialized())
+ break;
+ xfer.AckTracker.MarkAlive();
+ int attempt = Read<int>(&pktData);
+ Y_ASSERT(attempt <= xfer.Attempt);
+ if (attempt != xfer.Attempt)
+ break;
+ ReadAcks(&xfer, (int*)pktData, (int)(RecvBuf.GetDataPtr() + rv - pktData) / SIZEOF_ACK, deltaT);
+ break;
+ }
+ case ACK_COMPLETE: {
+ TUdpOutXferHash::iterator i = SendQueue.find(k);
+ if (i == SendQueue.end())
+ break;
+ TUdpOutTransfer& xfer = i->second;
+ xfer.AckTracker.MarkAlive();
+ TGUID packetGuid;
+ packetGuid = Read<TGUID>(&pktData);
+ int packetId = Read<int>(&pktData);
+ if (packetGuid == xfer.PacketGuid) {
+ xfer.AckTracker.Ack(packetId, deltaT, true); // update RTT
+ xfer.AckTracker.AckAll(); // acking packets is required, otherwise they will be treated as lost (look AckTracker destructor)
+ SucceededSend(transferId);
+ SendQueue.erase(i);
+ } else {
+ // peer asserts that he has received this packet but packetGuid is wrong
+ // try to resend everything
+ // ++xfer.Attempt; // should not do this, only sender can modify attempt number, otherwise cycle is possible with out of order packets
+ xfer.AckTracker.Resend();
+ }
+ break;
+ } break;
+ case ACK_RESEND: {
+ TUdpOutXferHash::iterator i = SendQueue.find(k);
+ if (i == SendQueue.end())
+ break;
+ TUdpOutTransfer& xfer = i->second;
+ xfer.AckTracker.MarkAlive();
+ int attempt = Read<int>(&pktData);
+ if (xfer.Attempt != attempt) {
+ // reset current tranfser & initialize new one
+ xfer.Attempt = attempt;
+ xfer.AckTracker.Resend();
+ }
+ break;
+ }
+ case ACK_RESEND_NOSHMEM: {
+ // abort execution here
+ // failed to open shmem on recv side, need to transmit data without using shmem
+ Y_VERIFY(0, "not implemented yet");
+ break;
+ }
+ case PING: {
+ sockaddr_in6 trueFromAddress = fromAddress;
+ int port = Read<int>(&pktData);
+ Y_ASSERT(trueFromAddress.sin6_family == AF_INET6);
+ trueFromAddress.sin6_port = port;
+ // can not set MTU for fromAddress here since asymmetrical mtu is possible
+ char* pktData2 = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData2, (int)0);
+ Write(&pktData2, (char)PONG);
+ if (IB.Get()) {
+ const TIBConnectInfo& ibConnectInfo = IB->GetConnectInfo();
+ Write(&pktData2, ibConnectInfo);
+ Write(&pktData2, trueFromAddress);
+ }
+ s.SendTo(PktBuf, pktData2 - PktBuf, trueFromAddress, FF_ALLOW_FRAG);
+ break;
+ }
+ case PONG: {
+ TPeerLink& peerInfo = GetPeerLink(k.Address);
+ peerInfo.UdpCongestion->SetMTU(UDP_PACKET_SIZE);
+ int dataSize = (int)(RecvBuf.GetDataPtr() + rv - pktData);
+ if (dataSize == sizeof(TIBConnectInfo) + sizeof(sockaddr_in6)) {
+ if (IB.Get() != nullptr && peerInfo.IBPeer.Get() == nullptr) {
+ TIBConnectInfo info = Read<TIBConnectInfo>(&pktData);
+ sockaddr_in6 myAddress = Read<sockaddr_in6>(&pktData);
+ TUdpAddress myUdpAddress;
+ GetUdpAddress(&myUdpAddress, myAddress);
+ peerInfo.IBPeer = IB->ConnectPeer(info, k.Address, myUdpAddress);
+ }
+ }
+ break;
+ }
+ case KILL: {
+ ui64 p1 = Read<ui64>(&pktData);
+ ui64 p2 = Read<ui64>(&pktData);
+ int restSize = (int)(RecvBuf.GetDataPtr() + rv - pktData);
+ if (restSize == 0 && p1 == KILL_PASSPHRASE1 && p2 == KILL_PASSPHRASE2) {
+ abort();
+ }
+ break;
+ }
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ }
+
+ void TUdpHost::IBStep() {
+ if (IB.Get()) {
+ NHPTimer::STime tChk = CurrentT;
+ float chkDeltaT = (float)NHPTimer::GetTimePassed(&tChk);
+ if (IB->Step(tChk)) {
+ IBIdleTime = -chkDeltaT;
+ }
+ }
+ }
+
+ void TUdpHost::Step() {
+ if (IB.Get()) {
+ NHPTimer::STime tChk = CurrentT;
+ float chkDeltaT = (float)NHPTimer::GetTimePassed(&tChk);
+ if (IB->Step(tChk)) {
+ IBIdleTime = -chkDeltaT;
+ }
+ if (chkDeltaT < 0.0005) {
+ return;
+ }
+ }
+
+ if (UseTOSforAcks) {
+ s.SetTOS(0x20);
+ } else {
+ s.SetTOS(0);
+ }
+
+ RecvCycle();
+
+ float deltaT = (float)NHPTimer::GetTimePassed(&CurrentT);
+ deltaT = ClampVal(deltaT, 0.0f, UDP_TRANSFER_TIMEOUT / 3);
+
+ MaxWaitTime = DEFAULT_MAX_WAIT_TIME;
+ IBIdleTime += deltaT;
+
+ bool needCheckAlive = false;
+
+ // update alive ports
+ const float INACTIVE_CONGESTION_UPDATE_INTERVAL = 1;
+ TimeSinceCongestionHistoryUpdate += deltaT;
+ if (TimeSinceCongestionHistoryUpdate > INACTIVE_CONGESTION_UPDATE_INTERVAL) {
+ for (TPeerLinkHash::iterator i = CongestionTrackHistory.begin(); i != CongestionTrackHistory.end();) {
+ TPeerLink& pl = i->second;
+ if (!pl.UpdateSleep(TimeSinceCongestionHistoryUpdate)) {
+ TPeerLinkHash::iterator k = i++;
+ CongestionTrackHistory.erase(k);
+ needCheckAlive = true;
+ } else {
+ ++i;
+ }
+ }
+ TimeSinceCongestionHistoryUpdate = 0;
+ }
+ for (TPeerLinkHash::iterator i = CongestionTrack.begin(); i != CongestionTrack.end();) {
+ const TUdpAddress& addr = i->first;
+ TPeerLink& pl = i->second;
+ if (pl.UdpCongestion->GetTransferCount() == 0) {
+ pl.StartSleep(addr, &MaxWaitTime);
+ CongestionTrackHistory[i->first] = i->second;
+ TPeerLinkHash::iterator k = i++;
+ CongestionTrack.erase(k);
+ } else if (!pl.Update(deltaT, addr, &MaxWaitTime)) {
+ TPeerLinkHash::iterator k = i++;
+ CongestionTrack.erase(k);
+ needCheckAlive = true;
+ } else {
+ ++i;
+ }
+ }
+
+ // send acks on received data
+ for (TUdpInXferHash::iterator i = RecvQueue.begin(); i != RecvQueue.end();) {
+ const TTransferKey& transKey = i->first;
+ int transferId = transKey.Id;
+ TUdpInTransfer& xfer = i->second;
+ xfer.TimeSinceLastRecv += deltaT;
+ if (xfer.TimeSinceLastRecv > UDP_MAX_INPUT_DATA_WAIT || (needCheckAlive && !xfer.Congestion->IsAlive())) {
+ TUdpInXferHash::iterator k = i++;
+ RecvQueue.erase(k);
+ continue;
+ }
+ Y_ASSERT(RecvCompleted.find(i->first) == RecvCompleted.end()); // state "Complete & incomplete" is incorrect
+ if (!xfer.NewPacketsToAck.empty()) {
+ char* pktData = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, transferId);
+ Write(&pktData, (char)ACK);
+ Write(&pktData, xfer.Attempt);
+ int acks = WriteAck(&xfer, (int*)pktData, (int)(xfer.PacketSize - (pktData - PktBuf)) / SIZEOF_ACK);
+ pktData += acks * SIZEOF_ACK;
+ s.SendTo(PktBuf, (int)(pktData - PktBuf), xfer.ToAddress, FF_ALLOW_FRAG);
+ }
+ ++i;
+ }
+
+ if (UseTOSforAcks) {
+ s.SetTOS(0x60);
+ }
+
+ // send data for outbound connections
+ SendData(&SendOrderHighPrior, deltaT, needCheckAlive);
+ SendData(&SendOrder, deltaT, needCheckAlive);
+ SendData(&SendOrderLow, deltaT, needCheckAlive);
+
+ // roll send order to avoid exotic problems with lots of peers and high traffic
+ SendOrderHighPrior.splice(SendOrderHighPrior.end(), SendOrderHighPrior, SendOrderHighPrior.begin());
+ //SendOrder.splice(SendOrder.end(), SendOrder, SendOrder.begin()); // sending data in order has lower delay and shorter queue
+
+ // clean completed queue
+ TimeSinceCompletedQueueClean += deltaT;
+ if (TimeSinceCompletedQueueClean > UDP_TRANSFER_TIMEOUT * 1.5) {
+ for (size_t i = 0; i < KeepCompletedQueue.size(); ++i) {
+ TUdpCompleteInXferHash::iterator k = RecvCompleted.find(KeepCompletedQueue[i]);
+ if (k != RecvCompleted.end())
+ RecvCompleted.erase(k);
+ }
+ KeepCompletedQueue.clear();
+ KeepCompletedQueue.swap(RecvCompletedQueue);
+ TimeSinceCompletedQueueClean = 0;
+ }
+ }
+
+ TString TUdpHost::GetPeerLinkDebug(const TPeerLinkHash& ch) {
+ TString res;
+ char buf[1000];
+ for (const auto& i : ch) {
+ const TUdpAddress& ip = i.first;
+ const TCongestionControl& cc = *i.second.UdpCongestion;
+ IIBPeer* ibPeer = i.second.IBPeer.Get();
+ sprintf(buf, "%s\tIB: %d, RTT: %g Timeout: %g Window: %g MaxWin: %g FailRate: %g TimeSinceLastRecv: %g Transfers: %d MTU: %d\n",
+ GetAddressAsString(ip).c_str(),
+ ibPeer ? ibPeer->GetState() : -1,
+ cc.GetRTT() * 1000, cc.GetTimeout() * 1000, cc.GetWindow(), cc.GetMaxWindow(), cc.GetFailRate(),
+ cc.GetTimeSinceLastRecv() * 1000, cc.GetTransferCount(), cc.GetMTU());
+ res += buf;
+ }
+ return res;
+ }
+
+ TString TUdpHost::GetDebugInfo() {
+ TString res;
+ char buf[1000];
+ sprintf(buf, "Receiving %d msgs, sending %d high prior, %d regular msgs, %d low prior msgs\n",
+ RecvQueue.ysize(), (int)SendOrderHighPrior.size(), (int)SendOrder.size(), (int)SendOrderLow.size());
+ res += buf;
+
+ TRequesterPendingDataStats pds;
+ GetPendingDataSize(&pds);
+ sprintf(buf, "Pending data size: %" PRIu64 "\n", pds.InpDataSize + pds.OutDataSize);
+ res += buf;
+ sprintf(buf, " in packets: %d, size %" PRIu64 "\n", pds.InpCount, pds.InpDataSize);
+ res += buf;
+ sprintf(buf, " out packets: %d, size %" PRIu64 "\n", pds.OutCount, pds.OutDataSize);
+ res += buf;
+
+ res += "\nCongestion info:\n";
+ res += GetPeerLinkDebug(CongestionTrack);
+ res += "\nCongestion info history:\n";
+ res += GetPeerLinkDebug(CongestionTrackHistory);
+
+ return res;
+ }
+
+ static void SendKill(const TNetSocket& s, const sockaddr_in6& toAddress) {
+ char buf[100];
+ char* pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE;
+ Write(&pktData, (int)0);
+ Write(&pktData, (char)KILL);
+ Write(&pktData, KILL_PASSPHRASE1);
+ Write(&pktData, KILL_PASSPHRASE2);
+ s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG);
+ }
+
+ void TUdpHost::Kill(const TUdpAddress& addr) {
+ sockaddr_in6 target;
+ GetWinsockAddr(&target, addr);
+ SendKill(s, target);
+ }
+
+ TIntrusivePtr<IPeerQueueStats> TUdpHost::GetQueueStats(const TUdpAddress& addr) {
+ TQueueStatsHash::iterator zq = PeerQueueStats.find(addr);
+ if (zq != PeerQueueStats.end()) {
+ return zq->second.Get();
+ }
+ TPeerQueueStats* res = new TPeerQueueStats;
+ PeerQueueStats[addr] = res;
+ // attach to existing congestion tracker
+ TPeerLinkHash::iterator z;
+ z = CongestionTrack.find(addr);
+ if (z != CongestionTrack.end()) {
+ z->second.UdpCongestion->AttachQueueStats(res);
+ }
+ z = CongestionTrackHistory.find(addr);
+ if (z != CongestionTrackHistory.end()) {
+ z->second.UdpCongestion->AttachQueueStats(res);
+ }
+ return res;
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+
+ TIntrusivePtr<IUdpHost> CreateUdpHost(int port) {
+ TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateBestRecvSocket();
+ socket->Open(port);
+ if (!socket->IsValid())
+ return nullptr;
+ return CreateUdpHost(socket);
+ }
+
+ TIntrusivePtr<IUdpHost> CreateUdpHost(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) {
+ if (!InitLocalIPList()) {
+ Y_ASSERT(0 && "Can not determine self IP address");
+ return nullptr;
+ }
+ TIntrusivePtr<TUdpHost> res = new TUdpHost;
+ if (!res->Start(socket))
+ return nullptr;
+ return res.Get();
+ }
+
+ void SetUdpMaxBandwidthPerIP(float f) {
+ f = Max(0.0f, f);
+ TCongestionControl::MaxPacketRate = f / UDP_PACKET_SIZE;
+ }
+
+ void SetUdpSlowStart(bool enable) {
+ TCongestionControl::StartWindowSize = enable ? 0.5f : 3;
+ }
+
+ void DisableIBDetection() {
+ IBDetection = false;
+ }
+
+}
diff --git a/library/cpp/netliba/v6/udp_client_server.h b/library/cpp/netliba/v6/udp_client_server.h
new file mode 100644
index 0000000000..23e0209405
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_client_server.h
@@ -0,0 +1,62 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <util/generic/guid.h>
+#include <library/cpp/netliba/socket/socket.h>
+
+#include "udp_address.h"
+#include "net_request.h"
+
+namespace NNetliba {
+ class TRopeDataPacket;
+ struct TRequesterPendingDataStats;
+ struct IPeerQueueStats;
+
+ struct TSendResult {
+ int TransferId;
+ bool Success;
+ TSendResult()
+ : TransferId(-1)
+ , Success(false)
+ {
+ }
+ TSendResult(int transferId, bool success)
+ : TransferId(transferId)
+ , Success(success)
+ {
+ }
+ };
+
+ enum EPacketPriority {
+ PP_LOW,
+ PP_NORMAL,
+ PP_HIGH
+ };
+
+ // Step should be called from one and the same thread
+ // thread safety is caller responsibility
+ struct IUdpHost: public TThrRefBase {
+ virtual TRequest* GetRequest() = 0;
+ // returns trasferId
+ // Send() needs correctly computed crc32
+ // crc32 is expected to be computed outside of the thread talking to IUdpHost to avoid crc32 computation delays
+ // packetGuid provides packet guid, if packetGuid is empty then guid is generated
+ virtual int Send(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data, int crc32, TGUID* packetGuid, EPacketPriority pp) = 0;
+ virtual bool GetSendResult(TSendResult* res) = 0;
+ virtual void Step() = 0;
+ virtual void IBStep() = 0;
+ virtual void Wait(float seconds) = 0; // does not use UdpHost
+ virtual void CancelWait() = 0; // thread safe
+ virtual void GetPendingDataSize(TRequesterPendingDataStats* res) = 0;
+ virtual TString GetDebugInfo() = 0;
+ virtual void Kill(const TUdpAddress& addr) = 0;
+ virtual TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) = 0;
+ };
+
+ TIntrusivePtr<IUdpHost> CreateUdpHost(int port);
+ TIntrusivePtr<IUdpHost> CreateUdpHost(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket);
+
+ void SetUdpMaxBandwidthPerIP(float f);
+ void SetUdpSlowStart(bool enable);
+ void DisableIBDetection();
+}
diff --git a/library/cpp/netliba/v6/udp_debug.cpp b/library/cpp/netliba/v6/udp_debug.cpp
new file mode 100644
index 0000000000..1e02fa9e2b
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_debug.cpp
@@ -0,0 +1,2 @@
+#include "stdafx.h"
+#include "udp_debug.h"
diff --git a/library/cpp/netliba/v6/udp_debug.h b/library/cpp/netliba/v6/udp_debug.h
new file mode 100644
index 0000000000..77eea61c23
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_debug.h
@@ -0,0 +1,21 @@
+#pragma once
+
+namespace NNetliba {
+ struct TRequesterPendingDataStats {
+ int InpCount, OutCount;
+ ui64 InpDataSize, OutDataSize;
+
+ TRequesterPendingDataStats() {
+ memset(this, 0, sizeof(*this));
+ }
+ };
+
+ struct TRequesterQueueStats {
+ int ReqCount, RespCount;
+ ui64 ReqQueueSize, RespQueueSize;
+
+ TRequesterQueueStats() {
+ memset(this, 0, sizeof(*this));
+ }
+ };
+}
diff --git a/library/cpp/netliba/v6/udp_http.cpp b/library/cpp/netliba/v6/udp_http.cpp
new file mode 100644
index 0000000000..9fa0b07818
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_http.cpp
@@ -0,0 +1,1354 @@
+#include "stdafx.h"
+#include "udp_http.h"
+#include "udp_client_server.h"
+#include "udp_socket.h"
+#include "cpu_affinity.h"
+
+#include <library/cpp/threading/atomic/bool.h>
+
+#include <util/system/hp_timer.h>
+#include <util/thread/lfqueue.h>
+#include <util/system/thread.h>
+#include <util/system/spinlock.h>
+#if !defined(_win_)
+#include <signal.h>
+#include <pthread.h>
+#endif
+#include "block_chain.h"
+#include <util/system/shmat.h>
+
+#include <atomic>
+
+namespace NNetliba {
+ const float HTTP_TIMEOUT = 15.0f;
+ const int MIN_SHARED_MEM_PACKET = 1000;
+
+ static ::NAtomic::TBool PanicAttack;
+ static std::atomic<NHPTimer::STime> LastHeartbeat;
+ static std::atomic<double> HeartbeatTimeout;
+
+ static int GetPacketSize(TRequest* req) {
+ if (req && req->Data.Get())
+ return req->Data->GetSize();
+ return 0;
+ }
+
+ static bool IsLocalFast(const TUdpAddress& addr) {
+ if (addr.IsIPv4()) {
+ return IsLocalIPv4(addr.GetIPv4());
+ } else {
+ return IsLocalIPv6(addr.Network, addr.Interface);
+ }
+ }
+
+ bool IsLocal(const TUdpAddress& addr) {
+ InitLocalIPList();
+ return IsLocalFast(addr);
+ }
+
+ TUdpHttpRequest::~TUdpHttpRequest() {
+ }
+
+ TUdpHttpResponse::~TUdpHttpResponse() {
+ }
+
+ class TRequesterUserQueueSizes: public TThrRefBase {
+ public:
+ TAtomic ReqCount, RespCount;
+ TAtomic ReqQueueSize, RespQueueSize;
+
+ TRequesterUserQueueSizes()
+ : ReqCount(0)
+ , RespCount(0)
+ , ReqQueueSize(0)
+ , RespQueueSize(0)
+ {
+ }
+ };
+
+ template <class T>
+ void EraseList(TLockFreeQueue<T*>* data) {
+ T* ptr = nullptr;
+ while (data->Dequeue(&ptr)) {
+ delete ptr;
+ }
+ }
+
+ class TRequesterUserQueues: public TThrRefBase {
+ TIntrusivePtr<TRequesterUserQueueSizes> QueueSizes;
+ TLockFreeQueue<TUdpHttpRequest*> ReqList;
+ TLockFreeQueue<TUdpHttpResponse*> ResponseList;
+ TLockFreeStack<TGUID> CancelList, SendRequestAccList; // any order will do
+ TMuxEvent AsyncEvent;
+
+ void UpdateAsyncSignalState() {
+ // not sure about this one. Idea is that AsyncEvent.Reset() is a memory barrier
+ if (ReqList.IsEmpty() && ResponseList.IsEmpty() && CancelList.IsEmpty() && SendRequestAccList.IsEmpty()) {
+ AsyncEvent.Reset();
+ if (!ReqList.IsEmpty() || !ResponseList.IsEmpty() || !CancelList.IsEmpty() || !SendRequestAccList.IsEmpty())
+ AsyncEvent.Signal();
+ }
+ }
+ ~TRequesterUserQueues() override {
+ EraseList(&ReqList);
+ EraseList(&ResponseList);
+ }
+
+ public:
+ TRequesterUserQueues(TRequesterUserQueueSizes* queueSizes)
+ : QueueSizes(queueSizes)
+ {
+ }
+ TUdpHttpRequest* GetRequest();
+ TUdpHttpResponse* GetResponse();
+ bool GetRequestCancel(TGUID* req) {
+ bool res = CancelList.Dequeue(req);
+ UpdateAsyncSignalState();
+ return res;
+ }
+ bool GetSendRequestAcc(TGUID* req) {
+ bool res = SendRequestAccList.Dequeue(req);
+ UpdateAsyncSignalState();
+ return res;
+ }
+
+ void AddRequest(TUdpHttpRequest* res) {
+ AtomicAdd(QueueSizes->ReqCount, 1);
+ AtomicAdd(QueueSizes->ReqQueueSize, GetPacketSize(res->DataHolder.Get()));
+ ReqList.Enqueue(res);
+ AsyncEvent.Signal();
+ }
+ void AddResponse(TUdpHttpResponse* res) {
+ AtomicAdd(QueueSizes->RespCount, 1);
+ AtomicAdd(QueueSizes->RespQueueSize, GetPacketSize(res->DataHolder.Get()));
+ ResponseList.Enqueue(res);
+ AsyncEvent.Signal();
+ }
+ void AddCancel(const TGUID& req) {
+ CancelList.Enqueue(req);
+ AsyncEvent.Signal();
+ }
+ void AddSendRequestAcc(const TGUID& req) {
+ SendRequestAccList.Enqueue(req);
+ AsyncEvent.Signal();
+ }
+ TMuxEvent& GetAsyncEvent() {
+ return AsyncEvent;
+ }
+ void AsyncSignal() {
+ AsyncEvent.Signal();
+ }
+ };
+
+ struct TOutRequestState {
+ enum EState {
+ S_SENDING,
+ S_WAITING,
+ S_WAITING_PING_SENDING,
+ S_WAITING_PING_SENT,
+ S_CANCEL_AFTER_SENDING
+ };
+ EState State;
+ TUdpAddress Address;
+ double TimePassed;
+ int PingTransferId;
+ TIntrusivePtr<TRequesterUserQueues> UserQueues;
+
+ TOutRequestState()
+ : State(S_SENDING)
+ , TimePassed(0)
+ , PingTransferId(-1)
+ {
+ }
+ };
+
+ struct TInRequestState {
+ enum EState {
+ S_WAITING,
+ S_RESPONSE_SENDING,
+ S_CANCELED,
+ };
+ EState State;
+ TUdpAddress Address;
+
+ TInRequestState()
+ : State(S_WAITING)
+ {
+ }
+ TInRequestState(const TUdpAddress& address)
+ : State(S_WAITING)
+ , Address(address)
+ {
+ }
+ };
+
+ enum EHttpPacket {
+ PKT_REQUEST,
+ PKT_PING,
+ PKT_PING_RESPONSE,
+ PKT_RESPONSE,
+ PKT_GETDEBUGINFO,
+ PKT_LOCAL_REQUEST,
+ PKT_LOCAL_RESPONSE,
+ PKT_CANCEL,
+ };
+
+ class TUdpHttp: public IRequester {
+ enum EDir {
+ DIR_OUT,
+ DIR_IN
+ };
+ struct TTransferPurpose {
+ EDir Dir;
+ TGUID Guid;
+ TTransferPurpose()
+ : Dir(DIR_OUT)
+ {
+ }
+ TTransferPurpose(EDir dir, TGUID guid)
+ : Dir(dir)
+ , Guid(guid)
+ {
+ }
+ };
+
+ struct TSendRequest {
+ TUdpAddress Addr;
+ TAutoPtr<TRopeDataPacket> Data;
+ TGUID ReqGuid;
+ TIntrusivePtr<TWaitResponse> WR;
+ TIntrusivePtr<TRequesterUserQueues> UserQueues;
+ ui32 Crc32;
+
+ TSendRequest()
+ : Crc32(0)
+ {
+ }
+ TSendRequest(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket>* data, const TGUID& reqguid, TWaitResponse* wr, TRequesterUserQueues* userQueues)
+ : Addr(addr)
+ , Data(*data)
+ , ReqGuid(reqguid)
+ , WR(wr)
+ , UserQueues(userQueues)
+ , Crc32(CalcChecksum(Data->GetChain()))
+ {
+ }
+ };
+ struct TSendResponse {
+ TVector<char> Data;
+ TGUID ReqGuid;
+ ui32 DataCrc32;
+ EPacketPriority Priority;
+
+ TSendResponse()
+ : DataCrc32(0)
+ , Priority(PP_NORMAL)
+ {
+ }
+ TSendResponse(const TGUID& reqguid, EPacketPriority prior, TVector<char>* data)
+ : ReqGuid(reqguid)
+ , DataCrc32(0)
+ , Priority(prior)
+ {
+ if (data && !data->empty()) {
+ data->swap(Data);
+ DataCrc32 = TIncrementalChecksumCalcer::CalcBlockSum(&Data[0], Data.ysize());
+ }
+ }
+ };
+ struct TCancelRequest {
+ TGUID ReqGuid;
+
+ TCancelRequest() = default;
+ TCancelRequest(const TGUID& reqguid)
+ : ReqGuid(reqguid)
+ {
+ }
+ };
+ struct TBreakRequest {
+ TGUID ReqGuid;
+
+ TBreakRequest() = default;
+ TBreakRequest(const TGUID& reqguid)
+ : ReqGuid(reqguid)
+ {
+ }
+ };
+
+ TThread myThread;
+ bool KeepRunning, AbortTransactions;
+ TSpinLock cs;
+ TSystemEvent HasStarted;
+
+ NHPTimer::STime PingsSendT;
+
+ TIntrusivePtr<IUdpHost> Host;
+ TIntrusivePtr<NNetlibaSocket::ISocket> Socket;
+ typedef THashMap<TGUID, TOutRequestState, TGUIDHash> TOutRequestHash;
+ typedef THashMap<TGUID, TInRequestState, TGUIDHash> TInRequestHash;
+ TOutRequestHash OutRequests;
+ TInRequestHash InRequests;
+
+ typedef THashMap<int, TTransferPurpose> TTransferHash;
+ TTransferHash TransferHash;
+
+ typedef THashMap<TGUID, TIntrusivePtr<TWaitResponse>, TGUIDHash> TSyncRequests;
+ TSyncRequests SyncRequests;
+
+ // hold it here to not construct on every DoSends()
+ typedef THashSet<TGUID, TGUIDHash> TAnticipateCancels;
+ TAnticipateCancels AnticipateCancels;
+
+ TLockFreeQueue<TSendRequest*> SendReqList;
+ TLockFreeQueue<TSendResponse*> SendRespList;
+ TLockFreeQueue<TCancelRequest> CancelReqList;
+ TLockFreeQueue<TBreakRequest> BreakReqList;
+
+ TIntrusivePtr<TRequesterUserQueueSizes> QueueSizes;
+ TIntrusivePtr<TRequesterUserQueues> UserQueues;
+
+ struct TStatsRequest: public TThrRefBase {
+ enum EReq {
+ PENDING_SIZE,
+ DEBUG_INFO,
+ HAS_IN_REQUEST,
+ GET_PEER_ADDRESS,
+ GET_PEER_QUEUE_STATS,
+ };
+ EReq Req;
+ TRequesterPendingDataStats PendingDataSize;
+ TString DebugInfo;
+ TGUID RequestId;
+ TUdpAddress PeerAddress;
+ TIntrusivePtr<IPeerQueueStats> QueueStats;
+ bool RequestFound;
+ TSystemEvent Complete;
+
+ TStatsRequest(EReq req)
+ : Req(req)
+ , RequestFound(false)
+ {
+ }
+ };
+ TLockFreeQueue<TIntrusivePtr<TStatsRequest>> StatsReqList;
+
+ bool ReportRequestCancel;
+ bool ReportSendRequestAcc;
+
+ void FinishRequest(TOutRequestHash::iterator i, TUdpHttpResponse::EResult ok, TAutoPtr<TRequest> data, const char* error = nullptr) {
+ TOutRequestState& s = i->second;
+ TUdpHttpResponse* res = new TUdpHttpResponse;
+ res->DataHolder = data;
+ res->ReqId = i->first;
+ res->PeerAddress = s.Address;
+ res->Ok = ok;
+ if (ok == TUdpHttpResponse::FAILED)
+ res->Error = error ? error : "request failed";
+ else if (ok == TUdpHttpResponse::CANCELED)
+ res->Error = error ? error : "request cancelled";
+ TSyncRequests::iterator k = SyncRequests.find(res->ReqId);
+ if (k != SyncRequests.end()) {
+ TIntrusivePtr<TWaitResponse>& wr = k->second;
+ wr->SetResponse(res);
+ SyncRequests.erase(k);
+ } else {
+ s.UserQueues->AddResponse(res);
+ }
+
+ OutRequests.erase(i);
+ }
+ int SendWithHighPriority(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data) {
+ ui32 crc32 = CalcChecksum(data->GetChain());
+ return Host->Send(addr, data.Release(), crc32, nullptr, PP_HIGH);
+ }
+ void ProcessIncomingPackets() {
+ TVector<TGUID, TCustomAllocator<TGUID>> failedRequests;
+ for (;;) {
+ TAutoPtr<TRequest> req = Host->GetRequest();
+ if (req.Get() == nullptr) {
+ if (!failedRequests.empty()) {
+ // we want to handle following sequence of events
+ // <- send ping
+ // -> send response over IB
+ // -> send ping response (no such request) over UDP
+ // Now if we are lucky enough we can get IB response waiting in the IB receive queue
+ // at the same time response sender will receive "send complete" from IB
+ // indeed, IB delivered message (but it was not parsed by ib_cs.cpp yet)
+ // so after receiving "send response complete" event resposne sender can legally response
+ // to pings with "no such request"
+ // but ping responses can be sent over UDP
+ // So we can run into situation with negative ping response in
+ // UDP receive queue and response waiting unprocessed in IB receive queue
+ // to check that there is no response in the IB queue we have to process IB queues
+ // so we call IBStep()
+ Host->IBStep();
+ req = Host->GetRequest();
+ if (req.Get() == nullptr) {
+ break;
+ }
+ } else {
+ break;
+ }
+ }
+
+ TBlockChainIterator reqData(req->Data->GetChain());
+ char pktType;
+ reqData.Read(&pktType, 1);
+ switch (pktType) {
+ case PKT_REQUEST:
+ case PKT_LOCAL_REQUEST: {
+ //printf("recv PKT_REQUEST or PKT_LOCAL_REQUEST\n");
+ TGUID reqId = req->Guid;
+ TInRequestHash::iterator z = InRequests.find(reqId);
+ if (z != InRequests.end()) {
+ // oops, this request already exists!
+ // might happen if request can be stored in single packet
+ // and this packet had source IP broken during transmission and managed to pass crc checks
+ // since we already reported wrong source address for this request to the user
+ // the best thing we can do is to stop the program to avoid further complications
+ // but we just report the accident to stderr
+ fprintf(stderr, "Jackpot, same request %s received twice from %s and earlier from %s\n",
+ GetGuidAsString(reqId).c_str(), GetAddressAsString(z->second.Address).c_str(),
+ GetAddressAsString(req->Address).c_str());
+ } else {
+ InRequests[reqId] = TInRequestState(req->Address);
+
+ //printf("InReq %s PKT_REQUEST recv ... -> S_WAITING\n", GetGuidAsString(reqId).c_str());
+
+ TUdpHttpRequest* res = new TUdpHttpRequest;
+ res->ReqId = reqId;
+ res->PeerAddress = req->Address;
+ res->DataHolder = req;
+
+ UserQueues->AddRequest(res);
+ }
+ } break;
+ case PKT_PING: {
+ //printf("recv PKT_PING\n");
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ bool ok = InRequests.find(guid) != InRequests.end();
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_PING_RESPONSE);
+ ms->Write(guid);
+ ms->Write(ok);
+ SendWithHighPriority(req->Address, ms.Release());
+ //printf("InReq %s PKT_PING recv Sending PKT_PING_RESPONSE\n", GetGuidAsString(guid).c_str());
+ //printf("got PKT_PING, responding %d\n", (int)ok);
+ } break;
+ case PKT_PING_RESPONSE: {
+ //printf("recv PKT_PING_RESPONSE\n");
+ TGUID guid;
+ bool ok;
+ reqData.Read(&guid, sizeof(guid));
+ reqData.Read(&ok, sizeof(ok));
+ TOutRequestHash::iterator i = OutRequests.find(guid);
+ if (i == OutRequests.end()) {
+ ; //Y_ASSERT(0); // actually possible with some packet orders
+ } else {
+ if (!ok) {
+ // can not delete request at this point
+ // since we can receive failed ping and response at the same moment
+ // consider sequence: client sends ping, server sends response
+ // and replies false to ping as reply is sent
+ // we can not receive failed ping_response earlier then response itself
+ // but we can receive them simultaneously
+ failedRequests.push_back(guid);
+ //printf("OutReq %s PKT_PING_RESPONSE recv no such query -> failed\n", GetGuidAsString(guid).c_str());
+ } else {
+ TOutRequestState& s = i->second;
+ switch (s.State) {
+ case TOutRequestState::S_WAITING_PING_SENDING: {
+ Y_ASSERT(s.PingTransferId >= 0);
+ TTransferHash::iterator k = TransferHash.find(s.PingTransferId);
+ if (k != TransferHash.end())
+ TransferHash.erase(k);
+ else
+ Y_ASSERT(0);
+ s.PingTransferId = -1;
+ s.TimePassed = 0;
+ s.State = TOutRequestState::S_WAITING;
+ //printf("OutReq %s PKT_PING_RESPONSE recv S_WAITING_PING_SENDING -> S_WAITING\n", GetGuidAsString(guid).c_str());
+ } break;
+ case TOutRequestState::S_WAITING_PING_SENT:
+ s.TimePassed = 0;
+ s.State = TOutRequestState::S_WAITING;
+ //printf("OutReq %s PKT_PING_RESPONSE recv S_WAITING_PING_SENT -> S_WAITING\n", GetGuidAsString(guid).c_str());
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ }
+ } break;
+ case PKT_RESPONSE:
+ case PKT_LOCAL_RESPONSE: {
+ //printf("recv PKT_RESPONSE or PKT_LOCAL_RESPONSE\n");
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ TOutRequestHash::iterator i = OutRequests.find(guid);
+ if (i == OutRequests.end()) {
+ ; //Y_ASSERT(0); // does happen
+ //printf("OutReq %s PKT_RESPONSE recv for non-existing req\n", GetGuidAsString(guid).c_str());
+ } else {
+ FinishRequest(i, TUdpHttpResponse::OK, req);
+ //printf("OutReq %s PKT_RESPONSE recv ... -> ok\n", GetGuidAsString(guid).c_str());
+ }
+ } break;
+ case PKT_CANCEL: {
+ //printf("recv PKT_CANCEL\n");
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ TInRequestHash::iterator i = InRequests.find(guid);
+ if (i == InRequests.end()) {
+ ; //Y_ASSERT(0); // may happen
+ //printf("InReq %s PKT_CANCEL recv for non-existing req\n", GetGuidAsString(guid).c_str());
+ } else {
+ TInRequestState& s = i->second;
+ if (s.State != TInRequestState::S_CANCELED && ReportRequestCancel)
+ UserQueues->AddCancel(guid);
+ s.State = TInRequestState::S_CANCELED;
+ //printf("InReq %s PKT_CANCEL recv\n", GetGuidAsString(guid).c_str());
+ }
+ } break;
+ case PKT_GETDEBUGINFO: {
+ //printf("recv PKT_GETDEBUGINFO\n");
+ TString dbgInfo = GetDebugInfoLocked();
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write(dbgInfo.c_str(), (int)dbgInfo.size());
+ SendWithHighPriority(req->Address, ms);
+ } break;
+ default:
+ Y_ASSERT(0);
+ }
+ }
+ // cleanup failed requests
+ for (size_t k = 0; k < failedRequests.size(); ++k) {
+ const TGUID& guid = failedRequests[k];
+ TOutRequestHash::iterator i = OutRequests.find(guid);
+ if (i != OutRequests.end())
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: recv no such query");
+ }
+ }
+ void AnalyzeSendResults() {
+ TSendResult res;
+ while (Host->GetSendResult(&res)) {
+ //printf("Send result received\n");
+ TTransferHash::iterator k1 = TransferHash.find(res.TransferId);
+ if (k1 != TransferHash.end()) {
+ const TTransferPurpose& tp = k1->second;
+ switch (tp.Dir) {
+ case DIR_OUT: {
+ TOutRequestHash::iterator i = OutRequests.find(tp.Guid);
+ if (i != OutRequests.end()) {
+ const TGUID& reqId = i->first;
+ TOutRequestState& s = i->second;
+ switch (s.State) {
+ case TOutRequestState::S_SENDING:
+ if (!res.Success) {
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_SENDING");
+ //printf("OutReq %s AnalyzeSendResults() S_SENDING -> failed\n", GetGuidAsString(reqId).c_str());
+ } else {
+ if (ReportSendRequestAcc) {
+ if (s.UserQueues.Get()) {
+ s.UserQueues->AddSendRequestAcc(reqId);
+ } else {
+ // waitable request?
+ TSyncRequests::iterator k2 = SyncRequests.find(reqId);
+ if (k2 != SyncRequests.end()) {
+ TIntrusivePtr<TWaitResponse>& wr = k2->second;
+ wr->SetRequestSent();
+ }
+ }
+ }
+ s.State = TOutRequestState::S_WAITING;
+ //printf("OutReq %s AnalyzeSendResults() S_SENDING -> S_WAITING\n", GetGuidAsString(reqId).c_str());
+ s.TimePassed = 0;
+ }
+ break;
+ case TOutRequestState::S_CANCEL_AFTER_SENDING:
+ DoSendCancel(s.Address, reqId);
+ FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request failed: state S_CANCEL_AFTER_SENDING");
+ break;
+ case TOutRequestState::S_WAITING:
+ case TOutRequestState::S_WAITING_PING_SENT:
+ Y_ASSERT(0);
+ break;
+ case TOutRequestState::S_WAITING_PING_SENDING:
+ Y_ASSERT(s.PingTransferId >= 0 && s.PingTransferId == res.TransferId);
+ if (!res.Success) {
+ FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_WAITING_PING_SENDING");
+ //printf("OutReq %s AnalyzeSendResults() S_WAITING_PING_SENDING -> failed\n", GetGuidAsString(reqId).c_str());
+ } else {
+ s.PingTransferId = -1;
+ s.State = TOutRequestState::S_WAITING_PING_SENT;
+ //printf("OutReq %s AnalyzeSendResults() S_WAITING_PING_SENDING -> S_WAITING_PING_SENT\n", GetGuidAsString(reqId).c_str());
+ s.TimePassed = 0;
+ }
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ }
+ } break;
+ case DIR_IN: {
+ TInRequestHash::iterator i = InRequests.find(tp.Guid);
+ if (i != InRequests.end()) {
+ Y_ASSERT(i->second.State == TInRequestState::S_RESPONSE_SENDING || i->second.State == TInRequestState::S_CANCELED);
+ InRequests.erase(i);
+ //if (res.Success)
+ // printf("InReq %s AnalyzeSendResults() ... -> finished\n", GetGuidAsString(tp.Guid).c_str());
+ //else
+ // printf("InReq %s AnalyzeSendResults() ... -> failed response send\n", GetGuidAsString(tp.Guid).c_str());
+ }
+ } break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ TransferHash.erase(k1);
+ }
+ }
+ }
+ void SendPingsIfNeeded() {
+ NHPTimer::STime tChk = PingsSendT;
+ float deltaT = (float)NHPTimer::GetTimePassed(&tChk);
+ if (deltaT < 0.05) {
+ return;
+ }
+ PingsSendT = tChk;
+ deltaT = ClampVal(deltaT, 0.0f, HTTP_TIMEOUT / 3);
+
+ {
+ for (TOutRequestHash::iterator i = OutRequests.begin(); i != OutRequests.end();) {
+ TOutRequestHash::iterator curIt = i++;
+ TOutRequestState& s = curIt->second;
+ const TGUID& guid = curIt->first;
+ switch (s.State) {
+ case TOutRequestState::S_WAITING:
+ s.TimePassed += deltaT;
+ if (s.TimePassed > HTTP_TIMEOUT) {
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_PING);
+ ms->Write(guid);
+ int transId = SendWithHighPriority(s.Address, ms.Release());
+ TransferHash[transId] = TTransferPurpose(DIR_OUT, guid);
+ s.State = TOutRequestState::S_WAITING_PING_SENDING;
+ //printf("OutReq %s SendPingsIfNeeded() S_WAITING -> S_WAITING_PING_SENDING\n", GetGuidAsString(guid).c_str());
+ s.PingTransferId = transId;
+ }
+ break;
+ case TOutRequestState::S_WAITING_PING_SENT:
+ s.TimePassed += deltaT;
+ if (s.TimePassed > HTTP_TIMEOUT) {
+ //printf("OutReq %s SendPingsIfNeeded() S_WAITING_PING_SENT -> failed\n", GetGuidAsString(guid).c_str());
+ FinishRequest(curIt, TUdpHttpResponse::FAILED, nullptr, "request failed: http timeout in state S_WAITING_PING_SENT");
+ }
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ }
+ void Step() {
+ {
+ TGuard<TSpinLock> lock(cs);
+ DoSends();
+ }
+ Host->Step();
+ for (TIntrusivePtr<TStatsRequest> req; StatsReqList.Dequeue(&req);) {
+ switch (req->Req) {
+ case TStatsRequest::PENDING_SIZE:
+ Host->GetPendingDataSize(&req->PendingDataSize);
+ break;
+ case TStatsRequest::DEBUG_INFO: {
+ TGuard<TSpinLock> lock(cs);
+ req->DebugInfo = GetDebugInfoLocked();
+ } break;
+ case TStatsRequest::HAS_IN_REQUEST: {
+ TGuard<TSpinLock> lock(cs);
+ req->RequestFound = (InRequests.find(req->RequestId) != InRequests.end());
+ } break;
+ case TStatsRequest::GET_PEER_ADDRESS: {
+ TGuard<TSpinLock> lock(cs);
+ TInRequestHash::const_iterator i = InRequests.find(req->RequestId);
+ if (i != InRequests.end()) {
+ req->PeerAddress = i->second.Address;
+ } else {
+ TOutRequestHash::const_iterator o = OutRequests.find(req->RequestId);
+ if (o != OutRequests.end()) {
+ req->PeerAddress = o->second.Address;
+ } else {
+ req->PeerAddress = TUdpAddress();
+ }
+ }
+ } break;
+ case TStatsRequest::GET_PEER_QUEUE_STATS:
+ req->QueueStats = Host->GetQueueStats(req->PeerAddress);
+ break;
+ default:
+ Y_ASSERT(0);
+ break;
+ }
+ req->Complete.Signal();
+ }
+ {
+ TGuard<TSpinLock> lock(cs);
+ DoSends();
+ ProcessIncomingPackets();
+ AnalyzeSendResults();
+ SendPingsIfNeeded();
+ }
+ }
+ void Wait() {
+ Host->Wait(0.1f);
+ }
+ void DoSendCancel(const TUdpAddress& addr, const TGUID& req) {
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ms->Write((char)PKT_CANCEL);
+ ms->Write(req);
+ SendWithHighPriority(addr, ms);
+ }
+ void DoSends() {
+ {
+ TBreakRequest rb;
+ while (BreakReqList.Dequeue(&rb)) {
+ InRequests.erase(rb.ReqGuid);
+ }
+ }
+ {
+ // cancelling requests
+ TCancelRequest rc;
+ while (CancelReqList.Dequeue(&rc)) {
+ TOutRequestHash::iterator i = OutRequests.find(rc.ReqGuid);
+ if (i == OutRequests.end()) {
+ AnticipateCancels.insert(rc.ReqGuid);
+ continue; // cancelling non existing request is ok
+ }
+ TOutRequestState& s = i->second;
+ if (s.State == TOutRequestState::S_SENDING) {
+ // we are in trouble - have not sent request and we already have to cancel it, wait send
+ s.State = TOutRequestState::S_CANCEL_AFTER_SENDING;
+ } else {
+ DoSendCancel(s.Address, rc.ReqGuid);
+ FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request canceled: notify requested side");
+ }
+ }
+ }
+ {
+ // sending replies
+ for (TSendResponse* rd = nullptr; SendRespList.Dequeue(&rd); delete rd) {
+ TInRequestHash::iterator i = InRequests.find(rd->ReqGuid);
+ if (i == InRequests.end()) {
+ Y_ASSERT(0);
+ continue;
+ }
+ TInRequestState& s = i->second;
+ if (s.State == TInRequestState::S_CANCELED) {
+ // need not send response for the canceled request
+ InRequests.erase(i);
+ continue;
+ }
+
+ Y_ASSERT(s.State == TInRequestState::S_WAITING);
+ s.State = TInRequestState::S_RESPONSE_SENDING;
+ //printf("InReq %s SendResponse() ... -> S_RESPONSE_SENDING (pkt %s)\n", GetGuidAsString(reqId).c_str(), GetGuidAsString(lowPktGuid).c_str());
+
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ ui32 crc32 = 0;
+ int dataSize = rd->Data.ysize();
+ if (rd->Data.ysize() > MIN_SHARED_MEM_PACKET && IsLocalFast(s.Address)) {
+ TIntrusivePtr<TSharedMemory> shm = new TSharedMemory;
+ if (shm->Create(dataSize)) {
+ ms->Write((char)PKT_LOCAL_RESPONSE);
+ ms->Write(rd->ReqGuid);
+ memcpy(shm->GetPtr(), &rd->Data[0], dataSize);
+ TVector<char> empty;
+ rd->Data.swap(empty);
+ ms->AttachSharedData(shm);
+ crc32 = CalcChecksum(ms->GetChain());
+ }
+ }
+ if (ms->GetSharedData() == nullptr) {
+ ms->Write((char)PKT_RESPONSE);
+ ms->Write(rd->ReqGuid);
+
+ // to offload crc calcs from inner thread, crc of data[] is calced outside and passed in DataCrc32
+ // this means that we are calculating crc when shared memory is used
+ // it is hard to avoid since in SendResponse() we don't know if shared mem will be used (peer address is not available there)
+ TIncrementalChecksumCalcer csCalcer;
+ AddChain(&csCalcer, ms->GetChain());
+ // here we are replicating the way WriteDestructive serializes data
+ csCalcer.AddBlock(&dataSize, sizeof(dataSize));
+ csCalcer.AddBlockSum(rd->DataCrc32, dataSize);
+ crc32 = csCalcer.CalcChecksum();
+
+ ms->WriteDestructive(&rd->Data);
+ //ui32 chkCrc = CalcChecksum(ms->GetChain()); // can not use since its slow for large responses
+ //Y_ASSERT(chkCrc == crc32);
+ }
+
+ int transId = Host->Send(s.Address, ms.Release(), crc32, nullptr, rd->Priority);
+ TransferHash[transId] = TTransferPurpose(DIR_IN, rd->ReqGuid);
+ }
+ }
+ {
+ // sending requests
+ for (TSendRequest* rd = nullptr; SendReqList.Dequeue(&rd); delete rd) {
+ Y_ASSERT(OutRequests.find(rd->ReqGuid) == OutRequests.end());
+
+ {
+ TOutRequestState& s = OutRequests[rd->ReqGuid];
+ s.State = TOutRequestState::S_SENDING;
+ s.Address = rd->Addr;
+ s.UserQueues = rd->UserQueues;
+ //printf("OutReq %s SendRequest() ... -> S_SENDING\n", GetGuidAsString(guid).c_str());
+ }
+
+ if (rd->WR.Get())
+ SyncRequests[rd->ReqGuid] = rd->WR;
+
+ if (AnticipateCancels.find(rd->ReqGuid) != AnticipateCancels.end()) {
+ FinishRequest(OutRequests.find(rd->ReqGuid), TUdpHttpResponse::CANCELED, nullptr, "request canceled before transmitting");
+ } else {
+ TGUID pktGuid = rd->ReqGuid; // request packet id should match request id
+ int transId = Host->Send(rd->Addr, rd->Data.Release(), rd->Crc32, &pktGuid, PP_NORMAL);
+ TransferHash[transId] = TTransferPurpose(DIR_OUT, rd->ReqGuid);
+ }
+ }
+ }
+ if (!AnticipateCancels.empty()) {
+ AnticipateCancels.clear();
+ }
+ }
+
+ public:
+ void SendRequestImpl(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId,
+ TWaitResponse* wr, TRequesterUserQueues* userQueues) {
+ if (data && data->size() > MAX_PACKET_SIZE) {
+ Y_VERIFY(0, "data size is too large");
+ }
+ //printf("SendRequest(%s)\n", url.c_str());
+ if (wr)
+ wr->SetReqId(reqId);
+
+ TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket;
+ if (data && data->ysize() > MIN_SHARED_MEM_PACKET && IsLocalFast(addr)) {
+ int dataSize = data->ysize();
+ TIntrusivePtr<TSharedMemory> shm = new TSharedMemory;
+ if (shm->Create(dataSize)) {
+ ms->Write((char)PKT_LOCAL_REQUEST);
+ ms->WriteStroka(url);
+ memcpy(shm->GetPtr(), &(*data)[0], dataSize);
+ TVector<char> empty;
+ data->swap(empty);
+ ms->AttachSharedData(shm);
+ }
+ }
+ if (ms->GetSharedData() == nullptr) {
+ ms->Write((char)PKT_REQUEST);
+ ms->WriteStroka(url);
+ ms->WriteDestructive(data);
+ }
+
+ SendReqList.Enqueue(new TSendRequest(addr, &ms, reqId, wr, userQueues));
+ Host->CancelWait();
+ }
+
+ void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) override {
+ SendRequestImpl(addr, url, data, reqId, nullptr, UserQueues.Get());
+ }
+ void CancelRequest(const TGUID& reqId) override {
+ CancelReqList.Enqueue(TCancelRequest(reqId));
+ Host->CancelWait();
+ }
+ void BreakRequest(const TGUID& reqId) override {
+ BreakReqList.Enqueue(TBreakRequest(reqId));
+ Host->CancelWait();
+ }
+
+ void SendResponseImpl(const TGUID& reqId, EPacketPriority prior, TVector<char>* data) // non-virtual, for direct call from TRequestOps
+ {
+ if (data && data->size() > MAX_PACKET_SIZE) {
+ Y_VERIFY(0, "data size is too large");
+ }
+ SendRespList.Enqueue(new TSendResponse(reqId, prior, data));
+ Host->CancelWait();
+ }
+ void SendResponse(const TGUID& reqId, TVector<char>* data) override {
+ SendResponseImpl(reqId, PP_NORMAL, data);
+ }
+ void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) override {
+ SendResponseImpl(reqId, PP_LOW, data);
+ }
+ TUdpHttpRequest* GetRequest() override {
+ return UserQueues->GetRequest();
+ }
+ TUdpHttpResponse* GetResponse() override {
+ return UserQueues->GetResponse();
+ }
+ bool GetRequestCancel(TGUID* req) override {
+ return UserQueues->GetRequestCancel(req);
+ }
+ bool GetSendRequestAcc(TGUID* req) override {
+ return UserQueues->GetSendRequestAcc(req);
+ }
+ TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) override {
+ TIntrusivePtr<TWaitResponse> wr = WaitableRequest(addr, url, data);
+ wr->Wait();
+ return wr->GetResponse();
+ }
+ TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) override {
+ TIntrusivePtr<TWaitResponse> wr = new TWaitResponse;
+ TGUID reqId;
+ CreateGuid(&reqId);
+ SendRequestImpl(addr, url, data, reqId, wr.Get(), nullptr);
+ return wr;
+ }
+ TMuxEvent& GetAsyncEvent() override {
+ return UserQueues->GetAsyncEvent();
+ }
+ int GetPort() override {
+ return Socket.Get() ? Socket->GetPort() : 0;
+ }
+ void StopNoWait() override {
+ AbortTransactions = true;
+ KeepRunning = false;
+ UserQueues->AsyncSignal();
+ // calcel all outgoing requests
+ TGuard<TSpinLock> lock(cs);
+ while (!OutRequests.empty()) {
+ // cancel without informing peer that we are cancelling the request
+ FinishRequest(OutRequests.begin(), TUdpHttpResponse::CANCELED, nullptr, "request canceled: inside TUdpHttp::StopNoWait()");
+ }
+ }
+ void ExecStatsRequest(TIntrusivePtr<TStatsRequest> req) {
+ StatsReqList.Enqueue(req);
+ Host->CancelWait();
+ req->Complete.Wait();
+ }
+ TUdpAddress GetPeerAddress(const TGUID& reqId) override {
+ TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::GET_PEER_ADDRESS);
+ req->RequestId = reqId;
+ ExecStatsRequest(req);
+ return req->PeerAddress;
+ }
+ void GetPendingDataSize(TRequesterPendingDataStats* res) override {
+ TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::PENDING_SIZE);
+ ExecStatsRequest(req);
+ *res = req->PendingDataSize;
+ }
+ bool HasRequest(const TGUID& reqId) override {
+ TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::HAS_IN_REQUEST);
+ req->RequestId = reqId;
+ ExecStatsRequest(req);
+ return req->RequestFound;
+ }
+
+ private:
+ void FinishOutstandingTransactions() {
+ // wait all pending requests, all new requests are canceled
+ while ((!OutRequests.empty() || !InRequests.empty() || !SendRespList.IsEmpty() || !SendReqList.IsEmpty()) && !PanicAttack) {
+ while (TUdpHttpRequest* req = GetRequest()) {
+ TInRequestHash::iterator i = InRequests.find(req->ReqId);
+ //printf("dropping request(%s) (thread %d)\n", req->Url.c_str(), ThreadId());
+ delete req;
+ if (i == InRequests.end()) {
+ Y_ASSERT(0);
+ continue;
+ }
+ InRequests.erase(i);
+ }
+ Step();
+ sleep(0);
+ }
+ }
+ static void* ExecServerThread(void* param) {
+ BindToSocket(0);
+ SetHighestThreadPriority();
+ TUdpHttp* pThis = (TUdpHttp*)param;
+ pThis->Host = CreateUdpHost(pThis->Socket);
+ pThis->HasStarted.Signal();
+ if (!pThis->Host) {
+ pThis->Socket.Drop();
+ return nullptr;
+ }
+ NHPTimer::GetTime(&pThis->PingsSendT);
+ while (pThis->KeepRunning && !PanicAttack) {
+ if (HeartbeatTimeout.load(std::memory_order_acquire) > 0) {
+ NHPTimer::STime chk = LastHeartbeat.load(std::memory_order_acquire);
+ double passed = NHPTimer::GetTimePassed(&chk);
+ if (passed > HeartbeatTimeout.load(std::memory_order_acquire)) {
+ StopAllNetLibaThreads();
+ fprintf(stderr, "%s\tTUdpHttp\tWaiting for %0.2f, time limit %0.2f, commit a suicide!11\n", Now().ToStringUpToSeconds().c_str(), passed, HeartbeatTimeout.load(std::memory_order_acquire));
+ fflush(stderr);
+#ifndef _win_
+ killpg(0, SIGKILL);
+#endif
+ abort();
+ break;
+ }
+ }
+ pThis->Step();
+ pThis->Wait();
+ }
+ if (!pThis->AbortTransactions && !PanicAttack)
+ pThis->FinishOutstandingTransactions();
+ pThis->Host = nullptr;
+ return nullptr;
+ }
+ ~TUdpHttp() override {
+ if (myThread.Running()) {
+ KeepRunning = false;
+ myThread.Join();
+ }
+ for (TIntrusivePtr<TStatsRequest> req; StatsReqList.Dequeue(&req);) {
+ req->Complete.Signal();
+ }
+ }
+
+ public:
+ TUdpHttp()
+ : myThread(TThread::TParams(ExecServerThread, (void*)this).SetName("nl6_udp_host"))
+ , KeepRunning(true)
+ , AbortTransactions(false)
+ , PingsSendT(0)
+ , ReportRequestCancel(false)
+ , ReportSendRequestAcc(false)
+ {
+ NHPTimer::GetTime(&PingsSendT);
+ QueueSizes = new TRequesterUserQueueSizes;
+ UserQueues = new TRequesterUserQueues(QueueSizes.Get());
+ }
+ bool Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) {
+ Y_ASSERT(Host.Get() == nullptr);
+ Socket = socket;
+ myThread.Start();
+ HasStarted.Wait();
+
+ if (Host.Get()) {
+ return true;
+ }
+ Socket.Drop();
+ return false;
+ }
+ TString GetDebugInfoLocked() {
+ TString res = KeepRunning ? "State: running\n" : "State: stopping\n";
+ res += Host->GetDebugInfo();
+
+ char buf[1000];
+ TRequesterUserQueueSizes* qs = QueueSizes.Get();
+ sprintf(buf, "\nRequest queue %d (%d bytes)\n", (int)AtomicGet(qs->ReqCount), (int)AtomicGet(qs->ReqQueueSize));
+ res += buf;
+ sprintf(buf, "Response queue %d (%d bytes)\n", (int)AtomicGet(qs->RespCount), (int)AtomicGet(qs->RespQueueSize));
+ res += buf;
+
+ const char* outReqStateNames[] = {
+ "S_SENDING",
+ "S_WAITING",
+ "S_WAITING_PING_SENDING",
+ "S_WAITING_PING_SENT",
+ "S_CANCEL_AFTER_SENDING"};
+ const char* inReqStateNames[] = {
+ "S_WAITING",
+ "S_RESPONSE_SENDING",
+ "S_CANCELED"};
+ res += "\nOut requests:\n";
+ for (TOutRequestHash::const_iterator i = OutRequests.begin(); i != OutRequests.end(); ++i) {
+ const TGUID& gg = i->first;
+ const TOutRequestState& s = i->second;
+ bool isSync = SyncRequests.find(gg) != SyncRequests.end();
+ sprintf(buf, "%s\t%s %s TimePassed: %g %s\n",
+ GetAddressAsString(s.Address).c_str(), GetGuidAsString(gg).c_str(), outReqStateNames[s.State],
+ s.TimePassed * 1000,
+ isSync ? "isSync" : "");
+ res += buf;
+ }
+ res += "\nIn requests:\n";
+ for (TInRequestHash::const_iterator i = InRequests.begin(); i != InRequests.end(); ++i) {
+ const TGUID& gg = i->first;
+ const TInRequestState& s = i->second;
+ sprintf(buf, "%s\t%s %s\n",
+ GetAddressAsString(s.Address).c_str(), GetGuidAsString(gg).c_str(), inReqStateNames[s.State]);
+ res += buf;
+ }
+ return res;
+ }
+ TString GetDebugInfo() override {
+ TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::DEBUG_INFO);
+ ExecStatsRequest(req);
+ return req->DebugInfo;
+ }
+ void GetRequestQueueSize(TRequesterQueueStats* res) override {
+ TRequesterUserQueueSizes* qs = QueueSizes.Get();
+ res->ReqCount = (int)AtomicGet(qs->ReqCount);
+ res->RespCount = (int)AtomicGet(qs->RespCount);
+ res->ReqQueueSize = (int)AtomicGet(qs->ReqQueueSize);
+ res->RespQueueSize = (int)AtomicGet(qs->RespQueueSize);
+ }
+ TRequesterUserQueueSizes* GetQueueSizes() const {
+ return QueueSizes.Get();
+ }
+ IRequestOps* CreateSubRequester() override;
+ void EnableReportRequestCancel() override {
+ ReportRequestCancel = true;
+ }
+ void EnableReportSendRequestAcc() override {
+ ReportSendRequestAcc = true;
+ }
+ TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) override {
+ TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::GET_PEER_QUEUE_STATS);
+ req->PeerAddress = addr;
+ ExecStatsRequest(req);
+ return req->QueueStats;
+ }
+ };
+
+ //////////////////////////////////////////////////////////////////////////
+ static void ReadShm(TSharedMemory* shm, TVector<char>* data) {
+ Y_ASSERT(shm);
+ int dataSize = shm->GetSize();
+ data->yresize(dataSize);
+ memcpy(&(*data)[0], shm->GetPtr(), dataSize);
+ }
+
+ static void LoadRequestData(TUdpHttpRequest* res) {
+ if (!res)
+ return;
+ {
+ TBlockChainIterator reqData(res->DataHolder->Data->GetChain());
+ char pktType;
+ reqData.Read(&pktType, 1);
+ ReadArr(&reqData, &res->Url);
+ if (pktType == PKT_REQUEST) {
+ ReadYArr(&reqData, &res->Data);
+ } else if (pktType == PKT_LOCAL_REQUEST) {
+ ReadShm(res->DataHolder->Data->GetSharedData(), &res->Data);
+ } else
+ Y_ASSERT(0);
+ if (reqData.HasFailed()) {
+ Y_ASSERT(0 && "wrong format, memory corruption suspected");
+ res->Url = "";
+ res->Data.clear();
+ }
+ }
+ res->DataHolder.Reset(nullptr);
+ }
+
+ static void LoadResponseData(TUdpHttpResponse* res) {
+ if (!res || res->DataHolder.Get() == nullptr)
+ return;
+ {
+ TBlockChainIterator reqData(res->DataHolder->Data->GetChain());
+ char pktType;
+ reqData.Read(&pktType, 1);
+ TGUID guid;
+ reqData.Read(&guid, sizeof(guid));
+ Y_ASSERT(res->ReqId == guid);
+ if (pktType == PKT_RESPONSE) {
+ ReadYArr(&reqData, &res->Data);
+ } else if (pktType == PKT_LOCAL_RESPONSE) {
+ ReadShm(res->DataHolder->Data->GetSharedData(), &res->Data);
+ } else
+ Y_ASSERT(0);
+ if (reqData.HasFailed()) {
+ Y_ASSERT(0 && "wrong format, memory corruption suspected");
+ res->Ok = TUdpHttpResponse::FAILED;
+ res->Data.clear();
+ res->Error = "wrong response format";
+ }
+ }
+ res->DataHolder.Reset(nullptr);
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ // IRequestOps::TWaitResponse
+ TUdpHttpResponse* IRequestOps::TWaitResponse::GetResponse() {
+ if (!Response)
+ return nullptr;
+ TUdpHttpResponse* res = Response;
+ Response = nullptr;
+ LoadResponseData(res);
+ return res;
+ }
+
+ void IRequestOps::TWaitResponse::SetResponse(TUdpHttpResponse* r) {
+ Y_ASSERT(Response == nullptr || r == nullptr);
+ if (r)
+ Response = r;
+ CompleteEvent.Signal();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ // TRequesterUserQueues
+ TUdpHttpRequest* TRequesterUserQueues::GetRequest() {
+ TUdpHttpRequest* res = nullptr;
+ ReqList.Dequeue(&res);
+ if (res) {
+ AtomicAdd(QueueSizes->ReqCount, -1);
+ AtomicAdd(QueueSizes->ReqQueueSize, -GetPacketSize(res->DataHolder.Get()));
+ }
+ UpdateAsyncSignalState();
+ LoadRequestData(res);
+ return res;
+ }
+
+ TUdpHttpResponse* TRequesterUserQueues::GetResponse() {
+ TUdpHttpResponse* res = nullptr;
+ ResponseList.Dequeue(&res);
+ if (res) {
+ AtomicAdd(QueueSizes->RespCount, -1);
+ AtomicAdd(QueueSizes->RespQueueSize, -GetPacketSize(res->DataHolder.Get()));
+ }
+ UpdateAsyncSignalState();
+ LoadResponseData(res);
+ return res;
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ class TRequestOps: public IRequestOps {
+ TIntrusivePtr<TUdpHttp> Requester;
+ TIntrusivePtr<TRequesterUserQueues> UserQueues;
+
+ public:
+ TRequestOps(TUdpHttp* req)
+ : Requester(req)
+ {
+ UserQueues = new TRequesterUserQueues(req->GetQueueSizes());
+ }
+ void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) override {
+ Requester->SendRequestImpl(addr, url, data, reqId, nullptr, UserQueues.Get());
+ }
+ void CancelRequest(const TGUID& reqId) override {
+ Requester->CancelRequest(reqId);
+ }
+ void BreakRequest(const TGUID& reqId) override {
+ Requester->BreakRequest(reqId);
+ }
+
+ void SendResponse(const TGUID& reqId, TVector<char>* data) override {
+ Requester->SendResponseImpl(reqId, PP_NORMAL, data);
+ }
+ void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) override {
+ Requester->SendResponseImpl(reqId, PP_LOW, data);
+ }
+ TUdpHttpRequest* GetRequest() override {
+ Y_ASSERT(0);
+ //return UserQueues.GetRequest();
+ return nullptr; // all requests are routed to the main requester
+ }
+ TUdpHttpResponse* GetResponse() override {
+ return UserQueues->GetResponse();
+ }
+ bool GetRequestCancel(TGUID*) override {
+ Y_ASSERT(0);
+ return false; // all request cancels are routed to the main requester
+ }
+ bool GetSendRequestAcc(TGUID* req) override {
+ return UserQueues->GetSendRequestAcc(req);
+ }
+ // sync mode
+ TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) override {
+ return Requester->Request(addr, url, data);
+ }
+ TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) override {
+ return Requester->WaitableRequest(addr, url, data);
+ }
+ //
+ TMuxEvent& GetAsyncEvent() override {
+ return UserQueues->GetAsyncEvent();
+ }
+ };
+
+ IRequestOps* TUdpHttp::CreateSubRequester() {
+ return new TRequestOps(this);
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ void AbortOnFailedRequest(TUdpHttpResponse* answer) {
+ if (answer && answer->Ok == TUdpHttpResponse::FAILED) {
+ fprintf(stderr, "Failed request to host %s\n", GetAddressAsString(answer->PeerAddress).data());
+ fprintf(stderr, "Error description: %s\n", answer->Error.data());
+ fflush(nullptr);
+ Y_ASSERT(0);
+ abort();
+ }
+ }
+
+ TString GetDebugInfo(const TUdpAddress& addr, double timeout) {
+ NHPTimer::STime start;
+ NHPTimer::GetTime(&start);
+ TIntrusivePtr<IUdpHost> host = CreateUdpHost(0);
+ {
+ TAutoPtr<TRopeDataPacket> rq = new TRopeDataPacket;
+ rq->Write((char)PKT_GETDEBUGINFO);
+ ui32 crc32 = CalcChecksum(rq->GetChain());
+ host->Send(addr, rq.Release(), crc32, nullptr, PP_HIGH);
+ }
+ for (;;) {
+ TAutoPtr<TRequest> ptr = host->GetRequest();
+ if (ptr.Get()) {
+ TBlockChainIterator reqData(ptr->Data->GetChain());
+ int sz = reqData.GetSize();
+ TString res;
+ res.resize(sz);
+ reqData.Read(res.begin(), sz);
+ return res;
+ }
+ host->Step();
+ host->Wait(0.1f);
+
+ NHPTimer::STime now;
+ NHPTimer::GetTime(&now);
+ if (NHPTimer::GetSeconds(now - start) > timeout) {
+ return TString();
+ }
+ }
+ }
+
+ void Kill(const TUdpAddress& addr) {
+ TIntrusivePtr<IUdpHost> host = CreateUdpHost(0);
+ host->Kill(addr);
+ }
+
+ void StopAllNetLibaThreads() {
+ PanicAttack = true; // AAAA!!!!
+ }
+
+ void SetNetLibaHeartbeatTimeout(double timeoutSec) {
+ NetLibaHeartbeat();
+ HeartbeatTimeout.store(timeoutSec, std::memory_order_release);
+ }
+
+ void NetLibaHeartbeat() {
+ NHPTimer::STime now;
+ NHPTimer::GetTime(&now);
+ LastHeartbeat.store(now, std::memory_order_release);
+ }
+
+ IRequester* CreateHttpUdpRequester(int port) {
+ if (PanicAttack)
+ return nullptr;
+
+ TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateSocket();
+ socket->Open(port);
+ if (!socket->IsValid())
+ return nullptr;
+
+ return CreateHttpUdpRequester(socket);
+ }
+
+ IRequester* CreateHttpUdpRequester(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) {
+ if (PanicAttack)
+ return nullptr;
+
+ TIntrusivePtr<TUdpHttp> res(new TUdpHttp);
+ if (!res->Start(socket))
+ return nullptr;
+ return res.Release();
+ }
+
+}
diff --git a/library/cpp/netliba/v6/udp_http.h b/library/cpp/netliba/v6/udp_http.h
new file mode 100644
index 0000000000..1084e7affa
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_http.h
@@ -0,0 +1,148 @@
+#pragma once
+
+#include "udp_address.h"
+#include "udp_debug.h"
+#include "net_queue_stat.h"
+
+#include <util/network/init.h>
+#include <util/generic/ptr.h>
+#include <util/generic/guid.h>
+#include <library/cpp/threading/mux_event/mux_event.h>
+#include <library/cpp/netliba/socket/socket.h>
+
+namespace NNetliba {
+ const ui64 MAX_PACKET_SIZE = 0x70000000;
+
+ struct TRequest;
+ struct TUdpHttpRequest {
+ TAutoPtr<TRequest> DataHolder;
+ TGUID ReqId;
+ TString Url;
+ TUdpAddress PeerAddress;
+ TVector<char> Data;
+
+ ~TUdpHttpRequest();
+ };
+
+ struct TUdpHttpResponse {
+ enum EResult {
+ FAILED = 0,
+ OK = 1,
+ CANCELED = 2
+ };
+ TAutoPtr<TRequest> DataHolder;
+ TGUID ReqId;
+ TUdpAddress PeerAddress;
+ TVector<char> Data;
+ EResult Ok;
+ TString Error;
+
+ ~TUdpHttpResponse();
+ };
+
+ // vector<char> *data - vector will be cleared upon call
+ struct IRequestOps: public TThrRefBase {
+ class TWaitResponse: public TThrRefBase, public TNonCopyable {
+ TGUID ReqId;
+ TMuxEvent CompleteEvent;
+ TUdpHttpResponse* Response;
+ bool RequestSent;
+
+ ~TWaitResponse() override {
+ delete GetResponse();
+ }
+
+ public:
+ TWaitResponse()
+ : Response(nullptr)
+ , RequestSent(false)
+ {
+ }
+ void Wait() {
+ CompleteEvent.Wait();
+ }
+ bool Wait(int ms) {
+ return CompleteEvent.Wait(ms);
+ }
+ TUdpHttpResponse* GetResponse();
+ bool IsRequestSent() const {
+ return RequestSent;
+ }
+ void SetResponse(TUdpHttpResponse* r);
+ void SetReqId(const TGUID& reqId) {
+ ReqId = reqId;
+ }
+ const TGUID& GetReqId() {
+ return ReqId;
+ }
+ void SetRequestSent() {
+ RequestSent = true;
+ }
+ };
+
+ // async
+ virtual void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) = 0;
+ TGUID SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) {
+ TGUID reqId;
+ CreateGuid(&reqId);
+ SendRequest(addr, url, data, reqId);
+ return reqId;
+ }
+ virtual void CancelRequest(const TGUID& reqId) = 0; //cancel request from requester side
+ virtual void BreakRequest(const TGUID& reqId) = 0; //break request-response from requester side
+
+ virtual void SendResponse(const TGUID& reqId, TVector<char>* data) = 0;
+ virtual void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) = 0;
+ virtual TUdpHttpRequest* GetRequest() = 0;
+ virtual TUdpHttpResponse* GetResponse() = 0;
+ virtual bool GetRequestCancel(TGUID* req) = 0;
+ virtual bool GetSendRequestAcc(TGUID* req) = 0;
+ // sync mode
+ virtual TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) = 0;
+ virtual TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) = 0;
+ //
+ virtual TMuxEvent& GetAsyncEvent() = 0;
+ };
+
+ struct IRequester: public IRequestOps {
+ virtual int GetPort() = 0;
+ virtual void StopNoWait() = 0;
+ virtual TUdpAddress GetPeerAddress(const TGUID& reqId) = 0;
+ virtual void GetPendingDataSize(TRequesterPendingDataStats* res) = 0;
+ virtual bool HasRequest(const TGUID& reqId) = 0;
+ virtual TString GetDebugInfo() = 0;
+ virtual void GetRequestQueueSize(TRequesterQueueStats* res) = 0;
+ virtual IRequestOps* CreateSubRequester() = 0;
+ virtual void EnableReportRequestCancel() = 0;
+ virtual void EnableReportSendRequestAcc() = 0;
+ virtual TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) = 0;
+
+ ui64 GetPendingDataSize() {
+ TRequesterPendingDataStats pds;
+ GetPendingDataSize(&pds);
+ return pds.InpDataSize + pds.OutDataSize;
+ }
+ };
+
+ IRequester* CreateHttpUdpRequester(int port);
+ IRequester* CreateHttpUdpRequester(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket);
+
+ void SetUdpMaxBandwidthPerIP(float f);
+ void SetUdpSlowStart(bool enable);
+ void SetCongCtrlChannelInflate(float inflate);
+
+ void EnableUseTOSforAcks(bool enable);
+ void EnableROCE(bool f);
+
+ void AbortOnFailedRequest(TUdpHttpResponse* answer);
+ TString GetDebugInfo(const TUdpAddress& addr, double timeout = 60);
+ void Kill(const TUdpAddress& addr);
+ void StopAllNetLibaThreads();
+
+ // if heartbeat timeout is set and NetLibaHeartbeat() is not called for timeoutSec
+ // then StopAllNetLibaThreads() will be called
+ void SetNetLibaHeartbeatTimeout(double timeoutSec);
+ void NetLibaHeartbeat();
+
+ bool IsLocal(const TUdpAddress& addr);
+}
diff --git a/library/cpp/netliba/v6/udp_socket.cpp b/library/cpp/netliba/v6/udp_socket.cpp
new file mode 100644
index 0000000000..fd85ef4d00
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_socket.cpp
@@ -0,0 +1,292 @@
+#include "stdafx.h"
+#include "udp_socket.h"
+#include "block_chain.h"
+#include "udp_address.h"
+
+#include <util/datetime/cputimer.h>
+#include <util/system/spinlock.h>
+#include <util/random/random.h>
+
+#include <library/cpp/netliba/socket/socket.h>
+
+#include <errno.h>
+
+//#define SIMULATE_NETWORK_FAILURES
+// there is no explicit bit in the packet header for last packet of transfer
+// last packet is just smaller then maximum size
+
+namespace NNetliba {
+ static bool LocalHostFound;
+ enum {
+ IPv4 = 0,
+ IPv6 = 1
+ };
+
+ struct TIPv6Addr {
+ ui64 Network, Interface;
+
+ TIPv6Addr() {
+ Zero(*this);
+ }
+ TIPv6Addr(ui64 n, ui64 i)
+ : Network(n)
+ , Interface(i)
+ {
+ }
+ };
+ inline bool operator==(const TIPv6Addr& a, const TIPv6Addr& b) {
+ return a.Interface == b.Interface && a.Network == b.Network;
+ }
+
+ static ui32 LocalHostIP[2];
+ static TVector<ui32> LocalHostIPList[2];
+ static TVector<TIPv6Addr> LocalHostIPv6List;
+
+ // Struct sockaddr_in6 does not have ui64-array representation
+ // so we add it here. This avoids "strict aliasing" warnings
+ typedef union {
+ in6_addr Addr;
+ ui64 Addr64[2];
+ } TIPv6AddrUnion;
+
+ static ui32 GetIPv6SuffixCrc(const sockaddr_in6& addr) {
+ TIPv6AddrUnion a;
+ a.Addr = addr.sin6_addr;
+ ui64 suffix = a.Addr64[1];
+ return (suffix & 0xffffffffll) + (suffix >> 32);
+ }
+
+ bool InitLocalIPList() {
+ // Do not use TMutex here: it has a non-trivial destructor which will be called before
+ // destruction of current thread, if its TThread declared as global/static variable.
+ static TAdaptiveLock cs;
+ TGuard lock(cs);
+
+ if (LocalHostFound)
+ return true;
+
+ TVector<TUdpAddress> addrs;
+ if (!GetLocalAddresses(&addrs))
+ return false;
+ for (int i = 0; i < addrs.ysize(); ++i) {
+ const TUdpAddress& addr = addrs[i];
+ if (addr.IsIPv4()) {
+ LocalHostIPList[IPv4].push_back(addr.GetIPv4());
+ LocalHostIP[IPv4] = addr.GetIPv4();
+ } else {
+ sockaddr_in6 addr6;
+ GetWinsockAddr(&addr6, addr);
+
+ LocalHostIPList[IPv6].push_back(GetIPv6SuffixCrc(addr6));
+ LocalHostIP[IPv6] = GetIPv6SuffixCrc(addr6);
+ LocalHostIPv6List.push_back(TIPv6Addr(addr.Network, addr.Interface));
+ }
+ }
+ LocalHostFound = true;
+ return true;
+ }
+
+ template <class T, class TElem>
+ inline bool IsInSet(const T& c, const TElem& e) {
+ return Find(c.begin(), c.end(), e) != c.end();
+ }
+
+ bool IsLocalIPv4(ui32 ip) {
+ return IsInSet(LocalHostIPList[IPv4], ip);
+ }
+ bool IsLocalIPv6(ui64 network, ui64 iface) {
+ return IsInSet(LocalHostIPv6List, TIPv6Addr(network, iface));
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ void TNetSocket::Open(int port) {
+ TIntrusivePtr<NNetlibaSocket::ISocket> theSocket = NNetlibaSocket::CreateSocket();
+ theSocket->Open(port);
+ Open(theSocket);
+ }
+
+ void TNetSocket::Open(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) {
+ s = socket;
+ if (IsValid()) {
+ PortCrc = s->GetSelfAddress().sin6_port;
+ }
+ }
+
+ void TNetSocket::Close() {
+ if (IsValid()) {
+ s->Close();
+ }
+ }
+
+ void TNetSocket::SendSelfFakePacket() const {
+ s->CancelWait();
+ }
+
+ inline ui32 CalcAddressCrc(const sockaddr_in6& addr) {
+ Y_ASSERT(addr.sin6_family == AF_INET6);
+ const ui64* addr64 = (const ui64*)addr.sin6_addr.s6_addr;
+ const ui32* addr32 = (const ui32*)addr.sin6_addr.s6_addr;
+ if (addr64[0] == 0 && addr32[2] == 0xffff0000ll) {
+ // ipv4
+ return addr32[3];
+ } else {
+ // ipv6
+ return GetIPv6SuffixCrc(addr);
+ }
+ }
+
+ TNetSocket::ESendError TNetSocket::SendTo(const char* buf, int size, const sockaddr_in6& toAddress, const EFragFlag frag) const {
+ Y_ASSERT(size >= UDP_LOW_LEVEL_HEADER_SIZE);
+ ui32 crc = CalcChecksum(buf + UDP_LOW_LEVEL_HEADER_SIZE, size - UDP_LOW_LEVEL_HEADER_SIZE);
+ ui32 ipCrc = CalcAddressCrc(toAddress);
+ ui32 portCrc = toAddress.sin6_port;
+ *(ui32*)buf = crc + ipCrc + portCrc;
+#ifdef SIMULATE_NETWORK_FAILURES
+ if ((RandomNumber<size_t>() % 3) == 0)
+ return true; // packet lost
+ if ((RandomNumber<size_t>() % 3) == 0)
+ (char&)(buf[RandomNumber<size_t>() % size]) += RandomNumber<size_t>(); // packet broken
+#endif
+
+ char tosBuffer[NNetlibaSocket::TOS_BUFFER_SIZE];
+ void* t = NNetlibaSocket::CreateTos(Tos, tosBuffer);
+ const NNetlibaSocket::TIoVec iov = NNetlibaSocket::CreateIoVec((char*)buf, size);
+ NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateSendMsgHdr(toAddress, iov, t);
+
+ const int rv = s->SendMsg(&hdr, 0, frag);
+ if (rv < 0) {
+ if (errno == EHOSTUNREACH || errno == ENETUNREACH) {
+ return SEND_NO_ROUTE_TO_HOST;
+ } else {
+ return SEND_BUFFER_OVERFLOW;
+ }
+ }
+ Y_ASSERT(rv == size);
+ return SEND_OK;
+ }
+
+ inline bool CrcMatches(ui32 pktCrc, ui32 crc, const sockaddr_in6& addr) {
+ Y_ASSERT(LocalHostFound);
+ Y_ASSERT(addr.sin6_family == AF_INET6);
+ // determine our ip address family based on the sender address
+ // address family can not change in network, so sender address type determines type of our address used
+ const ui64* addr64 = (const ui64*)addr.sin6_addr.s6_addr;
+ const ui32* addr32 = (const ui32*)addr.sin6_addr.s6_addr;
+ yint ipType;
+ if (addr64[0] == 0 && addr32[2] == 0xffff0000ll) {
+ // ipv4
+ ipType = IPv4;
+ } else {
+ // ipv6
+ ipType = IPv6;
+ }
+ if (crc + LocalHostIP[ipType] == pktCrc) {
+ return true;
+ }
+ // crc failed
+ // check if packet was sent to different IP address
+ for (int idx = 0; idx < LocalHostIPList[ipType].ysize(); ++idx) {
+ ui32 otherIP = LocalHostIPList[ipType][idx];
+ if (crc + otherIP == pktCrc) {
+ LocalHostIP[ipType] = otherIP;
+ return true;
+ }
+ }
+ // crc is really failed, discard packet
+ return false;
+ }
+
+ bool TNetSocket::RecvFrom(char* buf, int* size, sockaddr_in6* fromAddress) const {
+ for (;;) {
+ int rv;
+ if (s->IsRecvMsgSupported()) {
+ const NNetlibaSocket::TIoVec v = NNetlibaSocket::CreateIoVec(buf, *size);
+ NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateRecvMsgHdr(fromAddress, v);
+ rv = s->RecvMsg(&hdr, 0);
+
+ } else {
+ sockaddr_in6 dummy;
+ TAutoPtr<NNetlibaSocket::TUdpRecvPacket> pkt = s->Recv(fromAddress, &dummy, -1);
+ rv = !!pkt ? pkt->DataSize - pkt->DataStart : -1;
+ if (rv > 0) {
+ memcpy(buf, pkt->Data.get() + pkt->DataStart, rv);
+ }
+ }
+
+ if (rv < 0)
+ return false;
+ // ignore empty packets
+ if (rv == 0)
+ continue;
+ // skip small packets
+ if (rv < UDP_LOW_LEVEL_HEADER_SIZE)
+ continue;
+ *size = rv;
+ ui32 pktCrc = *(ui32*)buf;
+ ui32 crc = CalcChecksum(buf + UDP_LOW_LEVEL_HEADER_SIZE, rv - UDP_LOW_LEVEL_HEADER_SIZE);
+ if (!CrcMatches(pktCrc, crc + PortCrc, *fromAddress)) {
+ // crc is really failed, discard packet
+ continue;
+ }
+ return true;
+ }
+ }
+
+ void TNetSocket::Wait(float timeoutSec) const {
+ s->Wait(timeoutSec);
+ }
+
+ void TNetSocket::SetTOS(int n) const {
+ Tos = n;
+ }
+
+ bool TNetSocket::Connect(const sockaddr_in6& addr) {
+ // "connect" - meaningless operation
+ // needed since port unreachable is routed only to "connected" udp sockets in ingenious FreeBSD
+ if (s->Connect((sockaddr*)&addr, sizeof(addr)) < 0) {
+ if (errno == EHOSTUNREACH || errno == ENETUNREACH) {
+ return false;
+ } else {
+ Y_ASSERT(0);
+ }
+ }
+ return true;
+ }
+
+ void TNetSocket::SendEmptyPacket() {
+ NNetlibaSocket::TIoVec v;
+ Zero(v);
+
+ // darwin ignores packets with msg_iovlen == 0, also windows implementation uses sendto of first iovec.
+ NNetlibaSocket::TMsgHdr hdr;
+ Zero(hdr);
+ hdr.msg_iov = &v;
+ hdr.msg_iovlen = 1;
+
+ s->SendMsg(&hdr, 0, FF_ALLOW_FRAG); // sends empty packet to connected address
+ }
+
+ bool TNetSocket::IsHostUnreachable() {
+#ifdef _win_
+ char buf[10000];
+ sockaddr_in6 fromAddress;
+
+ const NNetlibaSocket::TIoVec v = NNetlibaSocket::CreateIoVec(buf, Y_ARRAY_SIZE(buf));
+ NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateRecvMsgHdr(&fromAddress, v);
+
+ const ssize_t rv = s->RecvMsg(&hdr, 0);
+ if (rv < 0) {
+ int err = WSAGetLastError();
+ if (err == WSAECONNRESET)
+ return true;
+ }
+#else
+ int err = 0;
+ socklen_t bufSize = sizeof(err);
+ s->GetSockOpt(SOL_SOCKET, SO_ERROR, (char*)&err, &bufSize);
+ if (err == ECONNREFUSED)
+ return true;
+#endif
+ return false;
+ }
+}
diff --git a/library/cpp/netliba/v6/udp_socket.h b/library/cpp/netliba/v6/udp_socket.h
new file mode 100644
index 0000000000..bd95b8bcd0
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_socket.h
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <util/generic/ptr.h>
+#include <util/generic/utility.h>
+#include <library/cpp/netliba/socket/socket.h>
+
+namespace NNetliba {
+ bool IsLocalIPv4(ui32 ip);
+ bool IsLocalIPv6(ui64 network, ui64 iface);
+ bool InitLocalIPList();
+
+ enum {
+ UDP_LOW_LEVEL_HEADER_SIZE = 4,
+ };
+
+ using NNetlibaSocket::EFragFlag;
+ using NNetlibaSocket::FF_ALLOW_FRAG;
+ using NNetlibaSocket::FF_DONT_FRAG;
+
+ class TNetSocket: public TNonCopyable {
+ TIntrusivePtr<NNetlibaSocket::ISocket> s;
+ ui32 PortCrc;
+ mutable int Tos;
+
+ public:
+ enum ESendError {
+ SEND_OK,
+ SEND_BUFFER_OVERFLOW,
+ SEND_NO_ROUTE_TO_HOST,
+ };
+ TNetSocket()
+ : PortCrc(0)
+ , Tos(0)
+ {
+ }
+ ~TNetSocket() {
+ }
+
+ void Open(int port);
+ void Open(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket);
+ void Close();
+ void SendSelfFakePacket() const;
+ bool IsValid() const {
+ return s.Get() ? s->IsValid() : false;
+ }
+ int GetNetworkOrderPort() const {
+ return s->GetNetworkOrderPort();
+ }
+ ESendError SendTo(const char* buf, int size, const sockaddr_in6& toAddress, const EFragFlag frag) const;
+ bool RecvFrom(char* buf, int* size, sockaddr_in6* fromAddress) const;
+ void Wait(float timeoutSec) const;
+ void SetTOS(int n) const;
+
+ // obtaining icmp host unreachable in convoluted way
+ bool Connect(const sockaddr_in6& addr);
+ void SendEmptyPacket();
+ bool IsHostUnreachable();
+ };
+}
diff --git a/library/cpp/netliba/v6/udp_test.cpp b/library/cpp/netliba/v6/udp_test.cpp
new file mode 100644
index 0000000000..d0af51a368
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_test.cpp
@@ -0,0 +1,161 @@
+#include "stdafx.h"
+#include "udp_test.h"
+#include "udp_client_server.h"
+#include "udp_http.h"
+#include "cpu_affinity.h"
+#include <util/system/hp_timer.h>
+#include <util/datetime/cputimer.h>
+#include <util/random/random.h>
+#include <util/random/fast.h>
+
+namespace NNetliba {
+ //static void PacketLevelTest(bool client)
+ //{
+ // int port = client ? 0 : 13013;
+ // TIntrusivePtr<IUdpHost> host = CreateUdpHost(&port);
+ //
+ // if(host == 0) {
+ // exit(-1);
+ // }
+ // TUdpAddress serverAddr = CreateAddress("localhost", 13013);
+ // vector<char> dummyPacket;
+ // dummyPacket.resize(10000);
+ // srand(GetCycleCount());
+ //
+ // for (int i = 0; i < dummyPacket.size(); ++i)
+ // dummyPacket[i] = rand();
+ // bool cont = true, hasReply = true;
+ // int reqCount = 1;
+ // for (int i = 0; cont; ++i) {
+ // host->Step();
+ // if (client) {
+ // //while (host->HasPendingData(serverAddr))
+ // // Sleep(0);
+ // if (hasReply) {
+ // printf("request %d\n", reqCount);
+ // *(int*)&dummyPacket[0] = reqCount;
+ // host->Send(serverAddr, dummyPacket, 0, PP_NORMAL);
+ // hasReply = false;
+ // ++reqCount;
+ // }else
+ // sleep(0);
+ //
+ // TRequest *req;
+ // while (req = host->GetRequest()) {
+ // int n = *(int*)&req->Data[0];
+ // printf("received response %d\n", n);
+ // Y_ASSERT(memcmp(&req->Data[4], &dummyPacket[4], dummyPacket.size() - 4) == 0);
+ // delete req;
+ // hasReply = true;
+ // }
+ // TSendResult sr;
+ // while (host->GetSendResult(&sr)) {
+ // if (!sr.Success) {
+ // printf("Send failed!\n");
+ // //Sleep(INFINITE);
+ // hasReply = true;
+ // }
+ // }
+ // } else {
+ // while (TRequest *req = host->GetRequest()) {
+ // int n = *(int*)&req->Data[0];
+ // printf("responding %d\n", n);
+ // host->Send(req->Address, req->Data, 0, PP_NORMAL);
+ // delete req;
+ // }
+ // TSendResult sr;
+ // while (host->GetSendResult(&sr)) {
+ // if (!sr.Success) {
+ // printf("Send failed!\n");
+ // sleep(0);
+ // }
+ // }
+ // sleep(0);
+ // }
+ // }
+ //}
+
+ static void SessionLevelTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort) {
+ BindToSocket(0);
+ TIntrusivePtr<IRequester> reqHost;
+ // reqHost = CreateHttpUdpRequester(13013);
+ reqHost = CreateHttpUdpRequester(client ? srcPort : 13013);
+ TUdpAddress serverAddr = CreateAddress(serverName, 13013);
+ TVector<char> dummyPacket;
+ dummyPacket.resize(packetSize);
+ TReallyFastRng32 rr((unsigned int)GetCycleCount());
+ for (size_t i = 0; i < dummyPacket.size(); ++i)
+ dummyPacket[i] = (char)rr.Uniform(256);
+ bool cont = true;
+ NHPTimer::STime t;
+ NHPTimer::GetTime(&t);
+ THashMap<TGUID, bool, TGUIDHash> seenReqs;
+ if (client) {
+ THashMap<TGUID, bool, TGUIDHash> reqList;
+ int packetsSentCount = 0;
+ TUdpHttpRequest* udpReq;
+ for (int i = 1; cont; ++i) {
+ for (;;) {
+ udpReq = reqHost->GetRequest();
+ if (udpReq == nullptr)
+ break;
+ udpReq->Data.resize(10);
+ reqHost->SendResponse(udpReq->ReqId, &udpReq->Data);
+ delete udpReq;
+ }
+ while (TUdpHttpResponse* res = reqHost->GetResponse()) {
+ THashMap<TGUID, bool, TGUIDHash>::iterator z = reqList.find(res->ReqId);
+ if (z == reqList.end()) {
+ printf("Unexpected response\n");
+ abort();
+ }
+ reqList.erase(z);
+ if (res->Ok) {
+ ++packetsSentCount;
+ //Y_ASSERT(res->Data == dummyPacket);
+ NHPTimer::STime tChk = t;
+ if (NHPTimer::GetTimePassed(&tChk) > 1) {
+ printf("packet size = %d\n", dummyPacket.ysize());
+ double passedTime = NHPTimer::GetTimePassed(&t);
+ double rate = packetsSentCount / passedTime;
+ printf("packet rate %g, transfer %gmb\n", rate, rate * dummyPacket.size() / 1000000);
+ packetsSentCount = 0;
+ }
+ } else {
+ printf("Failed request!\n");
+ //Sleep(INFINITE);
+ }
+ delete res;
+ }
+ while (reqList.ysize() < packetsInFly) {
+ *(int*)&dummyPacket[0] = i;
+ TVector<char> fakePacket = dummyPacket;
+ TGUID req2 = reqHost->SendRequest(serverAddr, "blaxuz", &fakePacket);
+ reqList[req2];
+ }
+ reqHost->GetAsyncEvent().Wait();
+ }
+ } else {
+ TUdpHttpRequest* req;
+ for (;;) {
+ req = reqHost->GetRequest();
+ if (req) {
+ if (seenReqs.find(req->ReqId) != seenReqs.end()) {
+ printf("Request %s recieved twice!\n", GetGuidAsString(req->ReqId).c_str());
+ }
+ seenReqs[req->ReqId];
+ req->Data.resize(10);
+ reqHost->SendResponse(req->ReqId, &req->Data);
+ delete req;
+ } else {
+ reqHost->GetAsyncEvent().Wait();
+ }
+ }
+ }
+ }
+
+ void RunUdpTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort) {
+ //PacketLevelTest(client);
+ SessionLevelTest(client, serverName, packetSize, packetsInFly, srcPort);
+ }
+}
diff --git a/library/cpp/netliba/v6/udp_test.h b/library/cpp/netliba/v6/udp_test.h
new file mode 100644
index 0000000000..68435fee5c
--- /dev/null
+++ b/library/cpp/netliba/v6/udp_test.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace NNetliba {
+ void RunUdpTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort = 0);
+}