diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/network | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/network')
34 files changed, 5332 insertions, 0 deletions
diff --git a/util/network/address.cpp b/util/network/address.cpp new file mode 100644 index 0000000000..a81a9e6994 --- /dev/null +++ b/util/network/address.cpp @@ -0,0 +1,204 @@ +#include <util/stream/str.h> + +#include "address.h" + +#if defined(_unix_) + #include <sys/types.h> + #include <sys/un.h> +#endif + +using namespace NAddr; + +template <bool printPort> +static inline void PrintAddr(IOutputStream& out, const IRemoteAddr& addr) { + const sockaddr* a = addr.Addr(); + char buf[INET6_ADDRSTRLEN + 10]; + + switch (a->sa_family) { + case AF_INET: { + const TIpAddress sa(*(const sockaddr_in*)a); + + out << IpToString(sa.Host(), buf, sizeof(buf)); + + if (printPort) { + out << ":" << sa.Port(); + } + + break; + } + + case AF_INET6: { + const sockaddr_in6* sa = (const sockaddr_in6*)a; + + if (!inet_ntop(AF_INET6, (void*)&sa->sin6_addr.s6_addr, buf, sizeof(buf))) { + ythrow TSystemError() << "inet_ntop() failed"; + } + + if (printPort) { + out << "[" << buf << "]" + << ":" << InetToHost(sa->sin6_port); + } else { + out << buf; + } + + break; + } + +#if defined(AF_UNIX) + case AF_UNIX: { + const sockaddr_un* sa = (const sockaddr_un*)a; + + out << TStringBuf(sa->sun_path); + + break; + } +#endif + + default: { + size_t len = addr.Len(); + + const char* b = (const char*)a; + const char* e = b + len; + + bool allZeros = true; + for (size_t i = 0; i < len; ++i) { + if (b[i] != 0) { + allZeros = false; + break; + } + } + + if (allZeros) { + out << "(raw all zeros)"; + } else { + out << "(raw " << (int)a->sa_family << " "; + + while (b != e) { + //just print raw bytes + out << (int)*b++; + if (b != e) { + out << " "; + } + } + + out << ")"; + } + + break; + } + } +} + +template <> +void Out<IRemoteAddr>(IOutputStream& out, const IRemoteAddr& addr) { + PrintAddr<true>(out, addr); +} + +template <> +void Out<NAddr::TAddrInfo>(IOutputStream& out, const NAddr::TAddrInfo& addr) { + PrintAddr<true>(out, addr); +} + +template <> +void Out<NAddr::TIPv4Addr>(IOutputStream& out, const NAddr::TIPv4Addr& addr) { + PrintAddr<true>(out, addr); +} + +template <> +void Out<NAddr::TIPv6Addr>(IOutputStream& out, const NAddr::TIPv6Addr& addr) { + PrintAddr<true>(out, addr); +} + +template <> +void Out<NAddr::TOpaqueAddr>(IOutputStream& out, const NAddr::TOpaqueAddr& addr) { + PrintAddr<true>(out, addr); +} + +void NAddr::PrintHost(IOutputStream& out, const IRemoteAddr& addr) { + PrintAddr<false>(out, addr); +} + +TString NAddr::PrintHost(const IRemoteAddr& addr) { + TStringStream ss; + PrintAddr<false>(ss, addr); + return ss.Str(); +} + +TString NAddr::PrintHostAndPort(const IRemoteAddr& addr) { + TStringStream ss; + PrintAddr<true>(ss, addr); + return ss.Str(); +} + +IRemoteAddrPtr NAddr::GetSockAddr(SOCKET s) { + auto addr = MakeHolder<TOpaqueAddr>(); + + if (getsockname(s, addr->MutableAddr(), addr->LenPtr()) < 0) { + ythrow TSystemError() << "getsockname() failed"; + } + + return addr; +} + +IRemoteAddrPtr NAddr::GetPeerAddr(SOCKET s) { + auto addr = MakeHolder<TOpaqueAddr>(); + + if (getpeername(s, addr->MutableAddr(), addr->LenPtr()) < 0) { + ythrow TSystemError() << "getpeername() failed"; + } + + return addr; +} + +static const in_addr& InAddr(const IRemoteAddr& addr) { + return ((const sockaddr_in*)addr.Addr())->sin_addr; +} + +static const in6_addr& In6Addr(const IRemoteAddr& addr) { + return ((const sockaddr_in6*)addr.Addr())->sin6_addr; +} + +bool NAddr::IsLoopback(const IRemoteAddr& addr) { + if (addr.Addr()->sa_family == AF_INET) { + return ((ntohl(InAddr(addr).s_addr) >> 24) & 0xff) == 127; + } + + if (addr.Addr()->sa_family == AF_INET6) { + return 0 == memcmp(&In6Addr(addr), &in6addr_loopback, sizeof(in6_addr)); + } + + return false; +} + +bool NAddr::IsSame(const IRemoteAddr& lhs, const IRemoteAddr& rhs) { + if (lhs.Addr()->sa_family != rhs.Addr()->sa_family) { + return false; + } + + if (lhs.Addr()->sa_family == AF_INET) { + return InAddr(lhs).s_addr == InAddr(rhs).s_addr; + } + + if (lhs.Addr()->sa_family == AF_INET6) { + return 0 == memcmp(&In6Addr(lhs), &In6Addr(rhs), sizeof(in6_addr)); + } + + ythrow yexception() << "unsupported addr family: " << lhs.Addr()->sa_family; +} + +socklen_t NAddr::SockAddrLength(const sockaddr* addr) { + switch (addr->sa_family) { + case AF_INET: + return sizeof(sockaddr_in); + + case AF_INET6: + return sizeof(sockaddr_in6); + +#if defined(AF_LOCAL) + case AF_LOCAL: + return sizeof(sockaddr_un); +#endif + } + + ythrow yexception() << "unsupported address family: " << addr->sa_family; +} diff --git a/util/network/address.h b/util/network/address.h new file mode 100644 index 0000000000..448fcac0c9 --- /dev/null +++ b/util/network/address.h @@ -0,0 +1,136 @@ +#pragma once + +#include "ip.h" +#include "socket.h" + +#include <util/generic/ptr.h> +#include <util/generic/string.h> + +namespace NAddr { + class IRemoteAddr { + public: + virtual ~IRemoteAddr() = default; + + virtual const sockaddr* Addr() const = 0; + virtual socklen_t Len() const = 0; + }; + + using IRemoteAddrPtr = THolder<IRemoteAddr>; + using IRemoteAddrRef = TAtomicSharedPtr<NAddr::IRemoteAddr>; + + IRemoteAddrPtr GetSockAddr(SOCKET s); + IRemoteAddrPtr GetPeerAddr(SOCKET s); + void PrintHost(IOutputStream& out, const IRemoteAddr& addr); + + TString PrintHost(const IRemoteAddr& addr); + TString PrintHostAndPort(const IRemoteAddr& addr); + + bool IsLoopback(const IRemoteAddr& addr); + bool IsSame(const IRemoteAddr& lhs, const IRemoteAddr& rhs); + + socklen_t SockAddrLength(const sockaddr* addr); + + //for accept, recvfrom - see LenPtr() + class TOpaqueAddr: public IRemoteAddr { + public: + inline TOpaqueAddr() noexcept + : L_(sizeof(S_)) + { + Zero(S_); + } + + inline TOpaqueAddr(const IRemoteAddr* addr) noexcept { + Assign(addr->Addr(), addr->Len()); + } + + inline TOpaqueAddr(const sockaddr* addr) { + Assign(addr, SockAddrLength(addr)); + } + + const sockaddr* Addr() const override { + return MutableAddr(); + } + + socklen_t Len() const override { + return L_; + } + + inline sockaddr* MutableAddr() const noexcept { + return (sockaddr*)&S_; + } + + inline socklen_t* LenPtr() noexcept { + return &L_; + } + + private: + inline void Assign(const sockaddr* addr, socklen_t len) noexcept { + L_ = len; + memcpy(MutableAddr(), addr, L_); + } + + private: + sockaddr_storage S_; + socklen_t L_; + }; + + //for TNetworkAddress + class TAddrInfo: public IRemoteAddr { + public: + inline TAddrInfo(const addrinfo* ai) noexcept + : AI_(ai) + { + } + + const sockaddr* Addr() const override { + return AI_->ai_addr; + } + + socklen_t Len() const override { + return (socklen_t)AI_->ai_addrlen; + } + + private: + const addrinfo* const AI_; + }; + + //compat, for TIpAddress + class TIPv4Addr: public IRemoteAddr { + public: + inline TIPv4Addr(const TIpAddress& addr) noexcept + : A_(addr) + { + } + + const sockaddr* Addr() const override { + return A_; + } + + socklen_t Len() const override { + return A_; + } + + private: + const TIpAddress A_; + }; + + //same, for ipv6 addresses + class TIPv6Addr: public IRemoteAddr { + public: + inline TIPv6Addr(const sockaddr_in6& a) noexcept + : A_(a) + { + } + + const sockaddr* Addr() const override { + return (sockaddr*)&A_; + } + + socklen_t Len() const override { + return sizeof(A_); + } + + private: + const sockaddr_in6 A_; + }; +} diff --git a/util/network/address_ut.cpp b/util/network/address_ut.cpp new file mode 100644 index 0000000000..28f45172ff --- /dev/null +++ b/util/network/address_ut.cpp @@ -0,0 +1,39 @@ +#include <library/cpp/testing/unittest/registar.h> + +#include "address.h" + +using namespace NAddr; + +Y_UNIT_TEST_SUITE(IRemoteAddr_ToString) { + Y_UNIT_TEST(Raw) { + THolder<TOpaqueAddr> opaque(new TOpaqueAddr); + IRemoteAddr* addr = opaque.Get(); + + TString s = ToString(*addr); + UNIT_ASSERT_VALUES_EQUAL("(raw all zeros)", s); + + opaque->MutableAddr()->sa_data[10] = 17; + + TString t = ToString(*addr); + + UNIT_ASSERT_C(t.StartsWith("(raw 0 0"), t); + UNIT_ASSERT_C(t.EndsWith(')'), t); + } + + Y_UNIT_TEST(Ipv6) { + TNetworkAddress address("::1", 22); + TNetworkAddress::TIterator it = address.Begin(); + UNIT_ASSERT(it != address.End()); + UNIT_ASSERT(it->ai_family == AF_INET6); + TString toString = ToString((const IRemoteAddr&)TAddrInfo(&*it)); + UNIT_ASSERT_VALUES_EQUAL(TString("[::1]:22"), toString); + } + + Y_UNIT_TEST(Loopback) { + TNetworkAddress localAddress("127.70.0.1", 22); + UNIT_ASSERT_VALUES_EQUAL(NAddr::IsLoopback(TAddrInfo(&*localAddress.Begin())), true); + + TNetworkAddress localAddress2("127.0.0.1", 22); + UNIT_ASSERT_VALUES_EQUAL(NAddr::IsLoopback(TAddrInfo(&*localAddress2.Begin())), true); + } +} diff --git a/util/network/endpoint.cpp b/util/network/endpoint.cpp new file mode 100644 index 0000000000..9acdd06940 --- /dev/null +++ b/util/network/endpoint.cpp @@ -0,0 +1,67 @@ +#include "endpoint.h" +#include "sock.h" + +TEndpoint::TEndpoint(const TEndpoint::TAddrRef& addr) + : Addr_(addr) +{ + const sockaddr* sa = Addr_->Addr(); + + if (sa->sa_family != AF_INET && sa->sa_family != AF_INET6 && sa->sa_family != AF_UNIX) { + ythrow yexception() << TStringBuf("endpoint can contain only ipv4, ipv6 or unix address"); + } +} + +TEndpoint::TEndpoint() + : Addr_(new NAddr::TIPv4Addr(TIpAddress(TIpHost(0), TIpPort(0)))) +{ +} + +void TEndpoint::SetPort(ui16 port) { + if (Port() == port || Addr_->Addr()->sa_family == AF_UNIX) { + return; + } + + NAddr::TOpaqueAddr* oa = new NAddr::TOpaqueAddr(Addr_.Get()); + Addr_.Reset(oa); + sockaddr* sa = oa->MutableAddr(); + + if (sa->sa_family == AF_INET) { + ((sockaddr_in*)sa)->sin_port = HostToInet(port); + } else { + ((sockaddr_in6*)sa)->sin6_port = HostToInet(port); + } +} + +ui16 TEndpoint::Port() const noexcept { + if (Addr_->Addr()->sa_family == AF_UNIX) { + return 0; + } + + const sockaddr* sa = Addr_->Addr(); + + if (sa->sa_family == AF_INET) { + return InetToHost(((const sockaddr_in*)sa)->sin_port); + } else { + return InetToHost(((const sockaddr_in6*)sa)->sin6_port); + } +} + +size_t TEndpoint::Hash() const { + const sockaddr* sa = Addr_->Addr(); + + if (sa->sa_family == AF_INET) { + const sockaddr_in* sa4 = (const sockaddr_in*)sa; + + return IntHash((((ui64)sa4->sin_addr.s_addr) << 16) ^ sa4->sin_port); + } else if (sa->sa_family == AF_INET6) { + const sockaddr_in6* sa6 = (const sockaddr_in6*)sa; + const ui64* ptr = (const ui64*)&sa6->sin6_addr; + + return IntHash(ptr[0] ^ ptr[1] ^ sa6->sin6_port); + } else { + const sockaddr_un* un = (const sockaddr_un*)sa; + THash<TString> strHash; + + return strHash(un->sun_path); + } +} diff --git a/util/network/endpoint.h b/util/network/endpoint.h new file mode 100644 index 0000000000..a3e59b4925 --- /dev/null +++ b/util/network/endpoint.h @@ -0,0 +1,61 @@ +#pragma once + +#include "address.h" + +#include <util/str_stl.h> + +//some equivalent boost::asio::ip::endpoint (easy for using pair ip:port) +class TEndpoint { +public: + using TAddrRef = NAddr::IRemoteAddrRef; + + TEndpoint(const TAddrRef& addr); + TEndpoint(); + + inline const TAddrRef& Addr() const noexcept { + return Addr_; + } + inline const sockaddr* SockAddr() const { + return Addr_->Addr(); + } + inline socklen_t SockAddrLen() const { + return Addr_->Len(); + } + + inline bool IsIpV4() const { + return Addr_->Addr()->sa_family == AF_INET; + } + inline bool IsIpV6() const { + return Addr_->Addr()->sa_family == AF_INET6; + } + inline bool IsUnix() const { + return Addr_->Addr()->sa_family == AF_UNIX; + } + + inline TString IpToString() const { + return NAddr::PrintHost(*Addr_); + } + + void SetPort(ui16 port); + ui16 Port() const noexcept; + + size_t Hash() const; + +private: + TAddrRef Addr_; +}; + +template <> +struct THash<TEndpoint> { + inline size_t operator()(const TEndpoint& ep) const { + return ep.Hash(); + } +}; + +inline bool operator==(const TEndpoint& l, const TEndpoint& r) { + try { + return NAddr::IsSame(*l.Addr(), *r.Addr()) && l.Port() == r.Port(); + } catch (...) { + return false; + } +} diff --git a/util/network/endpoint_ut.cpp b/util/network/endpoint_ut.cpp new file mode 100644 index 0000000000..d5e40dd6e1 --- /dev/null +++ b/util/network/endpoint_ut.cpp @@ -0,0 +1,123 @@ +#include "endpoint.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/hash_set.h> +#include <util/generic/strbuf.h> + +Y_UNIT_TEST_SUITE(TEndpointTest) { + Y_UNIT_TEST(TestSimple) { + TVector<TNetworkAddress> addrs; + + TEndpoint ep0; + + UNIT_ASSERT(ep0.IsIpV4()); + UNIT_ASSERT_VALUES_EQUAL(0, ep0.Port()); + UNIT_ASSERT_VALUES_EQUAL("0.0.0.0", ep0.IpToString()); + + TEndpoint ep1; + + try { + TNetworkAddress na1("25.26.27.28", 24242); + + addrs.push_back(na1); + + ep1 = TEndpoint(new NAddr::TAddrInfo(&*na1.Begin())); + + UNIT_ASSERT(ep1.IsIpV4()); + UNIT_ASSERT_VALUES_EQUAL("25.26.27.28", ep1.IpToString()); + UNIT_ASSERT_VALUES_EQUAL(24242, ep1.Port()); + } catch (const TNetworkResolutionError&) { + TNetworkAddress n("2a02:6b8:0:1420:0::5f6c:f3c2", 11111); + + addrs.push_back(n); + + ep1 = TEndpoint(new NAddr::TAddrInfo(&*n.Begin())); + } + + ep0.SetPort(12345); + + TEndpoint ep2(ep0); + + ep0.SetPort(0); + + UNIT_ASSERT_VALUES_EQUAL(12345, ep2.Port()); + + TEndpoint ep2_; + + ep2_.SetPort(12345); + + UNIT_ASSERT(ep2 == ep2_); + + TNetworkAddress na3("2a02:6b8:0:1410::5f6c:f3c2", 54321); + TEndpoint ep3(new NAddr::TAddrInfo(&*na3.Begin())); + + UNIT_ASSERT(ep3.IsIpV6()); + UNIT_ASSERT(ep3.IpToString().StartsWith(TStringBuf("2a02:6b8:0:1410:"))); + UNIT_ASSERT(ep3.IpToString().EndsWith(TStringBuf(":5f6c:f3c2"))); + UNIT_ASSERT_VALUES_EQUAL(54321, ep3.Port()); + + TNetworkAddress na4("2a02:6b8:0:1410:0::5f6c:f3c2", 1); + TEndpoint ep4(new NAddr::TAddrInfo(&*na4.Begin())); + + TEndpoint ep3_ = ep4; + + ep3_.SetPort(54321); + + THashSet<TEndpoint> he; + + he.insert(ep0); + he.insert(ep1); + he.insert(ep2); + + UNIT_ASSERT_VALUES_EQUAL(3u, he.size()); + + he.insert(ep2_); + + UNIT_ASSERT_VALUES_EQUAL(3u, he.size()); + + he.insert(ep3); + he.insert(ep3_); + + UNIT_ASSERT_VALUES_EQUAL(4u, he.size()); + + he.insert(ep4); + + UNIT_ASSERT_VALUES_EQUAL(5u, he.size()); + } + + Y_UNIT_TEST(TestEqual) { + const TString ip1 = "2a02:6b8:0:1410::5f6c:f3c2"; + const TString ip2 = "2a02:6b8:0:1410::5f6c:f3c3"; + + TNetworkAddress na1(ip1, 24242); + TEndpoint ep1(new NAddr::TAddrInfo(&*na1.Begin())); + + TNetworkAddress na2(ip1, 24242); + TEndpoint ep2(new NAddr::TAddrInfo(&*na2.Begin())); + + TNetworkAddress na3(ip2, 24242); + TEndpoint ep3(new NAddr::TAddrInfo(&*na3.Begin())); + + TNetworkAddress na4(ip2, 24243); + TEndpoint ep4(new NAddr::TAddrInfo(&*na4.Begin())); + + UNIT_ASSERT(ep1 == ep2); + UNIT_ASSERT(!(ep1 == ep3)); + UNIT_ASSERT(!(ep1 == ep4)); + } + + Y_UNIT_TEST(TestIsUnixSocket) { + TNetworkAddress na1(TUnixSocketPath("/tmp/unixsocket")); + TEndpoint ep1(new NAddr::TAddrInfo(&*na1.Begin())); + + TNetworkAddress na2("2a02:6b8:0:1410::5f6c:f3c2", 24242); + TEndpoint ep2(new NAddr::TAddrInfo(&*na2.Begin())); + + UNIT_ASSERT(ep1.IsUnix()); + UNIT_ASSERT(ep1.SockAddr()->sa_family == AF_UNIX); + + UNIT_ASSERT(!ep2.IsUnix()); + UNIT_ASSERT(ep2.SockAddr()->sa_family != AF_UNIX); + } +} diff --git a/util/network/hostip.cpp b/util/network/hostip.cpp new file mode 100644 index 0000000000..cb8d43bf90 --- /dev/null +++ b/util/network/hostip.cpp @@ -0,0 +1,76 @@ +#include "socket.h" +#include "hostip.h" + +#include <util/system/defaults.h> +#include <util/system/byteorder.h> + +#if defined(_unix_) || defined(_cygwin_) + #include <netdb.h> +#endif + +#if !defined(BIND_LIB) + #if !defined(__FreeBSD__) && !defined(_win32_) && !defined(_cygwin_) + #define AGENT_USE_GETADDRINFO + #endif + + #if defined(__FreeBSD__) + #define AGENT_USE_GETADDRINFO + #endif +#endif + +int NResolver::GetHostIP(const char* hostname, ui32* ip, size_t* slots) { + size_t i = 0; + size_t ipsFound = 0; + +#ifdef AGENT_USE_GETADDRINFO + int ret = 0; + struct addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + struct addrinfo* gai_res = nullptr; + int gai_ret = getaddrinfo(hostname, nullptr, &hints, &gai_res); + if (gai_ret == 0 && gai_res->ai_addr) { + struct addrinfo* cur = gai_res; + for (i = 0; i < *slots && cur; i++, cur = cur->ai_next, ipsFound++) { + ip[i] = *(ui32*)(&((sockaddr_in*)(cur->ai_addr))->sin_addr); + } + } else { + if (gai_ret == EAI_NONAME || gai_ret == EAI_SERVICE) { + ret = HOST_NOT_FOUND; + } else { + ret = GetDnsError(); + } + } + + if (gai_res) { + freeaddrinfo(gai_res); + } + + if (ret) { + return ret; + } +#else + hostent* hostent = gethostbyname(hostname); + + if (!hostent) + return GetDnsError(); + + if (hostent->h_addrtype != AF_INET || (unsigned)hostent->h_length < sizeof(ui32)) + return HOST_NOT_FOUND; + + char** cur = hostent->h_addr_list; + for (i = 0; i < *slots && *cur; i++, cur++, ipsFound++) + ip[i] = *(ui32*)*cur; +#endif + for (i = 0; i < ipsFound; i++) { + ip[i] = InetToHost(ip[i]); + } + *slots = ipsFound; + + return 0; +} + +int NResolver::GetDnsError() { + return h_errno; +} diff --git a/util/network/hostip.h b/util/network/hostip.h new file mode 100644 index 0000000000..cf63e4846a --- /dev/null +++ b/util/network/hostip.h @@ -0,0 +1,16 @@ +#pragma once + +#include <util/system/defaults.h> + +namespace NResolver { + // resolve hostname and fills up to *slots slots in ip array; + // actual number of slots filled is returned in *slots; + int GetHostIP(const char* hostname, ui32* ip, size_t* slots); + int GetDnsError(); + + inline int GetHostIP(const char* hostname, ui32* ip) { + size_t slots = 1; + + return GetHostIP(hostname, ip, &slots); + } +} diff --git a/util/network/init.cpp b/util/network/init.cpp new file mode 100644 index 0000000000..366e65682c --- /dev/null +++ b/util/network/init.cpp @@ -0,0 +1,34 @@ +#include "init.h" + +#include <util/system/compat.h> +#include <util/system/yassert.h> +#include <util/system/defaults.h> +#include <util/generic/singleton.h> + +#include <cstdio> +#include <cstdlib> + +namespace { + class TNetworkInit { + public: + inline TNetworkInit() { +#ifndef ROBOT_SIGPIPE + signal(SIGPIPE, SIG_IGN); +#endif + +#if defined(_win_) + #pragma comment(lib, "ws2_32.lib") + WSADATA wsaData; + int result = WSAStartup(MAKEWORD(2, 2), &wsaData); + Y_ASSERT(!result); + if (result) { + exit(-1); + } +#endif + } + }; +} + +void InitNetworkSubSystem() { + (void)Singleton<TNetworkInit>(); +} diff --git a/util/network/init.h b/util/network/init.h new file mode 100644 index 0000000000..08a79c0fca --- /dev/null +++ b/util/network/init.h @@ -0,0 +1,60 @@ +#pragma once + +#include <util/system/error.h> + +#if defined(_unix_) + #include <fcntl.h> + #include <netdb.h> + #include <time.h> + #include <unistd.h> + #include <poll.h> + + #include <sys/uio.h> + #include <sys/time.h> + #include <sys/types.h> + #include <sys/socket.h> + + #include <netinet/in.h> + #include <netinet/tcp.h> + #include <arpa/inet.h> + +using SOCKET = int; + + #define closesocket(s) close(s) + #define SOCKET_ERROR -1 + #define INVALID_SOCKET -1 + #define WSAGetLastError() errno +#elif defined(_win_) + #include <util/system/winint.h> + #include <io.h> + #include <winsock2.h> + #include <ws2tcpip.h> + +using nfds_t = ULONG; + + #undef Yield + +struct sockaddr_un { + short sun_family; + char sun_path[108]; +}; + + #define PF_LOCAL AF_UNIX + #define NETDB_INTERNAL -1 + #define NETDB_SUCCESS 0 + +#endif + +#if defined(_win_) || defined(_darwin_) + #ifndef MSG_NOSIGNAL + #define MSG_NOSIGNAL 0 + #endif +#endif // _win_ or _darwin_ + +void InitNetworkSubSystem(); + +static struct TNetworkInitializer { + inline TNetworkInitializer() { + InitNetworkSubSystem(); + } +} NetworkInitializerObject; diff --git a/util/network/interface.cpp b/util/network/interface.cpp new file mode 100644 index 0000000000..256776c6d3 --- /dev/null +++ b/util/network/interface.cpp @@ -0,0 +1,79 @@ +#include "interface.h" + +#if defined(_unix_) + #include <ifaddrs.h> +#endif + +#ifdef _win_ + #include <iphlpapi.h> + #pragma comment(lib, "Iphlpapi.lib") +#endif + +namespace NAddr { + static bool IsInetAddress(sockaddr* addr) { + return (addr != nullptr) && ((addr->sa_family == AF_INET) || (addr->sa_family == AF_INET6)); + } + + TNetworkInterfaceList GetNetworkInterfaces() { + TNetworkInterfaceList result; + +#ifdef _win_ + TVector<char> buf; + buf.resize(1000000); + PIP_ADAPTER_ADDRESSES adapterBuf = (PIP_ADAPTER_ADDRESSES)&buf[0]; + ULONG bufSize = buf.ysize(); + + if (GetAdaptersAddresses(AF_UNSPEC, 0, nullptr, adapterBuf, &bufSize) == ERROR_SUCCESS) { + for (PIP_ADAPTER_ADDRESSES ptr = adapterBuf; ptr != 0; ptr = ptr->Next) { + // The check below makes code working on Vista+ + if ((ptr->Flags & (IP_ADAPTER_IPV4_ENABLED | IP_ADAPTER_IPV6_ENABLED)) == 0) { + continue; + } + if (ptr->IfType == IF_TYPE_TUNNEL) { + // ignore tunnels + continue; + } + if (ptr->OperStatus != IfOperStatusUp) { + // ignore disable adapters + continue; + } + + for (IP_ADAPTER_UNICAST_ADDRESS* addr = ptr->FirstUnicastAddress; addr != 0; addr = addr->Next) { + sockaddr* a = (sockaddr*)addr->Address.lpSockaddr; + if (IsInetAddress(a)) { + TNetworkInterface networkInterface; + + // Not very efficient but straightforward + for (size_t i = 0; ptr->FriendlyName[i] != 0; i++) { + CHAR w = ptr->FriendlyName[i]; + char c = (w < 0x80) ? char(w) : '?'; + networkInterface.Name.append(1, c); + } + + networkInterface.Address = new TOpaqueAddr(a); + result.push_back(networkInterface); + } + } + } + } +#else + ifaddrs* ifap; + if (getifaddrs(&ifap) != -1) { + for (ifaddrs* ifa = ifap; ifa != nullptr; ifa = ifa->ifa_next) { + if (IsInetAddress(ifa->ifa_addr)) { + TNetworkInterface interface; + interface.Name = ifa->ifa_name; + interface.Address = new TOpaqueAddr(ifa->ifa_addr); + if (IsInetAddress(ifa->ifa_netmask)) { + interface.Mask = new TOpaqueAddr(ifa->ifa_netmask); + } + result.push_back(interface); + } + } + freeifaddrs(ifap); + } +#endif + + return result; + } +} diff --git a/util/network/interface.h b/util/network/interface.h new file mode 100644 index 0000000000..dda4555021 --- /dev/null +++ b/util/network/interface.h @@ -0,0 +1,17 @@ +#pragma once + +#include "address.h" + +#include <util/generic/vector.h> + +namespace NAddr { + struct TNetworkInterface { + TString Name; + IRemoteAddrRef Address; + IRemoteAddrRef Mask; + }; + + using TNetworkInterfaceList = TVector<TNetworkInterface>; + + TNetworkInterfaceList GetNetworkInterfaces(); +} diff --git a/util/network/iovec.cpp b/util/network/iovec.cpp new file mode 100644 index 0000000000..7251038848 --- /dev/null +++ b/util/network/iovec.cpp @@ -0,0 +1 @@ +#include "iovec.h" diff --git a/util/network/iovec.h b/util/network/iovec.h new file mode 100644 index 0000000000..ac15a41f54 --- /dev/null +++ b/util/network/iovec.h @@ -0,0 +1,65 @@ +#pragma once + +#include <util/stream/output.h> +#include <util/system/types.h> +#include <util/system/yassert.h> + +class TContIOVector { + using TPart = IOutputStream::TPart; + +public: + inline TContIOVector(TPart* parts, size_t count) + : Parts_(parts) + , Count_(count) + { + } + + inline void Proceed(size_t len) noexcept { + while (Count_) { + if (len < Parts_->len) { + Parts_->len -= len; + Parts_->buf = (const char*)Parts_->buf + len; + + return; + } else { + len -= Parts_->len; + --Count_; + ++Parts_; + } + } + + if (len) { + Y_ASSERT(0 && "non zero length left"); + } + } + + inline const TPart* Parts() const noexcept { + return Parts_; + } + + inline size_t Count() const noexcept { + return Count_; + } + + static inline size_t Bytes(const TPart* parts, size_t count) noexcept { + size_t ret = 0; + + for (size_t i = 0; i < count; ++i) { + ret += parts[i].len; + } + + return ret; + } + + inline size_t Bytes() const noexcept { + return Bytes(Parts_, Count_); + } + + inline bool Complete() const noexcept { + return !Count(); + } + +private: + TPart* Parts_; + size_t Count_; +}; diff --git a/util/network/ip.cpp b/util/network/ip.cpp new file mode 100644 index 0000000000..a43bcdadcf --- /dev/null +++ b/util/network/ip.cpp @@ -0,0 +1 @@ +#include "ip.h" diff --git a/util/network/ip.h b/util/network/ip.h new file mode 100644 index 0000000000..dc7c2d24a0 --- /dev/null +++ b/util/network/ip.h @@ -0,0 +1,119 @@ +#pragma once + +#include "socket.h" +#include "hostip.h" + +#include <util/system/error.h> +#include <util/system/byteorder.h> +#include <util/generic/string.h> +#include <util/generic/yexception.h> + +/// IPv4 address in network format +using TIpHost = ui32; + +/// Port number in host format +using TIpPort = ui16; + +/* + * ipStr is in 'ddd.ddd.ddd.ddd' format + * returns IPv4 address in inet format + */ +static inline TIpHost IpFromString(const char* ipStr) { + in_addr ia; + + if (inet_aton(ipStr, &ia) == 0) { + ythrow TSystemError() << "Failed to convert (" << ipStr << ") to ip address"; + } + + return (ui32)ia.s_addr; +} + +static inline char* IpToString(TIpHost ip, char* buf, size_t len) { + if (!inet_ntop(AF_INET, (void*)&ip, buf, (socklen_t)len)) { + ythrow TSystemError() << "Failed to get ip address string"; + } + + return buf; +} + +static inline TString IpToString(TIpHost ip) { + char buf[INET_ADDRSTRLEN]; + + return TString(IpToString(ip, buf, sizeof(buf))); +} + +static inline TIpHost ResolveHost(const char* data, size_t len) { + TIpHost ret; + const TString s(data, len); + + if (NResolver::GetHostIP(s.data(), &ret) != 0) { + ythrow TSystemError(NResolver::GetDnsError()) << "can not resolve(" << s << ")"; + } + + return HostToInet(ret); +} + +/// socket address +struct TIpAddress: public sockaddr_in { + inline TIpAddress() noexcept { + Clear(); + } + + inline TIpAddress(const sockaddr_in& addr) noexcept + : sockaddr_in(addr) + , tmp(0) + { + } + + inline TIpAddress(TIpHost ip, TIpPort port) noexcept { + Set(ip, port); + } + + inline TIpAddress(TStringBuf ip, TIpPort port) { + Set(ResolveHost(ip.data(), ip.size()), port); + } + + inline TIpAddress(const char* ip, TIpPort port) { + Set(ResolveHost(ip, strlen(ip)), port); + } + + inline operator sockaddr*() const noexcept { + return (sockaddr*)(sockaddr_in*)this; + } + + inline operator socklen_t*() const noexcept { + tmp = sizeof(sockaddr_in); + + return (socklen_t*)&tmp; + } + + inline operator socklen_t() const noexcept { + tmp = sizeof(sockaddr_in); + + return tmp; + } + + inline void Clear() noexcept { + Zero((sockaddr_in&)(*this)); + } + + inline void Set(TIpHost ip, TIpPort port) noexcept { + Clear(); + + sin_family = AF_INET; + sin_addr.s_addr = ip; + sin_port = HostToInet(port); + } + + inline TIpHost Host() const noexcept { + return sin_addr.s_addr; + } + + inline TIpPort Port() const noexcept { + return InetToHost(sin_port); + } + +private: + // required for "operator socklen_t*()" + mutable socklen_t tmp; +}; diff --git a/util/network/ip_ut.cpp b/util/network/ip_ut.cpp new file mode 100644 index 0000000000..6716c6a699 --- /dev/null +++ b/util/network/ip_ut.cpp @@ -0,0 +1,63 @@ +#include "ip.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/generic/yexception.h> + +class TSysIpTest: public TTestBase { + UNIT_TEST_SUITE(TSysIpTest); + UNIT_TEST(TestIpFromString); + UNIT_TEST_EXCEPTION(TestIpFromString2, yexception); + UNIT_TEST_EXCEPTION(TestIpFromString3, yexception); + UNIT_TEST_EXCEPTION(TestIpFromString4, yexception); + UNIT_TEST_EXCEPTION(TestIpFromString5, yexception); + UNIT_TEST(TestIpToString); + UNIT_TEST_SUITE_END(); + +private: + void TestIpFromString(); + void TestIpFromString2(); + void TestIpFromString3(); + void TestIpFromString4(); + void TestIpFromString5(); + void TestIpToString(); +}; + +UNIT_TEST_SUITE_REGISTRATION(TSysIpTest); + +void TSysIpTest::TestIpFromString() { + const char* ipStr[] = {"192.168.0.1", "87.255.18.167", "255.255.0.31", "188.225.124.255"}; + ui8 ipArr[][4] = {{192, 168, 0, 1}, {87, 255, 18, 167}, {255, 255, 0, 31}, {188, 225, 124, 255}}; + + for (size_t i = 0; i < Y_ARRAY_SIZE(ipStr); ++i) { + const ui32 ip = IpFromString(ipStr[i]); + + UNIT_ASSERT(memcmp(&ip, ipArr[i], sizeof(ui32)) == 0); + } +} + +void TSysIpTest::TestIpFromString2() { + IpFromString("XXXXXXWXW"); +} + +void TSysIpTest::TestIpFromString3() { + IpFromString("986.0.37.255"); +} + +void TSysIpTest::TestIpFromString4() { + IpFromString("256.0.22.365"); +} + +void TSysIpTest::TestIpFromString5() { + IpFromString("245.12..0"); +} + +void TSysIpTest::TestIpToString() { + ui8 ipArr[][4] = {{192, 168, 0, 1}, {87, 255, 18, 167}, {255, 255, 0, 31}, {188, 225, 124, 255}}; + + const char* ipStr[] = {"192.168.0.1", "87.255.18.167", "255.255.0.31", "188.225.124.255"}; + + for (size_t i = 0; i < Y_ARRAY_SIZE(ipStr); ++i) { + UNIT_ASSERT(IpToString(*reinterpret_cast<TIpHost*>(&(ipArr[i]))) == ipStr[i]); + } +} diff --git a/util/network/nonblock.cpp b/util/network/nonblock.cpp new file mode 100644 index 0000000000..e515c27cc5 --- /dev/null +++ b/util/network/nonblock.cpp @@ -0,0 +1,104 @@ +#include "nonblock.h" + +#include <util/system/platform.h> + +#include <util/generic/singleton.h> + +#if defined(_unix_) + #include <dlfcn.h> +#endif + +#if defined(_linux_) + #if !defined(SOCK_NONBLOCK) + #define SOCK_NONBLOCK 04000 + #endif +#endif + +namespace { + struct TFeatureCheck { + inline TFeatureCheck() + : Accept4(nullptr) + , HaveSockNonBlock(false) + { +#if defined(_unix_) && defined(SOCK_NONBLOCK) + { + Accept4 = reinterpret_cast<TAccept4>(dlsym(RTLD_DEFAULT, "accept4")); + + #if defined(_musl_) + //musl always statically linked + if (!Accept4) { + Accept4 = accept4; + } + #endif + + if (Accept4) { + Accept4(-1, nullptr, nullptr, SOCK_NONBLOCK); + + if (errno == ENOSYS) { + Accept4 = nullptr; + } + } + } +#endif + +#if defined(SOCK_NONBLOCK) + { + TSocketHolder tmp(socket(PF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0)); + + HaveSockNonBlock = !tmp.Closed(); + } +#endif + } + + inline SOCKET FastAccept(SOCKET s, struct sockaddr* addr, socklen_t* addrlen) const { +#if defined(SOCK_NONBLOCK) + if (Accept4) { + return Accept4(s, addr, addrlen, SOCK_NONBLOCK); + } +#endif + + const SOCKET ret = accept(s, addr, addrlen); + +#if !defined(_freebsd_) + //freebsd inherit O_NONBLOCK flag + if (ret != INVALID_SOCKET) { + SetNonBlock(ret); + } +#endif + + return ret; + } + + inline SOCKET FastSocket(int domain, int type, int protocol) const { +#if defined(SOCK_NONBLOCK) + if (HaveSockNonBlock) { + return socket(domain, type | SOCK_NONBLOCK, protocol); + } +#endif + + const SOCKET ret = socket(domain, type, protocol); + + if (ret != INVALID_SOCKET) { + SetNonBlock(ret); + } + + return ret; + } + + static inline const TFeatureCheck* Instance() noexcept { + return Singleton<TFeatureCheck>(); + } + + using TAccept4 = int (*)(int sockfd, struct sockaddr* addr, socklen_t* addrlen, int flags); + TAccept4 Accept4; + bool HaveSockNonBlock; + }; +} + +SOCKET Accept4(SOCKET s, struct sockaddr* addr, socklen_t* addrlen) { + return TFeatureCheck::Instance()->FastAccept(s, addr, addrlen); +} + +SOCKET Socket4(int domain, int type, int protocol) { + return TFeatureCheck::Instance()->FastSocket(domain, type, protocol); +} diff --git a/util/network/nonblock.h b/util/network/nonblock.h new file mode 100644 index 0000000000..54e5e44ae3 --- /dev/null +++ b/util/network/nonblock.h @@ -0,0 +1,8 @@ +#pragma once + +#include "socket.h" + +//assume s is non-blocking, return non-blocking socket +SOCKET Accept4(SOCKET s, struct sockaddr* addr, socklen_t* addrlen); +//create non-blocking socket +SOCKET Socket4(int domain, int type, int protocol); diff --git a/util/network/pair.cpp b/util/network/pair.cpp new file mode 100644 index 0000000000..9751ef5c96 --- /dev/null +++ b/util/network/pair.cpp @@ -0,0 +1,97 @@ +#include "pair.h" + +int SocketPair(SOCKET socks[2], bool overlapped, bool cloexec) { +#if defined(_win_) + struct sockaddr_in addr; + SOCKET listener; + int e; + int addrlen = sizeof(addr); + DWORD flags = (overlapped ? WSA_FLAG_OVERLAPPED : 0) | (cloexec ? WSA_FLAG_NO_HANDLE_INHERIT : 0); + + if (socks == 0) { + WSASetLastError(WSAEINVAL); + + return SOCKET_ERROR; + } + + socks[0] = INVALID_SOCKET; + socks[1] = INVALID_SOCKET; + + if ((listener = socket(AF_INET, SOCK_STREAM, 0)) == INVALID_SOCKET) { + return SOCKET_ERROR; + } + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(0x7f000001); + addr.sin_port = 0; + + e = bind(listener, (const struct sockaddr*)&addr, sizeof(addr)); + + if (e == SOCKET_ERROR) { + e = WSAGetLastError(); + closesocket(listener); + WSASetLastError(e); + + return SOCKET_ERROR; + } + + e = getsockname(listener, (struct sockaddr*)&addr, &addrlen); + + if (e == SOCKET_ERROR) { + e = WSAGetLastError(); + closesocket(listener); + WSASetLastError(e); + + return SOCKET_ERROR; + } + + do { + if (listen(listener, 1) == SOCKET_ERROR) + break; + + if ((socks[0] = WSASocket(AF_INET, SOCK_STREAM, 0, nullptr, 0, flags)) == INVALID_SOCKET) + break; + + if (connect(socks[0], (const struct sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) + break; + + if ((socks[1] = accept(listener, nullptr, nullptr)) == INVALID_SOCKET) + break; + + closesocket(listener); + + return 0; + } while (0); + + e = WSAGetLastError(); + closesocket(listener); + closesocket(socks[0]); + closesocket(socks[1]); + WSASetLastError(e); + + return SOCKET_ERROR; +#else + (void)overlapped; + + #if defined(_linux_) + return socketpair(AF_LOCAL, SOCK_STREAM | (cloexec ? SOCK_CLOEXEC : 0), 0, socks); + #else + int r = socketpair(AF_LOCAL, SOCK_STREAM, 0, socks); + // Non-atomic wrt exec + if (r == 0 && cloexec) { + for (int i = 0; i < 2; ++i) { + int flags = fcntl(socks[i], F_GETFD, 0); + if (flags < 0) { + return flags; + } + r = fcntl(socks[i], F_SETFD, flags | FD_CLOEXEC); + if (r < 0) { + return r; + } + } + } + return r; + #endif +#endif +} diff --git a/util/network/pair.h b/util/network/pair.h new file mode 100644 index 0000000000..0d4506f880 --- /dev/null +++ b/util/network/pair.h @@ -0,0 +1,9 @@ +#pragma once + +#include "init.h" + +int SocketPair(SOCKET socks[2], bool overlapped, bool cloexec = false); + +static inline int SocketPair(SOCKET socks[2]) { + return SocketPair(socks, false, false); +} diff --git a/util/network/poller.cpp b/util/network/poller.cpp new file mode 100644 index 0000000000..7954d0e8b5 --- /dev/null +++ b/util/network/poller.cpp @@ -0,0 +1,86 @@ +#include "poller.h" +#include "pollerimpl.h" + +#include <util/memory/tempbuf.h> + +namespace { + struct TMutexLocking { + using TMyMutex = TMutex; + }; +} + +class TSocketPoller::TImpl: public TPollerImpl<TMutexLocking> { +public: + inline size_t DoWaitReal(void** ev, TEvent* events, size_t len, const TInstant& deadLine) { + const size_t ret = WaitD(events, len, deadLine); + + for (size_t i = 0; i < ret; ++i) { + ev[i] = ExtractEvent(&events[i]); + } + + return ret; + } + + inline size_t DoWait(void** ev, size_t len, const TInstant& deadLine) { + if (len == 1) { + TEvent tmp; + + return DoWaitReal(ev, &tmp, 1, deadLine); + } else { + TTempArray<TEvent> tmpEvents(len); + + return DoWaitReal(ev, tmpEvents.Data(), len, deadLine); + } + } +}; + +TSocketPoller::TSocketPoller() + : Impl_(new TImpl()) +{ +} + +TSocketPoller::~TSocketPoller() = default; + +void TSocketPoller::WaitRead(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_READ); +} + +void TSocketPoller::WaitWrite(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_WRITE); +} + +void TSocketPoller::WaitReadWrite(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE); +} + +void TSocketPoller::WaitRdhup(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_RDHUP); +} + +void TSocketPoller::WaitReadOneShot(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_ONE_SHOT); +} + +void TSocketPoller::WaitWriteOneShot(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_WRITE | CONT_POLL_ONE_SHOT); +} + +void TSocketPoller::WaitReadWriteOneShot(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_ONE_SHOT); +} + +void TSocketPoller::WaitReadWriteEdgeTriggered(SOCKET sock, void* cookie) { + Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_EDGE_TRIGGERED); +} + +void TSocketPoller::RestartReadWriteEdgeTriggered(SOCKET sock, void* cookie, bool empty) { + Impl_->Set(cookie, sock, CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_MODIFY | CONT_POLL_EDGE_TRIGGERED | (empty ? CONT_POLL_BACKLOG_EMPTY : 0)); +} + +void TSocketPoller::Unwait(SOCKET sock) { + Impl_->Remove(sock); +} + +size_t TSocketPoller::WaitD(void** ev, size_t len, const TInstant& deadLine) { + return Impl_->DoWait(ev, len, deadLine); +} diff --git a/util/network/poller.h b/util/network/poller.h new file mode 100644 index 0000000000..8dccd73140 --- /dev/null +++ b/util/network/poller.h @@ -0,0 +1,58 @@ +#pragma once + +#include "socket.h" + +#include <util/generic/ptr.h> +#include <util/datetime/base.h> + +class TSocketPoller { +public: + TSocketPoller(); + ~TSocketPoller(); + + void WaitRead(SOCKET sock, void* cookie); + void WaitWrite(SOCKET sock, void* cookie); + void WaitReadWrite(SOCKET sock, void* cookie); + void WaitRdhup(SOCKET sock, void* cookie); + + void WaitReadOneShot(SOCKET sock, void* cookie); + void WaitWriteOneShot(SOCKET sock, void* cookie); + void WaitReadWriteOneShot(SOCKET sock, void* cookie); + + void WaitReadWriteEdgeTriggered(SOCKET sock, void* cookie); + void RestartReadWriteEdgeTriggered(SOCKET sock, void* cookie, bool empty = true); + + void Unwait(SOCKET sock); + + size_t WaitD(void** events, size_t len, const TInstant& deadLine); + + inline size_t WaitT(void** events, size_t len, const TDuration& timeOut) { + return WaitD(events, len, timeOut.ToDeadLine()); + } + + inline size_t WaitI(void** events, size_t len) { + return WaitD(events, len, TInstant::Max()); + } + + inline void* WaitD(const TInstant& deadLine) { + void* ret; + + if (WaitD(&ret, 1, deadLine)) { + return ret; + } + + return nullptr; + } + + inline void* WaitT(const TDuration& timeOut) { + return WaitD(timeOut.ToDeadLine()); + } + + inline void* WaitI() { + return WaitD(TInstant::Max()); + } + +private: + class TImpl; + THolder<TImpl> Impl_; +}; diff --git a/util/network/poller_ut.cpp b/util/network/poller_ut.cpp new file mode 100644 index 0000000000..6df0dda8ec --- /dev/null +++ b/util/network/poller_ut.cpp @@ -0,0 +1,236 @@ +#include <library/cpp/testing/unittest/registar.h> +#include <util/system/error.h> + +#include "pair.h" +#include "poller.h" +#include "pollerimpl.h" + +Y_UNIT_TEST_SUITE(TSocketPollerTest) { + Y_UNIT_TEST(TestSimple) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + TSocketPoller poller; + poller.WaitRead(sockets[1], (void*)17); + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + for (ui32 i = 0; i < 3; ++i) { + char buf[] = {18}; + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0)); + UNIT_ASSERT_VALUES_EQUAL(18, buf[0]); + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + } + } + + Y_UNIT_TEST(TestSimpleOneShot) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + TSocketPoller poller; + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + for (ui32 i = 0; i < 3; ++i) { + poller.WaitReadOneShot(sockets[1], (void*)17); + + char buf[1]; + + buf[0] = i + 20; + + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 20), buf[0]); + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + buf[0] = i + 21; + + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + // this fails if socket is not oneshot + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 21), buf[0]); + } + } + + Y_UNIT_TEST(TestItIsSafeToUnregisterUnregisteredDescriptor) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + TSocketPoller poller; + + poller.Unwait(s1); + } + + Y_UNIT_TEST(TestItIsSafeToReregisterDescriptor) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + TSocketPoller poller; + + poller.WaitRead(s1, nullptr); + poller.WaitRead(s1, nullptr); + poller.WaitWrite(s1, nullptr); + } + + Y_UNIT_TEST(TestSimpleEdgeTriggered) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + SetNonBlock(sockets[1]); + + TSocketPoller poller; + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + for (ui32 i = 0; i < 3; ++i) { + poller.WaitReadWriteEdgeTriggered(sockets[1], (void*)17); + + // notify about writeble + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + char buf[2]; + + buf[0] = i + 10; + buf[1] = i + 20; + + // send one byte + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // restart without reading + poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, false); + + // after restart read and write might generate separate events + { + void* events[3]; + size_t count = poller.WaitT(events, 3, TDuration::Zero()); + UNIT_ASSERT_GE(count, 1); + UNIT_ASSERT_LE(count, 2); + UNIT_ASSERT_VALUES_EQUAL(events[0], (void*)17); + } + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // second two more bytes + UNIT_ASSERT_VALUES_EQUAL(2, send(sockets[0], buf, 2, 0)); + + // here poller could notify or not because we haven't seen end + Y_UNUSED(poller.WaitT(TDuration::Zero())); + + // recv one, leave two + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 1, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]); + + // nothing new + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // recv the rest + UNIT_ASSERT_VALUES_EQUAL(2, recv(sockets[1], buf, 2, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]); + UNIT_ASSERT_VALUES_EQUAL(char(i + 20), buf[1]); + + // still nothing new + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // hit end + ClearLastSystemError(); + UNIT_ASSERT_VALUES_EQUAL(-1, recv(sockets[1], buf, 1, 0)); + UNIT_ASSERT_VALUES_EQUAL(EAGAIN, LastSystemError()); + + // restart after end (noop for epoll) + poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, true); + + // send and recv byte + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // recv and see end + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 2, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]); + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + // the same but send before restart + UNIT_ASSERT_VALUES_EQUAL(1, send(sockets[0], buf, 1, 0)); + + // restart after end (noop for epoll) + poller.RestartReadWriteEdgeTriggered(sockets[1], (void*)17, true); + + UNIT_ASSERT_VALUES_EQUAL((void*)17, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + UNIT_ASSERT_VALUES_EQUAL(1, recv(sockets[1], buf, 2, 0)); + UNIT_ASSERT_VALUES_EQUAL(char(i + 10), buf[0]); + + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + UNIT_ASSERT_VALUES_EQUAL(nullptr, poller.WaitT(TDuration::Zero())); + + poller.Unwait(sockets[1]); + } + } + +#if defined(HAVE_EPOLL_POLLER) + Y_UNIT_TEST(TestRdhup) { + SOCKET sockets[2]; + UNIT_ASSERT(SocketPair(sockets) == 0); + + TSocketHolder s1(sockets[0]); + TSocketHolder s2(sockets[1]); + + char buf[1] = {0}; + UNIT_ASSERT_VALUES_EQUAL(1, send(s1, buf, 1, 0)); + shutdown(s1, SHUT_WR); + + using TPoller = TGenericPoller<TEpollPoller<TWithoutLocking>>; + TPoller poller; + poller.Set((void*)17, s2, CONT_POLL_RDHUP); + + TPoller::TEvent e; + UNIT_ASSERT_VALUES_EQUAL(poller.WaitD(&e, 1, TDuration::Zero().ToDeadLine()), 1); + UNIT_ASSERT_EQUAL(TPoller::ExtractStatus(&e), 0); + UNIT_ASSERT_EQUAL(TPoller::ExtractFilter(&e), CONT_POLL_RDHUP); + UNIT_ASSERT_EQUAL(TPoller::ExtractEvent(&e), (void*)17); + } +#endif +} diff --git a/util/network/pollerimpl.cpp b/util/network/pollerimpl.cpp new file mode 100644 index 0000000000..bf2ba16cf6 --- /dev/null +++ b/util/network/pollerimpl.cpp @@ -0,0 +1 @@ +#include "pollerimpl.h" diff --git a/util/network/pollerimpl.h b/util/network/pollerimpl.h new file mode 100644 index 0000000000..e8c7e40fba --- /dev/null +++ b/util/network/pollerimpl.h @@ -0,0 +1,706 @@ +#pragma once + +#include "socket.h" + +#include <util/system/error.h> +#include <util/system/mutex.h> +#include <util/system/defaults.h> +#include <util/generic/ylimits.h> +#include <util/generic/utility.h> +#include <util/generic/vector.h> +#include <util/generic/yexception.h> +#include <util/datetime/base.h> + +#if defined(_freebsd_) || defined(_darwin_) + #define HAVE_KQUEUE_POLLER +#endif + +#if (defined(_linux_) && !defined(_bionic_)) || (__ANDROID_API__ >= 21) + #define HAVE_EPOLL_POLLER +#endif + +//now we always have it +#define HAVE_SELECT_POLLER + +#if defined(HAVE_KQUEUE_POLLER) + #include <sys/event.h> +#endif + +#if defined(HAVE_EPOLL_POLLER) + #include <sys/epoll.h> +#endif + +enum EContPoll { + CONT_POLL_READ = 1, + CONT_POLL_WRITE = 2, + CONT_POLL_RDHUP = 4, + CONT_POLL_ONE_SHOT = 8, // Disable after first event + CONT_POLL_MODIFY = 16, // Modify already added event + CONT_POLL_EDGE_TRIGGERED = 32, // Notify only about new events + CONT_POLL_BACKLOG_EMPTY = 64, // Backlog is empty (seen end of request, EAGAIN or truncated read) +}; + +static inline bool IsSocket(SOCKET fd) noexcept { + int val = 0; + socklen_t len = sizeof(val); + + if (getsockopt(fd, SOL_SOCKET, SO_TYPE, (char*)&val, &len) == 0) { + return true; + } + + return LastSystemError() != ENOTSOCK; +} + +static inline int MicroToMilli(int timeout) noexcept { + if (timeout) { + /* + * 1. API of epoll syscall allows to specify timeout with millisecond + * accuracy only + * 2. It is quite complicated to guarantee time resolution of blocking + * syscall less than kernel 1/HZ + * + * Without this rounding we just waste cpu time and do a lot of + * fast epoll_wait(..., 0) syscalls. + */ + return Max(timeout / 1000, 1); + } + + return 0; +} + +struct TWithoutLocking { + using TMyMutex = TFakeMutex; +}; + +#if defined(HAVE_KQUEUE_POLLER) +static inline int Kevent(int kq, struct kevent* changelist, int nchanges, + struct kevent* eventlist, int nevents, const struct timespec* timeout) noexcept { + int ret; + + do { + ret = kevent(kq, changelist, nchanges, eventlist, nevents, timeout); + } while (ret == -1 && errno == EINTR); + + return ret; +} + +template <class TLockPolicy> +class TKqueuePoller { +public: + typedef struct ::kevent TEvent; + + inline TKqueuePoller() + : Fd_(kqueue()) + { + if (Fd_ == -1) { + ythrow TSystemError() << "kqueue failed"; + } + } + + inline ~TKqueuePoller() { + close(Fd_); + } + + inline int Fd() const noexcept { + return Fd_; + } + + inline void SetImpl(void* data, int fd, int what) { + TEvent e[2]; + int flags = EV_ADD; + + if (what & CONT_POLL_EDGE_TRIGGERED) { + if (what & CONT_POLL_BACKLOG_EMPTY) { + // When backlog is empty, edge-triggered does not need restart. + return; + } + flags |= EV_CLEAR; + } + + if (what & CONT_POLL_ONE_SHOT) { + flags |= EV_ONESHOT; + } + + Zero(e); + + EV_SET(e + 0, fd, EVFILT_READ, flags | ((what & CONT_POLL_READ) ? EV_ENABLE : EV_DISABLE), 0, 0, data); + EV_SET(e + 1, fd, EVFILT_WRITE, flags | ((what & CONT_POLL_WRITE) ? EV_ENABLE : EV_DISABLE), 0, 0, data); + + if (Kevent(Fd_, e, 2, nullptr, 0, nullptr) == -1) { + ythrow TSystemError() << "kevent add failed"; + } + } + + inline void Remove(int fd) noexcept { + TEvent e[2]; + + Zero(e); + + EV_SET(e + 0, fd, EVFILT_READ, EV_DELETE, 0, 0, 0); + EV_SET(e + 1, fd, EVFILT_WRITE, EV_DELETE, 0, 0, 0); + + Y_VERIFY(!(Kevent(Fd_, e, 2, nullptr, 0, nullptr) == -1 && errno != ENOENT), "kevent remove failed: %s", LastSystemErrorText()); + } + + inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept { + struct timespec ts; + + ts.tv_sec = timeout / 1000000; + ts.tv_nsec = (timeout % 1000000) * 1000; + + const int ret = Kevent(Fd_, nullptr, 0, events, len, &ts); + + Y_VERIFY(ret >= 0, "kevent failed: %s", LastSystemErrorText()); + + return (size_t)ret; + } + + static inline void* ExtractEvent(const TEvent* event) noexcept { + return event->udata; + } + + static inline int ExtractStatus(const TEvent* event) noexcept { + if (event->flags & EV_ERROR) { + return EIO; + } + + return event->fflags; + } + + static inline int ExtractFilterImpl(const TEvent* event) noexcept { + if (event->filter == EVFILT_READ) { + return CONT_POLL_READ; + } + + if (event->filter == EVFILT_WRITE) { + return CONT_POLL_WRITE; + } + + if (event->flags & EV_EOF) { + return CONT_POLL_READ | CONT_POLL_WRITE; + } + + return 0; + } + +private: + int Fd_; +}; +#endif + +#if defined(HAVE_EPOLL_POLLER) +static inline int ContEpollWait(int epfd, struct epoll_event* events, int maxevents, int timeout) noexcept { + int ret; + + do { + ret = epoll_wait(epfd, events, maxevents, Min<int>(timeout, 35 * 60 * 1000)); + } while (ret == -1 && errno == EINTR); + + return ret; +} + +template <class TLockPolicy> +class TEpollPoller { +public: + typedef struct ::epoll_event TEvent; + + inline TEpollPoller(bool closeOnExec = false) + : Fd_(epoll_create1(closeOnExec ? EPOLL_CLOEXEC : 0)) + { + if (Fd_ == -1) { + ythrow TSystemError() << "epoll_create failed"; + } + } + + inline ~TEpollPoller() { + close(Fd_); + } + + inline int Fd() const noexcept { + return Fd_; + } + + inline void SetImpl(void* data, int fd, int what) { + TEvent e; + + Zero(e); + + if (what & CONT_POLL_EDGE_TRIGGERED) { + if (what & CONT_POLL_BACKLOG_EMPTY) { + // When backlog is empty, edge-triggered does not need restart. + return; + } + e.events |= EPOLLET; + } + + if (what & CONT_POLL_ONE_SHOT) { + e.events |= EPOLLONESHOT; + } + + if (what & CONT_POLL_READ) { + e.events |= EPOLLIN; + } + + if (what & CONT_POLL_WRITE) { + e.events |= EPOLLOUT; + } + + if (what & CONT_POLL_RDHUP) { + e.events |= EPOLLRDHUP; + } + + e.data.ptr = data; + + if ((what & CONT_POLL_MODIFY) || epoll_ctl(Fd_, EPOLL_CTL_ADD, fd, &e) == -1) { + if (epoll_ctl(Fd_, EPOLL_CTL_MOD, fd, &e) == -1) { + ythrow TSystemError() << "epoll add failed"; + } + } + } + + inline void Remove(int fd) noexcept { + TEvent e; + + Zero(e); + + epoll_ctl(Fd_, EPOLL_CTL_DEL, fd, &e); + } + + inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept { + const int ret = ContEpollWait(Fd_, events, len, MicroToMilli(timeout)); + + Y_VERIFY(ret >= 0, "epoll wait error: %s", LastSystemErrorText()); + + return (size_t)ret; + } + + static inline void* ExtractEvent(const TEvent* event) noexcept { + return event->data.ptr; + } + + static inline int ExtractStatus(const TEvent* event) noexcept { + if (event->events & (EPOLLERR | EPOLLHUP)) { + return EIO; + } + + return 0; + } + + static inline int ExtractFilterImpl(const TEvent* event) noexcept { + int ret = 0; + + if (event->events & EPOLLIN) { + ret |= CONT_POLL_READ; + } + + if (event->events & EPOLLOUT) { + ret |= CONT_POLL_WRITE; + } + + if (event->events & EPOLLRDHUP) { + ret |= CONT_POLL_RDHUP; + } + + return ret; + } + +private: + int Fd_; +}; +#endif + +#if defined(HAVE_SELECT_POLLER) + #include <util/memory/tempbuf.h> + #include <util/generic/hash.h> + + #include "pair.h" + +static inline int ContSelect(int n, fd_set* r, fd_set* w, fd_set* e, struct timeval* t) noexcept { + int ret; + + do { + ret = select(n, r, w, e, t); + } while (ret == -1 && errno == EINTR); + + return ret; +} + +struct TSelectPollerNoTemplate { + struct THandle { + void* Data_; + int Filter_; + + inline THandle() + : Data_(nullptr) + , Filter_(0) + { + } + + inline void* Data() const noexcept { + return Data_; + } + + inline void Set(void* d, int s) noexcept { + Data_ = d; + Filter_ = s; + } + + inline void Clear(int c) noexcept { + Filter_ &= ~c; + } + + inline int Filter() const noexcept { + return Filter_; + } + }; + + class TFds: public THashMap<SOCKET, THandle> { + public: + inline void Set(SOCKET fd, void* data, int filter) { + (*this)[fd].Set(data, filter); + } + + inline void Remove(SOCKET fd) { + erase(fd); + } + + inline SOCKET Build(fd_set* r, fd_set* w, fd_set* e) const noexcept { + SOCKET ret = 0; + + for (const auto& it : *this) { + const SOCKET fd = it.first; + const THandle& handle = it.second; + + FD_SET(fd, e); + + if (handle.Filter() & CONT_POLL_READ) { + FD_SET(fd, r); + } + + if (handle.Filter() & CONT_POLL_WRITE) { + FD_SET(fd, w); + } + + if (fd > ret) { + ret = fd; + } + } + + return ret; + } + }; + + struct TEvent: public THandle { + inline int Status() const noexcept { + return -Min(Filter(), 0); + } + + inline void Error(void* d, int err) noexcept { + Set(d, -err); + } + + inline void Success(void* d, int what) noexcept { + Set(d, what); + } + }; +}; + +template <class TLockPolicy> +class TSelectPoller: public TSelectPollerNoTemplate { + using TMyMutex = typename TLockPolicy::TMyMutex; + +public: + inline TSelectPoller() + : Begin_(nullptr) + , End_(nullptr) + { + SocketPair(Signal_); + SetNonBlock(WaitSock()); + SetNonBlock(SigSock()); + } + + inline ~TSelectPoller() { + closesocket(Signal_[0]); + closesocket(Signal_[1]); + } + + inline void SetImpl(void* data, SOCKET fd, int what) { + with_lock (CommandLock_) { + Commands_.push_back(TCommand(fd, what, data)); + } + + Signal(); + } + + inline void Remove(SOCKET fd) noexcept { + with_lock (CommandLock_) { + Commands_.push_back(TCommand(fd, 0)); + } + + Signal(); + } + + inline size_t Wait(TEvent* events, size_t len, int timeout) noexcept { + auto guard = Guard(Lock_); + + do { + if (Begin_ != End_) { + const size_t ret = Min<size_t>(End_ - Begin_, len); + + memcpy(events, Begin_, sizeof(*events) * ret); + Begin_ += ret; + + return ret; + } + + if (len >= EventNumberHint()) { + return WaitBase(events, len, timeout); + } + + Begin_ = SavedEvents(); + End_ = Begin_ + WaitBase(Begin_, EventNumberHint(), timeout); + } while (Begin_ != End_); + + return 0; + } + + inline TEvent* SavedEvents() { + if (!SavedEvents_) { + SavedEvents_.Reset(new TEvent[EventNumberHint()]); + } + + return SavedEvents_.Get(); + } + + inline size_t WaitBase(TEvent* events, size_t len, int timeout) noexcept { + with_lock (CommandLock_) { + for (auto command = Commands_.begin(); command != Commands_.end(); ++command) { + if (command->Filter_ != 0) { + Fds_.Set(command->Fd_, command->Cookie_, command->Filter_); + } else { + Fds_.Remove(command->Fd_); + } + } + + Commands_.clear(); + } + + TTempBuf tmpBuf(3 * sizeof(fd_set) + Fds_.size() * sizeof(SOCKET)); + + fd_set* in = (fd_set*)tmpBuf.Data(); + fd_set* out = &in[1]; + fd_set* errFds = &in[2]; + + SOCKET* keysToDeleteBegin = (SOCKET*)&in[3]; + SOCKET* keysToDeleteEnd = keysToDeleteBegin; + + #if defined(_msan_enabled_) // msan doesn't handle FD_ZERO and cause false positive BALANCER-1347 + memset(in, 0, sizeof(*in)); + memset(out, 0, sizeof(*out)); + memset(errFds, 0, sizeof(*errFds)); + #endif + + FD_ZERO(in); + FD_ZERO(out); + FD_ZERO(errFds); + + FD_SET(WaitSock(), in); + + const SOCKET maxFdNum = Max(Fds_.Build(in, out, errFds), WaitSock()); + struct timeval tout; + + tout.tv_sec = timeout / 1000000; + tout.tv_usec = timeout % 1000000; + + int ret = ContSelect(int(maxFdNum + 1), in, out, errFds, &tout); + + if (ret > 0 && FD_ISSET(WaitSock(), in)) { + --ret; + TryWait(); + } + + Y_VERIFY(ret >= 0 && (size_t)ret <= len, "select error: %s", LastSystemErrorText()); + + TEvent* eventsStart = events; + + for (typename TFds::iterator it = Fds_.begin(); it != Fds_.end(); ++it) { + const SOCKET fd = it->first; + THandle& handle = it->second; + + if (FD_ISSET(fd, errFds)) { + (events++)->Error(handle.Data(), EIO); + + if (handle.Filter() & CONT_POLL_ONE_SHOT) { + *keysToDeleteEnd = fd; + ++keysToDeleteEnd; + } + + } else { + int what = 0; + + if (FD_ISSET(fd, in)) { + what |= CONT_POLL_READ; + } + + if (FD_ISSET(fd, out)) { + what |= CONT_POLL_WRITE; + } + + if (what) { + (events++)->Success(handle.Data(), what); + + if (handle.Filter() & CONT_POLL_ONE_SHOT) { + *keysToDeleteEnd = fd; + ++keysToDeleteEnd; + } + + if (handle.Filter() & CONT_POLL_EDGE_TRIGGERED) { + // Emulate edge-triggered for level-triggered select(). + // User must restart waiting this event when needed. + handle.Clear(what); + } + } + } + } + + while (keysToDeleteBegin != keysToDeleteEnd) { + Fds_.erase(*keysToDeleteBegin); + ++keysToDeleteBegin; + } + + return events - eventsStart; + } + + inline size_t EventNumberHint() const noexcept { + return sizeof(fd_set) * 8 * 2; + } + + static inline void* ExtractEvent(const TEvent* event) noexcept { + return event->Data(); + } + + static inline int ExtractStatus(const TEvent* event) noexcept { + return event->Status(); + } + + static inline int ExtractFilterImpl(const TEvent* event) noexcept { + return event->Filter(); + } + +private: + inline void Signal() noexcept { + char ch = 13; + + send(SigSock(), &ch, 1, 0); + } + + inline void TryWait() { + char ch[32]; + + while (recv(WaitSock(), ch, sizeof(ch), 0) > 0) { + Y_ASSERT(ch[0] == 13); + } + } + + inline SOCKET WaitSock() const noexcept { + return Signal_[1]; + } + + inline SOCKET SigSock() const noexcept { + return Signal_[0]; + } + +private: + struct TCommand { + SOCKET Fd_; + int Filter_; // 0 to remove + void* Cookie_; + + TCommand(SOCKET fd, int filter, void* cookie) + : Fd_(fd) + , Filter_(filter) + , Cookie_(cookie) + { + } + + TCommand(SOCKET fd, int filter) + : Fd_(fd) + , Filter_(filter) + { + } + }; + + TFds Fds_; + + TMyMutex Lock_; + TArrayHolder<TEvent> SavedEvents_; + TEvent* Begin_; + TEvent* End_; + + TMyMutex CommandLock_; + TVector<TCommand> Commands_; + + SOCKET Signal_[2]; +}; +#endif + +static inline TDuration PollStep(const TInstant& deadLine, const TInstant& now) noexcept { + if (deadLine < now) { + return TDuration::Zero(); + } + + return Min(deadLine - now, TDuration::Seconds(1000)); +} + +template <class TBase> +class TGenericPoller: public TBase { +public: + using TBase::TBase; + + using TEvent = typename TBase::TEvent; + + inline void Set(void* data, SOCKET fd, int what) { + if (what) { + this->SetImpl(data, fd, what); + } else { + this->Remove(fd); + } + } + + static inline int ExtractFilter(const TEvent* event) noexcept { + if (TBase::ExtractStatus(event)) { + return CONT_POLL_READ | CONT_POLL_WRITE | CONT_POLL_RDHUP; + } + + return TBase::ExtractFilterImpl(event); + } + + inline size_t WaitD(TEvent* events, size_t len, TInstant deadLine, TInstant now = TInstant::Now()) noexcept { + if (!len) { + return 0; + } + + size_t ret; + + do { + ret = this->Wait(events, len, (int)PollStep(deadLine, now).MicroSeconds()); + } while (!ret && ((now = TInstant::Now()) < deadLine)); + + return ret; + } +}; + +#if defined(HAVE_KQUEUE_POLLER) + #define TPollerImplBase TKqueuePoller +#elif defined(HAVE_EPOLL_POLLER) + #define TPollerImplBase TEpollPoller +#elif defined(HAVE_SELECT_POLLER) + #define TPollerImplBase TSelectPoller +#else + #error "unsupported platform" +#endif + +template <class TLockPolicy> +using TPollerImpl = TGenericPoller<TPollerImplBase<TLockPolicy>>; + +#undef TPollerImplBase diff --git a/util/network/sock.cpp b/util/network/sock.cpp new file mode 100644 index 0000000000..d4864a9c1c --- /dev/null +++ b/util/network/sock.cpp @@ -0,0 +1 @@ +#include "sock.h" diff --git a/util/network/sock.h b/util/network/sock.h new file mode 100644 index 0000000000..b10be2f715 --- /dev/null +++ b/util/network/sock.h @@ -0,0 +1,608 @@ +#pragma once + +#include <util/folder/path.h> +#include <util/system/defaults.h> +#include <util/string/cast.h> +#include <util/stream/output.h> +#include <util/system/sysstat.h> + +#if defined(_win_) || defined(_cygwin_) + #include <util/system/file.h> +#else + #include <sys/un.h> + #include <sys/stat.h> +#endif //_win_ + +#include "init.h" +#include "ip.h" +#include "socket.h" + +constexpr ui16 DEF_LOCAL_SOCK_MODE = 00644; + +// Base abstract class for socket address +struct ISockAddr { + virtual ~ISockAddr() = default; + // Max size of the address that we can store (arg of recvfrom) + virtual socklen_t Size() const = 0; + // Real length of the address (arg of sendto) + virtual socklen_t Len() const = 0; + // cast to sockaddr* to pass to any syscall + virtual sockaddr* SockAddr() = 0; + virtual const sockaddr* SockAddr() const = 0; + // address in human readable form + virtual TString ToString() const = 0; + +protected: + // below are the implemetation methods that can be called by T*Socket classes + friend class TBaseSocket; + friend class TDgramSocket; + friend class TStreamSocket; + + virtual int ResolveAddr() const { + // usually it's nothing to do here + return 0; + } + virtual int Bind(SOCKET s, ui16 mode) const = 0; +}; + +#if defined(_win_) || defined(_cygwin_) + #define YAF_LOCAL AF_INET +struct TSockAddrLocal: public ISockAddr { + TSockAddrLocal() { + Clear(); + } + + TSockAddrLocal(const char* path) { + Set(path); + } + + socklen_t Size() const { + return sizeof(sockaddr_in); + } + + socklen_t Len() const { + return Size(); + } + + inline void Clear() noexcept { + Zero(in); + Zero(Path); + } + + inline void Set(const char* path) noexcept { + Clear(); + in.sin_family = AF_INET; + in.sin_addr.s_addr = IpFromString("127.0.0.1"); + in.sin_port = 0; + strlcpy(Path, path, PathSize); + } + + sockaddr* SockAddr() { + return (struct sockaddr*)(&in); + } + + const sockaddr* SockAddr() const { + return (const struct sockaddr*)(&in); + } + + TString ToString() const { + return TString(Path); + } + + TFsPath ToPath() const { + return TFsPath(Path); + } + + int ResolveAddr() const { + if (in.sin_port == 0) { + int ret = 0; + // 1. open file + TFileHandle f(Path, OpenExisting | RdOnly); + if (!f.IsOpen()) + return -errno; + + // 2. read the port from file + ret = f.Read(&in.sin_port, sizeof(in.sin_port)); + if (ret != sizeof(in.sin_port)) + return -(errno ? errno : EFAULT); + } + + return 0; + } + + int Bind(SOCKET s, ui16 mode) const { + Y_UNUSED(mode); + int ret = 0; + // 1. open file + TFileHandle f(Path, CreateAlways | WrOnly); + if (!f.IsOpen()) + return -errno; + + // 2. find port and bind to it + in.sin_port = 0; + ret = bind(s, SockAddr(), Len()); + if (ret != 0) + return -WSAGetLastError(); + + int size = Size(); + ret = getsockname(s, (struct sockaddr*)(&in), &size); + if (ret != 0) + return -WSAGetLastError(); + + // 3. write port to file + ret = f.Write(&(in.sin_port), sizeof(in.sin_port)); + if (ret != sizeof(in.sin_port)) + return -errno; + + return 0; + } + + static constexpr size_t PathSize = 128; + mutable struct sockaddr_in in; + char Path[PathSize]; +}; +#else + #define YAF_LOCAL AF_LOCAL +struct TSockAddrLocal: public sockaddr_un, public ISockAddr { + TSockAddrLocal() { + Clear(); + } + + TSockAddrLocal(const char* path) { + Set(path); + } + + socklen_t Size() const override { + return sizeof(sockaddr_un); + } + + socklen_t Len() const override { + return strlen(sun_path) + 2; + } + + inline void Clear() noexcept { + Zero(*(sockaddr_un*)this); + } + + inline void Set(const char* path) noexcept { + Clear(); + sun_family = AF_UNIX; + strlcpy(sun_path, path, sizeof(sun_path)); + } + + sockaddr* SockAddr() override { + return (struct sockaddr*)(struct sockaddr_un*)this; + } + + const sockaddr* SockAddr() const override { + return (const struct sockaddr*)(const struct sockaddr_un*)this; + } + + TString ToString() const override { + return TString(sun_path); + } + + TFsPath ToPath() const { + return TFsPath(sun_path); + } + + int Bind(SOCKET s, ui16 mode) const override { + (void)unlink(sun_path); + + int ret = bind(s, SockAddr(), Len()); + if (ret < 0) + return -errno; + + ret = Chmod(sun_path, mode); + if (ret < 0) + return -errno; + return 0; + } +}; +#endif // _win_ + +struct TSockAddrInet: public sockaddr_in, public ISockAddr { + TSockAddrInet() { + Clear(); + } + + TSockAddrInet(TIpHost ip, TIpPort port) { + Set(ip, port); + } + + TSockAddrInet(const char* ip, TIpPort port) { + Set(IpFromString(ip), port); + } + + socklen_t Size() const override { + return sizeof(sockaddr_in); + } + + socklen_t Len() const override { + return Size(); + } + + inline void Clear() noexcept { + Zero(*(sockaddr_in*)this); + } + + inline void Set(TIpHost ip, TIpPort port) noexcept { + Clear(); + sin_family = AF_INET; + sin_addr.s_addr = ip; + sin_port = HostToInet(port); + } + + sockaddr* SockAddr() override { + return (struct sockaddr*)(struct sockaddr_in*)this; + } + + const sockaddr* SockAddr() const override { + return (const struct sockaddr*)(const struct sockaddr_in*)this; + } + + TString ToString() const override { + return IpToString(sin_addr.s_addr) + ":" + ::ToString(InetToHost(sin_port)); + } + + int Bind(SOCKET s, ui16 mode) const override { + Y_UNUSED(mode); + int ret = bind(s, SockAddr(), Len()); + if (ret < 0) + return -errno; + + socklen_t len = Len(); + if (getsockname(s, (struct sockaddr*)(SockAddr()), &len) < 0) + return -WSAGetLastError(); + + return 0; + } + + TIpHost GetIp() const noexcept { + return sin_addr.s_addr; + } + + TIpPort GetPort() const noexcept { + return InetToHost(sin_port); + } + + void SetPort(TIpPort port) noexcept { + sin_port = HostToInet(port); + } +}; + +struct TSockAddrInet6: public sockaddr_in6, public ISockAddr { + TSockAddrInet6() { + Clear(); + } + + TSockAddrInet6(const char* ip6, const TIpPort port) { + Set(ip6, port); + } + + socklen_t Size() const override { + return sizeof(sockaddr_in6); + } + + socklen_t Len() const override { + return Size(); + } + + inline void Clear() noexcept { + Zero(*(sockaddr_in6*)this); + } + + inline void Set(const char* ip6, const TIpPort port) noexcept { + Clear(); + sin6_family = AF_INET6; + inet_pton(AF_INET6, ip6, &sin6_addr); + sin6_port = HostToInet(port); + } + + sockaddr* SockAddr() override { + return (struct sockaddr*)(struct sockaddr_in6*)this; + } + + const sockaddr* SockAddr() const override { + return (const struct sockaddr*)(const struct sockaddr_in6*)this; + } + + TString ToString() const override { + return "[" + GetIp() + "]:" + ::ToString(InetToHost(sin6_port)); + } + + int Bind(SOCKET s, ui16 mode) const override { + Y_UNUSED(mode); + int ret = bind(s, SockAddr(), Len()); + if (ret < 0) { + return -errno; + } + socklen_t len = Len(); + if (getsockname(s, (struct sockaddr*)(SockAddr()), &len) < 0) { + return -WSAGetLastError(); + } + return 0; + } + + TString GetIp() const noexcept { + char ip6[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, (void*)&sin6_addr, ip6, INET6_ADDRSTRLEN); + return TString(ip6); + } + + TIpPort GetPort() const noexcept { + return InetToHost(sin6_port); + } + + void SetPort(TIpPort port) noexcept { + sin6_port = HostToInet(port); + } +}; + +using TSockAddrLocalStream = TSockAddrLocal; +using TSockAddrLocalDgram = TSockAddrLocal; +using TSockAddrInetStream = TSockAddrInet; +using TSockAddrInetDgram = TSockAddrInet; +using TSockAddrInet6Stream = TSockAddrInet6; +using TSockAddrInet6Dgram = TSockAddrInet6; + +class TBaseSocket: public TSocketHolder { +protected: + TBaseSocket(SOCKET fd) + : TSocketHolder(fd) + { + } + +public: + int Bind(const ISockAddr* addr, ui16 mode = DEF_LOCAL_SOCK_MODE) { + return addr->Bind((SOCKET) * this, mode); + } + + void CheckSock() { + if ((SOCKET) * this == INVALID_SOCKET) + ythrow TSystemError() << "no socket"; + } + + static ssize_t Check(ssize_t ret, const char* op = "") { + if (ret < 0) + ythrow TSystemError(-(int)ret) << "socket operation " << op; + return ret; + } +}; + +class TDgramSocket: public TBaseSocket { +protected: + TDgramSocket(SOCKET fd) + : TBaseSocket(fd) + { + } + +public: + ssize_t SendTo(const void* msg, size_t len, const ISockAddr* toAddr) { + ssize_t ret = toAddr->ResolveAddr(); + if (ret < 0) { + return -LastSystemError(); + } + + ret = sendto((SOCKET) * this, (const char*)msg, (int)len, 0, toAddr->SockAddr(), toAddr->Len()); + if (ret < 0) { + return -LastSystemError(); + } + + return ret; + } + + ssize_t RecvFrom(void* buf, size_t len, ISockAddr* fromAddr) { + socklen_t fromSize = fromAddr->Size(); + const ssize_t ret = recvfrom((SOCKET) * this, (char*)buf, (int)len, 0, fromAddr->SockAddr(), &fromSize); + if (ret < 0) { + return -LastSystemError(); + } + + return ret; + } +}; + +class TStreamSocket: public TBaseSocket { +protected: + explicit TStreamSocket(SOCKET fd) + : TBaseSocket(fd) + { + } + +public: + TStreamSocket() + : TBaseSocket(INVALID_SOCKET) + { + } + + ssize_t Send(const void* msg, size_t len, int flags = 0) { + const ssize_t ret = send((SOCKET) * this, (const char*)msg, (int)len, flags); + if (ret < 0) + return -errno; + + return ret; + } + + ssize_t Recv(void* buf, size_t len, int flags = 0) { + const ssize_t ret = recv((SOCKET) * this, (char*)buf, (int)len, flags); + if (ret < 0) + return -errno; + + return ret; + } + + int Connect(const ISockAddr* addr) { + int ret = addr->ResolveAddr(); + if (ret < 0) + return -errno; + + ret = connect((SOCKET) * this, addr->SockAddr(), addr->Len()); + if (ret < 0) + return -errno; + + return ret; + } + + int Listen(int backlog) { + int ret = listen((SOCKET) * this, backlog); + if (ret < 0) + return -errno; + + return ret; + } + + int Accept(TStreamSocket* acceptedSock, ISockAddr* acceptedAddr = nullptr) { + SOCKET s = INVALID_SOCKET; + if (acceptedAddr) { + socklen_t acceptedSize = acceptedAddr->Size(); + s = accept((SOCKET) * this, acceptedAddr->SockAddr(), &acceptedSize); + } else { + s = accept((SOCKET) * this, nullptr, nullptr); + } + + if (s == INVALID_SOCKET) + return -errno; + + TSocketHolder sock(s); + acceptedSock->Swap(sock); + return 0; + } +}; + +class TLocalDgramSocket: public TDgramSocket { +public: + TLocalDgramSocket(SOCKET fd) + : TDgramSocket(fd) + { + } + + TLocalDgramSocket() + : TDgramSocket(socket(YAF_LOCAL, SOCK_DGRAM, 0)) + { + } +}; + +class TInetDgramSocket: public TDgramSocket { +public: + TInetDgramSocket(SOCKET fd) + : TDgramSocket(fd) + { + } + + TInetDgramSocket() + : TDgramSocket(socket(AF_INET, SOCK_DGRAM, 0)) + { + } +}; + +class TInet6DgramSocket: public TDgramSocket { +public: + TInet6DgramSocket(SOCKET fd) + : TDgramSocket(fd) + { + } + + TInet6DgramSocket() + : TDgramSocket(socket(AF_INET6, SOCK_DGRAM, 0)) + { + } +}; + +class TLocalStreamSocket: public TStreamSocket { +public: + TLocalStreamSocket(SOCKET fd) + : TStreamSocket(fd) + { + } + + TLocalStreamSocket() + : TStreamSocket(socket(YAF_LOCAL, SOCK_STREAM, 0)) + { + } +}; + +class TInetStreamSocket: public TStreamSocket { +public: + TInetStreamSocket(SOCKET fd) + : TStreamSocket(fd) + { + } + + TInetStreamSocket() + : TStreamSocket(socket(AF_INET, SOCK_STREAM, 0)) + { + } +}; + +class TInet6StreamSocket: public TStreamSocket { +public: + TInet6StreamSocket(SOCKET fd) + : TStreamSocket(fd) + { + } + + TInet6StreamSocket() + : TStreamSocket(socket(AF_INET6, SOCK_STREAM, 0)) + { + } +}; + +class TStreamSocketInput: public IInputStream { +public: + TStreamSocketInput(TStreamSocket* socket) + : Socket(socket) + { + } + void SetSocket(TStreamSocket* socket) { + Socket = socket; + } + +protected: + TStreamSocket* Socket; + + size_t DoRead(void* buf, size_t len) override { + Y_VERIFY(Socket, "TStreamSocketInput: socket isn't set"); + const ssize_t ret = Socket->Recv(buf, len); + + if (ret >= 0) { + return (size_t)ret; + } + + ythrow TSystemError(-(int)ret) << "can not read from socket input stream"; + } +}; + +class TStreamSocketOutput: public IOutputStream { +public: + TStreamSocketOutput(TStreamSocket* socket) + : Socket(socket) + { + } + void SetSocket(TStreamSocket* socket) { + Socket = socket; + } + + TStreamSocketOutput(TStreamSocketOutput&&) noexcept = default; + TStreamSocketOutput& operator=(TStreamSocketOutput&&) noexcept = default; + +protected: + TStreamSocket* Socket; + + void DoWrite(const void* buf, size_t len) override { + Y_VERIFY(Socket, "TStreamSocketOutput: socket isn't set"); + + const char* ptr = (const char*)buf; + while (len) { + const ssize_t ret = Socket->Send(ptr, len); + + if (ret < 0) { + ythrow TSystemError(-(int)ret) << "can not write to socket output stream"; + } + + Y_ASSERT((size_t)ret <= len); + len -= (size_t)ret; + ptr += (size_t)ret; + } + } +}; diff --git a/util/network/sock_ut.cpp b/util/network/sock_ut.cpp new file mode 100644 index 0000000000..fd8c783747 --- /dev/null +++ b/util/network/sock_ut.cpp @@ -0,0 +1,168 @@ +#include "sock.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/threading/future/legacy_future.h> + +#include <util/system/fs.h> + +Y_UNIT_TEST_SUITE(TSocketTest) { + Y_UNIT_TEST(InetDgramTest) { + char buf[256]; + TSockAddrInetDgram servAddr(IpFromString("127.0.0.1"), 0); + TSockAddrInetDgram cliAddr(IpFromString("127.0.0.1"), 0); + TSockAddrInetDgram servFromAddr; + TSockAddrInetDgram cliFromAddr; + TInetDgramSocket cliSock; + TInetDgramSocket servSock; + cliSock.CheckSock(); + servSock.CheckSock(); + + TBaseSocket::Check(cliSock.Bind(&cliAddr)); + TBaseSocket::Check(servSock.Bind(&servAddr)); + + // client + const char reqStr[] = "Hello, world!!!"; + TBaseSocket::Check(cliSock.SendTo(reqStr, sizeof(reqStr), &servAddr)); + + // server + TBaseSocket::Check(servSock.RecvFrom(buf, 256, &servFromAddr)); + UNIT_ASSERT(strcmp(reqStr, buf) == 0); + const char repStr[] = "The World's greatings to you"; + TBaseSocket::Check(servSock.SendTo(repStr, sizeof(repStr), &servFromAddr)); + + // client + TBaseSocket::Check(cliSock.RecvFrom(buf, 256, &cliFromAddr)); + UNIT_ASSERT(strcmp(repStr, buf) == 0); + } + + void RunLocalDgramTest(const char* localServerSockName, const char* localClientSockName) { + char buf[256]; + TSockAddrLocalDgram servAddr(localServerSockName); + TSockAddrLocalDgram cliAddr(localClientSockName); + TSockAddrLocalDgram servFromAddr; + TSockAddrLocalDgram cliFromAddr; + TLocalDgramSocket cliSock; + TLocalDgramSocket servSock; + cliSock.CheckSock(); + servSock.CheckSock(); + + TBaseSocket::Check(cliSock.Bind(&cliAddr), "bind client"); + TBaseSocket::Check(servSock.Bind(&servAddr), "bind server"); + + // client + const char reqStr[] = "Hello, world!!!"; + TBaseSocket::Check(cliSock.SendTo(reqStr, sizeof(reqStr), &servAddr), "send from client"); + + // server + TBaseSocket::Check(servSock.RecvFrom(buf, 256, &servFromAddr), "receive from client"); + UNIT_ASSERT(strcmp(reqStr, buf) == 0); + const char repStr[] = "The World's greatings to you"; + TBaseSocket::Check(servSock.SendTo(repStr, sizeof(repStr), &servFromAddr), "send to client"); + + // client + TBaseSocket::Check(cliSock.RecvFrom(buf, 256, &cliFromAddr), "receive from server"); + UNIT_ASSERT(strcmp(repStr, buf) == 0); + } + + Y_UNIT_TEST(LocalDgramTest) { + const char* localServerSockName = "./serv_sock"; + const char* localClientSockName = "./cli_sock"; + RunLocalDgramTest(localServerSockName, localClientSockName); + NFs::Remove(localServerSockName); + NFs::Remove(localClientSockName); + } + + template <class A, class S> + void RunInetStreamTest(const char* ip) { + char buf[256]; + A servAddr(ip, 0); + A newAddr; + S cliSock; + S servSock; + S newSock; + cliSock.CheckSock(); + servSock.CheckSock(); + newSock.CheckSock(); + + // server + int yes = 1; + CheckedSetSockOpt(servSock, SOL_SOCKET, SO_REUSEADDR, yes, "servSock, SO_REUSEADDR"); + TBaseSocket::Check(servSock.Bind(&servAddr), "bind"); + TBaseSocket::Check(servSock.Listen(10), "listen"); + + // client + TBaseSocket::Check(cliSock.Connect(&servAddr), "connect"); + + // server + TBaseSocket::Check(servSock.Accept(&newSock, &newAddr), "accept"); + + // client + const char reqStr[] = "Hello, world!!!"; + TBaseSocket::Check(cliSock.Send(reqStr, sizeof(reqStr)), "send"); + + // server - new + TBaseSocket::Check(newSock.Recv(buf, 256), "recv"); + UNIT_ASSERT(strcmp(reqStr, buf) == 0); + const char repStr[] = "The World's greatings to you"; + TBaseSocket::Check(newSock.Send(repStr, sizeof(repStr)), "send"); + + // client + TBaseSocket::Check(cliSock.Recv(buf, 256), "recv"); + UNIT_ASSERT(strcmp(repStr, buf) == 0); + } + + Y_UNIT_TEST(InetStreamTest) { + RunInetStreamTest<TSockAddrInetStream, TInetStreamSocket>("127.0.0.1"); + } + + Y_UNIT_TEST(Inet6StreamTest) { + RunInetStreamTest<TSockAddrInet6Stream, TInet6StreamSocket>("::1"); + } + + void RunLocalStreamTest(const char* localServerSockName) { + char buf[256]; + TSockAddrLocalStream servAddr(localServerSockName); + TSockAddrLocalStream newAddr; + TLocalStreamSocket cliSock; + TLocalStreamSocket servSock; + TLocalStreamSocket newSock; + cliSock.CheckSock(); + servSock.CheckSock(); + newSock.CheckSock(); + + // server + TBaseSocket::Check(servSock.Bind(&servAddr), "bind"); + TBaseSocket::Check(servSock.Listen(10), "listen"); + + NThreading::TLegacyFuture<void> f([&]() { + // server + TBaseSocket::Check(servSock.Accept(&newSock, &newAddr), "accept"); + }); + + // client + TBaseSocket::Check(cliSock.Connect(&servAddr), "connect"); + + f.Get(); + + // client + const char reqStr[] = "Hello, world!!!"; + TBaseSocket::Check(cliSock.Send(reqStr, sizeof(reqStr)), "send"); + + // server - new + TBaseSocket::Check(newSock.Recv(buf, 256), "recv"); + UNIT_ASSERT(strcmp(reqStr, buf) == 0); + const char repStr[] = "The World's greatings to you"; + TBaseSocket::Check(newSock.Send(repStr, sizeof(repStr)), "send"); + + // client + TBaseSocket::Check(cliSock.Recv(buf, 256), "recv"); + UNIT_ASSERT(strcmp(repStr, buf) == 0); + } + + Y_UNIT_TEST(LocalStreamTest) { + const char* localServerSockName = "./serv_sock2"; + RunLocalStreamTest(localServerSockName); + NFs::Remove(localServerSockName); + } + +} diff --git a/util/network/socket.cpp b/util/network/socket.cpp new file mode 100644 index 0000000000..c1a42e849e --- /dev/null +++ b/util/network/socket.cpp @@ -0,0 +1,1288 @@ +#include "ip.h" +#include "socket.h" +#include "address.h" +#include "pollerimpl.h" +#include "iovec.h" + +#include <util/system/defaults.h> +#include <util/system/byteorder.h> + +#if defined(_unix_) + #include <netdb.h> + #include <sys/types.h> + #include <sys/socket.h> + #include <sys/un.h> + #include <sys/ioctl.h> + #include <netinet/in.h> + #include <netinet/tcp.h> + #include <arpa/inet.h> +#endif + +#if defined(_freebsd_) + #include <sys/module.h> + #define ACCEPT_FILTER_MOD + #include <sys/socketvar.h> +#endif + +#if defined(_win_) + #include <cerrno> + #include <winsock2.h> + #include <ws2tcpip.h> + #include <wspiapi.h> + + #include <util/system/compat.h> +#endif + +#include <util/generic/ylimits.h> + +#include <util/string/cast.h> +#include <util/stream/mem.h> +#include <util/system/datetime.h> +#include <util/system/error.h> +#include <util/memory/tempbuf.h> +#include <util/generic/singleton.h> +#include <util/generic/hash_set.h> + +#include <stddef.h> +#include <sys/uio.h> + +using namespace NAddr; + +#if defined(_win_) + +int inet_aton(const char* cp, struct in_addr* inp) { + sockaddr_in addr; + addr.sin_family = AF_INET; + int psz = sizeof(addr); + if (0 == WSAStringToAddress((char*)cp, AF_INET, nullptr, (LPSOCKADDR)&addr, &psz)) { + memcpy(inp, &addr.sin_addr, sizeof(in_addr)); + return 1; + } + return 0; +} + + #if (_WIN32_WINNT < 0x0600) +const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) { + if (af != AF_INET) { + errno = EINVAL; + return 0; + } + const ui8* ia = (ui8*)src; + if (snprintf(dst, size, "%u.%u.%u.%u", ia[0], ia[1], ia[2], ia[3]) >= (int)size) { + errno = ENOSPC; + return 0; + } + return dst; +} + +struct evpair { + int event; + int winevent; +}; + +static const evpair evpairs_to_win[] = { + {POLLIN, FD_READ | FD_CLOSE | FD_ACCEPT}, + {POLLRDNORM, FD_READ | FD_CLOSE | FD_ACCEPT}, + {POLLRDBAND, -1}, + {POLLPRI, -1}, + {POLLOUT, FD_WRITE | FD_CLOSE}, + {POLLWRNORM, FD_WRITE | FD_CLOSE}, + {POLLWRBAND, -1}, + {POLLERR, 0}, + {POLLHUP, 0}, + {POLLNVAL, 0}}; + +static const size_t nevpairs_to_win = sizeof(evpairs_to_win) / sizeof(evpairs_to_win[0]); + +static const evpair evpairs_to_unix[] = { + {FD_ACCEPT, POLLIN | POLLRDNORM}, + {FD_READ, POLLIN | POLLRDNORM}, + {FD_WRITE, POLLOUT | POLLWRNORM}, + {FD_CLOSE, POLLHUP}, +}; + +static const size_t nevpairs_to_unix = sizeof(evpairs_to_unix) / sizeof(evpairs_to_unix[0]); + +static int convert_events(int events, const evpair* evpairs, size_t nevpairs, bool ignoreUnknown) noexcept { + int result = 0; + for (size_t i = 0; i < nevpairs; ++i) { + int event = evpairs[i].event; + if (events & event) { + events ^= event; + long winEvent = evpairs[i].winevent; + if (winEvent == -1) + return -1; + if (winEvent == 0) + continue; + result |= winEvent; + } + } + if (events != 0 && !ignoreUnknown) + return -1; + return result; +} + +class TWSAEventHolder { +private: + HANDLE Event; + +public: + inline TWSAEventHolder(HANDLE event) noexcept + : Event(event) + { + } + + inline ~TWSAEventHolder() { + WSACloseEvent(Event); + } + + inline HANDLE Get() noexcept { + return Event; + } +}; + +int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept { + HANDLE rawEvent = WSACreateEvent(); + if (rawEvent == WSA_INVALID_EVENT) { + errno = EIO; + return -1; + } + + TWSAEventHolder event(rawEvent); + + int checked_sockets = 0; + + for (pollfd* fd = fds; fd < fds + nfds; ++fd) { + int win_events = convert_events(fd->events, evpairs_to_win, nevpairs_to_win, false); + if (win_events == -1) { + errno = EINVAL; + return -1; + } + fd->revents = 0; + if (WSAEventSelect(fd->fd, event.Get(), win_events)) { + int error = WSAGetLastError(); + if (error == WSAEINVAL || error == WSAENOTSOCK) { + fd->revents = POLLNVAL; + ++checked_sockets; + } else { + errno = EIO; + return -1; + } + } + fd_set readfds; + fd_set writefds; + struct timeval timeout = {0, 0}; + FD_ZERO(&readfds); + FD_ZERO(&writefds); + if (fd->events & POLLIN) { + FD_SET(fd->fd, &readfds); + } + if (fd->events & POLLOUT) { + FD_SET(fd->fd, &writefds); + } + int error = select(0, &readfds, &writefds, nullptr, &timeout); + if (error > 0) { + if (FD_ISSET(fd->fd, &readfds)) { + fd->revents |= POLLIN; + } + if (FD_ISSET(fd->fd, &writefds)) { + fd->revents |= POLLOUT; + } + ++checked_sockets; + } + } + + if (checked_sockets > 0) { + // returns without wait since we already have sockets in desired conditions + return checked_sockets; + } + + HANDLE events[] = {event.Get()}; + DWORD wait_result = WSAWaitForMultipleEvents(1, events, TRUE, timeout, FALSE); + if (wait_result == WSA_WAIT_TIMEOUT) + return 0; + else if (wait_result == WSA_WAIT_EVENT_0) { + for (pollfd* fd = fds; fd < fds + nfds; ++fd) { + if (fd->revents == POLLNVAL) + continue; + WSANETWORKEVENTS network_events; + if (WSAEnumNetworkEvents(fd->fd, event.Get(), &network_events)) { + errno = EIO; + return -1; + } + fd->revents = 0; + for (int i = 0; i < FD_MAX_EVENTS; ++i) { + if ((network_events.lNetworkEvents & (1 << i)) != 0 && network_events.iErrorCode[i]) { + fd->revents = POLLERR; + break; + } + } + if (fd->revents == POLLERR) + continue; + if (network_events.lNetworkEvents) { + fd->revents = static_cast<short>(convert_events(network_events.lNetworkEvents, evpairs_to_unix, nevpairs_to_unix, true)); + if (fd->revents & POLLHUP) { + fd->revents &= POLLHUP | POLLIN | POLLRDNORM; + } + } + } + int chanded_sockets = 0; + for (pollfd* fd = fds; fd < fds + nfds; ++fd) + if (fd->revents != 0) + ++chanded_sockets; + return chanded_sockets; + } else { + errno = EIO; + return -1; + } +} + #endif + +#endif + +bool GetRemoteAddr(SOCKET Socket, char* str, socklen_t size) { + if (!size) { + return false; + } + + TOpaqueAddr addr; + + if (getpeername(Socket, addr.MutableAddr(), addr.LenPtr()) != 0) { + return false; + } + + try { + TMemoryOutput out(str, size - 1); + + PrintHost(out, addr); + *out.Buf() = 0; + + return true; + } catch (...) { + // ¯\_(ツ)_/¯ + } + + return false; +} + +void SetSocketTimeout(SOCKET s, long timeout) { + SetSocketTimeout(s, timeout, 0); +} + +void SetSocketTimeout(SOCKET s, long sec, long msec) { +#ifdef SO_SNDTIMEO + #ifdef _darwin_ + const timeval timeout = {sec, (__darwin_suseconds_t)msec * 1000}; + #elif defined(_unix_) + const timeval timeout = {sec, msec * 1000}; + #else + const int timeout = sec * 1000 + msec; + #endif + CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVTIMEO, timeout, "recv timeout"); + CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDTIMEO, timeout, "send timeout"); +#endif +} + +void SetLinger(SOCKET s, bool on, unsigned len) { +#ifdef SO_LINGER + struct linger l = {on, (u_short)len}; + + CheckedSetSockOpt(s, SOL_SOCKET, SO_LINGER, l, "linger"); +#endif +} + +void SetZeroLinger(SOCKET s) { + SetLinger(s, 1, 0); +} + +void SetKeepAlive(SOCKET s, bool value) { + CheckedSetSockOpt(s, SOL_SOCKET, SO_KEEPALIVE, (int)value, "keepalive"); +} + +void SetOutputBuffer(SOCKET s, unsigned value) { + CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDBUF, value, "output buffer"); +} + +void SetInputBuffer(SOCKET s, unsigned value) { + CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVBUF, value, "input buffer"); +} + +#if defined(_linux_) && !defined(SO_REUSEPORT) + #define SO_REUSEPORT 15 +#endif + +void SetReusePort(SOCKET s, bool value) { +#if defined(SO_REUSEPORT) + CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEPORT, (int)value, "reuse port"); +#else + Y_UNUSED(s); + Y_UNUSED(value); + ythrow TSystemError(ENOSYS) << "SO_REUSEPORT is not defined"; +#endif +} + +void SetNoDelay(SOCKET s, bool value) { + CheckedSetSockOpt(s, IPPROTO_TCP, TCP_NODELAY, (int)value, "tcp no delay"); +} + +void SetCloseOnExec(SOCKET s, bool value) { +#if defined(_unix_) + int flags = fcntl(s, F_GETFD); + if (flags == -1) { + ythrow TSystemError() << "fcntl() failed"; + } + if (value) { + flags |= FD_CLOEXEC; + } else { + flags &= ~FD_CLOEXEC; + } + if (fcntl(s, F_SETFD, flags) == -1) { + ythrow TSystemError() << "fcntl() failed"; + } +#else + Y_UNUSED(s); + Y_UNUSED(value); +#endif +} + +size_t GetMaximumSegmentSize(SOCKET s) { +#if defined(TCP_MAXSEG) + int val; + + if (GetSockOpt(s, IPPROTO_TCP, TCP_MAXSEG, val) == 0) { + return (size_t)val; + } +#endif + + /* + * probably a good guess... + */ + return 8192; +} + +size_t GetMaximumTransferUnit(SOCKET /*s*/) { + // for someone who'll dare to write it + // Linux: there rummored to be IP_MTU getsockopt() request + // FreeBSD: request to a socket of type PF_ROUTE + // with peer address as a destination argument + return 8192; +} + +int GetSocketToS(SOCKET s) { + TOpaqueAddr addr; + + if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) { + ythrow TSystemError() << "getsockname() failed"; + } + + return GetSocketToS(s, &addr); +} + +int GetSocketToS(SOCKET s, const IRemoteAddr* addr) { + int result = 0; + + switch (addr->Addr()->sa_family) { + case AF_INET: + CheckedGetSockOpt(s, IPPROTO_IP, IP_TOS, result, "tos"); + break; + + case AF_INET6: +#ifdef IPV6_TCLASS + CheckedGetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, result, "tos"); +#endif + break; + } + + return result; +} + +void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos) { + switch (addr->Addr()->sa_family) { + case AF_INET: + CheckedSetSockOpt(s, IPPROTO_IP, IP_TOS, tos, "tos"); + return; + + case AF_INET6: +#ifdef IPV6_TCLASS + CheckedSetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, tos, "tos"); + return; +#endif + break; + } + + ythrow yexception() << "SetSocketToS unsupported for family " << addr->Addr()->sa_family; +} + +void SetSocketToS(SOCKET s, int tos) { + TOpaqueAddr addr; + + if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) { + ythrow TSystemError() << "getsockname() failed"; + } + + SetSocketToS(s, &addr, tos); +} + +bool HasLocalAddress(SOCKET socket) { + TOpaqueAddr localAddr; + if (getsockname(socket, localAddr.MutableAddr(), localAddr.LenPtr()) != 0) { + ythrow TSystemError() << "HasLocalAddress: getsockname() failed. "; + } + if (IsLoopback(localAddr)) { + return true; + } + + TOpaqueAddr remoteAddr; + if (getpeername(socket, remoteAddr.MutableAddr(), remoteAddr.LenPtr()) != 0) { + ythrow TSystemError() << "HasLocalAddress: getpeername() failed. "; + } + return IsSame(localAddr, remoteAddr); +} + +namespace { +#if defined(_linux_) + #if !defined(TCP_FASTOPEN) + #define TCP_FASTOPEN 23 + #endif +#endif + +#if defined(TCP_FASTOPEN) + struct TTcpFastOpenFeature { + inline TTcpFastOpenFeature() + : HasFastOpen_(false) + { + TSocketHolder tmp(socket(AF_INET, SOCK_STREAM, 0)); + int val = 1; + int ret = SetSockOpt(tmp, IPPROTO_TCP, TCP_FASTOPEN, val); + HasFastOpen_ = (ret == 0); + } + + inline void SetFastOpen(SOCKET s, int qlen) const { + if (HasFastOpen_) { + CheckedSetSockOpt(s, IPPROTO_TCP, TCP_FASTOPEN, qlen, "setting TCP_FASTOPEN"); + } + } + + static inline const TTcpFastOpenFeature* Instance() noexcept { + return Singleton<TTcpFastOpenFeature>(); + } + + bool HasFastOpen_; + }; +#endif +} + +void SetTcpFastOpen(SOCKET s, int qlen) { +#if defined(TCP_FASTOPEN) + TTcpFastOpenFeature::Instance()->SetFastOpen(s, qlen); +#else + Y_UNUSED(s); + Y_UNUSED(qlen); +#endif +} + +static bool IsBlocked(int lasterr) noexcept { + return lasterr == EAGAIN || lasterr == EWOULDBLOCK; +} + +struct TUnblockingGuard { + SOCKET S_; + + TUnblockingGuard(SOCKET s) + : S_(s) + { + SetNonBlock(S_, true); + } + + ~TUnblockingGuard() { + SetNonBlock(S_, false); + } +}; + +static int MsgPeek(SOCKET s) { + int flags = MSG_PEEK; + +#if defined(_win_) + TUnblockingGuard unblocker(s); + Y_UNUSED(unblocker); +#else + flags |= MSG_DONTWAIT; +#endif + + char c; + return recv(s, &c, 1, flags); +} + +bool IsNotSocketClosedByOtherSide(SOCKET s) { + return HasSocketDataToRead(s) != ESocketReadStatus::SocketClosed; +} + +ESocketReadStatus HasSocketDataToRead(SOCKET s) { + const int r = MsgPeek(s); + if (r == -1 && IsBlocked(LastSystemError())) { + return ESocketReadStatus::NoData; + } + if (r > 0) { + return ESocketReadStatus::HasData; + } + return ESocketReadStatus::SocketClosed; +} + +#if defined(_win_) +static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) { + return writev(sock, iov, iovcnt); +} +#else +static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) { + struct msghdr message; + + Zero(message); + message.msg_iov = const_cast<struct iovec*>(iov); + message.msg_iovlen = iovcnt; + + return sendmsg(sock, &message, MSG_NOSIGNAL); +} +#endif + +void TSocketHolder::Close() noexcept { + if (Fd_ != INVALID_SOCKET) { + bool ok = (closesocket(Fd_) == 0); + if (!ok) { +// Do not quietly close bad descriptor, +// because often it means double close +// that is disasterous +#ifdef _win_ + Y_VERIFY(WSAGetLastError() != WSAENOTSOCK, "must not quietly close bad socket descriptor"); +#elif defined(_unix_) + Y_VERIFY(errno != EBADF, "must not quietly close bad descriptor: fd=%d", int(Fd_)); +#else + #error unsupported platform +#endif + } + + Fd_ = INVALID_SOCKET; + } +} + +class TSocket::TImpl: public TAtomicRefCount<TImpl> { + using TOps = TSocket::TOps; + +public: + inline TImpl(SOCKET fd, TOps* ops) + : Fd_(fd) + , Ops_(ops) + { + } + + inline ~TImpl() = default; + + inline SOCKET Fd() const noexcept { + return Fd_; + } + + inline ssize_t Send(const void* data, size_t len) { + return Ops_->Send(Fd_, data, len); + } + + inline ssize_t Recv(void* buf, size_t len) { + return Ops_->Recv(Fd_, buf, len); + } + + inline ssize_t SendV(const TPart* parts, size_t count) { + return Ops_->SendV(Fd_, parts, count); + } + + inline void Close() { + Fd_.Close(); + } + +private: + TSocketHolder Fd_; + TOps* Ops_; +}; + +template <> +void Out<const struct addrinfo*>(IOutputStream& os, const struct addrinfo* ai) { + if (ai->ai_flags & AI_CANONNAME) { + os << "`" << ai->ai_canonname << "' "; + } + + os << '['; + for (int i = 0; ai; ++i, ai = ai->ai_next) { + if (i > 0) { + os << ", "; + } + + os << (const IRemoteAddr&)TAddrInfo(ai); + } + os << ']'; +} + +template <> +void Out<struct addrinfo*>(IOutputStream& os, struct addrinfo* ai) { + Out<const struct addrinfo*>(os, static_cast<const struct addrinfo*>(ai)); +} + +template <> +void Out<TNetworkAddress>(IOutputStream& os, const TNetworkAddress& addr) { + os << &*addr.Begin(); +} + +static inline const struct addrinfo* Iterate(const struct addrinfo* addr, const struct addrinfo* addr0, const int sockerr) { + if (addr->ai_next) { + return addr->ai_next; + } + + ythrow TSystemError(sockerr) << "can not connect to " << addr0; +} + +static inline SOCKET DoConnectImpl(const struct addrinfo* res, const TInstant& deadLine) { + const struct addrinfo* addr0 = res; + + while (res) { + TSocketHolder s(socket(res->ai_family, res->ai_socktype, res->ai_protocol)); + + if (s.Closed()) { + res = Iterate(res, addr0, LastSystemError()); + + continue; + } + + SetNonBlock(s, true); + + if (connect(s, res->ai_addr, (int)res->ai_addrlen)) { + int err = LastSystemError(); + + if (err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK) { + /* + * must wait + */ + struct pollfd p = { + (SOCKET)s, + POLLOUT, + 0}; + + const ssize_t n = PollD(&p, 1, deadLine); + + /* + * timeout occured + */ + if (n < 0) { + ythrow TSystemError(-(int)n) << "can not connect"; + } + + CheckedGetSockOpt(s, SOL_SOCKET, SO_ERROR, err, "socket error"); + + if (!err) { + return s.Release(); + } + } + + res = Iterate(res, addr0, err); + + continue; + } + + return s.Release(); + } + + ythrow yexception() << "something went wrong: nullptr at addrinfo"; +} + +static inline SOCKET DoConnect(const struct addrinfo* res, const TInstant& deadLine) { + TSocketHolder ret(DoConnectImpl(res, deadLine)); + + SetNonBlock(ret, false); + + return ret.Release(); +} + +static inline ssize_t DoSendV(SOCKET fd, const struct iovec* iov, size_t count) { + ssize_t ret = -1; + do { + ret = DoSendMsg(fd, iov, (int)count); + } while (ret == -1 && errno == EINTR); + + if (ret < 0) { + return -LastSystemError(); + } + + return ret; +} + +template <bool isCompat> +struct TSender { + using TPart = TSocket::TPart; + + static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) { + return DoSendV(fd, (const iovec*)parts, count); + } +}; + +template <> +struct TSender<false> { + using TPart = TSocket::TPart; + + static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) { + TTempBuf tempbuf(sizeof(struct iovec) * count); + struct iovec* iov = (struct iovec*)tempbuf.Data(); + + for (size_t i = 0; i < count; ++i) { + struct iovec& io = iov[i]; + const TPart& part = parts[i]; + + io.iov_base = (char*)part.buf; + io.iov_len = part.len; + } + + return DoSendV(fd, iov, count); + } +}; + +class TCommonSockOps: public TSocket::TOps { + using TPart = TSocket::TPart; + +public: + inline TCommonSockOps() noexcept { + } + + ~TCommonSockOps() override = default; + + ssize_t Send(SOCKET fd, const void* data, size_t len) override { + ssize_t ret = -1; + do { + ret = send(fd, (const char*)data, (int)len, MSG_NOSIGNAL); + } while (ret == -1 && errno == EINTR); + + if (ret < 0) { + return -LastSystemError(); + } + + return ret; + } + + ssize_t Recv(SOCKET fd, void* buf, size_t len) override { + ssize_t ret = -1; + do { + ret = recv(fd, (char*)buf, (int)len, 0); + } while (ret == -1 && errno == EINTR); + + if (ret < 0) { + return -LastSystemError(); + } + + return ret; + } + + ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) override { + ssize_t ret = SendVImpl(fd, parts, count); + + if (ret < 0) { + return ret; + } + + size_t len = TContIOVector::Bytes(parts, count); + + if ((size_t)ret == len) { + return ret; + } + + return SendVPartial(fd, parts, count, ret); + } + + inline ssize_t SendVImpl(SOCKET fd, const TPart* parts, size_t count) { + return TSender < (sizeof(iovec) == sizeof(TPart)) && (offsetof(iovec, iov_base) == offsetof(TPart, buf)) && (offsetof(iovec, iov_len) == offsetof(TPart, len)) > ::SendV(fd, parts, count); + } + + ssize_t SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written); +}; + +ssize_t TCommonSockOps::SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written) { + TTempBuf tempbuf(sizeof(TPart) * count); + TPart* parts = (TPart*)tempbuf.Data(); + + for (size_t i = 0; i < count; ++i) { + parts[i] = constParts[i]; + } + + TContIOVector vec(parts, count); + vec.Proceed(written); + + while (!vec.Complete()) { + ssize_t ret = SendVImpl(fd, vec.Parts(), vec.Count()); + + if (ret < 0) { + return ret; + } + + written += ret; + + vec.Proceed((size_t)ret); + } + + return written; +} + +static inline TSocket::TOps* GetCommonSockOps() noexcept { + return Singleton<TCommonSockOps>(); +} + +TSocket::TSocket() + : Impl_(new TImpl(INVALID_SOCKET, GetCommonSockOps())) +{ +} + +TSocket::TSocket(SOCKET fd) + : Impl_(new TImpl(fd, GetCommonSockOps())) +{ +} + +TSocket::TSocket(SOCKET fd, TOps* ops) + : Impl_(new TImpl(fd, ops)) +{ +} + +TSocket::TSocket(const TNetworkAddress& addr) + : Impl_(new TImpl(DoConnect(addr.Info(), TInstant::Max()), GetCommonSockOps())) +{ +} + +TSocket::TSocket(const TNetworkAddress& addr, const TDuration& timeOut) + : Impl_(new TImpl(DoConnect(addr.Info(), timeOut.ToDeadLine()), GetCommonSockOps())) +{ +} + +TSocket::TSocket(const TNetworkAddress& addr, const TInstant& deadLine) + : Impl_(new TImpl(DoConnect(addr.Info(), deadLine), GetCommonSockOps())) +{ +} + +TSocket::~TSocket() = default; + +SOCKET TSocket::Fd() const noexcept { + return Impl_->Fd(); +} + +ssize_t TSocket::Send(const void* data, size_t len) { + return Impl_->Send(data, len); +} + +ssize_t TSocket::Recv(void* buf, size_t len) { + return Impl_->Recv(buf, len); +} + +ssize_t TSocket::SendV(const TPart* parts, size_t count) { + return Impl_->SendV(parts, count); +} + +void TSocket::Close() { + Impl_->Close(); +} + +TSocketInput::TSocketInput(const TSocket& s) noexcept + : S_(s) +{ +} + +TSocketInput::~TSocketInput() = default; + +size_t TSocketInput::DoRead(void* buf, size_t len) { + const ssize_t ret = S_.Recv(buf, len); + + if (ret >= 0) { + return (size_t)ret; + } + + ythrow TSystemError(-(int)ret) << "can not read from socket input stream"; +} + +TSocketOutput::TSocketOutput(const TSocket& s) noexcept + : S_(s) +{ +} + +TSocketOutput::~TSocketOutput() { + try { + Finish(); + } catch (...) { + // ¯\_(ツ)_/¯ + } +} + +void TSocketOutput::DoWrite(const void* buf, size_t len) { + size_t send = 0; + while (len) { + const ssize_t ret = S_.Send(buf, len); + + if (ret < 0) { + ythrow TSystemError(-(int)ret) << "can not write to socket output stream; " << send << " bytes already send"; + } + buf = (const char*)buf + ret; + len -= ret; + send += ret; + } +} + +void TSocketOutput::DoWriteV(const TPart* parts, size_t count) { + const ssize_t ret = S_.SendV(parts, count); + + if (ret < 0) { + ythrow TSystemError(-(int)ret) << "can not writev to socket output stream"; + } + + /* + * todo for nonblocking sockets? + */ +} + +namespace { + //https://bugzilla.mozilla.org/attachment.cgi?id=503263&action=diff + + struct TLocalNames: public THashSet<TStringBuf> { + inline TLocalNames() { + insert("localhost"); + insert("localhost.localdomain"); + insert("localhost6"); + insert("localhost6.localdomain6"); + insert("::1"); + } + + inline bool IsLocalName(const char* name) const noexcept { + struct sockaddr_in sa; + memset(&sa, 0, sizeof(sa)); + + if (inet_pton(AF_INET, name, &(sa.sin_addr)) == 1) { + return (InetToHost(sa.sin_addr.s_addr) >> 24) == 127; + } + + return contains(name); + } + }; +} + +class TNetworkAddress::TImpl: public TAtomicRefCount<TImpl> { +private: + class TAddrInfoDeleter { + public: + TAddrInfoDeleter(bool useFreeAddrInfo = true) + : UseFreeAddrInfo_(useFreeAddrInfo) + { + } + + void operator()(struct addrinfo* ai) noexcept { + if (!UseFreeAddrInfo_ && ai != NULL) { + if (ai->ai_addr != NULL) { + free(ai->ai_addr); + } + + struct addrinfo* p; + while (ai != NULL) { + p = ai; + ai = ai->ai_next; + free(p->ai_canonname); + free(p); + } + } else if (ai != NULL) { + freeaddrinfo(ai); + } + } + + private: + bool UseFreeAddrInfo_ = true; + }; + +public: + inline TImpl(const char* host, ui16 port, int flags) + : Info_(nullptr, TAddrInfoDeleter{}) + { + const TString port_st(ToString(port)); + struct addrinfo hints; + + memset(&hints, 0, sizeof(hints)); + + hints.ai_flags = flags; + hints.ai_family = PF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + + if (!host) { + hints.ai_flags |= AI_PASSIVE; + } else { + if (!Singleton<TLocalNames>()->IsLocalName(host)) { + hints.ai_flags |= AI_ADDRCONFIG; + } + } + + struct addrinfo* pai = NULL; + const int error = getaddrinfo(host, port_st.data(), &hints, &pai); + + if (error) { + TAddrInfoDeleter()(pai); + ythrow TNetworkResolutionError(error) << ": can not resolve " << host << ":" << port; + } + + Info_.reset(pai); + } + + inline TImpl(const char* path, int flags) + : Info_(nullptr, TAddrInfoDeleter{/* useFreeAddrInfo = */ false}) + { + THolder<struct sockaddr_un, TFree> sockAddr( + reinterpret_cast<struct sockaddr_un*>(malloc(sizeof(struct sockaddr_un)))); + + Y_ENSURE(strlen(path) < sizeof(sockAddr->sun_path), "Unix socket path more than " << sizeof(sockAddr->sun_path)); + sockAddr->sun_family = AF_UNIX; + strcpy(sockAddr->sun_path, path); + + TAddrInfoPtr hints(reinterpret_cast<struct addrinfo*>(malloc(sizeof(struct addrinfo))), TAddrInfoDeleter{/* useFreeAddrInfo = */ false}); + memset(hints.get(), 0, sizeof(*hints)); + + hints->ai_flags = flags; + hints->ai_family = AF_UNIX; + hints->ai_socktype = SOCK_STREAM; + hints->ai_addrlen = sizeof(*sockAddr); + hints->ai_addr = (struct sockaddr*)sockAddr.Release(); + + Info_.reset(hints.release()); + } + + inline struct addrinfo* Info() const noexcept { + return Info_.get(); + } + +private: + using TAddrInfoPtr = std::unique_ptr<struct addrinfo, TAddrInfoDeleter>; + + TAddrInfoPtr Info_; +}; + +TNetworkAddress::TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags) + : Impl_(new TImpl(unixSocketPath.Path.data(), flags)) +{ +} + +TNetworkAddress::TNetworkAddress(const TString& host, ui16 port, int flags) + : Impl_(new TImpl(host.data(), port, flags)) +{ +} + +TNetworkAddress::TNetworkAddress(const TString& host, ui16 port) + : Impl_(new TImpl(host.data(), port, 0)) +{ +} + +TNetworkAddress::TNetworkAddress(ui16 port) + : Impl_(new TImpl(nullptr, port, 0)) +{ +} + +TNetworkAddress::~TNetworkAddress() = default; + +struct addrinfo* TNetworkAddress::Info() const noexcept { + return Impl_->Info(); +} + +TNetworkResolutionError::TNetworkResolutionError(int error) { + const char* errMsg = nullptr; +#ifdef _win_ + errMsg = LastSystemErrorText(error); // gai_strerror is not thread-safe on Windows +#else + errMsg = gai_strerror(error); +#endif + (*this) << errMsg << "(" << error; + +#if defined(_unix_) + if (error == EAI_SYSTEM) { + (*this) << "; errno=" << LastSystemError(); + } +#endif + + (*this) << "): "; +} + +#if defined(_unix_) +static inline int GetFlags(int fd) { + const int ret = fcntl(fd, F_GETFL); + + if (ret == -1) { + ythrow TSystemError() << "can not get fd flags"; + } + + return ret; +} + +static inline void SetFlags(int fd, int flags) { + if (fcntl(fd, F_SETFL, flags) == -1) { + ythrow TSystemError() << "can not set fd flags"; + } +} + +static inline void EnableFlag(int fd, int flag) { + const int oldf = GetFlags(fd); + const int newf = oldf | flag; + + if (oldf != newf) { + SetFlags(fd, newf); + } +} + +static inline void DisableFlag(int fd, int flag) { + const int oldf = GetFlags(fd); + const int newf = oldf & (~flag); + + if (oldf != newf) { + SetFlags(fd, newf); + } +} + +static inline void SetFlag(int fd, int flag, bool value) { + if (value) { + EnableFlag(fd, flag); + } else { + DisableFlag(fd, flag); + } +} + +static inline bool FlagsAreEnabled(int fd, int flags) { + return GetFlags(fd) & flags; +} +#endif + +#if defined(_win_) +static inline void SetNonBlockSocket(SOCKET fd, int value) { + unsigned long inbuf = value; + unsigned long outbuf = 0; + DWORD written = 0; + + if (!inbuf) { + WSAEventSelect(fd, nullptr, 0); + } + + if (WSAIoctl(fd, FIONBIO, &inbuf, sizeof(inbuf), &outbuf, sizeof(outbuf), &written, 0, 0) == SOCKET_ERROR) { + ythrow TSystemError() << "can not set non block socket state"; + } +} + +static inline bool IsNonBlockSocket(SOCKET fd) { + unsigned long buf = 0; + + if (WSAIoctl(fd, FIONBIO, 0, 0, &buf, sizeof(buf), 0, 0, 0) == SOCKET_ERROR) { + ythrow TSystemError() << "can not get non block socket state"; + } + + return buf; +} +#endif + +void SetNonBlock(SOCKET fd, bool value) { +#if defined(_unix_) + #if defined(FIONBIO) + Y_UNUSED(SetFlag); // shut up clang about unused function + int nb = value; + + if (ioctl(fd, FIONBIO, &nb) < 0) { + ythrow TSystemError() << "ioctl failed"; + } + #else + SetFlag(fd, O_NONBLOCK, value); + #endif +#elif defined(_win_) + SetNonBlockSocket(fd, value); +#else + #error todo +#endif +} + +bool IsNonBlock(SOCKET fd) { +#if defined(_unix_) + return FlagsAreEnabled(fd, O_NONBLOCK); +#elif defined(_win_) + return IsNonBlockSocket(fd); +#else + #error todo +#endif +} + +void SetDeferAccept(SOCKET s) { + (void)s; + +#if defined(TCP_DEFER_ACCEPT) + CheckedSetSockOpt(s, IPPROTO_TCP, TCP_DEFER_ACCEPT, 10, "defer accept"); +#endif + +#if defined(SO_ACCEPTFILTER) + struct accept_filter_arg afa; + + Zero(afa); + strcpy(afa.af_name, "dataready"); + SetSockOpt(s, SOL_SOCKET, SO_ACCEPTFILTER, afa); +#endif +} + +ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept { + TInstant now = TInstant::Now(); + + do { + const TDuration toWait = PollStep(deadLine, now); + const int res = poll(fds, nfds, MicroToMilli(toWait.MicroSeconds())); + + if (res > 0) { + return res; + } + + if (res < 0) { + const int err = LastSystemError(); + + if (err != ETIMEDOUT && err != EINTR) { + return -err; + } + } + } while ((now = TInstant::Now()) < deadLine); + + return -ETIMEDOUT; +} + +void ShutDown(SOCKET s, int mode) { + if (shutdown(s, mode)) { + ythrow TSystemError() << "shutdown socket error"; + } +} + +extern "C" bool IsReusePortAvailable() { +// SO_REUSEPORT is always defined for linux builds, see SetReusePort() implementation above +#if defined(SO_REUSEPORT) + + class TCtx { + public: + TCtx() { + TSocketHolder sock(::socket(AF_INET, SOCK_STREAM, 0)); + const int e1 = errno; + if (sock == INVALID_SOCKET) { + ythrow TSystemError(e1) << "Cannot create AF_INET socket"; + } + int val; + const int ret = GetSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, val); + const int e2 = errno; + if (ret == 0) { + Flag_ = true; + } else { + if (e2 == ENOPROTOOPT) { + Flag_ = false; + } else { + ythrow TSystemError(e2) << "Unexpected error in getsockopt"; + } + } + } + + static inline const TCtx* Instance() noexcept { + return Singleton<TCtx>(); + } + + public: + bool Flag_; + }; + + return TCtx::Instance()->Flag_; +#else + return false; +#endif +} diff --git a/util/network/socket.h b/util/network/socket.h new file mode 100644 index 0000000000..357ad4079b --- /dev/null +++ b/util/network/socket.h @@ -0,0 +1,431 @@ +#pragma once + +#include "init.h" + +#include <util/system/yassert.h> +#include <util/system/defaults.h> +#include <util/system/error.h> +#include <util/stream/output.h> +#include <util/stream/input.h> +#include <util/generic/ptr.h> +#include <util/generic/yexception.h> +#include <util/generic/noncopyable.h> +#include <util/datetime/base.h> + +#include <cerrno> + +#ifndef INET_ADDRSTRLEN + #define INET_ADDRSTRLEN 16 +#endif + +#if defined(_unix_) + #define get_host_error() h_errno +#elif defined(_win_) + #pragma comment(lib, "Ws2_32.lib") + + #if _WIN32_WINNT < 0x0600 +struct pollfd { + SOCKET fd; + short events; + short revents; +}; + + #define POLLIN (1 << 0) + #define POLLRDNORM (1 << 1) + #define POLLRDBAND (1 << 2) + #define POLLPRI (1 << 3) + #define POLLOUT (1 << 4) + #define POLLWRNORM (1 << 5) + #define POLLWRBAND (1 << 6) + #define POLLERR (1 << 7) + #define POLLHUP (1 << 8) + #define POLLNVAL (1 << 9) + +const char* inet_ntop(int af, const void* src, char* dst, socklen_t size); +int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept; + #else + #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) + #endif + +int inet_aton(const char* cp, struct in_addr* inp); + + #define get_host_error() WSAGetLastError() + + #define SHUT_RD SD_RECEIVE + #define SHUT_WR SD_SEND + #define SHUT_RDWR SD_BOTH + + #define INFTIM (-1) +#endif + +template <class T> +static inline int SetSockOpt(SOCKET s, int level, int optname, T opt) noexcept { + return setsockopt(s, level, optname, (const char*)&opt, sizeof(opt)); +} + +template <class T> +static inline int GetSockOpt(SOCKET s, int level, int optname, T& opt) noexcept { + socklen_t len = sizeof(opt); + + return getsockopt(s, level, optname, (char*)&opt, &len); +} + +template <class T> +static inline void CheckedSetSockOpt(SOCKET s, int level, int optname, T opt, const char* err) { + if (SetSockOpt<T>(s, level, optname, opt)) { + ythrow TSystemError() << "setsockopt() failed for " << err; + } +} + +template <class T> +static inline void CheckedGetSockOpt(SOCKET s, int level, int optname, T& opt, const char* err) { + if (GetSockOpt<T>(s, level, optname, opt)) { + ythrow TSystemError() << "getsockopt() failed for " << err; + } +} + +static inline void FixIPv6ListenSocket(SOCKET s) { +#if defined(IPV6_V6ONLY) + SetSockOpt(s, IPPROTO_IPV6, IPV6_V6ONLY, 1); +#else + (void)s; +#endif +} + +namespace NAddr { + class IRemoteAddr; +} + +void SetSocketTimeout(SOCKET s, long timeout); +void SetSocketTimeout(SOCKET s, long sec, long msec); +void SetNoDelay(SOCKET s, bool value); +void SetKeepAlive(SOCKET s); +void SetLinger(SOCKET s, bool on, unsigned len); +void SetZeroLinger(SOCKET s); +void SetKeepAlive(SOCKET s, bool value); +void SetCloseOnExec(SOCKET s, bool value); +void SetOutputBuffer(SOCKET s, unsigned value); +void SetInputBuffer(SOCKET s, unsigned value); +void SetReusePort(SOCKET s, bool value); +void ShutDown(SOCKET s, int mode); +bool GetRemoteAddr(SOCKET s, char* str, socklen_t size); +size_t GetMaximumSegmentSize(SOCKET s); +size_t GetMaximumTransferUnit(SOCKET s); +void SetDeferAccept(SOCKET s); +void SetSocketToS(SOCKET s, int tos); +void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos); +int GetSocketToS(SOCKET s); +int GetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr); +void SetTcpFastOpen(SOCKET s, int qlen); +/** + * Deprecated, consider using HasSocketDataToRead instead. + **/ +bool IsNotSocketClosedByOtherSide(SOCKET s); +enum class ESocketReadStatus { + HasData, + NoData, + SocketClosed +}; +/** + * Useful for keep-alive connections. + **/ +ESocketReadStatus HasSocketDataToRead(SOCKET s); +/** + * Determines whether connection on socket is local (same machine) or not. + **/ +bool HasLocalAddress(SOCKET socket); + +/** + * Runtime check if current kernel supports SO_REUSEPORT option. + **/ +extern "C" bool IsReusePortAvailable(); + +bool IsNonBlock(SOCKET fd); +void SetNonBlock(SOCKET fd, bool nonBlock = true); + +struct addrinfo; + +class TNetworkResolutionError: public yexception { +public: + // @param error error code (EAI_XXX) returned by getaddrinfo or getnameinfo (not errno) + TNetworkResolutionError(int error); +}; + +struct TUnixSocketPath { + TString Path; + + // Constructor for create unix domain socket path from string with path in filesystem + // TUnixSocketPath("/tmp/unixsocket") -> "/tmp/unixsocket" + explicit TUnixSocketPath(const TString& path) + : Path(path) + { + } +}; + +class TNetworkAddress { + friend class TSocket; + +public: + class TIterator { + public: + inline TIterator(struct addrinfo* begin) + : C_(begin) + { + } + + inline void Next() noexcept { + C_ = C_->ai_next; + } + + inline TIterator operator++(int) noexcept { + TIterator old(*this); + + Next(); + + return old; + } + + inline TIterator& operator++() noexcept { + Next(); + + return *this; + } + + friend inline bool operator==(const TIterator& l, const TIterator& r) noexcept { + return l.C_ == r.C_; + } + + friend inline bool operator!=(const TIterator& l, const TIterator& r) noexcept { + return !(l == r); + } + + inline struct addrinfo& operator*() const noexcept { + return *C_; + } + + inline struct addrinfo* operator->() const noexcept { + return C_; + } + + private: + struct addrinfo* C_; + }; + + TNetworkAddress(ui16 port); + TNetworkAddress(const TString& host, ui16 port); + TNetworkAddress(const TString& host, ui16 port, int flags); + TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags = 0); + ~TNetworkAddress(); + + inline TIterator Begin() const noexcept { + return TIterator(Info()); + } + + inline TIterator End() const noexcept { + return TIterator(nullptr); + } + +private: + struct addrinfo* Info() const noexcept; + +private: + class TImpl; + TSimpleIntrusivePtr<TImpl> Impl_; +}; + +class TSocket; + +class TSocketHolder: public TMoveOnly { +public: + inline TSocketHolder() + : Fd_(INVALID_SOCKET) + { + } + + inline TSocketHolder(SOCKET fd) + : Fd_(fd) + { + } + + inline TSocketHolder(TSocketHolder&& other) noexcept { + Fd_ = other.Fd_; + other.Fd_ = INVALID_SOCKET; + } + + inline TSocketHolder& operator=(TSocketHolder&& other) noexcept { + Close(); + Swap(other); + + return *this; + } + + inline ~TSocketHolder() { + Close(); + } + + inline SOCKET Release() noexcept { + SOCKET ret = Fd_; + Fd_ = INVALID_SOCKET; + return ret; + } + + void Close() noexcept; + + inline void ShutDown(int mode) const { + ::ShutDown(Fd_, mode); + } + + inline void Swap(TSocketHolder& r) noexcept { + DoSwap(Fd_, r.Fd_); + } + + inline bool Closed() const noexcept { + return Fd_ == INVALID_SOCKET; + } + + inline operator SOCKET() const noexcept { + return Fd_; + } + +private: + SOCKET Fd_; + + // do not allow construction of TSocketHolder from TSocket + TSocketHolder(const TSocket& fd); +}; + +class TSocket { +public: + using TPart = IOutputStream::TPart; + + class TOps { + public: + inline TOps() noexcept = default; + virtual ~TOps() = default; + + virtual ssize_t Send(SOCKET fd, const void* data, size_t len) = 0; + virtual ssize_t Recv(SOCKET fd, void* buf, size_t len) = 0; + virtual ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) = 0; + }; + + TSocket(); + TSocket(SOCKET fd); + TSocket(SOCKET fd, TOps* ops); + TSocket(const TNetworkAddress& addr); + TSocket(const TNetworkAddress& addr, const TDuration& timeOut); + TSocket(const TNetworkAddress& addr, const TInstant& deadLine); + + ~TSocket(); + + template <class T> + inline void SetSockOpt(int level, int optname, T opt) { + CheckedSetSockOpt(Fd(), level, optname, opt, "TSocket"); + } + + inline void SetSocketTimeout(long timeout) { + ::SetSocketTimeout(Fd(), timeout); + } + + inline void SetSocketTimeout(long sec, long msec) { + ::SetSocketTimeout(Fd(), sec, msec); + } + + inline void SetNoDelay(bool value) { + ::SetNoDelay(Fd(), value); + } + + inline void SetLinger(bool on, unsigned len) { + ::SetLinger(Fd(), on, len); + } + + inline void SetZeroLinger() { + ::SetZeroLinger(Fd()); + } + + inline void SetKeepAlive(bool value) { + ::SetKeepAlive(Fd(), value); + } + + inline void SetOutputBuffer(unsigned value) { + ::SetOutputBuffer(Fd(), value); + } + + inline void SetInputBuffer(unsigned value) { + ::SetInputBuffer(Fd(), value); + } + + inline size_t MaximumSegmentSize() const { + return GetMaximumSegmentSize(Fd()); + } + + inline size_t MaximumTransferUnit() const { + return GetMaximumTransferUnit(Fd()); + } + + inline void ShutDown(int mode) const { + ::ShutDown(Fd(), mode); + } + + void Close(); + + ssize_t Send(const void* data, size_t len); + ssize_t Recv(void* buf, size_t len); + + /* + * scatter/gather io + */ + ssize_t SendV(const TPart* parts, size_t count); + + inline operator SOCKET() const noexcept { + return Fd(); + } + +private: + SOCKET Fd() const noexcept; + +private: + class TImpl; + TSimpleIntrusivePtr<TImpl> Impl_; +}; + +class TSocketInput: public IInputStream { +public: + TSocketInput(const TSocket& s) noexcept; + ~TSocketInput() override; + + TSocketInput(TSocketInput&&) noexcept = default; + TSocketInput& operator=(TSocketInput&&) noexcept = default; + + const TSocket& GetSocket() const noexcept { + return S_; + } + +private: + size_t DoRead(void* buf, size_t len) override; + +private: + TSocket S_; +}; + +class TSocketOutput: public IOutputStream { +public: + TSocketOutput(const TSocket& s) noexcept; + ~TSocketOutput() override; + + TSocketOutput(TSocketOutput&&) noexcept = default; + TSocketOutput& operator=(TSocketOutput&&) noexcept = default; + + const TSocket& GetSocket() const noexcept { + return S_; + } + +private: + void DoWrite(const void* buf, size_t len) override; + void DoWriteV(const TPart* parts, size_t count) override; + +private: + TSocket S_; +}; + +//return -(error code) if error occured, or number of ready fds +ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept; diff --git a/util/network/socket_ut.cpp b/util/network/socket_ut.cpp new file mode 100644 index 0000000000..6b20e11f70 --- /dev/null +++ b/util/network/socket_ut.cpp @@ -0,0 +1,341 @@ +#include "socket.h" + +#include "pair.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/string/builder.h> +#include <util/generic/vector.h> + +#include <ctime> + +#ifdef _linux_ + #include <linux/version.h> + #include <sys/utsname.h> +#endif + +class TSockTest: public TTestBase { + UNIT_TEST_SUITE(TSockTest); + UNIT_TEST(TestSock); + UNIT_TEST(TestTimeout); +#ifndef _win_ // Test hangs on Windows + UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception); +#endif + UNIT_TEST(TestNetworkResolutionError); + UNIT_TEST(TestNetworkResolutionErrorMessage); + UNIT_TEST(TestBrokenPipe); + UNIT_TEST(TestClose); + UNIT_TEST(TestReusePortAvailCheck); + UNIT_TEST_SUITE_END(); + +public: + void TestSock(); + void TestTimeout(); + void TestConnectionRefused(); + void TestNetworkResolutionError(); + void TestNetworkResolutionErrorMessage(); + void TestBrokenPipe(); + void TestClose(); + void TestReusePortAvailCheck(); +}; + +UNIT_TEST_SUITE_REGISTRATION(TSockTest); + +void TSockTest::TestSock() { + TNetworkAddress addr("yandex.ru", 80); + TSocket s(addr); + TSocketOutput so(s); + TSocketInput si(s); + const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n"; + + so.Write(req.data(), req.size()); + + UNIT_ASSERT(!si.ReadLine().empty()); +} + +void TSockTest::TestTimeout() { + static const int timeout = 1000; + i64 startTime = millisec(); + try { + TNetworkAddress addr("localhost", 1313); + TSocket s(addr, TDuration::MilliSeconds(timeout)); + } catch (const yexception&) { + } + int realTimeout = (int)(millisec() - startTime); + if (realTimeout > timeout + 2000) { + TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)"; + UNIT_FAIL(err); + } +} + +void TSockTest::TestConnectionRefused() { + TNetworkAddress addr("localhost", 1313); + TSocket s(addr); +} + +void TSockTest::TestNetworkResolutionError() { + TString errMsg; + try { + TNetworkAddress addr("", 0); + } catch (const TNetworkResolutionError& e) { + errMsg = e.what(); + } + + if (errMsg.empty()) { + return; // on Windows getaddrinfo("", 0, ...) returns "OK" + } + + int expectedErr = EAI_NONAME; + TString expectedErrMsg = gai_strerror(expectedErr); + if (errMsg.find(expectedErrMsg) == TString::npos) { + UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n"); + } +} + +void TSockTest::TestNetworkResolutionErrorMessage() { +#ifdef _unix_ + auto str = [](int code) -> TString { + return TNetworkResolutionError(code).what(); + }; + + auto expected = [](int code) -> TString { + return gai_strerror(code); + }; + + struct TErrnoGuard { + TErrnoGuard() + : PrevValue_(errno) + { + } + + ~TErrnoGuard() { + errno = PrevValue_; + } + + private: + int PrevValue_; + } g; + + UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0)); + UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9)); + + errno = 0; + UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ", + str(EAI_SYSTEM)); + errno = 110; + UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ", + str(EAI_SYSTEM)); +#endif +} + +class TTempEnableSigPipe { +public: + TTempEnableSigPipe() { + OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL); + Y_VERIFY(OriginalSigHandler_ != SIG_ERR); + } + + ~TTempEnableSigPipe() { + auto ret = signal(SIGPIPE, OriginalSigHandler_); + Y_VERIFY(ret != SIG_ERR); + } + +private: + void (*OriginalSigHandler_)(int); +}; + +void TSockTest::TestBrokenPipe() { + TTempEnableSigPipe guard; + + SOCKET socks[2]; + + int ret = SocketPair(socks); + UNIT_ASSERT_VALUES_EQUAL(ret, 0); + + TSocket sender(socks[0]); + TSocket receiver(socks[1]); + receiver.ShutDown(SHUT_RDWR); + int sent = sender.Send("FOO", 3); + UNIT_ASSERT(sent < 0); + + IOutputStream::TPart parts[] = { + {"foo", 3}, + {"bar", 3}, + }; + sent = sender.SendV(parts, 2); + UNIT_ASSERT(sent < 0); +} + +void TSockTest::TestClose() { + SOCKET socks[2]; + + UNIT_ASSERT_EQUAL(SocketPair(socks), 0); + TSocket receiver(socks[1]); + + UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), socks[1]); + +#if defined _linux_ + UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0); + receiver.Close(); + UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1); +#else + receiver.Close(); +#endif + + UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), INVALID_SOCKET); +} + +void TSockTest::TestReusePortAvailCheck() { +#if defined _linux_ + utsname sysInfo; + Y_VERIFY(!uname(&sysInfo), "Error while call uname: %s", LastSystemErrorText()); + TStringBuf release(sysInfo.release); + release = release.substr(0, release.find_first_not_of(".0123456789")); + int v1 = FromString<int>(release.NextTok('.')); + int v2 = FromString<int>(release.NextTok('.')); + int v3 = FromString<int>(release.NextTok('.')); + int linuxVersionCode = KERNEL_VERSION(v1, v2, v3); + if (linuxVersionCode >= KERNEL_VERSION(3, 9, 1)) { + // new kernels support SO_REUSEPORT + UNIT_ASSERT(true == IsReusePortAvailable()); + UNIT_ASSERT(true == IsReusePortAvailable()); + } else { + // older kernels may or may not support SO_REUSEPORT + // just check that it doesn't crash or throw + (void)IsReusePortAvailable(); + (void)IsReusePortAvailable(); + } +#else + // check that it doesn't crash or throw + (void)IsReusePortAvailable(); + (void)IsReusePortAvailable(); +#endif +} + +class TPollTest: public TTestBase { + UNIT_TEST_SUITE(TPollTest); + UNIT_TEST(TestPollInOut); + UNIT_TEST_SUITE_END(); + +public: + inline TPollTest() { + srand(static_cast<unsigned int>(time(nullptr))); + } + + void TestPollInOut(); + +private: + sockaddr_in GetAddress(ui32 ip, ui16 port); + SOCKET CreateSocket(); + SOCKET StartServerSocket(ui16 port, int backlog); + SOCKET StartClientSocket(ui32 ip, ui16 port); + SOCKET AcceptConnection(SOCKET serverSocket); +}; + +UNIT_TEST_SUITE_REGISTRATION(TPollTest); + +sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = htonl(ip); + return addr; +} + +SOCKET TPollTest::CreateSocket() { + SOCKET s = socket(AF_INET, SOCK_STREAM, 0); + if (s == INVALID_SOCKET) { + ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")"; + } + return s; +} + +SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) { + TSocketHolder s(CreateSocket()); + sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port); + if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { + ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")"; + } + if (listen(s, backlog) == SOCKET_ERROR) { + ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")"; + } + return s.Release(); +} + +SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) { + TSocketHolder s(CreateSocket()); + sockaddr_in addr = GetAddress(ip, port); + if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { + ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")"; + } + return s.Release(); +} + +SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) { + SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr); + if (connectedSocket == INVALID_SOCKET) { + ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")"; + } + return connectedSocket; +} + +void TPollTest::TestPollInOut() { +#ifdef _win_ + const size_t socketCount = 1000; + + ui16 port = static_cast<ui16>(1300 + rand() % 97); + TSocketHolder serverSocket = StartServerSocket(port, socketCount); + + ui32 localIp = ntohl(inet_addr("127.0.0.1")); + + TVector<TSimpleSharedPtr<TSocketHolder>> clientSockets; + TVector<TSimpleSharedPtr<TSocketHolder>> connectedSockets; + TVector<pollfd> fds; + + for (size_t i = 0; i < socketCount; ++i) { + TSimpleSharedPtr<TSocketHolder> clientSocket(new TSocketHolder(StartClientSocket(localIp, port))); + clientSockets.push_back(clientSocket); + + if (i % 5 == 0 || i % 5 == 2) { + char buffer = 'c'; + if (send(*clientSocket, &buffer, 1, 0) == -1) + ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")"; + } + + TSimpleSharedPtr<TSocketHolder> connectedSocket(new TSocketHolder(AcceptConnection(serverSocket))); + connectedSockets.push_back(connectedSocket); + + if (i % 5 == 2 || i % 5 == 3) { + closesocket(*clientSocket); + shutdown(*clientSocket, SD_BOTH); + } + } + + int expectedCount = 0; + for (size_t i = 0; i < connectedSockets.size(); ++i) { + pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast<SOCKET>(*connectedSockets[i]), POLLIN | POLLOUT, 0}; + fds.push_back(fd); + if (i % 5 != 4) + ++expectedCount; + } + + int polledCount = poll(&fds[0], fds.size(), INFTIM); + UNIT_ASSERT_EQUAL(expectedCount, polledCount); + + for (size_t i = 0; i < connectedSockets.size(); ++i) { + short revents = fds[i].revents; + if (i % 5 == 0) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLRDNORM | POLLWRNORM), revents); + } else if (i % 5 == 1) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLOUT | POLLWRNORM), revents); + } else if (i % 5 == 2) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLRDNORM | POLLWRNORM), revents); + } else if (i % 5 == 3) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLWRNORM), revents); + } else if (i % 5 == 4) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLNVAL), revents); + } + } +#endif +} diff --git a/util/network/ut/ya.make b/util/network/ut/ya.make new file mode 100644 index 0000000000..1ba03e167c --- /dev/null +++ b/util/network/ut/ya.make @@ -0,0 +1,23 @@ +UNITTEST_FOR(util) + +REQUIREMENTS(network:full) + +OWNER(g:util) +SUBSCRIBER(g:util-subscribers) + +PEERDIR( + library/cpp/threading/future +) + +SRCS( + network/address_ut.cpp + network/endpoint_ut.cpp + network/ip_ut.cpp + network/poller_ut.cpp + network/sock_ut.cpp + network/socket_ut.cpp +) + +INCLUDE(${ARCADIA_ROOT}/util/tests/ya_util_tests.inc) + +END() diff --git a/util/network/ya.make b/util/network/ya.make new file mode 100644 index 0000000000..79c9498ddd --- /dev/null +++ b/util/network/ya.make @@ -0,0 +1,6 @@ +OWNER(g:util) +SUBSCRIBER(g:util-subscribers) + +RECURSE_FOR_TESTS( + ut +) |