aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/netliba/v6/ib_collective.h
blob: 2ce790e9605265405399eaf75df3b3225ced2e22 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#pragma once

#include <library/cpp/binsaver/bin_saver.h>

namespace NNetliba {
    struct TCollectiveInit {
        int Size, Rank;

        SAVELOAD(Size, Rank);
    };

    struct TCollectiveLinkSet {
        struct TLinkInfo {
            int QPN, PSN;
        };
        TVector<int> Hosts;              // host LIDs
        TVector<TVector<int>> HostGroup; // HostGroup[0] - switchId, HostGroup[1] - hostId within the switch
        TVector<TLinkInfo> Links;

        SAVELOAD(Hosts, HostGroup, Links);
    };

    struct IAllDataSync: public TThrRefBase {
        virtual void* GetRawData() = 0;
        virtual size_t GetRawDataSize() = 0;
        virtual void Sync() = 0;
        virtual void Flush() = 0;

        template <class T>
        T* GetData() {
            return static_cast<T*>(GetRawData());
        }
        template <class T>
        size_t GetSize() {
            return GetRawDataSize() / sizeof(T);
        }
    };

    struct IAllReduce: public IAllDataSync {
        virtual bool Resize(size_t dataSize) = 0;
    };

    struct IAllGather: public IAllDataSync {
        virtual bool Resize(const TVector<size_t>& szPerRank) = 0;
    };

    struct IReduceOp: public TThrRefBase {
        virtual void Reduce(void* dst, const void* add, size_t dataSize) const = 0;
    };

    template <class T, class TElem = typename T::TElem>
    class TAllReduceOp: public IReduceOp {
        T Op;

    public:
        TAllReduceOp() {
        }
        TAllReduceOp(T op)
            : Op(op)
        {
        }
        void Reduce(void* dst, const void* add, size_t dataSize) const override {
            TElem* dstPtr = (TElem*)(dst);
            const TElem* addPtr = (const TElem*)(add);
            TElem* finPtr = (TElem*)(((char*)dst) + dataSize);
            while (dstPtr < finPtr) {
                Op(dstPtr, *addPtr);
                ++dstPtr;
                ++addPtr;
            }
        }
    };

    // table of active peers for micro send/recv
    class TIBMicroPeerTable {
        TVector<ui8> Table; // == 0 means accept mesages from this qpn
        int TableSize;
        bool ParsePending;

    public:
        TIBMicroPeerTable()
            : ParsePending(true)
        {
            Init(0);
        }
        void Init(int tableSizeLog) {
            TableSize = 1 << tableSizeLog;
            ParsePending = true;
            Table.resize(0);
            Table.resize(TableSize, 0);
        }
        bool NeedParsePending() const {
            return ParsePending;
        }
        void StopParsePending() {
            ParsePending = false;
        }
        void StopQPN(int qpn, ui8 mask) {
            Y_ASSERT((Table[qpn & (TableSize - 1)] & mask) == 0);
            Table[qpn & (TableSize - 1)] |= mask;
        }
        void StopQPN(int qpn) {
            Y_ASSERT(Table[qpn & (TableSize - 1)] == 0);
            Table[qpn & (TableSize - 1)] = 0xff;
        }
        bool NeedQPN(int qpn) const {
            return Table[qpn & (TableSize - 1)] != 0xff;
        }
    };

    struct IIBCollective;
    class TIBCollective;
    class TIBRecvMicro: public TNonCopyable {
        TIBCollective& IB;
        ui64 Id;
        int QPN;
        void* Data;

    public:
        TIBRecvMicro(IIBCollective* col, TIBMicroPeerTable* peerTable);
        ~TIBRecvMicro();
        void* GetRawData() const {
            return Data;
        }
        template <class T>
        T* GetData() {
            return static_cast<T*>(GetRawData());
        }
        int GetQPN() const {
            return QPN;
        }
    };

    struct IIBCollective: public TThrRefBase {
        struct TRdmaRequest {
            int DstRank;
            ui64 RemoteAddr, LocalAddr;
            ui32 RemoteKey, LocalKey;
            ui64 Size;
        };

        virtual int GetRank() = 0;
        virtual int GetSize() = 0;
        virtual int GetGroupTypeCount() = 0;
        virtual int GetQPN(int rank) = 0;
        virtual bool TryWaitCompletion() = 0;
        virtual void WaitCompletion() = 0;
        virtual void Start(const TCollectiveLinkSet& links) = 0;
        virtual IAllGather* CreateAllGather(const TVector<size_t>& szPerRank) = 0;
        virtual IAllGather* CreateAllGather(size_t szPerRank) = 0;
        virtual IAllReduce* CreateAllReduce(size_t dataSize, TPtrArg<IReduceOp> reduceOp) = 0;
        virtual void RunBWTest(int groupType, int delta, int* targetRank, float* res) = 0;
        virtual void Fence() = 0;
        virtual void InitPeerTable(TIBMicroPeerTable* res) = 0;
        virtual bool TrySendMicro(int dstRank, const void* data, int dataSize) = 0;
        virtual void RdmaWrite(const TVector<TRdmaRequest>& reqs) = 0;
    };

    IIBCollective* CreateCollective(const TCollectiveInit& params, TCollectiveLinkSet* resLinks);
}