aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/testing/common/network.cpp
blob: 230c50ee6dea6c08f5f46b6b8137dd8e13e61c91 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#include "network.h"

#include <util/folder/dirut.h>
#include <util/folder/path.h>
#include <util/generic/singleton.h>
#include <util/generic/utility.h>
#include <util/generic/vector.h>
#include <util/generic/ylimits.h>
#include <util/network/address.h>
#include <util/network/sock.h>
#include <util/random/random.h>
#include <util/stream/file.h>
#include <util/string/split.h>
#include <util/system/env.h>
#include <util/system/error.h>
#include <util/system/file_lock.h>
#include <util/system/fs.h>

#ifdef _darwin_
#include <sys/types.h>
#include <sys/sysctl.h>
#endif

namespace {
#define Y_VERIFY_SYSERROR(expr)                                           \
    do {                                                                  \
        if (!(expr)) {                                                    \
            Y_FAIL(#expr ", errno=%d", LastSystemError());                \
        }                                                                 \
    } while (false)

    class TPortGuard : public NTesting::IPort {
    public:
        TPortGuard(ui16 port, THolder<TFileLock> lock)
            : Lock_(std::move(lock))
            , Port_(port)
        {
        }

        ~TPortGuard() override {
            Y_VERIFY_SYSERROR(NFs::Remove(Lock_->GetName()));
        }

        ui16 Get() override {
            return Port_;
        }

    private:
        THolder<TFileLock> Lock_;
        ui16 Port_;
    };

    std::pair<ui16, ui16> GetEphemeralRange() {
        // IANA suggestion
        std::pair<ui16, ui16> pair{(1 << 15) + (1 << 14), (1 << 16) - 1};
    #ifdef _linux_
        if (NFs::Exists("/proc/sys/net/ipv4/ip_local_port_range")) {
                TIFStream fileStream("/proc/sys/net/ipv4/ip_local_port_range");
                fileStream >> pair.first >> pair.second;
            }
    #endif
    #ifdef _darwin_
        ui32 first, last;
        size_t size;
        sysctlbyname("net.inet.ip.portrange.first", &first, &size, NULL, 0);
        sysctlbyname("net.inet.ip.portrange.last", &last, &size, NULL, 0);
        pair.first = first;
        pair.second = last;
    #endif
        return pair;
    }

    TVector<std::pair<ui16, ui16>> GetPortRanges() {
        TString givenRange = GetEnv("VALID_PORT_RANGE");
        TVector<std::pair<ui16, ui16>> ranges;
        if (givenRange.Contains(':')) {
            auto res = StringSplitter(givenRange).Split(':').Limit(2).ToList<TString>();
            ranges.emplace_back(FromString<ui16>(res.front()), FromString<ui16>(res.back()));
        } else {
            const ui16 firstValid = 1025;
            const ui16 lastValid = Max<ui16>();

            auto [firstEphemeral, lastEphemeral] = GetEphemeralRange();
            const ui16 firstInvalid = Max(firstEphemeral, firstValid);
            const ui16 lastInvalid = Min(lastEphemeral, lastValid);

            if (firstInvalid > firstValid)
                ranges.emplace_back(firstValid, firstInvalid - 1);
            if (lastInvalid < lastValid)
                ranges.emplace_back(lastInvalid + 1, lastValid);
        }
        return ranges;
    }

    class TPortManager {
        static constexpr size_t Retries = 20;
    public:
        TPortManager()
            : SyncDir_(GetEnv("PORT_SYNC_PATH"))
            , Ranges_(GetPortRanges())
            , TotalCount_(0)
        {
            if (!SyncDir_.IsDefined()) {
                SyncDir_ = TFsPath(GetSystemTempDir()) / "yandex_port_locks";
            }
            Y_VERIFY(SyncDir_.IsDefined());
            NFs::MakeDirectoryRecursive(SyncDir_);

            for (auto [left, right] : Ranges_) {
                TotalCount_ += right - left;
            }
            Y_VERIFY(0 != TotalCount_);
        }

        NTesting::TPortHolder GetFreePort() const {
            ui16 salt = RandomNumber<ui16>();
            for (ui16 attempt = 0; attempt < TotalCount_; ++attempt) {
                ui16 probe = (salt + attempt) % TotalCount_;

                for (auto [left, right] : Ranges_) {
                    if (probe >= right - left)
                        probe -= right - left;
                    else {
                        probe += left;
                        break;
                    }
                }

                auto port = TryAcquirePort(probe);
                if (port) {
                    return NTesting::TPortHolder{std::move(port)};
                }
            }

            Y_FAIL("Cannot get free port!");
        }

        TVector<NTesting::TPortHolder> GetFreePortsRange(size_t count) const {
            Y_VERIFY(count > 0);
            TVector<NTesting::TPortHolder> ports(Reserve(count));
            for (size_t i = 0; i < Retries; ++i) {
                for (auto[left, right] : Ranges_) {
                    if (right - left < count) {
                        continue;
                    }
                    ui16 start = left + RandomNumber<ui16>((right - left) / 2);
                    if (right - start < count) {
                        continue;
                    }
                    for (ui16 probe = start; probe < right; ++probe) {
                        auto port = TryAcquirePort(probe);
                        if (port) {
                            ports.emplace_back(std::move(port));
                        } else {
                            ports.clear();
                        }
                        if (ports.size() == count) {
                            return ports;
                        }
                    }
                    // Can't find required number of ports without gap in the current range
                    ports.clear();
                }
            }
            Y_FAIL("Cannot get range of %zu ports!", count);
        }

    private:
        THolder<NTesting::IPort> TryAcquirePort(ui16 port) const {
            auto lock = MakeHolder<TFileLock>(TString(SyncDir_ / ::ToString(port)));
            if (!lock->TryAcquire()) {
                return nullptr;
            }

            TInet6StreamSocket sock;
            Y_VERIFY_SYSERROR(INVALID_SOCKET != static_cast<SOCKET>(sock));

            TSockAddrInet6 addr("::", port);
            if (sock.Bind(&addr) != 0) {
                lock->Release();
                Y_VERIFY(EADDRINUSE == LastSystemError(), "unexpected error: %d", LastSystemError());
                return nullptr;
            }
            return MakeHolder<TPortGuard>(port, std::move(lock));
        }

    private:
        TFsPath SyncDir_;
        TVector<std::pair<ui16, ui16>> Ranges_;
        size_t TotalCount_;
    };
}

namespace NTesting {
    TPortHolder GetFreePort() {
        return Singleton<TPortManager>()->GetFreePort();
    }

    namespace NLegacy {
        TVector<TPortHolder> GetFreePortsRange(size_t count) {
            return Singleton<TPortManager>()->GetFreePortsRange(count);
        }
    }

    IOutputStream& operator<<(IOutputStream& out, const TPortHolder& port) {
        return out << static_cast<ui16>(port);
    }
}