#pragma once
#include <library/cpp/binsaver/bin_saver.h>
namespace NNetliba {
struct TCollectiveInit {
int Size, Rank;
SAVELOAD(Size, Rank);
};
struct TCollectiveLinkSet {
struct TLinkInfo {
int QPN, PSN;
};
TVector<int> Hosts; // host LIDs
TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
TVector<TLinkInfo> Links;
SAVELOAD(Hosts, HostGroup, Links);
};
struct IAllDataSync: public TThrRefBase {
virtual void* GetRawData() = 0;
virtual size_t GetRawDataSize() = 0;
virtual void Sync() = 0;
virtual void Flush() = 0;
template <class T>
T* GetData() {
return static_cast<T*>(GetRawData());
}
template <class T>
size_t GetSize() {
return GetRawDataSize() / sizeof(T);
}
};
struct IAllReduce: public IAllDataSync {
virtual bool Resize(size_t dataSize) = 0;
};
struct IAllGather: public IAllDataSync {
virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
};
struct IReduceOp: public TThrRefBase {
virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
};
template <class T, class TElem = typename T::TElem>
class TAllReduceOp: public IReduceOp {
T Op;
public:
TAllReduceOp() {
}
TAllReduceOp(T op)
: Op(op)
{
}
void Reduce(void* dst, const void* add, size_t dataSize) const override {
TElem* dstPtr = (TElem*)(dst);
const TElem* addPtr = (const TElem*)(add);
TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
while (dstPtr < finPtr) {
Op(dstPtr, *addPtr);
++dstPtr;
++addPtr;
}
}
};
// table of active peers for micro send/recv
class TIBMicroPeerTable {
TVector<ui8> Table; // == 0 means accept mesages from this qpn
int TableSize;
bool ParsePending;
public:
TIBMicroPeerTable()
: ParsePending(true)
{
Init(0);
}
void Init(int tableSizeLog) {
TableSize = 1 << tableSizeLog;
ParsePending = true;
Table.resize(0);
Table.resize(TableSize, 0);
}
bool NeedParsePending() const {
return ParsePending;
}
void StopParsePending() {
ParsePending = false;
}
void StopQPN(int qpn, ui8 mask) {
Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
Table[qpn & (TableSize - 1)] |= mask;
}
void StopQPN(int qpn) {
Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
Table[qpn & (TableSize - 1)] = 0xff;
}
bool NeedQPN(int qpn) const {
return Table[qpn & (TableSize - 1)] != 0xff;
}
};
struct IIBCollective;
class TIBCollective;
class TIBRecvMicro: public TNonCopyable {
TIBCollective& IB;
ui64 Id;
int QPN;
void* Data;
public:
TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
~TIBRecvMicro();
void* GetRawData() const {
return Data;
}
template <class T>
T* GetData() {
return static_cast<T*>(GetRawData());
}
int GetQPN() const {
return QPN;
}
};
struct IIBCollective: public TThrRefBase {
struct TRdmaRequest {
int DstRank;
ui64 RemoteAddr, LocalAddr;
ui32 RemoteKey, LocalKey;
ui64 Size;
};
virtual int GetRank() = 0;
virtual int GetSize() = 0;
virtual int GetGroupTypeCount() = 0;
virtual int GetQPN(int rank) = 0;
virtual bool TryWaitCompletion() = 0;
virtual void WaitCompletion() = 0;
virtual void Start(const TCollectiveLinkSet& links) = 0;
virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
virtual void Fence() = 0;
virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
};
IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
}