aboutsummaryrefslogtreecommitdiffstats
path: root/util/network
diff options
context:
space:
mode:
authorDevtools Arcadia <arcadia-devtools@yandex-team.ru>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/network
downloadydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/network')
-rw-r--r--util/network/address.cpp204
-rw-r--r--util/network/address.h136
-rw-r--r--util/network/address_ut.cpp39
-rw-r--r--util/network/endpoint.cpp67
-rw-r--r--util/network/endpoint.h61
-rw-r--r--util/network/endpoint_ut.cpp123
-rw-r--r--util/network/hostip.cpp76
-rw-r--r--util/network/hostip.h16
-rw-r--r--util/network/init.cpp34
-rw-r--r--util/network/init.h60
-rw-r--r--util/network/interface.cpp79
-rw-r--r--util/network/interface.h17
-rw-r--r--util/network/iovec.cpp1
-rw-r--r--util/network/iovec.h65
-rw-r--r--util/network/ip.cpp1
-rw-r--r--util/network/ip.h119
-rw-r--r--util/network/ip_ut.cpp63
-rw-r--r--util/network/nonblock.cpp104
-rw-r--r--util/network/nonblock.h8
-rw-r--r--util/network/pair.cpp97
-rw-r--r--util/network/pair.h9
-rw-r--r--util/network/poller.cpp86
-rw-r--r--util/network/poller.h58
-rw-r--r--util/network/poller_ut.cpp236
-rw-r--r--util/network/pollerimpl.cpp1
-rw-r--r--util/network/pollerimpl.h706
-rw-r--r--util/network/sock.cpp1
-rw-r--r--util/network/sock.h608
-rw-r--r--util/network/sock_ut.cpp168
-rw-r--r--util/network/socket.cpp1288
-rw-r--r--util/network/socket.h431
-rw-r--r--util/network/socket_ut.cpp341
-rw-r--r--util/network/ut/ya.make23
-rw-r--r--util/network/ya.make6
34 files changed, 5332 insertions, 0 deletions
diff --git a/util/network/address.cpp b/util/network/address.cpp
new file mode 100644
index 0000000000..a81a9e6994
--- /dev/null
+++ b/util/network/address.cpp
@@ -0,0 +1,204 @@
+#include <util/stream/str.h>
+
+#include "address.h"
+
+#if defined(_unix_)
+ #include <sys/types.h>
+ #include <sys/un.h>
+#endif
+
+using namespace NAddr;
+
+template <bool printPort>
+static inline void PrintAddr(IOutputStream& out, const IRemoteAddr& addr) {
+ const sockaddr* a = addr.Addr();
+ char buf[INET6_ADDRSTRLEN + 10];
+
+ switch (a->sa_family) {
+ case AF_INET: {
+ const TIpAddress sa(*(const sockaddr_in*)a);
+
+ out << IpToString(sa.Host(), buf, sizeof(buf));
+
+ if (printPort) {
+ out << ":" << sa.Port();
+ }
+
+ break;
+ }
+
+ case AF_INET6: {
+ const sockaddr_in6* sa = (const sockaddr_in6*)a;
+
+ if (!inet_ntop(AF_INET6, (void*)&sa->sin6_addr.s6_addr, buf, sizeof(buf))) {
+ ythrow TSystemError() << "inet_ntop() failed";
+ }
+
+ if (printPort) {
+ out << "[" << buf << "]"
+ << ":" << InetToHost(sa->sin6_port);
+ } else {
+ out << buf;
+ }
+
+ break;
+ }
+
+#if defined(AF_UNIX)
+ case AF_UNIX: {
+ const sockaddr_un* sa = (const sockaddr_un*)a;
+
+ out << TStringBuf(sa->sun_path);
+
+ break;
+ }
+#endif
+
+ default: {
+ size_t len = addr.Len();
+
+ const char* b = (const char*)a;
+ const char* e = b + len;
+
+ bool allZeros = true;
+ for (size_t i = 0; i < len; ++i) {
+ if (b[i] != 0) {
+ allZeros = false;
+ break;
+ }
+ }
+
+ if (allZeros) {
+ out << "(raw all zeros)";
+ } else {
+ out << "(raw " << (int)a->sa_family << " ";
+
+ while (b != e) {
+ //just print raw bytes
+ out << (int)*b++;
+ if (b != e) {
+ out << " ";
+ }
+ }
+
+ out << ")";
+ }
+
+ break;
+ }
+ }
+}
+
+template <>
+void Out<IRemoteAddr>(IOutputStream& out, const IRemoteAddr& addr) {
+ PrintAddr<true>(out, addr);
+}
+
+template <>
+void Out<NAddr::TAddrInfo>(IOutputStream& out, const NAddr::TAddrInfo& addr) {
+ PrintAddr<true>(out, addr);
+}
+
+template <>
+void Out<NAddr::TIPv4Addr>(IOutputStream& out, const NAddr::TIPv4Addr& addr) {
+ PrintAddr<true>(out, addr);
+}
+
+template <>
+void Out<NAddr::TIPv6Addr>(IOutputStream& out, const NAddr::TIPv6Addr& addr) {
+ PrintAddr<true>(out, addr);
+}
+
+template <>
+void Out<NAddr::TOpaqueAddr>(IOutputStream& out, const NAddr::TOpaqueAddr& addr) {
+ PrintAddr<true>(out, addr);
+}
+
+void NAddr::PrintHost(IOutputStream& out, const IRemoteAddr& addr) {
+ PrintAddr<false>(out, addr);
+}
+
+TString NAddr::PrintHost(const IRemoteAddr& addr) {
+ TStringStream ss;
+ PrintAddr<false>(ss, addr);
+ return ss.Str();
+}
+
+TString NAddr::PrintHostAndPort(const IRemoteAddr& addr) {
+ TStringStream ss;
+ PrintAddr<true>(ss, addr);
+ return ss.Str();
+}
+
+IRemoteAddrPtr NAddr::GetSockAddr(SOCKET s) {
+ auto addr = MakeHolder<TOpaqueAddr>();
+
+ if (getsockname(s, addr->MutableAddr(), addr->LenPtr()) < 0) {
+ ythrow TSystemError() << "getsockname() failed";
+ }
+
+ return addr;
+}
+
+IRemoteAddrPtr NAddr::GetPeerAddr(SOCKET s) {
+ auto addr = MakeHolder<TOpaqueAddr>();
+
+ if (getpeername(s, addr->MutableAddr(), addr->LenPtr()) < 0) {
+ ythrow TSystemError() << "getpeername() failed";
+ }
+
+ return addr;
+}
+
+static const in_addr& InAddr(const IRemoteAddr& addr) {
+ return ((const sockaddr_in*)addr.Addr())->sin_addr;
+}
+
+static const in6_addr& In6Addr(const IRemoteAddr& addr) {
+ return ((const sockaddr_in6*)addr.Addr())->sin6_addr;
+}
+
+bool NAddr::IsLoopback(const IRemoteAddr& addr) {
+ if (addr.Addr()->sa_family == AF_INET) {
+ return ((ntohl(InAddr(addr).s_addr) >> 24) & 0xff) == 127;
+ }
+
+ if (addr.Addr()->sa_family == AF_INET6) {
+ return 0 == memcmp(&In6Addr(addr), &in6addr_loopback, sizeof(in6_addr));
+ }
+
+ return false;
+}
+
+bool NAddr::IsSame(const IRemoteAddr& lhs, const IRemoteAddr& rhs) {
+ if (lhs.Addr()->sa_family != rhs.Addr()->sa_family) {
+ return false;
+ }
+
+ if (lhs.Addr()->sa_family == AF_INET) {
+ return InAddr(lhs).s_addr == InAddr(rhs).s_addr;
+ }
+
+ if (lhs.Addr()->sa_family == AF_INET6) {
+ return 0 == memcmp(&In6Addr(lhs), &In6Addr(rhs), sizeof(in6_addr));
+ }
+
+ ythrow yexception() << "unsupported addr family: " << lhs.Addr()->sa_family;
+}
+
+socklen_t NAddr::SockAddrLength(const sockaddr* addr) {
+ switch (addr->sa_family) {
+ case AF_INET:
+ return sizeof(sockaddr_in);
+
+ case AF_INET6:
+ return sizeof(sockaddr_in6);
+
+#if defined(AF_LOCAL)
+ case AF_LOCAL:
+ return sizeof(sockaddr_un);
+#endif
+ }
+
+ ythrow yexception() << "unsupported address family: " << addr->sa_family;
+}
diff --git a/util/network/address.h b/util/network/address.h
new file mode 100644
index 0000000000..448fcac0c9
--- /dev/null
+++ b/util/network/address.h
@@ -0,0 +1,136 @@
+#pragma once
+
+#include "ip.h"
+#include "socket.h"
+
+#include <util/generic/ptr.h>
+#include <util/generic/string.h>
+
+namespace NAddr {
+ class IRemoteAddr {
+ public:
+ virtual ~IRemoteAddr() = default;
+
+ virtual const sockaddr* Addr() const = 0;
+ virtual socklen_t Len() const = 0;
+ };
+
+ using IRemoteAddrPtr = THolder<IRemoteAddr>;
+ using IRemoteAddrRef = TAtomicSharedPtr<NAddr::IRemoteAddr>;
+
+ IRemoteAddrPtr GetSockAddr(SOCKET s);
+ IRemoteAddrPtr GetPeerAddr(SOCKET s);
+ void PrintHost(IOutputStream& out, const IRemoteAddr& addr);
+
+ TString PrintHost(const IRemoteAddr& addr);
+ TString PrintHostAndPort(const IRemoteAddr& addr);
+
+ bool IsLoopback(const IRemoteAddr& addr);
+ bool IsSame(const IRemoteAddr& lhs, const IRemoteAddr& rhs);
+
+ socklen_t SockAddrLength(const sockaddr* addr);
+
+ //for accept, recvfrom - see LenPtr()
+ class TOpaqueAddr: public IRemoteAddr {
+ public:
+ inline TOpaqueAddr() noexcept
+ : L_(sizeof(S_))
+ {
+ Zero(S_);
+ }
+
+ inline TOpaqueAddr(const IRemoteAddr* addr) noexcept {
+ Assign(addr->Addr(), addr->Len());
+ }
+
+ inline TOpaqueAddr(const sockaddr* addr) {
+ Assign(addr, SockAddrLength(addr));
+ }
+
+ const sockaddr* Addr() const override {
+ return MutableAddr();
+ }
+
+ socklen_t Len() const override {
+ return L_;
+ }
+
+ inline sockaddr* MutableAddr() const noexcept {
+ return (sockaddr*)&S_;
+ }
+
+ inline socklen_t* LenPtr() noexcept {
+ return &L_;
+ }
+
+ private:
+ inline void Assign(const sockaddr* addr, socklen_t len) noexcept {
+ L_ = len;
+ memcpy(MutableAddr(), addr, L_);
+ }
+
+ private:
+ sockaddr_storage S_;
+ socklen_t L_;
+ };
+
+ //for TNetworkAddress
+ class TAddrInfo: public IRemoteAddr {
+ public:
+ inline TAddrInfo(const addrinfo* ai) noexcept
+ : AI_(ai)
+ {
+ }
+
+ const sockaddr* Addr() const override {
+ return AI_->ai_addr;
+ }
+
+ socklen_t Len() const override {
+ return (socklen_t)AI_->ai_addrlen;
+ }
+
+ private:
+ const addrinfo* const AI_;
+ };
+
+ //compat, for TIpAddress
+ class TIPv4Addr: public IRemoteAddr {
+ public:
+ inline TIPv4Addr(const TIpAddress& addr) noexcept
+ : A_(addr)
+ {
+ }
+
+ const sockaddr* Addr() const override {
+ return A_;
+ }
+
+ socklen_t Len() const override {
+ return A_;
+ }
+
+ private:
+ const TIpAddress A_;
+ };
+
+ //same, for ipv6 addresses
+ class TIPv6Addr: public IRemoteAddr {
+ public:
+ inline TIPv6Addr(const sockaddr_in6& a) noexcept
+ : A_(a)
+ {
+ }
+
+ const sockaddr* Addr() const override {
+ return (sockaddr*)&A_;
+ }
+
+ socklen_t Len() const override {
+ return sizeof(A_);
+ }
+
+ private:
+ const sockaddr_in6 A_;
+ };
+}
diff --git a/util/network/address_ut.cpp b/util/network/address_ut.cpp
new file mode 100644
index 0000000000..28f45172ff
--- /dev/null
+++ b/util/network/address_ut.cpp
@@ -0,0 +1,39 @@
+#include <library/cpp/testing/unittest/registar.h>
+
+#include "address.h"
+
+using namespace NAddr;
+
+Y_UNIT_TEST_SUITE(IRemoteAddr_ToString) {
+ Y_UNIT_TEST(Raw) {
+ THolder<TOpaqueAddr> opaque(new TOpaqueAddr);
+ IRemoteAddr* addr = opaque.Get();
+
+ TString s = ToString(*addr);
+ UNIT_ASSERT_VALUES_EQUAL("(raw all zeros)", s);
+
+ opaque->MutableAddr()->sa_data[10] = 17;
+
+ TString t = ToString(*addr);
+
+ UNIT_ASSERT_C(t.StartsWith("(raw 0 0"), t);
+ UNIT_ASSERT_C(t.EndsWith(')'), t);
+ }
+
+ Y_UNIT_TEST(Ipv6) {
+ TNetworkAddress address("::1", 22);
+ TNetworkAddress::TIterator it = address.Begin();
+ UNIT_ASSERT(it != address.End());
+ UNIT_ASSERT(it->ai_family == AF_INET6);
+ TString toString = ToString((const IRemoteAddr&)TAddrInfo(&*it));
+ UNIT_ASSERT_VALUES_EQUAL(TString("[::1]:22"), toString);
+ }
+
+ Y_UNIT_TEST(Loopback) {
+ TNetworkAddress localAddress("127.70.0.1", 22);
+ UNIT_ASSERT_VALUES_EQUAL(NAddr::IsLoopback(TAddrInfo(&*localAddress.Begin())), true);
+
+ TNetworkAddress localAddress2("127.0.0.1", 22);
+ UNIT_ASSERT_VALUES_EQUAL(NAddr::IsLoopback(TAddrInfo(&*localAddress2.Begin())), true);
+ }
+}
diff --git a/util/network/endpoint.cpp b/util/network/endpoint.cpp
new file mode 100644
index 0000000000..9acdd06940
--- /dev/null
+++ b/util/network/endpoint.cpp
@@ -0,0 +1,67 @@
+#include "endpoint.h"
+#include "sock.h"
+
+TEndpoint::TEndpoint(const TEndpoint::TAddrRef& addr)
+ : Addr_(addr)
+{
+ const sockaddr* sa = Addr_->Addr();
+
+ if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6 && sa->sa_family != AF_UNIX) {
+ ythrow yexception() << TStringBuf("endpoint can contain only ipv4, ipv6 or unix address");
+ }
+}
+
+TEndpoint::TEndpoint()
+ : Addr_(new NAddr::TIPv4Addr(TIpAddress(TIpHost(0), TIpPort(0))))
+{
+}
+
+void TEndpoint::SetPort(ui16 port) {
+ if (Port() == port || Addr_->Addr()->sa_family == AF_UNIX) {
+ return;
+ }
+
+ NAddr::TOpaqueAddr* oa = new NAddr::TOpaqueAddr(Addr_.Get());
+ Addr_.Reset(oa);
+ sockaddr* sa = oa->MutableAddr();
+
+ if (sa->sa_family == AF_INET) {
+ ((sockaddr_in*)sa)->sin_port = HostToInet(port);
+ } else {
+ ((sockaddr_in6*)sa)->sin6_port = HostToInet(port);
+ }
+}
+
+ui16 TEndpoint::Port() const noexcept {
+ if (Addr_->Addr()->sa_family == AF_UNIX) {
+ return 0;
+ }
+
+ const sockaddr* sa = Addr_->Addr();
+
+ if (sa->sa_family == AF_INET) {
+ return InetToHost(((const sockaddr_in*)sa)->sin_port);
+ } else {
+ return InetToHost(((const sockaddr_in6*)sa)->sin6_port);
+ }
+}
+
+size_t TEndpoint::Hash() const {
+ const sockaddr* sa = Addr_->Addr();
+
+ if (sa->sa_family == AF_INET) {
+ const sockaddr_in* sa4 = (const sockaddr_in*)sa;
+
+ return IntHash((((ui64)sa4->sin_addr.s_addr) << 16) ^ sa4->sin_port);
+ } else if (sa->sa_family == AF_INET6) {
+ const sockaddr_in6* sa6 = (const sockaddr_in6*)sa;
+ const ui64* ptr = (const ui64*)&sa6->sin6_addr;
+
+ return IntHash(ptr[0] ^ ptr[1] ^ sa6->sin6_port);
+ } else {
+ const sockaddr_un* un = (const sockaddr_un*)sa;
+ THash<TString> strHash;
+
+ return strHash(un->sun_path);
+ }
+}
diff --git a/util/network/endpoint.h b/util/network/endpoint.h
new file mode 100644
index 0000000000..a3e59b4925
--- /dev/null
+++ b/util/network/endpoint.h
@@ -0,0 +1,61 @@
+#pragma once
+
+#include "address.h"
+
+#include <util/str_stl.h>
+
+//some equivalent boost::asio::ip::endpoint (easy for using pair ip:port)
+class TEndpoint {
+public:
+ using TAddrRef = NAddr::IRemoteAddrRef;
+
+ TEndpoint(const TAddrRef& addr);
+ TEndpoint();
+
+ inline const TAddrRef& Addr() const noexcept {
+ return Addr_;
+ }
+ inline const sockaddr* SockAddr() const {
+ return Addr_->Addr();
+ }
+ inline socklen_t SockAddrLen() const {
+ return Addr_->Len();
+ }
+
+ inline bool IsIpV4() const {
+ return Addr_->Addr()->sa_family == AF_INET;
+ }
+ inline bool IsIpV6() const {
+ return Addr_->Addr()->sa_family == AF_INET6;
+ }
+ inline bool IsUnix() const {
+ return Addr_->Addr()->sa_family == AF_UNIX;
+ }
+
+ inline TString IpToString() const {
+ return NAddr::PrintHost(*Addr_);
+ }
+
+ void SetPort(ui16 port);
+ ui16 Port() const noexcept;
+
+ size_t Hash() const;
+
+private:
+ TAddrRef Addr_;
+};
+
+template <>
+struct THash<TEndpoint> {
+ inline size_t operator()(const TEndpoint& ep) const {
+ return ep.Hash();
+ }
+};
+
+inline bool operator==(const TEndpoint& l, const TEndpoint& r) {
+ try {
+ return NAddr::IsSame(*l.Addr(), *r.Addr()) && l.Port() == r.Port();
+ } catch (...) {
+ return false;
+ }
+}
diff --git a/util/network/endpoint_ut.cpp b/util/network/endpoint_ut.cpp
new file mode 100644
index 0000000000..d5e40dd6e1
--- /dev/null
+++ b/util/network/endpoint_ut.cpp
@@ -0,0 +1,123 @@
+#include "endpoint.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/generic/hash_set.h>
+#include <util/generic/strbuf.h>
+
+Y_UNIT_TEST_SUITE(TEndpointTest) {
+ Y_UNIT_TEST(TestSimple) {
+ TVector<TNetworkAddress> addrs;
+
+ TEndpoint ep0;
+
+ UNIT_ASSERT(ep0.IsIpV4());
+ UNIT_ASSERT_VALUES_EQUAL(0, ep0.Port());
+ UNIT_ASSERT_VALUES_EQUAL("0.0.0.0", ep0.IpToString());
+
+ TEndpoint ep1;
+
+ try {
+ TNetworkAddress na1("25.26.27.28", 24242);
+
+ addrs.push_back(na1);
+
+ ep1 = TEndpoint(new NAddr::TAddrInfo(&*na1.Begin()));
+
+ UNIT_ASSERT(ep1.IsIpV4());
+ UNIT_ASSERT_VALUES_EQUAL("25.26.27.28", ep1.IpToString());
+ UNIT_ASSERT_VALUES_EQUAL(24242, ep1.Port());
+ } catch (const TNetworkResolutionError&) {
+ TNetworkAddress n("2a02:6b8:0:1420:0::5f6c:f3c2", 11111);
+
+ addrs.push_back(n);
+
+ ep1 = TEndpoint(new NAddr::TAddrInfo(&*n.Begin()));
+ }
+
+ ep0.SetPort(12345);
+
+ TEndpoint ep2(ep0);
+
+ ep0.SetPort(0);
+
+ UNIT_ASSERT_VALUES_EQUAL(12345, ep2.Port());
+
+ TEndpoint ep2_;
+
+ ep2_.SetPort(12345);
+
+ UNIT_ASSERT(ep2 == ep2_);
+
+ TNetworkAddress na3("2a02:6b8:0:1410::5f6c:f3c2", 54321);
+ TEndpoint ep3(new NAddr::TAddrInfo(&*na3.Begin()));
+
+ UNIT_ASSERT(ep3.IsIpV6());
+ UNIT_ASSERT(ep3.IpToString().StartsWith(TStringBuf("2a02:6b8:0:1410:")));
+ UNIT_ASSERT(ep3.IpToString().EndsWith(TStringBuf(":5f6c:f3c2")));
+ UNIT_ASSERT_VALUES_EQUAL(54321, ep3.Port());
+
+ TNetworkAddress na4("2a02:6b8:0:1410:0::5f6c:f3c2", 1);
+ TEndpoint ep4(new NAddr::TAddrInfo(&*na4.Begin()));
+
+ TEndpoint ep3_ = ep4;
+
+ ep3_.SetPort(54321);
+
+ THashSet<TEndpoint> he;
+
+ he.insert(ep0);
+ he.insert(ep1);
+ he.insert(ep2);
+
+ UNIT_ASSERT_VALUES_EQUAL(3u, he.size());
+
+ he.insert(ep2_);
+
+ UNIT_ASSERT_VALUES_EQUAL(3u, he.size());
+
+ he.insert(ep3);
+ he.insert(ep3_);
+
+ UNIT_ASSERT_VALUES_EQUAL(4u, he.size());
+
+ he.insert(ep4);
+
+ UNIT_ASSERT_VALUES_EQUAL(5u, he.size());
+ }
+
+ Y_UNIT_TEST(TestEqual) {
+ const TString ip1 = "2a02:6b8:0:1410::5f6c:f3c2";
+ const TString ip2 = "2a02:6b8:0:1410::5f6c:f3c3";
+
+ TNetworkAddress na1(ip1, 24242);
+ TEndpoint ep1(new NAddr::TAddrInfo(&*na1.Begin()));
+
+ TNetworkAddress na2(ip1, 24242);
+ TEndpoint ep2(new NAddr::TAddrInfo(&*na2.Begin()));
+
+ TNetworkAddress na3(ip2, 24242);
+ TEndpoint ep3(new NAddr::TAddrInfo(&*na3.Begin()));
+
+ TNetworkAddress na4(ip2, 24243);
+ TEndpoint ep4(new NAddr::TAddrInfo(&*na4.Begin()));
+
+ UNIT_ASSERT(ep1 == ep2);
+ UNIT_ASSERT(!(ep1 == ep3));
+ UNIT_ASSERT(!(ep1 == ep4));
+ }
+
+ Y_UNIT_TEST(TestIsUnixSocket) {
+ TNetworkAddress na1(TUnixSocketPath("/tmp/unixsocket"));
+ TEndpoint ep1(new NAddr::TAddrInfo(&*na1.Begin()));
+
+ TNetworkAddress na2("2a02:6b8:0:1410::5f6c:f3c2", 24242);
+ TEndpoint ep2(new NAddr::TAddrInfo(&*na2.Begin()));
+
+ UNIT_ASSERT(ep1.IsUnix());
+ UNIT_ASSERT(ep1.SockAddr()->sa_family == AF_UNIX);
+
+ UNIT_ASSERT(!ep2.IsUnix());
+ UNIT_ASSERT(ep2.SockAddr()->sa_family != AF_UNIX);
+ }
+}
diff --git a/util/network/hostip.cpp b/util/network/hostip.cpp
new file mode 100644
index 0000000000..cb8d43bf90
--- /dev/null
+++ b/util/network/hostip.cpp
@@ -0,0 +1,76 @@
+#include "socket.h"
+#include "hostip.h"
+
+#include <util/system/defaults.h>
+#include <util/system/byteorder.h>
+
+#if defined(_unix_) || defined(_cygwin_)
+ #include <netdb.h>
+#endif
+
+#if !defined(BIND_LIB)
+ #if !defined(__FreeBSD__) && !defined(_win32_) && !defined(_cygwin_)
+ #define AGENT_USE_GETADDRINFO
+ #endif
+
+ #if defined(__FreeBSD__)
+ #define AGENT_USE_GETADDRINFO
+ #endif
+#endif
+
+int NResolver::GetHostIP(const char* hostname, ui32* ip, size_t* slots) {
+ size_t i = 0;
+ size_t ipsFound = 0;
+
+#ifdef AGENT_USE_GETADDRINFO
+ int ret = 0;
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET;
+ hints.ai_socktype = SOCK_STREAM;
+ struct addrinfo* gai_res = nullptr;
+ int gai_ret = getaddrinfo(hostname, nullptr, &hints, &gai_res);
+ if (gai_ret == 0 && gai_res->ai_addr) {
+ struct addrinfo* cur = gai_res;
+ for (i = 0; i < *slots && cur; i++, cur = cur->ai_next, ipsFound++) {
+ ip[i] = *(ui32*)(&((sockaddr_in*)(cur->ai_addr))->sin_addr);
+ }
+ } else {
+ if (gai_ret == EAI_NONAME || gai_ret == EAI_SERVICE) {
+ ret = HOST_NOT_FOUND;
+ } else {
+ ret = GetDnsError();
+ }
+ }
+
+ if (gai_res) {
+ freeaddrinfo(gai_res);
+ }
+
+ if (ret) {
+ return ret;
+ }
+#else
+ hostent* hostent = gethostbyname(hostname);
+
+ if (!hostent)
+ return GetDnsError();
+
+ if (hostent->h_addrtype != AF_INET || (unsigned)hostent->h_length < sizeof(ui32))
+ return HOST_NOT_FOUND;
+
+ char** cur = hostent->h_addr_list;
+ for (i = 0; i < *slots && *cur; i++, cur++, ipsFound++)
+ ip[i] = *(ui32*)*cur;
+#endif
+ for (i = 0; i < ipsFound; i++) {
+ ip[i] = InetToHost(ip[i]);
+ }
+ *slots = ipsFound;
+
+ return 0;
+}
+
+int NResolver::GetDnsError() {
+ return h_errno;
+}
diff --git a/util/network/hostip.h b/util/network/hostip.h
new file mode 100644
index 0000000000..cf63e4846a
--- /dev/null
+++ b/util/network/hostip.h
@@ -0,0 +1,16 @@
+#pragma once
+
+#include <util/system/defaults.h>
+
+namespace NResolver {
+ // resolve hostname and fills up to *slots slots in ip array;
+ // actual number of slots filled is returned in *slots;
+ int GetHostIP(const char* hostname, ui32* ip, size_t* slots);
+ int GetDnsError();
+
+ inline int GetHostIP(const char* hostname, ui32* ip) {
+ size_t slots = 1;
+
+ return GetHostIP(hostname, ip, &slots);
+ }
+}
diff --git a/util/network/init.cpp b/util/network/init.cpp
new file mode 100644
index 0000000000..366e65682c
--- /dev/null
+++ b/util/network/init.cpp
@@ -0,0 +1,34 @@
+#include "init.h"
+
+#include <util/system/compat.h>
+#include <util/system/yassert.h>
+#include <util/system/defaults.h>
+#include <util/generic/singleton.h>
+
+#include <cstdio>
+#include <cstdlib>
+
+namespace {
+ class TNetworkInit {
+ public:
+ inline TNetworkInit() {
+#ifndef ROBOT_SIGPIPE
+ signal(SIGPIPE, SIG_IGN);
+#endif
+
+#if defined(_win_)
+ #pragma comment(lib, "ws2_32.lib")
+ WSADATA wsaData;
+ int result = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ Y_ASSERT(!result);
+ if (result) {
+ exit(-1);
+ }
+#endif
+ }
+ };
+}
+
+void InitNetworkSubSystem() {
+ (void)Singleton<TNetworkInit>();
+}
diff --git a/util/network/init.h b/util/network/init.h
new file mode 100644
index 0000000000..08a79c0fca
--- /dev/null
+++ b/util/network/init.h
@@ -0,0 +1,60 @@
+#pragma once
+
+#include <util/system/error.h>
+
+#if defined(_unix_)
+ #include <fcntl.h>
+ #include <netdb.h>
+ #include <time.h>
+ #include <unistd.h>
+ #include <poll.h>
+
+ #include <sys/uio.h>
+ #include <sys/time.h>
+ #include <sys/types.h>
+ #include <sys/socket.h>
+
+ #include <netinet/in.h>
+ #include <netinet/tcp.h>
+ #include <arpa/inet.h>
+
+using SOCKET = int;
+
+ #define closesocket(s) close(s)
+ #define SOCKET_ERROR -1
+ #define INVALID_SOCKET -1
+ #define WSAGetLastError() errno
+#elif defined(_win_)
+ #include <util/system/winint.h>
+ #include <io.h>
+ #include <winsock2.h>
+ #include <ws2tcpip.h>
+
+using nfds_t = ULONG;
+
+ #undef Yield
+
+struct sockaddr_un {
+ short sun_family;
+ char sun_path[108];
+};
+
+ #define PF_LOCAL AF_UNIX
+ #define NETDB_INTERNAL -1
+ #define NETDB_SUCCESS 0
+
+#endif
+
+#if defined(_win_) || defined(_darwin_)
+ #ifndef MSG_NOSIGNAL
+ #define MSG_NOSIGNAL 0
+ #endif
+#endif // _win_ or _darwin_
+
+void InitNetworkSubSystem();
+
+static struct TNetworkInitializer {
+ inline TNetworkInitializer() {
+ InitNetworkSubSystem();
+ }
+} NetworkInitializerObject;
diff --git a/util/network/interface.cpp b/util/network/interface.cpp
new file mode 100644
index 0000000000..256776c6d3
--- /dev/null
+++ b/util/network/interface.cpp
@@ -0,0 +1,79 @@
+#include "interface.h"
+
+#if defined(_unix_)
+ #include <ifaddrs.h>
+#endif
+
+#ifdef _win_
+ #include <iphlpapi.h>
+ #pragma comment(lib, "Iphlpapi.lib")
+#endif
+
+namespace NAddr {
+ static bool IsInetAddress(sockaddr* addr) {
+ return (addr != nullptr) && ((addr->sa_family == AF_INET) || (addr->sa_family == AF_INET6));
+ }
+
+ TNetworkInterfaceList GetNetworkInterfaces() {
+ TNetworkInterfaceList result;
+
+#ifdef _win_
+ TVector<char> buf;
+ buf.resize(1000000);
+ PIP_ADAPTER_ADDRESSES adapterBuf = (PIP_ADAPTER_ADDRESSES)&buf[0];
+ ULONG bufSize = buf.ysize();
+
+ if (GetAdaptersAddresses(AF_UNSPEC, 0, nullptr, adapterBuf, &bufSize) == ERROR_SUCCESS) {
+ for (PIP_ADAPTER_ADDRESSES ptr = adapterBuf; ptr != 0; ptr = ptr->Next) {
+ // The check below makes code working on Vista+
+ if ((ptr->Flags & (IP_ADAPTER_IPV4_ENABLED | IP_ADAPTER_IPV6_ENABLED)) == 0) {
+ continue;
+ }
+ if (ptr->IfType == IF_TYPE_TUNNEL) {
+ // ignore tunnels
+ continue;
+ }
+ if (ptr->OperStatus != IfOperStatusUp) {
+ // ignore disable adapters
+ continue;
+ }
+
+ for (IP_ADAPTER_UNICAST_ADDRESS* addr = ptr->FirstUnicastAddress; addr != 0; addr = addr->Next) {
+ sockaddr* a = (sockaddr*)addr->Address.lpSockaddr;
+ if (IsInetAddress(a)) {
+ TNetworkInterface networkInterface;
+
+ // Not very efficient but straightforward
+ for (size_t i = 0; ptr->FriendlyName[i] != 0; i++) {
+ CHAR w = ptr->FriendlyName[i];
+ char c = (w < 0x80) ? char(w) : '?';
+ networkInterface.Name.append(1, c);
+ }
+
+ networkInterface.Address = new TOpaqueAddr(a);
+ result.push_back(networkInterface);
+ }
+ }
+ }
+ }
+#else
+ ifaddrs* ifap;
+ if (getifaddrs(&ifap) != -1) {
+ for (ifaddrs* ifa = ifap; ifa != nullptr; ifa = ifa->ifa_next) {
+ if (IsInetAddress(ifa->ifa_addr)) {
+ TNetworkInterface interface;
+ interface.Name = ifa->ifa_name;
+ interface.Address = new TOpaqueAddr(ifa->ifa_addr);
+ if (IsInetAddress(ifa->ifa_netmask)) {
+ interface.Mask = new TOpaqueAddr(ifa->ifa_netmask);
+ }
+ result.push_back(interface);
+ }
+ }
+ freeifaddrs(ifap);
+ }
+#endif
+
+ return result;
+ }
+}
diff --git a/util/network/interface.h b/util/network/interface.h
new file mode 100644
index 0000000000..dda4555021
--- /dev/null
+++ b/util/network/interface.h
@@ -0,0 +1,17 @@
+#pragma once
+
+#include "address.h"
+
+#include <util/generic/vector.h>
+
+namespace NAddr {
+ struct TNetworkInterface {
+ TString Name;
+ IRemoteAddrRef Address;
+ IRemoteAddrRef Mask;
+ };
+
+ using TNetworkInterfaceList = TVector<TNetworkInterface>;
+
+ TNetworkInterfaceList GetNetworkInterfaces();
+}
diff --git a/util/network/iovec.cpp b/util/network/iovec.cpp
new file mode 100644
index 0000000000..7251038848
--- /dev/null
+++ b/util/network/iovec.cpp
@@ -0,0 +1 @@
+#include "iovec.h"
diff --git a/util/network/iovec.h b/util/network/iovec.h
new file mode 100644
index 0000000000..ac15a41f54
--- /dev/null
+++ b/util/network/iovec.h
@@ -0,0 +1,65 @@
+#pragma once
+
+#include <util/stream/output.h>
+#include <util/system/types.h>
+#include <util/system/yassert.h>
+
+class TContIOVector {
+ using TPart = IOutputStream::TPart;
+
+public:
+ inline TContIOVector(TPart* parts, size_t count)
+ : Parts_(parts)
+ , Count_(count)
+ {
+ }
+
+ inline void Proceed(size_t len) noexcept {
+ while (Count_) {
+ if (len < Parts_->len) {
+ Parts_->len -= len;
+ Parts_->buf = (const char*)Parts_->buf + len;
+
+ return;
+ } else {
+ len -= Parts_->len;
+ --Count_;
+ ++Parts_;
+ }
+ }
+
+ if (len) {
+ Y_ASSERT(0 && "non zero length left");
+ }
+ }
+
+ inline const TPart* Parts() const noexcept {
+ return Parts_;
+ }
+
+ inline size_t Count() const noexcept {
+ return Count_;
+ }
+
+ static inline size_t Bytes(const TPart* parts, size_t count) noexcept {
+ size_t ret = 0;
+
+ for (size_t i = 0; i < count; ++i) {
+ ret += parts[i].len;
+ }
+
+ return ret;
+ }
+
+ inline size_t Bytes() const noexcept {
+ return Bytes(Parts_, Count_);
+ }
+
+ inline bool Complete() const noexcept {
+ return !Count();
+ }
+
+private:
+ TPart* Parts_;
+ size_t Count_;
+};
diff --git a/util/network/ip.cpp b/util/network/ip.cpp
new file mode 100644
index 0000000000..a43bcdadcf
--- /dev/null
+++ b/util/network/ip.cpp
@@ -0,0 +1 @@
+#include "ip.h"
diff --git a/util/network/ip.h b/util/network/ip.h
new file mode 100644
index 0000000000..dc7c2d24a0
--- /dev/null
+++ b/util/network/ip.h
@@ -0,0 +1,119 @@
+#pragma once
+
+#include "socket.h"
+#include "hostip.h"
+
+#include <util/system/error.h>
+#include <util/system/byteorder.h>
+#include <util/generic/string.h>
+#include <util/generic/yexception.h>
+
+/// IPv4 address in network format
+using TIpHost = ui32;
+
+/// Port number in host format
+using TIpPort = ui16;
+
+/*
+ * ipStr is in 'ddd.ddd.ddd.ddd' format
+ * returns IPv4 address in inet format
+ */
+static inline TIpHost IpFromString(const char* ipStr) {
+ in_addr ia;
+
+ if (inet_aton(ipStr, &ia) == 0) {
+ ythrow TSystemError() << "Failed to convert (" << ipStr << ") to ip address";
+ }
+
+ return (ui32)ia.s_addr;
+}
+
+static inline char* IpToString(TIpHost ip, char* buf, size_t len) {
+ if (!inet_ntop(AF_INET, (void*)&ip, buf, (socklen_t)len)) {
+ ythrow TSystemError() << "Failed to get ip address string";
+ }
+
+ return buf;
+}
+
+static inline TString IpToString(TIpHost ip) {
+ char buf[INET_ADDRSTRLEN];
+
+ return TString(IpToString(ip, buf, sizeof(buf)));
+}
+
+static inline TIpHost ResolveHost(const char* data, size_t len) {
+ TIpHost ret;
+ const TString s(data, len);
+
+ if (NResolver::GetHostIP(s.data(), &ret) != 0) {
+ ythrow TSystemError(NResolver::GetDnsError()) << "can not resolve(" << s << ")";
+ }
+
+ return HostToInet(ret);
+}
+
+/// socket address
+struct TIpAddress: public sockaddr_in {
+ inline TIpAddress() noexcept {
+ Clear();
+ }
+
+ inline TIpAddress(const sockaddr_in& addr) noexcept
+ : sockaddr_in(addr)
+ , tmp(0)
+ {
+ }
+
+ inline TIpAddress(TIpHost ip, TIpPort port) noexcept {
+ Set(ip, port);
+ }
+
+ inline TIpAddress(TStringBuf ip, TIpPort port) {
+ Set(ResolveHost(ip.data(), ip.size()), port);
+ }
+
+ inline TIpAddress(const char* ip, TIpPort port) {
+ Set(ResolveHost(ip, strlen(ip)), port);
+ }
+
+ inline operator sockaddr*() const noexcept {
+ return (sockaddr*)(sockaddr_in*)this;
+ }
+
+ inline operator socklen_t*() const noexcept {
+ tmp = sizeof(sockaddr_in);
+
+ return (socklen_t*)&tmp;
+ }
+
+ inline operator socklen_t() const noexcept {
+ tmp = sizeof(sockaddr_in);
+
+ return tmp;
+ }
+
+ inline void Clear() noexcept {
+ Zero((sockaddr_in&)(*this));
+ }
+
+ inline void Set(TIpHost ip, TIpPort port) noexcept {
+ Clear();
+
+ sin_family = AF_INET;
+ sin_addr.s_addr = ip;
+ sin_port = HostToInet(port);
+ }
+
+ inline TIpHost Host() const noexcept {
+ return sin_addr.s_addr;
+ }
+
+ inline TIpPort Port() const noexcept {
+ return InetToHost(sin_port);
+ }
+
+private:
+ // required for "operator socklen_t*()"
+ mutable socklen_t tmp;
+};
diff --git a/util/network/ip_ut.cpp b/util/network/ip_ut.cpp
new file mode 100644
index 0000000000..6716c6a699
--- /dev/null
+++ b/util/network/ip_ut.cpp
@@ -0,0 +1,63 @@
+#include "ip.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/generic/yexception.h>
+
+class TSysIpTest: public TTestBase {
+ UNIT_TEST_SUITE(TSysIpTest);
+ UNIT_TEST(TestIpFromString);
+ UNIT_TEST_EXCEPTION(TestIpFromString2, yexception);
+ UNIT_TEST_EXCEPTION(TestIpFromString3, yexception);
+ UNIT_TEST_EXCEPTION(TestIpFromString4, yexception);
+ UNIT_TEST_EXCEPTION(TestIpFromString5, yexception);
+ UNIT_TEST(TestIpToString);
+ UNIT_TEST_SUITE_END();
+
+private:
+ void TestIpFromString();
+ void TestIpFromString2();
+ void TestIpFromString3();
+ void TestIpFromString4();
+ void TestIpFromString5();
+ void TestIpToString();
+};
+
+UNIT_TEST_SUITE_REGISTRATION(TSysIpTest);
+
+void TSysIpTest::TestIpFromString() {
+ const char* ipStr[] = {"192.168.0.1", "87.255.18.167", "255.255.0.31", "188.225.124.255"};
+ ui8 ipArr[][4] = {{192, 168, 0, 1}, {87, 255, 18, 167}, {255, 255, 0, 31}, {188, 225, 124, 255}};
+
+ for (size_t i = 0; i < Y_ARRAY_SIZE(ipStr); ++i) {
+ const ui32 ip = IpFromString(ipStr[i]);
+
+ UNIT_ASSERT(memcmp(&ip, ipArr[i], sizeof(ui32)) == 0);
+ }
+}
+
+void TSysIpTest::TestIpFromString2() {
+ IpFromString("XXXXXXWXW");
+}
+
+void TSysIpTest::TestIpFromString3() {
+ IpFromString("986.0.37.255");
+}
+
+void TSysIpTest::TestIpFromString4() {
+ IpFromString("256.0.22.365");
+}
+
+void TSysIpTest::TestIpFromString5() {
+ IpFromString("245.12..0");
+}
+
+void TSysIpTest::TestIpToString() {
+ ui8 ipArr[][4] = {{192, 168, 0, 1}, {87, 255, 18, 167}, {255, 255, 0, 31}, {188, 225, 124, 255}};
+
+ const char* ipStr[] = {"192.168.0.1", "87.255.18.167", "255.255.0.31", "188.225.124.255"};
+
+ for (size_t i = 0; i < Y_ARRAY_SIZE(ipStr); ++i) {
+ UNIT_ASSERT(IpToString(*reinterpret_cast<TIpHost*>(&(ipArr[i]))) == ipStr[i]);
+ }
+}
diff --git a/util/network/nonblock.cpp b/util/network/nonblock.cpp
new file mode 100644
index 0000000000..e515c27cc5
--- /dev/null
+++ b/util/network/nonblock.cpp
@@ -0,0 +1,104 @@
+#include "nonblock.h"
+
+#include <util/system/platform.h>
+
+#include <util/generic/singleton.h>
+
+#if defined(_unix_)
+ #include <dlfcn.h>
+#endif
+
+#if defined(_linux_)
+ #if !defined(SOCK_NONBLOCK)
+ #define SOCK_NONBLOCK 04000
+ #endif
+#endif
+
+namespace {
+ struct TFeatureCheck {
+ inline TFeatureCheck()
+ : Accept4(nullptr)
+ , HaveSockNonBlock(false)
+ {
+#if defined(_unix_) && defined(SOCK_NONBLOCK)
+ {
+ Accept4 = reinterpret_cast<TAccept4>(dlsym(RTLD_DEFAULT, "accept4"));
+
+ #if defined(_musl_)
+ //musl always statically linked
+ if (!Accept4) {
+ Accept4 = accept4;
+ }
+ #endif
+
+ if (Accept4) {
+ Accept4(-1, nullptr, nullptr, SOCK_NONBLOCK);
+
+ if (errno == ENOSYS) {
+ Accept4 = nullptr;
+ }
+ }
+ }
+#endif
+
+#if defined(SOCK_NONBLOCK)
+ {
+ TSocketHolder tmp(socket(PF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0));
+
+ HaveSockNonBlock = !tmp.Closed();
+ }
+#endif
+ }
+
+ inline SOCKET FastAccept(SOCKET s, struct sockaddr* addr, socklen_t* addrlen) const {
+#if defined(SOCK_NONBLOCK)
+ if (Accept4) {
+ return Accept4(s, addr, addrlen, SOCK_NONBLOCK);
+ }
+#endif
+
+ const SOCKET ret = accept(s, addr, addrlen);
+
+#if !defined(_freebsd_)
+ //freebsd inherit O_NONBLOCK flag
+ if (ret != INVALID_SOCKET) {
+ SetNonBlock(ret);
+ }
+#endif
+
+ return ret;
+ }
+
+ inline SOCKET FastSocket(int domain, int type, int protocol) const {
+#if defined(SOCK_NONBLOCK)
+ if (HaveSockNonBlock) {
+ return socket(domain, type | SOCK_NONBLOCK, protocol);
+ }
+#endif
+
+ const SOCKET ret = socket(domain, type, protocol);
+
+ if (ret != INVALID_SOCKET) {
+ SetNonBlock(ret);
+ }
+
+ return ret;
+ }
+
+ static inline const TFeatureCheck* Instance() noexcept {
+ return Singleton<TFeatureCheck>();
+ }
+
+ using TAccept4 = int (*)(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags);
+ TAccept4 Accept4;
+ bool HaveSockNonBlock;
+ };
+}
+
+SOCKET Accept4(SOCKET s, struct sockaddr* addr, socklen_t* addrlen) {
+ return TFeatureCheck::Instance()->FastAccept(s, addr, addrlen);
+}
+
+SOCKET Socket4(int domain, int type, int protocol) {
+ return TFeatureCheck::Instance()->FastSocket(domain, type, protocol);
+}
diff --git a/util/network/nonblock.h b/util/network/nonblock.h
new file mode 100644
index 0000000000..54e5e44ae3
--- /dev/null
+++ b/util/network/nonblock.h
@@ -0,0 +1,8 @@
+#pragma once
+
+#include "socket.h"
+
+//assume s is non-blocking, return non-blocking socket
+SOCKET Accept4(SOCKET s, struct sockaddr* addr, socklen_t* addrlen);
+//create non-blocking socket
+SOCKET Socket4(int domain, int type, int protocol);
diff --git a/util/network/pair.cpp b/util/network/pair.cpp
new file mode 100644
index 0000000000..9751ef5c96
--- /dev/null
+++ b/util/network/pair.cpp
@@ -0,0 +1,97 @@
+#include "pair.h"
+
+int SocketPair(SOCKET socks[2], bool overlapped, bool cloexec) {
+#if defined(_win_)
+ struct sockaddr_in addr;
+ SOCKET listener;
+ int e;
+ int addrlen = sizeof(addr);
+ DWORD flags = (overlapped ? WSA_FLAG_OVERLAPPED : 0) | (cloexec ? WSA_FLAG_NO_HANDLE_INHERIT : 0);
+
+ if (socks == 0) {
+ WSASetLastError(WSAEINVAL);
+
+ return SOCKET_ERROR;
+ }
+
+ socks[0] = INVALID_SOCKET;
+ socks[1] = INVALID_SOCKET;
+
+ if ((listener = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) {
+ return SOCKET_ERROR;
+ }
+
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = htonl(0x7f000001);
+ addr.sin_port = 0;
+
+ e = bind(listener, (const struct sockaddr*)&addr, sizeof(addr));
+
+ if (e == SOCKET_ERROR) {
+ e = WSAGetLastError();
+ closesocket(listener);
+ WSASetLastError(e);
+
+ return SOCKET_ERROR;
+ }
+
+ e = getsockname(listener, (struct sockaddr*)&addr, &addrlen);
+
+ if (e == SOCKET_ERROR) {
+ e = WSAGetLastError();
+ closesocket(listener);
+ WSASetLastError(e);
+
+ return SOCKET_ERROR;
+ }
+
+ do {
+ if (listen(listener, 1) == SOCKET_ERROR)
+ break;
+
+ if ((socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, nullptr, 0, flags)) == INVALID_SOCKET)
+ break;
+
+ if (connect(socks[0], (const struct sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR)
+ break;
+
+ if ((socks[1] = accept(listener, nullptr, nullptr)) == INVALID_SOCKET)
+ break;
+
+ closesocket(listener);
+
+ return 0;
+ } while (0);
+
+ e = WSAGetLastError();
+ closesocket(listener);
+ closesocket(socks[0]);
+ closesocket(socks[1]);
+ WSASetLastError(e);
+
+ return SOCKET_ERROR;
+#else
+ (void)overlapped;
+
+ #if defined(_linux_)
+ return socketpair(AF_LOCAL, SOCK_STREAM | (cloexec ? SOCK_CLOEXEC : 0), 0, socks);
+ #else
+ int r = socketpair(AF_LOCAL, SOCK_STREAM, 0, socks);
+ // Non-atomic wrt exec
+ if (r == 0 && cloexec) {
+ for (int i = 0; i < 2; ++i) {
+ int flags = fcntl(socks[i], F_GETFD, 0);
+ if (flags < 0) {
+ return flags;
+ }
+ r = fcntl(socks[i], F_SETFD, flags | FD_CLOEXEC);
+ if (r < 0) {
+ return r;
+ }
+ }
+ }
+ return r;
+ #endif
+#endif
+}
diff --git a/util/network/pair.h b/util/network/pair.h
new file mode 100644
index 0000000000..0d4506f880
--- /dev/null
+++ b/util/network/pair.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "init.h"
+
+int SocketPair(SOCKET socks[2], bool overlapped, bool cloexec = false);
+
+static inline int SocketPair(SOCKET socks[2]) {
+ return SocketPair(socks, false, false);
+}
diff --git a/util/network/poller.cpp b/util/network/poller.cpp
new file mode 100644
index 0000000000..7954d0e8b5
--- /dev/null
+++ b/util/network/poller.cpp
@@ -0,0 +1,86 @@
+#include "poller.h"
+#include "pollerimpl.h"
+
+#include <util/memory/tempbuf.h>
+
+namespace {
+ struct TMutexLocking {
+ using TMyMutex = TMutex;
+ };
+}
+
+class TSocketPoller::TImpl: public TPollerImpl<TMutexLocking> {
+public:
+ inline size_t DoWaitReal(void** ev, TEvent* events, size_t len, const TInstant& deadLine) {
+ const size_t ret = WaitD(events, len, deadLine);
+
+ for (size_t i = 0; i < ret; ++i) {
+ ev[i] = ExtractEvent(&events[i]);
+ }
+
+ return ret;
+ }
+
+ inline size_t DoWait(void** ev, size_t len, const TInstant& deadLine) {
+ if (len == 1) {
+ TEvent tmp;
+
+ return DoWaitReal(ev, &tmp, 1, deadLine);
+ } else {
+ TTempArray<TEvent> tmpEvents(len);
+
+ return DoWaitReal(ev, tmpEvents.Data(), len, deadLine);
+ }
+ }
+};
+
+TSocketPoller::TSocketPoller()
+ : Impl_(new TImpl())
+{
+}
+
+TSocketPoller::~TSocketPoller() = default;
+
+void TSocketPoller::WaitRead(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ);
+}
+
+void TSocketPoller::WaitWrite(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_WRITE);
+}
+
+void TSocketPoller::WaitReadWrite(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE);
+}
+
+void TSocketPoller::WaitRdhup(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_RDHUP);
+}
+
+void TSocketPoller::WaitReadOneShot(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_ONE_SHOT);
+}
+
+void TSocketPoller::WaitWriteOneShot(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_WRITE | CONT_POLL_ONE_SHOT);
+}
+
+void TSocketPoller::WaitReadWriteOneShot(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_ONE_SHOT);
+}
+
+void TSocketPoller::WaitReadWriteEdgeTriggered(SOCKET sock, void* cookie) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_EDGE_TRIGGERED);
+}
+
+void TSocketPoller::RestartReadWriteEdgeTriggered(SOCKET sock, void* cookie, bool empty) {
+ Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_MODIFY | CONT_POLL_EDGE_TRIGGERED | (empty ? CONT_POLL_BACKLOG_EMPTY : 0));
+}
+
+void TSocketPoller::Unwait(SOCKET sock) {
+ Impl_->Remove(sock);
+}
+
+size_t TSocketPoller::WaitD(void** ev, size_t len, const TInstant& deadLine) {
+ return Impl_->DoWait(ev, len, deadLine);
+}
diff --git a/util/network/poller.h b/util/network/poller.h
new file mode 100644
index 0000000000..8dccd73140
--- /dev/null
+++ b/util/network/poller.h
@@ -0,0 +1,58 @@
+#pragma once
+
+#include "socket.h"
+
+#include <util/generic/ptr.h>
+#include <util/datetime/base.h>
+
+class TSocketPoller {
+public:
+ TSocketPoller();
+ ~TSocketPoller();
+
+ void WaitRead(SOCKET sock, void* cookie);
+ void WaitWrite(SOCKET sock, void* cookie);
+ void WaitReadWrite(SOCKET sock, void* cookie);
+ void WaitRdhup(SOCKET sock, void* cookie);
+
+ void WaitReadOneShot(SOCKET sock, void* cookie);
+ void WaitWriteOneShot(SOCKET sock, void* cookie);
+ void WaitReadWriteOneShot(SOCKET sock, void* cookie);
+
+ void WaitReadWriteEdgeTriggered(SOCKET sock, void* cookie);
+ void RestartReadWriteEdgeTriggered(SOCKET sock, void* cookie, bool empty = true);
+
+ void Unwait(SOCKET sock);
+
+ size_t WaitD(void** events, size_t len, const TInstant& deadLine);
+
+ inline size_t WaitT(void** events, size_t len, const TDuration& timeOut) {
+ return WaitD(events, len, timeOut.ToDeadLine());
+ }
+
+ inline size_t WaitI(void** events, size_t len) {
+ return WaitD(events, len, TInstant::Max());
+ }
+
+ inline void* WaitD(const TInstant& deadLine) {
+ void* ret;
+
+ if (WaitD(&ret, 1, deadLine)) {
+ return ret;
+ }
+
+ return nullptr;
+ }
+
+ inline void* WaitT(const TDuration& timeOut) {
+ return WaitD(timeOut.ToDeadLine());
+ }
+
+ inline void* WaitI() {
+ return WaitD(TInstant::Max());
+ }
+
+private:
+ class TImpl;
+ THolder<TImpl> Impl_;
+};
diff --git a/util/network/poller_ut.cpp b/util/network/poller_ut.cpp
new file mode 100644
index 0000000000..6df0dda8ec
--- /dev/null
+++ b/util/network/poller_ut.cpp
@@ -0,0 +1,236 @@
+#include <library/cpp/testing/unittest/registar.h>
+#include <util/system/error.h>
+
+#include "pair.h"
+#include "poller.h"
+#include "pollerimpl.h"
+
+Y_UNIT_TEST_SUITE(TSocketPollerTest) {
+ Y_UNIT_TEST(TestSimple) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ TSocketPoller poller;
+ poller.WaitRead(sockets[1], (void*)17);
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ for (ui32 i = 0; i < 3; ++i) {
+ char buf[] = {18};
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0));
+ UNIT_ASSERT_VALUES_EQUAL(18, buf[0]);
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ }
+ }
+
+ Y_UNIT_TEST(TestSimpleOneShot) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ TSocketPoller poller;
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ for (ui32 i = 0; i < 3; ++i) {
+ poller.WaitReadOneShot(sockets[1], (void*)17);
+
+ char buf[1];
+
+ buf[0] = i + 20;
+
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 20), buf[0]);
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ buf[0] = i + 21;
+
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ // this fails if socket is not oneshot
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 21), buf[0]);
+ }
+ }
+
+ Y_UNIT_TEST(TestItIsSafeToUnregisterUnregisteredDescriptor) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ TSocketPoller poller;
+
+ poller.Unwait(s1);
+ }
+
+ Y_UNIT_TEST(TestItIsSafeToReregisterDescriptor) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ TSocketPoller poller;
+
+ poller.WaitRead(s1, nullptr);
+ poller.WaitRead(s1, nullptr);
+ poller.WaitWrite(s1, nullptr);
+ }
+
+ Y_UNIT_TEST(TestSimpleEdgeTriggered) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ SetNonBlock(sockets[1]);
+
+ TSocketPoller poller;
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ for (ui32 i = 0; i < 3; ++i) {
+ poller.WaitReadWriteEdgeTriggered(sockets[1], (void*)17);
+
+ // notify about writeble
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ char buf[2];
+
+ buf[0] = i + 10;
+ buf[1] = i + 20;
+
+ // send one byte
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // restart without reading
+ poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, false);
+
+ // after restart read and write might generate separate events
+ {
+ void* events[3];
+ size_t count = poller.WaitT(events, 3, TDuration::Zero());
+ UNIT_ASSERT_GE(count, 1);
+ UNIT_ASSERT_LE(count, 2);
+ UNIT_ASSERT_VALUES_EQUAL(events[0], (void*)17);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // second two more bytes
+ UNIT_ASSERT_VALUES_EQUAL(2, send(sockets[0], buf, 2, 0));
+
+ // here poller could notify or not because we haven't seen end
+ Y_UNUSED(poller.WaitT(TDuration::Zero()));
+
+ // recv one, leave two
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]);
+
+ // nothing new
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // recv the rest
+ UNIT_ASSERT_VALUES_EQUAL(2, recv(sockets[1], buf, 2, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]);
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 20), buf[1]);
+
+ // still nothing new
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // hit end
+ ClearLastSystemError();
+ UNIT_ASSERT_VALUES_EQUAL(-1, recv(sockets[1], buf, 1, 0));
+ UNIT_ASSERT_VALUES_EQUAL(EAGAIN, LastSystemError());
+
+ // restart after end (noop for epoll)
+ poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, true);
+
+ // send and recv byte
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // recv and see end
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 2, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]);
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ // the same but send before restart
+ UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0));
+
+ // restart after end (noop for epoll)
+ poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, true);
+
+ UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 2, 0));
+ UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]);
+
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+ UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero()));
+
+ poller.Unwait(sockets[1]);
+ }
+ }
+
+#if defined(HAVE_EPOLL_POLLER)
+ Y_UNIT_TEST(TestRdhup) {
+ SOCKET sockets[2];
+ UNIT_ASSERT(SocketPair(sockets) == 0);
+
+ TSocketHolder s1(sockets[0]);
+ TSocketHolder s2(sockets[1]);
+
+ char buf[1] = {0};
+ UNIT_ASSERT_VALUES_EQUAL(1, send(s1, buf, 1, 0));
+ shutdown(s1, SHUT_WR);
+
+ using TPoller = TGenericPoller<TEpollPoller<TWithoutLocking>>;
+ TPoller poller;
+ poller.Set((void*)17, s2, CONT_POLL_RDHUP);
+
+ TPoller::TEvent e;
+ UNIT_ASSERT_VALUES_EQUAL(poller.WaitD(&e, 1, TDuration::Zero().ToDeadLine()), 1);
+ UNIT_ASSERT_EQUAL(TPoller::ExtractStatus(&e), 0);
+ UNIT_ASSERT_EQUAL(TPoller::ExtractFilter(&e), CONT_POLL_RDHUP);
+ UNIT_ASSERT_EQUAL(TPoller::ExtractEvent(&e), (void*)17);
+ }
+#endif
+}
diff --git a/util/network/pollerimpl.cpp b/util/network/pollerimpl.cpp
new file mode 100644
index 0000000000..bf2ba16cf6
--- /dev/null
+++ b/util/network/pollerimpl.cpp
@@ -0,0 +1 @@
+#include "pollerimpl.h"
diff --git a/util/network/pollerimpl.h b/util/network/pollerimpl.h
new file mode 100644
index 0000000000..e8c7e40fba
--- /dev/null
+++ b/util/network/pollerimpl.h
@@ -0,0 +1,706 @@
+#pragma once
+
+#include "socket.h"
+
+#include <util/system/error.h>
+#include <util/system/mutex.h>
+#include <util/system/defaults.h>
+#include <util/generic/ylimits.h>
+#include <util/generic/utility.h>
+#include <util/generic/vector.h>
+#include <util/generic/yexception.h>
+#include <util/datetime/base.h>
+
+#if defined(_freebsd_) || defined(_darwin_)
+ #define HAVE_KQUEUE_POLLER
+#endif
+
+#if (defined(_linux_) && !defined(_bionic_)) || (__ANDROID_API__ >= 21)
+ #define HAVE_EPOLL_POLLER
+#endif
+
+//now we always have it
+#define HAVE_SELECT_POLLER
+
+#if defined(HAVE_KQUEUE_POLLER)
+ #include <sys/event.h>
+#endif
+
+#if defined(HAVE_EPOLL_POLLER)
+ #include <sys/epoll.h>
+#endif
+
+enum EContPoll {
+ CONT_POLL_READ = 1,
+ CONT_POLL_WRITE = 2,
+ CONT_POLL_RDHUP = 4,
+ CONT_POLL_ONE_SHOT = 8, // Disable after first event
+ CONT_POLL_MODIFY = 16, // Modify already added event
+ CONT_POLL_EDGE_TRIGGERED = 32, // Notify only about new events
+ CONT_POLL_BACKLOG_EMPTY = 64, // Backlog is empty (seen end of request, EAGAIN or truncated read)
+};
+
+static inline bool IsSocket(SOCKET fd) noexcept {
+ int val = 0;
+ socklen_t len = sizeof(val);
+
+ if (getsockopt(fd, SOL_SOCKET, SO_TYPE, (char*)&val, &len) == 0) {
+ return true;
+ }
+
+ return LastSystemError() != ENOTSOCK;
+}
+
+static inline int MicroToMilli(int timeout) noexcept {
+ if (timeout) {
+ /*
+ * 1. API of epoll syscall allows to specify timeout with millisecond
+ * accuracy only
+ * 2. It is quite complicated to guarantee time resolution of blocking
+ * syscall less than kernel 1/HZ
+ *
+ * Without this rounding we just waste cpu time and do a lot of
+ * fast epoll_wait(..., 0) syscalls.
+ */
+ return Max(timeout / 1000, 1);
+ }
+
+ return 0;
+}
+
+struct TWithoutLocking {
+ using TMyMutex = TFakeMutex;
+};
+
+#if defined(HAVE_KQUEUE_POLLER)
+static inline int Kevent(int kq, struct kevent* changelist, int nchanges,
+ struct kevent* eventlist, int nevents, const struct timespec* timeout) noexcept {
+ int ret;
+
+ do {
+ ret = kevent(kq, changelist, nchanges, eventlist, nevents, timeout);
+ } while (ret == -1 && errno == EINTR);
+
+ return ret;
+}
+
+template <class TLockPolicy>
+class TKqueuePoller {
+public:
+ typedef struct ::kevent TEvent;
+
+ inline TKqueuePoller()
+ : Fd_(kqueue())
+ {
+ if (Fd_ == -1) {
+ ythrow TSystemError() << "kqueue failed";
+ }
+ }
+
+ inline ~TKqueuePoller() {
+ close(Fd_);
+ }
+
+ inline int Fd() const noexcept {
+ return Fd_;
+ }
+
+ inline void SetImpl(void* data, int fd, int what) {
+ TEvent e[2];
+ int flags = EV_ADD;
+
+ if (what & CONT_POLL_EDGE_TRIGGERED) {
+ if (what & CONT_POLL_BACKLOG_EMPTY) {
+ // When backlog is empty, edge-triggered does not need restart.
+ return;
+ }
+ flags |= EV_CLEAR;
+ }
+
+ if (what & CONT_POLL_ONE_SHOT) {
+ flags |= EV_ONESHOT;
+ }
+
+ Zero(e);
+
+ EV_SET(e + 0, fd, EVFILT_READ, flags | ((what & CONT_POLL_READ) ? EV_ENABLE : EV_DISABLE), 0, 0, data);
+ EV_SET(e + 1, fd, EVFILT_WRITE, flags | ((what & CONT_POLL_WRITE) ? EV_ENABLE : EV_DISABLE), 0, 0, data);
+
+ if (Kevent(Fd_, e, 2, nullptr, 0, nullptr) == -1) {
+ ythrow TSystemError() << "kevent add failed";
+ }
+ }
+
+ inline void Remove(int fd) noexcept {
+ TEvent e[2];
+
+ Zero(e);
+
+ EV_SET(e + 0, fd, EVFILT_READ, EV_DELETE, 0, 0, 0);
+ EV_SET(e + 1, fd, EVFILT_WRITE, EV_DELETE, 0, 0, 0);
+
+ Y_VERIFY(!(Kevent(Fd_, e, 2, nullptr, 0, nullptr) == -1 && errno != ENOENT), "kevent remove failed: %s", LastSystemErrorText());
+ }
+
+ inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept {
+ struct timespec ts;
+
+ ts.tv_sec = timeout / 1000000;
+ ts.tv_nsec = (timeout % 1000000) * 1000;
+
+ const int ret = Kevent(Fd_, nullptr, 0, events, len, &ts);
+
+ Y_VERIFY(ret >= 0, "kevent failed: %s", LastSystemErrorText());
+
+ return (size_t)ret;
+ }
+
+ static inline void* ExtractEvent(const TEvent* event) noexcept {
+ return event->udata;
+ }
+
+ static inline int ExtractStatus(const TEvent* event) noexcept {
+ if (event->flags & EV_ERROR) {
+ return EIO;
+ }
+
+ return event->fflags;
+ }
+
+ static inline int ExtractFilterImpl(const TEvent* event) noexcept {
+ if (event->filter == EVFILT_READ) {
+ return CONT_POLL_READ;
+ }
+
+ if (event->filter == EVFILT_WRITE) {
+ return CONT_POLL_WRITE;
+ }
+
+ if (event->flags & EV_EOF) {
+ return CONT_POLL_READ | CONT_POLL_WRITE;
+ }
+
+ return 0;
+ }
+
+private:
+ int Fd_;
+};
+#endif
+
+#if defined(HAVE_EPOLL_POLLER)
+static inline int ContEpollWait(int epfd, struct epoll_event* events, int maxevents, int timeout) noexcept {
+ int ret;
+
+ do {
+ ret = epoll_wait(epfd, events, maxevents, Min<int>(timeout, 35 * 60 * 1000));
+ } while (ret == -1 && errno == EINTR);
+
+ return ret;
+}
+
+template <class TLockPolicy>
+class TEpollPoller {
+public:
+ typedef struct ::epoll_event TEvent;
+
+ inline TEpollPoller(bool closeOnExec = false)
+ : Fd_(epoll_create1(closeOnExec ? EPOLL_CLOEXEC : 0))
+ {
+ if (Fd_ == -1) {
+ ythrow TSystemError() << "epoll_create failed";
+ }
+ }
+
+ inline ~TEpollPoller() {
+ close(Fd_);
+ }
+
+ inline int Fd() const noexcept {
+ return Fd_;
+ }
+
+ inline void SetImpl(void* data, int fd, int what) {
+ TEvent e;
+
+ Zero(e);
+
+ if (what & CONT_POLL_EDGE_TRIGGERED) {
+ if (what & CONT_POLL_BACKLOG_EMPTY) {
+ // When backlog is empty, edge-triggered does not need restart.
+ return;
+ }
+ e.events |= EPOLLET;
+ }
+
+ if (what & CONT_POLL_ONE_SHOT) {
+ e.events |= EPOLLONESHOT;
+ }
+
+ if (what & CONT_POLL_READ) {
+ e.events |= EPOLLIN;
+ }
+
+ if (what & CONT_POLL_WRITE) {
+ e.events |= EPOLLOUT;
+ }
+
+ if (what & CONT_POLL_RDHUP) {
+ e.events |= EPOLLRDHUP;
+ }
+
+ e.data.ptr = data;
+
+ if ((what & CONT_POLL_MODIFY) || epoll_ctl(Fd_, EPOLL_CTL_ADD, fd, &e) == -1) {
+ if (epoll_ctl(Fd_, EPOLL_CTL_MOD, fd, &e) == -1) {
+ ythrow TSystemError() << "epoll add failed";
+ }
+ }
+ }
+
+ inline void Remove(int fd) noexcept {
+ TEvent e;
+
+ Zero(e);
+
+ epoll_ctl(Fd_, EPOLL_CTL_DEL, fd, &e);
+ }
+
+ inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept {
+ const int ret = ContEpollWait(Fd_, events, len, MicroToMilli(timeout));
+
+ Y_VERIFY(ret >= 0, "epoll wait error: %s", LastSystemErrorText());
+
+ return (size_t)ret;
+ }
+
+ static inline void* ExtractEvent(const TEvent* event) noexcept {
+ return event->data.ptr;
+ }
+
+ static inline int ExtractStatus(const TEvent* event) noexcept {
+ if (event->events & (EPOLLERR | EPOLLHUP)) {
+ return EIO;
+ }
+
+ return 0;
+ }
+
+ static inline int ExtractFilterImpl(const TEvent* event) noexcept {
+ int ret = 0;
+
+ if (event->events & EPOLLIN) {
+ ret |= CONT_POLL_READ;
+ }
+
+ if (event->events & EPOLLOUT) {
+ ret |= CONT_POLL_WRITE;
+ }
+
+ if (event->events & EPOLLRDHUP) {
+ ret |= CONT_POLL_RDHUP;
+ }
+
+ return ret;
+ }
+
+private:
+ int Fd_;
+};
+#endif
+
+#if defined(HAVE_SELECT_POLLER)
+ #include <util/memory/tempbuf.h>
+ #include <util/generic/hash.h>
+
+ #include "pair.h"
+
+static inline int ContSelect(int n, fd_set* r, fd_set* w, fd_set* e, struct timeval* t) noexcept {
+ int ret;
+
+ do {
+ ret = select(n, r, w, e, t);
+ } while (ret == -1 && errno == EINTR);
+
+ return ret;
+}
+
+struct TSelectPollerNoTemplate {
+ struct THandle {
+ void* Data_;
+ int Filter_;
+
+ inline THandle()
+ : Data_(nullptr)
+ , Filter_(0)
+ {
+ }
+
+ inline void* Data() const noexcept {
+ return Data_;
+ }
+
+ inline void Set(void* d, int s) noexcept {
+ Data_ = d;
+ Filter_ = s;
+ }
+
+ inline void Clear(int c) noexcept {
+ Filter_ &= ~c;
+ }
+
+ inline int Filter() const noexcept {
+ return Filter_;
+ }
+ };
+
+ class TFds: public THashMap<SOCKET, THandle> {
+ public:
+ inline void Set(SOCKET fd, void* data, int filter) {
+ (*this)[fd].Set(data, filter);
+ }
+
+ inline void Remove(SOCKET fd) {
+ erase(fd);
+ }
+
+ inline SOCKET Build(fd_set* r, fd_set* w, fd_set* e) const noexcept {
+ SOCKET ret = 0;
+
+ for (const auto& it : *this) {
+ const SOCKET fd = it.first;
+ const THandle& handle = it.second;
+
+ FD_SET(fd, e);
+
+ if (handle.Filter() & CONT_POLL_READ) {
+ FD_SET(fd, r);
+ }
+
+ if (handle.Filter() & CONT_POLL_WRITE) {
+ FD_SET(fd, w);
+ }
+
+ if (fd > ret) {
+ ret = fd;
+ }
+ }
+
+ return ret;
+ }
+ };
+
+ struct TEvent: public THandle {
+ inline int Status() const noexcept {
+ return -Min(Filter(), 0);
+ }
+
+ inline void Error(void* d, int err) noexcept {
+ Set(d, -err);
+ }
+
+ inline void Success(void* d, int what) noexcept {
+ Set(d, what);
+ }
+ };
+};
+
+template <class TLockPolicy>
+class TSelectPoller: public TSelectPollerNoTemplate {
+ using TMyMutex = typename TLockPolicy::TMyMutex;
+
+public:
+ inline TSelectPoller()
+ : Begin_(nullptr)
+ , End_(nullptr)
+ {
+ SocketPair(Signal_);
+ SetNonBlock(WaitSock());
+ SetNonBlock(SigSock());
+ }
+
+ inline ~TSelectPoller() {
+ closesocket(Signal_[0]);
+ closesocket(Signal_[1]);
+ }
+
+ inline void SetImpl(void* data, SOCKET fd, int what) {
+ with_lock (CommandLock_) {
+ Commands_.push_back(TCommand(fd, what, data));
+ }
+
+ Signal();
+ }
+
+ inline void Remove(SOCKET fd) noexcept {
+ with_lock (CommandLock_) {
+ Commands_.push_back(TCommand(fd, 0));
+ }
+
+ Signal();
+ }
+
+ inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept {
+ auto guard = Guard(Lock_);
+
+ do {
+ if (Begin_ != End_) {
+ const size_t ret = Min<size_t>(End_ - Begin_, len);
+
+ memcpy(events, Begin_, sizeof(*events) * ret);
+ Begin_ += ret;
+
+ return ret;
+ }
+
+ if (len >= EventNumberHint()) {
+ return WaitBase(events, len, timeout);
+ }
+
+ Begin_ = SavedEvents();
+ End_ = Begin_ + WaitBase(Begin_, EventNumberHint(), timeout);
+ } while (Begin_ != End_);
+
+ return 0;
+ }
+
+ inline TEvent* SavedEvents() {
+ if (!SavedEvents_) {
+ SavedEvents_.Reset(new TEvent[EventNumberHint()]);
+ }
+
+ return SavedEvents_.Get();
+ }
+
+ inline size_t WaitBase(TEvent* events, size_t len, int timeout) noexcept {
+ with_lock (CommandLock_) {
+ for (auto command = Commands_.begin(); command != Commands_.end(); ++command) {
+ if (command->Filter_ != 0) {
+ Fds_.Set(command->Fd_, command->Cookie_, command->Filter_);
+ } else {
+ Fds_.Remove(command->Fd_);
+ }
+ }
+
+ Commands_.clear();
+ }
+
+ TTempBuf tmpBuf(3 * sizeof(fd_set) + Fds_.size() * sizeof(SOCKET));
+
+ fd_set* in = (fd_set*)tmpBuf.Data();
+ fd_set* out = &in[1];
+ fd_set* errFds = &in[2];
+
+ SOCKET* keysToDeleteBegin = (SOCKET*)&in[3];
+ SOCKET* keysToDeleteEnd = keysToDeleteBegin;
+
+ #if defined(_msan_enabled_) // msan doesn't handle FD_ZERO and cause false positive BALANCER-1347
+ memset(in, 0, sizeof(*in));
+ memset(out, 0, sizeof(*out));
+ memset(errFds, 0, sizeof(*errFds));
+ #endif
+
+ FD_ZERO(in);
+ FD_ZERO(out);
+ FD_ZERO(errFds);
+
+ FD_SET(WaitSock(), in);
+
+ const SOCKET maxFdNum = Max(Fds_.Build(in, out, errFds), WaitSock());
+ struct timeval tout;
+
+ tout.tv_sec = timeout / 1000000;
+ tout.tv_usec = timeout % 1000000;
+
+ int ret = ContSelect(int(maxFdNum + 1), in, out, errFds, &tout);
+
+ if (ret > 0 && FD_ISSET(WaitSock(), in)) {
+ --ret;
+ TryWait();
+ }
+
+ Y_VERIFY(ret >= 0 && (size_t)ret <= len, "select error: %s", LastSystemErrorText());
+
+ TEvent* eventsStart = events;
+
+ for (typename TFds::iterator it = Fds_.begin(); it != Fds_.end(); ++it) {
+ const SOCKET fd = it->first;
+ THandle& handle = it->second;
+
+ if (FD_ISSET(fd, errFds)) {
+ (events++)->Error(handle.Data(), EIO);
+
+ if (handle.Filter() & CONT_POLL_ONE_SHOT) {
+ *keysToDeleteEnd = fd;
+ ++keysToDeleteEnd;
+ }
+
+ } else {
+ int what = 0;
+
+ if (FD_ISSET(fd, in)) {
+ what |= CONT_POLL_READ;
+ }
+
+ if (FD_ISSET(fd, out)) {
+ what |= CONT_POLL_WRITE;
+ }
+
+ if (what) {
+ (events++)->Success(handle.Data(), what);
+
+ if (handle.Filter() & CONT_POLL_ONE_SHOT) {
+ *keysToDeleteEnd = fd;
+ ++keysToDeleteEnd;
+ }
+
+ if (handle.Filter() & CONT_POLL_EDGE_TRIGGERED) {
+ // Emulate edge-triggered for level-triggered select().
+ // User must restart waiting this event when needed.
+ handle.Clear(what);
+ }
+ }
+ }
+ }
+
+ while (keysToDeleteBegin != keysToDeleteEnd) {
+ Fds_.erase(*keysToDeleteBegin);
+ ++keysToDeleteBegin;
+ }
+
+ return events - eventsStart;
+ }
+
+ inline size_t EventNumberHint() const noexcept {
+ return sizeof(fd_set) * 8 * 2;
+ }
+
+ static inline void* ExtractEvent(const TEvent* event) noexcept {
+ return event->Data();
+ }
+
+ static inline int ExtractStatus(const TEvent* event) noexcept {
+ return event->Status();
+ }
+
+ static inline int ExtractFilterImpl(const TEvent* event) noexcept {
+ return event->Filter();
+ }
+
+private:
+ inline void Signal() noexcept {
+ char ch = 13;
+
+ send(SigSock(), &ch, 1, 0);
+ }
+
+ inline void TryWait() {
+ char ch[32];
+
+ while (recv(WaitSock(), ch, sizeof(ch), 0) > 0) {
+ Y_ASSERT(ch[0] == 13);
+ }
+ }
+
+ inline SOCKET WaitSock() const noexcept {
+ return Signal_[1];
+ }
+
+ inline SOCKET SigSock() const noexcept {
+ return Signal_[0];
+ }
+
+private:
+ struct TCommand {
+ SOCKET Fd_;
+ int Filter_; // 0 to remove
+ void* Cookie_;
+
+ TCommand(SOCKET fd, int filter, void* cookie)
+ : Fd_(fd)
+ , Filter_(filter)
+ , Cookie_(cookie)
+ {
+ }
+
+ TCommand(SOCKET fd, int filter)
+ : Fd_(fd)
+ , Filter_(filter)
+ {
+ }
+ };
+
+ TFds Fds_;
+
+ TMyMutex Lock_;
+ TArrayHolder<TEvent> SavedEvents_;
+ TEvent* Begin_;
+ TEvent* End_;
+
+ TMyMutex CommandLock_;
+ TVector<TCommand> Commands_;
+
+ SOCKET Signal_[2];
+};
+#endif
+
+static inline TDuration PollStep(const TInstant& deadLine, const TInstant& now) noexcept {
+ if (deadLine < now) {
+ return TDuration::Zero();
+ }
+
+ return Min(deadLine - now, TDuration::Seconds(1000));
+}
+
+template <class TBase>
+class TGenericPoller: public TBase {
+public:
+ using TBase::TBase;
+
+ using TEvent = typename TBase::TEvent;
+
+ inline void Set(void* data, SOCKET fd, int what) {
+ if (what) {
+ this->SetImpl(data, fd, what);
+ } else {
+ this->Remove(fd);
+ }
+ }
+
+ static inline int ExtractFilter(const TEvent* event) noexcept {
+ if (TBase::ExtractStatus(event)) {
+ return CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_RDHUP;
+ }
+
+ return TBase::ExtractFilterImpl(event);
+ }
+
+ inline size_t WaitD(TEvent* events, size_t len, TInstant deadLine, TInstant now = TInstant::Now()) noexcept {
+ if (!len) {
+ return 0;
+ }
+
+ size_t ret;
+
+ do {
+ ret = this->Wait(events, len, (int)PollStep(deadLine, now).MicroSeconds());
+ } while (!ret && ((now = TInstant::Now()) < deadLine));
+
+ return ret;
+ }
+};
+
+#if defined(HAVE_KQUEUE_POLLER)
+ #define TPollerImplBase TKqueuePoller
+#elif defined(HAVE_EPOLL_POLLER)
+ #define TPollerImplBase TEpollPoller
+#elif defined(HAVE_SELECT_POLLER)
+ #define TPollerImplBase TSelectPoller
+#else
+ #error "unsupported platform"
+#endif
+
+template <class TLockPolicy>
+using TPollerImpl = TGenericPoller<TPollerImplBase<TLockPolicy>>;
+
+#undef TPollerImplBase
diff --git a/util/network/sock.cpp b/util/network/sock.cpp
new file mode 100644
index 0000000000..d4864a9c1c
--- /dev/null
+++ b/util/network/sock.cpp
@@ -0,0 +1 @@
+#include "sock.h"
diff --git a/util/network/sock.h b/util/network/sock.h
new file mode 100644
index 0000000000..b10be2f715
--- /dev/null
+++ b/util/network/sock.h
@@ -0,0 +1,608 @@
+#pragma once
+
+#include <util/folder/path.h>
+#include <util/system/defaults.h>
+#include <util/string/cast.h>
+#include <util/stream/output.h>
+#include <util/system/sysstat.h>
+
+#if defined(_win_) || defined(_cygwin_)
+ #include <util/system/file.h>
+#else
+ #include <sys/un.h>
+ #include <sys/stat.h>
+#endif //_win_
+
+#include "init.h"
+#include "ip.h"
+#include "socket.h"
+
+constexpr ui16 DEF_LOCAL_SOCK_MODE = 00644;
+
+// Base abstract class for socket address
+struct ISockAddr {
+ virtual ~ISockAddr() = default;
+ // Max size of the address that we can store (arg of recvfrom)
+ virtual socklen_t Size() const = 0;
+ // Real length of the address (arg of sendto)
+ virtual socklen_t Len() const = 0;
+ // cast to sockaddr* to pass to any syscall
+ virtual sockaddr* SockAddr() = 0;
+ virtual const sockaddr* SockAddr() const = 0;
+ // address in human readable form
+ virtual TString ToString() const = 0;
+
+protected:
+ // below are the implemetation methods that can be called by T*Socket classes
+ friend class TBaseSocket;
+ friend class TDgramSocket;
+ friend class TStreamSocket;
+
+ virtual int ResolveAddr() const {
+ // usually it's nothing to do here
+ return 0;
+ }
+ virtual int Bind(SOCKET s, ui16 mode) const = 0;
+};
+
+#if defined(_win_) || defined(_cygwin_)
+ #define YAF_LOCAL AF_INET
+struct TSockAddrLocal: public ISockAddr {
+ TSockAddrLocal() {
+ Clear();
+ }
+
+ TSockAddrLocal(const char* path) {
+ Set(path);
+ }
+
+ socklen_t Size() const {
+ return sizeof(sockaddr_in);
+ }
+
+ socklen_t Len() const {
+ return Size();
+ }
+
+ inline void Clear() noexcept {
+ Zero(in);
+ Zero(Path);
+ }
+
+ inline void Set(const char* path) noexcept {
+ Clear();
+ in.sin_family = AF_INET;
+ in.sin_addr.s_addr = IpFromString("127.0.0.1");
+ in.sin_port = 0;
+ strlcpy(Path, path, PathSize);
+ }
+
+ sockaddr* SockAddr() {
+ return (struct sockaddr*)(&in);
+ }
+
+ const sockaddr* SockAddr() const {
+ return (const struct sockaddr*)(&in);
+ }
+
+ TString ToString() const {
+ return TString(Path);
+ }
+
+ TFsPath ToPath() const {
+ return TFsPath(Path);
+ }
+
+ int ResolveAddr() const {
+ if (in.sin_port == 0) {
+ int ret = 0;
+ // 1. open file
+ TFileHandle f(Path, OpenExisting | RdOnly);
+ if (!f.IsOpen())
+ return -errno;
+
+ // 2. read the port from file
+ ret = f.Read(&in.sin_port, sizeof(in.sin_port));
+ if (ret != sizeof(in.sin_port))
+ return -(errno ? errno : EFAULT);
+ }
+
+ return 0;
+ }
+
+ int Bind(SOCKET s, ui16 mode) const {
+ Y_UNUSED(mode);
+ int ret = 0;
+ // 1. open file
+ TFileHandle f(Path, CreateAlways | WrOnly);
+ if (!f.IsOpen())
+ return -errno;
+
+ // 2. find port and bind to it
+ in.sin_port = 0;
+ ret = bind(s, SockAddr(), Len());
+ if (ret != 0)
+ return -WSAGetLastError();
+
+ int size = Size();
+ ret = getsockname(s, (struct sockaddr*)(&in), &size);
+ if (ret != 0)
+ return -WSAGetLastError();
+
+ // 3. write port to file
+ ret = f.Write(&(in.sin_port), sizeof(in.sin_port));
+ if (ret != sizeof(in.sin_port))
+ return -errno;
+
+ return 0;
+ }
+
+ static constexpr size_t PathSize = 128;
+ mutable struct sockaddr_in in;
+ char Path[PathSize];
+};
+#else
+ #define YAF_LOCAL AF_LOCAL
+struct TSockAddrLocal: public sockaddr_un, public ISockAddr {
+ TSockAddrLocal() {
+ Clear();
+ }
+
+ TSockAddrLocal(const char* path) {
+ Set(path);
+ }
+
+ socklen_t Size() const override {
+ return sizeof(sockaddr_un);
+ }
+
+ socklen_t Len() const override {
+ return strlen(sun_path) + 2;
+ }
+
+ inline void Clear() noexcept {
+ Zero(*(sockaddr_un*)this);
+ }
+
+ inline void Set(const char* path) noexcept {
+ Clear();
+ sun_family = AF_UNIX;
+ strlcpy(sun_path, path, sizeof(sun_path));
+ }
+
+ sockaddr* SockAddr() override {
+ return (struct sockaddr*)(struct sockaddr_un*)this;
+ }
+
+ const sockaddr* SockAddr() const override {
+ return (const struct sockaddr*)(const struct sockaddr_un*)this;
+ }
+
+ TString ToString() const override {
+ return TString(sun_path);
+ }
+
+ TFsPath ToPath() const {
+ return TFsPath(sun_path);
+ }
+
+ int Bind(SOCKET s, ui16 mode) const override {
+ (void)unlink(sun_path);
+
+ int ret = bind(s, SockAddr(), Len());
+ if (ret < 0)
+ return -errno;
+
+ ret = Chmod(sun_path, mode);
+ if (ret < 0)
+ return -errno;
+ return 0;
+ }
+};
+#endif // _win_
+
+struct TSockAddrInet: public sockaddr_in, public ISockAddr {
+ TSockAddrInet() {
+ Clear();
+ }
+
+ TSockAddrInet(TIpHost ip, TIpPort port) {
+ Set(ip, port);
+ }
+
+ TSockAddrInet(const char* ip, TIpPort port) {
+ Set(IpFromString(ip), port);
+ }
+
+ socklen_t Size() const override {
+ return sizeof(sockaddr_in);
+ }
+
+ socklen_t Len() const override {
+ return Size();
+ }
+
+ inline void Clear() noexcept {
+ Zero(*(sockaddr_in*)this);
+ }
+
+ inline void Set(TIpHost ip, TIpPort port) noexcept {
+ Clear();
+ sin_family = AF_INET;
+ sin_addr.s_addr = ip;
+ sin_port = HostToInet(port);
+ }
+
+ sockaddr* SockAddr() override {
+ return (struct sockaddr*)(struct sockaddr_in*)this;
+ }
+
+ const sockaddr* SockAddr() const override {
+ return (const struct sockaddr*)(const struct sockaddr_in*)this;
+ }
+
+ TString ToString() const override {
+ return IpToString(sin_addr.s_addr) + ":" + ::ToString(InetToHost(sin_port));
+ }
+
+ int Bind(SOCKET s, ui16 mode) const override {
+ Y_UNUSED(mode);
+ int ret = bind(s, SockAddr(), Len());
+ if (ret < 0)
+ return -errno;
+
+ socklen_t len = Len();
+ if (getsockname(s, (struct sockaddr*)(SockAddr()), &len) < 0)
+ return -WSAGetLastError();
+
+ return 0;
+ }
+
+ TIpHost GetIp() const noexcept {
+ return sin_addr.s_addr;
+ }
+
+ TIpPort GetPort() const noexcept {
+ return InetToHost(sin_port);
+ }
+
+ void SetPort(TIpPort port) noexcept {
+ sin_port = HostToInet(port);
+ }
+};
+
+struct TSockAddrInet6: public sockaddr_in6, public ISockAddr {
+ TSockAddrInet6() {
+ Clear();
+ }
+
+ TSockAddrInet6(const char* ip6, const TIpPort port) {
+ Set(ip6, port);
+ }
+
+ socklen_t Size() const override {
+ return sizeof(sockaddr_in6);
+ }
+
+ socklen_t Len() const override {
+ return Size();
+ }
+
+ inline void Clear() noexcept {
+ Zero(*(sockaddr_in6*)this);
+ }
+
+ inline void Set(const char* ip6, const TIpPort port) noexcept {
+ Clear();
+ sin6_family = AF_INET6;
+ inet_pton(AF_INET6, ip6, &sin6_addr);
+ sin6_port = HostToInet(port);
+ }
+
+ sockaddr* SockAddr() override {
+ return (struct sockaddr*)(struct sockaddr_in6*)this;
+ }
+
+ const sockaddr* SockAddr() const override {
+ return (const struct sockaddr*)(const struct sockaddr_in6*)this;
+ }
+
+ TString ToString() const override {
+ return "[" + GetIp() + "]:" + ::ToString(InetToHost(sin6_port));
+ }
+
+ int Bind(SOCKET s, ui16 mode) const override {
+ Y_UNUSED(mode);
+ int ret = bind(s, SockAddr(), Len());
+ if (ret < 0) {
+ return -errno;
+ }
+ socklen_t len = Len();
+ if (getsockname(s, (struct sockaddr*)(SockAddr()), &len) < 0) {
+ return -WSAGetLastError();
+ }
+ return 0;
+ }
+
+ TString GetIp() const noexcept {
+ char ip6[INET6_ADDRSTRLEN];
+ inet_ntop(AF_INET6, (void*)&sin6_addr, ip6, INET6_ADDRSTRLEN);
+ return TString(ip6);
+ }
+
+ TIpPort GetPort() const noexcept {
+ return InetToHost(sin6_port);
+ }
+
+ void SetPort(TIpPort port) noexcept {
+ sin6_port = HostToInet(port);
+ }
+};
+
+using TSockAddrLocalStream = TSockAddrLocal;
+using TSockAddrLocalDgram = TSockAddrLocal;
+using TSockAddrInetStream = TSockAddrInet;
+using TSockAddrInetDgram = TSockAddrInet;
+using TSockAddrInet6Stream = TSockAddrInet6;
+using TSockAddrInet6Dgram = TSockAddrInet6;
+
+class TBaseSocket: public TSocketHolder {
+protected:
+ TBaseSocket(SOCKET fd)
+ : TSocketHolder(fd)
+ {
+ }
+
+public:
+ int Bind(const ISockAddr* addr, ui16 mode = DEF_LOCAL_SOCK_MODE) {
+ return addr->Bind((SOCKET) * this, mode);
+ }
+
+ void CheckSock() {
+ if ((SOCKET) * this == INVALID_SOCKET)
+ ythrow TSystemError() << "no socket";
+ }
+
+ static ssize_t Check(ssize_t ret, const char* op = "") {
+ if (ret < 0)
+ ythrow TSystemError(-(int)ret) << "socket operation " << op;
+ return ret;
+ }
+};
+
+class TDgramSocket: public TBaseSocket {
+protected:
+ TDgramSocket(SOCKET fd)
+ : TBaseSocket(fd)
+ {
+ }
+
+public:
+ ssize_t SendTo(const void* msg, size_t len, const ISockAddr* toAddr) {
+ ssize_t ret = toAddr->ResolveAddr();
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ ret = sendto((SOCKET) * this, (const char*)msg, (int)len, 0, toAddr->SockAddr(), toAddr->Len());
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ return ret;
+ }
+
+ ssize_t RecvFrom(void* buf, size_t len, ISockAddr* fromAddr) {
+ socklen_t fromSize = fromAddr->Size();
+ const ssize_t ret = recvfrom((SOCKET) * this, (char*)buf, (int)len, 0, fromAddr->SockAddr(), &fromSize);
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ return ret;
+ }
+};
+
+class TStreamSocket: public TBaseSocket {
+protected:
+ explicit TStreamSocket(SOCKET fd)
+ : TBaseSocket(fd)
+ {
+ }
+
+public:
+ TStreamSocket()
+ : TBaseSocket(INVALID_SOCKET)
+ {
+ }
+
+ ssize_t Send(const void* msg, size_t len, int flags = 0) {
+ const ssize_t ret = send((SOCKET) * this, (const char*)msg, (int)len, flags);
+ if (ret < 0)
+ return -errno;
+
+ return ret;
+ }
+
+ ssize_t Recv(void* buf, size_t len, int flags = 0) {
+ const ssize_t ret = recv((SOCKET) * this, (char*)buf, (int)len, flags);
+ if (ret < 0)
+ return -errno;
+
+ return ret;
+ }
+
+ int Connect(const ISockAddr* addr) {
+ int ret = addr->ResolveAddr();
+ if (ret < 0)
+ return -errno;
+
+ ret = connect((SOCKET) * this, addr->SockAddr(), addr->Len());
+ if (ret < 0)
+ return -errno;
+
+ return ret;
+ }
+
+ int Listen(int backlog) {
+ int ret = listen((SOCKET) * this, backlog);
+ if (ret < 0)
+ return -errno;
+
+ return ret;
+ }
+
+ int Accept(TStreamSocket* acceptedSock, ISockAddr* acceptedAddr = nullptr) {
+ SOCKET s = INVALID_SOCKET;
+ if (acceptedAddr) {
+ socklen_t acceptedSize = acceptedAddr->Size();
+ s = accept((SOCKET) * this, acceptedAddr->SockAddr(), &acceptedSize);
+ } else {
+ s = accept((SOCKET) * this, nullptr, nullptr);
+ }
+
+ if (s == INVALID_SOCKET)
+ return -errno;
+
+ TSocketHolder sock(s);
+ acceptedSock->Swap(sock);
+ return 0;
+ }
+};
+
+class TLocalDgramSocket: public TDgramSocket {
+public:
+ TLocalDgramSocket(SOCKET fd)
+ : TDgramSocket(fd)
+ {
+ }
+
+ TLocalDgramSocket()
+ : TDgramSocket(socket(YAF_LOCAL, SOCK_DGRAM, 0))
+ {
+ }
+};
+
+class TInetDgramSocket: public TDgramSocket {
+public:
+ TInetDgramSocket(SOCKET fd)
+ : TDgramSocket(fd)
+ {
+ }
+
+ TInetDgramSocket()
+ : TDgramSocket(socket(AF_INET, SOCK_DGRAM, 0))
+ {
+ }
+};
+
+class TInet6DgramSocket: public TDgramSocket {
+public:
+ TInet6DgramSocket(SOCKET fd)
+ : TDgramSocket(fd)
+ {
+ }
+
+ TInet6DgramSocket()
+ : TDgramSocket(socket(AF_INET6, SOCK_DGRAM, 0))
+ {
+ }
+};
+
+class TLocalStreamSocket: public TStreamSocket {
+public:
+ TLocalStreamSocket(SOCKET fd)
+ : TStreamSocket(fd)
+ {
+ }
+
+ TLocalStreamSocket()
+ : TStreamSocket(socket(YAF_LOCAL, SOCK_STREAM, 0))
+ {
+ }
+};
+
+class TInetStreamSocket: public TStreamSocket {
+public:
+ TInetStreamSocket(SOCKET fd)
+ : TStreamSocket(fd)
+ {
+ }
+
+ TInetStreamSocket()
+ : TStreamSocket(socket(AF_INET, SOCK_STREAM, 0))
+ {
+ }
+};
+
+class TInet6StreamSocket: public TStreamSocket {
+public:
+ TInet6StreamSocket(SOCKET fd)
+ : TStreamSocket(fd)
+ {
+ }
+
+ TInet6StreamSocket()
+ : TStreamSocket(socket(AF_INET6, SOCK_STREAM, 0))
+ {
+ }
+};
+
+class TStreamSocketInput: public IInputStream {
+public:
+ TStreamSocketInput(TStreamSocket* socket)
+ : Socket(socket)
+ {
+ }
+ void SetSocket(TStreamSocket* socket) {
+ Socket = socket;
+ }
+
+protected:
+ TStreamSocket* Socket;
+
+ size_t DoRead(void* buf, size_t len) override {
+ Y_VERIFY(Socket, "TStreamSocketInput: socket isn't set");
+ const ssize_t ret = Socket->Recv(buf, len);
+
+ if (ret >= 0) {
+ return (size_t)ret;
+ }
+
+ ythrow TSystemError(-(int)ret) << "can not read from socket input stream";
+ }
+};
+
+class TStreamSocketOutput: public IOutputStream {
+public:
+ TStreamSocketOutput(TStreamSocket* socket)
+ : Socket(socket)
+ {
+ }
+ void SetSocket(TStreamSocket* socket) {
+ Socket = socket;
+ }
+
+ TStreamSocketOutput(TStreamSocketOutput&&) noexcept = default;
+ TStreamSocketOutput& operator=(TStreamSocketOutput&&) noexcept = default;
+
+protected:
+ TStreamSocket* Socket;
+
+ void DoWrite(const void* buf, size_t len) override {
+ Y_VERIFY(Socket, "TStreamSocketOutput: socket isn't set");
+
+ const char* ptr = (const char*)buf;
+ while (len) {
+ const ssize_t ret = Socket->Send(ptr, len);
+
+ if (ret < 0) {
+ ythrow TSystemError(-(int)ret) << "can not write to socket output stream";
+ }
+
+ Y_ASSERT((size_t)ret <= len);
+ len -= (size_t)ret;
+ ptr += (size_t)ret;
+ }
+ }
+};
diff --git a/util/network/sock_ut.cpp b/util/network/sock_ut.cpp
new file mode 100644
index 0000000000..fd8c783747
--- /dev/null
+++ b/util/network/sock_ut.cpp
@@ -0,0 +1,168 @@
+#include "sock.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+#include <library/cpp/threading/future/legacy_future.h>
+
+#include <util/system/fs.h>
+
+Y_UNIT_TEST_SUITE(TSocketTest) {
+ Y_UNIT_TEST(InetDgramTest) {
+ char buf[256];
+ TSockAddrInetDgram servAddr(IpFromString("127.0.0.1"), 0);
+ TSockAddrInetDgram cliAddr(IpFromString("127.0.0.1"), 0);
+ TSockAddrInetDgram servFromAddr;
+ TSockAddrInetDgram cliFromAddr;
+ TInetDgramSocket cliSock;
+ TInetDgramSocket servSock;
+ cliSock.CheckSock();
+ servSock.CheckSock();
+
+ TBaseSocket::Check(cliSock.Bind(&cliAddr));
+ TBaseSocket::Check(servSock.Bind(&servAddr));
+
+ // client
+ const char reqStr[] = "Hello, world!!!";
+ TBaseSocket::Check(cliSock.SendTo(reqStr, sizeof(reqStr), &servAddr));
+
+ // server
+ TBaseSocket::Check(servSock.RecvFrom(buf, 256, &servFromAddr));
+ UNIT_ASSERT(strcmp(reqStr, buf) == 0);
+ const char repStr[] = "The World's greatings to you";
+ TBaseSocket::Check(servSock.SendTo(repStr, sizeof(repStr), &servFromAddr));
+
+ // client
+ TBaseSocket::Check(cliSock.RecvFrom(buf, 256, &cliFromAddr));
+ UNIT_ASSERT(strcmp(repStr, buf) == 0);
+ }
+
+ void RunLocalDgramTest(const char* localServerSockName, const char* localClientSockName) {
+ char buf[256];
+ TSockAddrLocalDgram servAddr(localServerSockName);
+ TSockAddrLocalDgram cliAddr(localClientSockName);
+ TSockAddrLocalDgram servFromAddr;
+ TSockAddrLocalDgram cliFromAddr;
+ TLocalDgramSocket cliSock;
+ TLocalDgramSocket servSock;
+ cliSock.CheckSock();
+ servSock.CheckSock();
+
+ TBaseSocket::Check(cliSock.Bind(&cliAddr), "bind client");
+ TBaseSocket::Check(servSock.Bind(&servAddr), "bind server");
+
+ // client
+ const char reqStr[] = "Hello, world!!!";
+ TBaseSocket::Check(cliSock.SendTo(reqStr, sizeof(reqStr), &servAddr), "send from client");
+
+ // server
+ TBaseSocket::Check(servSock.RecvFrom(buf, 256, &servFromAddr), "receive from client");
+ UNIT_ASSERT(strcmp(reqStr, buf) == 0);
+ const char repStr[] = "The World's greatings to you";
+ TBaseSocket::Check(servSock.SendTo(repStr, sizeof(repStr), &servFromAddr), "send to client");
+
+ // client
+ TBaseSocket::Check(cliSock.RecvFrom(buf, 256, &cliFromAddr), "receive from server");
+ UNIT_ASSERT(strcmp(repStr, buf) == 0);
+ }
+
+ Y_UNIT_TEST(LocalDgramTest) {
+ const char* localServerSockName = "./serv_sock";
+ const char* localClientSockName = "./cli_sock";
+ RunLocalDgramTest(localServerSockName, localClientSockName);
+ NFs::Remove(localServerSockName);
+ NFs::Remove(localClientSockName);
+ }
+
+ template <class A, class S>
+ void RunInetStreamTest(const char* ip) {
+ char buf[256];
+ A servAddr(ip, 0);
+ A newAddr;
+ S cliSock;
+ S servSock;
+ S newSock;
+ cliSock.CheckSock();
+ servSock.CheckSock();
+ newSock.CheckSock();
+
+ // server
+ int yes = 1;
+ CheckedSetSockOpt(servSock, SOL_SOCKET, SO_REUSEADDR, yes, "servSock, SO_REUSEADDR");
+ TBaseSocket::Check(servSock.Bind(&servAddr), "bind");
+ TBaseSocket::Check(servSock.Listen(10), "listen");
+
+ // client
+ TBaseSocket::Check(cliSock.Connect(&servAddr), "connect");
+
+ // server
+ TBaseSocket::Check(servSock.Accept(&newSock, &newAddr), "accept");
+
+ // client
+ const char reqStr[] = "Hello, world!!!";
+ TBaseSocket::Check(cliSock.Send(reqStr, sizeof(reqStr)), "send");
+
+ // server - new
+ TBaseSocket::Check(newSock.Recv(buf, 256), "recv");
+ UNIT_ASSERT(strcmp(reqStr, buf) == 0);
+ const char repStr[] = "The World's greatings to you";
+ TBaseSocket::Check(newSock.Send(repStr, sizeof(repStr)), "send");
+
+ // client
+ TBaseSocket::Check(cliSock.Recv(buf, 256), "recv");
+ UNIT_ASSERT(strcmp(repStr, buf) == 0);
+ }
+
+ Y_UNIT_TEST(InetStreamTest) {
+ RunInetStreamTest<TSockAddrInetStream, TInetStreamSocket>("127.0.0.1");
+ }
+
+ Y_UNIT_TEST(Inet6StreamTest) {
+ RunInetStreamTest<TSockAddrInet6Stream, TInet6StreamSocket>("::1");
+ }
+
+ void RunLocalStreamTest(const char* localServerSockName) {
+ char buf[256];
+ TSockAddrLocalStream servAddr(localServerSockName);
+ TSockAddrLocalStream newAddr;
+ TLocalStreamSocket cliSock;
+ TLocalStreamSocket servSock;
+ TLocalStreamSocket newSock;
+ cliSock.CheckSock();
+ servSock.CheckSock();
+ newSock.CheckSock();
+
+ // server
+ TBaseSocket::Check(servSock.Bind(&servAddr), "bind");
+ TBaseSocket::Check(servSock.Listen(10), "listen");
+
+ NThreading::TLegacyFuture<void> f([&]() {
+ // server
+ TBaseSocket::Check(servSock.Accept(&newSock, &newAddr), "accept");
+ });
+
+ // client
+ TBaseSocket::Check(cliSock.Connect(&servAddr), "connect");
+
+ f.Get();
+
+ // client
+ const char reqStr[] = "Hello, world!!!";
+ TBaseSocket::Check(cliSock.Send(reqStr, sizeof(reqStr)), "send");
+
+ // server - new
+ TBaseSocket::Check(newSock.Recv(buf, 256), "recv");
+ UNIT_ASSERT(strcmp(reqStr, buf) == 0);
+ const char repStr[] = "The World's greatings to you";
+ TBaseSocket::Check(newSock.Send(repStr, sizeof(repStr)), "send");
+
+ // client
+ TBaseSocket::Check(cliSock.Recv(buf, 256), "recv");
+ UNIT_ASSERT(strcmp(repStr, buf) == 0);
+ }
+
+ Y_UNIT_TEST(LocalStreamTest) {
+ const char* localServerSockName = "./serv_sock2";
+ RunLocalStreamTest(localServerSockName);
+ NFs::Remove(localServerSockName);
+ }
+
+}
diff --git a/util/network/socket.cpp b/util/network/socket.cpp
new file mode 100644
index 0000000000..c1a42e849e
--- /dev/null
+++ b/util/network/socket.cpp
@@ -0,0 +1,1288 @@
+#include "ip.h"
+#include "socket.h"
+#include "address.h"
+#include "pollerimpl.h"
+#include "iovec.h"
+
+#include <util/system/defaults.h>
+#include <util/system/byteorder.h>
+
+#if defined(_unix_)
+ #include <netdb.h>
+ #include <sys/types.h>
+ #include <sys/socket.h>
+ #include <sys/un.h>
+ #include <sys/ioctl.h>
+ #include <netinet/in.h>
+ #include <netinet/tcp.h>
+ #include <arpa/inet.h>
+#endif
+
+#if defined(_freebsd_)
+ #include <sys/module.h>
+ #define ACCEPT_FILTER_MOD
+ #include <sys/socketvar.h>
+#endif
+
+#if defined(_win_)
+ #include <cerrno>
+ #include <winsock2.h>
+ #include <ws2tcpip.h>
+ #include <wspiapi.h>
+
+ #include <util/system/compat.h>
+#endif
+
+#include <util/generic/ylimits.h>
+
+#include <util/string/cast.h>
+#include <util/stream/mem.h>
+#include <util/system/datetime.h>
+#include <util/system/error.h>
+#include <util/memory/tempbuf.h>
+#include <util/generic/singleton.h>
+#include <util/generic/hash_set.h>
+
+#include <stddef.h>
+#include <sys/uio.h>
+
+using namespace NAddr;
+
+#if defined(_win_)
+
+int inet_aton(const char* cp, struct in_addr* inp) {
+ sockaddr_in addr;
+ addr.sin_family = AF_INET;
+ int psz = sizeof(addr);
+ if (0 == WSAStringToAddress((char*)cp, AF_INET, nullptr, (LPSOCKADDR)&addr, &psz)) {
+ memcpy(inp, &addr.sin_addr, sizeof(in_addr));
+ return 1;
+ }
+ return 0;
+}
+
+ #if (_WIN32_WINNT < 0x0600)
+const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
+ if (af != AF_INET) {
+ errno = EINVAL;
+ return 0;
+ }
+ const ui8* ia = (ui8*)src;
+ if (snprintf(dst, size, "%u.%u.%u.%u", ia[0], ia[1], ia[2], ia[3]) >= (int)size) {
+ errno = ENOSPC;
+ return 0;
+ }
+ return dst;
+}
+
+struct evpair {
+ int event;
+ int winevent;
+};
+
+static const evpair evpairs_to_win[] = {
+ {POLLIN, FD_READ | FD_CLOSE | FD_ACCEPT},
+ {POLLRDNORM, FD_READ | FD_CLOSE | FD_ACCEPT},
+ {POLLRDBAND, -1},
+ {POLLPRI, -1},
+ {POLLOUT, FD_WRITE | FD_CLOSE},
+ {POLLWRNORM, FD_WRITE | FD_CLOSE},
+ {POLLWRBAND, -1},
+ {POLLERR, 0},
+ {POLLHUP, 0},
+ {POLLNVAL, 0}};
+
+static const size_t nevpairs_to_win = sizeof(evpairs_to_win) / sizeof(evpairs_to_win[0]);
+
+static const evpair evpairs_to_unix[] = {
+ {FD_ACCEPT, POLLIN | POLLRDNORM},
+ {FD_READ, POLLIN | POLLRDNORM},
+ {FD_WRITE, POLLOUT | POLLWRNORM},
+ {FD_CLOSE, POLLHUP},
+};
+
+static const size_t nevpairs_to_unix = sizeof(evpairs_to_unix) / sizeof(evpairs_to_unix[0]);
+
+static int convert_events(int events, const evpair* evpairs, size_t nevpairs, bool ignoreUnknown) noexcept {
+ int result = 0;
+ for (size_t i = 0; i < nevpairs; ++i) {
+ int event = evpairs[i].event;
+ if (events & event) {
+ events ^= event;
+ long winEvent = evpairs[i].winevent;
+ if (winEvent == -1)
+ return -1;
+ if (winEvent == 0)
+ continue;
+ result |= winEvent;
+ }
+ }
+ if (events != 0 && !ignoreUnknown)
+ return -1;
+ return result;
+}
+
+class TWSAEventHolder {
+private:
+ HANDLE Event;
+
+public:
+ inline TWSAEventHolder(HANDLE event) noexcept
+ : Event(event)
+ {
+ }
+
+ inline ~TWSAEventHolder() {
+ WSACloseEvent(Event);
+ }
+
+ inline HANDLE Get() noexcept {
+ return Event;
+ }
+};
+
+int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept {
+ HANDLE rawEvent = WSACreateEvent();
+ if (rawEvent == WSA_INVALID_EVENT) {
+ errno = EIO;
+ return -1;
+ }
+
+ TWSAEventHolder event(rawEvent);
+
+ int checked_sockets = 0;
+
+ for (pollfd* fd = fds; fd < fds + nfds; ++fd) {
+ int win_events = convert_events(fd->events, evpairs_to_win, nevpairs_to_win, false);
+ if (win_events == -1) {
+ errno = EINVAL;
+ return -1;
+ }
+ fd->revents = 0;
+ if (WSAEventSelect(fd->fd, event.Get(), win_events)) {
+ int error = WSAGetLastError();
+ if (error == WSAEINVAL || error == WSAENOTSOCK) {
+ fd->revents = POLLNVAL;
+ ++checked_sockets;
+ } else {
+ errno = EIO;
+ return -1;
+ }
+ }
+ fd_set readfds;
+ fd_set writefds;
+ struct timeval timeout = {0, 0};
+ FD_ZERO(&readfds);
+ FD_ZERO(&writefds);
+ if (fd->events & POLLIN) {
+ FD_SET(fd->fd, &readfds);
+ }
+ if (fd->events & POLLOUT) {
+ FD_SET(fd->fd, &writefds);
+ }
+ int error = select(0, &readfds, &writefds, nullptr, &timeout);
+ if (error > 0) {
+ if (FD_ISSET(fd->fd, &readfds)) {
+ fd->revents |= POLLIN;
+ }
+ if (FD_ISSET(fd->fd, &writefds)) {
+ fd->revents |= POLLOUT;
+ }
+ ++checked_sockets;
+ }
+ }
+
+ if (checked_sockets > 0) {
+ // returns without wait since we already have sockets in desired conditions
+ return checked_sockets;
+ }
+
+ HANDLE events[] = {event.Get()};
+ DWORD wait_result = WSAWaitForMultipleEvents(1, events, TRUE, timeout, FALSE);
+ if (wait_result == WSA_WAIT_TIMEOUT)
+ return 0;
+ else if (wait_result == WSA_WAIT_EVENT_0) {
+ for (pollfd* fd = fds; fd < fds + nfds; ++fd) {
+ if (fd->revents == POLLNVAL)
+ continue;
+ WSANETWORKEVENTS network_events;
+ if (WSAEnumNetworkEvents(fd->fd, event.Get(), &network_events)) {
+ errno = EIO;
+ return -1;
+ }
+ fd->revents = 0;
+ for (int i = 0; i < FD_MAX_EVENTS; ++i) {
+ if ((network_events.lNetworkEvents & (1 << i)) != 0 && network_events.iErrorCode[i]) {
+ fd->revents = POLLERR;
+ break;
+ }
+ }
+ if (fd->revents == POLLERR)
+ continue;
+ if (network_events.lNetworkEvents) {
+ fd->revents = static_cast<short>(convert_events(network_events.lNetworkEvents, evpairs_to_unix, nevpairs_to_unix, true));
+ if (fd->revents & POLLHUP) {
+ fd->revents &= POLLHUP | POLLIN | POLLRDNORM;
+ }
+ }
+ }
+ int chanded_sockets = 0;
+ for (pollfd* fd = fds; fd < fds + nfds; ++fd)
+ if (fd->revents != 0)
+ ++chanded_sockets;
+ return chanded_sockets;
+ } else {
+ errno = EIO;
+ return -1;
+ }
+}
+ #endif
+
+#endif
+
+bool GetRemoteAddr(SOCKET Socket, char* str, socklen_t size) {
+ if (!size) {
+ return false;
+ }
+
+ TOpaqueAddr addr;
+
+ if (getpeername(Socket, addr.MutableAddr(), addr.LenPtr()) != 0) {
+ return false;
+ }
+
+ try {
+ TMemoryOutput out(str, size - 1);
+
+ PrintHost(out, addr);
+ *out.Buf() = 0;
+
+ return true;
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+
+ return false;
+}
+
+void SetSocketTimeout(SOCKET s, long timeout) {
+ SetSocketTimeout(s, timeout, 0);
+}
+
+void SetSocketTimeout(SOCKET s, long sec, long msec) {
+#ifdef SO_SNDTIMEO
+ #ifdef _darwin_
+ const timeval timeout = {sec, (__darwin_suseconds_t)msec * 1000};
+ #elif defined(_unix_)
+ const timeval timeout = {sec, msec * 1000};
+ #else
+ const int timeout = sec * 1000 + msec;
+ #endif
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVTIMEO, timeout, "recv timeout");
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDTIMEO, timeout, "send timeout");
+#endif
+}
+
+void SetLinger(SOCKET s, bool on, unsigned len) {
+#ifdef SO_LINGER
+ struct linger l = {on, (u_short)len};
+
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_LINGER, l, "linger");
+#endif
+}
+
+void SetZeroLinger(SOCKET s) {
+ SetLinger(s, 1, 0);
+}
+
+void SetKeepAlive(SOCKET s, bool value) {
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_KEEPALIVE, (int)value, "keepalive");
+}
+
+void SetOutputBuffer(SOCKET s, unsigned value) {
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDBUF, value, "output buffer");
+}
+
+void SetInputBuffer(SOCKET s, unsigned value) {
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVBUF, value, "input buffer");
+}
+
+#if defined(_linux_) && !defined(SO_REUSEPORT)
+ #define SO_REUSEPORT 15
+#endif
+
+void SetReusePort(SOCKET s, bool value) {
+#if defined(SO_REUSEPORT)
+ CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEPORT, (int)value, "reuse port");
+#else
+ Y_UNUSED(s);
+ Y_UNUSED(value);
+ ythrow TSystemError(ENOSYS) << "SO_REUSEPORT is not defined";
+#endif
+}
+
+void SetNoDelay(SOCKET s, bool value) {
+ CheckedSetSockOpt(s, IPPROTO_TCP, TCP_NODELAY, (int)value, "tcp no delay");
+}
+
+void SetCloseOnExec(SOCKET s, bool value) {
+#if defined(_unix_)
+ int flags = fcntl(s, F_GETFD);
+ if (flags == -1) {
+ ythrow TSystemError() << "fcntl() failed";
+ }
+ if (value) {
+ flags |= FD_CLOEXEC;
+ } else {
+ flags &= ~FD_CLOEXEC;
+ }
+ if (fcntl(s, F_SETFD, flags) == -1) {
+ ythrow TSystemError() << "fcntl() failed";
+ }
+#else
+ Y_UNUSED(s);
+ Y_UNUSED(value);
+#endif
+}
+
+size_t GetMaximumSegmentSize(SOCKET s) {
+#if defined(TCP_MAXSEG)
+ int val;
+
+ if (GetSockOpt(s, IPPROTO_TCP, TCP_MAXSEG, val) == 0) {
+ return (size_t)val;
+ }
+#endif
+
+ /*
+ * probably a good guess...
+ */
+ return 8192;
+}
+
+size_t GetMaximumTransferUnit(SOCKET /*s*/) {
+ // for someone who'll dare to write it
+ // Linux: there rummored to be IP_MTU getsockopt() request
+ // FreeBSD: request to a socket of type PF_ROUTE
+ // with peer address as a destination argument
+ return 8192;
+}
+
+int GetSocketToS(SOCKET s) {
+ TOpaqueAddr addr;
+
+ if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) {
+ ythrow TSystemError() << "getsockname() failed";
+ }
+
+ return GetSocketToS(s, &addr);
+}
+
+int GetSocketToS(SOCKET s, const IRemoteAddr* addr) {
+ int result = 0;
+
+ switch (addr->Addr()->sa_family) {
+ case AF_INET:
+ CheckedGetSockOpt(s, IPPROTO_IP, IP_TOS, result, "tos");
+ break;
+
+ case AF_INET6:
+#ifdef IPV6_TCLASS
+ CheckedGetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, result, "tos");
+#endif
+ break;
+ }
+
+ return result;
+}
+
+void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos) {
+ switch (addr->Addr()->sa_family) {
+ case AF_INET:
+ CheckedSetSockOpt(s, IPPROTO_IP, IP_TOS, tos, "tos");
+ return;
+
+ case AF_INET6:
+#ifdef IPV6_TCLASS
+ CheckedSetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, tos, "tos");
+ return;
+#endif
+ break;
+ }
+
+ ythrow yexception() << "SetSocketToS unsupported for family " << addr->Addr()->sa_family;
+}
+
+void SetSocketToS(SOCKET s, int tos) {
+ TOpaqueAddr addr;
+
+ if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) {
+ ythrow TSystemError() << "getsockname() failed";
+ }
+
+ SetSocketToS(s, &addr, tos);
+}
+
+bool HasLocalAddress(SOCKET socket) {
+ TOpaqueAddr localAddr;
+ if (getsockname(socket, localAddr.MutableAddr(), localAddr.LenPtr()) != 0) {
+ ythrow TSystemError() << "HasLocalAddress: getsockname() failed. ";
+ }
+ if (IsLoopback(localAddr)) {
+ return true;
+ }
+
+ TOpaqueAddr remoteAddr;
+ if (getpeername(socket, remoteAddr.MutableAddr(), remoteAddr.LenPtr()) != 0) {
+ ythrow TSystemError() << "HasLocalAddress: getpeername() failed. ";
+ }
+ return IsSame(localAddr, remoteAddr);
+}
+
+namespace {
+#if defined(_linux_)
+ #if !defined(TCP_FASTOPEN)
+ #define TCP_FASTOPEN 23
+ #endif
+#endif
+
+#if defined(TCP_FASTOPEN)
+ struct TTcpFastOpenFeature {
+ inline TTcpFastOpenFeature()
+ : HasFastOpen_(false)
+ {
+ TSocketHolder tmp(socket(AF_INET, SOCK_STREAM, 0));
+ int val = 1;
+ int ret = SetSockOpt(tmp, IPPROTO_TCP, TCP_FASTOPEN, val);
+ HasFastOpen_ = (ret == 0);
+ }
+
+ inline void SetFastOpen(SOCKET s, int qlen) const {
+ if (HasFastOpen_) {
+ CheckedSetSockOpt(s, IPPROTO_TCP, TCP_FASTOPEN, qlen, "setting TCP_FASTOPEN");
+ }
+ }
+
+ static inline const TTcpFastOpenFeature* Instance() noexcept {
+ return Singleton<TTcpFastOpenFeature>();
+ }
+
+ bool HasFastOpen_;
+ };
+#endif
+}
+
+void SetTcpFastOpen(SOCKET s, int qlen) {
+#if defined(TCP_FASTOPEN)
+ TTcpFastOpenFeature::Instance()->SetFastOpen(s, qlen);
+#else
+ Y_UNUSED(s);
+ Y_UNUSED(qlen);
+#endif
+}
+
+static bool IsBlocked(int lasterr) noexcept {
+ return lasterr == EAGAIN || lasterr == EWOULDBLOCK;
+}
+
+struct TUnblockingGuard {
+ SOCKET S_;
+
+ TUnblockingGuard(SOCKET s)
+ : S_(s)
+ {
+ SetNonBlock(S_, true);
+ }
+
+ ~TUnblockingGuard() {
+ SetNonBlock(S_, false);
+ }
+};
+
+static int MsgPeek(SOCKET s) {
+ int flags = MSG_PEEK;
+
+#if defined(_win_)
+ TUnblockingGuard unblocker(s);
+ Y_UNUSED(unblocker);
+#else
+ flags |= MSG_DONTWAIT;
+#endif
+
+ char c;
+ return recv(s, &c, 1, flags);
+}
+
+bool IsNotSocketClosedByOtherSide(SOCKET s) {
+ return HasSocketDataToRead(s) != ESocketReadStatus::SocketClosed;
+}
+
+ESocketReadStatus HasSocketDataToRead(SOCKET s) {
+ const int r = MsgPeek(s);
+ if (r == -1 && IsBlocked(LastSystemError())) {
+ return ESocketReadStatus::NoData;
+ }
+ if (r > 0) {
+ return ESocketReadStatus::HasData;
+ }
+ return ESocketReadStatus::SocketClosed;
+}
+
+#if defined(_win_)
+static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) {
+ return writev(sock, iov, iovcnt);
+}
+#else
+static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) {
+ struct msghdr message;
+
+ Zero(message);
+ message.msg_iov = const_cast<struct iovec*>(iov);
+ message.msg_iovlen = iovcnt;
+
+ return sendmsg(sock, &message, MSG_NOSIGNAL);
+}
+#endif
+
+void TSocketHolder::Close() noexcept {
+ if (Fd_ != INVALID_SOCKET) {
+ bool ok = (closesocket(Fd_) == 0);
+ if (!ok) {
+// Do not quietly close bad descriptor,
+// because often it means double close
+// that is disasterous
+#ifdef _win_
+ Y_VERIFY(WSAGetLastError() != WSAENOTSOCK, "must not quietly close bad socket descriptor");
+#elif defined(_unix_)
+ Y_VERIFY(errno != EBADF, "must not quietly close bad descriptor: fd=%d", int(Fd_));
+#else
+ #error unsupported platform
+#endif
+ }
+
+ Fd_ = INVALID_SOCKET;
+ }
+}
+
+class TSocket::TImpl: public TAtomicRefCount<TImpl> {
+ using TOps = TSocket::TOps;
+
+public:
+ inline TImpl(SOCKET fd, TOps* ops)
+ : Fd_(fd)
+ , Ops_(ops)
+ {
+ }
+
+ inline ~TImpl() = default;
+
+ inline SOCKET Fd() const noexcept {
+ return Fd_;
+ }
+
+ inline ssize_t Send(const void* data, size_t len) {
+ return Ops_->Send(Fd_, data, len);
+ }
+
+ inline ssize_t Recv(void* buf, size_t len) {
+ return Ops_->Recv(Fd_, buf, len);
+ }
+
+ inline ssize_t SendV(const TPart* parts, size_t count) {
+ return Ops_->SendV(Fd_, parts, count);
+ }
+
+ inline void Close() {
+ Fd_.Close();
+ }
+
+private:
+ TSocketHolder Fd_;
+ TOps* Ops_;
+};
+
+template <>
+void Out<const struct addrinfo*>(IOutputStream& os, const struct addrinfo* ai) {
+ if (ai->ai_flags & AI_CANONNAME) {
+ os << "`" << ai->ai_canonname << "' ";
+ }
+
+ os << '[';
+ for (int i = 0; ai; ++i, ai = ai->ai_next) {
+ if (i > 0) {
+ os << ", ";
+ }
+
+ os << (const IRemoteAddr&)TAddrInfo(ai);
+ }
+ os << ']';
+}
+
+template <>
+void Out<struct addrinfo*>(IOutputStream& os, struct addrinfo* ai) {
+ Out<const struct addrinfo*>(os, static_cast<const struct addrinfo*>(ai));
+}
+
+template <>
+void Out<TNetworkAddress>(IOutputStream& os, const TNetworkAddress& addr) {
+ os << &*addr.Begin();
+}
+
+static inline const struct addrinfo* Iterate(const struct addrinfo* addr, const struct addrinfo* addr0, const int sockerr) {
+ if (addr->ai_next) {
+ return addr->ai_next;
+ }
+
+ ythrow TSystemError(sockerr) << "can not connect to " << addr0;
+}
+
+static inline SOCKET DoConnectImpl(const struct addrinfo* res, const TInstant& deadLine) {
+ const struct addrinfo* addr0 = res;
+
+ while (res) {
+ TSocketHolder s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
+
+ if (s.Closed()) {
+ res = Iterate(res, addr0, LastSystemError());
+
+ continue;
+ }
+
+ SetNonBlock(s, true);
+
+ if (connect(s, res->ai_addr, (int)res->ai_addrlen)) {
+ int err = LastSystemError();
+
+ if (err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK) {
+ /*
+ * must wait
+ */
+ struct pollfd p = {
+ (SOCKET)s,
+ POLLOUT,
+ 0};
+
+ const ssize_t n = PollD(&p, 1, deadLine);
+
+ /*
+ * timeout occured
+ */
+ if (n < 0) {
+ ythrow TSystemError(-(int)n) << "can not connect";
+ }
+
+ CheckedGetSockOpt(s, SOL_SOCKET, SO_ERROR, err, "socket error");
+
+ if (!err) {
+ return s.Release();
+ }
+ }
+
+ res = Iterate(res, addr0, err);
+
+ continue;
+ }
+
+ return s.Release();
+ }
+
+ ythrow yexception() << "something went wrong: nullptr at addrinfo";
+}
+
+static inline SOCKET DoConnect(const struct addrinfo* res, const TInstant& deadLine) {
+ TSocketHolder ret(DoConnectImpl(res, deadLine));
+
+ SetNonBlock(ret, false);
+
+ return ret.Release();
+}
+
+static inline ssize_t DoSendV(SOCKET fd, const struct iovec* iov, size_t count) {
+ ssize_t ret = -1;
+ do {
+ ret = DoSendMsg(fd, iov, (int)count);
+ } while (ret == -1 && errno == EINTR);
+
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ return ret;
+}
+
+template <bool isCompat>
+struct TSender {
+ using TPart = TSocket::TPart;
+
+ static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) {
+ return DoSendV(fd, (const iovec*)parts, count);
+ }
+};
+
+template <>
+struct TSender<false> {
+ using TPart = TSocket::TPart;
+
+ static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) {
+ TTempBuf tempbuf(sizeof(struct iovec) * count);
+ struct iovec* iov = (struct iovec*)tempbuf.Data();
+
+ for (size_t i = 0; i < count; ++i) {
+ struct iovec& io = iov[i];
+ const TPart& part = parts[i];
+
+ io.iov_base = (char*)part.buf;
+ io.iov_len = part.len;
+ }
+
+ return DoSendV(fd, iov, count);
+ }
+};
+
+class TCommonSockOps: public TSocket::TOps {
+ using TPart = TSocket::TPart;
+
+public:
+ inline TCommonSockOps() noexcept {
+ }
+
+ ~TCommonSockOps() override = default;
+
+ ssize_t Send(SOCKET fd, const void* data, size_t len) override {
+ ssize_t ret = -1;
+ do {
+ ret = send(fd, (const char*)data, (int)len, MSG_NOSIGNAL);
+ } while (ret == -1 && errno == EINTR);
+
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ return ret;
+ }
+
+ ssize_t Recv(SOCKET fd, void* buf, size_t len) override {
+ ssize_t ret = -1;
+ do {
+ ret = recv(fd, (char*)buf, (int)len, 0);
+ } while (ret == -1 && errno == EINTR);
+
+ if (ret < 0) {
+ return -LastSystemError();
+ }
+
+ return ret;
+ }
+
+ ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) override {
+ ssize_t ret = SendVImpl(fd, parts, count);
+
+ if (ret < 0) {
+ return ret;
+ }
+
+ size_t len = TContIOVector::Bytes(parts, count);
+
+ if ((size_t)ret == len) {
+ return ret;
+ }
+
+ return SendVPartial(fd, parts, count, ret);
+ }
+
+ inline ssize_t SendVImpl(SOCKET fd, const TPart* parts, size_t count) {
+ return TSender < (sizeof(iovec) == sizeof(TPart)) && (offsetof(iovec, iov_base) == offsetof(TPart, buf)) && (offsetof(iovec, iov_len) == offsetof(TPart, len)) > ::SendV(fd, parts, count);
+ }
+
+ ssize_t SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written);
+};
+
+ssize_t TCommonSockOps::SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written) {
+ TTempBuf tempbuf(sizeof(TPart) * count);
+ TPart* parts = (TPart*)tempbuf.Data();
+
+ for (size_t i = 0; i < count; ++i) {
+ parts[i] = constParts[i];
+ }
+
+ TContIOVector vec(parts, count);
+ vec.Proceed(written);
+
+ while (!vec.Complete()) {
+ ssize_t ret = SendVImpl(fd, vec.Parts(), vec.Count());
+
+ if (ret < 0) {
+ return ret;
+ }
+
+ written += ret;
+
+ vec.Proceed((size_t)ret);
+ }
+
+ return written;
+}
+
+static inline TSocket::TOps* GetCommonSockOps() noexcept {
+ return Singleton<TCommonSockOps>();
+}
+
+TSocket::TSocket()
+ : Impl_(new TImpl(INVALID_SOCKET, GetCommonSockOps()))
+{
+}
+
+TSocket::TSocket(SOCKET fd)
+ : Impl_(new TImpl(fd, GetCommonSockOps()))
+{
+}
+
+TSocket::TSocket(SOCKET fd, TOps* ops)
+ : Impl_(new TImpl(fd, ops))
+{
+}
+
+TSocket::TSocket(const TNetworkAddress& addr)
+ : Impl_(new TImpl(DoConnect(addr.Info(), TInstant::Max()), GetCommonSockOps()))
+{
+}
+
+TSocket::TSocket(const TNetworkAddress& addr, const TDuration& timeOut)
+ : Impl_(new TImpl(DoConnect(addr.Info(), timeOut.ToDeadLine()), GetCommonSockOps()))
+{
+}
+
+TSocket::TSocket(const TNetworkAddress& addr, const TInstant& deadLine)
+ : Impl_(new TImpl(DoConnect(addr.Info(), deadLine), GetCommonSockOps()))
+{
+}
+
+TSocket::~TSocket() = default;
+
+SOCKET TSocket::Fd() const noexcept {
+ return Impl_->Fd();
+}
+
+ssize_t TSocket::Send(const void* data, size_t len) {
+ return Impl_->Send(data, len);
+}
+
+ssize_t TSocket::Recv(void* buf, size_t len) {
+ return Impl_->Recv(buf, len);
+}
+
+ssize_t TSocket::SendV(const TPart* parts, size_t count) {
+ return Impl_->SendV(parts, count);
+}
+
+void TSocket::Close() {
+ Impl_->Close();
+}
+
+TSocketInput::TSocketInput(const TSocket& s) noexcept
+ : S_(s)
+{
+}
+
+TSocketInput::~TSocketInput() = default;
+
+size_t TSocketInput::DoRead(void* buf, size_t len) {
+ const ssize_t ret = S_.Recv(buf, len);
+
+ if (ret >= 0) {
+ return (size_t)ret;
+ }
+
+ ythrow TSystemError(-(int)ret) << "can not read from socket input stream";
+}
+
+TSocketOutput::TSocketOutput(const TSocket& s) noexcept
+ : S_(s)
+{
+}
+
+TSocketOutput::~TSocketOutput() {
+ try {
+ Finish();
+ } catch (...) {
+ // ¯\_(ツ)_/¯
+ }
+}
+
+void TSocketOutput::DoWrite(const void* buf, size_t len) {
+ size_t send = 0;
+ while (len) {
+ const ssize_t ret = S_.Send(buf, len);
+
+ if (ret < 0) {
+ ythrow TSystemError(-(int)ret) << "can not write to socket output stream; " << send << " bytes already send";
+ }
+ buf = (const char*)buf + ret;
+ len -= ret;
+ send += ret;
+ }
+}
+
+void TSocketOutput::DoWriteV(const TPart* parts, size_t count) {
+ const ssize_t ret = S_.SendV(parts, count);
+
+ if (ret < 0) {
+ ythrow TSystemError(-(int)ret) << "can not writev to socket output stream";
+ }
+
+ /*
+ * todo for nonblocking sockets?
+ */
+}
+
+namespace {
+ //https://bugzilla.mozilla.org/attachment.cgi?id=503263&action=diff
+
+ struct TLocalNames: public THashSet<TStringBuf> {
+ inline TLocalNames() {
+ insert("localhost");
+ insert("localhost.localdomain");
+ insert("localhost6");
+ insert("localhost6.localdomain6");
+ insert("::1");
+ }
+
+ inline bool IsLocalName(const char* name) const noexcept {
+ struct sockaddr_in sa;
+ memset(&sa, 0, sizeof(sa));
+
+ if (inet_pton(AF_INET, name, &(sa.sin_addr)) == 1) {
+ return (InetToHost(sa.sin_addr.s_addr) >> 24) == 127;
+ }
+
+ return contains(name);
+ }
+ };
+}
+
+class TNetworkAddress::TImpl: public TAtomicRefCount<TImpl> {
+private:
+ class TAddrInfoDeleter {
+ public:
+ TAddrInfoDeleter(bool useFreeAddrInfo = true)
+ : UseFreeAddrInfo_(useFreeAddrInfo)
+ {
+ }
+
+ void operator()(struct addrinfo* ai) noexcept {
+ if (!UseFreeAddrInfo_ && ai != NULL) {
+ if (ai->ai_addr != NULL) {
+ free(ai->ai_addr);
+ }
+
+ struct addrinfo* p;
+ while (ai != NULL) {
+ p = ai;
+ ai = ai->ai_next;
+ free(p->ai_canonname);
+ free(p);
+ }
+ } else if (ai != NULL) {
+ freeaddrinfo(ai);
+ }
+ }
+
+ private:
+ bool UseFreeAddrInfo_ = true;
+ };
+
+public:
+ inline TImpl(const char* host, ui16 port, int flags)
+ : Info_(nullptr, TAddrInfoDeleter{})
+ {
+ const TString port_st(ToString(port));
+ struct addrinfo hints;
+
+ memset(&hints, 0, sizeof(hints));
+
+ hints.ai_flags = flags;
+ hints.ai_family = PF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+
+ if (!host) {
+ hints.ai_flags |= AI_PASSIVE;
+ } else {
+ if (!Singleton<TLocalNames>()->IsLocalName(host)) {
+ hints.ai_flags |= AI_ADDRCONFIG;
+ }
+ }
+
+ struct addrinfo* pai = NULL;
+ const int error = getaddrinfo(host, port_st.data(), &hints, &pai);
+
+ if (error) {
+ TAddrInfoDeleter()(pai);
+ ythrow TNetworkResolutionError(error) << ": can not resolve " << host << ":" << port;
+ }
+
+ Info_.reset(pai);
+ }
+
+ inline TImpl(const char* path, int flags)
+ : Info_(nullptr, TAddrInfoDeleter{/* useFreeAddrInfo = */ false})
+ {
+ THolder<struct sockaddr_un, TFree> sockAddr(
+ reinterpret_cast<struct sockaddr_un*>(malloc(sizeof(struct sockaddr_un))));
+
+ Y_ENSURE(strlen(path) < sizeof(sockAddr->sun_path), "Unix socket path more than " << sizeof(sockAddr->sun_path));
+ sockAddr->sun_family = AF_UNIX;
+ strcpy(sockAddr->sun_path, path);
+
+ TAddrInfoPtr hints(reinterpret_cast<struct addrinfo*>(malloc(sizeof(struct addrinfo))), TAddrInfoDeleter{/* useFreeAddrInfo = */ false});
+ memset(hints.get(), 0, sizeof(*hints));
+
+ hints->ai_flags = flags;
+ hints->ai_family = AF_UNIX;
+ hints->ai_socktype = SOCK_STREAM;
+ hints->ai_addrlen = sizeof(*sockAddr);
+ hints->ai_addr = (struct sockaddr*)sockAddr.Release();
+
+ Info_.reset(hints.release());
+ }
+
+ inline struct addrinfo* Info() const noexcept {
+ return Info_.get();
+ }
+
+private:
+ using TAddrInfoPtr = std::unique_ptr<struct addrinfo, TAddrInfoDeleter>;
+
+ TAddrInfoPtr Info_;
+};
+
+TNetworkAddress::TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags)
+ : Impl_(new TImpl(unixSocketPath.Path.data(), flags))
+{
+}
+
+TNetworkAddress::TNetworkAddress(const TString& host, ui16 port, int flags)
+ : Impl_(new TImpl(host.data(), port, flags))
+{
+}
+
+TNetworkAddress::TNetworkAddress(const TString& host, ui16 port)
+ : Impl_(new TImpl(host.data(), port, 0))
+{
+}
+
+TNetworkAddress::TNetworkAddress(ui16 port)
+ : Impl_(new TImpl(nullptr, port, 0))
+{
+}
+
+TNetworkAddress::~TNetworkAddress() = default;
+
+struct addrinfo* TNetworkAddress::Info() const noexcept {
+ return Impl_->Info();
+}
+
+TNetworkResolutionError::TNetworkResolutionError(int error) {
+ const char* errMsg = nullptr;
+#ifdef _win_
+ errMsg = LastSystemErrorText(error); // gai_strerror is not thread-safe on Windows
+#else
+ errMsg = gai_strerror(error);
+#endif
+ (*this) << errMsg << "(" << error;
+
+#if defined(_unix_)
+ if (error == EAI_SYSTEM) {
+ (*this) << "; errno=" << LastSystemError();
+ }
+#endif
+
+ (*this) << "): ";
+}
+
+#if defined(_unix_)
+static inline int GetFlags(int fd) {
+ const int ret = fcntl(fd, F_GETFL);
+
+ if (ret == -1) {
+ ythrow TSystemError() << "can not get fd flags";
+ }
+
+ return ret;
+}
+
+static inline void SetFlags(int fd, int flags) {
+ if (fcntl(fd, F_SETFL, flags) == -1) {
+ ythrow TSystemError() << "can not set fd flags";
+ }
+}
+
+static inline void EnableFlag(int fd, int flag) {
+ const int oldf = GetFlags(fd);
+ const int newf = oldf | flag;
+
+ if (oldf != newf) {
+ SetFlags(fd, newf);
+ }
+}
+
+static inline void DisableFlag(int fd, int flag) {
+ const int oldf = GetFlags(fd);
+ const int newf = oldf & (~flag);
+
+ if (oldf != newf) {
+ SetFlags(fd, newf);
+ }
+}
+
+static inline void SetFlag(int fd, int flag, bool value) {
+ if (value) {
+ EnableFlag(fd, flag);
+ } else {
+ DisableFlag(fd, flag);
+ }
+}
+
+static inline bool FlagsAreEnabled(int fd, int flags) {
+ return GetFlags(fd) & flags;
+}
+#endif
+
+#if defined(_win_)
+static inline void SetNonBlockSocket(SOCKET fd, int value) {
+ unsigned long inbuf = value;
+ unsigned long outbuf = 0;
+ DWORD written = 0;
+
+ if (!inbuf) {
+ WSAEventSelect(fd, nullptr, 0);
+ }
+
+ if (WSAIoctl(fd, FIONBIO, &inbuf, sizeof(inbuf), &outbuf, sizeof(outbuf), &written, 0, 0) == SOCKET_ERROR) {
+ ythrow TSystemError() << "can not set non block socket state";
+ }
+}
+
+static inline bool IsNonBlockSocket(SOCKET fd) {
+ unsigned long buf = 0;
+
+ if (WSAIoctl(fd, FIONBIO, 0, 0, &buf, sizeof(buf), 0, 0, 0) == SOCKET_ERROR) {
+ ythrow TSystemError() << "can not get non block socket state";
+ }
+
+ return buf;
+}
+#endif
+
+void SetNonBlock(SOCKET fd, bool value) {
+#if defined(_unix_)
+ #if defined(FIONBIO)
+ Y_UNUSED(SetFlag); // shut up clang about unused function
+ int nb = value;
+
+ if (ioctl(fd, FIONBIO, &nb) < 0) {
+ ythrow TSystemError() << "ioctl failed";
+ }
+ #else
+ SetFlag(fd, O_NONBLOCK, value);
+ #endif
+#elif defined(_win_)
+ SetNonBlockSocket(fd, value);
+#else
+ #error todo
+#endif
+}
+
+bool IsNonBlock(SOCKET fd) {
+#if defined(_unix_)
+ return FlagsAreEnabled(fd, O_NONBLOCK);
+#elif defined(_win_)
+ return IsNonBlockSocket(fd);
+#else
+ #error todo
+#endif
+}
+
+void SetDeferAccept(SOCKET s) {
+ (void)s;
+
+#if defined(TCP_DEFER_ACCEPT)
+ CheckedSetSockOpt(s, IPPROTO_TCP, TCP_DEFER_ACCEPT, 10, "defer accept");
+#endif
+
+#if defined(SO_ACCEPTFILTER)
+ struct accept_filter_arg afa;
+
+ Zero(afa);
+ strcpy(afa.af_name, "dataready");
+ SetSockOpt(s, SOL_SOCKET, SO_ACCEPTFILTER, afa);
+#endif
+}
+
+ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept {
+ TInstant now = TInstant::Now();
+
+ do {
+ const TDuration toWait = PollStep(deadLine, now);
+ const int res = poll(fds, nfds, MicroToMilli(toWait.MicroSeconds()));
+
+ if (res > 0) {
+ return res;
+ }
+
+ if (res < 0) {
+ const int err = LastSystemError();
+
+ if (err != ETIMEDOUT && err != EINTR) {
+ return -err;
+ }
+ }
+ } while ((now = TInstant::Now()) < deadLine);
+
+ return -ETIMEDOUT;
+}
+
+void ShutDown(SOCKET s, int mode) {
+ if (shutdown(s, mode)) {
+ ythrow TSystemError() << "shutdown socket error";
+ }
+}
+
+extern "C" bool IsReusePortAvailable() {
+// SO_REUSEPORT is always defined for linux builds, see SetReusePort() implementation above
+#if defined(SO_REUSEPORT)
+
+ class TCtx {
+ public:
+ TCtx() {
+ TSocketHolder sock(::socket(AF_INET, SOCK_STREAM, 0));
+ const int e1 = errno;
+ if (sock == INVALID_SOCKET) {
+ ythrow TSystemError(e1) << "Cannot create AF_INET socket";
+ }
+ int val;
+ const int ret = GetSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, val);
+ const int e2 = errno;
+ if (ret == 0) {
+ Flag_ = true;
+ } else {
+ if (e2 == ENOPROTOOPT) {
+ Flag_ = false;
+ } else {
+ ythrow TSystemError(e2) << "Unexpected error in getsockopt";
+ }
+ }
+ }
+
+ static inline const TCtx* Instance() noexcept {
+ return Singleton<TCtx>();
+ }
+
+ public:
+ bool Flag_;
+ };
+
+ return TCtx::Instance()->Flag_;
+#else
+ return false;
+#endif
+}
diff --git a/util/network/socket.h b/util/network/socket.h
new file mode 100644
index 0000000000..357ad4079b
--- /dev/null
+++ b/util/network/socket.h
@@ -0,0 +1,431 @@
+#pragma once
+
+#include "init.h"
+
+#include <util/system/yassert.h>
+#include <util/system/defaults.h>
+#include <util/system/error.h>
+#include <util/stream/output.h>
+#include <util/stream/input.h>
+#include <util/generic/ptr.h>
+#include <util/generic/yexception.h>
+#include <util/generic/noncopyable.h>
+#include <util/datetime/base.h>
+
+#include <cerrno>
+
+#ifndef INET_ADDRSTRLEN
+ #define INET_ADDRSTRLEN 16
+#endif
+
+#if defined(_unix_)
+ #define get_host_error() h_errno
+#elif defined(_win_)
+ #pragma comment(lib, "Ws2_32.lib")
+
+ #if _WIN32_WINNT < 0x0600
+struct pollfd {
+ SOCKET fd;
+ short events;
+ short revents;
+};
+
+ #define POLLIN (1 << 0)
+ #define POLLRDNORM (1 << 1)
+ #define POLLRDBAND (1 << 2)
+ #define POLLPRI (1 << 3)
+ #define POLLOUT (1 << 4)
+ #define POLLWRNORM (1 << 5)
+ #define POLLWRBAND (1 << 6)
+ #define POLLERR (1 << 7)
+ #define POLLHUP (1 << 8)
+ #define POLLNVAL (1 << 9)
+
+const char* inet_ntop(int af, const void* src, char* dst, socklen_t size);
+int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept;
+ #else
+ #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout)
+ #endif
+
+int inet_aton(const char* cp, struct in_addr* inp);
+
+ #define get_host_error() WSAGetLastError()
+
+ #define SHUT_RD SD_RECEIVE
+ #define SHUT_WR SD_SEND
+ #define SHUT_RDWR SD_BOTH
+
+ #define INFTIM (-1)
+#endif
+
+template <class T>
+static inline int SetSockOpt(SOCKET s, int level, int optname, T opt) noexcept {
+ return setsockopt(s, level, optname, (const char*)&opt, sizeof(opt));
+}
+
+template <class T>
+static inline int GetSockOpt(SOCKET s, int level, int optname, T& opt) noexcept {
+ socklen_t len = sizeof(opt);
+
+ return getsockopt(s, level, optname, (char*)&opt, &len);
+}
+
+template <class T>
+static inline void CheckedSetSockOpt(SOCKET s, int level, int optname, T opt, const char* err) {
+ if (SetSockOpt<T>(s, level, optname, opt)) {
+ ythrow TSystemError() << "setsockopt() failed for " << err;
+ }
+}
+
+template <class T>
+static inline void CheckedGetSockOpt(SOCKET s, int level, int optname, T& opt, const char* err) {
+ if (GetSockOpt<T>(s, level, optname, opt)) {
+ ythrow TSystemError() << "getsockopt() failed for " << err;
+ }
+}
+
+static inline void FixIPv6ListenSocket(SOCKET s) {
+#if defined(IPV6_V6ONLY)
+ SetSockOpt(s, IPPROTO_IPV6, IPV6_V6ONLY, 1);
+#else
+ (void)s;
+#endif
+}
+
+namespace NAddr {
+ class IRemoteAddr;
+}
+
+void SetSocketTimeout(SOCKET s, long timeout);
+void SetSocketTimeout(SOCKET s, long sec, long msec);
+void SetNoDelay(SOCKET s, bool value);
+void SetKeepAlive(SOCKET s);
+void SetLinger(SOCKET s, bool on, unsigned len);
+void SetZeroLinger(SOCKET s);
+void SetKeepAlive(SOCKET s, bool value);
+void SetCloseOnExec(SOCKET s, bool value);
+void SetOutputBuffer(SOCKET s, unsigned value);
+void SetInputBuffer(SOCKET s, unsigned value);
+void SetReusePort(SOCKET s, bool value);
+void ShutDown(SOCKET s, int mode);
+bool GetRemoteAddr(SOCKET s, char* str, socklen_t size);
+size_t GetMaximumSegmentSize(SOCKET s);
+size_t GetMaximumTransferUnit(SOCKET s);
+void SetDeferAccept(SOCKET s);
+void SetSocketToS(SOCKET s, int tos);
+void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos);
+int GetSocketToS(SOCKET s);
+int GetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr);
+void SetTcpFastOpen(SOCKET s, int qlen);
+/**
+ * Deprecated, consider using HasSocketDataToRead instead.
+ **/
+bool IsNotSocketClosedByOtherSide(SOCKET s);
+enum class ESocketReadStatus {
+ HasData,
+ NoData,
+ SocketClosed
+};
+/**
+ * Useful for keep-alive connections.
+ **/
+ESocketReadStatus HasSocketDataToRead(SOCKET s);
+/**
+ * Determines whether connection on socket is local (same machine) or not.
+ **/
+bool HasLocalAddress(SOCKET socket);
+
+/**
+ * Runtime check if current kernel supports SO_REUSEPORT option.
+ **/
+extern "C" bool IsReusePortAvailable();
+
+bool IsNonBlock(SOCKET fd);
+void SetNonBlock(SOCKET fd, bool nonBlock = true);
+
+struct addrinfo;
+
+class TNetworkResolutionError: public yexception {
+public:
+ // @param error error code (EAI_XXX) returned by getaddrinfo or getnameinfo (not errno)
+ TNetworkResolutionError(int error);
+};
+
+struct TUnixSocketPath {
+ TString Path;
+
+ // Constructor for create unix domain socket path from string with path in filesystem
+ // TUnixSocketPath("/tmp/unixsocket") -> "/tmp/unixsocket"
+ explicit TUnixSocketPath(const TString& path)
+ : Path(path)
+ {
+ }
+};
+
+class TNetworkAddress {
+ friend class TSocket;
+
+public:
+ class TIterator {
+ public:
+ inline TIterator(struct addrinfo* begin)
+ : C_(begin)
+ {
+ }
+
+ inline void Next() noexcept {
+ C_ = C_->ai_next;
+ }
+
+ inline TIterator operator++(int) noexcept {
+ TIterator old(*this);
+
+ Next();
+
+ return old;
+ }
+
+ inline TIterator& operator++() noexcept {
+ Next();
+
+ return *this;
+ }
+
+ friend inline bool operator==(const TIterator& l, const TIterator& r) noexcept {
+ return l.C_ == r.C_;
+ }
+
+ friend inline bool operator!=(const TIterator& l, const TIterator& r) noexcept {
+ return !(l == r);
+ }
+
+ inline struct addrinfo& operator*() const noexcept {
+ return *C_;
+ }
+
+ inline struct addrinfo* operator->() const noexcept {
+ return C_;
+ }
+
+ private:
+ struct addrinfo* C_;
+ };
+
+ TNetworkAddress(ui16 port);
+ TNetworkAddress(const TString& host, ui16 port);
+ TNetworkAddress(const TString& host, ui16 port, int flags);
+ TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags = 0);
+ ~TNetworkAddress();
+
+ inline TIterator Begin() const noexcept {
+ return TIterator(Info());
+ }
+
+ inline TIterator End() const noexcept {
+ return TIterator(nullptr);
+ }
+
+private:
+ struct addrinfo* Info() const noexcept;
+
+private:
+ class TImpl;
+ TSimpleIntrusivePtr<TImpl> Impl_;
+};
+
+class TSocket;
+
+class TSocketHolder: public TMoveOnly {
+public:
+ inline TSocketHolder()
+ : Fd_(INVALID_SOCKET)
+ {
+ }
+
+ inline TSocketHolder(SOCKET fd)
+ : Fd_(fd)
+ {
+ }
+
+ inline TSocketHolder(TSocketHolder&& other) noexcept {
+ Fd_ = other.Fd_;
+ other.Fd_ = INVALID_SOCKET;
+ }
+
+ inline TSocketHolder& operator=(TSocketHolder&& other) noexcept {
+ Close();
+ Swap(other);
+
+ return *this;
+ }
+
+ inline ~TSocketHolder() {
+ Close();
+ }
+
+ inline SOCKET Release() noexcept {
+ SOCKET ret = Fd_;
+ Fd_ = INVALID_SOCKET;
+ return ret;
+ }
+
+ void Close() noexcept;
+
+ inline void ShutDown(int mode) const {
+ ::ShutDown(Fd_, mode);
+ }
+
+ inline void Swap(TSocketHolder& r) noexcept {
+ DoSwap(Fd_, r.Fd_);
+ }
+
+ inline bool Closed() const noexcept {
+ return Fd_ == INVALID_SOCKET;
+ }
+
+ inline operator SOCKET() const noexcept {
+ return Fd_;
+ }
+
+private:
+ SOCKET Fd_;
+
+ // do not allow construction of TSocketHolder from TSocket
+ TSocketHolder(const TSocket& fd);
+};
+
+class TSocket {
+public:
+ using TPart = IOutputStream::TPart;
+
+ class TOps {
+ public:
+ inline TOps() noexcept = default;
+ virtual ~TOps() = default;
+
+ virtual ssize_t Send(SOCKET fd, const void* data, size_t len) = 0;
+ virtual ssize_t Recv(SOCKET fd, void* buf, size_t len) = 0;
+ virtual ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) = 0;
+ };
+
+ TSocket();
+ TSocket(SOCKET fd);
+ TSocket(SOCKET fd, TOps* ops);
+ TSocket(const TNetworkAddress& addr);
+ TSocket(const TNetworkAddress& addr, const TDuration& timeOut);
+ TSocket(const TNetworkAddress& addr, const TInstant& deadLine);
+
+ ~TSocket();
+
+ template <class T>
+ inline void SetSockOpt(int level, int optname, T opt) {
+ CheckedSetSockOpt(Fd(), level, optname, opt, "TSocket");
+ }
+
+ inline void SetSocketTimeout(long timeout) {
+ ::SetSocketTimeout(Fd(), timeout);
+ }
+
+ inline void SetSocketTimeout(long sec, long msec) {
+ ::SetSocketTimeout(Fd(), sec, msec);
+ }
+
+ inline void SetNoDelay(bool value) {
+ ::SetNoDelay(Fd(), value);
+ }
+
+ inline void SetLinger(bool on, unsigned len) {
+ ::SetLinger(Fd(), on, len);
+ }
+
+ inline void SetZeroLinger() {
+ ::SetZeroLinger(Fd());
+ }
+
+ inline void SetKeepAlive(bool value) {
+ ::SetKeepAlive(Fd(), value);
+ }
+
+ inline void SetOutputBuffer(unsigned value) {
+ ::SetOutputBuffer(Fd(), value);
+ }
+
+ inline void SetInputBuffer(unsigned value) {
+ ::SetInputBuffer(Fd(), value);
+ }
+
+ inline size_t MaximumSegmentSize() const {
+ return GetMaximumSegmentSize(Fd());
+ }
+
+ inline size_t MaximumTransferUnit() const {
+ return GetMaximumTransferUnit(Fd());
+ }
+
+ inline void ShutDown(int mode) const {
+ ::ShutDown(Fd(), mode);
+ }
+
+ void Close();
+
+ ssize_t Send(const void* data, size_t len);
+ ssize_t Recv(void* buf, size_t len);
+
+ /*
+ * scatter/gather io
+ */
+ ssize_t SendV(const TPart* parts, size_t count);
+
+ inline operator SOCKET() const noexcept {
+ return Fd();
+ }
+
+private:
+ SOCKET Fd() const noexcept;
+
+private:
+ class TImpl;
+ TSimpleIntrusivePtr<TImpl> Impl_;
+};
+
+class TSocketInput: public IInputStream {
+public:
+ TSocketInput(const TSocket& s) noexcept;
+ ~TSocketInput() override;
+
+ TSocketInput(TSocketInput&&) noexcept = default;
+ TSocketInput& operator=(TSocketInput&&) noexcept = default;
+
+ const TSocket& GetSocket() const noexcept {
+ return S_;
+ }
+
+private:
+ size_t DoRead(void* buf, size_t len) override;
+
+private:
+ TSocket S_;
+};
+
+class TSocketOutput: public IOutputStream {
+public:
+ TSocketOutput(const TSocket& s) noexcept;
+ ~TSocketOutput() override;
+
+ TSocketOutput(TSocketOutput&&) noexcept = default;
+ TSocketOutput& operator=(TSocketOutput&&) noexcept = default;
+
+ const TSocket& GetSocket() const noexcept {
+ return S_;
+ }
+
+private:
+ void DoWrite(const void* buf, size_t len) override;
+ void DoWriteV(const TPart* parts, size_t count) override;
+
+private:
+ TSocket S_;
+};
+
+//return -(error code) if error occured, or number of ready fds
+ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept;
diff --git a/util/network/socket_ut.cpp b/util/network/socket_ut.cpp
new file mode 100644
index 0000000000..6b20e11f70
--- /dev/null
+++ b/util/network/socket_ut.cpp
@@ -0,0 +1,341 @@
+#include "socket.h"
+
+#include "pair.h"
+
+#include <library/cpp/testing/unittest/registar.h>
+
+#include <util/string/builder.h>
+#include <util/generic/vector.h>
+
+#include <ctime>
+
+#ifdef _linux_
+ #include <linux/version.h>
+ #include <sys/utsname.h>
+#endif
+
+class TSockTest: public TTestBase {
+ UNIT_TEST_SUITE(TSockTest);
+ UNIT_TEST(TestSock);
+ UNIT_TEST(TestTimeout);
+#ifndef _win_ // Test hangs on Windows
+ UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception);
+#endif
+ UNIT_TEST(TestNetworkResolutionError);
+ UNIT_TEST(TestNetworkResolutionErrorMessage);
+ UNIT_TEST(TestBrokenPipe);
+ UNIT_TEST(TestClose);
+ UNIT_TEST(TestReusePortAvailCheck);
+ UNIT_TEST_SUITE_END();
+
+public:
+ void TestSock();
+ void TestTimeout();
+ void TestConnectionRefused();
+ void TestNetworkResolutionError();
+ void TestNetworkResolutionErrorMessage();
+ void TestBrokenPipe();
+ void TestClose();
+ void TestReusePortAvailCheck();
+};
+
+UNIT_TEST_SUITE_REGISTRATION(TSockTest);
+
+void TSockTest::TestSock() {
+ TNetworkAddress addr("yandex.ru", 80);
+ TSocket s(addr);
+ TSocketOutput so(s);
+ TSocketInput si(s);
+ const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n";
+
+ so.Write(req.data(), req.size());
+
+ UNIT_ASSERT(!si.ReadLine().empty());
+}
+
+void TSockTest::TestTimeout() {
+ static const int timeout = 1000;
+ i64 startTime = millisec();
+ try {
+ TNetworkAddress addr("localhost", 1313);
+ TSocket s(addr, TDuration::MilliSeconds(timeout));
+ } catch (const yexception&) {
+ }
+ int realTimeout = (int)(millisec() - startTime);
+ if (realTimeout > timeout + 2000) {
+ TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)";
+ UNIT_FAIL(err);
+ }
+}
+
+void TSockTest::TestConnectionRefused() {
+ TNetworkAddress addr("localhost", 1313);
+ TSocket s(addr);
+}
+
+void TSockTest::TestNetworkResolutionError() {
+ TString errMsg;
+ try {
+ TNetworkAddress addr("", 0);
+ } catch (const TNetworkResolutionError& e) {
+ errMsg = e.what();
+ }
+
+ if (errMsg.empty()) {
+ return; // on Windows getaddrinfo("", 0, ...) returns "OK"
+ }
+
+ int expectedErr = EAI_NONAME;
+ TString expectedErrMsg = gai_strerror(expectedErr);
+ if (errMsg.find(expectedErrMsg) == TString::npos) {
+ UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n");
+ }
+}
+
+void TSockTest::TestNetworkResolutionErrorMessage() {
+#ifdef _unix_
+ auto str = [](int code) -> TString {
+ return TNetworkResolutionError(code).what();
+ };
+
+ auto expected = [](int code) -> TString {
+ return gai_strerror(code);
+ };
+
+ struct TErrnoGuard {
+ TErrnoGuard()
+ : PrevValue_(errno)
+ {
+ }
+
+ ~TErrnoGuard() {
+ errno = PrevValue_;
+ }
+
+ private:
+ int PrevValue_;
+ } g;
+
+ UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0));
+ UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9));
+
+ errno = 0;
+ UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ",
+ str(EAI_SYSTEM));
+ errno = 110;
+ UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ",
+ str(EAI_SYSTEM));
+#endif
+}
+
+class TTempEnableSigPipe {
+public:
+ TTempEnableSigPipe() {
+ OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL);
+ Y_VERIFY(OriginalSigHandler_ != SIG_ERR);
+ }
+
+ ~TTempEnableSigPipe() {
+ auto ret = signal(SIGPIPE, OriginalSigHandler_);
+ Y_VERIFY(ret != SIG_ERR);
+ }
+
+private:
+ void (*OriginalSigHandler_)(int);
+};
+
+void TSockTest::TestBrokenPipe() {
+ TTempEnableSigPipe guard;
+
+ SOCKET socks[2];
+
+ int ret = SocketPair(socks);
+ UNIT_ASSERT_VALUES_EQUAL(ret, 0);
+
+ TSocket sender(socks[0]);
+ TSocket receiver(socks[1]);
+ receiver.ShutDown(SHUT_RDWR);
+ int sent = sender.Send("FOO", 3);
+ UNIT_ASSERT(sent < 0);
+
+ IOutputStream::TPart parts[] = {
+ {"foo", 3},
+ {"bar", 3},
+ };
+ sent = sender.SendV(parts, 2);
+ UNIT_ASSERT(sent < 0);
+}
+
+void TSockTest::TestClose() {
+ SOCKET socks[2];
+
+ UNIT_ASSERT_EQUAL(SocketPair(socks), 0);
+ TSocket receiver(socks[1]);
+
+ UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), socks[1]);
+
+#if defined _linux_
+ UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0);
+ receiver.Close();
+ UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1);
+#else
+ receiver.Close();
+#endif
+
+ UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), INVALID_SOCKET);
+}
+
+void TSockTest::TestReusePortAvailCheck() {
+#if defined _linux_
+ utsname sysInfo;
+ Y_VERIFY(!uname(&sysInfo), "Error while call uname: %s", LastSystemErrorText());
+ TStringBuf release(sysInfo.release);
+ release = release.substr(0, release.find_first_not_of(".0123456789"));
+ int v1 = FromString<int>(release.NextTok('.'));
+ int v2 = FromString<int>(release.NextTok('.'));
+ int v3 = FromString<int>(release.NextTok('.'));
+ int linuxVersionCode = KERNEL_VERSION(v1, v2, v3);
+ if (linuxVersionCode >= KERNEL_VERSION(3, 9, 1)) {
+ // new kernels support SO_REUSEPORT
+ UNIT_ASSERT(true == IsReusePortAvailable());
+ UNIT_ASSERT(true == IsReusePortAvailable());
+ } else {
+ // older kernels may or may not support SO_REUSEPORT
+ // just check that it doesn't crash or throw
+ (void)IsReusePortAvailable();
+ (void)IsReusePortAvailable();
+ }
+#else
+ // check that it doesn't crash or throw
+ (void)IsReusePortAvailable();
+ (void)IsReusePortAvailable();
+#endif
+}
+
+class TPollTest: public TTestBase {
+ UNIT_TEST_SUITE(TPollTest);
+ UNIT_TEST(TestPollInOut);
+ UNIT_TEST_SUITE_END();
+
+public:
+ inline TPollTest() {
+ srand(static_cast<unsigned int>(time(nullptr)));
+ }
+
+ void TestPollInOut();
+
+private:
+ sockaddr_in GetAddress(ui32 ip, ui16 port);
+ SOCKET CreateSocket();
+ SOCKET StartServerSocket(ui16 port, int backlog);
+ SOCKET StartClientSocket(ui32 ip, ui16 port);
+ SOCKET AcceptConnection(SOCKET serverSocket);
+};
+
+UNIT_TEST_SUITE_REGISTRATION(TPollTest);
+
+sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) {
+ struct sockaddr_in addr;
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_port = htons(port);
+ addr.sin_addr.s_addr = htonl(ip);
+ return addr;
+}
+
+SOCKET TPollTest::CreateSocket() {
+ SOCKET s = socket(AF_INET, SOCK_STREAM, 0);
+ if (s == INVALID_SOCKET) {
+ ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")";
+ }
+ return s;
+}
+
+SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) {
+ TSocketHolder s(CreateSocket());
+ sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port);
+ if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
+ ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")";
+ }
+ if (listen(s, backlog) == SOCKET_ERROR) {
+ ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")";
+ }
+ return s.Release();
+}
+
+SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) {
+ TSocketHolder s(CreateSocket());
+ sockaddr_in addr = GetAddress(ip, port);
+ if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
+ ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")";
+ }
+ return s.Release();
+}
+
+SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) {
+ SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr);
+ if (connectedSocket == INVALID_SOCKET) {
+ ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")";
+ }
+ return connectedSocket;
+}
+
+void TPollTest::TestPollInOut() {
+#ifdef _win_
+ const size_t socketCount = 1000;
+
+ ui16 port = static_cast<ui16>(1300 + rand() % 97);
+ TSocketHolder serverSocket = StartServerSocket(port, socketCount);
+
+ ui32 localIp = ntohl(inet_addr("127.0.0.1"));
+
+ TVector<TSimpleSharedPtr<TSocketHolder>> clientSockets;
+ TVector<TSimpleSharedPtr<TSocketHolder>> connectedSockets;
+ TVector<pollfd> fds;
+
+ for (size_t i = 0; i < socketCount; ++i) {
+ TSimpleSharedPtr<TSocketHolder> clientSocket(new TSocketHolder(StartClientSocket(localIp, port)));
+ clientSockets.push_back(clientSocket);
+
+ if (i % 5 == 0 || i % 5 == 2) {
+ char buffer = 'c';
+ if (send(*clientSocket, &buffer, 1, 0) == -1)
+ ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")";
+ }
+
+ TSimpleSharedPtr<TSocketHolder> connectedSocket(new TSocketHolder(AcceptConnection(serverSocket)));
+ connectedSockets.push_back(connectedSocket);
+
+ if (i % 5 == 2 || i % 5 == 3) {
+ closesocket(*clientSocket);
+ shutdown(*clientSocket, SD_BOTH);
+ }
+ }
+
+ int expectedCount = 0;
+ for (size_t i = 0; i < connectedSockets.size(); ++i) {
+ pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast<SOCKET>(*connectedSockets[i]), POLLIN | POLLOUT, 0};
+ fds.push_back(fd);
+ if (i % 5 != 4)
+ ++expectedCount;
+ }
+
+ int polledCount = poll(&fds[0], fds.size(), INFTIM);
+ UNIT_ASSERT_EQUAL(expectedCount, polledCount);
+
+ for (size_t i = 0; i < connectedSockets.size(); ++i) {
+ short revents = fds[i].revents;
+ if (i % 5 == 0) {
+ UNIT_ASSERT_EQUAL(static_cast<short>(POLLRDNORM | POLLWRNORM), revents);
+ } else if (i % 5 == 1) {
+ UNIT_ASSERT_EQUAL(static_cast<short>(POLLOUT | POLLWRNORM), revents);
+ } else if (i % 5 == 2) {
+ UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLRDNORM | POLLWRNORM), revents);
+ } else if (i % 5 == 3) {
+ UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLWRNORM), revents);
+ } else if (i % 5 == 4) {
+ UNIT_ASSERT_EQUAL(static_cast<short>(POLLNVAL), revents);
+ }
+ }
+#endif
+}
diff --git a/util/network/ut/ya.make b/util/network/ut/ya.make
new file mode 100644
index 0000000000..1ba03e167c
--- /dev/null
+++ b/util/network/ut/ya.make
@@ -0,0 +1,23 @@
+UNITTEST_FOR(util)
+
+REQUIREMENTS(network:full)
+
+OWNER(g:util)
+SUBSCRIBER(g:util-subscribers)
+
+PEERDIR(
+ library/cpp/threading/future
+)
+
+SRCS(
+ network/address_ut.cpp
+ network/endpoint_ut.cpp
+ network/ip_ut.cpp
+ network/poller_ut.cpp
+ network/sock_ut.cpp
+ network/socket_ut.cpp
+)
+
+INCLUDE(${ARCADIA_ROOT}/util/tests/ya_util_tests.inc)
+
+END()
diff --git a/util/network/ya.make b/util/network/ya.make
new file mode 100644
index 0000000000..79c9498ddd
--- /dev/null
+++ b/util/network/ya.make
@@ -0,0 +1,6 @@
+OWNER(g:util)
+SUBSCRIBER(g:util-subscribers)
+
+RECURSE_FOR_TESTS(
+ ut
+)