aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/actors/interconnect/channel_scheduler.h
blob: 551a4cb61a1c5d47a2ddff36a18e93e2a4c7b29a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#pragma once

#include "interconnect_channel.h"
#include "event_holder_pool.h"

#include <memory>

namespace NActors {

    class TChannelScheduler {
        const ui32 PeerNodeId;
        std::array<std::optional<TEventOutputChannel>, 16> ChannelArray;
        THashMap<ui16, TEventOutputChannel> ChannelMap;
        std::shared_ptr<IInterconnectMetrics> Metrics;
        TEventHolderPool& Pool;
        const ui32 MaxSerializedEventSize;
        const TSessionParams Params;

        struct THeapItem {
            TEventOutputChannel *Channel;
            ui64 WeightConsumed = 0;

            friend bool operator <(const THeapItem& x, const THeapItem& y) {
                return x.WeightConsumed > y.WeightConsumed;
            }
        };

        std::vector<THeapItem> Heap;

    public:
        TChannelScheduler(ui32 peerNodeId, const TChannelsConfig& predefinedChannels,
                std::shared_ptr<IInterconnectMetrics> metrics, TEventHolderPool& pool, ui32 maxSerializedEventSize,
                TSessionParams params)
            : PeerNodeId(peerNodeId)
            , Metrics(std::move(metrics))
            , Pool(pool)
            , MaxSerializedEventSize(maxSerializedEventSize)
            , Params(std::move(params))
        {
            for (const auto& item : predefinedChannels) {
                GetOutputChannel(item.first);
            }
        }

        TEventOutputChannel *PickChannelWithLeastConsumedWeight() {
            Y_VERIFY(!Heap.empty());
            return Heap.front().Channel;
        }

        void AddToHeap(TEventOutputChannel& channel, ui64 counter) {
            if (channel.IsWorking()) {
                ui64 weight = channel.WeightConsumedOnPause;
                weight -= Min(weight, counter - channel.EqualizeCounterOnPause);
                Heap.push_back(THeapItem{&channel, weight});
                std::push_heap(Heap.begin(), Heap.end());
            }
        }

        void FinishPick(ui64 weightConsumed, ui64 counter) {
            std::pop_heap(Heap.begin(), Heap.end());
            auto& item = Heap.back();
            item.WeightConsumed += weightConsumed;
            if (item.Channel->IsWorking()) { // reschedule
                std::push_heap(Heap.begin(), Heap.end());
            } else { // remove from heap
                item.Channel->EqualizeCounterOnPause = counter;
                item.Channel->WeightConsumedOnPause = item.WeightConsumed;
                Heap.pop_back();
            }
        }

        TEventOutputChannel& GetOutputChannel(ui16 channel) {
            if (channel < ChannelArray.size()) {
                auto& res = ChannelArray[channel];
                if (Y_UNLIKELY(!res)) {
                    res.emplace(Pool, channel, PeerNodeId, MaxSerializedEventSize, Metrics,
                        Params);
                }
                return *res;
            } else {
                auto it = ChannelMap.find(channel);
                if (Y_UNLIKELY(it == ChannelMap.end())) {
                    it = ChannelMap.emplace(std::piecewise_construct, std::forward_as_tuple(channel),
                        std::forward_as_tuple(Pool, channel, PeerNodeId, MaxSerializedEventSize,
                        Metrics, Params)).first;
                }
                return it->second;
            }
        }

        ui64 Equalize() {
            if (Heap.empty()) {
                return 0; // nothing to do here -- no working channels
            }

            // find the minimum consumed weight among working channels and then adjust weights
            ui64 min = Max<ui64>();
            for (THeapItem& item : Heap) {
                min = Min(min, item.WeightConsumed);
            }
            for (THeapItem& item : Heap) {
                item.WeightConsumed -= min;
            }
            return min;
        }

        template<typename TCallback>
        void ForEach(TCallback&& callback) {
            for (auto& channel : ChannelArray) {
                if (channel) {
                    callback(*channel);
                }
            }
            for (auto& [id, channel] : ChannelMap) {
                callback(channel);
            }
        }
    };

} // NActors