#include "socket.h"
#include "pair.h"
#include <library/cpp/testing/unittest/registar.h>
#include <util/string/builder.h>
#include <util/generic/vector.h>
#include <ctime>
#ifdef _linux_
#include <linux/version.h>
#include <sys/utsname.h>
#endif
class TSockTest: public TTestBase {
UNIT_TEST_SUITE(TSockTest);
UNIT_TEST(TestSock);
UNIT_TEST(TestTimeout);
#ifndef _win_ // Test hangs on Windows
UNIT_TEST_EXCEPTION(TestConnectionRefused, yexception);
#endif
UNIT_TEST(TestNetworkResolutionError);
UNIT_TEST(TestNetworkResolutionErrorMessage);
UNIT_TEST(TestBrokenPipe);
UNIT_TEST(TestClose);
UNIT_TEST(TestReusePortAvailCheck);
UNIT_TEST_SUITE_END();
public:
void TestSock();
void TestTimeout();
void TestConnectionRefused();
void TestNetworkResolutionError();
void TestNetworkResolutionErrorMessage();
void TestBrokenPipe();
void TestClose();
void TestReusePortAvailCheck();
};
UNIT_TEST_SUITE_REGISTRATION(TSockTest);
void TSockTest::TestSock() {
TNetworkAddress addr("yandex.ru", 80);
TSocket s(addr);
TSocketOutput so(s);
TSocketInput si(s);
const TStringBuf req = "GET / HTTP/1.1\r\nHost: yandex.ru\r\n\r\n";
so.Write(req.data(), req.size());
UNIT_ASSERT(!si.ReadLine().empty());
}
void TSockTest::TestTimeout() {
static const int timeout = 1000;
i64 startTime = millisec();
try {
TNetworkAddress addr("localhost", 1313);
TSocket s(addr, TDuration::MilliSeconds(timeout));
} catch (const yexception&) {
}
int realTimeout = (int)(millisec() - startTime);
if (realTimeout > timeout + 2000) {
TString err = TStringBuilder() << "Timeout exceeded: " << realTimeout << " ms (expected " << timeout << " ms)";
UNIT_FAIL(err);
}
}
void TSockTest::TestConnectionRefused() {
TNetworkAddress addr("localhost", 1313);
TSocket s(addr);
}
void TSockTest::TestNetworkResolutionError() {
TString errMsg;
try {
TNetworkAddress addr("", 0);
} catch (const TNetworkResolutionError& e) {
errMsg = e.what();
}
if (errMsg.empty()) {
return; // on Windows getaddrinfo("", 0, ...) returns "OK"
}
int expectedErr = EAI_NONAME;
TString expectedErrMsg = gai_strerror(expectedErr);
if (errMsg.find(expectedErrMsg) == TString::npos) {
UNIT_FAIL("TNetworkResolutionError contains\nInvalid msg: " + errMsg + "\nExpected msg: " + expectedErrMsg + "\n");
}
}
void TSockTest::TestNetworkResolutionErrorMessage() {
#ifdef _unix_
auto str = [](int code) -> TString {
return TNetworkResolutionError(code).what();
};
auto expected = [](int code) -> TString {
return gai_strerror(code);
};
struct TErrnoGuard {
TErrnoGuard()
: PrevValue_(errno)
{
}
~TErrnoGuard() {
errno = PrevValue_;
}
private:
int PrevValue_;
} g;
UNIT_ASSERT_VALUES_EQUAL(expected(0) + "(0): ", str(0));
UNIT_ASSERT_VALUES_EQUAL(expected(-9) + "(-9): ", str(-9));
errno = 0;
UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=0): ",
str(EAI_SYSTEM));
errno = 110;
UNIT_ASSERT_VALUES_EQUAL(expected(EAI_SYSTEM) + "(" + IntToString<10>(EAI_SYSTEM) + "; errno=110): ",
str(EAI_SYSTEM));
#endif
}
class TTempEnableSigPipe {
public:
TTempEnableSigPipe() {
OriginalSigHandler_ = signal(SIGPIPE, SIG_DFL);
Y_ABORT_UNLESS(OriginalSigHandler_ != SIG_ERR);
}
~TTempEnableSigPipe() {
auto ret = signal(SIGPIPE, OriginalSigHandler_);
Y_ABORT_UNLESS(ret != SIG_ERR);
}
private:
void (*OriginalSigHandler_)(int);
};
void TSockTest::TestBrokenPipe() {
TTempEnableSigPipe guard;
SOCKET socks[2];
int ret = SocketPair(socks);
UNIT_ASSERT_VALUES_EQUAL(ret, 0);
TSocket sender(socks[0]);
TSocket receiver(socks[1]);
receiver.ShutDown(SHUT_RDWR);
int sent = sender.Send("FOO", 3);
UNIT_ASSERT(sent < 0);
IOutputStream::TPart parts[] = {
{"foo", 3},
{"bar", 3},
};
sent = sender.SendV(parts, 2);
UNIT_ASSERT(sent < 0);
}
void TSockTest::TestClose() {
SOCKET socks[2];
UNIT_ASSERT_EQUAL(SocketPair(socks), 0);
TSocket receiver(socks[1]);
UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), socks[1]);
#if defined _linux_
UNIT_ASSERT_GE(fcntl(socks[1], F_GETFD), 0);
receiver.Close();
UNIT_ASSERT_EQUAL(fcntl(socks[1], F_GETFD), -1);
#else
receiver.Close();
#endif
UNIT_ASSERT_EQUAL(static_cast<SOCKET>(receiver), INVALID_SOCKET);
}
void TSockTest::TestReusePortAvailCheck() {
#if defined _linux_
utsname sysInfo;
Y_ABORT_UNLESS(!uname(&sysInfo), "Error while call uname: %s", LastSystemErrorText());
TStringBuf release(sysInfo.release);
release = release.substr(0, release.find_first_not_of(".0123456789"));
int v1 = FromString<int>(release.NextTok('.'));
int v2 = FromString<int>(release.NextTok('.'));
int v3 = FromString<int>(release.NextTok('.'));
int linuxVersionCode = KERNEL_VERSION(v1, v2, v3);
if (linuxVersionCode >= KERNEL_VERSION(3, 9, 1)) {
// new kernels support SO_REUSEPORT
UNIT_ASSERT(true == IsReusePortAvailable());
UNIT_ASSERT(true == IsReusePortAvailable());
} else {
// older kernels may or may not support SO_REUSEPORT
// just check that it doesn't crash or throw
(void)IsReusePortAvailable();
(void)IsReusePortAvailable();
}
#else
// check that it doesn't crash or throw
(void)IsReusePortAvailable();
(void)IsReusePortAvailable();
#endif
}
class TPollTest: public TTestBase {
UNIT_TEST_SUITE(TPollTest);
UNIT_TEST(TestPollInOut);
UNIT_TEST_SUITE_END();
public:
inline TPollTest() {
srand(static_cast<unsigned int>(time(nullptr)));
}
void TestPollInOut();
private:
sockaddr_in GetAddress(ui32 ip, ui16 port);
SOCKET CreateSocket();
SOCKET StartServerSocket(ui16 port, int backlog);
SOCKET StartClientSocket(ui32 ip, ui16 port);
SOCKET AcceptConnection(SOCKET serverSocket);
};
UNIT_TEST_SUITE_REGISTRATION(TPollTest);
sockaddr_in TPollTest::GetAddress(ui32 ip, ui16 port) {
struct sockaddr_in addr;
memset(&addr, 0, sizeof(addr));
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = htonl(ip);
return addr;
}
SOCKET TPollTest::CreateSocket() {
SOCKET s = socket(AF_INET, SOCK_STREAM, 0);
if (s == INVALID_SOCKET) {
ythrow yexception() << "Can not create socket (" << LastSystemErrorText() << ")";
}
return s;
}
SOCKET TPollTest::StartServerSocket(ui16 port, int backlog) {
TSocketHolder s(CreateSocket());
sockaddr_in addr = GetAddress(ntohl(INADDR_ANY), port);
if (bind(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
ythrow yexception() << "Can not bind server socket (" << LastSystemErrorText() << ")";
}
if (listen(s, backlog) == SOCKET_ERROR) {
ythrow yexception() << "Can not listen on server socket (" << LastSystemErrorText() << ")";
}
return s.Release();
}
SOCKET TPollTest::StartClientSocket(ui32 ip, ui16 port) {
TSocketHolder s(CreateSocket());
sockaddr_in addr = GetAddress(ip, port);
if (connect(s, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) {
ythrow yexception() << "Can not connect client socket (" << LastSystemErrorText() << ")";
}
return s.Release();
}
SOCKET TPollTest::AcceptConnection(SOCKET serverSocket) {
SOCKET connectedSocket = accept(serverSocket, nullptr, nullptr);
if (connectedSocket == INVALID_SOCKET) {
ythrow yexception() << "Can not accept connection on server socket (" << LastSystemErrorText() << ")";
}
return connectedSocket;
}
void TPollTest::TestPollInOut() {
#ifdef _win_
const size_t socketCount = 1000;
ui16 port = static_cast<ui16>(1300 + rand() % 97);
TSocketHolder serverSocket = StartServerSocket(port, socketCount);
ui32 localIp = ntohl(inet_addr("127.0.0.1"));
TVector<TSimpleSharedPtr<TSocketHolder>> clientSockets;
TVector<TSimpleSharedPtr<TSocketHolder>> connectedSockets;
TVector<pollfd> fds;
for (size_t i = 0; i < socketCount; ++i) {
TSimpleSharedPtr<TSocketHolder> clientSocket(new TSocketHolder(StartClientSocket(localIp, port)));
clientSockets.push_back(clientSocket);
if (i % 5 == 0 || i % 5 == 2) {
char buffer = 'c';
if (send(*clientSocket, &buffer, 1, 0) == -1)
ythrow yexception() << "Can not send (" << LastSystemErrorText() << ")";
}
TSimpleSharedPtr<TSocketHolder> connectedSocket(new TSocketHolder(AcceptConnection(serverSocket)));
connectedSockets.push_back(connectedSocket);
if (i % 5 == 2 || i % 5 == 3) {
closesocket(*clientSocket);
shutdown(*clientSocket, SD_BOTH);
}
}
int expectedCount = 0;
for (size_t i = 0; i < connectedSockets.size(); ++i) {
pollfd fd = {(i % 5 == 4) ? INVALID_SOCKET : static_cast<SOCKET>(*connectedSockets[i]), POLLIN | POLLOUT, 0};
fds.push_back(fd);
if (i % 5 != 4)
++expectedCount;
}
int polledCount = poll(&fds[0], fds.size(), INFTIM);
UNIT_ASSERT_EQUAL(expectedCount, polledCount);
for (size_t i = 0; i < connectedSockets.size(); ++i) {
short revents = fds[i].revents;
if (i % 5 == 0) {
UNIT_ASSERT_EQUAL(static_cast<short>(POLLRDNORM | POLLWRNORM), revents);
} else if (i % 5 == 1) {
UNIT_ASSERT_EQUAL(static_cast<short>(POLLOUT | POLLWRNORM), revents);
} else if (i % 5 == 2) {
UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLRDNORM | POLLWRNORM), revents);
} else if (i % 5 == 3) {
UNIT_ASSERT_EQUAL(static_cast<short>(POLLHUP | POLLWRNORM), revents);
} else if (i % 5 == 4) {
UNIT_ASSERT_EQUAL(static_cast<short>(POLLNVAL), revents);
}
}
#endif
}