diff options
author | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
---|---|---|
committer | monster <monster@ydb.tech> | 2022-07-07 14:41:37 +0300 |
commit | 06e5c21a835c0e923506c4ff27929f34e00761c2 (patch) | |
tree | 75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/netliba/v6/ib_collective.h | |
parent | 03f024c4412e3aa613bb543cf1660176320ba8f4 (diff) | |
download | ydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz |
fix ya.make
Diffstat (limited to 'library/cpp/netliba/v6/ib_collective.h')
-rw-r--r-- | library/cpp/netliba/v6/ib_collective.h | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/library/cpp/netliba/v6/ib_collective.h b/library/cpp/netliba/v6/ib_collective.h new file mode 100644 index 0000000000..48ffd29b34 --- /dev/null +++ b/library/cpp/netliba/v6/ib_collective.h @@ -0,0 +1,160 @@ +#pragma once + +#include <library/cpp/binsaver/bin_saver.h> + +namespace NNetliba { + struct TCollectiveInit { + int Size, Rank; + + SAVELOAD(Size, Rank) + }; + + struct TCollectiveLinkSet { + struct TLinkInfo { + int QPN, PSN; + }; + TVector<int> Hosts; // host LIDs + TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch + TVector<TLinkInfo> Links; + + SAVELOAD(Hosts, HostGroup, Links) + }; + + struct IAllDataSync: public TThrRefBase { + virtual void* GetRawData() = 0; + virtual size_t GetRawDataSize() = 0; + virtual void Sync() = 0; + virtual void Flush() = 0; + + template <class T> + T* GetData() { + return static_cast<T*>(GetRawData()); + } + template <class T> + size_t GetSize() { + return GetRawDataSize() / sizeof(T); + } + }; + + struct IAllReduce: public IAllDataSync { + virtual bool Resize(size_t dataSize) = 0; + }; + + struct IAllGather: public IAllDataSync { + virtual bool Resize(const TVector<size_t>& szPerRank) = 0; + }; + + struct IReduceOp: public TThrRefBase { + virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0; + }; + + template <class T, class TElem = typename T::TElem> + class TAllReduceOp: public IReduceOp { + T Op; + + public: + TAllReduceOp() { + } + TAllReduceOp(T op) + : Op(op) + { + } + void Reduce(void* dst, const void* add, size_t dataSize) const override { + TElem* dstPtr = (TElem*)(dst); + const TElem* addPtr = (const TElem*)(add); + TElem* finPtr = (TElem*)(((char*)dst) + dataSize); + while (dstPtr < finPtr) { + Op(dstPtr, *addPtr); + ++dstPtr; + ++addPtr; + } + } + }; + + // table of active peers for micro send/recv + class TIBMicroPeerTable { + TVector<ui8> Table; // == 0 means accept mesages from this qpn + int TableSize; + bool ParsePending; + + public: + TIBMicroPeerTable() + : ParsePending(true) + { + Init(0); + } + void Init(int tableSizeLog) { + TableSize = 1 << tableSizeLog; + ParsePending = true; + Table.resize(0); + Table.resize(TableSize, 0); + } + bool NeedParsePending() const { + return ParsePending; + } + void StopParsePending() { + ParsePending = false; + } + void StopQPN(int qpn, ui8 mask) { + Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0); + Table[qpn & (TableSize - 1)] |= mask; + } + void StopQPN(int qpn) { + Y_ASSERT(Table[qpn & (TableSize - 1)] == 0); + Table[qpn & (TableSize - 1)] = 0xff; + } + bool NeedQPN(int qpn) const { + return Table[qpn & (TableSize - 1)] != 0xff; + } + }; + + struct IIBCollective; + class TIBCollective; + class TIBRecvMicro: public TNonCopyable { + TIBCollective& IB; + ui64 Id; + int QPN; + void* Data; + + public: + TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable); + ~TIBRecvMicro(); + void* GetRawData() const { + return Data; + } + template <class T> + T* GetData() { + return static_cast<T*>(GetRawData()); + } + int GetQPN() const { + return QPN; + } + }; + + struct IIBCollective: public TThrRefBase { + struct TRdmaRequest { + int DstRank; + ui64 RemoteAddr, LocalAddr; + ui32 RemoteKey, LocalKey; + ui64 Size; + }; + + virtual int GetRank() = 0; + virtual int GetSize() = 0; + virtual int GetGroupTypeCount() = 0; + virtual int GetQPN(int rank) = 0; + virtual bool TryWaitCompletion() = 0; + virtual void WaitCompletion() = 0; + virtual void Start(const TCollectiveLinkSet& links) = 0; + virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0; + virtual IAllGather* CreateAllGather(size_t szPerRank) = 0; + virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0; + virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0; + virtual void Fence() = 0; + virtual void InitPeerTable(TIBMicroPeerTable* res) = 0; + virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0; + virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0; + }; + + IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks); +} |