diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/network/socket_ut.cpp | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/network/socket_ut.cpp')
-rw-r--r-- | util/network/socket_ut.cpp | 341 |
1 files changed, 341 insertions, 0 deletions
diff --git a/util/network/socket_ut.cpp b/util/network/socket_ut.cpp new file mode 100644 index 0000000000..6b20e11f70 --- /dev/null +++ b/util/network/socket_ut.cpp @@ -0,0 +1,341 @@ +#include "socket.h" + +#include "pair.h" + +#include <library/cpp/testing/unittest/registar.h> + +#include <util/string/builder.h> +#include <util/generic/vector.h> + +#include <ctime> + +#ifdef _linux_ + #include <linux/version.h> + #include <sys/utsname.h> +#endif + +class TSockTest: public TTestBase { + UNIT_TEST_SUITE(TSockTest); + UNIT_TEST(TestSock); + UNIT_TEST(TestTimeout); +#ifndef _win_ // Test hangs on Windows + UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception); +#endif + UNIT_TEST(TestNetworkResolutionError); + UNIT_TEST(TestNetworkResolutionErrorMessage); + UNIT_TEST(TestBrokenPipe); + UNIT_TEST(TestClose); + UNIT_TEST(TestReusePortAvailCheck); + UNIT_TEST_SUITE_END(); + +public: + void TestSock(); + void TestTimeout(); + void TestConnectionRefused(); + void TestNetworkResolutionError(); + void TestNetworkResolutionErrorMessage(); + void TestBrokenPipe(); + void TestClose(); + void TestReusePortAvailCheck(); +}; + +UNIT_TEST_SUITE_REGISTRATION(TSockTest); + +void TSockTest::TestSock() { + TNetworkAddress addr("yandex.ru", 80); + TSocket s(addr); + TSocketOutput so(s); + TSocketInput si(s); + const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n"; + + so.Write(req.data(), req.size()); + + UNIT_ASSERT(!si.ReadLine().empty()); +} + +void TSockTest::TestTimeout() { + static const int timeout = 1000; + i64 startTime = millisec(); + try { + TNetworkAddress addr("localhost", 1313); + TSocket s(addr, TDuration::MilliSeconds(timeout)); + } catch (const yexception&) { + } + int realTimeout = (int)(millisec() - startTime); + if (realTimeout > timeout + 2000) { + TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)"; + UNIT_FAIL(err); + } +} + +void TSockTest::TestConnectionRefused() { + TNetworkAddress addr("localhost", 1313); + TSocket s(addr); +} + +void TSockTest::TestNetworkResolutionError() { + TString errMsg; + try { + TNetworkAddress addr("", 0); + } catch (const TNetworkResolutionError& e) { + errMsg = e.what(); + } + + if (errMsg.empty()) { + return; // on Windows getaddrinfo("", 0, ...) returns "OK" + } + + int expectedErr = EAI_NONAME; + TString expectedErrMsg = gai_strerror(expectedErr); + if (errMsg.find(expectedErrMsg) == TString::npos) { + UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n"); + } +} + +void TSockTest::TestNetworkResolutionErrorMessage() { +#ifdef _unix_ + auto str = [](int code) -> TString { + return TNetworkResolutionError(code).what(); + }; + + auto expected = [](int code) -> TString { + return gai_strerror(code); + }; + + struct TErrnoGuard { + TErrnoGuard() + : PrevValue_(errno) + { + } + + ~TErrnoGuard() { + errno = PrevValue_; + } + + private: + int PrevValue_; + } g; + + UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0)); + UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9)); + + errno = 0; + UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ", + str(EAI_SYSTEM)); + errno = 110; + UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ", + str(EAI_SYSTEM)); +#endif +} + +class TTempEnableSigPipe { +public: + TTempEnableSigPipe() { + OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL); + Y_VERIFY(OriginalSigHandler_ != SIG_ERR); + } + + ~TTempEnableSigPipe() { + auto ret = signal(SIGPIPE, OriginalSigHandler_); + Y_VERIFY(ret != SIG_ERR); + } + +private: + void (*OriginalSigHandler_)(int); +}; + +void TSockTest::TestBrokenPipe() { + TTempEnableSigPipe guard; + + SOCKET socks[2]; + + int ret = SocketPair(socks); + UNIT_ASSERT_VALUES_EQUAL(ret, 0); + + TSocket sender(socks[0]); + TSocket receiver(socks[1]); + receiver.ShutDown(SHUT_RDWR); + int sent = sender.Send("FOO", 3); + UNIT_ASSERT(sent < 0); + + IOutputStream::TPart parts[] = { + {"foo", 3}, + {"bar", 3}, + }; + sent = sender.SendV(parts, 2); + UNIT_ASSERT(sent < 0); +} + +void TSockTest::TestClose() { + SOCKET socks[2]; + + UNIT_ASSERT_EQUAL(SocketPair(socks), 0); + TSocket receiver(socks[1]); + + UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), socks[1]); + +#if defined _linux_ + UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0); + receiver.Close(); + UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1); +#else + receiver.Close(); +#endif + + UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), INVALID_SOCKET); +} + +void TSockTest::TestReusePortAvailCheck() { +#if defined _linux_ + utsname sysInfo; + Y_VERIFY(!uname(&sysInfo), "Error while call uname: %s", LastSystemErrorText()); + TStringBuf release(sysInfo.release); + release = release.substr(0, release.find_first_not_of(".0123456789")); + int v1 = FromString<int>(release.NextTok('.')); + int v2 = FromString<int>(release.NextTok('.')); + int v3 = FromString<int>(release.NextTok('.')); + int linuxVersionCode = KERNEL_VERSION(v1, v2, v3); + if (linuxVersionCode >= KERNEL_VERSION(3, 9, 1)) { + // new kernels support SO_REUSEPORT + UNIT_ASSERT(true == IsReusePortAvailable()); + UNIT_ASSERT(true == IsReusePortAvailable()); + } else { + // older kernels may or may not support SO_REUSEPORT + // just check that it doesn't crash or throw + (void)IsReusePortAvailable(); + (void)IsReusePortAvailable(); + } +#else + // check that it doesn't crash or throw + (void)IsReusePortAvailable(); + (void)IsReusePortAvailable(); +#endif +} + +class TPollTest: public TTestBase { + UNIT_TEST_SUITE(TPollTest); + UNIT_TEST(TestPollInOut); + UNIT_TEST_SUITE_END(); + +public: + inline TPollTest() { + srand(static_cast<unsigned int>(time(nullptr))); + } + + void TestPollInOut(); + +private: + sockaddr_in GetAddress(ui32 ip, ui16 port); + SOCKET CreateSocket(); + SOCKET StartServerSocket(ui16 port, int backlog); + SOCKET StartClientSocket(ui32 ip, ui16 port); + SOCKET AcceptConnection(SOCKET serverSocket); +}; + +UNIT_TEST_SUITE_REGISTRATION(TPollTest); + +sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) { + struct sockaddr_in addr; + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = htonl(ip); + return addr; +} + +SOCKET TPollTest::CreateSocket() { + SOCKET s = socket(AF_INET, SOCK_STREAM, 0); + if (s == INVALID_SOCKET) { + ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")"; + } + return s; +} + +SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) { + TSocketHolder s(CreateSocket()); + sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port); + if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { + ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")"; + } + if (listen(s, backlog) == SOCKET_ERROR) { + ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")"; + } + return s.Release(); +} + +SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) { + TSocketHolder s(CreateSocket()); + sockaddr_in addr = GetAddress(ip, port); + if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) { + ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")"; + } + return s.Release(); +} + +SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) { + SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr); + if (connectedSocket == INVALID_SOCKET) { + ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")"; + } + return connectedSocket; +} + +void TPollTest::TestPollInOut() { +#ifdef _win_ + const size_t socketCount = 1000; + + ui16 port = static_cast<ui16>(1300 + rand() % 97); + TSocketHolder serverSocket = StartServerSocket(port, socketCount); + + ui32 localIp = ntohl(inet_addr("127.0.0.1")); + + TVector<TSimpleSharedPtr<TSocketHolder>> clientSockets; + TVector<TSimpleSharedPtr<TSocketHolder>> connectedSockets; + TVector<pollfd> fds; + + for (size_t i = 0; i < socketCount; ++i) { + TSimpleSharedPtr<TSocketHolder> clientSocket(new TSocketHolder(StartClientSocket(localIp, port))); + clientSockets.push_back(clientSocket); + + if (i % 5 == 0 || i % 5 == 2) { + char buffer = 'c'; + if (send(*clientSocket, &buffer, 1, 0) == -1) + ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")"; + } + + TSimpleSharedPtr<TSocketHolder> connectedSocket(new TSocketHolder(AcceptConnection(serverSocket))); + connectedSockets.push_back(connectedSocket); + + if (i % 5 == 2 || i % 5 == 3) { + closesocket(*clientSocket); + shutdown(*clientSocket, SD_BOTH); + } + } + + int expectedCount = 0; + for (size_t i = 0; i < connectedSockets.size(); ++i) { + pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast<SOCKET>(*connectedSockets[i]), POLLIN | POLLOUT, 0}; + fds.push_back(fd); + if (i % 5 != 4) + ++expectedCount; + } + + int polledCount = poll(&fds[0], fds.size(), INFTIM); + UNIT_ASSERT_EQUAL(expectedCount, polledCount); + + for (size_t i = 0; i < connectedSockets.size(); ++i) { + short revents = fds[i].revents; + if (i % 5 == 0) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLRDNORM | POLLWRNORM), revents); + } else if (i % 5 == 1) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLOUT | POLLWRNORM), revents); + } else if (i % 5 == 2) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLRDNORM | POLLWRNORM), revents); + } else if (i % 5 == 3) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLWRNORM), revents); + } else if (i % 5 == 4) { + UNIT_ASSERT_EQUAL(static_cast<short>(POLLNVAL), revents); + } + } +#endif +} |