#include "ip.h"
#include "socket.h"
#include "address.h"
#include "pollerimpl.h"
#include "iovec.h"
#include <util/system/defaults.h>
#include <util/system/byteorder.h>
#if defined(_unix_)
#include <netdb.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#endif
#if defined(_freebsd_)
#include <sys/module.h>
#define ACCEPT_FILTER_MOD
#include <sys/socketvar.h>
#endif
#if defined(_win_)
#include <cerrno>
#include <winsock2.h>
#include <ws2tcpip.h>
#include <wspiapi.h>
#include <util/system/compat.h>
#endif
#include <util/generic/ylimits.h>
#include <util/string/cast.h>
#include <util/stream/mem.h>
#include <util/system/datetime.h>
#include <util/system/error.h>
#include <util/memory/tempbuf.h>
#include <util/generic/singleton.h>
#include <util/generic/hash_set.h>
#include <stddef.h>
#include <sys/uio.h>
using namespace NAddr;
#if defined(_win_)
int inet_aton(const char* cp, struct in_addr* inp) {
sockaddr_in addr;
addr.sin_family = AF_INET;
int psz = sizeof(addr);
if (0 == WSAStringToAddress((char*)cp, AF_INET, nullptr, (LPSOCKADDR)&addr, &psz)) {
memcpy(inp, &addr.sin_addr, sizeof(in_addr));
return 1;
}
return 0;
}
#if (_WIN32_WINNT < 0x0600)
const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
if (af != AF_INET) {
errno = EINVAL;
return 0;
}
const ui8* ia = (ui8*)src;
if (snprintf(dst, size, "%u.%u.%u.%u", ia[0], ia[1], ia[2], ia[3]) >= (int)size) {
errno = ENOSPC;
return 0;
}
return dst;
}
struct evpair {
int event;
int winevent;
};
static const evpair evpairs_to_win[] = {
{POLLIN, FD_READ | FD_CLOSE | FD_ACCEPT},
{POLLRDNORM, FD_READ | FD_CLOSE | FD_ACCEPT},
{POLLRDBAND, -1},
{POLLPRI, -1},
{POLLOUT, FD_WRITE | FD_CLOSE},
{POLLWRNORM, FD_WRITE | FD_CLOSE},
{POLLWRBAND, -1},
{POLLERR, 0},
{POLLHUP, 0},
{POLLNVAL, 0}};
static const size_t nevpairs_to_win = sizeof(evpairs_to_win) / sizeof(evpairs_to_win[0]);
static const evpair evpairs_to_unix[] = {
{FD_ACCEPT, POLLIN | POLLRDNORM},
{FD_READ, POLLIN | POLLRDNORM},
{FD_WRITE, POLLOUT | POLLWRNORM},
{FD_CLOSE, POLLHUP},
};
static const size_t nevpairs_to_unix = sizeof(evpairs_to_unix) / sizeof(evpairs_to_unix[0]);
static int convert_events(int events, const evpair* evpairs, size_t nevpairs, bool ignoreUnknown) noexcept {
int result = 0;
for (size_t i = 0; i < nevpairs; ++i) {
int event = evpairs[i].event;
if (events & event) {
events ^= event;
long winEvent = evpairs[i].winevent;
if (winEvent == -1)
return -1;
if (winEvent == 0)
continue;
result |= winEvent;
}
}
if (events != 0 && !ignoreUnknown)
return -1;
return result;
}
class TWSAEventHolder {
private:
HANDLE Event;
public:
inline TWSAEventHolder(HANDLE event) noexcept
: Event(event)
{
}
inline ~TWSAEventHolder() {
WSACloseEvent(Event);
}
inline HANDLE Get() noexcept {
return Event;
}
};
int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept {
HANDLE rawEvent = WSACreateEvent();
if (rawEvent == WSA_INVALID_EVENT) {
errno = EIO;
return -1;
}
TWSAEventHolder event(rawEvent);
int checked_sockets = 0;
for (pollfd* fd = fds; fd < fds + nfds; ++fd) {
int win_events = convert_events(fd->events, evpairs_to_win, nevpairs_to_win, false);
if (win_events == -1) {
errno = EINVAL;
return -1;
}
fd->revents = 0;
if (WSAEventSelect(fd->fd, event.Get(), win_events)) {
int error = WSAGetLastError();
if (error == WSAEINVAL || error == WSAENOTSOCK) {
fd->revents = POLLNVAL;
++checked_sockets;
} else {
errno = EIO;
return -1;
}
}
fd_set readfds;
fd_set writefds;
struct timeval timeout = {0, 0};
FD_ZERO(&readfds);
FD_ZERO(&writefds);
if (fd->events & POLLIN) {
FD_SET(fd->fd, &readfds);
}
if (fd->events & POLLOUT) {
FD_SET(fd->fd, &writefds);
}
int error = select(0, &readfds, &writefds, nullptr, &timeout);
if (error > 0) {
if (FD_ISSET(fd->fd, &readfds)) {
fd->revents |= POLLIN;
}
if (FD_ISSET(fd->fd, &writefds)) {
fd->revents |= POLLOUT;
}
++checked_sockets;
}
}
if (checked_sockets > 0) {
// returns without wait since we already have sockets in desired conditions
return checked_sockets;
}
HANDLE events[] = {event.Get()};
DWORD wait_result = WSAWaitForMultipleEvents(1, events, TRUE, timeout, FALSE);
if (wait_result == WSA_WAIT_TIMEOUT)
return 0;
else if (wait_result == WSA_WAIT_EVENT_0) {
for (pollfd* fd = fds; fd < fds + nfds; ++fd) {
if (fd->revents == POLLNVAL)
continue;
WSANETWORKEVENTS network_events;
if (WSAEnumNetworkEvents(fd->fd, event.Get(), &network_events)) {
errno = EIO;
return -1;
}
fd->revents = 0;
for (int i = 0; i < FD_MAX_EVENTS; ++i) {
if ((network_events.lNetworkEvents & (1 << i)) != 0 && network_events.iErrorCode[i]) {
fd->revents = POLLERR;
break;
}
}
if (fd->revents == POLLERR)
continue;
if (network_events.lNetworkEvents) {
fd->revents = static_cast<short>(convert_events(network_events.lNetworkEvents, evpairs_to_unix, nevpairs_to_unix, true));
if (fd->revents & POLLHUP) {
fd->revents &= POLLHUP | POLLIN | POLLRDNORM;
}
}
}
int chanded_sockets = 0;
for (pollfd* fd = fds; fd < fds + nfds; ++fd)
if (fd->revents != 0)
++chanded_sockets;
return chanded_sockets;
} else {
errno = EIO;
return -1;
}
}
#endif
#endif
bool GetRemoteAddr(SOCKET Socket, char* str, socklen_t size) {
if (!size) {
return false;
}
TOpaqueAddr addr;
if (getpeername(Socket, addr.MutableAddr(), addr.LenPtr()) != 0) {
return false;
}
try {
TMemoryOutput out(str, size - 1);
PrintHost(out, addr);
*out.Buf() = 0;
return true;
} catch (...) {
// ¯\_(ツ)_/¯
}
return false;
}
void SetSocketTimeout(SOCKET s, long timeout) {
SetSocketTimeout(s, timeout, 0);
}
void SetSocketTimeout(SOCKET s, long sec, long msec) {
#ifdef SO_SNDTIMEO
#ifdef _darwin_
const timeval timeout = {sec, (__darwin_suseconds_t)msec * 1000};
#elif defined(_unix_)
const timeval timeout = {sec, msec * 1000};
#else
const int timeout = sec * 1000 + msec;
#endif
CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVTIMEO, timeout, "recv timeout");
CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDTIMEO, timeout, "send timeout");
#endif
}
void SetLinger(SOCKET s, bool on, unsigned len) {
#ifdef SO_LINGER
struct linger l = {on, (u_short)len};
CheckedSetSockOpt(s, SOL_SOCKET, SO_LINGER, l, "linger");
#endif
}
void SetZeroLinger(SOCKET s) {
SetLinger(s, 1, 0);
}
void SetKeepAlive(SOCKET s, bool value) {
CheckedSetSockOpt(s, SOL_SOCKET, SO_KEEPALIVE, (int)value, "keepalive");
}
void SetOutputBuffer(SOCKET s, unsigned value) {
CheckedSetSockOpt(s, SOL_SOCKET, SO_SNDBUF, value, "output buffer");
}
void SetInputBuffer(SOCKET s, unsigned value) {
CheckedSetSockOpt(s, SOL_SOCKET, SO_RCVBUF, value, "input buffer");
}
#if defined(_linux_) && !defined(SO_REUSEPORT)
#define SO_REUSEPORT 15
#endif
void SetReusePort(SOCKET s, bool value) {
#if defined(SO_REUSEPORT)
CheckedSetSockOpt(s, SOL_SOCKET, SO_REUSEPORT, (int)value, "reuse port");
#else
Y_UNUSED(s);
Y_UNUSED(value);
ythrow TSystemError(ENOSYS) << "SO_REUSEPORT is not defined";
#endif
}
void SetNoDelay(SOCKET s, bool value) {
CheckedSetSockOpt(s, IPPROTO_TCP, TCP_NODELAY, (int)value, "tcp no delay");
}
void SetCloseOnExec(SOCKET s, bool value) {
#if defined(_unix_)
int flags = fcntl(s, F_GETFD);
if (flags == -1) {
ythrow TSystemError() << "fcntl() failed";
}
if (value) {
flags |= FD_CLOEXEC;
} else {
flags &= ~FD_CLOEXEC;
}
if (fcntl(s, F_SETFD, flags) == -1) {
ythrow TSystemError() << "fcntl() failed";
}
#else
Y_UNUSED(s);
Y_UNUSED(value);
#endif
}
size_t GetMaximumSegmentSize(SOCKET s) {
#if defined(TCP_MAXSEG)
int val;
if (GetSockOpt(s, IPPROTO_TCP, TCP_MAXSEG, val) == 0) {
return (size_t)val;
}
#endif
/*
* probably a good guess...
*/
return 8192;
}
size_t GetMaximumTransferUnit(SOCKET /*s*/) {
// for someone who'll dare to write it
// Linux: there rummored to be IP_MTU getsockopt() request
// FreeBSD: request to a socket of type PF_ROUTE
// with peer address as a destination argument
return 8192;
}
int GetSocketToS(SOCKET s) {
TOpaqueAddr addr;
if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) {
ythrow TSystemError() << "getsockname() failed";
}
return GetSocketToS(s, &addr);
}
int GetSocketToS(SOCKET s, const IRemoteAddr* addr) {
int result = 0;
switch (addr->Addr()->sa_family) {
case AF_INET:
CheckedGetSockOpt(s, IPPROTO_IP, IP_TOS, result, "tos");
break;
case AF_INET6:
#ifdef IPV6_TCLASS
CheckedGetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, result, "tos");
#endif
break;
}
return result;
}
void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos) {
switch (addr->Addr()->sa_family) {
case AF_INET:
CheckedSetSockOpt(s, IPPROTO_IP, IP_TOS, tos, "tos");
return;
case AF_INET6:
#ifdef IPV6_TCLASS
CheckedSetSockOpt(s, IPPROTO_IPV6, IPV6_TCLASS, tos, "tos");
return;
#endif
break;
}
ythrow yexception() << "SetSocketToS unsupported for family " << addr->Addr()->sa_family;
}
void SetSocketToS(SOCKET s, int tos) {
TOpaqueAddr addr;
if (getsockname(s, addr.MutableAddr(), addr.LenPtr()) < 0) {
ythrow TSystemError() << "getsockname() failed";
}
SetSocketToS(s, &addr, tos);
}
void SetSocketPriority(SOCKET s, int priority) {
#if defined(SO_PRIORITY)
CheckedSetSockOpt(s, SOL_SOCKET, SO_PRIORITY, priority, "priority");
#else
Y_UNUSED(s);
Y_UNUSED(priority);
#endif
}
bool HasLocalAddress(SOCKET socket) {
TOpaqueAddr localAddr;
if (getsockname(socket, localAddr.MutableAddr(), localAddr.LenPtr()) != 0) {
ythrow TSystemError() << "HasLocalAddress: getsockname() failed. ";
}
if (IsLoopback(localAddr)) {
return true;
}
TOpaqueAddr remoteAddr;
if (getpeername(socket, remoteAddr.MutableAddr(), remoteAddr.LenPtr()) != 0) {
ythrow TSystemError() << "HasLocalAddress: getpeername() failed. ";
}
return IsSame(localAddr, remoteAddr);
}
namespace {
#if defined(_linux_)
#if !defined(TCP_FASTOPEN)
#define TCP_FASTOPEN 23
#endif
#endif
#if defined(TCP_FASTOPEN)
struct TTcpFastOpenFeature {
inline TTcpFastOpenFeature()
: HasFastOpen_(false)
{
TSocketHolder tmp(socket(AF_INET, SOCK_STREAM, 0));
int val = 1;
int ret = SetSockOpt(tmp, IPPROTO_TCP, TCP_FASTOPEN, val);
HasFastOpen_ = (ret == 0);
}
inline void SetFastOpen(SOCKET s, int qlen) const {
if (HasFastOpen_) {
CheckedSetSockOpt(s, IPPROTO_TCP, TCP_FASTOPEN, qlen, "setting TCP_FASTOPEN");
}
}
static inline const TTcpFastOpenFeature* Instance() noexcept {
return Singleton<TTcpFastOpenFeature>();
}
bool HasFastOpen_;
};
#endif
}
void SetTcpFastOpen(SOCKET s, int qlen) {
#if defined(TCP_FASTOPEN)
TTcpFastOpenFeature::Instance()->SetFastOpen(s, qlen);
#else
Y_UNUSED(s);
Y_UNUSED(qlen);
#endif
}
static bool IsBlocked(int lasterr) noexcept {
return lasterr == EAGAIN || lasterr == EWOULDBLOCK;
}
struct TUnblockingGuard {
SOCKET S_;
TUnblockingGuard(SOCKET s)
: S_(s)
{
SetNonBlock(S_, true);
}
~TUnblockingGuard() {
SetNonBlock(S_, false);
}
};
static int MsgPeek(SOCKET s) {
int flags = MSG_PEEK;
#if defined(_win_)
TUnblockingGuard unblocker(s);
Y_UNUSED(unblocker);
#else
flags |= MSG_DONTWAIT;
#endif
char c;
return recv(s, &c, 1, flags);
}
bool IsNotSocketClosedByOtherSide(SOCKET s) {
return HasSocketDataToRead(s) != ESocketReadStatus::SocketClosed;
}
ESocketReadStatus HasSocketDataToRead(SOCKET s) {
const int r = MsgPeek(s);
if (r == -1 && IsBlocked(LastSystemError())) {
return ESocketReadStatus::NoData;
}
if (r > 0) {
return ESocketReadStatus::HasData;
}
return ESocketReadStatus::SocketClosed;
}
#if defined(_win_)
static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) {
return writev(sock, iov, iovcnt);
}
#else
static ssize_t DoSendMsg(SOCKET sock, const struct iovec* iov, int iovcnt) {
struct msghdr message;
Zero(message);
message.msg_iov = const_cast<struct iovec*>(iov);
message.msg_iovlen = iovcnt;
return sendmsg(sock, &message, MSG_NOSIGNAL);
}
#endif
void TSocketHolder::Close() noexcept {
if (Fd_ != INVALID_SOCKET) {
bool ok = (closesocket(Fd_) == 0);
if (!ok) {
// Do not quietly close bad descriptor,
// because often it means double close
// that is disasterous
#ifdef _win_
Y_VERIFY(WSAGetLastError() != WSAENOTSOCK, "must not quietly close bad socket descriptor");
#elif defined(_unix_)
Y_VERIFY(errno != EBADF, "must not quietly close bad descriptor: fd=%d", int(Fd_));
#else
#error unsupported platform
#endif
}
Fd_ = INVALID_SOCKET;
}
}
class TSocket::TImpl: public TAtomicRefCount<TImpl> {
using TOps = TSocket::TOps;
public:
inline TImpl(SOCKET fd, TOps* ops)
: Fd_(fd)
, Ops_(ops)
{
}
inline ~TImpl() = default;
inline SOCKET Fd() const noexcept {
return Fd_;
}
inline ssize_t Send(const void* data, size_t len) {
return Ops_->Send(Fd_, data, len);
}
inline ssize_t Recv(void* buf, size_t len) {
return Ops_->Recv(Fd_, buf, len);
}
inline ssize_t SendV(const TPart* parts, size_t count) {
return Ops_->SendV(Fd_, parts, count);
}
inline void Close() {
Fd_.Close();
}
private:
TSocketHolder Fd_;
TOps* Ops_;
};
template <>
void Out<const struct addrinfo*>(IOutputStream& os, const struct addrinfo* ai) {
if (ai->ai_flags & AI_CANONNAME) {
os << "`" << ai->ai_canonname << "' ";
}
os << '[';
for (int i = 0; ai; ++i, ai = ai->ai_next) {
if (i > 0) {
os << ", ";
}
os << (const IRemoteAddr&)TAddrInfo(ai);
}
os << ']';
}
template <>
void Out<struct addrinfo*>(IOutputStream& os, struct addrinfo* ai) {
Out<const struct addrinfo*>(os, static_cast<const struct addrinfo*>(ai));
}
template <>
void Out<TNetworkAddress>(IOutputStream& os, const TNetworkAddress& addr) {
os << &*addr.Begin();
}
static inline const struct addrinfo* Iterate(const struct addrinfo* addr, const struct addrinfo* addr0, const int sockerr) {
if (addr->ai_next) {
return addr->ai_next;
}
ythrow TSystemError(sockerr) << "can not connect to " << addr0;
}
static inline SOCKET DoConnectImpl(const struct addrinfo* res, const TInstant& deadLine) {
const struct addrinfo* addr0 = res;
while (res) {
TSocketHolder s(socket(res->ai_family, res->ai_socktype, res->ai_protocol));
if (s.Closed()) {
res = Iterate(res, addr0, LastSystemError());
continue;
}
SetNonBlock(s, true);
if (connect(s, res->ai_addr, (int)res->ai_addrlen)) {
int err = LastSystemError();
if (err == EINPROGRESS || err == EAGAIN || err == EWOULDBLOCK) {
/*
* must wait
*/
struct pollfd p = {
(SOCKET)s,
POLLOUT,
0};
const ssize_t n = PollD(&p, 1, deadLine);
/*
* timeout occured
*/
if (n < 0) {
ythrow TSystemError(-(int)n) << "can not connect";
}
CheckedGetSockOpt(s, SOL_SOCKET, SO_ERROR, err, "socket error");
if (!err) {
return s.Release();
}
}
res = Iterate(res, addr0, err);
continue;
}
return s.Release();
}
ythrow yexception() << "something went wrong: nullptr at addrinfo";
}
static inline SOCKET DoConnect(const struct addrinfo* res, const TInstant& deadLine) {
TSocketHolder ret(DoConnectImpl(res, deadLine));
SetNonBlock(ret, false);
return ret.Release();
}
static inline ssize_t DoSendV(SOCKET fd, const struct iovec* iov, size_t count) {
ssize_t ret = -1;
do {
ret = DoSendMsg(fd, iov, (int)count);
} while (ret == -1 && errno == EINTR);
if (ret < 0) {
return -LastSystemError();
}
return ret;
}
template <bool isCompat>
struct TSender {
using TPart = TSocket::TPart;
static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) {
return DoSendV(fd, (const iovec*)parts, count);
}
};
template <>
struct TSender<false> {
using TPart = TSocket::TPart;
static inline ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) {
TTempBuf tempbuf(sizeof(struct iovec) * count);
struct iovec* iov = (struct iovec*)tempbuf.Data();
for (size_t i = 0; i < count; ++i) {
struct iovec& io = iov[i];
const TPart& part = parts[i];
io.iov_base = (char*)part.buf;
io.iov_len = part.len;
}
return DoSendV(fd, iov, count);
}
};
class TCommonSockOps: public TSocket::TOps {
using TPart = TSocket::TPart;
public:
inline TCommonSockOps() noexcept {
}
~TCommonSockOps() override = default;
ssize_t Send(SOCKET fd, const void* data, size_t len) override {
ssize_t ret = -1;
do {
ret = send(fd, (const char*)data, (int)len, MSG_NOSIGNAL);
} while (ret == -1 && errno == EINTR);
if (ret < 0) {
return -LastSystemError();
}
return ret;
}
ssize_t Recv(SOCKET fd, void* buf, size_t len) override {
ssize_t ret = -1;
do {
ret = recv(fd, (char*)buf, (int)len, 0);
} while (ret == -1 && errno == EINTR);
if (ret < 0) {
return -LastSystemError();
}
return ret;
}
ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) override {
ssize_t ret = SendVImpl(fd, parts, count);
if (ret < 0) {
return ret;
}
size_t len = TContIOVector::Bytes(parts, count);
if ((size_t)ret == len) {
return ret;
}
return SendVPartial(fd, parts, count, ret);
}
inline ssize_t SendVImpl(SOCKET fd, const TPart* parts, size_t count) {
return TSender < (sizeof(iovec) == sizeof(TPart)) && (offsetof(iovec, iov_base) == offsetof(TPart, buf)) && (offsetof(iovec, iov_len) == offsetof(TPart, len)) > ::SendV(fd, parts, count);
}
ssize_t SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written);
};
ssize_t TCommonSockOps::SendVPartial(SOCKET fd, const TPart* constParts, size_t count, size_t written) {
TTempBuf tempbuf(sizeof(TPart) * count);
TPart* parts = (TPart*)tempbuf.Data();
for (size_t i = 0; i < count; ++i) {
parts[i] = constParts[i];
}
TContIOVector vec(parts, count);
vec.Proceed(written);
while (!vec.Complete()) {
ssize_t ret = SendVImpl(fd, vec.Parts(), vec.Count());
if (ret < 0) {
return ret;
}
written += ret;
vec.Proceed((size_t)ret);
}
return written;
}
static inline TSocket::TOps* GetCommonSockOps() noexcept {
return Singleton<TCommonSockOps>();
}
TSocket::TSocket()
: Impl_(new TImpl(INVALID_SOCKET, GetCommonSockOps()))
{
}
TSocket::TSocket(SOCKET fd)
: Impl_(new TImpl(fd, GetCommonSockOps()))
{
}
TSocket::TSocket(SOCKET fd, TOps* ops)
: Impl_(new TImpl(fd, ops))
{
}
TSocket::TSocket(const TNetworkAddress& addr)
: Impl_(new TImpl(DoConnect(addr.Info(), TInstant::Max()), GetCommonSockOps()))
{
}
TSocket::TSocket(const TNetworkAddress& addr, const TDuration& timeOut)
: Impl_(new TImpl(DoConnect(addr.Info(), timeOut.ToDeadLine()), GetCommonSockOps()))
{
}
TSocket::TSocket(const TNetworkAddress& addr, const TInstant& deadLine)
: Impl_(new TImpl(DoConnect(addr.Info(), deadLine), GetCommonSockOps()))
{
}
TSocket::~TSocket() = default;
SOCKET TSocket::Fd() const noexcept {
return Impl_->Fd();
}
ssize_t TSocket::Send(const void* data, size_t len) {
return Impl_->Send(data, len);
}
ssize_t TSocket::Recv(void* buf, size_t len) {
return Impl_->Recv(buf, len);
}
ssize_t TSocket::SendV(const TPart* parts, size_t count) {
return Impl_->SendV(parts, count);
}
void TSocket::Close() {
Impl_->Close();
}
TSocketInput::TSocketInput(const TSocket& s) noexcept
: S_(s)
{
}
TSocketInput::~TSocketInput() = default;
size_t TSocketInput::DoRead(void* buf, size_t len) {
const ssize_t ret = S_.Recv(buf, len);
if (ret >= 0) {
return (size_t)ret;
}
ythrow TSystemError(-(int)ret) << "can not read from socket input stream";
}
TSocketOutput::TSocketOutput(const TSocket& s) noexcept
: S_(s)
{
}
TSocketOutput::~TSocketOutput() {
try {
Finish();
} catch (...) {
// ¯\_(ツ)_/¯
}
}
void TSocketOutput::DoWrite(const void* buf, size_t len) {
size_t send = 0;
while (len) {
const ssize_t ret = S_.Send(buf, len);
if (ret < 0) {
ythrow TSystemError(-(int)ret) << "can not write to socket output stream; " << send << " bytes already send";
}
buf = (const char*)buf + ret;
len -= ret;
send += ret;
}
}
void TSocketOutput::DoWriteV(const TPart* parts, size_t count) {
const ssize_t ret = S_.SendV(parts, count);
if (ret < 0) {
ythrow TSystemError(-(int)ret) << "can not writev to socket output stream";
}
/*
* todo for nonblocking sockets?
*/
}
namespace {
//https://bugzilla.mozilla.org/attachment.cgi?id=503263&action=diff
struct TLocalNames: public THashSet<TStringBuf> {
inline TLocalNames() {
insert("localhost");
insert("localhost.localdomain");
insert("localhost6");
insert("localhost6.localdomain6");
insert("::1");
}
inline bool IsLocalName(const char* name) const noexcept {
struct sockaddr_in sa;
memset(&sa, 0, sizeof(sa));
if (inet_pton(AF_INET, name, &(sa.sin_addr)) == 1) {
return (InetToHost(sa.sin_addr.s_addr) >> 24) == 127;
}
return contains(name);
}
};
}
class TNetworkAddress::TImpl: public TAtomicRefCount<TImpl> {
private:
class TAddrInfoDeleter {
public:
TAddrInfoDeleter(bool useFreeAddrInfo = true)
: UseFreeAddrInfo_(useFreeAddrInfo)
{
}
void operator()(struct addrinfo* ai) noexcept {
if (!UseFreeAddrInfo_ && ai != NULL) {
if (ai->ai_addr != NULL) {
free(ai->ai_addr);
}
struct addrinfo* p;
while (ai != NULL) {
p = ai;
ai = ai->ai_next;
free(p->ai_canonname);
free(p);
}
} else if (ai != NULL) {
freeaddrinfo(ai);
}
}
private:
bool UseFreeAddrInfo_ = true;
};
public:
inline TImpl(const char* host, ui16 port, int flags)
: Info_(nullptr, TAddrInfoDeleter{})
{
const TString port_st(ToString(port));
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_flags = flags;
hints.ai_family = PF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if (!host) {
hints.ai_flags |= AI_PASSIVE;
} else {
if (!Singleton<TLocalNames>()->IsLocalName(host)) {
hints.ai_flags |= AI_ADDRCONFIG;
}
}
struct addrinfo* pai = NULL;
const int error = getaddrinfo(host, port_st.data(), &hints, &pai);
if (error) {
TAddrInfoDeleter()(pai);
ythrow TNetworkResolutionError(error) << ": can not resolve " << host << ":" << port;
}
Info_.reset(pai);
}
inline TImpl(const char* path, int flags)
: Info_(nullptr, TAddrInfoDeleter{/* useFreeAddrInfo = */ false})
{
THolder<struct sockaddr_un, TFree> sockAddr(
reinterpret_cast<struct sockaddr_un*>(malloc(sizeof(struct sockaddr_un))));
Y_ENSURE(strlen(path) < sizeof(sockAddr->sun_path), "Unix socket path more than " << sizeof(sockAddr->sun_path));
sockAddr->sun_family = AF_UNIX;
strcpy(sockAddr->sun_path, path);
TAddrInfoPtr hints(reinterpret_cast<struct addrinfo*>(malloc(sizeof(struct addrinfo))), TAddrInfoDeleter{/* useFreeAddrInfo = */ false});
memset(hints.get(), 0, sizeof(*hints));
hints->ai_flags = flags;
hints->ai_family = AF_UNIX;
hints->ai_socktype = SOCK_STREAM;
hints->ai_addrlen = sizeof(*sockAddr);
hints->ai_addr = (struct sockaddr*)sockAddr.Release();
Info_.reset(hints.release());
}
inline struct addrinfo* Info() const noexcept {
return Info_.get();
}
private:
using TAddrInfoPtr = std::unique_ptr<struct addrinfo, TAddrInfoDeleter>;
TAddrInfoPtr Info_;
};
TNetworkAddress::TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags)
: Impl_(new TImpl(unixSocketPath.Path.data(), flags))
{
}
TNetworkAddress::TNetworkAddress(const TString& host, ui16 port, int flags)
: Impl_(new TImpl(host.data(), port, flags))
{
}
TNetworkAddress::TNetworkAddress(const TString& host, ui16 port)
: Impl_(new TImpl(host.data(), port, 0))
{
}
TNetworkAddress::TNetworkAddress(ui16 port)
: Impl_(new TImpl(nullptr, port, 0))
{
}
TNetworkAddress::~TNetworkAddress() = default;
struct addrinfo* TNetworkAddress::Info() const noexcept {
return Impl_->Info();
}
TNetworkResolutionError::TNetworkResolutionError(int error) {
const char* errMsg = nullptr;
#ifdef _win_
errMsg = LastSystemErrorText(error); // gai_strerror is not thread-safe on Windows
#else
errMsg = gai_strerror(error);
#endif
(*this) << errMsg << "(" << error;
#if defined(_unix_)
if (error == EAI_SYSTEM) {
(*this) << "; errno=" << LastSystemError();
}
#endif
(*this) << "): ";
}
#if defined(_unix_)
static inline int GetFlags(int fd) {
const int ret = fcntl(fd, F_GETFL);
if (ret == -1) {
ythrow TSystemError() << "can not get fd flags";
}
return ret;
}
static inline void SetFlags(int fd, int flags) {
if (fcntl(fd, F_SETFL, flags) == -1) {
ythrow TSystemError() << "can not set fd flags";
}
}
static inline void EnableFlag(int fd, int flag) {
const int oldf = GetFlags(fd);
const int newf = oldf | flag;
if (oldf != newf) {
SetFlags(fd, newf);
}
}
static inline void DisableFlag(int fd, int flag) {
const int oldf = GetFlags(fd);
const int newf = oldf & (~flag);
if (oldf != newf) {
SetFlags(fd, newf);
}
}
static inline void SetFlag(int fd, int flag, bool value) {
if (value) {
EnableFlag(fd, flag);
} else {
DisableFlag(fd, flag);
}
}
static inline bool FlagsAreEnabled(int fd, int flags) {
return GetFlags(fd) & flags;
}
#endif
#if defined(_win_)
static inline void SetNonBlockSocket(SOCKET fd, int value) {
unsigned long inbuf = value;
unsigned long outbuf = 0;
DWORD written = 0;
if (!inbuf) {
WSAEventSelect(fd, nullptr, 0);
}
if (WSAIoctl(fd, FIONBIO, &inbuf, sizeof(inbuf), &outbuf, sizeof(outbuf), &written, 0, 0) == SOCKET_ERROR) {
ythrow TSystemError() << "can not set non block socket state";
}
}
static inline bool IsNonBlockSocket(SOCKET fd) {
unsigned long buf = 0;
if (WSAIoctl(fd, FIONBIO, 0, 0, &buf, sizeof(buf), 0, 0, 0) == SOCKET_ERROR) {
ythrow TSystemError() << "can not get non block socket state";
}
return buf;
}
#endif
void SetNonBlock(SOCKET fd, bool value) {
#if defined(_unix_)
#if defined(FIONBIO)
Y_UNUSED(SetFlag); // shut up clang about unused function
int nb = value;
if (ioctl(fd, FIONBIO, &nb) < 0) {
ythrow TSystemError() << "ioctl failed";
}
#else
SetFlag(fd, O_NONBLOCK, value);
#endif
#elif defined(_win_)
SetNonBlockSocket(fd, value);
#else
#error todo
#endif
}
bool IsNonBlock(SOCKET fd) {
#if defined(_unix_)
return FlagsAreEnabled(fd, O_NONBLOCK);
#elif defined(_win_)
return IsNonBlockSocket(fd);
#else
#error todo
#endif
}
void SetDeferAccept(SOCKET s) {
(void)s;
#if defined(TCP_DEFER_ACCEPT)
CheckedSetSockOpt(s, IPPROTO_TCP, TCP_DEFER_ACCEPT, 10, "defer accept");
#endif
#if defined(SO_ACCEPTFILTER)
struct accept_filter_arg afa;
Zero(afa);
strcpy(afa.af_name, "dataready");
SetSockOpt(s, SOL_SOCKET, SO_ACCEPTFILTER, afa);
#endif
}
ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept {
TInstant now = TInstant::Now();
do {
const TDuration toWait = PollStep(deadLine, now);
const int res = poll(fds, nfds, MicroToMilli(toWait.MicroSeconds()));
if (res > 0) {
return res;
}
if (res < 0) {
const int err = LastSystemError();
if (err != ETIMEDOUT && err != EINTR) {
return -err;
}
}
} while ((now = TInstant::Now()) < deadLine);
return -ETIMEDOUT;
}
void ShutDown(SOCKET s, int mode) {
if (shutdown(s, mode)) {
ythrow TSystemError() << "shutdown socket error";
}
}
extern "C" bool IsReusePortAvailable() {
// SO_REUSEPORT is always defined for linux builds, see SetReusePort() implementation above
#if defined(SO_REUSEPORT)
class TCtx {
public:
TCtx() {
TSocketHolder sock(::socket(AF_INET, SOCK_STREAM, 0));
const int e1 = errno;
if (sock == INVALID_SOCKET) {
ythrow TSystemError(e1) << "Cannot create AF_INET socket";
}
int val;
const int ret = GetSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, val);
const int e2 = errno;
if (ret == 0) {
Flag_ = true;
} else {
if (e2 == ENOPROTOOPT) {
Flag_ = false;
} else {
ythrow TSystemError(e2) << "Unexpected error in getsockopt";
}
}
}
static inline const TCtx* Instance() noexcept {
return Singleton<TCtx>();
}
public:
bool Flag_;
};
return TCtx::Instance()->Flag_;
#else
return false;
#endif
}