diff options
author | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/network/socket.h |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/network/socket.h')
-rw-r--r-- | util/network/socket.h | 431 |
1 files changed, 431 insertions, 0 deletions
diff --git a/util/network/socket.h b/util/network/socket.h new file mode 100644 index 00000000000..357ad4079bc --- /dev/null +++ b/util/network/socket.h @@ -0,0 +1,431 @@ +#pragma once + +#include "init.h" + +#include <util/system/yassert.h> +#include <util/system/defaults.h> +#include <util/system/error.h> +#include <util/stream/output.h> +#include <util/stream/input.h> +#include <util/generic/ptr.h> +#include <util/generic/yexception.h> +#include <util/generic/noncopyable.h> +#include <util/datetime/base.h> + +#include <cerrno> + +#ifndef INET_ADDRSTRLEN + #define INET_ADDRSTRLEN 16 +#endif + +#if defined(_unix_) + #define get_host_error() h_errno +#elif defined(_win_) + #pragma comment(lib, "Ws2_32.lib") + + #if _WIN32_WINNT < 0x0600 +struct pollfd { + SOCKET fd; + short events; + short revents; +}; + + #define POLLIN (1 << 0) + #define POLLRDNORM (1 << 1) + #define POLLRDBAND (1 << 2) + #define POLLPRI (1 << 3) + #define POLLOUT (1 << 4) + #define POLLWRNORM (1 << 5) + #define POLLWRBAND (1 << 6) + #define POLLERR (1 << 7) + #define POLLHUP (1 << 8) + #define POLLNVAL (1 << 9) + +const char* inet_ntop(int af, const void* src, char* dst, socklen_t size); +int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept; + #else + #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) + #endif + +int inet_aton(const char* cp, struct in_addr* inp); + + #define get_host_error() WSAGetLastError() + + #define SHUT_RD SD_RECEIVE + #define SHUT_WR SD_SEND + #define SHUT_RDWR SD_BOTH + + #define INFTIM (-1) +#endif + +template <class T> +static inline int SetSockOpt(SOCKET s, int level, int optname, T opt) noexcept { + return setsockopt(s, level, optname, (const char*)&opt, sizeof(opt)); +} + +template <class T> +static inline int GetSockOpt(SOCKET s, int level, int optname, T& opt) noexcept { + socklen_t len = sizeof(opt); + + return getsockopt(s, level, optname, (char*)&opt, &len); +} + +template <class T> +static inline void CheckedSetSockOpt(SOCKET s, int level, int optname, T opt, const char* err) { + if (SetSockOpt<T>(s, level, optname, opt)) { + ythrow TSystemError() << "setsockopt() failed for " << err; + } +} + +template <class T> +static inline void CheckedGetSockOpt(SOCKET s, int level, int optname, T& opt, const char* err) { + if (GetSockOpt<T>(s, level, optname, opt)) { + ythrow TSystemError() << "getsockopt() failed for " << err; + } +} + +static inline void FixIPv6ListenSocket(SOCKET s) { +#if defined(IPV6_V6ONLY) + SetSockOpt(s, IPPROTO_IPV6, IPV6_V6ONLY, 1); +#else + (void)s; +#endif +} + +namespace NAddr { + class IRemoteAddr; +} + +void SetSocketTimeout(SOCKET s, long timeout); +void SetSocketTimeout(SOCKET s, long sec, long msec); +void SetNoDelay(SOCKET s, bool value); +void SetKeepAlive(SOCKET s); +void SetLinger(SOCKET s, bool on, unsigned len); +void SetZeroLinger(SOCKET s); +void SetKeepAlive(SOCKET s, bool value); +void SetCloseOnExec(SOCKET s, bool value); +void SetOutputBuffer(SOCKET s, unsigned value); +void SetInputBuffer(SOCKET s, unsigned value); +void SetReusePort(SOCKET s, bool value); +void ShutDown(SOCKET s, int mode); +bool GetRemoteAddr(SOCKET s, char* str, socklen_t size); +size_t GetMaximumSegmentSize(SOCKET s); +size_t GetMaximumTransferUnit(SOCKET s); +void SetDeferAccept(SOCKET s); +void SetSocketToS(SOCKET s, int tos); +void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos); +int GetSocketToS(SOCKET s); +int GetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr); +void SetTcpFastOpen(SOCKET s, int qlen); +/** + * Deprecated, consider using HasSocketDataToRead instead. + **/ +bool IsNotSocketClosedByOtherSide(SOCKET s); +enum class ESocketReadStatus { + HasData, + NoData, + SocketClosed +}; +/** + * Useful for keep-alive connections. + **/ +ESocketReadStatus HasSocketDataToRead(SOCKET s); +/** + * Determines whether connection on socket is local (same machine) or not. + **/ +bool HasLocalAddress(SOCKET socket); + +/** + * Runtime check if current kernel supports SO_REUSEPORT option. + **/ +extern "C" bool IsReusePortAvailable(); + +bool IsNonBlock(SOCKET fd); +void SetNonBlock(SOCKET fd, bool nonBlock = true); + +struct addrinfo; + +class TNetworkResolutionError: public yexception { +public: + // @param error error code (EAI_XXX) returned by getaddrinfo or getnameinfo (not errno) + TNetworkResolutionError(int error); +}; + +struct TUnixSocketPath { + TString Path; + + // Constructor for create unix domain socket path from string with path in filesystem + // TUnixSocketPath("/tmp/unixsocket") -> "/tmp/unixsocket" + explicit TUnixSocketPath(const TString& path) + : Path(path) + { + } +}; + +class TNetworkAddress { + friend class TSocket; + +public: + class TIterator { + public: + inline TIterator(struct addrinfo* begin) + : C_(begin) + { + } + + inline void Next() noexcept { + C_ = C_->ai_next; + } + + inline TIterator operator++(int) noexcept { + TIterator old(*this); + + Next(); + + return old; + } + + inline TIterator& operator++() noexcept { + Next(); + + return *this; + } + + friend inline bool operator==(const TIterator& l, const TIterator& r) noexcept { + return l.C_ == r.C_; + } + + friend inline bool operator!=(const TIterator& l, const TIterator& r) noexcept { + return !(l == r); + } + + inline struct addrinfo& operator*() const noexcept { + return *C_; + } + + inline struct addrinfo* operator->() const noexcept { + return C_; + } + + private: + struct addrinfo* C_; + }; + + TNetworkAddress(ui16 port); + TNetworkAddress(const TString& host, ui16 port); + TNetworkAddress(const TString& host, ui16 port, int flags); + TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags = 0); + ~TNetworkAddress(); + + inline TIterator Begin() const noexcept { + return TIterator(Info()); + } + + inline TIterator End() const noexcept { + return TIterator(nullptr); + } + +private: + struct addrinfo* Info() const noexcept; + +private: + class TImpl; + TSimpleIntrusivePtr<TImpl> Impl_; +}; + +class TSocket; + +class TSocketHolder: public TMoveOnly { +public: + inline TSocketHolder() + : Fd_(INVALID_SOCKET) + { + } + + inline TSocketHolder(SOCKET fd) + : Fd_(fd) + { + } + + inline TSocketHolder(TSocketHolder&& other) noexcept { + Fd_ = other.Fd_; + other.Fd_ = INVALID_SOCKET; + } + + inline TSocketHolder& operator=(TSocketHolder&& other) noexcept { + Close(); + Swap(other); + + return *this; + } + + inline ~TSocketHolder() { + Close(); + } + + inline SOCKET Release() noexcept { + SOCKET ret = Fd_; + Fd_ = INVALID_SOCKET; + return ret; + } + + void Close() noexcept; + + inline void ShutDown(int mode) const { + ::ShutDown(Fd_, mode); + } + + inline void Swap(TSocketHolder& r) noexcept { + DoSwap(Fd_, r.Fd_); + } + + inline bool Closed() const noexcept { + return Fd_ == INVALID_SOCKET; + } + + inline operator SOCKET() const noexcept { + return Fd_; + } + +private: + SOCKET Fd_; + + // do not allow construction of TSocketHolder from TSocket + TSocketHolder(const TSocket& fd); +}; + +class TSocket { +public: + using TPart = IOutputStream::TPart; + + class TOps { + public: + inline TOps() noexcept = default; + virtual ~TOps() = default; + + virtual ssize_t Send(SOCKET fd, const void* data, size_t len) = 0; + virtual ssize_t Recv(SOCKET fd, void* buf, size_t len) = 0; + virtual ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) = 0; + }; + + TSocket(); + TSocket(SOCKET fd); + TSocket(SOCKET fd, TOps* ops); + TSocket(const TNetworkAddress& addr); + TSocket(const TNetworkAddress& addr, const TDuration& timeOut); + TSocket(const TNetworkAddress& addr, const TInstant& deadLine); + + ~TSocket(); + + template <class T> + inline void SetSockOpt(int level, int optname, T opt) { + CheckedSetSockOpt(Fd(), level, optname, opt, "TSocket"); + } + + inline void SetSocketTimeout(long timeout) { + ::SetSocketTimeout(Fd(), timeout); + } + + inline void SetSocketTimeout(long sec, long msec) { + ::SetSocketTimeout(Fd(), sec, msec); + } + + inline void SetNoDelay(bool value) { + ::SetNoDelay(Fd(), value); + } + + inline void SetLinger(bool on, unsigned len) { + ::SetLinger(Fd(), on, len); + } + + inline void SetZeroLinger() { + ::SetZeroLinger(Fd()); + } + + inline void SetKeepAlive(bool value) { + ::SetKeepAlive(Fd(), value); + } + + inline void SetOutputBuffer(unsigned value) { + ::SetOutputBuffer(Fd(), value); + } + + inline void SetInputBuffer(unsigned value) { + ::SetInputBuffer(Fd(), value); + } + + inline size_t MaximumSegmentSize() const { + return GetMaximumSegmentSize(Fd()); + } + + inline size_t MaximumTransferUnit() const { + return GetMaximumTransferUnit(Fd()); + } + + inline void ShutDown(int mode) const { + ::ShutDown(Fd(), mode); + } + + void Close(); + + ssize_t Send(const void* data, size_t len); + ssize_t Recv(void* buf, size_t len); + + /* + * scatter/gather io + */ + ssize_t SendV(const TPart* parts, size_t count); + + inline operator SOCKET() const noexcept { + return Fd(); + } + +private: + SOCKET Fd() const noexcept; + +private: + class TImpl; + TSimpleIntrusivePtr<TImpl> Impl_; +}; + +class TSocketInput: public IInputStream { +public: + TSocketInput(const TSocket& s) noexcept; + ~TSocketInput() override; + + TSocketInput(TSocketInput&&) noexcept = default; + TSocketInput& operator=(TSocketInput&&) noexcept = default; + + const TSocket& GetSocket() const noexcept { + return S_; + } + +private: + size_t DoRead(void* buf, size_t len) override; + +private: + TSocket S_; +}; + +class TSocketOutput: public IOutputStream { +public: + TSocketOutput(const TSocket& s) noexcept; + ~TSocketOutput() override; + + TSocketOutput(TSocketOutput&&) noexcept = default; + TSocketOutput& operator=(TSocketOutput&&) noexcept = default; + + const TSocket& GetSocket() const noexcept { + return S_; + } + +private: + void DoWrite(const void* buf, size_t len) override; + void DoWriteV(const TPart* parts, size_t count) override; + +private: + TSocket S_; +}; + +//return -(error code) if error occured, or number of ready fds +ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept; |