aboutsummaryrefslogblamecommitdiffstats
path: root/library/cpp/netliba/v6/ib_buffers.h
blob: ae1857cdcbc6f9288eb35274cc81d47da5ab2a19 (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13











                                                                                             


                                                             
























                                                                          
                                                        


















































                                                                                             
                                                    














                                                                                                          
                                                    
































































                                                                                                                         
#pragma once

#include "ib_low.h"

namespace NNetliba {
    // buffer id 0 is special, it is used when data is sent inline and should not be returned
    const size_t SMALL_PKT_SIZE = 4096;

    const ui64 BP_AH_USED_FLAG = 0x1000000000000000ul;
    const ui64 BP_BUF_ID_MASK = 0x00000000fffffffful;

    // single thread version
    class TIBBufferPool: public TThrRefBase, TNonCopyable {
        static constexpr int BLOCK_SIZE_LN = 11;
        static constexpr int BLOCK_SIZE = 1 << BLOCK_SIZE_LN;
        static constexpr int BLOCK_COUNT = 1024;

        struct TSingleBlock {
            TIntrusivePtr<TMemoryRegion> Mem;
            TVector<ui8> BlkRefCounts;
            TVector<TIntrusivePtr<TAddressHandle>> AHHolder;

            void Alloc(TPtrArg<TIBContext> ctx) {
                size_t dataSize = SMALL_PKT_SIZE * BLOCK_SIZE;
                Mem = new TMemoryRegion(ctx, dataSize);
                BlkRefCounts.resize(BLOCK_SIZE, 0);
                AHHolder.resize(BLOCK_SIZE);
            }
            char* GetBufData(ui64 idArg) {
                char* data = Mem->GetData();
                return data + (idArg & (BLOCK_SIZE - 1)) * SMALL_PKT_SIZE;
            }
        };

        TIntrusivePtr<TIBContext> IBCtx;
        TVector<int> FreeList;
        TVector<TSingleBlock> Blocks;
        size_t FirstFreeBlock;
        int PostRecvDeficit;
        TIntrusivePtr<TSharedReceiveQueue> SRQ;

        void AddBlock() {
            if (FirstFreeBlock == Blocks.size()) {
                Y_ABORT_UNLESS(0, "run out of buffers");
            }
            Blocks[FirstFreeBlock].Alloc(IBCtx);
            size_t start = (FirstFreeBlock == 0) ? 1 : FirstFreeBlock * BLOCK_SIZE;
            size_t finish = FirstFreeBlock * BLOCK_SIZE + BLOCK_SIZE;
            for (size_t i = start; i < finish; ++i) {
                FreeList.push_back(i);
            }
            ++FirstFreeBlock;
        }

    public:
        TIBBufferPool(TPtrArg<TIBContext> ctx, int maxSRQWorkRequests)
            : IBCtx(ctx)
            , FirstFreeBlock(0)
            , PostRecvDeficit(maxSRQWorkRequests)
        {
            Blocks.resize(BLOCK_COUNT);
            AddBlock();
            SRQ = new TSharedReceiveQueue(ctx, maxSRQWorkRequests);

            PostRecv();
        }
        TSharedReceiveQueue* GetSRQ() const {
            return SRQ.Get();
        }
        int AllocBuf() {
            if (FreeList.empty()) {
                AddBlock();
            }
            int id = FreeList.back();
            FreeList.pop_back();
            Y_ASSERT(++Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 1);
            return id;
        }
        void FreeBuf(ui64 idArg) {
            ui64 id = idArg & BP_BUF_ID_MASK;
            if (id == 0) {
                return;
            }
            Y_ASSERT(id > 0 && id < (ui64)(FirstFreeBlock * BLOCK_SIZE));
            FreeList.push_back(id);
            Y_ASSERT(--Blocks[id >> BLOCK_SIZE_LN].BlkRefCounts[id & (BLOCK_SIZE - 1)] == 0);
            if (idArg & BP_AH_USED_FLAG) {
                Blocks[id >> BLOCK_SIZE_LN].AHHolder[id & (BLOCK_SIZE - 1)] = nullptr;
            }
        }
        char* GetBufData(ui64 idArg) {
            ui64 id = idArg & BP_BUF_ID_MASK;
            return Blocks[id >> BLOCK_SIZE_LN].GetBufData(id);
        }
        int PostSend(TPtrArg<TRCQueuePair> qp, const void* data, size_t len) {
            if (len > SMALL_PKT_SIZE) {
                Y_ABORT_UNLESS(0, "buffer overrun");
            }
            if (len <= MAX_INLINE_DATA_SIZE) {
                qp->PostSend(nullptr, 0, data, len);
                return 0;
            } else {
                int id = AllocBuf();
                TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
                char* buf = blk.GetBufData(id);
                memcpy(buf, data, len);
                qp->PostSend(blk.Mem, id, buf, len);
                return id;
            }
        }
        void PostSend(TPtrArg<TUDQueuePair> qp, TPtrArg<TAddressHandle> ah, int remoteQPN, int remoteQKey,
                      const void* data, size_t len) {
            if (len > SMALL_PKT_SIZE - 40) {
                Y_ABORT_UNLESS(0, "buffer overrun");
            }
            ui64 id = AllocBuf();
            TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
            int ptr = id & (BLOCK_SIZE - 1);
            blk.AHHolder[ptr] = ah.Get();
            id |= BP_AH_USED_FLAG;
            if (len <= MAX_INLINE_DATA_SIZE) {
                qp->PostSend(ah, remoteQPN, remoteQKey, nullptr, id, data, len);
            } else {
                char* buf = blk.GetBufData(id);
                memcpy(buf, data, len);
                qp->PostSend(ah, remoteQPN, remoteQKey, blk.Mem, id, buf, len);
            }
        }
        void RequestPostRecv() {
            ++PostRecvDeficit;
        }
        void PostRecv() {
            for (int i = 0; i < PostRecvDeficit; ++i) {
                int id = AllocBuf();
                TSingleBlock& blk = Blocks[id >> BLOCK_SIZE_LN];
                char* buf = blk.GetBufData(id);
                SRQ->PostReceive(blk.Mem, id, buf, SMALL_PKT_SIZE);
            }
            PostRecvDeficit = 0;
        }
    };

    class TIBRecvPacketProcess: public TNonCopyable {
        TIBBufferPool& BP;
        ui64 Id;
        char* Data;

    public:
        TIBRecvPacketProcess(TIBBufferPool& bp, const ibv_wc& wc)
            : BP(bp)
            , Id(wc.wr_id)
        {
            Y_ASSERT(wc.opcode & IBV_WC_RECV);
            BP.RequestPostRecv();
            Data = BP.GetBufData(Id);
        }
        // intended for postponed packet processing
        // with this call RequestPostRecv() should be called outside (and PostRecv() too in order to avoid rnr situation)
        TIBRecvPacketProcess(TIBBufferPool& bp, const ui64 wr_id)
            : BP(bp)
            , Id(wr_id)
        {
            Data = BP.GetBufData(Id);
        }
        ~TIBRecvPacketProcess() {
            BP.FreeBuf(Id);
            BP.PostRecv();
        }
        char* GetData() const {
            return Data;
        }
        char* GetUDData() const {
            return Data + 40;
        }
        ibv_grh* GetGRH() const {
            return (ibv_grh*)Data;
        }
    };

}