aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/netliba/v6/ib_collective.h
diff options
context:
space:
mode:
authormonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
committermonster <monster@ydb.tech>2022-07-07 14:41:37 +0300
commit06e5c21a835c0e923506c4ff27929f34e00761c2 (patch)
tree75efcbc6854ef9bd476eb8bf00cc5c900da436a2 /library/cpp/netliba/v6/ib_collective.h
parent03f024c4412e3aa613bb543cf1660176320ba8f4 (diff)
downloadydb-06e5c21a835c0e923506c4ff27929f34e00761c2.tar.gz
fix ya.make
Diffstat (limited to 'library/cpp/netliba/v6/ib_collective.h')
-rw-r--r--library/cpp/netliba/v6/ib_collective.h160
1 files changed, 160 insertions, 0 deletions
diff --git a/library/cpp/netliba/v6/ib_collective.h b/library/cpp/netliba/v6/ib_collective.h
new file mode 100644
index 0000000000..48ffd29b34
--- /dev/null
+++ b/library/cpp/netliba/v6/ib_collective.h
@@ -0,0 +1,160 @@
+#pragma once
+
+#include <library/cpp/binsaver/bin_saver.h>
+
+namespace NNetliba {
+ struct TCollectiveInit {
+ int Size, Rank;
+
+ SAVELOAD(Size, Rank)
+ };
+
+ struct TCollectiveLinkSet {
+ struct TLinkInfo {
+ int QPN, PSN;
+ };
+ TVector<int> Hosts; // host LIDs
+ TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
+ TVector<TLinkInfo> Links;
+
+ SAVELOAD(Hosts, HostGroup, Links)
+ };
+
+ struct IAllDataSync: public TThrRefBase {
+ virtual void* GetRawData() = 0;
+ virtual size_t GetRawDataSize() = 0;
+ virtual void Sync() = 0;
+ virtual void Flush() = 0;
+
+ template <class T>
+ T* GetData() {
+ return static_cast<T*>(GetRawData());
+ }
+ template <class T>
+ size_t GetSize() {
+ return GetRawDataSize() / sizeof(T);
+ }
+ };
+
+ struct IAllReduce: public IAllDataSync {
+ virtual bool Resize(size_t dataSize) = 0;
+ };
+
+ struct IAllGather: public IAllDataSync {
+ virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
+ };
+
+ struct IReduceOp: public TThrRefBase {
+ virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
+ };
+
+ template <class T, class TElem = typename T::TElem>
+ class TAllReduceOp: public IReduceOp {
+ T Op;
+
+ public:
+ TAllReduceOp() {
+ }
+ TAllReduceOp(T op)
+ : Op(op)
+ {
+ }
+ void Reduce(void* dst, const void* add, size_t dataSize) const override {
+ TElem* dstPtr = (TElem*)(dst);
+ const TElem* addPtr = (const TElem*)(add);
+ TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
+ while (dstPtr < finPtr) {
+ Op(dstPtr, *addPtr);
+ ++dstPtr;
+ ++addPtr;
+ }
+ }
+ };
+
+ // table of active peers for micro send/recv
+ class TIBMicroPeerTable {
+ TVector<ui8> Table; // == 0 means accept mesages from this qpn
+ int TableSize;
+ bool ParsePending;
+
+ public:
+ TIBMicroPeerTable()
+ : ParsePending(true)
+ {
+ Init(0);
+ }
+ void Init(int tableSizeLog) {
+ TableSize = 1 << tableSizeLog;
+ ParsePending = true;
+ Table.resize(0);
+ Table.resize(TableSize, 0);
+ }
+ bool NeedParsePending() const {
+ return ParsePending;
+ }
+ void StopParsePending() {
+ ParsePending = false;
+ }
+ void StopQPN(int qpn, ui8 mask) {
+ Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
+ Table[qpn & (TableSize - 1)] |= mask;
+ }
+ void StopQPN(int qpn) {
+ Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
+ Table[qpn & (TableSize - 1)] = 0xff;
+ }
+ bool NeedQPN(int qpn) const {
+ return Table[qpn & (TableSize - 1)] != 0xff;
+ }
+ };
+
+ struct IIBCollective;
+ class TIBCollective;
+ class TIBRecvMicro: public TNonCopyable {
+ TIBCollective& IB;
+ ui64 Id;
+ int QPN;
+ void* Data;
+
+ public:
+ TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
+ ~TIBRecvMicro();
+ void* GetRawData() const {
+ return Data;
+ }
+ template <class T>
+ T* GetData() {
+ return static_cast<T*>(GetRawData());
+ }
+ int GetQPN() const {
+ return QPN;
+ }
+ };
+
+ struct IIBCollective: public TThrRefBase {
+ struct TRdmaRequest {
+ int DstRank;
+ ui64 RemoteAddr, LocalAddr;
+ ui32 RemoteKey, LocalKey;
+ ui64 Size;
+ };
+
+ virtual int GetRank() = 0;
+ virtual int GetSize() = 0;
+ virtual int GetGroupTypeCount() = 0;
+ virtual int GetQPN(int rank) = 0;
+ virtual bool TryWaitCompletion() = 0;
+ virtual void WaitCompletion() = 0;
+ virtual void Start(const TCollectiveLinkSet& links) = 0;
+ virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
+ virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
+ virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
+ virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
+ virtual void Fence() = 0;
+ virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
+ virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
+ virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
+ };
+
+ IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
+}