diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/netliba/v6 | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/netliba/v6')
40 files changed, 9427 insertions, 0 deletions
diff --git a/library/cpp/netliba/v6/block_chain.cpp b/library/cpp/netliba/v6/block_chain.cpp new file mode 100644 index 0000000000..cbb56d9a5e --- /dev/null +++ b/library/cpp/netliba/v6/block_chain.cpp @@ -0,0 +1,90 @@ +#include "stdafx.h" +#include "block_chain.h" + +#include <util/system/unaligned_mem.h> + +namespace NNetliba { + ui32 CalcChecksum(const void* p, int size) { + //return 0; + //return CalcCrc32(p, size); + i64 sum = 0; + const unsigned char *pp = (const unsigned char*)p, *pend = pp + size; + for (const unsigned char* pend4 = pend - 3; pp < pend4; pp += 4) + sum += *(const ui32*)pp; + + ui32 left = 0, pos = 0; + for (; pp < pend; ++pp) { + pos += ((ui32)*pp) << left; + left += 8; + } + + sum += pos; + sum = (sum & 0xffffffff) + (sum >> 32); + sum += sum >> 32; + return (ui32)~sum; + } + + ui32 CalcChecksum(const TBlockChain& chain) { + TIncrementalChecksumCalcer ics; + AddChain(&ics, chain); + return ics.CalcChecksum(); + } + + void TIncrementalChecksumCalcer::AddBlock(const void* p, int size) { + ui32 sum = CalcBlockSum(p, size); + AddBlockSum(sum, size); + } + + void TIncrementalChecksumCalcer::AddBlockSum(ui32 sum, int size) { + for (int k = 0; k < Offset; ++k) + sum = (sum >> 24) + ((sum & 0xffffff) << 8); + TotalSum += sum; + + Offset = (Offset + size) & 3; + } + + ui32 TIncrementalChecksumCalcer::CalcBlockSum(const void* p, int size) { + i64 sum = 0; + const unsigned char *pp = (const unsigned char*)p, *pend = pp + size; + for (const unsigned char* pend4 = pend - 3; pp < pend4; pp += 4) + sum += ReadUnaligned<ui32>(pp); + + ui32 left = 0, pos = 0; + for (; pp < pend; ++pp) { + pos += ((ui32)*pp) << left; + left += 8; + } + sum += pos; + sum = (sum & 0xffffffff) + (sum >> 32); + sum += sum >> 32; + return (ui32)sum; + } + + ui32 TIncrementalChecksumCalcer::CalcChecksum() { + TotalSum = (TotalSum & 0xffffffff) + (TotalSum >> 32); + TotalSum += TotalSum >> 32; + return (ui32)~TotalSum; + } + + //void TestChainChecksum() + //{ + // TVector<char> data; + // data.resize(10); + // for (int i = 0; i < data.ysize(); ++i) + // data[i] = rand(); + // int crc1 = CalcChecksum(&data[0], data.size()); + // + // TBlockChain chain; + // TIncrementalChecksumCalcer incCS; + // for (int offset = 0; offset < data.ysize();) { + // int sz = Min(rand() % 10, data.ysize() - offset); + // chain.AddBlock(&data[offset], sz); + // incCS.AddBlock(&data[offset], sz); + // offset += sz; + // } + // int crc2 = CalcChecksum(chain); + // Y_ASSERT(crc1 == crc2); + // int crc3 = incCS.CalcChecksum(); + // Y_ASSERT(crc1 == crc3); + //} +} diff --git a/library/cpp/netliba/v6/block_chain.h b/library/cpp/netliba/v6/block_chain.h new file mode 100644 index 0000000000..25247ec05f --- /dev/null +++ b/library/cpp/netliba/v6/block_chain.h @@ -0,0 +1,319 @@ +#pragma once + +#include <util/generic/algorithm.h> +#include <util/generic/list.h> +#include <util/system/shmat.h> +#include <util/generic/noncopyable.h> + +namespace NNetliba { + class TBlockChain { + public: + struct TBlock { + const char* Data; + int Offset, Size; // Offset in whole chain + + TBlock() + : Data(nullptr) + , Offset(0) + , Size(0) + { + } + TBlock(const char* data, int offset, int sz) + : Data(data) + , Offset(offset) + , Size(sz) + { + } + }; + + private: + typedef TVector<TBlock> TBlockVector; + TBlockVector Blocks; + int Size; + struct TBlockLess { + bool operator()(const TBlock& b, int offset) const { + return b.Offset < offset; + } + }; + + public: + TBlockChain() + : Size(0) + { + } + void AddBlock(const void* data, int sz) { + Blocks.push_back(TBlock((const char*)data, Size, sz)); + Size += sz; + } + int GetSize() const { + return Size; + } + const TBlock& GetBlock(int i) const { + return Blocks[i]; + } + int GetBlockCount() const { + return Blocks.ysize(); + } + int GetBlockIdByOffset(int offset) const { + TBlockVector::const_iterator i = LowerBound(Blocks.begin(), Blocks.end(), offset, TBlockLess()); + if (i == Blocks.end()) + return Blocks.ysize() - 1; + if (i->Offset == offset) + return (int)(i - Blocks.begin()); + return (int)(i - Blocks.begin() - 1); + } + }; + + ////////////////////////////////////////////////////////////////////////// + class TBlockChainIterator { + const TBlockChain& Chain; + int Pos, BlockPos, BlockId; + bool Failed; + + public: + TBlockChainIterator(const TBlockChain& chain) + : Chain(chain) + , Pos(0) + , BlockPos(0) + , BlockId(0) + , Failed(false) + { + } + void Read(void* dst, int sz) { + char* dstBuf = (char*)dst; + while (sz > 0) { + if (BlockId >= Chain.GetBlockCount()) { + // JACKPOT! + fprintf(stderr, "reading beyond chain end: BlockId %d, Chain.GetBlockCount() %d, Pos %d, BlockPos %d\n", + BlockId, Chain.GetBlockCount(), Pos, BlockPos); + Y_ASSERT(0 && "reading beyond chain end"); + memset(dstBuf, 0, sz); + Failed = true; + return; + } + const TBlockChain::TBlock& blk = Chain.GetBlock(BlockId); + int copySize = Min(blk.Size - BlockPos, sz); + memcpy(dstBuf, blk.Data + BlockPos, copySize); + dstBuf += copySize; + Pos += copySize; + BlockPos += copySize; + sz -= copySize; + if (BlockPos == blk.Size) { + BlockPos = 0; + ++BlockId; + } + } + } + void Seek(int pos) { + if (pos < 0 || pos > Chain.GetSize()) { + Y_ASSERT(0); + Pos = 0; + BlockPos = 0; + BlockId = 0; + return; + } + BlockId = Chain.GetBlockIdByOffset(pos); + const TBlockChain::TBlock& blk = Chain.GetBlock(BlockId); + Pos = pos; + BlockPos = Pos - blk.Offset; + } + int GetPos() const { + return Pos; + } + int GetSize() const { + return Chain.GetSize(); + } + bool HasFailed() const { + return Failed; + } + void Fail() { + Failed = true; + } + }; + + ////////////////////////////////////////////////////////////////////////// + class TRopeDataPacket: public TNonCopyable { + TBlockChain Chain; + TVector<char*> Buf; + char *Block, *BlockEnd; + TList<TVector<char>> DataVectors; + TIntrusivePtr<TSharedMemory> SharedData; + TVector<TIntrusivePtr<TThrRefBase>> AttachedStorage; + char DefaultBuf[128]; // prevent allocs in most cases + enum { + N_DEFAULT_BLOCK_SIZE = 1024 + }; + + char* Alloc(int sz) { + char* res = nullptr; + if (BlockEnd - Block < sz) { + int bufSize = Max((int)N_DEFAULT_BLOCK_SIZE, sz); + char* newBlock = AllocBuf(bufSize); + Block = newBlock; + BlockEnd = Block + bufSize; + Buf.push_back(newBlock); + } + res = Block; + Block += sz; + Y_ASSERT(Block <= BlockEnd); + return res; + } + + public: + TRopeDataPacket() + : Block(DefaultBuf) + , BlockEnd(DefaultBuf + Y_ARRAY_SIZE(DefaultBuf)) + { + } + ~TRopeDataPacket() { + for (size_t i = 0; i < Buf.size(); ++i) + FreeBuf(Buf[i]); + } + static char* AllocBuf(int sz) { + return new char[sz]; + } + static void FreeBuf(char* buf) { + delete[] buf; + } + + // buf - pointer to buffer which will be freed with FreeBuf() + // data - pointer to data start within buf + // sz - size of useful data + void AddBlock(char* buf, const char* data, int sz) { + Buf.push_back(buf); + Chain.AddBlock(data, sz); + } + void AddBlock(TThrRefBase* buf, const char* data, int sz) { + AttachedStorage.push_back(buf); + Chain.AddBlock(data, sz); + } + // + void Write(const void* data, int sz) { + char* buf = Alloc(sz); + memcpy(buf, data, sz); + Chain.AddBlock(buf, sz); + } + template <class T> + void Write(const T& data) { + Write(&data, sizeof(T)); + } + //// caller guarantees that data will persist all *this lifetime + //// int this case so we don`t have to copy data to locally held buffer + //template<class T> + //void WriteNoCopy(const T *data) + //{ + // Chain.AddBlock(data, sizeof(T)); + //} + // write some array like TVector<> + //template<class T> + //void WriteArr(const T &sz) + //{ + // int n = (int)sz.size(); + // Write(n); + // if (n > 0) + // Write(&sz[0], n * sizeof(sz[0])); + //} + void WriteStroka(const TString& sz) { + int n = (int)sz.size(); + Write(n); + if (n > 0) + Write(sz.c_str(), n * sizeof(sz[0])); + } + // will take *data ownership, saves copy + void WriteDestructive(TVector<char>* data) { + int n = data ? data->ysize() : 0; + Write(n); + if (n > 0) { + TVector<char>& local = DataVectors.emplace_back(std::move(*data)); + Chain.AddBlock(&local[0], local.ysize()); + } + } + void AttachSharedData(TIntrusivePtr<TSharedMemory> shm) { + SharedData = shm; + } + TSharedMemory* GetSharedData() const { + return SharedData.Get(); + } + const TBlockChain& GetChain() { + return Chain; + } + int GetSize() { + return Chain.GetSize(); + } + }; + + template <class T> + inline void ReadArr(TBlockChainIterator* res, T* dst) { + int n; + res->Read(&n, sizeof(n)); + if (n >= 0) { + dst->resize(n); + if (n > 0) + res->Read(&(*dst)[0], n * sizeof((*dst)[0])); + } else { + dst->resize(0); + res->Fail(); + } + } + + template <> + inline void ReadArr<TString>(TBlockChainIterator* res, TString* dst) { + int n; + res->Read(&n, sizeof(n)); + if (n >= 0) { + dst->resize(n); + if (n > 0) + res->Read(dst->begin(), n * sizeof(TString::value_type)); + } else { + dst->resize(0); + res->Fail(); + } + } + + // saves on zeroing *dst with yresize() + template <class T> + static void ReadYArr(TBlockChainIterator* res, TVector<T>* dst) { + int n; + res->Read(&n, sizeof(n)); + if (n >= 0) { + dst->yresize(n); + if (n > 0) + res->Read(&(*dst)[0], n * sizeof((*dst)[0])); + } else { + dst->yresize(0); + res->Fail(); + } + } + + template <class T> + static void Read(TBlockChainIterator* res, T* dst) { + res->Read(dst, sizeof(T)); + } + + ui32 CalcChecksum(const void* p, int size); + ui32 CalcChecksum(const TBlockChain& chain); + + class TIncrementalChecksumCalcer { + i64 TotalSum; + int Offset; + + public: + TIncrementalChecksumCalcer() + : TotalSum(0) + , Offset(0) + { + } + void AddBlock(const void* p, int size); + void AddBlockSum(ui32 sum, int size); + ui32 CalcChecksum(); + + static ui32 CalcBlockSum(const void* p, int size); + }; + + inline void AddChain(TIncrementalChecksumCalcer* ics, const TBlockChain& chain) { + for (int k = 0; k < chain.GetBlockCount(); ++k) { + const TBlockChain::TBlock& blk = chain.GetBlock(k); + ics->AddBlock(blk.Data, blk.Size); + } + } +} diff --git a/library/cpp/netliba/v6/cpu_affinity.cpp b/library/cpp/netliba/v6/cpu_affinity.cpp new file mode 100644 index 0000000000..7808092a72 --- /dev/null +++ b/library/cpp/netliba/v6/cpu_affinity.cpp @@ -0,0 +1,138 @@ +#include "stdafx.h" +#include "cpu_affinity.h" + +#if defined(__FreeBSD__) && (__FreeBSD__ >= 7) +#include <sys/param.h> +#include <sys/cpuset.h> +#elif defined(_linux_) +#include <pthread.h> +#include <util/stream/file.h> +#include <util/string/printf.h> +#endif + +namespace NNetliba { + class TCPUSet { + public: + enum { MAX_SIZE = 128 }; + + private: +#if defined(__FreeBSD__) && (__FreeBSD__ >= 7) +#define NUMCPU ((CPU_MAXSIZE > MAX_SIZE) ? 1 : (MAX_SIZE / CPU_MAXSIZE)) + cpuset_t CpuInfo[NUMCPU]; + + public: + bool GetAffinity() { + int error = cpuset_getaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuInfo), CpuInfo); + return error == 0; + } + bool SetAffinity() { + int error = cpuset_setaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuInfo), CpuInfo); + return error == 0; + } + bool IsSet(size_t i) { + return CPU_ISSET(i, CpuInfo); + } + void Set(size_t i) { + CPU_SET(i, CpuInfo); + } +#elif defined(_linux_) + public: +#define NUMCPU ((CPU_SETSIZE > MAX_SIZE) ? 1 : (MAX_SIZE / CPU_SETSIZE)) + cpu_set_t CpuInfo[NUMCPU]; + + public: + bool GetAffinity() { + int error = pthread_getaffinity_np(pthread_self(), sizeof(CpuInfo), CpuInfo); + return error == 0; + } + bool SetAffinity() { + int error = pthread_setaffinity_np(pthread_self(), sizeof(CpuInfo), CpuInfo); + return error == 0; + } + bool IsSet(size_t i) { + return CPU_ISSET(i, CpuInfo); + } + void Set(size_t i) { + CPU_SET(i, CpuInfo); + } +#else + public: + bool GetAffinity() { + return true; + } + bool SetAffinity() { + return true; + } + bool IsSet(size_t i) { + Y_UNUSED(i); + return true; + } + void Set(size_t i) { + Y_UNUSED(i); + } +#endif + + TCPUSet() { + Clear(); + } + void Clear() { + memset(this, 0, sizeof(*this)); + } + }; + + static TMutex CPUSetsLock; + struct TCPUSetInfo { + TCPUSet CPUSet; + bool IsOk; + + TCPUSetInfo() + : IsOk(false) + { + } + }; + static THashMap<int, TCPUSetInfo> CPUSets; + + void BindToSocket(int n) { + TGuard<TMutex> gg(CPUSetsLock); + if (CPUSets.find(n) == CPUSets.end()) { + TCPUSetInfo& res = CPUSets[n]; + + bool foundCPU = false; +#ifdef _linux_ + for (int cpuId = 0; cpuId < TCPUSet::MAX_SIZE; ++cpuId) { + try { // I just wanna check if file exists, I don't want your stinking exceptions :/ + TIFStream f(Sprintf("/sys/devices/system/cpu/cpu%d/topology/physical_package_id", cpuId).c_str()); + TString s; + if (f.ReadLine(s) && !s.empty()) { + //printf("cpu%d - %s\n", cpuId, s.c_str()); + int physCPU = atoi(s.c_str()); + if (physCPU == 0) { + res.IsOk = true; + res.CPUSet.Set(cpuId); + foundCPU = true; + } + } else { + break; + } + } catch (const TFileError&) { + break; + } + } +#endif + if (!foundCPU && n == 0) { + for (int i = 0; i < 6; ++i) { + res.CPUSet.Set(i); + } + res.IsOk = true; + foundCPU = true; + } + } + { + TCPUSetInfo& cc = CPUSets[n]; + if (cc.IsOk) { + cc.CPUSet.SetAffinity(); + } + } + } + +} diff --git a/library/cpp/netliba/v6/cpu_affinity.h b/library/cpp/netliba/v6/cpu_affinity.h new file mode 100644 index 0000000000..a580edd829 --- /dev/null +++ b/library/cpp/netliba/v6/cpu_affinity.h @@ -0,0 +1,5 @@ +#pragma once + +namespace NNetliba { + void BindToSocket(int n); +} diff --git a/library/cpp/netliba/v6/cstdafx.h b/library/cpp/netliba/v6/cstdafx.h new file mode 100644 index 0000000000..e11fe54cc4 --- /dev/null +++ b/library/cpp/netliba/v6/cstdafx.h @@ -0,0 +1,41 @@ +#pragma once + +#ifdef _WIN32 +#pragma warning(disable : 4530 4244 4996) +#include <malloc.h> +#include <util/system/winint.h> +#endif + +#include <util/system/defaults.h> +#include <util/system/mutex.h> +#include <util/system/event.h> +#include <library/cpp/deprecated/atomic/atomic.h> +#include <util/system/yassert.h> +#include <util/system/compat.h> + +#include <util/ysafeptr.h> + +#include <util/stream/output.h> + +#include <library/cpp/string_utils/url/url.h> + +#include <library/cpp/charset/codepage.h> +#include <library/cpp/charset/recyr.hh> + +#include <util/generic/vector.h> +#include <util/generic/hash.h> +#include <util/generic/list.h> +#include <util/generic/hash_set.h> +#include <util/generic/ptr.h> +#include <util/generic/ymath.h> +#include <util/generic/utility.h> +#include <util/generic/algorithm.h> + +#include <array> +#include <cstdlib> +#include <cstdio> + +namespace NNetliba { + typedef unsigned char byte; + typedef ssize_t yint; +} diff --git a/library/cpp/netliba/v6/ib_buffers.cpp b/library/cpp/netliba/v6/ib_buffers.cpp new file mode 100644 index 0000000000..323846b633 --- /dev/null +++ b/library/cpp/netliba/v6/ib_buffers.cpp @@ -0,0 +1,5 @@ +#include "stdafx.h" +#include "ib_buffers.h" + +namespace NNetliba { +} diff --git a/library/cpp/netliba/v6/ib_buffers.h b/library/cpp/netliba/v6/ib_buffers.h new file mode 100644 index 0000000000..96d83ff654 --- /dev/null +++ b/library/cpp/netliba/v6/ib_buffers.h @@ -0,0 +1,181 @@ +#pragma once + +#include "ib_low.h" + +namespace NNetliba { + // buffer id 0 is special, it is used when data is sent inline and should not be returned + const size_t SMALL_PKT_SIZE = 4096; + + const ui64 BP_AH_USED_FLAG = 0x1000000000000000ul; + const ui64 BP_BUF_ID_MASK = 0x00000000fffffffful; + + // single thread version + class TIBBufferPool: public TThrRefBase, TNonCopyable { + enum { + BLOCK_SIZE_LN = 11, + BLOCK_SIZE = 1 << BLOCK_SIZE_LN, + BLOCK_COUNT = 1024 + }; + struct TSingleBlock { + TIntrusivePtr<TMemoryRegion> Mem; + TVector<ui8> BlkRefCounts; + TVector<TIntrusivePtr<TAddressHandle>> AHHolder; + + void Alloc(TPtrArg<TIBContext> ctx) { + size_t dataSize = SMALL_PKT_SIZE * BLOCK_SIZE; + Mem = new TMemoryRegion(ctx, dataSize); + BlkRefCounts.resize(BLOCK_SIZE, 0); + AHHolder.resize(BLOCK_SIZE); + } + char* GetBufData(ui64 idArg) { + char* data = Mem->GetData(); + return data + (idArg & (BLOCK_SIZE - 1)) * SMALL_PKT_SIZE; + } + }; + + TIntrusivePtr<TIBContext> IBCtx; + TVector<int> FreeList; + TVector<TSingleBlock> Blocks; + size_t FirstFreeBlock; + int PostRecvDeficit; + TIntrusivePtr<TSharedReceiveQueue> SRQ; + + void AddBlock() { + if (FirstFreeBlock == Blocks.size()) { + Y_VERIFY(0, "run out of buffers"); + } + Blocks[FirstFreeBlock].Alloc(IBCtx); + size_t start = (FirstFreeBlock == 0) ? 1 : FirstFreeBlock * BLOCK_SIZE; + size_t finish = FirstFreeBlock * BLOCK_SIZE + BLOCK_SIZE; + for (size_t i = start; i < finish; ++i) { + FreeList.push_back(i); + } + ++FirstFreeBlock; + } + + public: + TIBBufferPool(TPtrArg<TIBContext> ctx, int maxSRQWorkRequests) + : IBCtx(ctx) + , FirstFreeBlock(0) + , PostRecvDeficit(maxSRQWorkRequests) + { + Blocks.resize(BLOCK_COUNT); + AddBlock(); + SRQ = new TSharedReceiveQueue(ctx, maxSRQWorkRequests); + + PostRecv(); + } + TSharedReceiveQueue* GetSRQ() const { + return SRQ.Get(); + } + int AllocBuf() { + if (FreeList.empty()) { + AddBlock(); + } + int id = FreeList.back(); + FreeList.pop_back(); + Y_ASSERT(++Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 1); + return id; + } + void FreeBuf(ui64 idArg) { + ui64 id = idArg & BP_BUF_ID_MASK; + if (id == 0) { + return; + } + Y_ASSERT(id > 0 && id < (ui64)(FirstFreeBlock * BLOCK_SIZE)); + FreeList.push_back(id); + Y_ASSERT(--Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 0); + if (idArg & BP_AH_USED_FLAG) { + Blocks[id >> BLOCK_SIZE_LN].AHHolder[id & (BLOCK_SIZE - 1)] = nullptr; + } + } + char* GetBufData(ui64 idArg) { + ui64 id = idArg & BP_BUF_ID_MASK; + return Blocks[id >> BLOCK_SIZE_LN].GetBufData(id); + } + int PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) { + if (len > SMALL_PKT_SIZE) { + Y_VERIFY(0, "buffer overrun"); + } + if (len <= MAX_INLINE_DATA_SIZE) { + qp->PostSend(nullptr, 0, data, len); + return 0; + } else { + int id = AllocBuf(); + TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; + char* buf = blk.GetBufData(id); + memcpy(buf, data, len); + qp->PostSend(blk.Mem, id, buf, len); + return id; + } + } + void PostSend(TPtrArg<TUDQueuePair> qp, TPtrArg<TAddressHandle> ah, int remoteQPN, int remoteQKey, + const void* data, size_t len) { + if (len > SMALL_PKT_SIZE - 40) { + Y_VERIFY(0, "buffer overrun"); + } + ui64 id = AllocBuf(); + TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; + int ptr = id & (BLOCK_SIZE - 1); + blk.AHHolder[ptr] = ah.Get(); + id |= BP_AH_USED_FLAG; + if (len <= MAX_INLINE_DATA_SIZE) { + qp->PostSend(ah, remoteQPN, remoteQKey, nullptr, id, data, len); + } else { + char* buf = blk.GetBufData(id); + memcpy(buf, data, len); + qp->PostSend(ah, remoteQPN, remoteQKey, blk.Mem, id, buf, len); + } + } + void RequestPostRecv() { + ++PostRecvDeficit; + } + void PostRecv() { + for (int i = 0; i < PostRecvDeficit; ++i) { + int id = AllocBuf(); + TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN]; + char* buf = blk.GetBufData(id); + SRQ->PostReceive(blk.Mem, id, buf, SMALL_PKT_SIZE); + } + PostRecvDeficit = 0; + } + }; + + class TIBRecvPacketProcess: public TNonCopyable { + TIBBufferPool& BP; + ui64 Id; + char* Data; + + public: + TIBRecvPacketProcess(TIBBufferPool& bp, const ibv_wc& wc) + : BP(bp) + , Id(wc.wr_id) + { + Y_ASSERT(wc.opcode & IBV_WC_RECV); + BP.RequestPostRecv(); + Data = BP.GetBufData(Id); + } + // intended for postponed packet processing + // with this call RequestPostRecv() should be called outside (and PostRecv() too in order to avoid rnr situation) + TIBRecvPacketProcess(TIBBufferPool& bp, const ui64 wr_id) + : BP(bp) + , Id(wr_id) + { + Data = BP.GetBufData(Id); + } + ~TIBRecvPacketProcess() { + BP.FreeBuf(Id); + BP.PostRecv(); + } + char* GetData() const { + return Data; + } + char* GetUDData() const { + return Data + 40; + } + ibv_grh* GetGRH() const { + return (ibv_grh*)Data; + } + }; + +} diff --git a/library/cpp/netliba/v6/ib_collective.cpp b/library/cpp/netliba/v6/ib_collective.cpp new file mode 100644 index 0000000000..87ea166366 --- /dev/null +++ b/library/cpp/netliba/v6/ib_collective.cpp @@ -0,0 +1,1317 @@ +#include "stdafx.h" +#include "ib_collective.h" +#include "ib_mem.h" +#include "ib_buffers.h" +#include "ib_low.h" +#include "udp_http.h" +#include "udp_address.h" +#include <util/generic/deque.h> +#include <util/system/hp_timer.h> + +namespace NNetliba { + const int COL_SERVICE_LEVEL = 2; + const int COL_DATA_SERVICE_LEVEL = 2; // base level + const int COL_DATA_SERVICE_LEVEL_COUNT = 6; // level count + const int MAX_REQS_PER_PEER = 32; + const int MAX_TOTAL_RDMA = 20; + const int SEND_COUNT_TABLE_SIZE = 1 << 12; // must be power of 2 + + struct TMergeRecord { + struct TTransfer { + int DstRank; + int SL; + int RangeBeg, RangeFin; + int Id; + + TTransfer() + : DstRank(-1) + , SL(0) + , RangeBeg(0) + , RangeFin(0) + , Id(0) + { + } + TTransfer(int dstRank, int sl, int rangeBeg, int rangeFin, int id) + : DstRank(dstRank) + , SL(sl) + , RangeBeg(rangeBeg) + , RangeFin(rangeFin) + , Id(id) + { + } + }; + struct TInTransfer { + int SrcRank; + int SL; + + TInTransfer() + : SrcRank(-1) + , SL(0) + { + } + TInTransfer(int srcRank, int sl) + : SrcRank(srcRank) + , SL(sl) + { + } + }; + + TVector<TTransfer> OutList; + TVector<TInTransfer> InList; + ui64 RecvMask; + + TMergeRecord() + : RecvMask(0) + { + } + }; + + struct TMergeIteration { + TVector<TMergeRecord> Ops; + + void Init(int colSize) { + Ops.resize(colSize); + } + void Transfer(int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin, int id) { + Y_VERIFY(id < 64, "recv mask overflow"); + Ops[srcRank].OutList.push_back(TMergeRecord::TTransfer(dstRank, sl, rangeBeg, rangeFin, id)); + Ops[dstRank].InList.push_back(TMergeRecord::TInTransfer(srcRank, sl)); + Ops[dstRank].RecvMask |= ui64(1) << id; + } + }; + + struct TMergePlan { + TVector<TMergeIteration> Iterations; + TVector<int> RankReceiveCount; + int ColSize; + int MaxRankReceiveCount; + + TMergePlan() + : ColSize(0) + , MaxRankReceiveCount(0) + { + } + void Init(int colSize) { + Iterations.resize(0); + RankReceiveCount.resize(0); + RankReceiveCount.resize(colSize, 0); + ColSize = colSize; + } + void Transfer(int iter, int srcRank, int dstRank, int sl, int rangeBeg, int rangeFin) { + while (iter >= Iterations.ysize()) { + TMergeIteration& res = Iterations.emplace_back(); + res.Init(ColSize); + } + int id = RankReceiveCount[dstRank]++; + MaxRankReceiveCount = Max(MaxRankReceiveCount, id + 1); + Y_ASSERT(id < 64); + Iterations[iter].Transfer(srcRank, dstRank, sl, rangeBeg, rangeFin, id); + } + }; + + struct TSRTransfer { + int SrcRank, DstRank; + int RangeBeg, RangeFin; + + TSRTransfer() { + Zero(*this); + } + TSRTransfer(int srcRank, int dstRank, int rangeBeg, int rangeFin) + : SrcRank(srcRank) + , DstRank(dstRank) + , RangeBeg(rangeBeg) + , RangeFin(rangeFin) + { + } + }; + + static int SplitRange(THashMap<int, TVector<TSRTransfer>>* res, int iter, int beg, int fin) { + int mid = (beg + fin + 1) / 2; + if (mid == fin) { + return iter; + } + for (int i = 0; i < fin - mid; ++i) { + (*res)[iter].push_back(TSRTransfer(beg + i, mid + i, beg, mid)); + (*res)[iter].push_back(TSRTransfer(mid + i, beg + i, mid, fin)); + } + if (fin - mid < mid - beg) { + // [mid - 1] did not receive [mid;fin) + (*res)[iter].push_back(TSRTransfer(mid, mid - 1, mid, fin)); + } + int rv1 = SplitRange(res, iter + 1, beg, mid); + int rv2 = SplitRange(res, iter + 1, mid, fin); + return Max(rv1, rv2); + } + + static void CreatePow2Merge(TMergePlan* plan, int colSize) { + // finally everybody has full range [0;ColSize) + // construct plan recursively, on each iteration split some range + plan->Init(colSize); + + THashMap<int, TVector<TSRTransfer>> allTransfers; + int maxIter = SplitRange(&allTransfers, 0, 0, colSize); + + for (int iter = 0; iter < maxIter; ++iter) { + const TVector<TSRTransfer>& arr = allTransfers[maxIter - iter - 1]; // reverse order + for (int i = 0; i < arr.ysize(); ++i) { + const TSRTransfer& sr = arr[i]; + plan->Transfer(iter, sr.SrcRank, sr.DstRank, 0, sr.RangeBeg, sr.RangeFin); + } + } + } + + struct TCoverInterval { + int Beg, Fin; // [Beg;Fin) + + TCoverInterval() + : Beg(0) + , Fin(0) + { + } + TCoverInterval(int b, int f) + : Beg(b) + , Fin(f) + { + } + }; + + enum EAllToAllMode { + AA_POW2, + AA_CIRCLE, + AA_STAR, + AA_POW2_MERGE, + }; + static int AllToAll(TMergePlan* plan, int iter, int sl, EAllToAllMode mode, const TVector<int>& myGroup, TVector<TCoverInterval>* cover) { + TVector<TCoverInterval>& hostCoverage = *cover; + int groupSize = myGroup.ysize(); + + for (int k = 1; k < groupSize; ++k) { + int h1 = myGroup[k - 1]; + int h2 = myGroup[k]; + Y_VERIFY(hostCoverage[h1].Fin == hostCoverage[h2].Beg, "Invalid host order in CreateGroupMerge()"); + } + + switch (mode) { + case AA_POW2: { + for (int delta = 1; delta < groupSize; delta *= 2) { + int sz = Min(delta, groupSize - delta); + for (int offset = 0; offset < groupSize; ++offset) { + int srcRank = myGroup[offset]; + int dstRank = myGroup[(offset + delta) % groupSize]; + + int start = offset + 1 - sz; + int finish = offset + 1; + if (start < 0) { + // [start; myGroup.size()) + int dataBeg = hostCoverage[myGroup[start + groupSize]].Beg; + int dataFin = hostCoverage[myGroup.back()].Fin; + plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); + // [0; finish) + dataBeg = hostCoverage[myGroup[0]].Beg; + dataFin = hostCoverage[myGroup[finish - 1]].Fin; + plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); + } else { + // [start;finish) + int dataBeg = hostCoverage[myGroup[start]].Beg; + int dataFin = hostCoverage[myGroup[finish - 1]].Fin; + plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); + } + } + ++iter; + } + } break; + case AA_CIRCLE: { + for (int dataDelta = 1; dataDelta < groupSize; ++dataDelta) { + for (int offset = 0; offset < groupSize; ++offset) { + int srcRank = myGroup[offset]; + int dstRank = myGroup[(offset + 1) % groupSize]; + + int dataRank = myGroup[(offset + 1 - dataDelta + groupSize) % groupSize]; + int dataBeg = hostCoverage[dataRank].Beg; + int dataFin = hostCoverage[dataRank].Fin; + + plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); + } + ++iter; + } + } break; + case AA_STAR: { + for (int offset = 0; offset < groupSize; ++offset) { + for (int delta = 1; delta < groupSize; ++delta) { + int srcRank = myGroup[offset]; + int dstRank = myGroup[(offset + delta) % groupSize]; + + int dataRank = myGroup[offset]; + int dataBeg = hostCoverage[dataRank].Beg; + int dataFin = hostCoverage[dataRank].Fin; + + plan->Transfer(iter, srcRank, dstRank, sl, dataBeg, dataFin); + } + } + ++iter; + } break; + case AA_POW2_MERGE: { + TMergePlan pp; + CreatePow2Merge(&pp, groupSize); + for (int z = 0; z < pp.Iterations.ysize(); ++z) { + const TMergeIteration& mm = pp.Iterations[z]; + for (int src = 0; src < mm.Ops.ysize(); ++src) { + const TMergeRecord& mr = mm.Ops[src]; + int srcRank = myGroup[src]; + for (int i = 0; i < mr.OutList.ysize(); ++i) { + int dstRank = myGroup[mr.OutList[i].DstRank]; + plan->Transfer(iter, srcRank, dstRank, sl, 0, 1); + } + } + ++iter; + } + } break; + default: + Y_ASSERT(0); + break; + } + { + TCoverInterval cc(hostCoverage[myGroup[0]].Beg, hostCoverage[myGroup.back()].Fin); + for (int k = 0; k < groupSize; ++k) { + hostCoverage[myGroup[k]] = cc; + } + } + return iter; + } + + // fully populated matrix + static void CreateGroupMerge(TMergePlan* plan, EAllToAllMode mode, const TVector<TVector<int>>& hostGroup) { + int hostCount = hostGroup[0].ysize(); + int groupTypeCount = hostGroup.ysize(); + + plan->Init(hostCount); + + TVector<int> gcount; + gcount.resize(groupTypeCount, 0); + for (int hostId = 0; hostId < hostCount; ++hostId) { + for (int groupType = 0; groupType < groupTypeCount; ++groupType) { + int val = hostGroup[groupType][hostId]; + gcount[groupType] = Max(gcount[groupType], val + 1); + } + } + for (int hostId = 1; hostId < hostCount; ++hostId) { + bool isIncrement = true; + for (int groupType = 0; groupType < groupTypeCount; ++groupType) { + int prev = hostGroup[groupType][hostId - 1]; + int cur = hostGroup[groupType][hostId]; + if (isIncrement) { + if (cur == prev + 1) { + isIncrement = false; + } else { + Y_VERIFY(cur == 0, "ib_hosts, wrapped to non-zero"); + Y_VERIFY(prev == gcount[groupType] - 1, "ib_hosts, structure is irregular"); + isIncrement = true; + } + } else { + Y_VERIFY(prev == cur, "ib_hosts, structure is irregular"); + } + } + } + + TVector<TCoverInterval> hostCoverage; + for (int i = 0; i < hostCount; ++i) { + hostCoverage.push_back(TCoverInterval(i, i + 1)); + } + + int baseIter = 0; + for (int groupType = hostGroup.ysize() - 1; groupType >= 0; --groupType) { + Y_ASSERT(hostGroup[groupType].ysize() == hostCount); + TVector<TVector<int>> hh; + hh.resize(gcount[groupType]); + for (int rank = 0; rank < hostGroup[groupType].ysize(); ++rank) { + int groupId = hostGroup[groupType][rank]; + hh[groupId].push_back(rank); + } + int newIter = 0; + for (int groupId = 0; groupId < hh.ysize(); ++groupId) { + int nn = AllToAll(plan, baseIter, 0, mode, hh[groupId], &hostCoverage); // seems to be fastest + if (newIter == 0) { + newIter = nn; + } else { + Y_VERIFY(newIter == nn, "groups should be of same size"); + } + } + baseIter = newIter; + } + //printf("%d iterations symmetrical plan\n", baseIter); + } + + ////////////////////////////////////////////////////////////////////////// + struct TAllDataSync { + enum { + WR_COUNT = 64 * 2 + }; + + int CurrentBuffer; + TIntrusivePtr<TIBMemBlock> MemBlock[2]; + TIntrusivePtr<TComplectionQueue> CQ; + TIntrusivePtr<TSharedReceiveQueue> SRQ; + TIntrusivePtr<TIBMemBlock> FakeRecvMem; + size_t DataSize, BufSize; + size_t CurrentOffset, ReadyOffset; + bool WasFlushed; + int ActiveRDMACount; + ui64 FutureRecvMask; + TIntrusivePtr<IReduceOp> ReduceOp; + + struct TBlockInfo { + ui64 Addr; + ui32 Key; + }; + struct TSend { + TBlockInfo RemoteBlocks[2]; + TIntrusivePtr<TRCQueuePair> QP; + size_t SrcOffset; + size_t DstOffset; + size_t Length; + ui32 ImmData; + int DstRank; + union { + struct { + int RangeBeg, RangeFin; + } Gather; + struct { + int SrcIndex, DstIndex; + } Reduce; + }; + }; + struct TRecv { + TIntrusivePtr<TRCQueuePair> QP; + int SrcRank; + }; + struct TReduce { + size_t DstOffset, SrcOffset; + int DstIndex, SrcIndex; + }; + struct TIteration { + TVector<TSend> OutList; + TVector<TRecv> InList; + TVector<TReduce> ReduceList; + ui64 RecvMask; + }; + TVector<TIteration> Iterations; + + public: + void* GetRawData() { + char* myData = (char*)MemBlock[CurrentBuffer]->GetData(); + return myData + CurrentOffset; + } + size_t GetRawDataSize() { + return DataSize; + } + void PostRecv() { + SRQ->PostReceive(FakeRecvMem->GetMemRegion(), 0, FakeRecvMem->GetData(), FakeRecvMem->GetSize()); + } + void Sync() { + Y_ASSERT(WasFlushed && "Have to call Flush() before data fill & Sync()"); + char* myData = (char*)MemBlock[CurrentBuffer]->GetData(); + + ui64 recvMask = FutureRecvMask; + FutureRecvMask = 0; + int recvDebt = 0; + for (int z = 0; z < Iterations.ysize(); ++z) { + const TIteration& iter = Iterations[z]; + for (int k = 0; k < iter.OutList.ysize(); ++k) { + const TSend& ss = iter.OutList[k]; + const TBlockInfo& remoteBlk = ss.RemoteBlocks[CurrentBuffer]; + ss.QP->PostRDMAWriteImm(remoteBlk.Addr + ss.DstOffset, remoteBlk.Key, ss.ImmData, + MemBlock[CurrentBuffer]->GetMemRegion(), 0, myData + ss.SrcOffset, ss.Length); + ++ActiveRDMACount; + //printf("-> %d, imm %d (%" PRId64 " bytes)\n", ss.DstRank, ss.ImmData, ss.Length); + //printf("send %d\n", ss.SrcOffset); + } + ibv_wc wc; + while ((recvMask & iter.RecvMask) != iter.RecvMask) { + int rv = CQ->Poll(&wc, 1); + if (rv > 0) { + Y_VERIFY(wc.status == IBV_WC_SUCCESS, "AllGather::Sync fail, status %d", (int)wc.status); + if (wc.opcode == IBV_WC_RECV_RDMA_WITH_IMM) { + //printf("Got %d\n", wc.imm_data); + ++recvDebt; + ui64 newBit = ui64(1) << wc.imm_data; + if (recvMask & newBit) { + Y_VERIFY((FutureRecvMask & newBit) == 0, "data from 2 Sync() ahead is impossible"); + FutureRecvMask |= newBit; + } else { + recvMask |= newBit; + } + } else if (wc.opcode == IBV_WC_RDMA_WRITE) { + --ActiveRDMACount; + } else { + Y_ASSERT(0); + } + } else { + if (recvDebt > 0) { + PostRecv(); + --recvDebt; + } + } + } + for (int k = 0; k < iter.ReduceList.ysize(); ++k) { + const TReduce& rr = iter.ReduceList[k]; + ReduceOp->Reduce(myData + rr.DstOffset, myData + rr.SrcOffset, DataSize); + //printf("Merge %d -> %d (%d bytes)\n", rr.SrcOffset, rr.DstOffset, DataSize); + } + //printf("Iteration %d done\n", z); + } + while (recvDebt > 0) { + PostRecv(); + --recvDebt; + } + CurrentOffset = ReadyOffset; + WasFlushed = false; + //printf("new cur offset %g\n", (double)CurrentOffset); + //printf("Sync complete\n"); + } + void Flush() { + Y_ASSERT(!WasFlushed); + CurrentBuffer = 1 - CurrentBuffer; + CurrentOffset = 0; + WasFlushed = true; + } + + public: + TAllDataSync(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp) + : CurrentBuffer(0) + , DataSize(0) + , BufSize(bufSize) + , CurrentOffset(0) + , ReadyOffset(0) + , WasFlushed(false) + , ActiveRDMACount(0) + , FutureRecvMask(0) + , ReduceOp(reduceOp) + { + if (memPool) { + MemBlock[0] = memPool->Alloc(BufSize); + MemBlock[1] = memPool->Alloc(BufSize); + CQ = new TComplectionQueue(memPool->GetIBContext(), WR_COUNT); + SRQ = new TSharedReceiveQueue(memPool->GetIBContext(), WR_COUNT); + FakeRecvMem = memPool->Alloc(4096); + } else { + MemBlock[0] = new TIBMemBlock(BufSize); + MemBlock[1] = new TIBMemBlock(BufSize); + CQ = new TComplectionQueue(nullptr, WR_COUNT); + SRQ = new TSharedReceiveQueue(nullptr, WR_COUNT); + FakeRecvMem = new TIBMemBlock(4096); + } + for (int i = 0; i < WR_COUNT; ++i) { + PostRecv(); + } + } + ~TAllDataSync() { + while (ActiveRDMACount > 0) { + ibv_wc wc; + int rv = CQ->Poll(&wc, 1); + if (rv > 0) { + if (wc.opcode == IBV_WC_RDMA_WRITE) { + --ActiveRDMACount; + } else { + Y_ASSERT(0); + } + } + } + } + }; + + class TAllReduce: public IAllReduce { + TAllDataSync DataSync; + size_t BufSizeMult; + size_t ReadyOffsetMult; + + public: + TAllReduce(size_t bufSize, TPtrArg<TIBMemPool> memPool, TPtrArg<IReduceOp> reduceOp) + : DataSync(bufSize, memPool, reduceOp) + , BufSizeMult(0) + , ReadyOffsetMult(0) + { + } + TAllDataSync& GetDataSync() { + return DataSync; + } + void* GetRawData() override { + return DataSync.GetRawData(); + } + size_t GetRawDataSize() override { + return DataSync.GetRawDataSize(); + } + void Sync() override { + DataSync.Sync(); + } + void Flush() override { + DataSync.Flush(); + } + + bool Resize(size_t dataSize) override { + size_t repSize = (dataSize + 63) & (~63ull); + size_t bufSize = repSize * BufSizeMult; + + if (bufSize > DataSync.BufSize) { + return false; + } + + for (int z = 0; z < DataSync.Iterations.ysize(); ++z) { + TAllDataSync::TIteration& iter = DataSync.Iterations[z]; + for (int i = 0; i < iter.OutList.ysize(); ++i) { + TAllDataSync::TSend& snd = iter.OutList[i]; + snd.Length = dataSize; + snd.SrcOffset = snd.Reduce.SrcIndex * repSize; + snd.DstOffset = snd.Reduce.DstIndex * repSize; + } + + for (int i = 0; i < iter.ReduceList.ysize(); ++i) { + TAllDataSync::TReduce& red = iter.ReduceList[i]; + red.SrcOffset = red.SrcIndex * repSize; + red.DstOffset = red.DstIndex * repSize; + } + } + DataSync.ReadyOffset = ReadyOffsetMult * repSize; + DataSync.DataSize = dataSize; + return true; + } + friend class TIBCollective; + }; + + class TAllGather: public IAllGather { + TAllDataSync DataSync; + int ColSize; + + public: + TAllGather(int colSize, size_t bufSize, TPtrArg<TIBMemPool> memPool) + : DataSync(bufSize, memPool, nullptr) + , ColSize(colSize) + { + } + TAllDataSync& GetDataSync() { + return DataSync; + } + void* GetRawData() override { + return DataSync.GetRawData(); + } + size_t GetRawDataSize() override { + return DataSync.GetRawDataSize(); + } + void Sync() override { + DataSync.Sync(); + } + void Flush() override { + DataSync.Flush(); + } + + bool Resize(const TVector<size_t>& szPerRank) override { + Y_VERIFY(szPerRank.ysize() == ColSize, "Invalid size array"); + + TVector<size_t> offsets; + offsets.push_back(0); + for (int rank = 0; rank < ColSize; ++rank) { + offsets.push_back(offsets.back() + szPerRank[rank]); + } + + size_t dataSize = offsets.back(); + if (dataSize > DataSync.BufSize) { + return false; + } + + for (int z = 0; z < DataSync.Iterations.ysize(); ++z) { + TAllDataSync::TIteration& iter = DataSync.Iterations[z]; + for (int i = 0; i < iter.OutList.ysize(); ++i) { + TAllDataSync::TSend& snd = iter.OutList[i]; + int rangeBeg = snd.Gather.RangeBeg; + int rangeFin = snd.Gather.RangeFin; + snd.Length = offsets[rangeFin] - offsets[rangeBeg]; + snd.SrcOffset = offsets[rangeBeg]; + snd.DstOffset = snd.SrcOffset; + } + } + DataSync.DataSize = dataSize; + return true; + } + }; + + struct TIBAddr { + int LID, SL; + + TIBAddr() + : LID(0) + , SL(0) + { + } + TIBAddr(int lid, int sl) + : LID(lid) + , SL(sl) + { + } + }; + inline bool operator==(const TIBAddr& a, const TIBAddr& b) { + return a.LID == b.LID && a.SL == b.SL; + } + inline bool operator<(const TIBAddr& a, const TIBAddr& b) { + if (a.LID == b.LID) { + return a.SL < b.SL; + } + return a.LID < b.LID; + } + + struct TIBAddrHash { + int operator()(const TIBAddr& a) const { + return a.LID + a.SL * 4254515; + } + }; + + class TIBCollective: public IIBCollective { + struct TPendingMessage { + int QPN; + ui64 WorkId; + + TPendingMessage() { + Zero(*this); + } + TPendingMessage(int qpn, ui64 wid) + : QPN(qpn) + , WorkId(wid) + { + } + }; + + struct TBlockInform { + TAllDataSync::TBlockInfo RemoteBlocks[2]; + int PSN, QPN; + }; + + struct TPeerConnection { + TAllDataSync::TBlockInfo RemoteBlocks[2]; + TIntrusivePtr<TRCQueuePair> QP; + }; + + struct TBWTest { + ui64 Addr; + ui32 RKey; + }; + + TIntrusivePtr<TIBPort> Port; + TIntrusivePtr<TIBMemPool> MemPool; + int ColSize, ColRank; + TVector<int> Hosts; // host LIDs + TVector<TVector<int>> HostGroup; + TVector<TIntrusivePtr<TRCQueuePair>> Peers; + TIntrusivePtr<TComplectionQueue> CQ; + TIntrusivePtr<TIBBufferPool> BP; + ui8 SendCountTable[SEND_COUNT_TABLE_SIZE]; + ui8 RDMACountTable[SEND_COUNT_TABLE_SIZE]; + TDeque<TPendingMessage> Pending; + TMergePlan MergePlan, ReducePlan; + int QPNTableSizeLog; + + void WriteCompleted(const ibv_wc& wc) { + --SendCountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)]; + if (wc.opcode == IBV_WC_RDMA_WRITE) { + --RDMACountTable[wc.qp_num & (SEND_COUNT_TABLE_SIZE - 1)]; + } + BP->FreeBuf(wc.wr_id); + } + bool GetMsg(ui64* resWorkId, int* resQPN, TIBMicroPeerTable* tbl) { + if (tbl->NeedParsePending()) { + for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) { + if (!tbl->NeedQPN(z->QPN)) { + continue; + } + *resWorkId = z->WorkId; + *resQPN = z->QPN; + Pending.erase(z); + return true; + } + //printf("Stop parse pending\n"); + tbl->StopParsePending(); + } + for (;;) { + ibv_wc wc; + int rv = CQ->Poll(&wc, 1); + if (rv > 0) { + Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); + if (wc.opcode & IBV_WC_RECV) { + BP->RequestPostRecv(); + if (tbl->NeedQPN(wc.qp_num)) { + *resWorkId = wc.wr_id; + *resQPN = wc.qp_num; + return true; + } else { + Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); + BP->PostRecv(); + } + } else { + WriteCompleted(wc); + } + } else { + return false; + } + } + } + + bool ProcessSendCompletion(const ibv_wc& wc) { + Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); + if (wc.opcode & IBV_WC_RECV) { + BP->RequestPostRecv(); + Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); + BP->PostRecv(); + } else { + WriteCompleted(wc); + return true; + } + return false; + } + + void WaitCompletion(ibv_wc* res) { + ibv_wc& wc = *res; + for (;;) { + int rv = CQ->Poll(&wc, 1); + if (rv > 0 && ProcessSendCompletion(wc)) { + break; + } + } + } + + bool TryWaitCompletion() override { + ibv_wc wc; + for (;;) { + int rv = CQ->Poll(&wc, 1); + if (rv > 0) { + if (ProcessSendCompletion(wc)) { + return true; + } + } else { + return false; + } + } + } + + void WaitCompletion() override { + ibv_wc wc; + WaitCompletion(&wc); + } + + ui64 WaitForMsg(int qpn) { + for (TDeque<TPendingMessage>::iterator z = Pending.begin(); z != Pending.end(); ++z) { + if (z->QPN == qpn) { + ui64 workId = z->WorkId; + Pending.erase(z); + return workId; + } + } + ibv_wc wc; + for (;;) { + int rv = CQ->Poll(&wc, 1); + if (rv > 0) { + Y_VERIFY(wc.status == IBV_WC_SUCCESS, "WaitForMsg() fail, status %d", (int)wc.status); + if (wc.opcode & IBV_WC_RECV) { + BP->RequestPostRecv(); + if ((int)wc.qp_num == qpn) { + return wc.wr_id; + } else { + Pending.push_back(TPendingMessage(wc.qp_num, wc.wr_id)); + BP->PostRecv(); + } + } else { + WriteCompleted(wc); + } + } + } + } + + bool AllocOperationSlot(TPtrArg<TRCQueuePair> qp) { + int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1); + if (SendCountTable[way] >= MAX_REQS_PER_PEER) { + return false; + } + ++SendCountTable[way]; + return true; + } + bool AllocRDMAWriteSlot(TPtrArg<TRCQueuePair> qp) { + int way = qp->GetQPN() & (SEND_COUNT_TABLE_SIZE - 1); + if (SendCountTable[way] >= MAX_REQS_PER_PEER) { + return false; + } + if (RDMACountTable[way] >= MAX_OUTSTANDING_RDMA) { + return false; + } + ++SendCountTable[way]; + ++RDMACountTable[way]; + return true; + } + bool TryPostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) { + if (AllocOperationSlot(qp)) { + BP->PostSend(qp, data, len); + return true; + } + return false; + } + void PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) { + while (!TryPostSend(qp, data, len)) { + WaitCompletion(); + } + } + int GetRank() override { + return ColRank; + } + int GetSize() override { + return ColSize; + } + int GetGroupTypeCount() override { + return HostGroup.ysize(); + } + int GetQPN(int rank) override { + if (rank == ColRank) { + Y_ASSERT(0 && "there is no qpn connected to localhost"); + return 0; + } + return Peers[rank]->GetQPN(); + } + + void Start(const TCollectiveLinkSet& links) override { + Hosts = links.Hosts; + HostGroup = links.HostGroup; + for (int k = 0; k < ColSize; ++k) { + if (k == ColRank) { + continue; + } + const TCollectiveLinkSet::TLinkInfo& lnk = links.Links[k]; + ibv_ah_attr peerAddr; + MakeAH(&peerAddr, Port, Hosts[k], COL_SERVICE_LEVEL); + Peers[k]->Init(peerAddr, lnk.QPN, lnk.PSN); + } + + //CreatePow2Merge(&MergePlan, ColSize); + //CreatePow2Merge(&ReducePlan, ColSize); + CreateGroupMerge(&MergePlan, AA_STAR, HostGroup); + CreateGroupMerge(&ReducePlan, AA_POW2_MERGE, HostGroup); + } + + void CreateDataSyncQPs( + TPtrArg<TComplectionQueue> cq, + TPtrArg<TSharedReceiveQueue> srq, + TPtrArg<TIBMemBlock> memBlock0, + TPtrArg<TIBMemBlock> memBlock1, + const TMergePlan& plan, + THashMap<TIBAddr, TPeerConnection, TIBAddrHash>* res) { + THashMap<TIBAddr, TPeerConnection, TIBAddrHash>& connections = *res; + + TIBMemBlock* memBlock[2] = {memBlock0, memBlock1}; + + // make full peer list + TVector<TIBAddr> peerList; + for (int z = 0; z < plan.Iterations.ysize(); ++z) { + const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; + for (int i = 0; i < rr.OutList.ysize(); ++i) { + const TMergeRecord::TTransfer& tr = rr.OutList[i]; + peerList.push_back(TIBAddr(tr.DstRank, tr.SL)); + } + for (int i = 0; i < rr.InList.ysize(); ++i) { + const TMergeRecord::TInTransfer& tr = rr.InList[i]; + peerList.push_back(TIBAddr(tr.SrcRank, tr.SL)); + } + } + Sort(peerList.begin(), peerList.end()); + peerList.erase(Unique(peerList.begin(), peerList.end()), peerList.end()); + + // establish QPs and exchange mem block handlers + for (int z = 0; z < peerList.ysize(); ++z) { + const TIBAddr& ibAddr = peerList[z]; + int dstRank = ibAddr.LID; + TPeerConnection& dst = connections[ibAddr]; + + dst.QP = new TRCQueuePair(Port->GetCtx(), cq, srq, TAllDataSync::WR_COUNT); + + TBlockInform myBlock; + for (int k = 0; k < 2; ++k) { + myBlock.RemoteBlocks[k].Addr = memBlock[k]->GetAddr(); + myBlock.RemoteBlocks[k].Key = memBlock[k]->GetMemRegion()->GetRKey(); + } + myBlock.PSN = dst.QP->GetPSN(); + myBlock.QPN = dst.QP->GetQPN(); + PostSend(Peers[dstRank], &myBlock, sizeof(myBlock)); + } + + for (int z = 0; z < peerList.ysize(); ++z) { + const TIBAddr& ibAddr = peerList[z]; + int dstRank = ibAddr.LID; + int sl = COL_DATA_SERVICE_LEVEL + ClampVal(ibAddr.SL, 0, COL_DATA_SERVICE_LEVEL_COUNT); + + TPeerConnection& dst = connections[ibAddr]; + + ui64 wr_id = WaitForMsg(Peers[dstRank]->GetQPN()); + TIBRecvPacketProcess pkt(*BP, wr_id); + const TBlockInform& info = *(TBlockInform*)pkt.GetData(); + ibv_ah_attr peerAddr; + MakeAH(&peerAddr, Port, Hosts[dstRank], COL_DATA_SERVICE_LEVEL + sl); + dst.QP->Init(peerAddr, info.QPN, info.PSN); + dst.RemoteBlocks[0] = info.RemoteBlocks[0]; + dst.RemoteBlocks[1] = info.RemoteBlocks[1]; + } + Fence(); + } + + IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) override { + const TMergePlan& plan = MergePlan; + + Y_VERIFY(szPerRank.ysize() == ColSize, "Invalid size array"); + + size_t totalSize = 0; + for (int i = 0; i < szPerRank.ysize(); ++i) { + totalSize += szPerRank[i]; + } + size_t bufSize = 4096; + while (totalSize >= bufSize) { + bufSize *= 2; + } + + TAllGather* res = new TAllGather(ColSize, bufSize, MemPool); + TAllDataSync& ds = res->GetDataSync(); + + THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections; + CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections); + + // build plan + for (int z = 0; z < plan.Iterations.ysize(); ++z) { + const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; + if (rr.OutList.empty() && rr.InList.empty()) { + continue; + } + TAllDataSync::TIteration& iter = ds.Iterations.emplace_back(); + for (int i = 0; i < rr.OutList.ysize(); ++i) { + const TMergeRecord::TTransfer& tr = rr.OutList[i]; + TAllDataSync::TSend& snd = iter.OutList.emplace_back(); + TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)]; + + snd.ImmData = tr.Id; + snd.Gather.RangeBeg = tr.RangeBeg; + snd.Gather.RangeFin = tr.RangeFin; + snd.QP = pc.QP; + snd.RemoteBlocks[0] = pc.RemoteBlocks[0]; + snd.RemoteBlocks[1] = pc.RemoteBlocks[1]; + snd.DstRank = tr.DstRank; + } + for (int i = 0; i < rr.InList.ysize(); ++i) { + const TMergeRecord::TInTransfer& tr = rr.InList[i]; + TAllDataSync::TRecv& rcv = iter.InList.emplace_back(); + TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)]; + rcv.QP = pc.QP; + rcv.SrcRank = tr.SrcRank; + } + iter.RecvMask = rr.RecvMask; + } + bool rv = res->Resize(szPerRank); + Y_VERIFY(rv, "oops"); + + return res; + } + IAllGather* CreateAllGather(size_t szPerRank) override { + TVector<size_t> arr; + arr.resize(ColSize, szPerRank); + return CreateAllGather(arr); + } + + IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) override { + const TMergePlan& plan = ReducePlan; + + size_t bufSizeMult = plan.MaxRankReceiveCount + 1; + size_t bufSize = 4096; + { + size_t sz = (dataSize + 64) * bufSizeMult; + while (sz > bufSize) { + bufSize *= 2; + } + } + + TAllReduce* res = new TAllReduce(bufSize, MemPool, reduceOp); + TAllDataSync& ds = res->GetDataSync(); + + THashMap<TIBAddr, TPeerConnection, TIBAddrHash> connections; + CreateDataSyncQPs(ds.CQ, ds.SRQ, ds.MemBlock[0], ds.MemBlock[1], plan, &connections); + + // build plan + int currentDataOffset = 0; + for (int z = 0; z < plan.Iterations.ysize(); ++z) { + const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; + if (rr.OutList.empty() && rr.InList.empty()) { + continue; + } + TAllDataSync::TIteration& iter = ds.Iterations.emplace_back(); + for (int i = 0; i < rr.OutList.ysize(); ++i) { + const TMergeRecord::TTransfer& tr = rr.OutList[i]; + TAllDataSync::TSend& snd = iter.OutList.emplace_back(); + TPeerConnection& pc = connections[TIBAddr(tr.DstRank, tr.SL)]; + + snd.ImmData = tr.Id; + snd.Reduce.SrcIndex = currentDataOffset; + snd.Reduce.DstIndex = tr.Id + 1; + snd.QP = pc.QP; + snd.RemoteBlocks[0] = pc.RemoteBlocks[0]; + snd.RemoteBlocks[1] = pc.RemoteBlocks[1]; + snd.DstRank = tr.DstRank; + } + + for (int i = 0; i < rr.InList.ysize(); ++i) { + const TMergeRecord::TInTransfer& tr = rr.InList[i]; + TAllDataSync::TRecv& rcv = iter.InList.emplace_back(); + TPeerConnection& pc = connections[TIBAddr(tr.SrcRank, tr.SL)]; + rcv.QP = pc.QP; + rcv.SrcRank = tr.SrcRank; + } + iter.RecvMask = rr.RecvMask; + + TVector<int> inputOffset; + inputOffset.push_back(currentDataOffset); + int newDataOffset = currentDataOffset; + for (int i = 0; i < 64; ++i) { + if (rr.RecvMask & (1ull << i)) { + int offset = i + 1; + inputOffset.push_back(offset); + newDataOffset = Max(offset, newDataOffset); + } + } + for (int i = 0; i < inputOffset.ysize(); ++i) { + if (inputOffset[i] == newDataOffset) { + continue; + } + TAllDataSync::TReduce& red = iter.ReduceList.emplace_back(); + red.SrcIndex = inputOffset[i]; + red.DstIndex = newDataOffset; + } + currentDataOffset = newDataOffset; + } + res->BufSizeMult = bufSizeMult; + res->ReadyOffsetMult = currentDataOffset; + + bool rv = res->Resize(dataSize); + Y_VERIFY(rv, "oops"); + + return res; + } + + void Fence() override { + const TMergePlan& plan = ReducePlan; + + for (int z = 0; z < plan.Iterations.ysize(); ++z) { + const TMergeRecord& rr = plan.Iterations[z].Ops[ColRank]; + for (int i = 0; i < rr.OutList.ysize(); ++i) { + const TMergeRecord::TTransfer& tr = rr.OutList[i]; + char c; + PostSend(Peers[tr.DstRank], &c, sizeof(c)); + } + + for (int i = 0; i < rr.InList.ysize(); ++i) { + const TMergeRecord::TInTransfer& tr = rr.InList[i]; + ui64 wr_id = WaitForMsg(Peers[tr.SrcRank]->GetQPN()); + TIBRecvPacketProcess pkt(*BP, wr_id); + } + } + } + void RunBWTest(int groupType, int delta, int* targetRank, float* res) override { + const int BUF_SIZE = 8 * 1024 * 1024; + TIntrusivePtr<TIBMemBlock> sendMem, recvMem; + sendMem = MemPool->Alloc(BUF_SIZE); + recvMem = MemPool->Alloc(BUF_SIZE); + + int myGroup = HostGroup[groupType][ColRank]; + int myGroupPos = 0; + TVector<int> gg; + Y_ASSERT(HostGroup[groupType].ysize() == ColSize); + for (int rank = 0; rank < ColSize; ++rank) { + if (HostGroup[groupType][rank] == myGroup) { + if (rank == ColRank) { + myGroupPos = gg.ysize(); + } + gg.push_back(rank); + } + } + if (delta >= gg.ysize()) { + *targetRank = -1; + *res = 0; + return; + } + + int sendRank = gg[(myGroupPos + delta) % gg.ysize()]; + int recvRank = gg[(myGroupPos + gg.ysize() - delta) % gg.ysize()]; + *targetRank = sendRank; + TIntrusivePtr<TRCQueuePair> sendRC = Peers[sendRank]; + TIntrusivePtr<TRCQueuePair> recvRC = Peers[recvRank]; + { + TBWTest bw; + bw.Addr = recvMem->GetAddr(); + bw.RKey = recvMem->GetMemRegion()->GetRKey(); + PostSend(recvRC, &bw, sizeof(bw)); + } + TBWTest dstMem; + { + ui64 wr_id = WaitForMsg(sendRC->GetQPN()); + TIBRecvPacketProcess pkt(*BP, wr_id); + dstMem = *(TBWTest*)pkt.GetData(); + } + // run + TVector<double> score; + for (int iter = 0; iter < 5; ++iter) { + while (!AllocRDMAWriteSlot(sendRC)) { + WaitCompletion(); + Y_ASSERT(0 && "measurements are imprecise"); + } + NHPTimer::STime t; + NHPTimer::GetTime(&t); + sendRC->PostRDMAWrite(dstMem.Addr, dstMem.RKey, sendMem->GetMemRegion(), 0, sendMem->GetData(), BUF_SIZE); + for (;;) { + ibv_wc wc; + WaitCompletion(&wc); + if (wc.opcode == IBV_WC_RDMA_WRITE) { + if (wc.qp_num != (ui32)sendRC->GetQPN()) { + abort(); + } + break; + } + } + double tPassed = NHPTimer::GetTimePassed(&t); + double speed = BUF_SIZE / tPassed / 1000000000.0; // G/sec + score.push_back(speed); + } + Sort(score.begin(), score.end()); + // signal completion & wait for signal + *res = score[score.size() / 2]; + { + char bb; + PostSend(sendRC, &bb, sizeof(bb)); + ui64 wr_id = WaitForMsg(recvRC->GetQPN()); + TIBRecvPacketProcess pkt(*BP, wr_id); + } + } + bool TrySendMicro(int dstRank, const void* data, int dataSize) override { + return TryPostSend(Peers[dstRank], data, dataSize); + } + void InitPeerTable(TIBMicroPeerTable* res) override { + res->Init(QPNTableSizeLog); + } + void RdmaWrite(const TVector<TRdmaRequest>& reqs) override { + TVector<TVector<int>> reqPerRank; + reqPerRank.resize(ColSize); + int reqCount = reqs.ysize(); + for (int i = 0; i < reqCount; ++i) { + reqPerRank[reqs[i].DstRank].push_back(i); + } + int inFlight = 0; // IB congestion control sucks :/ so we limit number of simultaneous rdmas + int startRank = ColRank; + while (reqCount > 0) { + if (inFlight < MAX_TOTAL_RDMA) { + for (int z = 0; z < ColSize; ++z) { + int dstRank = (startRank + 1 + z) % ColSize; + if (reqPerRank[dstRank].empty()) { + continue; + } + Y_ASSERT(dstRank != ColRank && "sending self is meaningless"); + TRCQueuePair* qp = Peers[dstRank].Get(); + if (AllocRDMAWriteSlot(qp)) { + const TRdmaRequest& rr = reqs[reqPerRank[dstRank].back()]; + qp->PostRDMAWrite(rr.RemoteAddr, rr.RemoteKey, rr.LocalAddr, rr.LocalKey, 0, rr.Size); + reqPerRank[dstRank].pop_back(); + if (++inFlight >= MAX_TOTAL_RDMA) { + startRank = dstRank; + break; + } + } + } + } + { + ibv_wc wc; + WaitCompletion(&wc); + if (wc.opcode == IBV_WC_RDMA_WRITE) { + --inFlight; + --reqCount; + } + } + } + } + + public: + TIBCollective(TPtrArg<TIBPort> port, TPtrArg<TIBMemPool> memPool, + const TCollectiveInit& params, + TCollectiveLinkSet* resLinks) + : Port(port) + , MemPool(memPool) + , QPNTableSizeLog(0) + { + ColSize = params.Size; + ColRank = params.Rank; + + int maxOutstandingQueries = MAX_REQS_PER_PEER * ColSize + 10; + CQ = new TComplectionQueue(Port->GetCtx(), maxOutstandingQueries * 2); + BP = new TIBBufferPool(Port->GetCtx(), maxOutstandingQueries); + + Peers.resize(ColSize); + resLinks->Links.resize(ColSize); + TVector<int> qpnArr; + for (int k = 0; k < ColSize; ++k) { + if (k == ColRank) { + continue; + } + TRCQueuePair* rc = new TRCQueuePair(Port->GetCtx(), CQ, BP->GetSRQ(), MAX_REQS_PER_PEER); + Peers[k] = rc; + TCollectiveLinkSet::TLinkInfo& lnk = resLinks->Links[k]; + lnk.PSN = rc->GetPSN(); + lnk.QPN = rc->GetQPN(); + + qpnArr.push_back(lnk.QPN); + } + resLinks->Hosts.resize(ColSize); + resLinks->Hosts[ColRank] = Port->GetLID(); + + static_assert(MAX_REQS_PER_PEER < 256, "expect MAX_REQS_PER_PEER < 256"); // sent count will fit into SendCountTable[] + Zero(SendCountTable); + Zero(RDMACountTable); + + if (!qpnArr.empty()) { + for (;;) { + TVector<ui8> qpnTable; + int qpnTableSize = 1 << QPNTableSizeLog; + qpnTable.resize(qpnTableSize, 0); + bool ok = true; + for (int i = 0; i < qpnArr.ysize(); ++i) { + int idx = qpnArr[i] & (qpnTableSize - 1); + if (++qpnTable[idx] == 2) { + ok = false; + break; + } + } + if (ok) { + break; + } + ++QPNTableSizeLog; + } + //printf("QPN table, size_log %d\n", QPNTableSizeLog); + } + } + + friend class TIBRecvMicro; + }; + + TIBRecvMicro::TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable) + : IB(*(TIBCollective*)col) + { + Y_ASSERT(typeid(IB) == typeid(TIBCollective)); + if (IB.GetMsg(&Id, &QPN, peerTable)) { + Data = IB.BP->GetBufData(Id); + } else { + Data = nullptr; + } + } + + TIBRecvMicro::~TIBRecvMicro() { + if (Data) { + IB.BP->FreeBuf(Id); + IB.BP->PostRecv(); + } + } + + IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks) { + return new TIBCollective(GetIBDevice(), GetIBMemPool(), params, resLinks); + } +} diff --git a/library/cpp/netliba/v6/ib_collective.h b/library/cpp/netliba/v6/ib_collective.h new file mode 100644 index 0000000000..48ffd29b34 --- /dev/null +++ b/library/cpp/netliba/v6/ib_collective.h @@ -0,0 +1,160 @@ +#pragma once + +#include <library/cpp/binsaver/bin_saver.h> + +namespace NNetliba { + struct TCollectiveInit { + int Size, Rank; + + SAVELOAD(Size, Rank) + }; + + struct TCollectiveLinkSet { + struct TLinkInfo { + int QPN, PSN; + }; + TVector<int> Hosts; // host LIDs + TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch + TVector<TLinkInfo> Links; + + SAVELOAD(Hosts, HostGroup, Links) + }; + + struct IAllDataSync: public TThrRefBase { + virtual void* GetRawData() = 0; + virtual size_t GetRawDataSize() = 0; + virtual void Sync() = 0; + virtual void Flush() = 0; + + template <class T> + T* GetData() { + return static_cast<T*>(GetRawData()); + } + template <class T> + size_t GetSize() { + return GetRawDataSize() / sizeof(T); + } + }; + + struct IAllReduce: public IAllDataSync { + virtual bool Resize(size_t dataSize) = 0; + }; + + struct IAllGather: public IAllDataSync { + virtual bool Resize(const TVector<size_t>& szPerRank) = 0; + }; + + struct IReduceOp: public TThrRefBase { + virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0; + }; + + template <class T, class TElem = typename T::TElem> + class TAllReduceOp: public IReduceOp { + T Op; + + public: + TAllReduceOp() { + } + TAllReduceOp(T op) + : Op(op) + { + } + void Reduce(void* dst, const void* add, size_t dataSize) const override { + TElem* dstPtr = (TElem*)(dst); + const TElem* addPtr = (const TElem*)(add); + TElem* finPtr = (TElem*)(((char*)dst) + dataSize); + while (dstPtr < finPtr) { + Op(dstPtr, *addPtr); + ++dstPtr; + ++addPtr; + } + } + }; + + // table of active peers for micro send/recv + class TIBMicroPeerTable { + TVector<ui8> Table; // == 0 means accept mesages from this qpn + int TableSize; + bool ParsePending; + + public: + TIBMicroPeerTable() + : ParsePending(true) + { + Init(0); + } + void Init(int tableSizeLog) { + TableSize = 1 << tableSizeLog; + ParsePending = true; + Table.resize(0); + Table.resize(TableSize, 0); + } + bool NeedParsePending() const { + return ParsePending; + } + void StopParsePending() { + ParsePending = false; + } + void StopQPN(int qpn, ui8 mask) { + Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0); + Table[qpn & (TableSize - 1)] |= mask; + } + void StopQPN(int qpn) { + Y_ASSERT(Table[qpn & (TableSize - 1)] == 0); + Table[qpn & (TableSize - 1)] = 0xff; + } + bool NeedQPN(int qpn) const { + return Table[qpn & (TableSize - 1)] != 0xff; + } + }; + + struct IIBCollective; + class TIBCollective; + class TIBRecvMicro: public TNonCopyable { + TIBCollective& IB; + ui64 Id; + int QPN; + void* Data; + + public: + TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable); + ~TIBRecvMicro(); + void* GetRawData() const { + return Data; + } + template <class T> + T* GetData() { + return static_cast<T*>(GetRawData()); + } + int GetQPN() const { + return QPN; + } + }; + + struct IIBCollective: public TThrRefBase { + struct TRdmaRequest { + int DstRank; + ui64 RemoteAddr, LocalAddr; + ui32 RemoteKey, LocalKey; + ui64 Size; + }; + + virtual int GetRank() = 0; + virtual int GetSize() = 0; + virtual int GetGroupTypeCount() = 0; + virtual int GetQPN(int rank) = 0; + virtual bool TryWaitCompletion() = 0; + virtual void WaitCompletion() = 0; + virtual void Start(const TCollectiveLinkSet& links) = 0; + virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0; + virtual IAllGather* CreateAllGather(size_t szPerRank) = 0; + virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0; + virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0; + virtual void Fence() = 0; + virtual void InitPeerTable(TIBMicroPeerTable* res) = 0; + virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0; + virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0; + }; + + IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks); +} diff --git a/library/cpp/netliba/v6/ib_cs.cpp b/library/cpp/netliba/v6/ib_cs.cpp new file mode 100644 index 0000000000..6dbe7bb0e5 --- /dev/null +++ b/library/cpp/netliba/v6/ib_cs.cpp @@ -0,0 +1,776 @@ +#include "stdafx.h" +#include "ib_cs.h" +#include "ib_buffers.h" +#include "ib_mem.h" +#include <util/generic/deque.h> +#include <util/digest/murmur.h> + +/* +Questions + does rdma work? + what is RC latency? + 3us if measured by completion event arrival + 2.3us if bind to socket 0 & use inline send + memory region - can we use memory from some offset? + yes + is send_inplace supported and is it faster? + yes, supported, 1024 bytes limit, inline is faster (2.4 vs 2.9) + is srq a penalty compared to regular rq? + rdma is faster anyway, so why bother + +collective ops + support asymmetric configurations by additional transfers (overlap 1 or 2 hosts is allowed) + +remove commented stuff all around + +next gen + shared+registered large mem blocks for easy transfer + no crc calcs + direct channel exposure + make ui64 packet id? otherwise we could get duplicate id (highly improbable but possible) + lock free allocation in ib_mem +*/ + +namespace NNetliba { + const int WELCOME_QKEY = 0x13081976; + + const int MAX_SEND_COUNT = (128 - 10) / 4; + const int QP_SEND_QUEUE_SIZE = (MAX_SEND_COUNT * 2 + 10) + 10; + const int WELCOME_QP_SEND_SIZE = 10000; + + const int MAX_SRQ_WORK_REQUESTS = 10000; + const int MAX_CQ_EVENTS = MAX_SRQ_WORK_REQUESTS; //1000; + + const double CHANNEL_CHECK_INTERVAL = 1.; + + const int TRAFFIC_SL = 4; // 4 is mandatory for RoCE to work, it's the only lossless priority(?) + const int CONNECT_SL = 1; + + class TIBClientServer: public IIBClientServer { + enum ECmd { + CMD_HANDSHAKE, + CMD_HANDSHAKE_ACK, + CMD_CONFIRM, + CMD_DATA_TINY, + CMD_DATA_INIT, + CMD_BUFFER_READY, + CMD_DATA_COMPLETE, + CMD_KEEP_ALIVE, + }; +#pragma pack(1) + struct TCmdHandshake { + char Command; + int QPN, PSN; + TGUID SocketId; + TUdpAddress MyAddress; // address of the handshake sender as viewed from receiver + }; + struct TCmdHandshakeAck { + char Command; + int QPN, PSN; + int YourQPN; + }; + struct TCmdConfirm { + char Command; + }; + struct TCmdDataTiny { + struct THeader { + char Command; + ui16 Size; + TGUID PacketGuid; + } Header; + typedef char TDataVec[SMALL_PKT_SIZE - sizeof(THeader)]; + TDataVec Data; + + static int GetMaxDataSize() { + return sizeof(TDataVec); + } + }; + struct TCmdDataInit { + char Command; + size_t Size; + TGUID PacketGuid; + }; + struct TCmdBufferReady { + char Command; + TGUID PacketGuid; + ui64 RemoteAddr; + ui32 RemoteKey; + }; + struct TCmdDataComplete { + char Command; + TGUID PacketGuid; + ui64 DataHash; + }; + struct TCmdKeepAlive { + char Command; + }; +#pragma pack() + + struct TCompleteInfo { + enum { + CI_DATA_TINY, + CI_RDMA_COMPLETE, + CI_DATA_SENT, + CI_KEEP_ALIVE, + CI_IGNORE, + }; + int Type; + int BufId; + TIBMsgHandle MsgHandle; + + TCompleteInfo(int t, int bufId, TIBMsgHandle msg) + : Type(t) + , BufId(bufId) + , MsgHandle(msg) + { + } + }; + struct TPendingQueuedSend { + TGUID PacketGuid; + TIBMsgHandle MsgHandle; + TRopeDataPacket* Data; + + TPendingQueuedSend() + : MsgHandle(0) + { + } + TPendingQueuedSend(const TGUID& packetGuid, TIBMsgHandle msgHandle, TRopeDataPacket* data) + : PacketGuid(packetGuid) + , MsgHandle(msgHandle) + , Data(data) + { + } + }; + struct TQueuedSend { + TGUID PacketGuid; + TIBMsgHandle MsgHandle; + TIntrusivePtr<TIBMemBlock> MemBlock; + ui64 RemoteAddr; + ui32 RemoteKey; + + TQueuedSend() = default; + TQueuedSend(const TGUID& packetGuid, TIBMsgHandle msgHandle) + : PacketGuid(packetGuid) + , MsgHandle(msgHandle) + , RemoteAddr(0) + , RemoteKey(0) + { + } + }; + struct TQueuedRecv { + TGUID PacketGuid; + TIntrusivePtr<TIBMemBlock> Data; + + TQueuedRecv() = default; + TQueuedRecv(const TGUID& packetGuid, TPtrArg<TIBMemBlock> data) + : PacketGuid(packetGuid) + , Data(data) + { + } + }; + struct TIBPeer: public IIBPeer { + TUdpAddress PeerAddress; + TIntrusivePtr<TRCQueuePair> QP; + EState State; + int SendCount; + NHPTimer::STime LastRecv; + TDeque<TPendingQueuedSend> PendingSendQueue; + // these lists have limited size and potentially just circle buffers + TDeque<TQueuedSend> SendQueue; + TDeque<TQueuedRecv> RecvQueue; + TDeque<TCompleteInfo> OutMsgs; + + TIBPeer(const TUdpAddress& peerAddress, TPtrArg<TRCQueuePair> qp) + : PeerAddress(peerAddress) + , QP(qp) + , State(CONNECTING) + , SendCount(0) + { + NHPTimer::GetTime(&LastRecv); + } + ~TIBPeer() override { + //printf("IBPeer destroyed\n"); + } + EState GetState() override { + return State; + } + TDeque<TQueuedSend>::iterator GetSend(const TGUID& packetGuid) { + for (TDeque<TQueuedSend>::iterator z = SendQueue.begin(); z != SendQueue.end(); ++z) { + if (z->PacketGuid == packetGuid) { + return z; + } + } + Y_VERIFY(0, "no send by guid"); + return SendQueue.begin(); + } + TDeque<TQueuedSend>::iterator GetSend(TIBMsgHandle msgHandle) { + for (TDeque<TQueuedSend>::iterator z = SendQueue.begin(); z != SendQueue.end(); ++z) { + if (z->MsgHandle == msgHandle) { + return z; + } + } + Y_VERIFY(0, "no send by handle"); + return SendQueue.begin(); + } + TDeque<TQueuedRecv>::iterator GetRecv(const TGUID& packetGuid) { + for (TDeque<TQueuedRecv>::iterator z = RecvQueue.begin(); z != RecvQueue.end(); ++z) { + if (z->PacketGuid == packetGuid) { + return z; + } + } + Y_VERIFY(0, "no recv by guid"); + return RecvQueue.begin(); + } + void PostRDMA(TQueuedSend& qs) { + Y_ASSERT(qs.RemoteAddr != 0 && qs.MemBlock.Get() != nullptr); + QP->PostRDMAWrite(qs.RemoteAddr, qs.RemoteKey, + qs.MemBlock->GetMemRegion(), 0, qs.MemBlock->GetData(), qs.MemBlock->GetSize()); + OutMsgs.push_back(TCompleteInfo(TCompleteInfo::CI_RDMA_COMPLETE, 0, qs.MsgHandle)); + //printf("Post rdma write, size %d\n", qs.Data->GetSize()); + } + void PostSend(TIBBufferPool& bp, const void* data, size_t len, int t, TIBMsgHandle msgHandle) { + int bufId = bp.PostSend(QP, data, len); + OutMsgs.push_back(TCompleteInfo(t, bufId, msgHandle)); + } + }; + + TIntrusivePtr<TIBPort> Port; + TIntrusivePtr<TIBMemPool> MemPool; + TIntrusivePtr<TIBMemPool::TCopyResultStorage> CopyResults; + TIntrusivePtr<TComplectionQueue> CQ; + TIBBufferPool BP; + TIntrusivePtr<TUDQueuePair> WelcomeQP; + int WelcomeQPN; + TIBConnectInfo ConnectInfo; + TDeque<TIBSendResult> SendResults; + TDeque<TRequest*> ReceivedList; + typedef THashMap<int, TIntrusivePtr<TIBPeer>> TPeerChannelHash; + TPeerChannelHash Channels; + TIBMsgHandle MsgCounter; + NHPTimer::STime LastCheckTime; + + ~TIBClientServer() override { + for (auto& z : ReceivedList) { + delete z; + } + } + TIBPeer* GetChannelByQPN(int qpn) { + TPeerChannelHash::iterator z = Channels.find(qpn); + if (z == Channels.end()) { + return nullptr; + } + return z->second.Get(); + } + + // IIBClientServer + TRequest* GetRequest() override { + if (ReceivedList.empty()) { + return nullptr; + } + TRequest* res = ReceivedList.front(); + ReceivedList.pop_front(); + return res; + } + bool GetSendResult(TIBSendResult* res) override { + if (SendResults.empty()) { + return false; + } + *res = SendResults.front(); + SendResults.pop_front(); + return true; + } + void StartSend(TPtrArg<TIBPeer> peer, const TGUID& packetGuid, TIBMsgHandle msgHandle, TRopeDataPacket* data) { + int sz = data->GetSize(); + if (sz <= TCmdDataTiny::GetMaxDataSize()) { + TCmdDataTiny dataTiny; + dataTiny.Header.Command = CMD_DATA_TINY; + dataTiny.Header.Size = (ui16)sz; + dataTiny.Header.PacketGuid = packetGuid; + TBlockChainIterator bc(data->GetChain()); + bc.Read(dataTiny.Data, sz); + + peer->PostSend(BP, &dataTiny, sizeof(dataTiny.Header) + sz, TCompleteInfo::CI_DATA_TINY, msgHandle); + //printf("Send CMD_DATA_TINY\n"); + } else { + MemPool->CopyData(data, msgHandle, peer, CopyResults); + peer->SendQueue.push_back(TQueuedSend(packetGuid, msgHandle)); + { + TQueuedSend& msg = peer->SendQueue.back(); + TCmdDataInit dataInit; + dataInit.Command = CMD_DATA_INIT; + dataInit.PacketGuid = msg.PacketGuid; + dataInit.Size = data->GetSize(); + peer->PostSend(BP, &dataInit, sizeof(dataInit), TCompleteInfo::CI_IGNORE, 0); + //printf("Send CMD_DATA_INIT\n"); + } + } + ++peer->SendCount; + } + void SendCompleted(TPtrArg<TIBPeer> peer, TIBMsgHandle msgHandle) { + SendResults.push_back(TIBSendResult(msgHandle, true)); + if (--peer->SendCount < MAX_SEND_COUNT) { + if (!peer->PendingSendQueue.empty()) { + TPendingQueuedSend& qs = peer->PendingSendQueue.front(); + StartSend(peer, qs.PacketGuid, qs.MsgHandle, qs.Data); + //printf("Sending pending %d\n", qs.MsgHandle); + peer->PendingSendQueue.pop_front(); + } + } + } + void SendFailed(TPtrArg<TIBPeer> peer, TIBMsgHandle msgHandle) { + //printf("IB SendFailed()\n"); + SendResults.push_back(TIBSendResult(msgHandle, false)); + --peer->SendCount; + } + void PeerFailed(TPtrArg<TIBPeer> peer) { + //printf("PeerFailed(), peer %p, state %d (%d pending, %d queued, %d out, %d sendcount)\n", + // peer.Get(), peer->State, + // (int)peer->PendingSendQueue.size(), + // (int)peer->SendQueue.size(), + // (int)peer->OutMsgs.size(), + // peer->SendCount); + peer->State = IIBPeer::FAILED; + while (!peer->PendingSendQueue.empty()) { + TPendingQueuedSend& qs = peer->PendingSendQueue.front(); + SendResults.push_back(TIBSendResult(qs.MsgHandle, false)); + peer->PendingSendQueue.pop_front(); + } + while (!peer->SendQueue.empty()) { + TQueuedSend& qs = peer->SendQueue.front(); + SendFailed(peer, qs.MsgHandle); + peer->SendQueue.pop_front(); + } + while (!peer->OutMsgs.empty()) { + TCompleteInfo& cc = peer->OutMsgs.front(); + //printf("Don't wait completion for sent packet (QPN %d), bufId %d\n", peer->QP->GetQPN(), cc.BufId); + if (cc.Type == TCompleteInfo::CI_DATA_TINY) { + SendFailed(peer, cc.MsgHandle); + } + BP.FreeBuf(cc.BufId); + peer->OutMsgs.pop_front(); + } + { + Y_ASSERT(peer->SendCount == 0); + //printf("Remove peer %p from hash (QPN %d)\n", peer.Get(), peer->QP->GetQPN()); + TPeerChannelHash::iterator z = Channels.find(peer->QP->GetQPN()); + if (z == Channels.end()) { + Y_VERIFY(0, "peer failed for unregistered peer"); + } + Channels.erase(z); + } + } + TIBMsgHandle Send(TPtrArg<IIBPeer> peerArg, TRopeDataPacket* data, const TGUID& packetGuid) override { + TIBPeer* peer = static_cast<TIBPeer*>(peerArg.Get()); // trust me, I'm professional + if (peer == nullptr || peer->State != IIBPeer::OK) { + return -1; + } + Y_ASSERT(Channels.find(peer->QP->GetQPN())->second == peer); + TIBMsgHandle msgHandle = ++MsgCounter; + if (peer->SendCount >= MAX_SEND_COUNT) { + peer->PendingSendQueue.push_back(TPendingQueuedSend(packetGuid, msgHandle, data)); + } else { + //printf("Sending direct %d\n", msgHandle); + StartSend(peer, packetGuid, msgHandle, data); + } + return msgHandle; + } + void ParsePacket(ibv_wc* wc, NHPTimer::STime tCurrent) { + if (wc->status != IBV_WC_SUCCESS) { + TIBPeer* peer = GetChannelByQPN(wc->qp_num); + if (peer) { + //printf("failed recv packet (status %d)\n", wc->status); + PeerFailed(peer); + } else { + //printf("Ignoring recv error for closed/non existing QPN %d\n", wc->qp_num); + } + return; + } + + TIBRecvPacketProcess pkt(BP, *wc); + + TIBPeer* peer = GetChannelByQPN(wc->qp_num); + if (peer) { + Y_ASSERT(peer->State != IIBPeer::FAILED); + peer->LastRecv = tCurrent; + char cmdId = *(const char*)pkt.GetData(); + switch (cmdId) { + case CMD_CONFIRM: + //printf("got confirm\n"); + Y_ASSERT(peer->State == IIBPeer::CONNECTING); + peer->State = IIBPeer::OK; + break; + case CMD_DATA_TINY: + //printf("Recv CMD_DATA_TINY\n"); + { + const TCmdDataTiny& dataTiny = *(TCmdDataTiny*)pkt.GetData(); + TRequest* req = new TRequest; + req->Address = peer->PeerAddress; + req->Guid = dataTiny.Header.PacketGuid; + req->Data = new TRopeDataPacket; + req->Data->Write(dataTiny.Data, dataTiny.Header.Size); + ReceivedList.push_back(req); + } + break; + case CMD_DATA_INIT: + //printf("Recv CMD_DATA_INIT\n"); + { + const TCmdDataInit& data = *(TCmdDataInit*)pkt.GetData(); + TIntrusivePtr<TIBMemBlock> blk = MemPool->Alloc(data.Size); + peer->RecvQueue.push_back(TQueuedRecv(data.PacketGuid, blk)); + TCmdBufferReady ready; + ready.Command = CMD_BUFFER_READY; + ready.PacketGuid = data.PacketGuid; + ready.RemoteAddr = blk->GetData() - (char*)nullptr; + ready.RemoteKey = blk->GetMemRegion()->GetRKey(); + + peer->PostSend(BP, &ready, sizeof(ready), TCompleteInfo::CI_IGNORE, 0); + //printf("Send CMD_BUFFER_READY\n"); + } + break; + case CMD_BUFFER_READY: + //printf("Recv CMD_BUFFER_READY\n"); + { + const TCmdBufferReady& ready = *(TCmdBufferReady*)pkt.GetData(); + TDeque<TQueuedSend>::iterator z = peer->GetSend(ready.PacketGuid); + TQueuedSend& qs = *z; + qs.RemoteAddr = ready.RemoteAddr; + qs.RemoteKey = ready.RemoteKey; + if (qs.MemBlock.Get()) { + peer->PostRDMA(qs); + } + } + break; + case CMD_DATA_COMPLETE: + //printf("Recv CMD_DATA_COMPLETE\n"); + { + const TCmdDataComplete& cmd = *(TCmdDataComplete*)pkt.GetData(); + TDeque<TQueuedRecv>::iterator z = peer->GetRecv(cmd.PacketGuid); + TQueuedRecv& qr = *z; +#ifdef _DEBUG + Y_VERIFY(MurmurHash<ui64>(qr.Data->GetData(), qr.Data->GetSize()) == cmd.DataHash || cmd.DataHash == 0, "RDMA data hash mismatch"); +#endif + TRequest* req = new TRequest; + req->Address = peer->PeerAddress; + req->Guid = qr.PacketGuid; + req->Data = new TRopeDataPacket; + req->Data->AddBlock(qr.Data.Get(), qr.Data->GetData(), qr.Data->GetSize()); + ReceivedList.push_back(req); + peer->RecvQueue.erase(z); + } + break; + case CMD_KEEP_ALIVE: + break; + default: + Y_ASSERT(0); + break; + } + } else { + // can get here + //printf("Ignoring packet for closed/non existing QPN %d\n", wc->qp_num); + } + } + void OnComplete(ibv_wc* wc, NHPTimer::STime tCurrent) { + TIBPeer* peer = GetChannelByQPN(wc->qp_num); + if (peer) { + if (!peer->OutMsgs.empty()) { + peer->LastRecv = tCurrent; + if (wc->status != IBV_WC_SUCCESS) { + //printf("completed with status %d\n", wc->status); + PeerFailed(peer); + } else { + const TCompleteInfo& cc = peer->OutMsgs.front(); + switch (cc.Type) { + case TCompleteInfo::CI_DATA_TINY: + //printf("Completed data_tiny\n"); + SendCompleted(peer, cc.MsgHandle); + break; + case TCompleteInfo::CI_RDMA_COMPLETE: + //printf("Completed rdma_complete\n"); + { + TDeque<TQueuedSend>::iterator z = peer->GetSend(cc.MsgHandle); + TQueuedSend& qs = *z; + + TCmdDataComplete complete; + complete.Command = CMD_DATA_COMPLETE; + complete.PacketGuid = qs.PacketGuid; +#ifdef _DEBUG + complete.DataHash = MurmurHash<ui64>(qs.MemBlock->GetData(), qs.MemBlock->GetSize()); +#else + complete.DataHash = 0; +#endif + + peer->PostSend(BP, &complete, sizeof(complete), TCompleteInfo::CI_DATA_SENT, qs.MsgHandle); + //printf("Send CMD_DATA_COMPLETE\n"); + } + break; + case TCompleteInfo::CI_DATA_SENT: + //printf("Completed data_sent\n"); + { + TDeque<TQueuedSend>::iterator z = peer->GetSend(cc.MsgHandle); + TIBMsgHandle msgHandle = z->MsgHandle; + peer->SendQueue.erase(z); + SendCompleted(peer, msgHandle); + } + break; + case TCompleteInfo::CI_KEEP_ALIVE: + break; + case TCompleteInfo::CI_IGNORE: + //printf("Completed ignored\n"); + break; + default: + Y_ASSERT(0); + break; + } + peer->OutMsgs.pop_front(); + BP.FreeBuf(wc->wr_id); + } + } else { + Y_VERIFY(0, "got completion without outstanding messages"); + } + } else { + //printf("Got completion for non existing qpn %d, bufId %d (status %d)\n", wc->qp_num, (int)wc->wr_id, (int)wc->status); + if (wc->status == IBV_WC_SUCCESS) { + Y_VERIFY(0, "only errors should go unmatched"); + } + // no need to free buf since it has to be freed in PeerFailed() + } + } + void ParseWelcomePacket(ibv_wc* wc) { + TIBRecvPacketProcess pkt(BP, *wc); + + char cmdId = *(const char*)pkt.GetUDData(); + switch (cmdId) { + case CMD_HANDSHAKE: { + //printf("got handshake\n"); + const TCmdHandshake& handshake = *(TCmdHandshake*)pkt.GetUDData(); + if (handshake.SocketId != ConnectInfo.SocketId) { + // connection attempt from wrong IB subnet + break; + } + TIntrusivePtr<TRCQueuePair> rcQP; + rcQP = new TRCQueuePair(Port->GetCtx(), CQ, BP.GetSRQ(), QP_SEND_QUEUE_SIZE); + + int qpn = rcQP->GetQPN(); + Y_ASSERT(Channels.find(qpn) == Channels.end()); + TIntrusivePtr<TIBPeer>& peer = Channels[qpn]; + peer = new TIBPeer(handshake.MyAddress, rcQP); + + ibv_ah_attr peerAddr; + TIntrusivePtr<TAddressHandle> ahPeer; + Port->GetAHAttr(wc, pkt.GetGRH(), &peerAddr); + ahPeer = new TAddressHandle(Port->GetCtx(), &peerAddr); + + peerAddr.sl = TRAFFIC_SL; + rcQP->Init(peerAddr, handshake.QPN, handshake.PSN); + + TCmdHandshakeAck handshakeAck; + handshakeAck.Command = CMD_HANDSHAKE_ACK; + handshakeAck.PSN = rcQP->GetPSN(); + handshakeAck.QPN = rcQP->GetQPN(); + handshakeAck.YourQPN = handshake.QPN; + // if ack gets lost we'll create new Peer Channel + // and this one will be erased in Step() by timeout counted from LastRecv + BP.PostSend(WelcomeQP, ahPeer, wc->src_qp, WELCOME_QKEY, &handshakeAck, sizeof(handshakeAck)); + //printf("send handshake_ack\n"); + } break; + case CMD_HANDSHAKE_ACK: { + //printf("got handshake_ack\n"); + const TCmdHandshakeAck& handshakeAck = *(TCmdHandshakeAck*)pkt.GetUDData(); + TIBPeer* peer = GetChannelByQPN(handshakeAck.YourQPN); + if (peer) { + ibv_ah_attr peerAddr; + Port->GetAHAttr(wc, pkt.GetGRH(), &peerAddr); + + peerAddr.sl = TRAFFIC_SL; + peer->QP->Init(peerAddr, handshakeAck.QPN, handshakeAck.PSN); + + peer->State = IIBPeer::OK; + + TCmdConfirm confirm; + confirm.Command = CMD_CONFIRM; + peer->PostSend(BP, &confirm, sizeof(confirm), TCompleteInfo::CI_IGNORE, 0); + //printf("send confirm\n"); + } else { + // respective QPN was deleted or never existed + // silently ignore and peer channel on remote side + // will not get into confirmed state and will be deleted + } + } break; + default: + Y_ASSERT(0); + break; + } + } + bool Step(NHPTimer::STime tCurrent) override { + bool rv = false; + // only have to process completions, everything is done on completion of something + ibv_wc wcArr[10]; + for (;;) { + int wcCount = CQ->Poll(wcArr, Y_ARRAY_SIZE(wcArr)); + if (wcCount == 0) { + break; + } + rv = true; + for (int z = 0; z < wcCount; ++z) { + ibv_wc& wc = wcArr[z]; + if (wc.opcode & IBV_WC_RECV) { + // received msg + if ((int)wc.qp_num == WelcomeQPN) { + if (wc.status != IBV_WC_SUCCESS) { + Y_VERIFY(0, "ud recv op completed with error %d\n", (int)wc.status); + } + Y_ASSERT(wc.opcode == IBV_WC_RECV | IBV_WC_SEND); + ParseWelcomePacket(&wc); + } else { + ParsePacket(&wc, tCurrent); + } + } else { + // send completion + if ((int)wc.qp_num == WelcomeQPN) { + // ok + BP.FreeBuf(wc.wr_id); + } else { + OnComplete(&wc, tCurrent); + } + } + } + } + { + TIntrusivePtr<TIBMemBlock> memBlock; + i64 msgHandle; + TIntrusivePtr<TIBPeer> peer; + while (CopyResults->GetCopyResult(&memBlock, &msgHandle, &peer)) { + if (peer->GetState() != IIBPeer::OK) { + continue; + } + TDeque<TQueuedSend>::iterator z = peer->GetSend(msgHandle); + if (z == peer->SendQueue.end()) { + Y_VERIFY(0, "peer %p, copy completed, msg %d not found?\n", peer.Get(), (int)msgHandle); + continue; + } + TQueuedSend& qs = *z; + qs.MemBlock = memBlock; + if (qs.RemoteAddr != 0) { + peer->PostRDMA(qs); + } + rv = true; + } + } + { + NHPTimer::STime t1 = LastCheckTime; + if (NHPTimer::GetTimePassed(&t1) > CHANNEL_CHECK_INTERVAL) { + for (TPeerChannelHash::iterator z = Channels.begin(); z != Channels.end();) { + TIntrusivePtr<TIBPeer> peer = z->second; + ++z; // peer can be removed from Channels + Y_ASSERT(peer->State != IIBPeer::FAILED); + NHPTimer::STime t2 = peer->LastRecv; + double timeSinceLastRecv = NHPTimer::GetTimePassed(&t2); + if (timeSinceLastRecv > CHANNEL_CHECK_INTERVAL) { + if (peer->State == IIBPeer::CONNECTING) { + Y_ASSERT(peer->OutMsgs.empty() && peer->SendCount == 0); + // if handshake does not seem to work out - close connection + //printf("IB connecting timed out\n"); + PeerFailed(peer); + } else { + // if we have outmsg we hope that IB will report us if there are any problems + // with connectivity + if (peer->OutMsgs.empty()) { + //printf("Sending keep alive\n"); + TCmdKeepAlive keep; + keep.Command = CMD_KEEP_ALIVE; + peer->PostSend(BP, &keep, sizeof(keep), TCompleteInfo::CI_KEEP_ALIVE, 0); + } + } + } + } + LastCheckTime = t1; + } + } + return rv; + } + IIBPeer* ConnectPeer(const TIBConnectInfo& info, const TUdpAddress& peerAddr, const TUdpAddress& myAddr) override { + for (auto& channel : Channels) { + TIntrusivePtr<TIBPeer> peer = channel.second; + if (peer->PeerAddress == peerAddr) { + return peer.Get(); + } + } + TIntrusivePtr<TRCQueuePair> rcQP; + rcQP = new TRCQueuePair(Port->GetCtx(), CQ, BP.GetSRQ(), QP_SEND_QUEUE_SIZE); + + int qpn = rcQP->GetQPN(); + Y_ASSERT(Channels.find(qpn) == Channels.end()); + TIntrusivePtr<TIBPeer>& peer = Channels[qpn]; + peer = new TIBPeer(peerAddr, rcQP); + + TCmdHandshake handshake; + handshake.Command = CMD_HANDSHAKE; + handshake.PSN = rcQP->GetPSN(); + handshake.QPN = rcQP->GetQPN(); + handshake.SocketId = info.SocketId; + handshake.MyAddress = myAddr; + + TIntrusivePtr<TAddressHandle> serverAH; + if (info.LID != 0) { + serverAH = new TAddressHandle(Port, info.LID, CONNECT_SL); + } else { + //ibv_gid addr; + //addr.global.subnet_prefix = info.Subnet; + //addr.global.interface_id = info.Interface; + //serverAH = new TAddressHandle(Port, addr, CONNECT_SL); + + TUdpAddress local = myAddr; + local.Port = 0; + TUdpAddress remote = peerAddr; + remote.Port = 0; + //printf("local Addr %s\n", GetAddressAsString(local).c_str()); + //printf("remote Addr %s\n", GetAddressAsString(remote).c_str()); + // CRAP - somehow prevent connecting machines from different RoCE isles + serverAH = new TAddressHandle(Port, remote, local, CONNECT_SL); + if (!serverAH->IsValid()) { + return nullptr; + } + } + BP.PostSend(WelcomeQP, serverAH, info.QPN, WELCOME_QKEY, &handshake, sizeof(handshake)); + //printf("send handshake\n"); + + return peer.Get(); + } + const TIBConnectInfo& GetConnectInfo() override { + return ConnectInfo; + } + + public: + TIBClientServer(TPtrArg<TIBPort> port) + : Port(port) + , MemPool(GetIBMemPool()) + , CQ(new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS)) + , BP(port->GetCtx(), MAX_SRQ_WORK_REQUESTS) + , WelcomeQP(new TUDQueuePair(port, CQ, BP.GetSRQ(), WELCOME_QP_SEND_SIZE)) + , WelcomeQPN(WelcomeQP->GetQPN()) + , MsgCounter(1) + { + CopyResults = new TIBMemPool::TCopyResultStorage; + CreateGuid(&ConnectInfo.SocketId); + ibv_gid addr; + port->GetGID(&addr); + ConnectInfo.Interface = addr.global.interface_id; + ConnectInfo.Subnet = addr.global.subnet_prefix; + //printf("connect addr subnet %lx, iface %lx\n", addr.global.subnet_prefix, addr.global.interface_id); + ConnectInfo.LID = port->GetLID(); + ConnectInfo.QPN = WelcomeQPN; + + WelcomeQP->Init(WELCOME_QKEY); + + NHPTimer::GetTime(&LastCheckTime); + } + }; + + IIBClientServer* CreateIBClientServer() { + TIntrusivePtr<TIBPort> port = GetIBDevice(); + if (port.Get() == nullptr) { + return nullptr; + } + return new TIBClientServer(port); + } +} diff --git a/library/cpp/netliba/v6/ib_cs.h b/library/cpp/netliba/v6/ib_cs.h new file mode 100644 index 0000000000..932f2880dc --- /dev/null +++ b/library/cpp/netliba/v6/ib_cs.h @@ -0,0 +1,57 @@ +#pragma once + +#include "udp_address.h" +#include "block_chain.h" +#include "net_request.h" +#include <util/generic/guid.h> +#include <util/system/hp_timer.h> + +namespace NNetliba { + struct TIBConnectInfo { + TGUID SocketId; + ui64 Subnet, Interface; + int LID; + int QPN; + }; + + struct TRCQueuePairHandshake { + int QPN, PSN; + }; + + using TIBMsgHandle = i64; + + struct TIBSendResult { + TIBMsgHandle Handle; + bool Success; + TIBSendResult() + : Handle(0) + , Success(false) + { + } + TIBSendResult(TIBMsgHandle handle, bool success) + : Handle(handle) + , Success(success) + { + } + }; + + struct IIBPeer: public TThrRefBase { + enum EState { + CONNECTING, + OK, + FAILED, + }; + virtual EState GetState() = 0; + }; + + struct IIBClientServer: public TThrRefBase { + virtual TRequest* GetRequest() = 0; + virtual TIBMsgHandle Send(TPtrArg<IIBPeer> peer, TRopeDataPacket* data, const TGUID& packetGuid) = 0; + virtual bool GetSendResult(TIBSendResult* res) = 0; + virtual bool Step(NHPTimer::STime tCurrent) = 0; + virtual IIBPeer* ConnectPeer(const TIBConnectInfo& info, const TUdpAddress& peerAddr, const TUdpAddress& myAddr) = 0; + virtual const TIBConnectInfo& GetConnectInfo() = 0; + }; + + IIBClientServer* CreateIBClientServer(); +} diff --git a/library/cpp/netliba/v6/ib_low.cpp b/library/cpp/netliba/v6/ib_low.cpp new file mode 100644 index 0000000000..455a5f3512 --- /dev/null +++ b/library/cpp/netliba/v6/ib_low.cpp @@ -0,0 +1,114 @@ +#include "stdafx.h" +#include "ib_low.h" + +namespace NNetliba { + static bool EnableROCEFlag = false; + + void EnableROCE(bool f) { + EnableROCEFlag = f; + } + +#if defined _linux_ && !defined CATBOOST_OPENSOURCE + static TMutex IBPortMutex; + static TIntrusivePtr<TIBPort> IBPort; + static bool IBWasInitialized; + + TIntrusivePtr<TIBPort> GetIBDevice() { + TGuard<TMutex> gg(IBPortMutex); + if (IBWasInitialized) { + return IBPort; + } + IBWasInitialized = true; + + try { + int rv = ibv_fork_init(); + if (rv != 0) { + //printf("ibv_fork_init() failed"); + return nullptr; + } + } catch (...) { + //we can not load ib interface, so no ib + return nullptr; + } + + TIntrusivePtr<TIBContext> ctx; + TIntrusivePtr<TIBPort> resPort; + int numDevices; + ibv_device** deviceList = ibv_get_device_list(&numDevices); + //for (int i = 0; i < numDevices; ++i) { + // ibv_device *dev = deviceList[i]; + + // printf("Dev %d\n", i); + // printf("name:%s\ndev_name:%s\ndev_path:%s\nibdev_path:%s\n", + // dev->name, + // dev->dev_name, + // dev->dev_path, + // dev->ibdev_path); + // printf("get_device_name(): %s\n", ibv_get_device_name(dev)); + // ui64 devGuid = ibv_get_device_guid(dev); + // printf("ibv_get_device_guid: %" PRIx64 "\n", devGuid); + // printf("node type: %s\n", ibv_node_type_str(dev->node_type)); + // printf("\n"); + //} + if (numDevices == 1) { + ctx = new TIBContext(deviceList[0]); + TIBContext::TLock ibContext(ctx); + ibv_device_attr devAttrs; + CHECK_Z(ibv_query_device(ibContext.GetContext(), &devAttrs)); + + for (int port = 1; port <= devAttrs.phys_port_cnt; ++port) { + ibv_port_attr portAttrs; + CHECK_Z(ibv_query_port(ibContext.GetContext(), port, &portAttrs)); + //ibv_gid myAddress; // ipv6 address of this port; + //CHECK_Z(ibv_query_gid(ibContext.GetContext(), port, 0, &myAddress)); + //{ + // ibv_gid p = myAddress; + // for (int k = 0; k < 4; ++k) { + // DoSwap(p.raw[k], p.raw[7 - k]); + // DoSwap(p.raw[8 + k], p.raw[15 - k]); + // } + // printf("Port %d, address %" PRIx64 ":%" PRIx64 "\n", + // port, + // p.global.subnet_prefix, + // p.global.interface_id); + //} + + // skip ROCE if flag is not set + if (portAttrs.lid == 0 && EnableROCEFlag == false) { + continue; + } + // bind to first active port + if (portAttrs.state == IBV_PORT_ACTIVE) { + resPort = new TIBPort(ctx, port); + break; + } + } + } else { + //printf("%d IB devices found, fail\n", numDevices); + ctx = nullptr; + } + ibv_free_device_list(deviceList); + IBPort = resPort; + return IBPort; + } + + void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel) { + ibv_gid localGid, remoteGid; + localGid.global.subnet_prefix = localAddr.Network; + localGid.global.interface_id = localAddr.Interface; + remoteGid.global.subnet_prefix = remoteAddr.Network; + remoteGid.global.interface_id = remoteAddr.Interface; + + Zero(*res); + res->is_global = 1; + res->port_num = port->GetPort(); + res->sl = serviceLevel; + res->grh.dgid = remoteGid; + //res->grh.flow_label = 0; + res->grh.sgid_index = port->GetGIDIndex(localGid); + res->grh.hop_limit = 7; + //res->grh.traffic_class = 0; + } + +#endif +} diff --git a/library/cpp/netliba/v6/ib_low.h b/library/cpp/netliba/v6/ib_low.h new file mode 100644 index 0000000000..b2a3e341d2 --- /dev/null +++ b/library/cpp/netliba/v6/ib_low.h @@ -0,0 +1,797 @@ +#pragma once + +#include "udp_address.h" + +#if defined(_linux_) && !defined(CATBOOST_OPENSOURCE) +#include <contrib/libs/ibdrv/include/infiniband/verbs.h> +#include <contrib/libs/ibdrv/include/rdma/rdma_cma.h> +#endif + +namespace NNetliba { +#define CHECK_Z(x) \ + { \ + int rv = (x); \ + if (rv != 0) { \ + fprintf(stderr, "check_z failed, errno = %d\n", errno); \ + Y_VERIFY(0, "check_z"); \ + } \ + } + + ////////////////////////////////////////////////////////////////////////// + const int MAX_SGE = 1; + const size_t MAX_INLINE_DATA_SIZE = 16; + const int MAX_OUTSTANDING_RDMA = 10; + +#if defined(_linux_) && !defined(CATBOOST_OPENSOURCE) + class TIBContext: public TThrRefBase, TNonCopyable { + ibv_context* Context; + ibv_pd* ProtDomain; + TMutex Lock; + + ~TIBContext() override { + if (Context) { + CHECK_Z(ibv_dealloc_pd(ProtDomain)); + CHECK_Z(ibv_close_device(Context)); + } + } + + public: + TIBContext(ibv_device* device) { + Context = ibv_open_device(device); + if (Context) { + ProtDomain = ibv_alloc_pd(Context); + } + } + bool IsValid() const { + return Context != nullptr && ProtDomain != nullptr; + } + + class TLock { + TIntrusivePtr<TIBContext> Ptr; + TGuard<TMutex> Guard; + + public: + TLock(TPtrArg<TIBContext> ctx) + : Ptr(ctx) + , Guard(ctx->Lock) + { + } + ibv_context* GetContext() { + return Ptr->Context; + } + ibv_pd* GetProtDomain() { + return Ptr->ProtDomain; + } + }; + }; + + class TIBPort: public TThrRefBase, TNonCopyable { + int Port; + int LID; + TIntrusivePtr<TIBContext> IBCtx; + enum { + MAX_GID = 16 + }; + ibv_gid MyGidArr[MAX_GID]; + + public: + TIBPort(TPtrArg<TIBContext> ctx, int port) + : IBCtx(ctx) + { + ibv_port_attr portAttrs; + TIBContext::TLock ibContext(IBCtx); + CHECK_Z(ibv_query_port(ibContext.GetContext(), port, &portAttrs)); + Port = port; + LID = portAttrs.lid; + for (int i = 0; i < MAX_GID; ++i) { + ibv_gid& dst = MyGidArr[i]; + Zero(dst); + ibv_query_gid(ibContext.GetContext(), Port, i, &dst); + } + } + int GetPort() const { + return Port; + } + int GetLID() const { + return LID; + } + TIBContext* GetCtx() { + return IBCtx.Get(); + } + void GetGID(ibv_gid* res) const { + *res = MyGidArr[0]; + } + int GetGIDIndex(const ibv_gid& arg) const { + for (int i = 0; i < MAX_GID; ++i) { + const ibv_gid& chk = MyGidArr[i]; + if (memcmp(&chk, &arg, sizeof(chk)) == 0) { + return i; + } + } + return 0; + } + void GetAHAttr(ibv_wc* wc, ibv_grh* grh, ibv_ah_attr* res) { + TIBContext::TLock ibContext(IBCtx); + CHECK_Z(ibv_init_ah_from_wc(ibContext.GetContext(), Port, wc, grh, res)); + } + }; + + class TComplectionQueue: public TThrRefBase, TNonCopyable { + ibv_cq* CQ; + TIntrusivePtr<TIBContext> IBCtx; + + ~TComplectionQueue() override { + if (CQ) { + CHECK_Z(ibv_destroy_cq(CQ)); + } + } + + public: + TComplectionQueue(TPtrArg<TIBContext> ctx, int maxCQEcount) + : IBCtx(ctx) + { + TIBContext::TLock ibContext(IBCtx); + /* ibv_cq_init_attr_ex attr; + Zero(attr); + attr.cqe = maxCQEcount; + attr.cq_create_flags = 0; + ibv_cq_ex *vcq = ibv_create_cq_ex(ibContext.GetContext(), &attr); + if (vcq) { + CQ = (ibv_cq*)vcq; // doubtful trick but that's life + } else {*/ + // no completion channel + // no completion vector + CQ = ibv_create_cq(ibContext.GetContext(), maxCQEcount, nullptr, nullptr, 0); + // } + } + ibv_cq* GetCQ() { + return CQ; + } + int Poll(ibv_wc* res, int bufSize) { + Y_ASSERT(bufSize >= 1); + //struct ibv_wc + //{ + // ui64 wr_id; + // enum ibv_wc_status status; + // enum ibv_wc_opcode opcode; + // ui32 vendor_err; + // ui32 byte_len; + // ui32 imm_data;/* network byte order */ + // ui32 qp_num; + // ui32 src_qp; + // enum ibv_wc_flags wc_flags; + // ui16 pkey_index; + // ui16 slid; + // ui8 sl; + // ui8 dlid_path_bits; + //}; + int rv = ibv_poll_cq(CQ, bufSize, res); + if (rv < 0) { + Y_VERIFY(0, "ibv_poll_cq failed"); + } + if (rv > 0) { + //printf("Completed wr\n"); + //printf(" wr_id = %" PRIx64 "\n", wc.wr_id); + //printf(" status = %d\n", wc.status); + //printf(" opcode = %d\n", wc.opcode); + //printf(" byte_len = %d\n", wc.byte_len); + //printf(" imm_data = %d\n", wc.imm_data); + //printf(" qp_num = %d\n", wc.qp_num); + //printf(" src_qp = %d\n", wc.src_qp); + //printf(" wc_flags = %x\n", wc.wc_flags); + //printf(" slid = %d\n", wc.slid); + } + //rv = number_of_toggled_wc; + return rv; + } + }; + + //struct ibv_mr + //{ + // struct ibv_context *context; + // struct ibv_pd *pd; + // void *addr; + // size_t length; + // ui32 handle; + // ui32 lkey; + // ui32 rkey; + //}; + class TMemoryRegion: public TThrRefBase, TNonCopyable { + ibv_mr* MR; + TIntrusivePtr<TIBContext> IBCtx; + + ~TMemoryRegion() override { + if (MR) { + CHECK_Z(ibv_dereg_mr(MR)); + } + } + + public: + TMemoryRegion(TPtrArg<TIBContext> ctx, size_t len) + : IBCtx(ctx) + { + TIBContext::TLock ibContext(IBCtx); + int access = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ; // TODO: IBV_ACCESS_ALLOCATE_MR + MR = ibv_reg_mr(ibContext.GetProtDomain(), 0, len, access); + Y_ASSERT(MR); + } + ui32 GetLKey() const { + static_assert(sizeof(ui32) == sizeof(MR->lkey), "expect sizeof(ui32) == sizeof(MR->lkey)"); + return MR->lkey; + } + ui32 GetRKey() const { + static_assert(sizeof(ui32) == sizeof(MR->lkey), "expect sizeof(ui32) == sizeof(MR->lkey)"); + return MR->lkey; + } + char* GetData() { + return MR ? (char*)MR->addr : nullptr; + } + bool IsCovered(const void* data, size_t len) const { + size_t dataAddr = (const char*)data - (const char*)nullptr; + size_t bufAddr = (const char*)MR->addr - (const char*)nullptr; + return (dataAddr >= bufAddr) && (dataAddr + len <= bufAddr + MR->length); + } + }; + + class TSharedReceiveQueue: public TThrRefBase, TNonCopyable { + ibv_srq* SRQ; + TIntrusivePtr<TIBContext> IBCtx; + + ~TSharedReceiveQueue() override { + if (SRQ) { + ibv_destroy_srq(SRQ); + } + } + + public: + TSharedReceiveQueue(TPtrArg<TIBContext> ctx, int maxWR) + : IBCtx(ctx) + { + ibv_srq_init_attr attr; + Zero(attr); + attr.srq_context = this; + attr.attr.max_sge = MAX_SGE; + attr.attr.max_wr = maxWR; + + TIBContext::TLock ibContext(IBCtx); + SRQ = ibv_create_srq(ibContext.GetProtDomain(), &attr); + Y_ASSERT(SRQ); + } + ibv_srq* GetSRQ() { + return SRQ; + } + void PostReceive(TPtrArg<TMemoryRegion> mem, ui64 id, const void* buf, size_t len) { + Y_ASSERT(mem->IsCovered(buf, len)); + ibv_recv_wr wr, *bad; + ibv_sge sg; + sg.addr = (const char*)buf - (const char*)nullptr; + sg.length = len; + sg.lkey = mem->GetLKey(); + Zero(wr); + wr.wr_id = id; + wr.sg_list = &sg; + wr.num_sge = 1; + CHECK_Z(ibv_post_srq_recv(SRQ, &wr, &bad)); + } + }; + + inline void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, int lid, int serviceLevel) { + Zero(*res); + res->dlid = lid; + res->port_num = port->GetPort(); + res->sl = serviceLevel; + } + + void MakeAH(ibv_ah_attr* res, TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel); + + class TAddressHandle: public TThrRefBase, TNonCopyable { + ibv_ah* AH; + TIntrusivePtr<TIBContext> IBCtx; + + ~TAddressHandle() override { + if (AH) { + CHECK_Z(ibv_destroy_ah(AH)); + } + AH = nullptr; + IBCtx = nullptr; + } + + public: + TAddressHandle(TPtrArg<TIBContext> ctx, ibv_ah_attr* attr) + : IBCtx(ctx) + { + TIBContext::TLock ibContext(IBCtx); + AH = ibv_create_ah(ibContext.GetProtDomain(), attr); + Y_ASSERT(AH != nullptr); + } + TAddressHandle(TPtrArg<TIBPort> port, int lid, int serviceLevel) + : IBCtx(port->GetCtx()) + { + ibv_ah_attr attr; + MakeAH(&attr, port, lid, serviceLevel); + TIBContext::TLock ibContext(IBCtx); + AH = ibv_create_ah(ibContext.GetProtDomain(), &attr); + Y_ASSERT(AH != nullptr); + } + TAddressHandle(TPtrArg<TIBPort> port, const TUdpAddress& remoteAddr, const TUdpAddress& localAddr, int serviceLevel) + : IBCtx(port->GetCtx()) + { + ibv_ah_attr attr; + MakeAH(&attr, port, remoteAddr, localAddr, serviceLevel); + TIBContext::TLock ibContext(IBCtx); + AH = ibv_create_ah(ibContext.GetProtDomain(), &attr); + Y_ASSERT(AH != nullptr); + } + ibv_ah* GetAH() { + return AH; + } + bool IsValid() const { + return AH != nullptr; + } + }; + + // GRH + wc -> address_handle_attr + //int ibv_init_ah_from_wc(struct ibv_context *context, ui8 port_num, + //struct ibv_wc *wc, struct ibv_grh *grh, + //struct ibv_ah_attr *ah_attr) + //ibv_create_ah_from_wc(struct ibv_pd *pd, struct ibv_wc *wc, struct ibv_grh + // *grh, ui8 port_num) + + class TQueuePair: public TThrRefBase, TNonCopyable { + protected: + ibv_qp* QP; + int MyPSN; // start packet sequence number + TIntrusivePtr<TIBContext> IBCtx; + TIntrusivePtr<TComplectionQueue> CQ; + TIntrusivePtr<TSharedReceiveQueue> SRQ; + + TQueuePair(TPtrArg<TIBContext> ctx, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq, + int sendQueueSize, + ibv_qp_type qp_type) + : IBCtx(ctx) + , CQ(cq) + , SRQ(srq) + { + MyPSN = GetCycleCount() & 0xffffff; // should be random and different on different runs, 24bit + + ibv_qp_init_attr attr; + Zero(attr); + attr.qp_context = this; // not really useful + attr.send_cq = cq->GetCQ(); + attr.recv_cq = cq->GetCQ(); + attr.srq = srq->GetSRQ(); + attr.cap.max_send_wr = sendQueueSize; + attr.cap.max_recv_wr = 0; // we are using srq, no need for qp's rq + attr.cap.max_send_sge = MAX_SGE; + attr.cap.max_recv_sge = MAX_SGE; + attr.cap.max_inline_data = MAX_INLINE_DATA_SIZE; + attr.qp_type = qp_type; + attr.sq_sig_all = 1; // inline sends need not be signaled, but if they are not work queue overflows + + TIBContext::TLock ibContext(IBCtx); + QP = ibv_create_qp(ibContext.GetProtDomain(), &attr); + Y_ASSERT(QP); + + //struct ibv_qp { + // struct ibv_context *context; + // void *qp_context; + // struct ibv_pd *pd; + // struct ibv_cq *send_cq; + // struct ibv_cq *recv_cq; + // struct ibv_srq *srq; + // ui32 handle; + // ui32 qp_num; + // enum ibv_qp_state state; + // enum ibv_qp_type qp_type; + + // pthread_mutex_t mutex; + // pthread_cond_t cond; + // ui32 events_completed; + //}; + //qp_context The value qp_context that was provided to ibv_create_qp() + //qp_num The number of this Queue Pair + //state The last known state of this Queue Pair. The actual state may be different from this state (in the RDMA device transitioned the state into other state) + //qp_type The Transport Service Type of this Queue Pair + } + ~TQueuePair() override { + if (QP) { + CHECK_Z(ibv_destroy_qp(QP)); + } + } + void FillSendAttrs(ibv_send_wr* wr, ibv_sge* sg, + ui64 localAddr, ui32 lKey, ui64 id, size_t len) { + sg->addr = localAddr; + sg->length = len; + sg->lkey = lKey; + Zero(*wr); + wr->wr_id = id; + wr->sg_list = sg; + wr->num_sge = 1; + if (len <= MAX_INLINE_DATA_SIZE) { + wr->send_flags = IBV_SEND_INLINE; + } + } + void FillSendAttrs(ibv_send_wr* wr, ibv_sge* sg, + TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) { + ui64 localAddr = (const char*)data - (const char*)nullptr; + ui32 lKey = 0; + if (mem) { + Y_ASSERT(mem->IsCovered(data, len)); + lKey = mem->GetLKey(); + } else { + Y_ASSERT(len <= MAX_INLINE_DATA_SIZE); + } + FillSendAttrs(wr, sg, localAddr, lKey, id, len); + } + + public: + int GetQPN() const { + if (QP) + return QP->qp_num; + return 0; + } + int GetPSN() const { + return MyPSN; + } + // we are using srq + //void PostReceive(const TMemoryRegion &mem) + //{ + // ibv_recv_wr wr, *bad; + // ibv_sge sg; + // sg.addr = mem.Addr; + // sg.length = mem.Length; + // sg.lkey = mem.lkey; + // Zero(wr); + // wr.wr_id = 13; + // wr.sg_list = sg; + // wr.num_sge = 1; + // CHECK_Z(ibv_post_recv(QP, &wr, &bad)); + //} + }; + + class TRCQueuePair: public TQueuePair { + public: + TRCQueuePair(TPtrArg<TIBContext> ctx, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq, int sendQueueSize) + : TQueuePair(ctx, cq, srq, sendQueueSize, IBV_QPT_RC) + { + } + // SRQ should have receive posted + void Init(const ibv_ah_attr& peerAddr, int peerQPN, int peerPSN) { + Y_ASSERT(QP->qp_type == IBV_QPT_RC); + ibv_qp_attr attr; + //{ + // enum ibv_qp_state qp_state; + // enum ibv_qp_state cur_qp_state; + // enum ibv_mtu path_mtu; + // enum ibv_mig_state path_mig_state; + // ui32 qkey; + // ui32 rq_psn; + // ui32 sq_psn; + // ui32 dest_qp_num; + // int qp_access_flags; + // struct ibv_qp_cap cap; + // struct ibv_ah_attr ah_attr; + // struct ibv_ah_attr alt_ah_attr; + // ui16 pkey_index; + // ui16 alt_pkey_index; + // ui8 en_sqd_async_notify; + // ui8 sq_draining; + // ui8 max_rd_atomic; + // ui8 max_dest_rd_atomic; + // ui8 min_rnr_timer; + // ui8 port_num; + // ui8 timeout; + // ui8 retry_cnt; + // ui8 rnr_retry; + // ui8 alt_port_num; + // ui8 alt_timeout; + //}; + // RESET -> INIT + Zero(attr); + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = peerAddr.port_num; + // for connected QP + attr.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_REMOTE_ATOMIC; + CHECK_Z(ibv_modify_qp(QP, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + + // INIT -> ReadyToReceive + //PostReceive(mem); + attr.qp_state = IBV_QPS_RTR; + attr.path_mtu = IBV_MTU_512; // allows more fine grained VL arbitration + // for connected QP + attr.ah_attr = peerAddr; + attr.dest_qp_num = peerQPN; + attr.rq_psn = peerPSN; + attr.max_dest_rd_atomic = MAX_OUTSTANDING_RDMA; // number of outstanding RDMA requests + attr.min_rnr_timer = 12; // recommended + CHECK_Z(ibv_modify_qp(QP, &attr, + IBV_QP_STATE | IBV_QP_PATH_MTU | + IBV_QP_AV | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)); + + // ReadyToReceive -> ReadyToTransmit + attr.qp_state = IBV_QPS_RTS; + // for connected QP + attr.timeout = 14; // increased to 18 for sometime, 14 recommended + //attr.retry_cnt = 0; // for debug purposes + //attr.rnr_retry = 0; // for debug purposes + attr.retry_cnt = 7; // release configuration + attr.rnr_retry = 7; // release configuration (try forever) + attr.sq_psn = MyPSN; + attr.max_rd_atomic = MAX_OUTSTANDING_RDMA; // number of outstanding RDMA requests + CHECK_Z(ibv_modify_qp(QP, &attr, + IBV_QP_STATE | + IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC)); + } + void PostSend(TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) { + ibv_send_wr wr, *bad; + ibv_sge sg; + FillSendAttrs(&wr, &sg, mem, id, data, len); + wr.opcode = IBV_WR_SEND; + //IBV_WR_RDMA_WRITE + //IBV_WR_RDMA_WRITE_WITH_IMM + //IBV_WR_SEND + //IBV_WR_SEND_WITH_IMM + //IBV_WR_RDMA_READ + //wr.imm_data = xz; + + CHECK_Z(ibv_post_send(QP, &wr, &bad)); + } + void PostRDMAWrite(ui64 remoteAddr, ui32 remoteKey, + TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) { + ibv_send_wr wr, *bad; + ibv_sge sg; + FillSendAttrs(&wr, &sg, mem, id, data, len); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.wr.rdma.remote_addr = remoteAddr; + wr.wr.rdma.rkey = remoteKey; + + CHECK_Z(ibv_post_send(QP, &wr, &bad)); + } + void PostRDMAWrite(ui64 remoteAddr, ui32 remoteKey, + ui64 localAddr, ui32 localKey, ui64 id, size_t len) { + ibv_send_wr wr, *bad; + ibv_sge sg; + FillSendAttrs(&wr, &sg, localAddr, localKey, id, len); + wr.opcode = IBV_WR_RDMA_WRITE; + wr.wr.rdma.remote_addr = remoteAddr; + wr.wr.rdma.rkey = remoteKey; + + CHECK_Z(ibv_post_send(QP, &wr, &bad)); + } + void PostRDMAWriteImm(ui64 remoteAddr, ui32 remoteKey, ui32 immData, + TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) { + ibv_send_wr wr, *bad; + ibv_sge sg; + FillSendAttrs(&wr, &sg, mem, id, data, len); + wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; + wr.imm_data = immData; + wr.wr.rdma.remote_addr = remoteAddr; + wr.wr.rdma.rkey = remoteKey; + + CHECK_Z(ibv_post_send(QP, &wr, &bad)); + } + }; + + class TUDQueuePair: public TQueuePair { + TIntrusivePtr<TIBPort> Port; + + public: + TUDQueuePair(TPtrArg<TIBPort> port, TPtrArg<TComplectionQueue> cq, TPtrArg<TSharedReceiveQueue> srq, int sendQueueSize) + : TQueuePair(port->GetCtx(), cq, srq, sendQueueSize, IBV_QPT_UD) + , Port(port) + { + } + // SRQ should have receive posted + void Init(int qkey) { + Y_ASSERT(QP->qp_type == IBV_QPT_UD); + ibv_qp_attr attr; + // RESET -> INIT + Zero(attr); + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = 0; + attr.port_num = Port->GetPort(); + // for unconnected qp + attr.qkey = qkey; + CHECK_Z(ibv_modify_qp(QP, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)); + + // INIT -> ReadyToReceive + //PostReceive(mem); + attr.qp_state = IBV_QPS_RTR; + CHECK_Z(ibv_modify_qp(QP, &attr, IBV_QP_STATE)); + + // ReadyToReceive -> ReadyToTransmit + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = 0; + CHECK_Z(ibv_modify_qp(QP, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)); + } + void PostSend(TPtrArg<TAddressHandle> ah, int remoteQPN, int remoteQKey, + TPtrArg<TMemoryRegion> mem, ui64 id, const void* data, size_t len) { + ibv_send_wr wr, *bad; + ibv_sge sg; + FillSendAttrs(&wr, &sg, mem, id, data, len); + wr.opcode = IBV_WR_SEND; + wr.wr.ud.ah = ah->GetAH(); + wr.wr.ud.remote_qpn = remoteQPN; + wr.wr.ud.remote_qkey = remoteQKey; + //IBV_WR_SEND_WITH_IMM + //wr.imm_data = xz; + + CHECK_Z(ibv_post_send(QP, &wr, &bad)); + } + }; + + TIntrusivePtr<TIBPort> GetIBDevice(); + +#else + ////////////////////////////////////////////////////////////////////////// + // stub for OS without IB support + ////////////////////////////////////////////////////////////////////////// + enum ibv_wc_opcode { + IBV_WC_SEND, + IBV_WC_RDMA_WRITE, + IBV_WC_RDMA_READ, + IBV_WC_COMP_SWAP, + IBV_WC_FETCH_ADD, + IBV_WC_BIND_MW, + IBV_WC_RECV = 1 << 7, + IBV_WC_RECV_RDMA_WITH_IMM + }; + + enum ibv_wc_status { + IBV_WC_SUCCESS, + // lots of errors follow + }; + //struct ibv_device; + //struct ibv_pd; + union ibv_gid { + ui8 raw[16]; + struct { + ui64 subnet_prefix; + ui64 interface_id; + } global; + }; + + struct ibv_wc { + ui64 wr_id; + enum ibv_wc_status status; + enum ibv_wc_opcode opcode; + ui32 imm_data; /* in network byte order */ + ui32 qp_num; + ui32 src_qp; + }; + struct ibv_grh {}; + struct ibv_ah_attr { + ui8 sl; + }; + //struct ibv_cq; + class TIBContext: public TThrRefBase, TNonCopyable { + public: + bool IsValid() const { + return false; + } + //ibv_context *GetContext() { return 0; } + //ibv_pd *GetProtDomain() { return 0; } + }; + + class TIBPort: public TThrRefBase, TNonCopyable { + public: + TIBPort(TPtrArg<TIBContext>, int) { + } + int GetPort() const { + return 1; + } + int GetLID() const { + return 1; + } + TIBContext* GetCtx() { + return 0; + } + void GetGID(ibv_gid* res) { + Zero(*res); + } + void GetAHAttr(ibv_wc*, ibv_grh*, ibv_ah_attr*) { + } + }; + + class TComplectionQueue: public TThrRefBase, TNonCopyable { + public: + TComplectionQueue(TPtrArg<TIBContext>, int) { + } + //ibv_cq *GetCQ() { return 0; } + int Poll(ibv_wc*, int) { + return 0; + } + }; + + class TMemoryRegion: public TThrRefBase, TNonCopyable { + public: + TMemoryRegion(TPtrArg<TIBContext>, size_t) { + } + ui32 GetLKey() const { + return 0; + } + ui32 GetRKey() const { + return 0; + } + char* GetData() { + return 0; + } + bool IsCovered(const void*, size_t) const { + return false; + } + }; + + class TSharedReceiveQueue: public TThrRefBase, TNonCopyable { + public: + TSharedReceiveQueue(TPtrArg<TIBContext>, int) { + } + //ibv_srq *GetSRQ() { return SRQ; } + void PostReceive(TPtrArg<TMemoryRegion>, ui64, const void*, size_t) { + } + }; + + inline void MakeAH(ibv_ah_attr*, TPtrArg<TIBPort>, int, int) { + } + + class TAddressHandle: public TThrRefBase, TNonCopyable { + public: + TAddressHandle(TPtrArg<TIBContext>, ibv_ah_attr*) { + } + TAddressHandle(TPtrArg<TIBPort>, int, int) { + } + TAddressHandle(TPtrArg<TIBPort>, const TUdpAddress&, const TUdpAddress&, int) { + } + //ibv_ah *GetAH() { return AH; } + bool IsValid() { + return true; + } + }; + + class TQueuePair: public TThrRefBase, TNonCopyable { + public: + int GetQPN() const { + return 0; + } + int GetPSN() const { + return 0; + } + }; + + class TRCQueuePair: public TQueuePair { + public: + TRCQueuePair(TPtrArg<TIBContext>, TPtrArg<TComplectionQueue>, TPtrArg<TSharedReceiveQueue>, int) { + } + // SRQ should have receive posted + void Init(const ibv_ah_attr&, int, int) { + } + void PostSend(TPtrArg<TMemoryRegion>, ui64, const void*, size_t) { + } + void PostRDMAWrite(ui64, ui32, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) { + } + void PostRDMAWrite(ui64, ui32, ui64, ui32, ui64, size_t) { + } + void PostRDMAWriteImm(ui64, ui32, ui32, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) { + } + }; + + class TUDQueuePair: public TQueuePair { + TIntrusivePtr<TIBPort> Port; + + public: + TUDQueuePair(TPtrArg<TIBPort>, TPtrArg<TComplectionQueue>, TPtrArg<TSharedReceiveQueue>, int) { + } + // SRQ should have receive posted + void Init(int) { + } + void PostSend(TPtrArg<TAddressHandle>, int, int, TPtrArg<TMemoryRegion>, ui64, const void*, size_t) { + } + }; + + inline TIntrusivePtr<TIBPort> GetIBDevice() { + return 0; + } +#endif +} diff --git a/library/cpp/netliba/v6/ib_mem.cpp b/library/cpp/netliba/v6/ib_mem.cpp new file mode 100644 index 0000000000..1e7f55ac57 --- /dev/null +++ b/library/cpp/netliba/v6/ib_mem.cpp @@ -0,0 +1,167 @@ +#include "stdafx.h" +#include "ib_mem.h" +#include "ib_low.h" +#include "cpu_affinity.h" + +#if defined(_unix_) +#include <pthread.h> +#endif + +namespace NNetliba { + TIBMemSuperBlock::TIBMemSuperBlock(TIBMemPool* pool, size_t szLog) + : Pool(pool) + , SzLog(szLog) + , UseCount(0) + { + size_t sz = GetSize(); + MemRegion = new TMemoryRegion(pool->GetIBContext(), sz); + //printf("Alloc super block, size %" PRId64 "\n", sz); + } + + TIBMemSuperBlock::~TIBMemSuperBlock() { + Y_ASSERT(AtomicGet(UseCount) == 0); + } + + char* TIBMemSuperBlock::GetData() { + return MemRegion->GetData(); + } + + void TIBMemSuperBlock::DecRef() { + if (AtomicAdd(UseCount, -1) == 0) { + Pool->Return(this); + } + } + + TIBMemBlock::~TIBMemBlock() { + if (Super.Get()) { + Super->DecRef(); + } else { + delete[] Data; + } + } + + ////////////////////////////////////////////////////////////////////////// + TIBMemPool::TIBMemPool(TPtrArg<TIBContext> ctx) + : IBCtx(ctx) + , AllocCacheSize(0) + , CurrentOffset(IB_MEM_LARGE_BLOCK) + , WorkThread(TThread::TParams(ThreadFunc, (void*)this).SetName("nl6_ib_mem_pool")) + , KeepRunning(true) + { + WorkThread.Start(); + HasStarted.Wait(); + } + + TIBMemPool::~TIBMemPool() { + Y_ASSERT(WorkThread.Running()); + KeepRunning = false; + HasWork.Signal(); + WorkThread.Join(); + { + TJobItem* work = nullptr; + while (Requests.Dequeue(&work)) { + delete work; + } + } + } + + TIntrusivePtr<TIBMemSuperBlock> TIBMemPool::AllocSuper(size_t szArg) { + // assume CacheLock is taken + size_t szLog = 12; + while ((((size_t)1) << szLog) < szArg) { + ++szLog; + } + TIntrusivePtr<TIBMemSuperBlock> super; + { + TVector<TIntrusivePtr<TIBMemSuperBlock>>& cc = AllocCache[szLog]; + if (!cc.empty()) { + super = cc.back(); + cc.resize(cc.size() - 1); + AllocCacheSize -= 1ll << super->SzLog; + } + } + if (super.Get() == nullptr) { + super = new TIBMemSuperBlock(this, szLog); + } + return super; + } + + TIBMemBlock* TIBMemPool::Alloc(size_t sz) { + TGuard<TMutex> gg(CacheLock); + if (sz > IB_MEM_LARGE_BLOCK) { + TIntrusivePtr<TIBMemSuperBlock> super = AllocSuper(sz); + return new TIBMemBlock(super, super->GetData(), sz); + } else { + if (CurrentOffset + sz > IB_MEM_LARGE_BLOCK) { + CurrentBlk.Assign(AllocSuper(IB_MEM_LARGE_BLOCK)); + CurrentOffset = 0; + } + CurrentOffset += sz; + return new TIBMemBlock(CurrentBlk.Get(), CurrentBlk.Get()->GetData() + CurrentOffset - sz, sz); + } + } + + void TIBMemPool::Return(TPtrArg<TIBMemSuperBlock> blk) { + TGuard<TMutex> gg(CacheLock); + Y_ASSERT(AtomicGet(blk->UseCount) == 0); + size_t sz = 1ull << blk->SzLog; + if (sz + AllocCacheSize > IB_MEM_POOL_SIZE) { + AllocCache.clear(); + AllocCacheSize = 0; + } + { + TVector<TIntrusivePtr<TIBMemSuperBlock>>& cc = AllocCache[blk->SzLog]; + cc.push_back(blk.Get()); + AllocCacheSize += sz; + } + } + + void* TIBMemPool::ThreadFunc(void* param) { + BindToSocket(0); + SetHighestThreadPriority(); + TIBMemPool* pThis = (TIBMemPool*)param; + pThis->HasStarted.Signal(); + + while (pThis->KeepRunning) { + TJobItem* work = nullptr; + if (!pThis->Requests.Dequeue(&work)) { + pThis->HasWork.Reset(); + if (!pThis->Requests.Dequeue(&work)) { + pThis->HasWork.Wait(); + } + } + if (work) { + //printf("mem copy got work\n"); + int sz = work->Data->GetSize(); + work->Block = pThis->Alloc(sz); + TBlockChainIterator bc(work->Data->GetChain()); + bc.Read(work->Block->GetData(), sz); + TIntrusivePtr<TCopyResultStorage> dst = work->ResultStorage; + work->ResultStorage = nullptr; + dst->Results.Enqueue(work); + //printf("mem copy completed\n"); + } + } + return nullptr; + } + + ////////////////////////////////////////////////////////////////////////// + static TMutex IBMemMutex; + static TIntrusivePtr<TIBMemPool> IBMemPool; + static bool IBWasInitialized; + + TIntrusivePtr<TIBMemPool> GetIBMemPool() { + TGuard<TMutex> gg(IBMemMutex); + if (IBWasInitialized) { + return IBMemPool; + } + IBWasInitialized = true; + + TIntrusivePtr<TIBPort> ibPort = GetIBDevice(); + if (ibPort.Get() == nullptr) { + return nullptr; + } + IBMemPool = new TIBMemPool(ibPort->GetCtx()); + return IBMemPool; + } +} diff --git a/library/cpp/netliba/v6/ib_mem.h b/library/cpp/netliba/v6/ib_mem.h new file mode 100644 index 0000000000..dfa5b9cd5f --- /dev/null +++ b/library/cpp/netliba/v6/ib_mem.h @@ -0,0 +1,178 @@ +#pragma once + +#include "block_chain.h" +#include <util/thread/lfqueue.h> +#include <util/system/thread.h> + +namespace NNetliba { + // registered memory blocks + class TMemoryRegion; + class TIBContext; + + class TIBMemPool; + struct TIBMemSuperBlock: public TThrRefBase, TNonCopyable { + TIntrusivePtr<TIBMemPool> Pool; + size_t SzLog; + TAtomic UseCount; + TIntrusivePtr<TMemoryRegion> MemRegion; + + TIBMemSuperBlock(TIBMemPool* pool, size_t szLog); + ~TIBMemSuperBlock() override; + char* GetData(); + size_t GetSize() { + return ((ui64)1) << SzLog; + } + void IncRef() { + AtomicAdd(UseCount, 1); + } + void DecRef(); + }; + + class TIBMemBlock: public TThrRefBase, TNonCopyable { + TIntrusivePtr<TIBMemSuperBlock> Super; + char* Data; + size_t Size; + + ~TIBMemBlock() override; + + public: + TIBMemBlock(TPtrArg<TIBMemSuperBlock> super, char* data, size_t sz) + : Super(super) + , Data(data) + , Size(sz) + { + Super->IncRef(); + } + TIBMemBlock(size_t sz) + : Super(nullptr) + , Size(sz) + { + // not really IB mem block, but useful IB code debug without IB + Data = new char[sz]; + } + char* GetData() { + return Data; + } + ui64 GetAddr() { + return Data - (char*)nullptr; + } + size_t GetSize() { + return Size; + } + TMemoryRegion* GetMemRegion() { + return Super.Get() ? Super->MemRegion.Get() : nullptr; + } + }; + + const size_t IB_MEM_LARGE_BLOCK_LN = 20; + const size_t IB_MEM_LARGE_BLOCK = 1ul << IB_MEM_LARGE_BLOCK_LN; + const size_t IB_MEM_POOL_SIZE = 1024 * 1024 * 1024; + + class TIBMemPool: public TThrRefBase, TNonCopyable { + public: + struct TCopyResultStorage; + + private: + class TIBMemSuperBlockPtr { + TIntrusivePtr<TIBMemSuperBlock> Blk; + + public: + ~TIBMemSuperBlockPtr() { + Detach(); + } + void Assign(TIntrusivePtr<TIBMemSuperBlock> p) { + Detach(); + Blk = p; + if (p.Get()) { + AtomicAdd(p->UseCount, 1); + } + } + void Detach() { + if (Blk.Get()) { + Blk->DecRef(); + Blk = nullptr; + } + } + TIBMemSuperBlock* Get() { + return Blk.Get(); + } + }; + + TIntrusivePtr<TIBContext> IBCtx; + THashMap<size_t, TVector<TIntrusivePtr<TIBMemSuperBlock>>> AllocCache; + size_t AllocCacheSize; + TIBMemSuperBlockPtr CurrentBlk; + int CurrentOffset; + TMutex CacheLock; + TThread WorkThread; + TSystemEvent HasStarted; + bool KeepRunning; + + struct TJobItem { + TRopeDataPacket* Data; + i64 MsgHandle; + TIntrusivePtr<TThrRefBase> Context; + TIntrusivePtr<TIBMemBlock> Block; + TIntrusivePtr<TCopyResultStorage> ResultStorage; + + TJobItem(TRopeDataPacket* data, i64 msgHandle, TThrRefBase* context, TPtrArg<TCopyResultStorage> dst) + : Data(data) + , MsgHandle(msgHandle) + , Context(context) + , ResultStorage(dst) + { + } + }; + + TLockFreeQueue<TJobItem*> Requests; + TSystemEvent HasWork; + + static void* ThreadFunc(void* param); + + void Return(TPtrArg<TIBMemSuperBlock> blk); + TIntrusivePtr<TIBMemSuperBlock> AllocSuper(size_t sz); + ~TIBMemPool() override; + + public: + struct TCopyResultStorage: public TThrRefBase { + TLockFreeStack<TJobItem*> Results; + + ~TCopyResultStorage() override { + TJobItem* work; + while (Results.Dequeue(&work)) { + delete work; + } + } + template <class T> + bool GetCopyResult(TIntrusivePtr<TIBMemBlock>* resBlock, i64* resMsgHandle, TIntrusivePtr<T>* context) { + TJobItem* work; + if (Results.Dequeue(&work)) { + *resBlock = work->Block; + *resMsgHandle = work->MsgHandle; + *context = static_cast<T*>(work->Context.Get()); // caller responsibility to make sure this makes sense + delete work; + return true; + } else { + return false; + } + } + }; + + public: + TIBMemPool(TPtrArg<TIBContext> ctx); + TIBContext* GetIBContext() { + return IBCtx.Get(); + } + TIBMemBlock* Alloc(size_t sz); + + void CopyData(TRopeDataPacket* data, i64 msgHandle, TThrRefBase* context, TPtrArg<TCopyResultStorage> dst) { + Requests.Enqueue(new TJobItem(data, msgHandle, context, dst)); + HasWork.Signal(); + } + + friend class TIBMemBlock; + friend struct TIBMemSuperBlock; + }; + + extern TIntrusivePtr<TIBMemPool> GetIBMemPool(); +} diff --git a/library/cpp/netliba/v6/ib_memstream.cpp b/library/cpp/netliba/v6/ib_memstream.cpp new file mode 100644 index 0000000000..ffed8f1dba --- /dev/null +++ b/library/cpp/netliba/v6/ib_memstream.cpp @@ -0,0 +1,122 @@ +#include "stdafx.h" +#include "ib_mem.h" +#include "ib_memstream.h" +#include "ib_low.h" + +namespace NNetliba { + int TIBMemStream::WriteImpl(const void* userBuffer, int sizeArg) { + const char* srcData = (const char*)userBuffer; + int size = sizeArg; + for (;;) { + if (size == 0) + return sizeArg; + if (CurBlock == Blocks.ysize()) { + // add new block + TBlock& blk = Blocks.emplace_back(); + blk.StartOffset = GetLength(); + int szLog = 17 + Min(Blocks.ysize() / 2, 13); + blk.BufSize = 1 << szLog; + blk.DataSize = 0; + blk.Mem = MemPool->Alloc(blk.BufSize); + Y_ASSERT(CurBlockOffset == 0); + } + TBlock& curBlk = Blocks[CurBlock]; + int leftSpace = curBlk.BufSize - CurBlockOffset; + int copySize = Min(size, leftSpace); + memcpy(curBlk.Mem->GetData() + CurBlockOffset, srcData, copySize); + size -= copySize; + CurBlockOffset += copySize; + srcData += copySize; + curBlk.DataSize = Max(curBlk.DataSize, CurBlockOffset); + if (CurBlockOffset == curBlk.BufSize) { + ++CurBlock; + CurBlockOffset = 0; + } + } + } + + int TIBMemStream::ReadImpl(void* userBuffer, int sizeArg) { + char* dstData = (char*)userBuffer; + int size = sizeArg; + for (;;) { + if (size == 0) + return sizeArg; + if (CurBlock == Blocks.ysize()) { + //memset(dstData, 0, size); + size = 0; + continue; + } + TBlock& curBlk = Blocks[CurBlock]; + int leftSpace = curBlk.DataSize - CurBlockOffset; + int copySize = Min(size, leftSpace); + memcpy(dstData, curBlk.Mem->GetData() + CurBlockOffset, copySize); + size -= copySize; + CurBlockOffset += copySize; + dstData += copySize; + if (CurBlockOffset == curBlk.DataSize) { + ++CurBlock; + CurBlockOffset = 0; + } + } + } + + i64 TIBMemStream::GetLength() { + i64 res = 0; + for (int i = 0; i < Blocks.ysize(); ++i) { + res += Blocks[i].DataSize; + } + return res; + } + + i64 TIBMemStream::Seek(i64 pos) { + for (int resBlockId = 0; resBlockId < Blocks.ysize(); ++resBlockId) { + const TBlock& blk = Blocks[resBlockId]; + if (pos < blk.StartOffset + blk.DataSize) { + CurBlock = resBlockId; + CurBlockOffset = pos - blk.StartOffset; + return pos; + } + } + CurBlock = Blocks.ysize(); + CurBlockOffset = 0; + return GetLength(); + } + + void TIBMemStream::GetBlocks(TVector<TBlockDescr>* res) const { + int blockCount = Blocks.ysize(); + res->resize(blockCount); + for (int i = 0; i < blockCount; ++i) { + const TBlock& blk = Blocks[i]; + TBlockDescr& dst = (*res)[i]; + dst.Addr = blk.Mem->GetAddr(); + dst.BufSize = blk.BufSize; + dst.DataSize = blk.DataSize; + TMemoryRegion* mem = blk.Mem->GetMemRegion(); + dst.LocalKey = mem->GetLKey(); + dst.RemoteKey = mem->GetRKey(); + } + } + + void TIBMemStream::CreateBlocks(const TVector<TBlockSizes>& arr) { + int blockCount = arr.ysize(); + Blocks.resize(blockCount); + i64 offset = 0; + for (int i = 0; i < blockCount; ++i) { + const TBlockSizes& src = arr[i]; + TBlock& blk = Blocks[i]; + blk.BufSize = src.BufSize; + blk.DataSize = src.DataSize; + blk.Mem = MemPool->Alloc(blk.BufSize); + blk.StartOffset = offset; + offset += blk.DataSize; + } + CurBlock = 0; + CurBlockOffset = 0; + } + + void TIBMemStream::Clear() { + Blocks.resize(0); + CurBlock = 0; + CurBlockOffset = 0; + } +} diff --git a/library/cpp/netliba/v6/ib_memstream.h b/library/cpp/netliba/v6/ib_memstream.h new file mode 100644 index 0000000000..67eb2386de --- /dev/null +++ b/library/cpp/netliba/v6/ib_memstream.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ib_mem.h" +#include <library/cpp/binsaver/bin_saver.h> +#include <library/cpp/binsaver/buffered_io.h> + +namespace NNetliba { + class TIBMemStream: public IBinaryStream { + struct TBlock { + TIntrusivePtr<TIBMemBlock> Mem; + i64 StartOffset; + int BufSize, DataSize; + + TBlock() + : StartOffset(0) + , BufSize(0) + , DataSize(0) + { + } + TBlock(const TBlock& x) { + Copy(x); + } + void operator=(const TBlock& x) { + Copy(x); + } + void Copy(const TBlock& x) { + if (x.BufSize > 0) { + Mem = GetIBMemPool()->Alloc(x.BufSize); + memcpy(Mem->GetData(), x.Mem->GetData(), x.DataSize); + StartOffset = x.StartOffset; + BufSize = x.BufSize; + DataSize = x.DataSize; + } else { + Mem = nullptr; + StartOffset = 0; + BufSize = 0; + DataSize = 0; + } + } + }; + + TIntrusivePtr<TIBMemPool> MemPool; + TVector<TBlock> Blocks; + int CurBlock; + int CurBlockOffset; + + public: + struct TBlockDescr { + ui64 Addr; + int BufSize, DataSize; + ui32 RemoteKey, LocalKey; + }; + struct TBlockSizes { + int BufSize, DataSize; + }; + + public: + TIBMemStream() + : MemPool(GetIBMemPool()) + , CurBlock(0) + , CurBlockOffset(0) + { + } + ~TIBMemStream() override { + } // keep gcc happy + + bool IsValid() const override { + return true; + } + bool IsFailed() const override { + return false; + } + void Flush() { + } + + i64 GetLength(); + i64 Seek(i64 pos); + + void GetBlocks(TVector<TBlockDescr>* res) const; + void CreateBlocks(const TVector<TBlockSizes>& arr); + + void Clear(); + + private: + int WriteImpl(const void* userBuffer, int size) override; + int ReadImpl(void* userBuffer, int size) override; + }; + + template <class T> + inline void Serialize(bool bRead, TIBMemStream& ms, T& c) { + IBinSaver bs(ms, bRead); + bs.Add(1, &c); + } + +} diff --git a/library/cpp/netliba/v6/ib_test.cpp b/library/cpp/netliba/v6/ib_test.cpp new file mode 100644 index 0000000000..178691e837 --- /dev/null +++ b/library/cpp/netliba/v6/ib_test.cpp @@ -0,0 +1,232 @@ +#include "stdafx.h" +#include "ib_test.h" +#include "ib_buffers.h" +#include "udp_socket.h" +#include "udp_address.h" +#include <util/system/hp_timer.h> + +namespace NNetliba { + struct TWelcomeSocketAddr { + int LID; + int QPN; + }; + + struct TRCQueuePairHandshake { + int QPN, PSN; + }; + + class TIPSocket { + TNetSocket s; + enum { + HDR_SIZE = UDP_LOW_LEVEL_HEADER_SIZE + }; + + public: + void Init(int port) { + if (!InitLocalIPList()) { + Y_ASSERT(0 && "Can not determine self IP address"); + return; + } + s.Open(port); + } + bool IsValid() { + return s.IsValid(); + } + void Respond(const TWelcomeSocketAddr& info) { + char buf[10000]; + int sz = sizeof(buf); + sockaddr_in6 fromAddress; + bool rv = s.RecvFrom(buf, &sz, &fromAddress); + if (rv && strcmp(buf + HDR_SIZE, "Hello_IB") == 0) { + printf("send welcome info\n"); + memcpy(buf + HDR_SIZE, &info, sizeof(info)); + TNetSocket::ESendError err = s.SendTo(buf, sizeof(info) + HDR_SIZE, fromAddress, FF_ALLOW_FRAG); + if (err != TNetSocket::SEND_OK) { + printf("SendTo() fail %d\n", err); + } + } + } + void Request(const char* hostName, int port, TWelcomeSocketAddr* res) { + TUdpAddress addr = CreateAddress(hostName, port); + printf("addr = %s\n", GetAddressAsString(addr).c_str()); + + sockaddr_in6 sockAddr; + GetWinsockAddr(&sockAddr, addr); + + for (;;) { + char buf[10000]; + int sz = sizeof(buf); + sockaddr_in6 fromAddress; + bool rv = s.RecvFrom(buf, &sz, &fromAddress); + if (rv) { + if (sz == sizeof(TWelcomeSocketAddr) + HDR_SIZE) { + *res = *(TWelcomeSocketAddr*)(buf + HDR_SIZE); + break; + } + printf("Get unexpected %d bytes from somewhere?\n", sz); + } + + strcpy(buf + HDR_SIZE, "Hello_IB"); + TNetSocket::ESendError err = s.SendTo(buf, strlen(buf) + 1 + HDR_SIZE, sockAddr, FF_ALLOW_FRAG); + if (err != TNetSocket::SEND_OK) { + printf("SendTo() fail %d\n", err); + } + + Sleep(TDuration::MilliSeconds(100)); + } + } + }; + + // can hang if opposite side exits, but it's only basic test + static void WaitForRecv(TIBBufferPool* bp, TPtrArg<TComplectionQueue> cq, ibv_wc* wc) { + for (;;) { + if (cq->Poll(wc, 1) == 1) { + if (wc->opcode & IBV_WC_RECV) { + break; + } + bp->FreeBuf(wc->wr_id); + } + } + } + + void RunIBTest(bool isClient, const char* serverName) { + TIntrusivePtr<TIBPort> port = GetIBDevice(); + if (port.Get() == nullptr) { + printf("No IB device found\n"); + return; + } + + const int IP_PORT = 13666; + const int WELCOME_QKEY = 0x1113013; + const int MAX_SRQ_WORK_REQUESTS = 100; + const int MAX_CQ_EVENTS = 1000; + const int QP_SEND_QUEUE_SIZE = 3; + + TIntrusivePtr<TComplectionQueue> cq = new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS); + + TIBBufferPool bp(port->GetCtx(), MAX_SRQ_WORK_REQUESTS); + + if (!isClient) { + // server + TIPSocket ipSocket; + ipSocket.Init(IP_PORT); + if (!ipSocket.IsValid()) { + printf("UDP port %d is not available\n", IP_PORT); + return; + } + + TIntrusivePtr<TComplectionQueue> cqRC = new TComplectionQueue(port->GetCtx(), MAX_CQ_EVENTS); + + TIntrusivePtr<TUDQueuePair> welcomeQP = new TUDQueuePair(port, cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE); + welcomeQP->Init(WELCOME_QKEY); + + TWelcomeSocketAddr info; + info.LID = port->GetLID(); + info.QPN = welcomeQP->GetQPN(); + + TIntrusivePtr<TAddressHandle> ahPeer1; + for (;;) { + ipSocket.Respond(info); + // poll srq + ibv_wc wc; + if (cq->Poll(&wc, 1) == 1 && (wc.opcode & IBV_WC_RECV)) { + printf("Got IB handshake\n"); + + TRCQueuePairHandshake remoteHandshake; + ibv_ah_attr clientAddr; + { + TIBRecvPacketProcess pkt(bp, wc); + remoteHandshake = *(TRCQueuePairHandshake*)pkt.GetUDData(); + port->GetAHAttr(&wc, pkt.GetGRH(), &clientAddr); + } + + TIntrusivePtr<TAddressHandle> ahPeer2; + ahPeer2 = new TAddressHandle(port->GetCtx(), &clientAddr); + + TIntrusivePtr<TRCQueuePair> rcTest = new TRCQueuePair(port->GetCtx(), cqRC, bp.GetSRQ(), QP_SEND_QUEUE_SIZE); + rcTest->Init(clientAddr, remoteHandshake.QPN, remoteHandshake.PSN); + + TRCQueuePairHandshake handshake; + handshake.PSN = rcTest->GetPSN(); + handshake.QPN = rcTest->GetQPN(); + bp.PostSend(welcomeQP, ahPeer2, wc.src_qp, WELCOME_QKEY, &handshake, sizeof(handshake)); + + WaitForRecv(&bp, cqRC, &wc); + + { + TIBRecvPacketProcess pkt(bp, wc); + printf("Got RC ping: %s\n", pkt.GetData()); + const char* ret = "Puk"; + bp.PostSend(rcTest, ret, strlen(ret) + 1); + } + + for (int i = 0; i < 5; ++i) { + WaitForRecv(&bp, cqRC, &wc); + TIBRecvPacketProcess pkt(bp, wc); + printf("Got RC ping: %s\n", pkt.GetData()); + const char* ret = "Fine"; + bp.PostSend(rcTest, ret, strlen(ret) + 1); + } + } + } + } else { + // client + ibv_wc wc; + + TIPSocket ipSocket; + ipSocket.Init(0); + if (!ipSocket.IsValid()) { + printf("Failed to create UDP socket\n"); + return; + } + + printf("Connecting to %s\n", serverName); + TWelcomeSocketAddr info; + ipSocket.Request(serverName, IP_PORT, &info); + printf("Got welcome info, lid %d, qpn %d\n", info.LID, info.QPN); + + TIntrusivePtr<TUDQueuePair> welcomeQP = new TUDQueuePair(port, cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE); + welcomeQP->Init(WELCOME_QKEY); + + TIntrusivePtr<TRCQueuePair> rcTest = new TRCQueuePair(port->GetCtx(), cq, bp.GetSRQ(), QP_SEND_QUEUE_SIZE); + + TRCQueuePairHandshake handshake; + handshake.PSN = rcTest->GetPSN(); + handshake.QPN = rcTest->GetQPN(); + TIntrusivePtr<TAddressHandle> serverAH = new TAddressHandle(port, info.LID, 0); + bp.PostSend(welcomeQP, serverAH, info.QPN, WELCOME_QKEY, &handshake, sizeof(handshake)); + + WaitForRecv(&bp, cq, &wc); + + ibv_ah_attr serverAddr; + TRCQueuePairHandshake remoteHandshake; + { + TIBRecvPacketProcess pkt(bp, wc); + printf("Got handshake response\n"); + remoteHandshake = *(TRCQueuePairHandshake*)pkt.GetUDData(); + port->GetAHAttr(&wc, pkt.GetGRH(), &serverAddr); + } + + rcTest->Init(serverAddr, remoteHandshake.QPN, remoteHandshake.PSN); + + char hiAndy[] = "Hi, Andy"; + bp.PostSend(rcTest, hiAndy, sizeof(hiAndy)); + WaitForRecv(&bp, cq, &wc); + { + TIBRecvPacketProcess pkt(bp, wc); + printf("Got RC pong: %s\n", pkt.GetData()); + } + + for (int i = 0; i < 5; ++i) { + char howAreYou[] = "How are you?"; + bp.PostSend(rcTest, howAreYou, sizeof(howAreYou)); + + WaitForRecv(&bp, cq, &wc); + { + TIBRecvPacketProcess pkt(bp, wc); + printf("Got RC pong: %s\n", pkt.GetData()); + } + } + } + } +} diff --git a/library/cpp/netliba/v6/ib_test.h b/library/cpp/netliba/v6/ib_test.h new file mode 100644 index 0000000000..b72c95d7a9 --- /dev/null +++ b/library/cpp/netliba/v6/ib_test.h @@ -0,0 +1,5 @@ +#pragma once + +namespace NNetliba { + void RunIBTest(bool isClient, const char* serverName); +} diff --git a/library/cpp/netliba/v6/net_acks.cpp b/library/cpp/netliba/v6/net_acks.cpp new file mode 100644 index 0000000000..5f4690c264 --- /dev/null +++ b/library/cpp/netliba/v6/net_acks.cpp @@ -0,0 +1,194 @@ +#include "stdafx.h" +#include "net_acks.h" +#include <util/datetime/cputimer.h> + +#include <atomic> + +namespace NNetliba { + const float RTT_AVERAGE_OVER = 15; + + float TCongestionControl::StartWindowSize = 3; + float TCongestionControl::MaxPacketRate = 0; // unlimited + + bool UseTOSforAcks = false; //true;// + + void EnableUseTOSforAcks(bool enable) { + UseTOSforAcks = enable; + } + + float CONG_CTRL_CHANNEL_INFLATE = 1; + + void SetCongCtrlChannelInflate(float inflate) { + CONG_CTRL_CHANNEL_INFLATE = inflate; + } + + ////////////////////////////////////////////////////////////////////////// + TPingTracker::TPingTracker() + : AvrgRTT(CONG_CTRL_INITIAL_RTT) + , AvrgRTT2(CONG_CTRL_INITIAL_RTT * CONG_CTRL_INITIAL_RTT) + , RTTCount(0) + { + } + + void TPingTracker::RegisterRTT(float rtt) { + Y_ASSERT(rtt > 0); + float keep = RTTCount / (RTTCount + 1); + AvrgRTT *= keep; + AvrgRTT += (1 - keep) * rtt; + AvrgRTT2 *= keep; + AvrgRTT2 += (1 - keep) * Sqr(rtt); + RTTCount = Min(RTTCount + 1, RTT_AVERAGE_OVER); + //static int n; + //if ((++n % 1024) == 0) + // printf("Average RTT = %g (sko = %g)\n", GetRTT() * 1000, GetRTTSKO() * 1000); + } + + void TPingTracker::IncreaseRTT() { + const float F_RTT_DECAY_RATE = 1.1f; + AvrgRTT *= F_RTT_DECAY_RATE; + AvrgRTT2 *= Sqr(F_RTT_DECAY_RATE); + } + + ////////////////////////////////////////////////////////////////////////// + void TAckTracker::Resend() { + CurrentPacket = 0; + for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i) + Congestion->Failure(); // not actually correct but simplifies logic a lot + PacketsInFly.clear(); + DroppedPackets.clear(); + ResendQueue.clear(); + for (size_t i = 0; i < AckReceived.size(); ++i) + AckReceived[i] = false; + } + + int TAckTracker::SelectPacket() { + if (!ResendQueue.empty()) { + int res = ResendQueue.back(); + ResendQueue.pop_back(); + //printf("resending packet %d\n", res); + return res; + } + if (CurrentPacket == PacketCount) { + return -1; + } + return CurrentPacket++; + } + + TAckTracker::~TAckTracker() { + for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i) + Congestion->Failure(); + // object will be incorrect state after this (failed packets are not added to resend queue), but who cares + } + + int TAckTracker::GetPacketToSend(float deltaT) { + int res = SelectPacket(); + if (res == -1) { + // needed to count time even if we don't have anything to send + Congestion->HasTriedToSend(); + return res; + } + Congestion->LaunchPacket(); + PacketsInFly[res] = -deltaT; // deltaT is time since last Step(), so for the timing to be correct we should subtract it + return res; + } + + // called on SendTo() failure + void TAckTracker::AddToResend(int pkt) { + //printf("AddToResend(%d)\n", pkt); + TPacketHash::iterator i = PacketsInFly.find(pkt); + if (i != PacketsInFly.end()) { + PacketsInFly.erase(i); + Congestion->FailureOnSend(); + ResendQueue.push_back(pkt); + } else + Y_ASSERT(0); + } + + void TAckTracker::Ack(int pkt, float deltaT, bool updateRTT) { + Y_ASSERT(pkt >= 0 && pkt < PacketCount); + if (AckReceived[pkt]) + return; + AckReceived[pkt] = true; + //printf("Ack received for %d\n", pkt); + TPacketHash::iterator i = PacketsInFly.find(pkt); + if (i == PacketsInFly.end()) { + for (size_t k = 0; k < ResendQueue.size(); ++k) { + if (ResendQueue[k] == pkt) { + ResendQueue[k] = ResendQueue.back(); + ResendQueue.pop_back(); + break; + } + } + TPacketHash::iterator z = DroppedPackets.find(pkt); + if (z != DroppedPackets.end()) { + // late packet arrived + if (updateRTT) { + float ping = z->second + deltaT; + Congestion->RegisterRTT(ping); + } + DroppedPackets.erase(z); + } else { + // Y_ASSERT(0); // ack on nonsent packet, possible in resend scenario + } + return; + } + if (updateRTT) { + float ping = i->second + deltaT; + //printf("Register RTT %g\n", ping * 1000); + Congestion->RegisterRTT(ping); + } + PacketsInFly.erase(i); + Congestion->Success(); + } + + void TAckTracker::AckAll() { + for (TPacketHash::const_iterator i = PacketsInFly.begin(); i != PacketsInFly.end(); ++i) { + int pkt = i->first; + AckReceived[pkt] = true; + Congestion->Success(); + } + PacketsInFly.clear(); + } + + void TAckTracker::Step(float deltaT) { + float timeoutVal = Congestion->GetTimeout(); + + //static int n; + //if ((++n % 1024) == 0) + // printf("timeout = %g, window = %g, fail_rate %g, pkt_rate = %g\n", timeoutVal * 1000, Congestion->GetWindow(), Congestion->GetFailRate(), (1 - Congestion->GetFailRate()) * Congestion->GetWindow() / Congestion->GetRTT()); + + TimeToNextPacketTimeout = 1000; + // для окон меньше единицы мы кидаем рандом один раз за RTT на то, можно ли пускать пакет + // поэтому можно ждать максимум RTT, после этого надо кинуть новый random + if (Congestion->GetWindow() < 1) + TimeToNextPacketTimeout = Congestion->GetRTT(); + + for (auto& droppedPacket : DroppedPackets) { + float& t = droppedPacket.second; + t += deltaT; + } + + for (TPacketHash::iterator i = PacketsInFly.begin(); i != PacketsInFly.end();) { + float& t = i->second; + t += deltaT; + if (t > timeoutVal) { + //printf("packet %d timed out (timeout = %g)\n", i->first, timeoutVal); + ResendQueue.push_back(i->first); + DroppedPackets[i->first] = i->second; + TPacketHash::iterator k = i++; + PacketsInFly.erase(k); + Congestion->Failure(); + } else { + TimeToNextPacketTimeout = Min(TimeToNextPacketTimeout, timeoutVal - t); + ++i; + } + } + } + + static std::atomic<ui32> netAckRndVal = (ui32)GetCycleCount(); + ui32 NetAckRnd() { + const auto nextNetAckRndVal = static_cast<ui32>(((ui64)netAckRndVal.load(std::memory_order_acquire) * 279470273) % 4294967291); + netAckRndVal.store(nextNetAckRndVal, std::memory_order_release); + return nextNetAckRndVal; + } +} diff --git a/library/cpp/netliba/v6/net_acks.h b/library/cpp/netliba/v6/net_acks.h new file mode 100644 index 0000000000..a497eaf95c --- /dev/null +++ b/library/cpp/netliba/v6/net_acks.h @@ -0,0 +1,528 @@ +#pragma once + +#include "net_test.h" +#include "net_queue_stat.h" + +#include <util/system/spinlock.h> + +namespace NNetliba { + const float MIN_PACKET_RTT_SKO = 0.001f; // avoid drops due to small hiccups + + const float CONG_CTRL_INITIAL_RTT = 0.24f; //0.01f; // taking into account Las Vegas 10ms estimate is too optimistic + + const float CONG_CTRL_WINDOW_GROW = 0.005f; + const float CONG_CTRL_WINDOW_SHRINK = 0.9f; + const float CONG_CTRL_WINDOW_SHRINK_RTT = 0.95f; + const float CONG_CTRL_RTT_MIX_RATE = 0.9f; + const int CONG_CTRL_RTT_SEQ_COUNT = 8; + const float CONG_CTRL_MIN_WINDOW = 0.01f; + const float CONG_CTRL_LARGE_TIME_WINDOW = 10000.0f; + const float CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD = 0.4f; // in seconds + const float CONG_CTRL_MINIMAL_SEND_INTERVAL = 1; + const float CONG_CTRL_MIN_FAIL_INTERVAL = 0.001f; + const float CONG_CTRL_ALLOWED_BURST_SIZE = 3; + const float CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION = 0.002f; + + const float LAME_MTU_TIMEOUT = 0.3f; + const float LAME_MTU_INTERVAL = 0.05f; + + const float START_CHECK_PORT_DELAY = 0.5; + const float FINISH_CHECK_PORT_DELAY = 10; + const int N_PORT_TEST_COUNT_LIMIT = 256; // or 512 + + // if enabled all acks are sent with different TOS, so they end up in different queue + // this allows us to limit window based on minimal RTT observed and 1G link assumption + extern bool UseTOSforAcks; + + class TPingTracker { + float AvrgRTT, AvrgRTT2; // RTT statistics + float RTTCount; + + public: + TPingTracker(); + float GetRTT() const { + return AvrgRTT; + } + float GetRTTSKO() const { + float sko = sqrt(fabs(Sqr(AvrgRTT) - AvrgRTT2)); + float minSKO = Max(MIN_PACKET_RTT_SKO, AvrgRTT * 0.05f); + return Max(minSKO, sko); + } + float GetTimeout() const { + return GetRTT() + GetRTTSKO() * 3; + } + void RegisterRTT(float rtt); + void IncreaseRTT(); + }; + + ui32 NetAckRnd(); + + class TLameMTUDiscovery: public TThrRefBase { + enum EState { + NEED_PING, + WAIT, + }; + + float TimePassed, TimeSinceLastPing; + EState State; + + public: + TLameMTUDiscovery() + : TimePassed(0) + , TimeSinceLastPing(0) + , State(NEED_PING) + { + } + bool CanSend() { + return State == NEED_PING; + } + void PingSent() { + State = WAIT; + TimeSinceLastPing = 0; + } + bool IsTimedOut() const { + return TimePassed > LAME_MTU_TIMEOUT; + } + void Step(float deltaT) { + TimePassed += deltaT; + TimeSinceLastPing += deltaT; + if (TimeSinceLastPing > LAME_MTU_INTERVAL) + State = NEED_PING; + } + }; + + struct TPeerQueueStats: public IPeerQueueStats { + int Count; + + TPeerQueueStats() + : Count(0) + { + } + int GetPacketCount() override { + return Count; + } + }; + + // pretend we have multiple channels in parallel + // not exact approximation since N channels should have N distinct windows + extern float CONG_CTRL_CHANNEL_INFLATE; + + class TCongestionControl: public TThrRefBase { + float Window, PacketsInFly, FailRate; + float MinRTT, MaxWindow; + bool FullSpeed, DoCountTime; + TPingTracker PingTracker; + double TimeSinceLastRecv; + TAdaptiveLock PortTesterLock; + TIntrusivePtr<TPortUnreachableTester> PortTester; + int ActiveTransferCount; + float AvrgRTT; + int HighRTTCounter; + float WindowFraction, FractionRecalc; + float TimeWindow; + double TimeSinceLastFail; + float VirtualPackets; + int MTU; + TIntrusivePtr<TLameMTUDiscovery> MTUDiscovery; + TIntrusivePtr<TPeerQueueStats> QueueStats; + + void CalcMaxWindow() { + if (MTU == 0) + return; + MaxWindow = 125000000 / MTU * Max(0.001f, MinRTT); + } + + public: + static float StartWindowSize, MaxPacketRate; + + public: + TCongestionControl() + : Window(StartWindowSize * CONG_CTRL_CHANNEL_INFLATE) + , PacketsInFly(0) + , FailRate(0) + , MinRTT(10) + , MaxWindow(10000) + , FullSpeed(false) + , DoCountTime(false) + , TimeSinceLastRecv(0) + , ActiveTransferCount(0) + , AvrgRTT(0) + , HighRTTCounter(0) + , WindowFraction(0) + , FractionRecalc(0) + , TimeWindow(CONG_CTRL_LARGE_TIME_WINDOW) + , TimeSinceLastFail(0) + , MTU(0) + { + VirtualPackets = Max(Window - CONG_CTRL_ALLOWED_BURST_SIZE, 0.f); + } + bool CanSend() { + bool res = VirtualPackets + PacketsInFly + WindowFraction <= Window; + FullSpeed |= !res; + res &= TimeWindow > 0; + return res; + } + void LaunchPacket() { + PacketsInFly += 1.0f; + TimeWindow -= 1.0f; + } + void RegisterRTT(float RTT) { + if (RTT < 0) + return; + RTT = ClampVal(RTT, 0.0001f, 1.0f); + if (RTT < MinRTT && MTU != 0) { + MinRTT = RTT; + CalcMaxWindow(); + } + MinRTT = Min(MinRTT, RTT); + + PingTracker.RegisterRTT(RTT); + if (AvrgRTT == 0) + AvrgRTT = RTT; + if (RTT > AvrgRTT) { + ++HighRTTCounter; + if (HighRTTCounter >= CONG_CTRL_RTT_SEQ_COUNT) { + //printf("Too many high RTT in a row\n"); + if (FullSpeed) { + float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK_RTT) / CONG_CTRL_CHANNEL_INFLATE); + Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract); + VirtualPackets = Max(0.f, VirtualPackets - windowSubtract); + //printf("reducing window by RTT , new window %g\n", Window); + } + // reduce no more then twice per RTT + HighRTTCounter = Min(0, CONG_CTRL_RTT_SEQ_COUNT - (int)(Window * 0.5)); + } + } else { + HighRTTCounter = Min(0, HighRTTCounter); + } + + float rttMixRate = CONG_CTRL_RTT_MIX_RATE; + AvrgRTT = AvrgRTT * rttMixRate + RTT * (1 - rttMixRate); + } + void Success() { + PacketsInFly -= 1; + Y_ASSERT(PacketsInFly >= 0); + // FullSpeed should be correct at this point + // we assume that after UpdateAlive() we send all packets first then we listen for acks and call Success() + // FullSpeed is set in CanSend() during send if we are using full window + // do not increaese window while send rate is limited by virtual packets (ie start of transfer) + if (FullSpeed && VirtualPackets == 0) { + // there are 2 requirements for window growth + // 1) growth should be proportional to window size to ensure constant FailRate + // 2) growth should be constant to ensure fairness among different flows + // so lets make it square root :) + Window += sqrt(Window / CONG_CTRL_CHANNEL_INFLATE) * CONG_CTRL_WINDOW_GROW; + if (UseTOSforAcks) { + Window = Min(Window, MaxWindow); + } + } + FailRate *= 0.99f; + } + void FailureOnSend() { + //printf("Failure on send\n"); + PacketsInFly -= 1; + Y_ASSERT(PacketsInFly >= 0); + // not a congestion event, do not modify Window + // do not set FullSpeed since we are not using full Window + } + void Failure() { + //printf("Congestion failure\n"); + PacketsInFly -= 1; + Y_ASSERT(PacketsInFly >= 0); + // account limited number of fails per segment + if (TimeSinceLastFail > CONG_CTRL_MIN_FAIL_INTERVAL) { + TimeSinceLastFail = 0; + if (Window <= CONG_CTRL_MIN_WINDOW) { + // ping dead hosts less frequently + if (PingTracker.GetRTT() / CONG_CTRL_MIN_WINDOW < CONG_CTRL_MINIMAL_SEND_INTERVAL) + PingTracker.IncreaseRTT(); + Window = CONG_CTRL_MIN_WINDOW; + VirtualPackets = 0; + } else { + float windowSubtract = Window * ((1 - CONG_CTRL_WINDOW_SHRINK) / CONG_CTRL_CHANNEL_INFLATE); + Window = Max(CONG_CTRL_MIN_WINDOW, Window - windowSubtract); + VirtualPackets = Max(0.f, VirtualPackets - windowSubtract); + } + } + FailRate = FailRate * 0.99f + 0.01f; + } + bool HasPacketsInFly() const { + return PacketsInFly > 0; + } + float GetTimeout() const { + return PingTracker.GetTimeout(); + } + float GetWindow() const { + return Window; + } + float GetRTT() const { + return PingTracker.GetRTT(); + } + float GetFailRate() const { + return FailRate; + } + float GetTimeSinceLastRecv() const { + return TimeSinceLastRecv; + } + int GetTransferCount() const { + return ActiveTransferCount; + } + float GetMaxWindow() const { + return UseTOSforAcks ? MaxWindow : -1; + } + void MarkAlive() { + TimeSinceLastRecv = 0; + + with_lock (PortTesterLock) { + PortTester = nullptr; + } + + } + void HasTriedToSend() { + DoCountTime = true; + } + bool IsAlive() const { + return TimeSinceLastRecv < 1e6f; + } + void Kill() { + TimeSinceLastRecv = 1e6f; + } + bool UpdateAlive(const TUdpAddress& toAddress, float deltaT, float timeout, float* resMaxWaitTime) { + if (!FullSpeed) { + // create virtual packets during idle to avoid burst on transmit start + if (AvrgRTT > CONG_CTRL_MIN_RTT_FOR_BURST_REDUCTION) { + VirtualPackets = Max(VirtualPackets, Window - PacketsInFly - CONG_CTRL_ALLOWED_BURST_SIZE); + } + } else { + if (VirtualPackets > 0) { + if (Window <= CONG_CTRL_ALLOWED_BURST_SIZE) { + VirtualPackets = 0; + } + float xRTT = AvrgRTT == 0 ? CONG_CTRL_INITIAL_RTT : AvrgRTT; + float virtualPktsPerSecond = Window / xRTT; + VirtualPackets = Max(0.f, VirtualPackets - deltaT * virtualPktsPerSecond); + *resMaxWaitTime = Min(*resMaxWaitTime, 0.001f); // need to update virtual packets counter regularly + } + } + float currentRTT = GetRTT(); + FractionRecalc += deltaT; + if (FractionRecalc > currentRTT) { + int cycleCount = (int)(FractionRecalc / currentRTT); + FractionRecalc -= currentRTT * cycleCount; + WindowFraction = (NetAckRnd() & 1023) * (1 / 1023.0f) / cycleCount; + } + + if (MaxPacketRate > 0 && AvrgRTT > 0) { + float maxTimeWindow = CONG_CTRL_TIME_WINDOW_LIMIT_PERIOD * MaxPacketRate; + TimeWindow = Min(maxTimeWindow, TimeWindow + MaxPacketRate * deltaT); + } else + TimeWindow = CONG_CTRL_LARGE_TIME_WINDOW; + + // guarantee minimal send rate + if (currentRTT > CONG_CTRL_MINIMAL_SEND_INTERVAL * Window) { + Window = Max(CONG_CTRL_MIN_WINDOW, currentRTT / CONG_CTRL_MINIMAL_SEND_INTERVAL); + VirtualPackets = 0; + } + + TimeSinceLastFail += deltaT; + + //static int n; + //if ((++n & 127) == 0) + // printf("window = %g, fly = %g, VirtualPkts = %g, deltaT = %g, FailRate = %g FullSpeed = %d AvrgRTT = %g\n", + // Window, PacketsInFly, VirtualPackets, deltaT * 1000, FailRate, (int)FullSpeed, AvrgRTT * 1000); + + if (PacketsInFly > 0 || FullSpeed || DoCountTime) { + // считаем время только когда есть пакеты в полете + TimeSinceLastRecv += deltaT; + if (TimeSinceLastRecv > START_CHECK_PORT_DELAY) { + if (TimeSinceLastRecv < FINISH_CHECK_PORT_DELAY) { + TIntrusivePtr<TPortUnreachableTester> portTester; + with_lock (PortTesterLock) { + portTester = PortTester; + } + + if (!portTester && AtomicGet(ActivePortTestersCount) < N_PORT_TEST_COUNT_LIMIT) { + portTester = new TPortUnreachableTester(); + with_lock (PortTesterLock) { + PortTester = portTester; + } + + if (portTester->IsValid()) { + portTester->Connect(toAddress); + } else { + with_lock (PortTesterLock) { + PortTester = nullptr; + } + } + } + if (portTester && !portTester->Test(deltaT)) { + Kill(); + return false; + } + } else { + with_lock (PortTesterLock) { + PortTester = nullptr; + } + } + } + if (TimeSinceLastRecv > timeout) { + Kill(); + return false; + } + } + + FullSpeed = false; + DoCountTime = false; + + if (MTUDiscovery.Get()) + MTUDiscovery->Step(deltaT); + + return true; + } + bool IsKnownMTU() const { + return MTU != 0; + } + int GetMTU() const { + return MTU; + } + TLameMTUDiscovery* GetMTUDiscovery() { + if (MTUDiscovery.Get() == nullptr) + MTUDiscovery = new TLameMTUDiscovery; + return MTUDiscovery.Get(); + } + void SetMTU(int sz) { + MTU = sz; + MTUDiscovery = nullptr; + CalcMaxWindow(); + } + void AttachQueueStats(TIntrusivePtr<TPeerQueueStats> s) { + if (s.Get()) { + s->Count = ActiveTransferCount; + } + Y_ASSERT(QueueStats.Get() == nullptr); + QueueStats = s; + } + friend class TCongestionControlPtr; + }; + + class TCongestionControlPtr { + TIntrusivePtr<TCongestionControl> Ptr; + + void Inc() { + if (Ptr.Get()) { + ++Ptr->ActiveTransferCount; + if (Ptr->QueueStats.Get()) { + Ptr->QueueStats->Count = Ptr->ActiveTransferCount; + } + } + } + void Dec() { + if (Ptr.Get()) { + --Ptr->ActiveTransferCount; + if (Ptr->QueueStats.Get()) { + Ptr->QueueStats->Count = Ptr->ActiveTransferCount; + } + } + } + + public: + TCongestionControlPtr() { + } + ~TCongestionControlPtr() { + Dec(); + } + TCongestionControlPtr(TCongestionControl* p) + : Ptr(p) + { + Inc(); + } + TCongestionControlPtr& operator=(const TCongestionControlPtr& a) { + Dec(); + Ptr = a.Ptr; + Inc(); + return *this; + } + TCongestionControlPtr& operator=(TCongestionControl* a) { + Dec(); + Ptr = a; + Inc(); + return *this; + } + operator TCongestionControl*() const { + return Ptr.Get(); + } + TCongestionControl* operator->() const { + return Ptr.Get(); + } + TIntrusivePtr<TCongestionControl> Get() const { + return Ptr; + } + }; + + class TAckTracker { + struct TFlyingPacket { + float T; + int PktId; + TFlyingPacket() + : T(0) + , PktId(-1) + { + } + TFlyingPacket(float t, int pktId) + : T(t) + , PktId(pktId) + { + } + }; + int PacketCount, CurrentPacket; + typedef THashMap<int, float> TPacketHash; + TPacketHash PacketsInFly, DroppedPackets; + TVector<int> ResendQueue; + TCongestionControlPtr Congestion; + TVector<bool> AckReceived; + float TimeToNextPacketTimeout; + + int SelectPacket(); + + public: + TAckTracker() + : PacketCount(0) + , CurrentPacket(0) + , TimeToNextPacketTimeout(1000) + { + } + ~TAckTracker(); + void AttachCongestionControl(TCongestionControl* p) { + Congestion = p; + } + TIntrusivePtr<TCongestionControl> GetCongestionControl() const { + return Congestion.Get(); + } + void SetPacketCount(int n) { + Y_ASSERT(PacketCount == 0); + PacketCount = n; + AckReceived.resize(n, false); + } + void Resend(); + bool IsInitialized() { + return PacketCount != 0; + } + int GetPacketToSend(float deltaT); + void AddToResend(int pkt); // called when failed to send packet + void Ack(int pkt, float deltaT, bool updateRTT); + void AckAll(); + void MarkAlive() { + Congestion->MarkAlive(); + } + bool IsAlive() const { + return Congestion->IsAlive(); + } + void Step(float deltaT); + bool CanSend() const { + return Congestion->CanSend(); + } + float GetTimeToNextPacketTimeout() const { + return TimeToNextPacketTimeout; + } + }; +} diff --git a/library/cpp/netliba/v6/net_queue_stat.h b/library/cpp/netliba/v6/net_queue_stat.h new file mode 100644 index 0000000000..75e4b10e35 --- /dev/null +++ b/library/cpp/netliba/v6/net_queue_stat.h @@ -0,0 +1,9 @@ +#pragma once + +#include <util/generic/ptr.h> + +namespace NNetliba { + struct IPeerQueueStats: public TThrRefBase { + virtual int GetPacketCount() = 0; + }; +} diff --git a/library/cpp/netliba/v6/net_request.cpp b/library/cpp/netliba/v6/net_request.cpp new file mode 100644 index 0000000000..e3688f2835 --- /dev/null +++ b/library/cpp/netliba/v6/net_request.cpp @@ -0,0 +1,5 @@ +#include "stdafx.h" +#include "net_request.h" + +namespace NNetliba { +} diff --git a/library/cpp/netliba/v6/net_request.h b/library/cpp/netliba/v6/net_request.h new file mode 100644 index 0000000000..15bc47ae21 --- /dev/null +++ b/library/cpp/netliba/v6/net_request.h @@ -0,0 +1,15 @@ +#pragma once + +#include <util/generic/guid.h> +#include "udp_address.h" + +namespace NNetliba { + class TRopeDataPacket; + + struct TRequest { + TUdpAddress Address; + TGUID Guid; + TAutoPtr<TRopeDataPacket> Data; + }; + +} diff --git a/library/cpp/netliba/v6/net_test.cpp b/library/cpp/netliba/v6/net_test.cpp new file mode 100644 index 0000000000..63fd7c4067 --- /dev/null +++ b/library/cpp/netliba/v6/net_test.cpp @@ -0,0 +1,50 @@ +#include "stdafx.h" +#include "net_test.h" +#include "udp_address.h" + +#ifndef _win_ +#include <errno.h> +#endif + +namespace NNetliba { + TAtomic ActivePortTestersCount; + + const float PING_DELAY = 0.5f; + + TPortUnreachableTester::TPortUnreachableTester() + : TimePassed(0) + , ConnectOk(false) + + { + s.Open(0); + if (s.IsValid()) { + AtomicAdd(ActivePortTestersCount, 1); + } + } + + void TPortUnreachableTester::Connect(const TUdpAddress& addr) { + Y_ASSERT(IsValid()); + sockaddr_in6 toAddress; + GetWinsockAddr(&toAddress, addr); + ConnectOk = s.Connect(toAddress); + TimePassed = 0; + } + + TPortUnreachableTester::~TPortUnreachableTester() { + if (s.IsValid()) + AtomicAdd(ActivePortTestersCount, -1); + } + + bool TPortUnreachableTester::Test(float deltaT) { + if (!ConnectOk) + return false; + if (s.IsHostUnreachable()) + return false; + TimePassed += deltaT; + if (TimePassed > PING_DELAY) { + TimePassed = 0; + s.SendEmptyPacket(); + } + return true; + } +} diff --git a/library/cpp/netliba/v6/net_test.h b/library/cpp/netliba/v6/net_test.h new file mode 100644 index 0000000000..cfff7409f5 --- /dev/null +++ b/library/cpp/netliba/v6/net_test.h @@ -0,0 +1,28 @@ +#pragma once + +#include "udp_socket.h" + +namespace NNetliba { + struct TUdpAddress; + + // needed to limit simultaneous port testers to avoid limit on open handles count + extern TAtomic ActivePortTestersCount; + + // need separate socket for each destination + // FreeBSD can not return port unreachable error for unconnected socket + class TPortUnreachableTester: public TThrRefBase { + TNetSocket s; + float TimePassed; + bool ConnectOk; + + ~TPortUnreachableTester() override; + + public: + TPortUnreachableTester(); + bool IsValid() const { + return s.IsValid(); + } + void Connect(const TUdpAddress& addr); + bool Test(float deltaT); + }; +} diff --git a/library/cpp/netliba/v6/stdafx.cpp b/library/cpp/netliba/v6/stdafx.cpp new file mode 100644 index 0000000000..fd4f341c7b --- /dev/null +++ b/library/cpp/netliba/v6/stdafx.cpp @@ -0,0 +1 @@ +#include "stdafx.h" diff --git a/library/cpp/netliba/v6/stdafx.h b/library/cpp/netliba/v6/stdafx.h new file mode 100644 index 0000000000..42ea5c0894 --- /dev/null +++ b/library/cpp/netliba/v6/stdafx.h @@ -0,0 +1,25 @@ +#pragma once + +#include "cstdafx.h" + +#include <util/system/compat.h> +#include <util/network/init.h> +#if defined(_unix_) +#include <netdb.h> +#include <fcntl.h> +#elif defined(_win_) +#include <winsock2.h> +using socklen_t = int; +#endif + +#include <util/generic/ptr.h> + +template <class T> +static const T* BreakAliasing(const void* f) { + return (const T*)f; +} + +template <class T> +static T* BreakAliasing(void* f) { + return (T*)f; +} diff --git a/library/cpp/netliba/v6/udp_address.cpp b/library/cpp/netliba/v6/udp_address.cpp new file mode 100644 index 0000000000..17540602e9 --- /dev/null +++ b/library/cpp/netliba/v6/udp_address.cpp @@ -0,0 +1,300 @@ +#include "stdafx.h" +#include "udp_address.h" + +#include <util/system/mutex.h> +#include <util/system/spinlock.h> + +#ifdef _win_ +#include <iphlpapi.h> +#pragma comment(lib, "Iphlpapi.lib") +#else +#include <errno.h> +#include <ifaddrs.h> +#endif + +namespace NNetliba { + static bool IsValidIPv6(const char* sz) { + enum { + S1, + SEMICOLON, + SCOPE + }; + int state = S1, scCount = 0, digitCount = 0, hasDoubleSemicolon = false; + while (*sz) { + if (state == S1) { + switch (*sz) { + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case 'A': + case 'B': + case 'C': + case 'D': + case 'E': + case 'F': + case 'a': + case 'b': + case 'c': + case 'd': + case 'e': + case 'f': + ++digitCount; + if (digitCount > 4) + return false; + break; + case ':': + state = SEMICOLON; + ++scCount; + break; + case '%': + state = SCOPE; + break; + default: + return false; + } + ++sz; + } else if (state == SEMICOLON) { + if (*sz == ':') { + if (hasDoubleSemicolon) + return false; + hasDoubleSemicolon = true; + ++scCount; + digitCount = 0; + state = S1; + ++sz; + } else { + digitCount = 0; + state = S1; + } + } else if (state == SCOPE) { + // arbitrary string is allowed as scope id + ++sz; + } + } + if (!hasDoubleSemicolon && scCount != 7) + return false; + return scCount <= 7; + } + + static bool ParseInetName(TUdpAddress* pRes, const char* name, int nDefaultPort, EUdpAddressType addressType) { + int nPort = nDefaultPort; + + TString host; + if (name[0] == '[') { + ++name; + const char* nameFin = name; + for (; *nameFin; ++nameFin) { + if (nameFin[0] == ']') + break; + } + host.assign(name, nameFin); + Y_ASSERT(IsValidIPv6(host.c_str())); + name = *nameFin ? nameFin + 1 : nameFin; + if (name[0] == ':') { + char* endPtr = nullptr; + nPort = strtol(name + 1, &endPtr, 10); + if (!endPtr || *endPtr != '\0') + return false; + } + } else { + host = name; + if (!IsValidIPv6(name)) { + size_t nIdx = host.find(':'); + if (nIdx != (size_t)TString::npos) { + const char* pszPort = host.c_str() + nIdx + 1; + char* endPtr = nullptr; + nPort = strtol(pszPort, &endPtr, 10); + if (!endPtr || *endPtr != '\0') + return false; + host.resize(nIdx); + } + } + } + + addrinfo aiHints; + Zero(aiHints); + aiHints.ai_family = AF_UNSPEC; + aiHints.ai_socktype = SOCK_DGRAM; + aiHints.ai_protocol = IPPROTO_UDP; + + // Do not use TMutex here: it has a non-trivial destructor which will be called before + // destruction of current thread, if its TThread declared as global/static variable. + static TAdaptiveLock cs; + TGuard lock(cs); + + addrinfo* aiList = nullptr; + for (int attempt = 0; attempt < 1000; ++attempt) { + int rv = getaddrinfo(host.c_str(), "1313", &aiHints, &aiList); + if (rv == 0) + break; + if (aiList) { + freeaddrinfo(aiList); + } + if (rv != EAI_AGAIN) { + return false; + } + usleep(100 * 1000); + } + for (addrinfo* ptr = aiList; ptr; ptr = ptr->ai_next) { + sockaddr* addr = ptr->ai_addr; + if (addr == nullptr) + continue; + switch (addressType) { + case UAT_ANY: { + if (addr->sa_family != AF_INET && addr->sa_family != AF_INET6) + continue; + break; + } + case UAT_IPV4: { + if (addr->sa_family != AF_INET) + continue; + break; + } + case UAT_IPV6: { + if (addr->sa_family != AF_INET6) + continue; + break; + } + } + + GetUdpAddress(pRes, *(sockaddr_in6*)addr); + pRes->Port = nPort; + freeaddrinfo(aiList); + return true; + } + freeaddrinfo(aiList); + return false; + } + + bool GetLocalAddresses(TVector<TUdpAddress>* addrs) { +#ifdef _win_ + TVector<char> buf; + buf.resize(1000000); + PIP_ADAPTER_ADDRESSES adapterBuf = (PIP_ADAPTER_ADDRESSES)&buf[0]; + ULONG bufSize = buf.ysize(); + + ULONG rv = GetAdaptersAddresses(AF_UNSPEC, 0, NULL, adapterBuf, &bufSize); + if (rv != ERROR_SUCCESS) + return false; + for (PIP_ADAPTER_ADDRESSES ptr = adapterBuf; ptr; ptr = ptr->Next) { + if ((ptr->Flags & (IP_ADAPTER_IPV4_ENABLED | IP_ADAPTER_IPV6_ENABLED)) == 0) { + continue; + } + if (ptr->IfType == IF_TYPE_TUNNEL) { + // ignore tunnels + continue; + } + if (ptr->OperStatus != IfOperStatusUp) { + // ignore disable adapters + continue; + } + if (ptr->Mtu < 1280) { + fprintf(stderr, "WARNING: MTU %d is less then ipv6 minimum", ptr->Mtu); + } + for (IP_ADAPTER_UNICAST_ADDRESS* addr = ptr->FirstUnicastAddress; addr; addr = addr->Next) { + sockaddr* x = (sockaddr*)addr->Address.lpSockaddr; + if (x == 0) + continue; + if (x->sa_family == AF_INET || x->sa_family == AF_INET6) { + TUdpAddress address; + sockaddr_in6* xx = (sockaddr_in6*)x; + GetUdpAddress(&address, *xx); + addrs->push_back(address); + } + } + } + return true; +#else + ifaddrs* ifap; + if (getifaddrs(&ifap) != -1) { + for (ifaddrs* ifa = ifap; ifa; ifa = ifa->ifa_next) { + sockaddr* sa = (sockaddr*)ifa->ifa_addr; + if (sa == nullptr) + continue; + if (sa->sa_family == AF_INET || sa->sa_family == AF_INET6) { + TUdpAddress address; + sockaddr_in6* xx = (sockaddr_in6*)sa; + GetUdpAddress(&address, *xx); + addrs->push_back(address); + } + } + freeifaddrs(ifap); + return true; + } + return false; +#endif + } + + void GetUdpAddress(TUdpAddress* res, const sockaddr_in6& addr) { + if (addr.sin6_family == AF_INET) { + const sockaddr_in& addr4 = *(const sockaddr_in*)&addr; + res->Network = 0; + res->Interface = 0xffff0000ll + (((ui64)(ui32)addr4.sin_addr.s_addr) << 32); + res->Scope = 0; + res->Port = ntohs(addr4.sin_port); + } else if (addr.sin6_family == AF_INET6) { + res->Network = *BreakAliasing<ui64>(addr.sin6_addr.s6_addr + 0); + res->Interface = *BreakAliasing<ui64>(addr.sin6_addr.s6_addr + 8); + res->Scope = addr.sin6_scope_id; + res->Port = ntohs(addr.sin6_port); + } + } + + void GetWinsockAddr(sockaddr_in6* res, const TUdpAddress& addr) { + if (0) { //addr.IsIPv4()) { + // use ipv4 to ipv6 mapping + //// ipv4 + //sockaddr_in &toAddress = *(sockaddr_in*)res; + //Zero(toAddress); + //toAddress.sin_family = AF_INET; + //toAddress.sin_addr.s_addr = addr.GetIPv4(); + //toAddress.sin_port = htons((u_short)addr.Port); + } else { + // ipv6 + sockaddr_in6& toAddress = *(sockaddr_in6*)res; + Zero(toAddress); + toAddress.sin6_family = AF_INET6; + *BreakAliasing<ui64>(toAddress.sin6_addr.s6_addr + 0) = addr.Network; + *BreakAliasing<ui64>(toAddress.sin6_addr.s6_addr + 8) = addr.Interface; + toAddress.sin6_scope_id = addr.Scope; + toAddress.sin6_port = htons((u_short)addr.Port); + } + } + + TUdpAddress CreateAddress(const TString& server, int defaultPort, EUdpAddressType addressType) { + TUdpAddress res; + ParseInetName(&res, server.c_str(), defaultPort, addressType); + return res; + } + + TString GetAddressAsString(const TUdpAddress& addr) { + char buf[1000]; + if (addr.IsIPv4()) { + int ip = addr.GetIPv4(); + sprintf(buf, "%d.%d.%d.%d:%d", + (ip >> 0) & 0xff, (ip >> 8) & 0xff, + (ip >> 16) & 0xff, (ip >> 24) & 0xff, + addr.Port); + } else { + ui16 ipv6[8]; + *BreakAliasing<ui64>(ipv6) = addr.Network; + *BreakAliasing<ui64>(ipv6 + 4) = addr.Interface; + char suffix[100] = ""; + if (addr.Scope != 0) { + sprintf(suffix, "%%%d", addr.Scope); + } + sprintf(buf, "[%x:%x:%x:%x:%x:%x:%x:%x%s]:%d", + ntohs(ipv6[0]), ntohs(ipv6[1]), ntohs(ipv6[2]), ntohs(ipv6[3]), + ntohs(ipv6[4]), ntohs(ipv6[5]), ntohs(ipv6[6]), ntohs(ipv6[7]), + suffix, addr.Port); + } + return buf; + } +} diff --git a/library/cpp/netliba/v6/udp_address.h b/library/cpp/netliba/v6/udp_address.h new file mode 100644 index 0000000000..3e283fe545 --- /dev/null +++ b/library/cpp/netliba/v6/udp_address.h @@ -0,0 +1,48 @@ +#pragma once + +#include <util/generic/string.h> +#include <util/generic/vector.h> +#include <util/system/defaults.h> + +struct sockaddr_in6; + +namespace NNetliba { + struct TUdpAddress { + ui64 Network, Interface; + int Scope, Port; + + TUdpAddress() + : Network(0) + , Interface(0) + , Scope(0) + , Port(0) + { + } + bool IsIPv4() const { + return (Network == 0 && (Interface & 0xffffffffll) == 0xffff0000ll); + } + ui32 GetIPv4() const { + return Interface >> 32; + } + }; + + inline bool operator==(const TUdpAddress& a, const TUdpAddress& b) { + return a.Network == b.Network && a.Interface == b.Interface && a.Scope == b.Scope && a.Port == b.Port; + } + + enum EUdpAddressType { + UAT_ANY, + UAT_IPV4, + UAT_IPV6, + }; + + // accepts sockaddr_in & sockaddr_in6 + void GetUdpAddress(TUdpAddress* res, const sockaddr_in6& addr); + // generates sockaddr_in6 for both ipv4 & ipv6 + void GetWinsockAddr(sockaddr_in6* res, const TUdpAddress& addr); + // supports formats like hostname, hostname:124, 127.0.0.1, 127.0.0.1:80, fe34::12, [fe34::12]:80 + TUdpAddress CreateAddress(const TString& server, int defaultPort, EUdpAddressType type = UAT_ANY); + TString GetAddressAsString(const TUdpAddress& addr); + + bool GetLocalAddresses(TVector<TUdpAddress>* addrs); +} diff --git a/library/cpp/netliba/v6/udp_client_server.cpp b/library/cpp/netliba/v6/udp_client_server.cpp new file mode 100644 index 0000000000..3eaf6e5e96 --- /dev/null +++ b/library/cpp/netliba/v6/udp_client_server.cpp @@ -0,0 +1,1321 @@ +#include "stdafx.h" +#include "udp_client_server.h" +#include "net_acks.h" +#include <util/generic/guid.h> +#include <util/system/hp_timer.h> +#include <util/datetime/cputimer.h> +#include <util/system/yield.h> +#include <util/system/unaligned_mem.h> +#include "block_chain.h" +#include <util/system/shmat.h> +#include "udp_debug.h" +#include "udp_socket.h" +#include "ib_cs.h" + +#include <library/cpp/netliba/socket/socket.h> + +#include <util/random/random.h> +#include <util/system/sanitizers.h> + +#include <atomic> + +namespace NNetliba { + // rely on UDP checksum in packets, check crc only for complete packets + // UPDATE: looks like UDP checksum is not enough, network errors do happen, we saw 600+ retransmits of a ~1MB data packet + + const float UDP_TRANSFER_TIMEOUT = 90.0f; + const float DEFAULT_MAX_WAIT_TIME = 1; + const float UDP_KEEP_PEER_INFO = 600; + // траффик может идти, а новых данных для конкретного пакета может не добавляться. + // это возможно когда мы прерываем процесс в момент передачи и перезапускаем его на том же порту, + // тогда на приемнике повиснет пакет. Этот пакет мы зашибем по этому таймауту + const float UDP_MAX_INPUT_DATA_WAIT = UDP_TRANSFER_TIMEOUT * 2; + + enum { + UDP_PACKET_SIZE_FULL = 8900, // used for ping to detect jumbo-frame support + UDP_PACKET_SIZE = 8800, // max data in packet + UDP_PACKET_SIZE_SMALL = 1350, // 1180 would be better taking into account that 1280 is guaranteed ipv6 minimum MTU + UDP_PACKET_BUF_SIZE = UDP_PACKET_SIZE + 100, + }; + + ////////////////////////////////////////////////////////////////////////// + struct TUdpCompleteInTransfer { + TGUID PacketGuid; + }; + + ////////////////////////////////////////////////////////////////////////// + struct TUdpRecvPacket { + int DataStart, DataSize; + ui32 BlockSum; + // Data[] should be last member in struct, this fact is used to create truncated TUdpRecvPacket in CreateNewSmallPacket() + char Data[UDP_PACKET_BUF_SIZE]; + }; + + struct TUdpInTransfer { + private: + TVector<TUdpRecvPacket*> Packets; + + public: + sockaddr_in6 ToAddress; + int PacketSize, LastPacketSize; + bool HasLastPacket; + TVector<int> NewPacketsToAck; + TCongestionControlPtr Congestion; + float TimeSinceLastRecv; + int Attempt; + TGUID PacketGuid; + int Crc32; + TIntrusivePtr<TSharedMemory> SharedData; + TRequesterPendingDataStats* Stats; + + TUdpInTransfer() + : PacketSize(0) + , LastPacketSize(0) + , HasLastPacket(false) + , TimeSinceLastRecv(0) + , Attempt(0) + , Crc32(0) + , Stats(nullptr) + { + Zero(ToAddress); + } + ~TUdpInTransfer() { + if (Stats) { + Stats->InpCount -= 1; + } + EraseAllPackets(); + } + void EraseAllPackets() { + for (int i = 0; i < Packets.ysize(); ++i) { + ErasePacket(i); + } + Packets.clear(); + HasLastPacket = false; + } + void AttachStats(TRequesterPendingDataStats* stats) { + Stats = stats; + Stats->InpCount += 1; + Y_ASSERT(Packets.empty()); + } + void ErasePacket(int id) { + TUdpRecvPacket* pkt = Packets[id]; + if (pkt) { + if (Stats) { + Stats->InpDataSize -= PacketSize; + } + TRopeDataPacket::FreeBuf((char*)pkt); + Packets[id] = nullptr; + } + } + void AssignPacket(int id, TUdpRecvPacket* pkt) { + ErasePacket(id); + if (pkt && Stats) { + Stats->InpDataSize += PacketSize; + } + Packets[id] = pkt; + } + int GetPacketCount() const { + return Packets.ysize(); + } + void SetPacketCount(int n) { + Packets.resize(n, nullptr); + } + const TUdpRecvPacket* GetPacket(int id) const { + return Packets[id]; + } + TUdpRecvPacket* ExtractPacket(int id) { + TUdpRecvPacket* res = Packets[id]; + if (res) { + if (Stats) { + Stats->InpDataSize -= PacketSize; + } + Packets[id] = nullptr; + } + return res; + } + }; + + struct TUdpOutTransfer { + sockaddr_in6 ToAddress; + TAutoPtr<TRopeDataPacket> Data; + int PacketCount; + int PacketSize, LastPacketSize; + TAckTracker AckTracker; + int Attempt; + TGUID PacketGuid; + int Crc32; + EPacketPriority PacketPriority; + TRequesterPendingDataStats* Stats; + + TUdpOutTransfer() + : PacketCount(0) + , PacketSize(0) + , LastPacketSize(0) + , Attempt(0) + , Crc32(0) + , PacketPriority(PP_LOW) + , Stats(nullptr) + { + Zero(ToAddress); + } + ~TUdpOutTransfer() { + if (Stats) { + Stats->OutCount -= 1; + Stats->OutDataSize -= Data->GetSize(); + } + } + void AttachStats(TRequesterPendingDataStats* stats) { + Stats = stats; + Stats->OutCount += 1; + Stats->OutDataSize += Data->GetSize(); + } + }; + + struct TTransferKey { + TUdpAddress Address; + int Id; + }; + inline bool operator==(const TTransferKey& a, const TTransferKey& b) { + return a.Address == b.Address && a.Id == b.Id; + } + struct TTransferKeyHash { + int operator()(const TTransferKey& k) const { + return (ui32)k.Address.Interface + (ui32)k.Address.Port * (ui32)389461 + (ui32)k.Id; + } + }; + + struct TUdpAddressHash { + int operator()(const TUdpAddress& addr) const { + return (ui32)addr.Interface + (ui32)addr.Port * (ui32)389461; + } + }; + + class TUdpHostRevBufAlloc: public TNonCopyable { + TUdpRecvPacket* RecvPktBuf; + + void AllocNewBuf() { + RecvPktBuf = (TUdpRecvPacket*)TRopeDataPacket::AllocBuf(sizeof(TUdpRecvPacket)); + } + + public: + TUdpHostRevBufAlloc() { + AllocNewBuf(); + } + ~TUdpHostRevBufAlloc() { + FreeBuf(RecvPktBuf); + } + void FreeBuf(TUdpRecvPacket* pkt) { + TRopeDataPacket::FreeBuf((char*)pkt); + } + TUdpRecvPacket* ExtractPacket() { + TUdpRecvPacket* res = RecvPktBuf; + AllocNewBuf(); + return res; + } + TUdpRecvPacket* CreateNewSmallPacket(int sz) { + int pktStructSz = sizeof(TUdpRecvPacket) - Y_ARRAY_SIZE(RecvPktBuf->Data) + sz; + TUdpRecvPacket* pkt = (TUdpRecvPacket*)TRopeDataPacket::AllocBuf(pktStructSz); + return pkt; + } + int GetBufSize() const { + return Y_ARRAY_SIZE(RecvPktBuf->Data); + } + char* GetDataPtr() const { + return RecvPktBuf->Data; + } + }; + + static TAtomic transferIdCounter = (long)(GetCycleCount() & 0x1fffffff); + inline int GetTransferId() { + int res = AtomicAdd(transferIdCounter, 1); + while (res < 0) { + // negative transfer ids are treated as errors, so wrap transfer id + AtomicCas(&transferIdCounter, 0, transferIdCounter); + res = AtomicAdd(transferIdCounter, 1); + } + return res; + } + + static bool IBDetection = true; + class TUdpHost: public IUdpHost { + struct TPeerLink { + TIntrusivePtr<TCongestionControl> UdpCongestion; + TIntrusivePtr<IIBPeer> IBPeer; + double TimeNoActiveTransfers; + + TPeerLink() + : TimeNoActiveTransfers(0) + { + } + bool Update(float deltaT, const TUdpAddress& toAddress, float* maxWaitTime) { + bool updateOk = UdpCongestion->UpdateAlive(toAddress, deltaT, UDP_TRANSFER_TIMEOUT, maxWaitTime); + return updateOk; + } + void StartSleep(const TUdpAddress& toAddress, float* maxWaitTime) { + //printf("peer_link start sleep, IBPeer = %p, refs = %d\n", IBPeer.Get(), (int)IBPeer.RefCount()); + UdpCongestion->UpdateAlive(toAddress, 0, UDP_TRANSFER_TIMEOUT, maxWaitTime); + UdpCongestion->MarkAlive(); + TimeNoActiveTransfers = 0; + } + bool UpdateSleep(float deltaT) { + TimeNoActiveTransfers += deltaT; + if (IBPeer.Get()) { + //printf("peer_link update sleep, IBPeer = %p, refs = %d\n", IBPeer.Get(), (int)IBPeer.RefCount()); + if (IBPeer->GetState() == IIBPeer::OK) { + return true; + } + //printf("Drop broken IB connection\n"); + IBPeer = nullptr; + } + return (TimeNoActiveTransfers < UDP_KEEP_PEER_INFO); + } + }; + + TNetSocket s; + typedef THashMap<TTransferKey, TUdpInTransfer, TTransferKeyHash> TUdpInXferHash; + typedef THashMap<TTransferKey, TUdpOutTransfer, TTransferKeyHash> TUdpOutXferHash; + // congestion control per peer + typedef THashMap<TUdpAddress, TPeerLink, TUdpAddressHash> TPeerLinkHash; + typedef THashMap<TTransferKey, TUdpCompleteInTransfer, TTransferKeyHash> TUdpCompleteInXferHash; + typedef THashMap<TUdpAddress, TIntrusivePtr<TPeerQueueStats>, TUdpAddressHash> TQueueStatsHash; + TUdpInXferHash RecvQueue; + TUdpCompleteInXferHash RecvCompleted; + TUdpOutXferHash SendQueue; + TPeerLinkHash CongestionTrack, CongestionTrackHistory; + TList<TRequest*> ReceivedList; + NHPTimer::STime CurrentT; + TList<TSendResult> SendResults; + TList<TTransferKey> SendOrderLow, SendOrder, SendOrderHighPrior; + TAtomic IsWaiting; + float MaxWaitTime; + std::atomic<float> MaxWaitTime2; + float IBIdleTime; + TVector<TTransferKey> RecvCompletedQueue, KeepCompletedQueue; + float TimeSinceCompletedQueueClean, TimeSinceCongestionHistoryUpdate; + TRequesterPendingDataStats PendingDataStats; + TQueueStatsHash PeerQueueStats; + TIntrusivePtr<IIBClientServer> IB; + typedef THashMap<TIBMsgHandle, TTransferKey> TIBtoTransferKeyHash; + TIBtoTransferKeyHash IBKeyToTransferKey; + + char PktBuf[UDP_PACKET_BUF_SIZE]; + TUdpHostRevBufAlloc RecvBuf; + + TPeerLink& GetPeerLink(const TUdpAddress& ip) { + TPeerLinkHash::iterator z = CongestionTrack.find(ip); + if (z == CongestionTrack.end()) { + z = CongestionTrackHistory.find(ip); + if (z == CongestionTrackHistory.end()) { + TPeerLink& res = CongestionTrack[ip]; + Y_ASSERT(res.UdpCongestion.Get() == nullptr); + res.UdpCongestion = new TCongestionControl; + TQueueStatsHash::iterator zq = PeerQueueStats.find(ip); + if (zq != PeerQueueStats.end()) { + res.UdpCongestion->AttachQueueStats(zq->second); + } + return res; + } else { + TPeerLink& res = CongestionTrack[z->first]; + res = z->second; + CongestionTrackHistory.erase(z); + return res; + } + } else { + Y_ASSERT(CongestionTrackHistory.find(ip) == CongestionTrackHistory.end()); + return z->second; + } + } + void SucceededSend(int id) { + SendResults.push_back(TSendResult(id, true)); + } + void FailedSend(int id) { + SendResults.push_back(TSendResult(id, false)); + } + void SendData(TList<TTransferKey>* order, float deltaT, bool needCheckAlive); + void RecvCycle(); + + public: + TUdpHost() + : CurrentT(0) + , IsWaiting(0) + , MaxWaitTime(DEFAULT_MAX_WAIT_TIME) + , MaxWaitTime2(DEFAULT_MAX_WAIT_TIME) + , IBIdleTime(0) + , TimeSinceCompletedQueueClean(0) + , TimeSinceCongestionHistoryUpdate(0) + { + } + ~TUdpHost() override { + for (TList<TRequest*>::const_iterator i = ReceivedList.begin(); i != ReceivedList.end(); ++i) + delete *i; + } + + bool Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket); + + TRequest* GetRequest() override { + if (ReceivedList.empty()) { + if (IB.Get()) { + return IB->GetRequest(); + } + return nullptr; + } + TRequest* res = ReceivedList.front(); + ReceivedList.pop_front(); + return res; + } + + void AddToSendOrder(const TTransferKey& transferKey, EPacketPriority pp) { + if (pp == PP_LOW) + SendOrderLow.push_back(transferKey); + else if (pp == PP_NORMAL) + SendOrder.push_back(transferKey); + else if (pp == PP_HIGH) + SendOrderHighPrior.push_back(transferKey); + else + Y_ASSERT(0); + + CancelWait(); + } + + int Send(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data, int crc32, TGUID* packetGuid, EPacketPriority pp) override { + if (addr.Port == 0) { + // shortcut for broken addresses + if (packetGuid && packetGuid->IsEmpty()) + CreateGuid(packetGuid); + int reqId = GetTransferId(); + FailedSend(reqId); + return reqId; + } + TTransferKey transferKey; + transferKey.Address = addr; + transferKey.Id = GetTransferId(); + Y_ASSERT(SendQueue.find(transferKey) == SendQueue.end()); + + TPeerLink& peerInfo = GetPeerLink(transferKey.Address); + + TUdpOutTransfer& xfer = SendQueue[transferKey]; + GetWinsockAddr(&xfer.ToAddress, transferKey.Address); + xfer.Crc32 = crc32; + xfer.PacketPriority = pp; + if (!packetGuid || packetGuid->IsEmpty()) { + CreateGuid(&xfer.PacketGuid); + if (packetGuid) + *packetGuid = xfer.PacketGuid; + } else { + xfer.PacketGuid = *packetGuid; + } + xfer.Data.Reset(data.Release()); + xfer.AttachStats(&PendingDataStats); + xfer.AckTracker.AttachCongestionControl(peerInfo.UdpCongestion.Get()); + + bool isSentOverIB = false; + // we don't support priorities (=service levels in IB terms) currently + // so send only PP_NORMAL traffic over IB + if (pp == PP_NORMAL && peerInfo.IBPeer.Get() && xfer.Data->GetSharedData() == nullptr) { + TIBMsgHandle hndl = IB->Send(peerInfo.IBPeer, xfer.Data.Get(), xfer.PacketGuid); + if (hndl >= 0) { + IBKeyToTransferKey[hndl] = transferKey; + isSentOverIB = true; + } else { + // so we failed to use IB, ibPeer is either not connected yet or failed + if (peerInfo.IBPeer->GetState() == IIBPeer::FAILED) { + //printf("Disconnect failed IB peer\n"); + peerInfo.IBPeer = nullptr; + } + } + } + if (!isSentOverIB) { + AddToSendOrder(transferKey, pp); + } + + return transferKey.Id; + } + + bool GetSendResult(TSendResult* res) override { + if (SendResults.empty()) { + if (IB.Get()) { + TIBSendResult sr; + if (IB->GetSendResult(&sr)) { + TIBtoTransferKeyHash::iterator z = IBKeyToTransferKey.find(sr.Handle); + if (z == IBKeyToTransferKey.end()) { + Y_VERIFY(0, "unknown handle returned from IB"); + } + TTransferKey transferKey = z->second; + IBKeyToTransferKey.erase(z); + + TUdpOutXferHash::iterator i = SendQueue.find(transferKey); + if (i == SendQueue.end()) { + Y_VERIFY(0, "IBKeyToTransferKey refers nonexisting xfer"); + } + if (sr.Success) { + TUdpOutTransfer& xfer = i->second; + xfer.AckTracker.MarkAlive(); // do we really need this? + *res = TSendResult(transferKey.Id, sr.Success); + SendQueue.erase(i); + return true; + } else { + //printf("IB send failed, fall back to regular network\n"); + // Houston, we got a problem + // IB failed to send, try to use regular network + TUdpOutTransfer& xfer = i->second; + AddToSendOrder(transferKey, xfer.PacketPriority); + } + } + } + return false; + } + *res = SendResults.front(); + SendResults.pop_front(); + return true; + } + + void Step() override; + void IBStep() override; + + void Wait(float seconds) override { + if (seconds < 1e-3) + seconds = 0; + if (seconds > MaxWaitTime) + seconds = MaxWaitTime; + if (IBIdleTime < 0.010) { + seconds = 0; + } + if (seconds == 0) { + ThreadYield(); + } else { + AtomicAdd(IsWaiting, 1); + if (seconds > MaxWaitTime2) + seconds = MaxWaitTime2; + MaxWaitTime2 = DEFAULT_MAX_WAIT_TIME; + + if (seconds == 0) { + ThreadYield(); + } else { + if (IB.Get()) { + for (float done = 0; done < seconds;) { + float deltaSleep = Min(seconds - done, 0.002f); + s.Wait(deltaSleep); + NHPTimer::STime tChk; + NHPTimer::GetTime(&tChk); + if (IB->Step(tChk)) { + IBIdleTime = 0; + break; + } + done += deltaSleep; + } + } else { + s.Wait(seconds); + } + } + AtomicAdd(IsWaiting, -1); + } + } + + void CancelWait() override { + MaxWaitTime2 = 0; + if (AtomicAdd(IsWaiting, 0) == 1) { + s.SendSelfFakePacket(); + } + } + + void GetPendingDataSize(TRequesterPendingDataStats* res) override { + *res = PendingDataStats; +#ifndef NDEBUG + TRequesterPendingDataStats chk; + for (TUdpOutXferHash::const_iterator i = SendQueue.begin(); i != SendQueue.end(); ++i) { + TRopeDataPacket* pckt = i->second.Data.Get(); + if (pckt) { + chk.OutDataSize += pckt->GetSize(); + ++chk.OutCount; + } + } + for (TUdpInXferHash::const_iterator i = RecvQueue.begin(); i != RecvQueue.end(); ++i) { + const TUdpInTransfer& tr = i->second; + for (int p = 0; p < tr.GetPacketCount(); ++p) { + if (tr.GetPacket(p)) { + chk.InpDataSize += tr.PacketSize; + } + } + ++chk.InpCount; + } + Y_ASSERT(memcmp(&chk, res, sizeof(chk)) == 0); +#endif + } + TString GetDebugInfo() override; + TString GetPeerLinkDebug(const TPeerLinkHash& ch); + void Kill(const TUdpAddress& addr) override; + TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) override; + }; + + bool TUdpHost::Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) { + if (s.IsValid()) { + Y_ASSERT(0); + return false; + } + s.Open(socket); + if (!s.IsValid()) + return false; + + if (IBDetection) + IB = CreateIBClientServer(); + + NHPTimer::GetTime(&CurrentT); + return true; + } + + static bool HasAllPackets(const TUdpInTransfer& res) { + if (!res.HasLastPacket) + return false; + for (int i = res.GetPacketCount() - 1; i >= 0; --i) { + if (!res.GetPacket(i)) + return false; + } + return true; + } + + // grouped acks, first int - packet_id, second int - bit mask for 32 packets preceding packet_id + const int SIZEOF_ACK = 8; + static int WriteAck(TUdpInTransfer* p, int* dst, int maxAcks) { + int ackCount = 0; + if (p->NewPacketsToAck.size() > 1) + Sort(p->NewPacketsToAck.begin(), p->NewPacketsToAck.end()); + int lastAcked = 0; + for (size_t idx = 0; idx < p->NewPacketsToAck.size(); ++idx) { + int pkt = p->NewPacketsToAck[idx]; + if (idx == p->NewPacketsToAck.size() - 1 || pkt > lastAcked + 30) { + *dst++ = pkt; + int bitMask = 0; + int backPackets = Min(pkt, 32); + for (int k = 0; k < backPackets; ++k) { + if (p->GetPacket(pkt - k - 1)) + bitMask |= 1 << k; + } + *dst++ = bitMask; + if (++ackCount >= maxAcks) + break; + lastAcked = pkt; + //printf("sending ack %d (mask %x)\n", pkt, bitMask); + } + } + p->NewPacketsToAck.clear(); + return ackCount; + } + + static void AckPacket(TUdpOutTransfer* p, int pkt, float deltaT, bool updateRTT) { + if (pkt < 0 || pkt >= p->PacketCount) { + Y_ASSERT(0); + return; + } + p->AckTracker.Ack(pkt, deltaT, updateRTT); + } + + static void ReadAcks(TUdpOutTransfer* p, const int* acks, int ackCount, float deltaT) { + for (int i = 0; i < ackCount; ++i) { + int pkt = *acks++; + int bitMask = *acks++; + bool updateRTT = i == ackCount - 1; // update RTT using only last packet in the pack + AckPacket(p, pkt, deltaT, updateRTT); + for (int k = 0; k < 32; ++k) { + if (bitMask & (1 << k)) + AckPacket(p, pkt - k - 1, deltaT, false); + } + } + } + + using namespace NNetlibaSocket::NNetliba; + + const ui64 KILL_PASSPHRASE1 = 0x98ff9cefb11d9a4cul; + const ui64 KILL_PASSPHRASE2 = 0xf7754c29e0be95eaul; + + template <class T> + inline T Read(char** data) { + T res = ReadUnaligned<T>(*data); + *data += sizeof(T); + return res; + } + template <class T> + inline void Write(char** data, T res) { + WriteUnaligned<T>(*data, res); + *data += sizeof(T); + } + + static void RequireResend(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, int attempt) { + char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, transferId); + Write(&pktData, (char)ACK_RESEND); + Write(&pktData, attempt); + s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG); + } + + static void RequireResendNoShmem(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, int attempt) { + char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, transferId); + Write(&pktData, (char)ACK_RESEND_NOSHMEM); + Write(&pktData, attempt); + s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG); + } + + static void AckComplete(const TNetSocket& s, const sockaddr_in6& toAddress, int transferId, const TGUID& packetGuid, int packetId) { + char buf[100], *pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, transferId); + Write(&pktData, (char)ACK_COMPLETE); + Write(&pktData, packetGuid); + Write(&pktData, packetId); // we need packetId to update RTT + s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG); + } + + static void SendPing(TNetSocket& s, const sockaddr_in6& toAddress, int selfNetworkOrderPort) { + char pktBuf[UDP_PACKET_SIZE_FULL]; + char* pktData = pktBuf + UDP_LOW_LEVEL_HEADER_SIZE; + if (NSan::MSanIsOn()) { + Zero(pktBuf); + } + Write(&pktData, (int)0); + Write(&pktData, (char)PING); + Write(&pktData, selfNetworkOrderPort); + s.SendTo(pktBuf, UDP_PACKET_SIZE_FULL, toAddress, FF_DONT_FRAG); + } + + // not MTU discovery, just figure out IB address of the peer + static void SendFakePing(TNetSocket& s, const sockaddr_in6& toAddress, int selfNetworkOrderPort) { + char buf[100]; + char* pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, (int)0); + Write(&pktData, (char)PING); + Write(&pktData, selfNetworkOrderPort); + s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG); + } + + void TUdpHost::SendData(TList<TTransferKey>* order, float deltaT1, bool needCheckAlive) { + for (TList<TTransferKey>::iterator z = order->begin(); z != order->end();) { + // pick connection to send + const TTransferKey& transferKey = *z; + TUdpOutXferHash::iterator i = SendQueue.find(transferKey); + if (i == SendQueue.end()) { + z = order->erase(z); + continue; + } + ++z; + + // perform sending + int transferId = transferKey.Id; + TUdpOutTransfer& xfer = i->second; + + if (!xfer.AckTracker.IsInitialized()) { + TIntrusivePtr<TCongestionControl> congestion = xfer.AckTracker.GetCongestionControl(); + Y_ASSERT(congestion.Get() != nullptr); + if (!congestion->IsKnownMTU()) { + TLameMTUDiscovery* md = congestion->GetMTUDiscovery(); + if (md->IsTimedOut()) { + congestion->SetMTU(UDP_PACKET_SIZE_SMALL); + } else { + if (md->CanSend()) { + SendPing(s, xfer.ToAddress, s.GetNetworkOrderPort()); + md->PingSent(); + } + continue; + } + } + // try to use large mtu, we could have selected small mtu due to connectivity problems + if (congestion->GetMTU() == UDP_PACKET_SIZE_SMALL || IB.Get() != nullptr) { + // recheck every ~50mb + int chkDenom = (50000000 / xfer.Data->GetSize()) | 1; + if ((NetAckRnd() % chkDenom) == 0) { + //printf("send rechecking ping\n"); + if (congestion->GetMTU() == UDP_PACKET_SIZE_SMALL) { + SendPing(s, xfer.ToAddress, s.GetNetworkOrderPort()); + } else { + SendFakePing(s, xfer.ToAddress, s.GetNetworkOrderPort()); + } + } + } + xfer.PacketSize = congestion->GetMTU(); + xfer.LastPacketSize = xfer.Data->GetSize() % xfer.PacketSize; + xfer.PacketCount = xfer.Data->GetSize() / xfer.PacketSize + 1; + xfer.AckTracker.SetPacketCount(xfer.PacketCount); + } + + xfer.AckTracker.Step(deltaT1); + MaxWaitTime = Min(MaxWaitTime, xfer.AckTracker.GetTimeToNextPacketTimeout()); + if (needCheckAlive && !xfer.AckTracker.IsAlive()) { + FailedSend(transferId); + SendQueue.erase(i); + continue; + } + bool sendBufferOverflow = false; + while (xfer.AckTracker.CanSend()) { + NHPTimer::STime tCopy = CurrentT; + float deltaT2 = (float)NHPTimer::GetTimePassed(&tCopy); + deltaT2 = ClampVal(deltaT2, 0.0f, UDP_TRANSFER_TIMEOUT / 3); + + int pkt = xfer.AckTracker.GetPacketToSend(deltaT2); + if (pkt == -1) { + break; + } + + int dataSize = xfer.PacketSize; + if (pkt == xfer.PacketCount - 1) + dataSize = xfer.LastPacketSize; + + char* pktData = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, transferId); + char pktType = xfer.PacketSize == UDP_PACKET_SIZE ? DATA : DATA_SMALL; + TSharedMemory* shm = xfer.Data->GetSharedData(); + if (shm) { + if (pktType == DATA) + pktType = DATA_SHMEM; + else + pktType = DATA_SMALL_SHMEM; + } + Write(&pktData, pktType); + Write(&pktData, xfer.Attempt); + Write(&pktData, pkt); + if (pkt == 0) { + Write(&pktData, xfer.PacketGuid); + Write(&pktData, xfer.Crc32); + if (shm) { + Write(&pktData, shm->GetId()); + Write(&pktData, shm->GetSize()); + } + } + TBlockChainIterator dataReader(xfer.Data->GetChain()); + dataReader.Seek(pkt * xfer.PacketSize); + dataReader.Read(pktData, dataSize); + pktData += dataSize; + int sendSize = (int)(pktData - PktBuf); + TNetSocket::ESendError sendErr = s.SendTo(PktBuf, sendSize, xfer.ToAddress, FF_ALLOW_FRAG); + if (sendErr != TNetSocket::SEND_OK) { + if (sendErr == TNetSocket::SEND_NO_ROUTE_TO_HOST) { + FailedSend(transferId); + SendQueue.erase(i); + break; + } else { + // most probably out of send buffer space (or something terrible has happened) + xfer.AckTracker.AddToResend(pkt); + sendBufferOverflow = true; + MaxWaitTime = 0; + //printf("failed send\n"); + break; + } + } + } + if (sendBufferOverflow) + break; + } + } + + void TUdpHost::RecvCycle() { + for (;;) { + sockaddr_in6 fromAddress; + int rv = RecvBuf.GetBufSize(); + bool recvOk = s.RecvFrom(RecvBuf.GetDataPtr(), &rv, &fromAddress); + if (!recvOk) + break; + + NHPTimer::STime tCopy = CurrentT; + float deltaT = (float)NHPTimer::GetTimePassed(&tCopy); + deltaT = ClampVal(deltaT, 0.0f, UDP_TRANSFER_TIMEOUT / 3); + + //int fromIP = fromAddress.sin_addr.s_addr; + + TTransferKey k; + char* pktData = RecvBuf.GetDataPtr() + UDP_LOW_LEVEL_HEADER_SIZE; + GetUdpAddress(&k.Address, fromAddress); + k.Id = Read<int>(&pktData); + int transferId = k.Id; + int cmd = Read<char>(&pktData); + Y_ASSERT(cmd == (int)*(RecvBuf.GetDataPtr() + CMD_POS)); + switch (cmd) { + case DATA: + case DATA_SMALL: + case DATA_SHMEM: + case DATA_SMALL_SHMEM: { + int attempt = Read<int>(&pktData); + int packetId = Read<int>(&pktData); + //printf("data packet %d (trans ID = %d)\n", packetId, transferId); + TUdpCompleteInXferHash::iterator itCompl = RecvCompleted.find(k); + if (itCompl != RecvCompleted.end()) { + Y_ASSERT(RecvQueue.find(k) == RecvQueue.end()); + const TUdpCompleteInTransfer& complete = itCompl->second; + bool sendAckComplete = true; + if (packetId == 0) { + // check packet GUID + char* tmpPktData = pktData; + TGUID packetGuid; + packetGuid = Read<TGUID>(&tmpPktData); + if (packetGuid != complete.PacketGuid) { + // we are receiving new data with the same transferId + // in this case we have to flush all the information about previous transfer + // and start over + //printf("same transferId for a different packet\n"); + RecvCompleted.erase(itCompl); + sendAckComplete = false; + } + } + if (sendAckComplete) { + AckComplete(s, fromAddress, transferId, complete.PacketGuid, packetId); + break; + } + } + TUdpInXferHash::iterator rq = RecvQueue.find(k); + if (rq == RecvQueue.end()) { + //printf("new input transfer\n"); + TUdpInTransfer& res = RecvQueue[k]; + res.ToAddress = fromAddress; + res.Attempt = attempt; + res.Congestion = GetPeerLink(k.Address).UdpCongestion.Get(); + res.PacketSize = 0; + res.HasLastPacket = false; + res.AttachStats(&PendingDataStats); + rq = RecvQueue.find(k); + Y_ASSERT(rq != RecvQueue.end()); + } + TUdpInTransfer& res = rq->second; + res.Congestion->MarkAlive(); + res.TimeSinceLastRecv = 0; + + if (packetId == 0) { + TGUID packetGuid; + packetGuid = Read<TGUID>(&pktData); + int crc32 = Read<int>(&pktData); + res.Crc32 = crc32; + res.PacketGuid = packetGuid; + if (cmd == DATA_SHMEM || cmd == DATA_SMALL_SHMEM) { + // link to attached shared memory + TGUID shmemId = Read<TGUID>(&pktData); + int shmemSize = Read<int>(&pktData); + if (res.SharedData.Get() == nullptr) { + res.SharedData = new TSharedMemory; + if (!res.SharedData->Open(shmemId, shmemSize)) { + res.SharedData = nullptr; + RequireResendNoShmem(s, res.ToAddress, transferId, res.Attempt); + break; + } + } + } + } + if (attempt != res.Attempt) { + RequireResend(s, res.ToAddress, transferId, res.Attempt); + break; + } else { + if (res.PacketSize == 0) { + res.PacketSize = (cmd == DATA || cmd == DATA_SHMEM ? UDP_PACKET_SIZE : UDP_PACKET_SIZE_SMALL); + } else { + // check that all data is of same size + Y_ASSERT(cmd == DATA || cmd == DATA_SMALL); + Y_ASSERT(res.PacketSize == (cmd == DATA ? UDP_PACKET_SIZE : UDP_PACKET_SIZE_SMALL)); + } + + int dataSize = (int)(RecvBuf.GetDataPtr() + rv - pktData); + + Y_ASSERT(dataSize <= res.PacketSize); + if (dataSize > res.PacketSize) + break; // mem overrun protection + if (packetId >= res.GetPacketCount()) + res.SetPacketCount(packetId + 1); + { + TUdpRecvPacket* pkt = nullptr; + if (res.PacketSize == UDP_PACKET_SIZE_SMALL) { + // save memory by using smaller buffer at the cost of additional memcpy + pkt = RecvBuf.CreateNewSmallPacket(dataSize); + memcpy(pkt->Data, pktData, dataSize); + pkt->DataStart = 0; + pkt->DataSize = dataSize; + } else { + int dataStart = (int)(pktData - RecvBuf.GetDataPtr()); // data offset in the packet + pkt = RecvBuf.ExtractPacket(); + pkt->DataStart = dataStart; + pkt->DataSize = dataSize; + } + // calc packet sum, will be used to calc whole message crc + pkt->BlockSum = TIncrementalChecksumCalcer::CalcBlockSum(pkt->Data + pkt->DataStart, pkt->DataSize); + res.AssignPacket(packetId, pkt); + } + + if (dataSize != res.PacketSize) { + res.LastPacketSize = dataSize; + res.HasLastPacket = true; + } + + if (HasAllPackets(res)) { + //printf("received\n"); + TRequest* out = new TRequest; + out->Address = k.Address; + out->Guid = res.PacketGuid; + TIncrementalChecksumCalcer incCS; + int packetCount = res.GetPacketCount(); + out->Data.Reset(new TRopeDataPacket); + for (int i = 0; i < packetCount; ++i) { + TUdpRecvPacket* pkt = res.ExtractPacket(i); + Y_ASSERT(pkt->DataSize == ((i == packetCount - 1) ? res.LastPacketSize : res.PacketSize)); + out->Data->AddBlock((char*)pkt, pkt->Data + pkt->DataStart, pkt->DataSize); + incCS.AddBlockSum(pkt->BlockSum, pkt->DataSize); + } + out->Data->AttachSharedData(res.SharedData); + res.EraseAllPackets(); + + int crc32 = incCS.CalcChecksum(); // CalcChecksum(out->Data->GetChain()); +#ifdef SIMULATE_NETWORK_FAILURES + bool crcOk = crc32 == res.Crc32 ? (RandomNumber<size_t>() % 10) != 0 : false; +#else + bool crcOk = crc32 == res.Crc32; +#endif + if (crcOk) { + ReceivedList.push_back(out); + Y_ASSERT(RecvCompleted.find(k) == RecvCompleted.end()); + TUdpCompleteInTransfer& complete = RecvCompleted[k]; + RecvCompletedQueue.push_back(k); + complete.PacketGuid = res.PacketGuid; + AckComplete(s, res.ToAddress, transferId, complete.PacketGuid, packetId); + RecvQueue.erase(rq); + } else { + //printf("crc failed, require resend\n"); + delete out; + ++res.Attempt; + res.NewPacketsToAck.clear(); + RequireResend(s, res.ToAddress, transferId, res.Attempt); + } + } else { + res.NewPacketsToAck.push_back(packetId); + } + } + } break; + case ACK: { + TUdpOutXferHash::iterator i = SendQueue.find(k); + if (i == SendQueue.end()) + break; + TUdpOutTransfer& xfer = i->second; + if (!xfer.AckTracker.IsInitialized()) + break; + xfer.AckTracker.MarkAlive(); + int attempt = Read<int>(&pktData); + Y_ASSERT(attempt <= xfer.Attempt); + if (attempt != xfer.Attempt) + break; + ReadAcks(&xfer, (int*)pktData, (int)(RecvBuf.GetDataPtr() + rv - pktData) / SIZEOF_ACK, deltaT); + break; + } + case ACK_COMPLETE: { + TUdpOutXferHash::iterator i = SendQueue.find(k); + if (i == SendQueue.end()) + break; + TUdpOutTransfer& xfer = i->second; + xfer.AckTracker.MarkAlive(); + TGUID packetGuid; + packetGuid = Read<TGUID>(&pktData); + int packetId = Read<int>(&pktData); + if (packetGuid == xfer.PacketGuid) { + xfer.AckTracker.Ack(packetId, deltaT, true); // update RTT + xfer.AckTracker.AckAll(); // acking packets is required, otherwise they will be treated as lost (look AckTracker destructor) + SucceededSend(transferId); + SendQueue.erase(i); + } else { + // peer asserts that he has received this packet but packetGuid is wrong + // try to resend everything + // ++xfer.Attempt; // should not do this, only sender can modify attempt number, otherwise cycle is possible with out of order packets + xfer.AckTracker.Resend(); + } + break; + } break; + case ACK_RESEND: { + TUdpOutXferHash::iterator i = SendQueue.find(k); + if (i == SendQueue.end()) + break; + TUdpOutTransfer& xfer = i->second; + xfer.AckTracker.MarkAlive(); + int attempt = Read<int>(&pktData); + if (xfer.Attempt != attempt) { + // reset current tranfser & initialize new one + xfer.Attempt = attempt; + xfer.AckTracker.Resend(); + } + break; + } + case ACK_RESEND_NOSHMEM: { + // abort execution here + // failed to open shmem on recv side, need to transmit data without using shmem + Y_VERIFY(0, "not implemented yet"); + break; + } + case PING: { + sockaddr_in6 trueFromAddress = fromAddress; + int port = Read<int>(&pktData); + Y_ASSERT(trueFromAddress.sin6_family == AF_INET6); + trueFromAddress.sin6_port = port; + // can not set MTU for fromAddress here since asymmetrical mtu is possible + char* pktData2 = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData2, (int)0); + Write(&pktData2, (char)PONG); + if (IB.Get()) { + const TIBConnectInfo& ibConnectInfo = IB->GetConnectInfo(); + Write(&pktData2, ibConnectInfo); + Write(&pktData2, trueFromAddress); + } + s.SendTo(PktBuf, pktData2 - PktBuf, trueFromAddress, FF_ALLOW_FRAG); + break; + } + case PONG: { + TPeerLink& peerInfo = GetPeerLink(k.Address); + peerInfo.UdpCongestion->SetMTU(UDP_PACKET_SIZE); + int dataSize = (int)(RecvBuf.GetDataPtr() + rv - pktData); + if (dataSize == sizeof(TIBConnectInfo) + sizeof(sockaddr_in6)) { + if (IB.Get() != nullptr && peerInfo.IBPeer.Get() == nullptr) { + TIBConnectInfo info = Read<TIBConnectInfo>(&pktData); + sockaddr_in6 myAddress = Read<sockaddr_in6>(&pktData); + TUdpAddress myUdpAddress; + GetUdpAddress(&myUdpAddress, myAddress); + peerInfo.IBPeer = IB->ConnectPeer(info, k.Address, myUdpAddress); + } + } + break; + } + case KILL: { + ui64 p1 = Read<ui64>(&pktData); + ui64 p2 = Read<ui64>(&pktData); + int restSize = (int)(RecvBuf.GetDataPtr() + rv - pktData); + if (restSize == 0 && p1 == KILL_PASSPHRASE1 && p2 == KILL_PASSPHRASE2) { + abort(); + } + break; + } + default: + Y_ASSERT(0); + break; + } + } + } + + void TUdpHost::IBStep() { + if (IB.Get()) { + NHPTimer::STime tChk = CurrentT; + float chkDeltaT = (float)NHPTimer::GetTimePassed(&tChk); + if (IB->Step(tChk)) { + IBIdleTime = -chkDeltaT; + } + } + } + + void TUdpHost::Step() { + if (IB.Get()) { + NHPTimer::STime tChk = CurrentT; + float chkDeltaT = (float)NHPTimer::GetTimePassed(&tChk); + if (IB->Step(tChk)) { + IBIdleTime = -chkDeltaT; + } + if (chkDeltaT < 0.0005) { + return; + } + } + + if (UseTOSforAcks) { + s.SetTOS(0x20); + } else { + s.SetTOS(0); + } + + RecvCycle(); + + float deltaT = (float)NHPTimer::GetTimePassed(&CurrentT); + deltaT = ClampVal(deltaT, 0.0f, UDP_TRANSFER_TIMEOUT / 3); + + MaxWaitTime = DEFAULT_MAX_WAIT_TIME; + IBIdleTime += deltaT; + + bool needCheckAlive = false; + + // update alive ports + const float INACTIVE_CONGESTION_UPDATE_INTERVAL = 1; + TimeSinceCongestionHistoryUpdate += deltaT; + if (TimeSinceCongestionHistoryUpdate > INACTIVE_CONGESTION_UPDATE_INTERVAL) { + for (TPeerLinkHash::iterator i = CongestionTrackHistory.begin(); i != CongestionTrackHistory.end();) { + TPeerLink& pl = i->second; + if (!pl.UpdateSleep(TimeSinceCongestionHistoryUpdate)) { + TPeerLinkHash::iterator k = i++; + CongestionTrackHistory.erase(k); + needCheckAlive = true; + } else { + ++i; + } + } + TimeSinceCongestionHistoryUpdate = 0; + } + for (TPeerLinkHash::iterator i = CongestionTrack.begin(); i != CongestionTrack.end();) { + const TUdpAddress& addr = i->first; + TPeerLink& pl = i->second; + if (pl.UdpCongestion->GetTransferCount() == 0) { + pl.StartSleep(addr, &MaxWaitTime); + CongestionTrackHistory[i->first] = i->second; + TPeerLinkHash::iterator k = i++; + CongestionTrack.erase(k); + } else if (!pl.Update(deltaT, addr, &MaxWaitTime)) { + TPeerLinkHash::iterator k = i++; + CongestionTrack.erase(k); + needCheckAlive = true; + } else { + ++i; + } + } + + // send acks on received data + for (TUdpInXferHash::iterator i = RecvQueue.begin(); i != RecvQueue.end();) { + const TTransferKey& transKey = i->first; + int transferId = transKey.Id; + TUdpInTransfer& xfer = i->second; + xfer.TimeSinceLastRecv += deltaT; + if (xfer.TimeSinceLastRecv > UDP_MAX_INPUT_DATA_WAIT || (needCheckAlive && !xfer.Congestion->IsAlive())) { + TUdpInXferHash::iterator k = i++; + RecvQueue.erase(k); + continue; + } + Y_ASSERT(RecvCompleted.find(i->first) == RecvCompleted.end()); // state "Complete & incomplete" is incorrect + if (!xfer.NewPacketsToAck.empty()) { + char* pktData = PktBuf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, transferId); + Write(&pktData, (char)ACK); + Write(&pktData, xfer.Attempt); + int acks = WriteAck(&xfer, (int*)pktData, (int)(xfer.PacketSize - (pktData - PktBuf)) / SIZEOF_ACK); + pktData += acks * SIZEOF_ACK; + s.SendTo(PktBuf, (int)(pktData - PktBuf), xfer.ToAddress, FF_ALLOW_FRAG); + } + ++i; + } + + if (UseTOSforAcks) { + s.SetTOS(0x60); + } + + // send data for outbound connections + SendData(&SendOrderHighPrior, deltaT, needCheckAlive); + SendData(&SendOrder, deltaT, needCheckAlive); + SendData(&SendOrderLow, deltaT, needCheckAlive); + + // roll send order to avoid exotic problems with lots of peers and high traffic + SendOrderHighPrior.splice(SendOrderHighPrior.end(), SendOrderHighPrior, SendOrderHighPrior.begin()); + //SendOrder.splice(SendOrder.end(), SendOrder, SendOrder.begin()); // sending data in order has lower delay and shorter queue + + // clean completed queue + TimeSinceCompletedQueueClean += deltaT; + if (TimeSinceCompletedQueueClean > UDP_TRANSFER_TIMEOUT * 1.5) { + for (size_t i = 0; i < KeepCompletedQueue.size(); ++i) { + TUdpCompleteInXferHash::iterator k = RecvCompleted.find(KeepCompletedQueue[i]); + if (k != RecvCompleted.end()) + RecvCompleted.erase(k); + } + KeepCompletedQueue.clear(); + KeepCompletedQueue.swap(RecvCompletedQueue); + TimeSinceCompletedQueueClean = 0; + } + } + + TString TUdpHost::GetPeerLinkDebug(const TPeerLinkHash& ch) { + TString res; + char buf[1000]; + for (const auto& i : ch) { + const TUdpAddress& ip = i.first; + const TCongestionControl& cc = *i.second.UdpCongestion; + IIBPeer* ibPeer = i.second.IBPeer.Get(); + sprintf(buf, "%s\tIB: %d, RTT: %g Timeout: %g Window: %g MaxWin: %g FailRate: %g TimeSinceLastRecv: %g Transfers: %d MTU: %d\n", + GetAddressAsString(ip).c_str(), + ibPeer ? ibPeer->GetState() : -1, + cc.GetRTT() * 1000, cc.GetTimeout() * 1000, cc.GetWindow(), cc.GetMaxWindow(), cc.GetFailRate(), + cc.GetTimeSinceLastRecv() * 1000, cc.GetTransferCount(), cc.GetMTU()); + res += buf; + } + return res; + } + + TString TUdpHost::GetDebugInfo() { + TString res; + char buf[1000]; + sprintf(buf, "Receiving %d msgs, sending %d high prior, %d regular msgs, %d low prior msgs\n", + RecvQueue.ysize(), (int)SendOrderHighPrior.size(), (int)SendOrder.size(), (int)SendOrderLow.size()); + res += buf; + + TRequesterPendingDataStats pds; + GetPendingDataSize(&pds); + sprintf(buf, "Pending data size: %" PRIu64 "\n", pds.InpDataSize + pds.OutDataSize); + res += buf; + sprintf(buf, " in packets: %d, size %" PRIu64 "\n", pds.InpCount, pds.InpDataSize); + res += buf; + sprintf(buf, " out packets: %d, size %" PRIu64 "\n", pds.OutCount, pds.OutDataSize); + res += buf; + + res += "\nCongestion info:\n"; + res += GetPeerLinkDebug(CongestionTrack); + res += "\nCongestion info history:\n"; + res += GetPeerLinkDebug(CongestionTrackHistory); + + return res; + } + + static void SendKill(const TNetSocket& s, const sockaddr_in6& toAddress) { + char buf[100]; + char* pktData = buf + UDP_LOW_LEVEL_HEADER_SIZE; + Write(&pktData, (int)0); + Write(&pktData, (char)KILL); + Write(&pktData, KILL_PASSPHRASE1); + Write(&pktData, KILL_PASSPHRASE2); + s.SendTo(buf, (int)(pktData - buf), toAddress, FF_ALLOW_FRAG); + } + + void TUdpHost::Kill(const TUdpAddress& addr) { + sockaddr_in6 target; + GetWinsockAddr(&target, addr); + SendKill(s, target); + } + + TIntrusivePtr<IPeerQueueStats> TUdpHost::GetQueueStats(const TUdpAddress& addr) { + TQueueStatsHash::iterator zq = PeerQueueStats.find(addr); + if (zq != PeerQueueStats.end()) { + return zq->second.Get(); + } + TPeerQueueStats* res = new TPeerQueueStats; + PeerQueueStats[addr] = res; + // attach to existing congestion tracker + TPeerLinkHash::iterator z; + z = CongestionTrack.find(addr); + if (z != CongestionTrack.end()) { + z->second.UdpCongestion->AttachQueueStats(res); + } + z = CongestionTrackHistory.find(addr); + if (z != CongestionTrackHistory.end()) { + z->second.UdpCongestion->AttachQueueStats(res); + } + return res; + } + + ////////////////////////////////////////////////////////////////////////// + + TIntrusivePtr<IUdpHost> CreateUdpHost(int port) { + TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateBestRecvSocket(); + socket->Open(port); + if (!socket->IsValid()) + return nullptr; + return CreateUdpHost(socket); + } + + TIntrusivePtr<IUdpHost> CreateUdpHost(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) { + if (!InitLocalIPList()) { + Y_ASSERT(0 && "Can not determine self IP address"); + return nullptr; + } + TIntrusivePtr<TUdpHost> res = new TUdpHost; + if (!res->Start(socket)) + return nullptr; + return res.Get(); + } + + void SetUdpMaxBandwidthPerIP(float f) { + f = Max(0.0f, f); + TCongestionControl::MaxPacketRate = f / UDP_PACKET_SIZE; + } + + void SetUdpSlowStart(bool enable) { + TCongestionControl::StartWindowSize = enable ? 0.5f : 3; + } + + void DisableIBDetection() { + IBDetection = false; + } + +} diff --git a/library/cpp/netliba/v6/udp_client_server.h b/library/cpp/netliba/v6/udp_client_server.h new file mode 100644 index 0000000000..23e0209405 --- /dev/null +++ b/library/cpp/netliba/v6/udp_client_server.h @@ -0,0 +1,62 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/guid.h> +#include <library/cpp/netliba/socket/socket.h> + +#include "udp_address.h" +#include "net_request.h" + +namespace NNetliba { + class TRopeDataPacket; + struct TRequesterPendingDataStats; + struct IPeerQueueStats; + + struct TSendResult { + int TransferId; + bool Success; + TSendResult() + : TransferId(-1) + , Success(false) + { + } + TSendResult(int transferId, bool success) + : TransferId(transferId) + , Success(success) + { + } + }; + + enum EPacketPriority { + PP_LOW, + PP_NORMAL, + PP_HIGH + }; + + // Step should be called from one and the same thread + // thread safety is caller responsibility + struct IUdpHost: public TThrRefBase { + virtual TRequest* GetRequest() = 0; + // returns trasferId + // Send() needs correctly computed crc32 + // crc32 is expected to be computed outside of the thread talking to IUdpHost to avoid crc32 computation delays + // packetGuid provides packet guid, if packetGuid is empty then guid is generated + virtual int Send(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data, int crc32, TGUID* packetGuid, EPacketPriority pp) = 0; + virtual bool GetSendResult(TSendResult* res) = 0; + virtual void Step() = 0; + virtual void IBStep() = 0; + virtual void Wait(float seconds) = 0; // does not use UdpHost + virtual void CancelWait() = 0; // thread safe + virtual void GetPendingDataSize(TRequesterPendingDataStats* res) = 0; + virtual TString GetDebugInfo() = 0; + virtual void Kill(const TUdpAddress& addr) = 0; + virtual TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) = 0; + }; + + TIntrusivePtr<IUdpHost> CreateUdpHost(int port); + TIntrusivePtr<IUdpHost> CreateUdpHost(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket); + + void SetUdpMaxBandwidthPerIP(float f); + void SetUdpSlowStart(bool enable); + void DisableIBDetection(); +} diff --git a/library/cpp/netliba/v6/udp_debug.cpp b/library/cpp/netliba/v6/udp_debug.cpp new file mode 100644 index 0000000000..1e02fa9e2b --- /dev/null +++ b/library/cpp/netliba/v6/udp_debug.cpp @@ -0,0 +1,2 @@ +#include "stdafx.h" +#include "udp_debug.h" diff --git a/library/cpp/netliba/v6/udp_debug.h b/library/cpp/netliba/v6/udp_debug.h new file mode 100644 index 0000000000..77eea61c23 --- /dev/null +++ b/library/cpp/netliba/v6/udp_debug.h @@ -0,0 +1,21 @@ +#pragma once + +namespace NNetliba { + struct TRequesterPendingDataStats { + int InpCount, OutCount; + ui64 InpDataSize, OutDataSize; + + TRequesterPendingDataStats() { + memset(this, 0, sizeof(*this)); + } + }; + + struct TRequesterQueueStats { + int ReqCount, RespCount; + ui64 ReqQueueSize, RespQueueSize; + + TRequesterQueueStats() { + memset(this, 0, sizeof(*this)); + } + }; +} diff --git a/library/cpp/netliba/v6/udp_http.cpp b/library/cpp/netliba/v6/udp_http.cpp new file mode 100644 index 0000000000..9fa0b07818 --- /dev/null +++ b/library/cpp/netliba/v6/udp_http.cpp @@ -0,0 +1,1354 @@ +#include "stdafx.h" +#include "udp_http.h" +#include "udp_client_server.h" +#include "udp_socket.h" +#include "cpu_affinity.h" + +#include <library/cpp/threading/atomic/bool.h> + +#include <util/system/hp_timer.h> +#include <util/thread/lfqueue.h> +#include <util/system/thread.h> +#include <util/system/spinlock.h> +#if !defined(_win_) +#include <signal.h> +#include <pthread.h> +#endif +#include "block_chain.h" +#include <util/system/shmat.h> + +#include <atomic> + +namespace NNetliba { + const float HTTP_TIMEOUT = 15.0f; + const int MIN_SHARED_MEM_PACKET = 1000; + + static ::NAtomic::TBool PanicAttack; + static std::atomic<NHPTimer::STime> LastHeartbeat; + static std::atomic<double> HeartbeatTimeout; + + static int GetPacketSize(TRequest* req) { + if (req && req->Data.Get()) + return req->Data->GetSize(); + return 0; + } + + static bool IsLocalFast(const TUdpAddress& addr) { + if (addr.IsIPv4()) { + return IsLocalIPv4(addr.GetIPv4()); + } else { + return IsLocalIPv6(addr.Network, addr.Interface); + } + } + + bool IsLocal(const TUdpAddress& addr) { + InitLocalIPList(); + return IsLocalFast(addr); + } + + TUdpHttpRequest::~TUdpHttpRequest() { + } + + TUdpHttpResponse::~TUdpHttpResponse() { + } + + class TRequesterUserQueueSizes: public TThrRefBase { + public: + TAtomic ReqCount, RespCount; + TAtomic ReqQueueSize, RespQueueSize; + + TRequesterUserQueueSizes() + : ReqCount(0) + , RespCount(0) + , ReqQueueSize(0) + , RespQueueSize(0) + { + } + }; + + template <class T> + void EraseList(TLockFreeQueue<T*>* data) { + T* ptr = nullptr; + while (data->Dequeue(&ptr)) { + delete ptr; + } + } + + class TRequesterUserQueues: public TThrRefBase { + TIntrusivePtr<TRequesterUserQueueSizes> QueueSizes; + TLockFreeQueue<TUdpHttpRequest*> ReqList; + TLockFreeQueue<TUdpHttpResponse*> ResponseList; + TLockFreeStack<TGUID> CancelList, SendRequestAccList; // any order will do + TMuxEvent AsyncEvent; + + void UpdateAsyncSignalState() { + // not sure about this one. Idea is that AsyncEvent.Reset() is a memory barrier + if (ReqList.IsEmpty() && ResponseList.IsEmpty() && CancelList.IsEmpty() && SendRequestAccList.IsEmpty()) { + AsyncEvent.Reset(); + if (!ReqList.IsEmpty() || !ResponseList.IsEmpty() || !CancelList.IsEmpty() || !SendRequestAccList.IsEmpty()) + AsyncEvent.Signal(); + } + } + ~TRequesterUserQueues() override { + EraseList(&ReqList); + EraseList(&ResponseList); + } + + public: + TRequesterUserQueues(TRequesterUserQueueSizes* queueSizes) + : QueueSizes(queueSizes) + { + } + TUdpHttpRequest* GetRequest(); + TUdpHttpResponse* GetResponse(); + bool GetRequestCancel(TGUID* req) { + bool res = CancelList.Dequeue(req); + UpdateAsyncSignalState(); + return res; + } + bool GetSendRequestAcc(TGUID* req) { + bool res = SendRequestAccList.Dequeue(req); + UpdateAsyncSignalState(); + return res; + } + + void AddRequest(TUdpHttpRequest* res) { + AtomicAdd(QueueSizes->ReqCount, 1); + AtomicAdd(QueueSizes->ReqQueueSize, GetPacketSize(res->DataHolder.Get())); + ReqList.Enqueue(res); + AsyncEvent.Signal(); + } + void AddResponse(TUdpHttpResponse* res) { + AtomicAdd(QueueSizes->RespCount, 1); + AtomicAdd(QueueSizes->RespQueueSize, GetPacketSize(res->DataHolder.Get())); + ResponseList.Enqueue(res); + AsyncEvent.Signal(); + } + void AddCancel(const TGUID& req) { + CancelList.Enqueue(req); + AsyncEvent.Signal(); + } + void AddSendRequestAcc(const TGUID& req) { + SendRequestAccList.Enqueue(req); + AsyncEvent.Signal(); + } + TMuxEvent& GetAsyncEvent() { + return AsyncEvent; + } + void AsyncSignal() { + AsyncEvent.Signal(); + } + }; + + struct TOutRequestState { + enum EState { + S_SENDING, + S_WAITING, + S_WAITING_PING_SENDING, + S_WAITING_PING_SENT, + S_CANCEL_AFTER_SENDING + }; + EState State; + TUdpAddress Address; + double TimePassed; + int PingTransferId; + TIntrusivePtr<TRequesterUserQueues> UserQueues; + + TOutRequestState() + : State(S_SENDING) + , TimePassed(0) + , PingTransferId(-1) + { + } + }; + + struct TInRequestState { + enum EState { + S_WAITING, + S_RESPONSE_SENDING, + S_CANCELED, + }; + EState State; + TUdpAddress Address; + + TInRequestState() + : State(S_WAITING) + { + } + TInRequestState(const TUdpAddress& address) + : State(S_WAITING) + , Address(address) + { + } + }; + + enum EHttpPacket { + PKT_REQUEST, + PKT_PING, + PKT_PING_RESPONSE, + PKT_RESPONSE, + PKT_GETDEBUGINFO, + PKT_LOCAL_REQUEST, + PKT_LOCAL_RESPONSE, + PKT_CANCEL, + }; + + class TUdpHttp: public IRequester { + enum EDir { + DIR_OUT, + DIR_IN + }; + struct TTransferPurpose { + EDir Dir; + TGUID Guid; + TTransferPurpose() + : Dir(DIR_OUT) + { + } + TTransferPurpose(EDir dir, TGUID guid) + : Dir(dir) + , Guid(guid) + { + } + }; + + struct TSendRequest { + TUdpAddress Addr; + TAutoPtr<TRopeDataPacket> Data; + TGUID ReqGuid; + TIntrusivePtr<TWaitResponse> WR; + TIntrusivePtr<TRequesterUserQueues> UserQueues; + ui32 Crc32; + + TSendRequest() + : Crc32(0) + { + } + TSendRequest(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket>* data, const TGUID& reqguid, TWaitResponse* wr, TRequesterUserQueues* userQueues) + : Addr(addr) + , Data(*data) + , ReqGuid(reqguid) + , WR(wr) + , UserQueues(userQueues) + , Crc32(CalcChecksum(Data->GetChain())) + { + } + }; + struct TSendResponse { + TVector<char> Data; + TGUID ReqGuid; + ui32 DataCrc32; + EPacketPriority Priority; + + TSendResponse() + : DataCrc32(0) + , Priority(PP_NORMAL) + { + } + TSendResponse(const TGUID& reqguid, EPacketPriority prior, TVector<char>* data) + : ReqGuid(reqguid) + , DataCrc32(0) + , Priority(prior) + { + if (data && !data->empty()) { + data->swap(Data); + DataCrc32 = TIncrementalChecksumCalcer::CalcBlockSum(&Data[0], Data.ysize()); + } + } + }; + struct TCancelRequest { + TGUID ReqGuid; + + TCancelRequest() = default; + TCancelRequest(const TGUID& reqguid) + : ReqGuid(reqguid) + { + } + }; + struct TBreakRequest { + TGUID ReqGuid; + + TBreakRequest() = default; + TBreakRequest(const TGUID& reqguid) + : ReqGuid(reqguid) + { + } + }; + + TThread myThread; + bool KeepRunning, AbortTransactions; + TSpinLock cs; + TSystemEvent HasStarted; + + NHPTimer::STime PingsSendT; + + TIntrusivePtr<IUdpHost> Host; + TIntrusivePtr<NNetlibaSocket::ISocket> Socket; + typedef THashMap<TGUID, TOutRequestState, TGUIDHash> TOutRequestHash; + typedef THashMap<TGUID, TInRequestState, TGUIDHash> TInRequestHash; + TOutRequestHash OutRequests; + TInRequestHash InRequests; + + typedef THashMap<int, TTransferPurpose> TTransferHash; + TTransferHash TransferHash; + + typedef THashMap<TGUID, TIntrusivePtr<TWaitResponse>, TGUIDHash> TSyncRequests; + TSyncRequests SyncRequests; + + // hold it here to not construct on every DoSends() + typedef THashSet<TGUID, TGUIDHash> TAnticipateCancels; + TAnticipateCancels AnticipateCancels; + + TLockFreeQueue<TSendRequest*> SendReqList; + TLockFreeQueue<TSendResponse*> SendRespList; + TLockFreeQueue<TCancelRequest> CancelReqList; + TLockFreeQueue<TBreakRequest> BreakReqList; + + TIntrusivePtr<TRequesterUserQueueSizes> QueueSizes; + TIntrusivePtr<TRequesterUserQueues> UserQueues; + + struct TStatsRequest: public TThrRefBase { + enum EReq { + PENDING_SIZE, + DEBUG_INFO, + HAS_IN_REQUEST, + GET_PEER_ADDRESS, + GET_PEER_QUEUE_STATS, + }; + EReq Req; + TRequesterPendingDataStats PendingDataSize; + TString DebugInfo; + TGUID RequestId; + TUdpAddress PeerAddress; + TIntrusivePtr<IPeerQueueStats> QueueStats; + bool RequestFound; + TSystemEvent Complete; + + TStatsRequest(EReq req) + : Req(req) + , RequestFound(false) + { + } + }; + TLockFreeQueue<TIntrusivePtr<TStatsRequest>> StatsReqList; + + bool ReportRequestCancel; + bool ReportSendRequestAcc; + + void FinishRequest(TOutRequestHash::iterator i, TUdpHttpResponse::EResult ok, TAutoPtr<TRequest> data, const char* error = nullptr) { + TOutRequestState& s = i->second; + TUdpHttpResponse* res = new TUdpHttpResponse; + res->DataHolder = data; + res->ReqId = i->first; + res->PeerAddress = s.Address; + res->Ok = ok; + if (ok == TUdpHttpResponse::FAILED) + res->Error = error ? error : "request failed"; + else if (ok == TUdpHttpResponse::CANCELED) + res->Error = error ? error : "request cancelled"; + TSyncRequests::iterator k = SyncRequests.find(res->ReqId); + if (k != SyncRequests.end()) { + TIntrusivePtr<TWaitResponse>& wr = k->second; + wr->SetResponse(res); + SyncRequests.erase(k); + } else { + s.UserQueues->AddResponse(res); + } + + OutRequests.erase(i); + } + int SendWithHighPriority(const TUdpAddress& addr, TAutoPtr<TRopeDataPacket> data) { + ui32 crc32 = CalcChecksum(data->GetChain()); + return Host->Send(addr, data.Release(), crc32, nullptr, PP_HIGH); + } + void ProcessIncomingPackets() { + TVector<TGUID, TCustomAllocator<TGUID>> failedRequests; + for (;;) { + TAutoPtr<TRequest> req = Host->GetRequest(); + if (req.Get() == nullptr) { + if (!failedRequests.empty()) { + // we want to handle following sequence of events + // <- send ping + // -> send response over IB + // -> send ping response (no such request) over UDP + // Now if we are lucky enough we can get IB response waiting in the IB receive queue + // at the same time response sender will receive "send complete" from IB + // indeed, IB delivered message (but it was not parsed by ib_cs.cpp yet) + // so after receiving "send response complete" event resposne sender can legally response + // to pings with "no such request" + // but ping responses can be sent over UDP + // So we can run into situation with negative ping response in + // UDP receive queue and response waiting unprocessed in IB receive queue + // to check that there is no response in the IB queue we have to process IB queues + // so we call IBStep() + Host->IBStep(); + req = Host->GetRequest(); + if (req.Get() == nullptr) { + break; + } + } else { + break; + } + } + + TBlockChainIterator reqData(req->Data->GetChain()); + char pktType; + reqData.Read(&pktType, 1); + switch (pktType) { + case PKT_REQUEST: + case PKT_LOCAL_REQUEST: { + //printf("recv PKT_REQUEST or PKT_LOCAL_REQUEST\n"); + TGUID reqId = req->Guid; + TInRequestHash::iterator z = InRequests.find(reqId); + if (z != InRequests.end()) { + // oops, this request already exists! + // might happen if request can be stored in single packet + // and this packet had source IP broken during transmission and managed to pass crc checks + // since we already reported wrong source address for this request to the user + // the best thing we can do is to stop the program to avoid further complications + // but we just report the accident to stderr + fprintf(stderr, "Jackpot, same request %s received twice from %s and earlier from %s\n", + GetGuidAsString(reqId).c_str(), GetAddressAsString(z->second.Address).c_str(), + GetAddressAsString(req->Address).c_str()); + } else { + InRequests[reqId] = TInRequestState(req->Address); + + //printf("InReq %s PKT_REQUEST recv ... -> S_WAITING\n", GetGuidAsString(reqId).c_str()); + + TUdpHttpRequest* res = new TUdpHttpRequest; + res->ReqId = reqId; + res->PeerAddress = req->Address; + res->DataHolder = req; + + UserQueues->AddRequest(res); + } + } break; + case PKT_PING: { + //printf("recv PKT_PING\n"); + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + bool ok = InRequests.find(guid) != InRequests.end(); + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_PING_RESPONSE); + ms->Write(guid); + ms->Write(ok); + SendWithHighPriority(req->Address, ms.Release()); + //printf("InReq %s PKT_PING recv Sending PKT_PING_RESPONSE\n", GetGuidAsString(guid).c_str()); + //printf("got PKT_PING, responding %d\n", (int)ok); + } break; + case PKT_PING_RESPONSE: { + //printf("recv PKT_PING_RESPONSE\n"); + TGUID guid; + bool ok; + reqData.Read(&guid, sizeof(guid)); + reqData.Read(&ok, sizeof(ok)); + TOutRequestHash::iterator i = OutRequests.find(guid); + if (i == OutRequests.end()) { + ; //Y_ASSERT(0); // actually possible with some packet orders + } else { + if (!ok) { + // can not delete request at this point + // since we can receive failed ping and response at the same moment + // consider sequence: client sends ping, server sends response + // and replies false to ping as reply is sent + // we can not receive failed ping_response earlier then response itself + // but we can receive them simultaneously + failedRequests.push_back(guid); + //printf("OutReq %s PKT_PING_RESPONSE recv no such query -> failed\n", GetGuidAsString(guid).c_str()); + } else { + TOutRequestState& s = i->second; + switch (s.State) { + case TOutRequestState::S_WAITING_PING_SENDING: { + Y_ASSERT(s.PingTransferId >= 0); + TTransferHash::iterator k = TransferHash.find(s.PingTransferId); + if (k != TransferHash.end()) + TransferHash.erase(k); + else + Y_ASSERT(0); + s.PingTransferId = -1; + s.TimePassed = 0; + s.State = TOutRequestState::S_WAITING; + //printf("OutReq %s PKT_PING_RESPONSE recv S_WAITING_PING_SENDING -> S_WAITING\n", GetGuidAsString(guid).c_str()); + } break; + case TOutRequestState::S_WAITING_PING_SENT: + s.TimePassed = 0; + s.State = TOutRequestState::S_WAITING; + //printf("OutReq %s PKT_PING_RESPONSE recv S_WAITING_PING_SENT -> S_WAITING\n", GetGuidAsString(guid).c_str()); + break; + default: + Y_ASSERT(0); + break; + } + } + } + } break; + case PKT_RESPONSE: + case PKT_LOCAL_RESPONSE: { + //printf("recv PKT_RESPONSE or PKT_LOCAL_RESPONSE\n"); + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + TOutRequestHash::iterator i = OutRequests.find(guid); + if (i == OutRequests.end()) { + ; //Y_ASSERT(0); // does happen + //printf("OutReq %s PKT_RESPONSE recv for non-existing req\n", GetGuidAsString(guid).c_str()); + } else { + FinishRequest(i, TUdpHttpResponse::OK, req); + //printf("OutReq %s PKT_RESPONSE recv ... -> ok\n", GetGuidAsString(guid).c_str()); + } + } break; + case PKT_CANCEL: { + //printf("recv PKT_CANCEL\n"); + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + TInRequestHash::iterator i = InRequests.find(guid); + if (i == InRequests.end()) { + ; //Y_ASSERT(0); // may happen + //printf("InReq %s PKT_CANCEL recv for non-existing req\n", GetGuidAsString(guid).c_str()); + } else { + TInRequestState& s = i->second; + if (s.State != TInRequestState::S_CANCELED && ReportRequestCancel) + UserQueues->AddCancel(guid); + s.State = TInRequestState::S_CANCELED; + //printf("InReq %s PKT_CANCEL recv\n", GetGuidAsString(guid).c_str()); + } + } break; + case PKT_GETDEBUGINFO: { + //printf("recv PKT_GETDEBUGINFO\n"); + TString dbgInfo = GetDebugInfoLocked(); + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write(dbgInfo.c_str(), (int)dbgInfo.size()); + SendWithHighPriority(req->Address, ms); + } break; + default: + Y_ASSERT(0); + } + } + // cleanup failed requests + for (size_t k = 0; k < failedRequests.size(); ++k) { + const TGUID& guid = failedRequests[k]; + TOutRequestHash::iterator i = OutRequests.find(guid); + if (i != OutRequests.end()) + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: recv no such query"); + } + } + void AnalyzeSendResults() { + TSendResult res; + while (Host->GetSendResult(&res)) { + //printf("Send result received\n"); + TTransferHash::iterator k1 = TransferHash.find(res.TransferId); + if (k1 != TransferHash.end()) { + const TTransferPurpose& tp = k1->second; + switch (tp.Dir) { + case DIR_OUT: { + TOutRequestHash::iterator i = OutRequests.find(tp.Guid); + if (i != OutRequests.end()) { + const TGUID& reqId = i->first; + TOutRequestState& s = i->second; + switch (s.State) { + case TOutRequestState::S_SENDING: + if (!res.Success) { + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_SENDING"); + //printf("OutReq %s AnalyzeSendResults() S_SENDING -> failed\n", GetGuidAsString(reqId).c_str()); + } else { + if (ReportSendRequestAcc) { + if (s.UserQueues.Get()) { + s.UserQueues->AddSendRequestAcc(reqId); + } else { + // waitable request? + TSyncRequests::iterator k2 = SyncRequests.find(reqId); + if (k2 != SyncRequests.end()) { + TIntrusivePtr<TWaitResponse>& wr = k2->second; + wr->SetRequestSent(); + } + } + } + s.State = TOutRequestState::S_WAITING; + //printf("OutReq %s AnalyzeSendResults() S_SENDING -> S_WAITING\n", GetGuidAsString(reqId).c_str()); + s.TimePassed = 0; + } + break; + case TOutRequestState::S_CANCEL_AFTER_SENDING: + DoSendCancel(s.Address, reqId); + FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request failed: state S_CANCEL_AFTER_SENDING"); + break; + case TOutRequestState::S_WAITING: + case TOutRequestState::S_WAITING_PING_SENT: + Y_ASSERT(0); + break; + case TOutRequestState::S_WAITING_PING_SENDING: + Y_ASSERT(s.PingTransferId >= 0 && s.PingTransferId == res.TransferId); + if (!res.Success) { + FinishRequest(i, TUdpHttpResponse::FAILED, nullptr, "request failed: state S_WAITING_PING_SENDING"); + //printf("OutReq %s AnalyzeSendResults() S_WAITING_PING_SENDING -> failed\n", GetGuidAsString(reqId).c_str()); + } else { + s.PingTransferId = -1; + s.State = TOutRequestState::S_WAITING_PING_SENT; + //printf("OutReq %s AnalyzeSendResults() S_WAITING_PING_SENDING -> S_WAITING_PING_SENT\n", GetGuidAsString(reqId).c_str()); + s.TimePassed = 0; + } + break; + default: + Y_ASSERT(0); + break; + } + } + } break; + case DIR_IN: { + TInRequestHash::iterator i = InRequests.find(tp.Guid); + if (i != InRequests.end()) { + Y_ASSERT(i->second.State == TInRequestState::S_RESPONSE_SENDING || i->second.State == TInRequestState::S_CANCELED); + InRequests.erase(i); + //if (res.Success) + // printf("InReq %s AnalyzeSendResults() ... -> finished\n", GetGuidAsString(tp.Guid).c_str()); + //else + // printf("InReq %s AnalyzeSendResults() ... -> failed response send\n", GetGuidAsString(tp.Guid).c_str()); + } + } break; + default: + Y_ASSERT(0); + break; + } + TransferHash.erase(k1); + } + } + } + void SendPingsIfNeeded() { + NHPTimer::STime tChk = PingsSendT; + float deltaT = (float)NHPTimer::GetTimePassed(&tChk); + if (deltaT < 0.05) { + return; + } + PingsSendT = tChk; + deltaT = ClampVal(deltaT, 0.0f, HTTP_TIMEOUT / 3); + + { + for (TOutRequestHash::iterator i = OutRequests.begin(); i != OutRequests.end();) { + TOutRequestHash::iterator curIt = i++; + TOutRequestState& s = curIt->second; + const TGUID& guid = curIt->first; + switch (s.State) { + case TOutRequestState::S_WAITING: + s.TimePassed += deltaT; + if (s.TimePassed > HTTP_TIMEOUT) { + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_PING); + ms->Write(guid); + int transId = SendWithHighPriority(s.Address, ms.Release()); + TransferHash[transId] = TTransferPurpose(DIR_OUT, guid); + s.State = TOutRequestState::S_WAITING_PING_SENDING; + //printf("OutReq %s SendPingsIfNeeded() S_WAITING -> S_WAITING_PING_SENDING\n", GetGuidAsString(guid).c_str()); + s.PingTransferId = transId; + } + break; + case TOutRequestState::S_WAITING_PING_SENT: + s.TimePassed += deltaT; + if (s.TimePassed > HTTP_TIMEOUT) { + //printf("OutReq %s SendPingsIfNeeded() S_WAITING_PING_SENT -> failed\n", GetGuidAsString(guid).c_str()); + FinishRequest(curIt, TUdpHttpResponse::FAILED, nullptr, "request failed: http timeout in state S_WAITING_PING_SENT"); + } + break; + default: + break; + } + } + } + } + void Step() { + { + TGuard<TSpinLock> lock(cs); + DoSends(); + } + Host->Step(); + for (TIntrusivePtr<TStatsRequest> req; StatsReqList.Dequeue(&req);) { + switch (req->Req) { + case TStatsRequest::PENDING_SIZE: + Host->GetPendingDataSize(&req->PendingDataSize); + break; + case TStatsRequest::DEBUG_INFO: { + TGuard<TSpinLock> lock(cs); + req->DebugInfo = GetDebugInfoLocked(); + } break; + case TStatsRequest::HAS_IN_REQUEST: { + TGuard<TSpinLock> lock(cs); + req->RequestFound = (InRequests.find(req->RequestId) != InRequests.end()); + } break; + case TStatsRequest::GET_PEER_ADDRESS: { + TGuard<TSpinLock> lock(cs); + TInRequestHash::const_iterator i = InRequests.find(req->RequestId); + if (i != InRequests.end()) { + req->PeerAddress = i->second.Address; + } else { + TOutRequestHash::const_iterator o = OutRequests.find(req->RequestId); + if (o != OutRequests.end()) { + req->PeerAddress = o->second.Address; + } else { + req->PeerAddress = TUdpAddress(); + } + } + } break; + case TStatsRequest::GET_PEER_QUEUE_STATS: + req->QueueStats = Host->GetQueueStats(req->PeerAddress); + break; + default: + Y_ASSERT(0); + break; + } + req->Complete.Signal(); + } + { + TGuard<TSpinLock> lock(cs); + DoSends(); + ProcessIncomingPackets(); + AnalyzeSendResults(); + SendPingsIfNeeded(); + } + } + void Wait() { + Host->Wait(0.1f); + } + void DoSendCancel(const TUdpAddress& addr, const TGUID& req) { + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ms->Write((char)PKT_CANCEL); + ms->Write(req); + SendWithHighPriority(addr, ms); + } + void DoSends() { + { + TBreakRequest rb; + while (BreakReqList.Dequeue(&rb)) { + InRequests.erase(rb.ReqGuid); + } + } + { + // cancelling requests + TCancelRequest rc; + while (CancelReqList.Dequeue(&rc)) { + TOutRequestHash::iterator i = OutRequests.find(rc.ReqGuid); + if (i == OutRequests.end()) { + AnticipateCancels.insert(rc.ReqGuid); + continue; // cancelling non existing request is ok + } + TOutRequestState& s = i->second; + if (s.State == TOutRequestState::S_SENDING) { + // we are in trouble - have not sent request and we already have to cancel it, wait send + s.State = TOutRequestState::S_CANCEL_AFTER_SENDING; + } else { + DoSendCancel(s.Address, rc.ReqGuid); + FinishRequest(i, TUdpHttpResponse::CANCELED, nullptr, "request canceled: notify requested side"); + } + } + } + { + // sending replies + for (TSendResponse* rd = nullptr; SendRespList.Dequeue(&rd); delete rd) { + TInRequestHash::iterator i = InRequests.find(rd->ReqGuid); + if (i == InRequests.end()) { + Y_ASSERT(0); + continue; + } + TInRequestState& s = i->second; + if (s.State == TInRequestState::S_CANCELED) { + // need not send response for the canceled request + InRequests.erase(i); + continue; + } + + Y_ASSERT(s.State == TInRequestState::S_WAITING); + s.State = TInRequestState::S_RESPONSE_SENDING; + //printf("InReq %s SendResponse() ... -> S_RESPONSE_SENDING (pkt %s)\n", GetGuidAsString(reqId).c_str(), GetGuidAsString(lowPktGuid).c_str()); + + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + ui32 crc32 = 0; + int dataSize = rd->Data.ysize(); + if (rd->Data.ysize() > MIN_SHARED_MEM_PACKET && IsLocalFast(s.Address)) { + TIntrusivePtr<TSharedMemory> shm = new TSharedMemory; + if (shm->Create(dataSize)) { + ms->Write((char)PKT_LOCAL_RESPONSE); + ms->Write(rd->ReqGuid); + memcpy(shm->GetPtr(), &rd->Data[0], dataSize); + TVector<char> empty; + rd->Data.swap(empty); + ms->AttachSharedData(shm); + crc32 = CalcChecksum(ms->GetChain()); + } + } + if (ms->GetSharedData() == nullptr) { + ms->Write((char)PKT_RESPONSE); + ms->Write(rd->ReqGuid); + + // to offload crc calcs from inner thread, crc of data[] is calced outside and passed in DataCrc32 + // this means that we are calculating crc when shared memory is used + // it is hard to avoid since in SendResponse() we don't know if shared mem will be used (peer address is not available there) + TIncrementalChecksumCalcer csCalcer; + AddChain(&csCalcer, ms->GetChain()); + // here we are replicating the way WriteDestructive serializes data + csCalcer.AddBlock(&dataSize, sizeof(dataSize)); + csCalcer.AddBlockSum(rd->DataCrc32, dataSize); + crc32 = csCalcer.CalcChecksum(); + + ms->WriteDestructive(&rd->Data); + //ui32 chkCrc = CalcChecksum(ms->GetChain()); // can not use since its slow for large responses + //Y_ASSERT(chkCrc == crc32); + } + + int transId = Host->Send(s.Address, ms.Release(), crc32, nullptr, rd->Priority); + TransferHash[transId] = TTransferPurpose(DIR_IN, rd->ReqGuid); + } + } + { + // sending requests + for (TSendRequest* rd = nullptr; SendReqList.Dequeue(&rd); delete rd) { + Y_ASSERT(OutRequests.find(rd->ReqGuid) == OutRequests.end()); + + { + TOutRequestState& s = OutRequests[rd->ReqGuid]; + s.State = TOutRequestState::S_SENDING; + s.Address = rd->Addr; + s.UserQueues = rd->UserQueues; + //printf("OutReq %s SendRequest() ... -> S_SENDING\n", GetGuidAsString(guid).c_str()); + } + + if (rd->WR.Get()) + SyncRequests[rd->ReqGuid] = rd->WR; + + if (AnticipateCancels.find(rd->ReqGuid) != AnticipateCancels.end()) { + FinishRequest(OutRequests.find(rd->ReqGuid), TUdpHttpResponse::CANCELED, nullptr, "request canceled before transmitting"); + } else { + TGUID pktGuid = rd->ReqGuid; // request packet id should match request id + int transId = Host->Send(rd->Addr, rd->Data.Release(), rd->Crc32, &pktGuid, PP_NORMAL); + TransferHash[transId] = TTransferPurpose(DIR_OUT, rd->ReqGuid); + } + } + } + if (!AnticipateCancels.empty()) { + AnticipateCancels.clear(); + } + } + + public: + void SendRequestImpl(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId, + TWaitResponse* wr, TRequesterUserQueues* userQueues) { + if (data && data->size() > MAX_PACKET_SIZE) { + Y_VERIFY(0, "data size is too large"); + } + //printf("SendRequest(%s)\n", url.c_str()); + if (wr) + wr->SetReqId(reqId); + + TAutoPtr<TRopeDataPacket> ms = new TRopeDataPacket; + if (data && data->ysize() > MIN_SHARED_MEM_PACKET && IsLocalFast(addr)) { + int dataSize = data->ysize(); + TIntrusivePtr<TSharedMemory> shm = new TSharedMemory; + if (shm->Create(dataSize)) { + ms->Write((char)PKT_LOCAL_REQUEST); + ms->WriteStroka(url); + memcpy(shm->GetPtr(), &(*data)[0], dataSize); + TVector<char> empty; + data->swap(empty); + ms->AttachSharedData(shm); + } + } + if (ms->GetSharedData() == nullptr) { + ms->Write((char)PKT_REQUEST); + ms->WriteStroka(url); + ms->WriteDestructive(data); + } + + SendReqList.Enqueue(new TSendRequest(addr, &ms, reqId, wr, userQueues)); + Host->CancelWait(); + } + + void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) override { + SendRequestImpl(addr, url, data, reqId, nullptr, UserQueues.Get()); + } + void CancelRequest(const TGUID& reqId) override { + CancelReqList.Enqueue(TCancelRequest(reqId)); + Host->CancelWait(); + } + void BreakRequest(const TGUID& reqId) override { + BreakReqList.Enqueue(TBreakRequest(reqId)); + Host->CancelWait(); + } + + void SendResponseImpl(const TGUID& reqId, EPacketPriority prior, TVector<char>* data) // non-virtual, for direct call from TRequestOps + { + if (data && data->size() > MAX_PACKET_SIZE) { + Y_VERIFY(0, "data size is too large"); + } + SendRespList.Enqueue(new TSendResponse(reqId, prior, data)); + Host->CancelWait(); + } + void SendResponse(const TGUID& reqId, TVector<char>* data) override { + SendResponseImpl(reqId, PP_NORMAL, data); + } + void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) override { + SendResponseImpl(reqId, PP_LOW, data); + } + TUdpHttpRequest* GetRequest() override { + return UserQueues->GetRequest(); + } + TUdpHttpResponse* GetResponse() override { + return UserQueues->GetResponse(); + } + bool GetRequestCancel(TGUID* req) override { + return UserQueues->GetRequestCancel(req); + } + bool GetSendRequestAcc(TGUID* req) override { + return UserQueues->GetSendRequestAcc(req); + } + TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) override { + TIntrusivePtr<TWaitResponse> wr = WaitableRequest(addr, url, data); + wr->Wait(); + return wr->GetResponse(); + } + TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) override { + TIntrusivePtr<TWaitResponse> wr = new TWaitResponse; + TGUID reqId; + CreateGuid(&reqId); + SendRequestImpl(addr, url, data, reqId, wr.Get(), nullptr); + return wr; + } + TMuxEvent& GetAsyncEvent() override { + return UserQueues->GetAsyncEvent(); + } + int GetPort() override { + return Socket.Get() ? Socket->GetPort() : 0; + } + void StopNoWait() override { + AbortTransactions = true; + KeepRunning = false; + UserQueues->AsyncSignal(); + // calcel all outgoing requests + TGuard<TSpinLock> lock(cs); + while (!OutRequests.empty()) { + // cancel without informing peer that we are cancelling the request + FinishRequest(OutRequests.begin(), TUdpHttpResponse::CANCELED, nullptr, "request canceled: inside TUdpHttp::StopNoWait()"); + } + } + void ExecStatsRequest(TIntrusivePtr<TStatsRequest> req) { + StatsReqList.Enqueue(req); + Host->CancelWait(); + req->Complete.Wait(); + } + TUdpAddress GetPeerAddress(const TGUID& reqId) override { + TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::GET_PEER_ADDRESS); + req->RequestId = reqId; + ExecStatsRequest(req); + return req->PeerAddress; + } + void GetPendingDataSize(TRequesterPendingDataStats* res) override { + TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::PENDING_SIZE); + ExecStatsRequest(req); + *res = req->PendingDataSize; + } + bool HasRequest(const TGUID& reqId) override { + TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::HAS_IN_REQUEST); + req->RequestId = reqId; + ExecStatsRequest(req); + return req->RequestFound; + } + + private: + void FinishOutstandingTransactions() { + // wait all pending requests, all new requests are canceled + while ((!OutRequests.empty() || !InRequests.empty() || !SendRespList.IsEmpty() || !SendReqList.IsEmpty()) && !PanicAttack) { + while (TUdpHttpRequest* req = GetRequest()) { + TInRequestHash::iterator i = InRequests.find(req->ReqId); + //printf("dropping request(%s) (thread %d)\n", req->Url.c_str(), ThreadId()); + delete req; + if (i == InRequests.end()) { + Y_ASSERT(0); + continue; + } + InRequests.erase(i); + } + Step(); + sleep(0); + } + } + static void* ExecServerThread(void* param) { + BindToSocket(0); + SetHighestThreadPriority(); + TUdpHttp* pThis = (TUdpHttp*)param; + pThis->Host = CreateUdpHost(pThis->Socket); + pThis->HasStarted.Signal(); + if (!pThis->Host) { + pThis->Socket.Drop(); + return nullptr; + } + NHPTimer::GetTime(&pThis->PingsSendT); + while (pThis->KeepRunning && !PanicAttack) { + if (HeartbeatTimeout.load(std::memory_order_acquire) > 0) { + NHPTimer::STime chk = LastHeartbeat.load(std::memory_order_acquire); + double passed = NHPTimer::GetTimePassed(&chk); + if (passed > HeartbeatTimeout.load(std::memory_order_acquire)) { + StopAllNetLibaThreads(); + fprintf(stderr, "%s\tTUdpHttp\tWaiting for %0.2f, time limit %0.2f, commit a suicide!11\n", Now().ToStringUpToSeconds().c_str(), passed, HeartbeatTimeout.load(std::memory_order_acquire)); + fflush(stderr); +#ifndef _win_ + killpg(0, SIGKILL); +#endif + abort(); + break; + } + } + pThis->Step(); + pThis->Wait(); + } + if (!pThis->AbortTransactions && !PanicAttack) + pThis->FinishOutstandingTransactions(); + pThis->Host = nullptr; + return nullptr; + } + ~TUdpHttp() override { + if (myThread.Running()) { + KeepRunning = false; + myThread.Join(); + } + for (TIntrusivePtr<TStatsRequest> req; StatsReqList.Dequeue(&req);) { + req->Complete.Signal(); + } + } + + public: + TUdpHttp() + : myThread(TThread::TParams(ExecServerThread, (void*)this).SetName("nl6_udp_host")) + , KeepRunning(true) + , AbortTransactions(false) + , PingsSendT(0) + , ReportRequestCancel(false) + , ReportSendRequestAcc(false) + { + NHPTimer::GetTime(&PingsSendT); + QueueSizes = new TRequesterUserQueueSizes; + UserQueues = new TRequesterUserQueues(QueueSizes.Get()); + } + bool Start(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) { + Y_ASSERT(Host.Get() == nullptr); + Socket = socket; + myThread.Start(); + HasStarted.Wait(); + + if (Host.Get()) { + return true; + } + Socket.Drop(); + return false; + } + TString GetDebugInfoLocked() { + TString res = KeepRunning ? "State: running\n" : "State: stopping\n"; + res += Host->GetDebugInfo(); + + char buf[1000]; + TRequesterUserQueueSizes* qs = QueueSizes.Get(); + sprintf(buf, "\nRequest queue %d (%d bytes)\n", (int)AtomicGet(qs->ReqCount), (int)AtomicGet(qs->ReqQueueSize)); + res += buf; + sprintf(buf, "Response queue %d (%d bytes)\n", (int)AtomicGet(qs->RespCount), (int)AtomicGet(qs->RespQueueSize)); + res += buf; + + const char* outReqStateNames[] = { + "S_SENDING", + "S_WAITING", + "S_WAITING_PING_SENDING", + "S_WAITING_PING_SENT", + "S_CANCEL_AFTER_SENDING"}; + const char* inReqStateNames[] = { + "S_WAITING", + "S_RESPONSE_SENDING", + "S_CANCELED"}; + res += "\nOut requests:\n"; + for (TOutRequestHash::const_iterator i = OutRequests.begin(); i != OutRequests.end(); ++i) { + const TGUID& gg = i->first; + const TOutRequestState& s = i->second; + bool isSync = SyncRequests.find(gg) != SyncRequests.end(); + sprintf(buf, "%s\t%s %s TimePassed: %g %s\n", + GetAddressAsString(s.Address).c_str(), GetGuidAsString(gg).c_str(), outReqStateNames[s.State], + s.TimePassed * 1000, + isSync ? "isSync" : ""); + res += buf; + } + res += "\nIn requests:\n"; + for (TInRequestHash::const_iterator i = InRequests.begin(); i != InRequests.end(); ++i) { + const TGUID& gg = i->first; + const TInRequestState& s = i->second; + sprintf(buf, "%s\t%s %s\n", + GetAddressAsString(s.Address).c_str(), GetGuidAsString(gg).c_str(), inReqStateNames[s.State]); + res += buf; + } + return res; + } + TString GetDebugInfo() override { + TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::DEBUG_INFO); + ExecStatsRequest(req); + return req->DebugInfo; + } + void GetRequestQueueSize(TRequesterQueueStats* res) override { + TRequesterUserQueueSizes* qs = QueueSizes.Get(); + res->ReqCount = (int)AtomicGet(qs->ReqCount); + res->RespCount = (int)AtomicGet(qs->RespCount); + res->ReqQueueSize = (int)AtomicGet(qs->ReqQueueSize); + res->RespQueueSize = (int)AtomicGet(qs->RespQueueSize); + } + TRequesterUserQueueSizes* GetQueueSizes() const { + return QueueSizes.Get(); + } + IRequestOps* CreateSubRequester() override; + void EnableReportRequestCancel() override { + ReportRequestCancel = true; + } + void EnableReportSendRequestAcc() override { + ReportSendRequestAcc = true; + } + TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) override { + TIntrusivePtr<TStatsRequest> req = new TStatsRequest(TStatsRequest::GET_PEER_QUEUE_STATS); + req->PeerAddress = addr; + ExecStatsRequest(req); + return req->QueueStats; + } + }; + + ////////////////////////////////////////////////////////////////////////// + static void ReadShm(TSharedMemory* shm, TVector<char>* data) { + Y_ASSERT(shm); + int dataSize = shm->GetSize(); + data->yresize(dataSize); + memcpy(&(*data)[0], shm->GetPtr(), dataSize); + } + + static void LoadRequestData(TUdpHttpRequest* res) { + if (!res) + return; + { + TBlockChainIterator reqData(res->DataHolder->Data->GetChain()); + char pktType; + reqData.Read(&pktType, 1); + ReadArr(&reqData, &res->Url); + if (pktType == PKT_REQUEST) { + ReadYArr(&reqData, &res->Data); + } else if (pktType == PKT_LOCAL_REQUEST) { + ReadShm(res->DataHolder->Data->GetSharedData(), &res->Data); + } else + Y_ASSERT(0); + if (reqData.HasFailed()) { + Y_ASSERT(0 && "wrong format, memory corruption suspected"); + res->Url = ""; + res->Data.clear(); + } + } + res->DataHolder.Reset(nullptr); + } + + static void LoadResponseData(TUdpHttpResponse* res) { + if (!res || res->DataHolder.Get() == nullptr) + return; + { + TBlockChainIterator reqData(res->DataHolder->Data->GetChain()); + char pktType; + reqData.Read(&pktType, 1); + TGUID guid; + reqData.Read(&guid, sizeof(guid)); + Y_ASSERT(res->ReqId == guid); + if (pktType == PKT_RESPONSE) { + ReadYArr(&reqData, &res->Data); + } else if (pktType == PKT_LOCAL_RESPONSE) { + ReadShm(res->DataHolder->Data->GetSharedData(), &res->Data); + } else + Y_ASSERT(0); + if (reqData.HasFailed()) { + Y_ASSERT(0 && "wrong format, memory corruption suspected"); + res->Ok = TUdpHttpResponse::FAILED; + res->Data.clear(); + res->Error = "wrong response format"; + } + } + res->DataHolder.Reset(nullptr); + } + + ////////////////////////////////////////////////////////////////////////// + // IRequestOps::TWaitResponse + TUdpHttpResponse* IRequestOps::TWaitResponse::GetResponse() { + if (!Response) + return nullptr; + TUdpHttpResponse* res = Response; + Response = nullptr; + LoadResponseData(res); + return res; + } + + void IRequestOps::TWaitResponse::SetResponse(TUdpHttpResponse* r) { + Y_ASSERT(Response == nullptr || r == nullptr); + if (r) + Response = r; + CompleteEvent.Signal(); + } + + ////////////////////////////////////////////////////////////////////////// + // TRequesterUserQueues + TUdpHttpRequest* TRequesterUserQueues::GetRequest() { + TUdpHttpRequest* res = nullptr; + ReqList.Dequeue(&res); + if (res) { + AtomicAdd(QueueSizes->ReqCount, -1); + AtomicAdd(QueueSizes->ReqQueueSize, -GetPacketSize(res->DataHolder.Get())); + } + UpdateAsyncSignalState(); + LoadRequestData(res); + return res; + } + + TUdpHttpResponse* TRequesterUserQueues::GetResponse() { + TUdpHttpResponse* res = nullptr; + ResponseList.Dequeue(&res); + if (res) { + AtomicAdd(QueueSizes->RespCount, -1); + AtomicAdd(QueueSizes->RespQueueSize, -GetPacketSize(res->DataHolder.Get())); + } + UpdateAsyncSignalState(); + LoadResponseData(res); + return res; + } + + ////////////////////////////////////////////////////////////////////////// + class TRequestOps: public IRequestOps { + TIntrusivePtr<TUdpHttp> Requester; + TIntrusivePtr<TRequesterUserQueues> UserQueues; + + public: + TRequestOps(TUdpHttp* req) + : Requester(req) + { + UserQueues = new TRequesterUserQueues(req->GetQueueSizes()); + } + void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) override { + Requester->SendRequestImpl(addr, url, data, reqId, nullptr, UserQueues.Get()); + } + void CancelRequest(const TGUID& reqId) override { + Requester->CancelRequest(reqId); + } + void BreakRequest(const TGUID& reqId) override { + Requester->BreakRequest(reqId); + } + + void SendResponse(const TGUID& reqId, TVector<char>* data) override { + Requester->SendResponseImpl(reqId, PP_NORMAL, data); + } + void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) override { + Requester->SendResponseImpl(reqId, PP_LOW, data); + } + TUdpHttpRequest* GetRequest() override { + Y_ASSERT(0); + //return UserQueues.GetRequest(); + return nullptr; // all requests are routed to the main requester + } + TUdpHttpResponse* GetResponse() override { + return UserQueues->GetResponse(); + } + bool GetRequestCancel(TGUID*) override { + Y_ASSERT(0); + return false; // all request cancels are routed to the main requester + } + bool GetSendRequestAcc(TGUID* req) override { + return UserQueues->GetSendRequestAcc(req); + } + // sync mode + TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) override { + return Requester->Request(addr, url, data); + } + TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) override { + return Requester->WaitableRequest(addr, url, data); + } + // + TMuxEvent& GetAsyncEvent() override { + return UserQueues->GetAsyncEvent(); + } + }; + + IRequestOps* TUdpHttp::CreateSubRequester() { + return new TRequestOps(this); + } + + ////////////////////////////////////////////////////////////////////////// + void AbortOnFailedRequest(TUdpHttpResponse* answer) { + if (answer && answer->Ok == TUdpHttpResponse::FAILED) { + fprintf(stderr, "Failed request to host %s\n", GetAddressAsString(answer->PeerAddress).data()); + fprintf(stderr, "Error description: %s\n", answer->Error.data()); + fflush(nullptr); + Y_ASSERT(0); + abort(); + } + } + + TString GetDebugInfo(const TUdpAddress& addr, double timeout) { + NHPTimer::STime start; + NHPTimer::GetTime(&start); + TIntrusivePtr<IUdpHost> host = CreateUdpHost(0); + { + TAutoPtr<TRopeDataPacket> rq = new TRopeDataPacket; + rq->Write((char)PKT_GETDEBUGINFO); + ui32 crc32 = CalcChecksum(rq->GetChain()); + host->Send(addr, rq.Release(), crc32, nullptr, PP_HIGH); + } + for (;;) { + TAutoPtr<TRequest> ptr = host->GetRequest(); + if (ptr.Get()) { + TBlockChainIterator reqData(ptr->Data->GetChain()); + int sz = reqData.GetSize(); + TString res; + res.resize(sz); + reqData.Read(res.begin(), sz); + return res; + } + host->Step(); + host->Wait(0.1f); + + NHPTimer::STime now; + NHPTimer::GetTime(&now); + if (NHPTimer::GetSeconds(now - start) > timeout) { + return TString(); + } + } + } + + void Kill(const TUdpAddress& addr) { + TIntrusivePtr<IUdpHost> host = CreateUdpHost(0); + host->Kill(addr); + } + + void StopAllNetLibaThreads() { + PanicAttack = true; // AAAA!!!! + } + + void SetNetLibaHeartbeatTimeout(double timeoutSec) { + NetLibaHeartbeat(); + HeartbeatTimeout.store(timeoutSec, std::memory_order_release); + } + + void NetLibaHeartbeat() { + NHPTimer::STime now; + NHPTimer::GetTime(&now); + LastHeartbeat.store(now, std::memory_order_release); + } + + IRequester* CreateHttpUdpRequester(int port) { + if (PanicAttack) + return nullptr; + + TIntrusivePtr<NNetlibaSocket::ISocket> socket = NNetlibaSocket::CreateSocket(); + socket->Open(port); + if (!socket->IsValid()) + return nullptr; + + return CreateHttpUdpRequester(socket); + } + + IRequester* CreateHttpUdpRequester(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) { + if (PanicAttack) + return nullptr; + + TIntrusivePtr<TUdpHttp> res(new TUdpHttp); + if (!res->Start(socket)) + return nullptr; + return res.Release(); + } + +} diff --git a/library/cpp/netliba/v6/udp_http.h b/library/cpp/netliba/v6/udp_http.h new file mode 100644 index 0000000000..1084e7affa --- /dev/null +++ b/library/cpp/netliba/v6/udp_http.h @@ -0,0 +1,148 @@ +#pragma once + +#include "udp_address.h" +#include "udp_debug.h" +#include "net_queue_stat.h" + +#include <util/network/init.h> +#include <util/generic/ptr.h> +#include <util/generic/guid.h> +#include <library/cpp/threading/mux_event/mux_event.h> +#include <library/cpp/netliba/socket/socket.h> + +namespace NNetliba { + const ui64 MAX_PACKET_SIZE = 0x70000000; + + struct TRequest; + struct TUdpHttpRequest { + TAutoPtr<TRequest> DataHolder; + TGUID ReqId; + TString Url; + TUdpAddress PeerAddress; + TVector<char> Data; + + ~TUdpHttpRequest(); + }; + + struct TUdpHttpResponse { + enum EResult { + FAILED = 0, + OK = 1, + CANCELED = 2 + }; + TAutoPtr<TRequest> DataHolder; + TGUID ReqId; + TUdpAddress PeerAddress; + TVector<char> Data; + EResult Ok; + TString Error; + + ~TUdpHttpResponse(); + }; + + // vector<char> *data - vector will be cleared upon call + struct IRequestOps: public TThrRefBase { + class TWaitResponse: public TThrRefBase, public TNonCopyable { + TGUID ReqId; + TMuxEvent CompleteEvent; + TUdpHttpResponse* Response; + bool RequestSent; + + ~TWaitResponse() override { + delete GetResponse(); + } + + public: + TWaitResponse() + : Response(nullptr) + , RequestSent(false) + { + } + void Wait() { + CompleteEvent.Wait(); + } + bool Wait(int ms) { + return CompleteEvent.Wait(ms); + } + TUdpHttpResponse* GetResponse(); + bool IsRequestSent() const { + return RequestSent; + } + void SetResponse(TUdpHttpResponse* r); + void SetReqId(const TGUID& reqId) { + ReqId = reqId; + } + const TGUID& GetReqId() { + return ReqId; + } + void SetRequestSent() { + RequestSent = true; + } + }; + + // async + virtual void SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data, const TGUID& reqId) = 0; + TGUID SendRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) { + TGUID reqId; + CreateGuid(&reqId); + SendRequest(addr, url, data, reqId); + return reqId; + } + virtual void CancelRequest(const TGUID& reqId) = 0; //cancel request from requester side + virtual void BreakRequest(const TGUID& reqId) = 0; //break request-response from requester side + + virtual void SendResponse(const TGUID& reqId, TVector<char>* data) = 0; + virtual void SendResponseLowPriority(const TGUID& reqId, TVector<char>* data) = 0; + virtual TUdpHttpRequest* GetRequest() = 0; + virtual TUdpHttpResponse* GetResponse() = 0; + virtual bool GetRequestCancel(TGUID* req) = 0; + virtual bool GetSendRequestAcc(TGUID* req) = 0; + // sync mode + virtual TUdpHttpResponse* Request(const TUdpAddress& addr, const TString& url, TVector<char>* data) = 0; + virtual TIntrusivePtr<TWaitResponse> WaitableRequest(const TUdpAddress& addr, const TString& url, TVector<char>* data) = 0; + // + virtual TMuxEvent& GetAsyncEvent() = 0; + }; + + struct IRequester: public IRequestOps { + virtual int GetPort() = 0; + virtual void StopNoWait() = 0; + virtual TUdpAddress GetPeerAddress(const TGUID& reqId) = 0; + virtual void GetPendingDataSize(TRequesterPendingDataStats* res) = 0; + virtual bool HasRequest(const TGUID& reqId) = 0; + virtual TString GetDebugInfo() = 0; + virtual void GetRequestQueueSize(TRequesterQueueStats* res) = 0; + virtual IRequestOps* CreateSubRequester() = 0; + virtual void EnableReportRequestCancel() = 0; + virtual void EnableReportSendRequestAcc() = 0; + virtual TIntrusivePtr<IPeerQueueStats> GetQueueStats(const TUdpAddress& addr) = 0; + + ui64 GetPendingDataSize() { + TRequesterPendingDataStats pds; + GetPendingDataSize(&pds); + return pds.InpDataSize + pds.OutDataSize; + } + }; + + IRequester* CreateHttpUdpRequester(int port); + IRequester* CreateHttpUdpRequester(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket); + + void SetUdpMaxBandwidthPerIP(float f); + void SetUdpSlowStart(bool enable); + void SetCongCtrlChannelInflate(float inflate); + + void EnableUseTOSforAcks(bool enable); + void EnableROCE(bool f); + + void AbortOnFailedRequest(TUdpHttpResponse* answer); + TString GetDebugInfo(const TUdpAddress& addr, double timeout = 60); + void Kill(const TUdpAddress& addr); + void StopAllNetLibaThreads(); + + // if heartbeat timeout is set and NetLibaHeartbeat() is not called for timeoutSec + // then StopAllNetLibaThreads() will be called + void SetNetLibaHeartbeatTimeout(double timeoutSec); + void NetLibaHeartbeat(); + + bool IsLocal(const TUdpAddress& addr); +} diff --git a/library/cpp/netliba/v6/udp_socket.cpp b/library/cpp/netliba/v6/udp_socket.cpp new file mode 100644 index 0000000000..fd85ef4d00 --- /dev/null +++ b/library/cpp/netliba/v6/udp_socket.cpp @@ -0,0 +1,292 @@ +#include "stdafx.h" +#include "udp_socket.h" +#include "block_chain.h" +#include "udp_address.h" + +#include <util/datetime/cputimer.h> +#include <util/system/spinlock.h> +#include <util/random/random.h> + +#include <library/cpp/netliba/socket/socket.h> + +#include <errno.h> + +//#define SIMULATE_NETWORK_FAILURES +// there is no explicit bit in the packet header for last packet of transfer +// last packet is just smaller then maximum size + +namespace NNetliba { + static bool LocalHostFound; + enum { + IPv4 = 0, + IPv6 = 1 + }; + + struct TIPv6Addr { + ui64 Network, Interface; + + TIPv6Addr() { + Zero(*this); + } + TIPv6Addr(ui64 n, ui64 i) + : Network(n) + , Interface(i) + { + } + }; + inline bool operator==(const TIPv6Addr& a, const TIPv6Addr& b) { + return a.Interface == b.Interface && a.Network == b.Network; + } + + static ui32 LocalHostIP[2]; + static TVector<ui32> LocalHostIPList[2]; + static TVector<TIPv6Addr> LocalHostIPv6List; + + // Struct sockaddr_in6 does not have ui64-array representation + // so we add it here. This avoids "strict aliasing" warnings + typedef union { + in6_addr Addr; + ui64 Addr64[2]; + } TIPv6AddrUnion; + + static ui32 GetIPv6SuffixCrc(const sockaddr_in6& addr) { + TIPv6AddrUnion a; + a.Addr = addr.sin6_addr; + ui64 suffix = a.Addr64[1]; + return (suffix & 0xffffffffll) + (suffix >> 32); + } + + bool InitLocalIPList() { + // Do not use TMutex here: it has a non-trivial destructor which will be called before + // destruction of current thread, if its TThread declared as global/static variable. + static TAdaptiveLock cs; + TGuard lock(cs); + + if (LocalHostFound) + return true; + + TVector<TUdpAddress> addrs; + if (!GetLocalAddresses(&addrs)) + return false; + for (int i = 0; i < addrs.ysize(); ++i) { + const TUdpAddress& addr = addrs[i]; + if (addr.IsIPv4()) { + LocalHostIPList[IPv4].push_back(addr.GetIPv4()); + LocalHostIP[IPv4] = addr.GetIPv4(); + } else { + sockaddr_in6 addr6; + GetWinsockAddr(&addr6, addr); + + LocalHostIPList[IPv6].push_back(GetIPv6SuffixCrc(addr6)); + LocalHostIP[IPv6] = GetIPv6SuffixCrc(addr6); + LocalHostIPv6List.push_back(TIPv6Addr(addr.Network, addr.Interface)); + } + } + LocalHostFound = true; + return true; + } + + template <class T, class TElem> + inline bool IsInSet(const T& c, const TElem& e) { + return Find(c.begin(), c.end(), e) != c.end(); + } + + bool IsLocalIPv4(ui32 ip) { + return IsInSet(LocalHostIPList[IPv4], ip); + } + bool IsLocalIPv6(ui64 network, ui64 iface) { + return IsInSet(LocalHostIPv6List, TIPv6Addr(network, iface)); + } + + ////////////////////////////////////////////////////////////////////////// + void TNetSocket::Open(int port) { + TIntrusivePtr<NNetlibaSocket::ISocket> theSocket = NNetlibaSocket::CreateSocket(); + theSocket->Open(port); + Open(theSocket); + } + + void TNetSocket::Open(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket) { + s = socket; + if (IsValid()) { + PortCrc = s->GetSelfAddress().sin6_port; + } + } + + void TNetSocket::Close() { + if (IsValid()) { + s->Close(); + } + } + + void TNetSocket::SendSelfFakePacket() const { + s->CancelWait(); + } + + inline ui32 CalcAddressCrc(const sockaddr_in6& addr) { + Y_ASSERT(addr.sin6_family == AF_INET6); + const ui64* addr64 = (const ui64*)addr.sin6_addr.s6_addr; + const ui32* addr32 = (const ui32*)addr.sin6_addr.s6_addr; + if (addr64[0] == 0 && addr32[2] == 0xffff0000ll) { + // ipv4 + return addr32[3]; + } else { + // ipv6 + return GetIPv6SuffixCrc(addr); + } + } + + TNetSocket::ESendError TNetSocket::SendTo(const char* buf, int size, const sockaddr_in6& toAddress, const EFragFlag frag) const { + Y_ASSERT(size >= UDP_LOW_LEVEL_HEADER_SIZE); + ui32 crc = CalcChecksum(buf + UDP_LOW_LEVEL_HEADER_SIZE, size - UDP_LOW_LEVEL_HEADER_SIZE); + ui32 ipCrc = CalcAddressCrc(toAddress); + ui32 portCrc = toAddress.sin6_port; + *(ui32*)buf = crc + ipCrc + portCrc; +#ifdef SIMULATE_NETWORK_FAILURES + if ((RandomNumber<size_t>() % 3) == 0) + return true; // packet lost + if ((RandomNumber<size_t>() % 3) == 0) + (char&)(buf[RandomNumber<size_t>() % size]) += RandomNumber<size_t>(); // packet broken +#endif + + char tosBuffer[NNetlibaSocket::TOS_BUFFER_SIZE]; + void* t = NNetlibaSocket::CreateTos(Tos, tosBuffer); + const NNetlibaSocket::TIoVec iov = NNetlibaSocket::CreateIoVec((char*)buf, size); + NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateSendMsgHdr(toAddress, iov, t); + + const int rv = s->SendMsg(&hdr, 0, frag); + if (rv < 0) { + if (errno == EHOSTUNREACH || errno == ENETUNREACH) { + return SEND_NO_ROUTE_TO_HOST; + } else { + return SEND_BUFFER_OVERFLOW; + } + } + Y_ASSERT(rv == size); + return SEND_OK; + } + + inline bool CrcMatches(ui32 pktCrc, ui32 crc, const sockaddr_in6& addr) { + Y_ASSERT(LocalHostFound); + Y_ASSERT(addr.sin6_family == AF_INET6); + // determine our ip address family based on the sender address + // address family can not change in network, so sender address type determines type of our address used + const ui64* addr64 = (const ui64*)addr.sin6_addr.s6_addr; + const ui32* addr32 = (const ui32*)addr.sin6_addr.s6_addr; + yint ipType; + if (addr64[0] == 0 && addr32[2] == 0xffff0000ll) { + // ipv4 + ipType = IPv4; + } else { + // ipv6 + ipType = IPv6; + } + if (crc + LocalHostIP[ipType] == pktCrc) { + return true; + } + // crc failed + // check if packet was sent to different IP address + for (int idx = 0; idx < LocalHostIPList[ipType].ysize(); ++idx) { + ui32 otherIP = LocalHostIPList[ipType][idx]; + if (crc + otherIP == pktCrc) { + LocalHostIP[ipType] = otherIP; + return true; + } + } + // crc is really failed, discard packet + return false; + } + + bool TNetSocket::RecvFrom(char* buf, int* size, sockaddr_in6* fromAddress) const { + for (;;) { + int rv; + if (s->IsRecvMsgSupported()) { + const NNetlibaSocket::TIoVec v = NNetlibaSocket::CreateIoVec(buf, *size); + NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateRecvMsgHdr(fromAddress, v); + rv = s->RecvMsg(&hdr, 0); + + } else { + sockaddr_in6 dummy; + TAutoPtr<NNetlibaSocket::TUdpRecvPacket> pkt = s->Recv(fromAddress, &dummy, -1); + rv = !!pkt ? pkt->DataSize - pkt->DataStart : -1; + if (rv > 0) { + memcpy(buf, pkt->Data.get() + pkt->DataStart, rv); + } + } + + if (rv < 0) + return false; + // ignore empty packets + if (rv == 0) + continue; + // skip small packets + if (rv < UDP_LOW_LEVEL_HEADER_SIZE) + continue; + *size = rv; + ui32 pktCrc = *(ui32*)buf; + ui32 crc = CalcChecksum(buf + UDP_LOW_LEVEL_HEADER_SIZE, rv - UDP_LOW_LEVEL_HEADER_SIZE); + if (!CrcMatches(pktCrc, crc + PortCrc, *fromAddress)) { + // crc is really failed, discard packet + continue; + } + return true; + } + } + + void TNetSocket::Wait(float timeoutSec) const { + s->Wait(timeoutSec); + } + + void TNetSocket::SetTOS(int n) const { + Tos = n; + } + + bool TNetSocket::Connect(const sockaddr_in6& addr) { + // "connect" - meaningless operation + // needed since port unreachable is routed only to "connected" udp sockets in ingenious FreeBSD + if (s->Connect((sockaddr*)&addr, sizeof(addr)) < 0) { + if (errno == EHOSTUNREACH || errno == ENETUNREACH) { + return false; + } else { + Y_ASSERT(0); + } + } + return true; + } + + void TNetSocket::SendEmptyPacket() { + NNetlibaSocket::TIoVec v; + Zero(v); + + // darwin ignores packets with msg_iovlen == 0, also windows implementation uses sendto of first iovec. + NNetlibaSocket::TMsgHdr hdr; + Zero(hdr); + hdr.msg_iov = &v; + hdr.msg_iovlen = 1; + + s->SendMsg(&hdr, 0, FF_ALLOW_FRAG); // sends empty packet to connected address + } + + bool TNetSocket::IsHostUnreachable() { +#ifdef _win_ + char buf[10000]; + sockaddr_in6 fromAddress; + + const NNetlibaSocket::TIoVec v = NNetlibaSocket::CreateIoVec(buf, Y_ARRAY_SIZE(buf)); + NNetlibaSocket::TMsgHdr hdr = NNetlibaSocket::CreateRecvMsgHdr(&fromAddress, v); + + const ssize_t rv = s->RecvMsg(&hdr, 0); + if (rv < 0) { + int err = WSAGetLastError(); + if (err == WSAECONNRESET) + return true; + } +#else + int err = 0; + socklen_t bufSize = sizeof(err); + s->GetSockOpt(SOL_SOCKET, SO_ERROR, (char*)&err, &bufSize); + if (err == ECONNREFUSED) + return true; +#endif + return false; + } +} diff --git a/library/cpp/netliba/v6/udp_socket.h b/library/cpp/netliba/v6/udp_socket.h new file mode 100644 index 0000000000..bd95b8bcd0 --- /dev/null +++ b/library/cpp/netliba/v6/udp_socket.h @@ -0,0 +1,59 @@ +#pragma once + +#include <util/generic/ptr.h> +#include <util/generic/utility.h> +#include <library/cpp/netliba/socket/socket.h> + +namespace NNetliba { + bool IsLocalIPv4(ui32 ip); + bool IsLocalIPv6(ui64 network, ui64 iface); + bool InitLocalIPList(); + + enum { + UDP_LOW_LEVEL_HEADER_SIZE = 4, + }; + + using NNetlibaSocket::EFragFlag; + using NNetlibaSocket::FF_ALLOW_FRAG; + using NNetlibaSocket::FF_DONT_FRAG; + + class TNetSocket: public TNonCopyable { + TIntrusivePtr<NNetlibaSocket::ISocket> s; + ui32 PortCrc; + mutable int Tos; + + public: + enum ESendError { + SEND_OK, + SEND_BUFFER_OVERFLOW, + SEND_NO_ROUTE_TO_HOST, + }; + TNetSocket() + : PortCrc(0) + , Tos(0) + { + } + ~TNetSocket() { + } + + void Open(int port); + void Open(const TIntrusivePtr<NNetlibaSocket::ISocket>& socket); + void Close(); + void SendSelfFakePacket() const; + bool IsValid() const { + return s.Get() ? s->IsValid() : false; + } + int GetNetworkOrderPort() const { + return s->GetNetworkOrderPort(); + } + ESendError SendTo(const char* buf, int size, const sockaddr_in6& toAddress, const EFragFlag frag) const; + bool RecvFrom(char* buf, int* size, sockaddr_in6* fromAddress) const; + void Wait(float timeoutSec) const; + void SetTOS(int n) const; + + // obtaining icmp host unreachable in convoluted way + bool Connect(const sockaddr_in6& addr); + void SendEmptyPacket(); + bool IsHostUnreachable(); + }; +} diff --git a/library/cpp/netliba/v6/udp_test.cpp b/library/cpp/netliba/v6/udp_test.cpp new file mode 100644 index 0000000000..d0af51a368 --- /dev/null +++ b/library/cpp/netliba/v6/udp_test.cpp @@ -0,0 +1,161 @@ +#include "stdafx.h" +#include "udp_test.h" +#include "udp_client_server.h" +#include "udp_http.h" +#include "cpu_affinity.h" +#include <util/system/hp_timer.h> +#include <util/datetime/cputimer.h> +#include <util/random/random.h> +#include <util/random/fast.h> + +namespace NNetliba { + //static void PacketLevelTest(bool client) + //{ + // int port = client ? 0 : 13013; + // TIntrusivePtr<IUdpHost> host = CreateUdpHost(&port); + // + // if(host == 0) { + // exit(-1); + // } + // TUdpAddress serverAddr = CreateAddress("localhost", 13013); + // vector<char> dummyPacket; + // dummyPacket.resize(10000); + // srand(GetCycleCount()); + // + // for (int i = 0; i < dummyPacket.size(); ++i) + // dummyPacket[i] = rand(); + // bool cont = true, hasReply = true; + // int reqCount = 1; + // for (int i = 0; cont; ++i) { + // host->Step(); + // if (client) { + // //while (host->HasPendingData(serverAddr)) + // // Sleep(0); + // if (hasReply) { + // printf("request %d\n", reqCount); + // *(int*)&dummyPacket[0] = reqCount; + // host->Send(serverAddr, dummyPacket, 0, PP_NORMAL); + // hasReply = false; + // ++reqCount; + // }else + // sleep(0); + // + // TRequest *req; + // while (req = host->GetRequest()) { + // int n = *(int*)&req->Data[0]; + // printf("received response %d\n", n); + // Y_ASSERT(memcmp(&req->Data[4], &dummyPacket[4], dummyPacket.size() - 4) == 0); + // delete req; + // hasReply = true; + // } + // TSendResult sr; + // while (host->GetSendResult(&sr)) { + // if (!sr.Success) { + // printf("Send failed!\n"); + // //Sleep(INFINITE); + // hasReply = true; + // } + // } + // } else { + // while (TRequest *req = host->GetRequest()) { + // int n = *(int*)&req->Data[0]; + // printf("responding %d\n", n); + // host->Send(req->Address, req->Data, 0, PP_NORMAL); + // delete req; + // } + // TSendResult sr; + // while (host->GetSendResult(&sr)) { + // if (!sr.Success) { + // printf("Send failed!\n"); + // sleep(0); + // } + // } + // sleep(0); + // } + // } + //} + + static void SessionLevelTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort) { + BindToSocket(0); + TIntrusivePtr<IRequester> reqHost; + // reqHost = CreateHttpUdpRequester(13013); + reqHost = CreateHttpUdpRequester(client ? srcPort : 13013); + TUdpAddress serverAddr = CreateAddress(serverName, 13013); + TVector<char> dummyPacket; + dummyPacket.resize(packetSize); + TReallyFastRng32 rr((unsigned int)GetCycleCount()); + for (size_t i = 0; i < dummyPacket.size(); ++i) + dummyPacket[i] = (char)rr.Uniform(256); + bool cont = true; + NHPTimer::STime t; + NHPTimer::GetTime(&t); + THashMap<TGUID, bool, TGUIDHash> seenReqs; + if (client) { + THashMap<TGUID, bool, TGUIDHash> reqList; + int packetsSentCount = 0; + TUdpHttpRequest* udpReq; + for (int i = 1; cont; ++i) { + for (;;) { + udpReq = reqHost->GetRequest(); + if (udpReq == nullptr) + break; + udpReq->Data.resize(10); + reqHost->SendResponse(udpReq->ReqId, &udpReq->Data); + delete udpReq; + } + while (TUdpHttpResponse* res = reqHost->GetResponse()) { + THashMap<TGUID, bool, TGUIDHash>::iterator z = reqList.find(res->ReqId); + if (z == reqList.end()) { + printf("Unexpected response\n"); + abort(); + } + reqList.erase(z); + if (res->Ok) { + ++packetsSentCount; + //Y_ASSERT(res->Data == dummyPacket); + NHPTimer::STime tChk = t; + if (NHPTimer::GetTimePassed(&tChk) > 1) { + printf("packet size = %d\n", dummyPacket.ysize()); + double passedTime = NHPTimer::GetTimePassed(&t); + double rate = packetsSentCount / passedTime; + printf("packet rate %g, transfer %gmb\n", rate, rate * dummyPacket.size() / 1000000); + packetsSentCount = 0; + } + } else { + printf("Failed request!\n"); + //Sleep(INFINITE); + } + delete res; + } + while (reqList.ysize() < packetsInFly) { + *(int*)&dummyPacket[0] = i; + TVector<char> fakePacket = dummyPacket; + TGUID req2 = reqHost->SendRequest(serverAddr, "blaxuz", &fakePacket); + reqList[req2]; + } + reqHost->GetAsyncEvent().Wait(); + } + } else { + TUdpHttpRequest* req; + for (;;) { + req = reqHost->GetRequest(); + if (req) { + if (seenReqs.find(req->ReqId) != seenReqs.end()) { + printf("Request %s recieved twice!\n", GetGuidAsString(req->ReqId).c_str()); + } + seenReqs[req->ReqId]; + req->Data.resize(10); + reqHost->SendResponse(req->ReqId, &req->Data); + delete req; + } else { + reqHost->GetAsyncEvent().Wait(); + } + } + } + } + + void RunUdpTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort) { + //PacketLevelTest(client); + SessionLevelTest(client, serverName, packetSize, packetsInFly, srcPort); + } +} diff --git a/library/cpp/netliba/v6/udp_test.h b/library/cpp/netliba/v6/udp_test.h new file mode 100644 index 0000000000..68435fee5c --- /dev/null +++ b/library/cpp/netliba/v6/udp_test.h @@ -0,0 +1,5 @@ +#pragma once + +namespace NNetliba { + void RunUdpTest(bool client, const char* serverName, int packetSize, int packetsInFly, int srcPort = 0); +} |