#pragma once

#include "init.h"

#include <util/system/yassert.h>
#include <util/system/defaults.h>
#include <util/system/error.h>
#include <util/stream/output.h>
#include <util/stream/input.h>
#include <util/generic/ptr.h>
#include <util/generic/yexception.h>
#include <util/generic/noncopyable.h>
#include <util/datetime/base.h>

#include <cerrno>

#ifndef INET_ADDRSTRLEN
    #define INET_ADDRSTRLEN 16
#endif

#if defined(_unix_)
    #define get_host_error() h_errno
#elif defined(_win_)
    #pragma comment(lib, "Ws2_32.lib")

    #if _WIN32_WINNT < 0x0600
struct pollfd {
    SOCKET fd;
    short events;
    short revents;
};

        #define POLLIN (1 << 0)
        #define POLLRDNORM (1 << 1)
        #define POLLRDBAND (1 << 2)
        #define POLLPRI (1 << 3)
        #define POLLOUT (1 << 4)
        #define POLLWRNORM (1 << 5)
        #define POLLWRBAND (1 << 6)
        #define POLLERR (1 << 7)
        #define POLLHUP (1 << 8)
        #define POLLNVAL (1 << 9)

const char* inet_ntop(int af, const void* src, char* dst, socklen_t size);
int poll(struct pollfd fds[], nfds_t nfds, int timeout) noexcept;
    #else
        #define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout)
    #endif

int inet_aton(const char* cp, struct in_addr* inp);

    #define get_host_error() WSAGetLastError()

    #define SHUT_RD SD_RECEIVE
    #define SHUT_WR SD_SEND
    #define SHUT_RDWR SD_BOTH

    #define INFTIM (-1)
#endif

template <class T>
static inline int SetSockOpt(SOCKET s, int level, int optname, T opt) noexcept {
    return setsockopt(s, level, optname, (const char*)&opt, sizeof(opt));
}

template <class T>
static inline int GetSockOpt(SOCKET s, int level, int optname, T& opt) noexcept {
    socklen_t len = sizeof(opt);

    return getsockopt(s, level, optname, (char*)&opt, &len);
}

template <class T>
static inline void CheckedSetSockOpt(SOCKET s, int level, int optname, T opt, const char* err) {
    if (SetSockOpt<T>(s, level, optname, opt)) {
        ythrow TSystemError() << "setsockopt() failed for " << err;
    }
}

template <class T>
static inline void CheckedGetSockOpt(SOCKET s, int level, int optname, T& opt, const char* err) {
    if (GetSockOpt<T>(s, level, optname, opt)) {
        ythrow TSystemError() << "getsockopt() failed for " << err;
    }
}

static inline void FixIPv6ListenSocket(SOCKET s) {
#if defined(IPV6_V6ONLY)
    SetSockOpt(s, IPPROTO_IPV6, IPV6_V6ONLY, 1);
#else
    (void)s;
#endif
}

namespace NAddr {
    class IRemoteAddr;
}

void SetSocketTimeout(SOCKET s, long timeout);
void SetSocketTimeout(SOCKET s, long sec, long msec);
void SetNoDelay(SOCKET s, bool value);
void SetKeepAlive(SOCKET s);
void SetLinger(SOCKET s, bool on, unsigned len);
void SetZeroLinger(SOCKET s);
void SetKeepAlive(SOCKET s, bool value);
void SetCloseOnExec(SOCKET s, bool value);
void SetOutputBuffer(SOCKET s, unsigned value);
void SetInputBuffer(SOCKET s, unsigned value);
void SetReusePort(SOCKET s, bool value);
void ShutDown(SOCKET s, int mode);
bool GetRemoteAddr(SOCKET s, char* str, socklen_t size);
size_t GetMaximumSegmentSize(SOCKET s);
size_t GetMaximumTransferUnit(SOCKET s);
void SetDeferAccept(SOCKET s);
void SetSocketToS(SOCKET s, int tos);
void SetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr, int tos);
int GetSocketToS(SOCKET s);
int GetSocketToS(SOCKET s, const NAddr::IRemoteAddr* addr);
void SetSocketPriority(SOCKET s, int priority);
void SetTcpFastOpen(SOCKET s, int qlen);
/**
 * Deprecated, consider using HasSocketDataToRead instead.
 **/
bool IsNotSocketClosedByOtherSide(SOCKET s);
enum class ESocketReadStatus {
    HasData,
    NoData,
    SocketClosed
};
/**
 * Useful for keep-alive connections.
 **/
ESocketReadStatus HasSocketDataToRead(SOCKET s);
/**
 * Determines whether connection on socket is local (same machine) or not.
 **/
bool HasLocalAddress(SOCKET socket);

/**
 * Runtime check if current kernel supports SO_REUSEPORT option.
 **/
extern "C" bool IsReusePortAvailable();

bool IsNonBlock(SOCKET fd);
void SetNonBlock(SOCKET fd, bool nonBlock = true);

struct addrinfo;

class TNetworkResolutionError: public yexception {
public:
    // @param error error code (EAI_XXX) returned by getaddrinfo or getnameinfo (not errno)
    TNetworkResolutionError(int error);
};

struct TUnixSocketPath {
    TString Path;

    // Constructor for create unix domain socket path from string with path in filesystem
    // TUnixSocketPath("/tmp/unixsocket") -> "/tmp/unixsocket"
    explicit TUnixSocketPath(const TString& path)
        : Path(path)
    {
    }
};

class TNetworkAddress {
    friend class TSocket;

public:
    class TIterator {
    public:
        inline TIterator(struct addrinfo* begin)
            : C_(begin)
        {
        }

        inline void Next() noexcept {
            C_ = C_->ai_next;
        }

        inline TIterator operator++(int) noexcept {
            TIterator old(*this);

            Next();

            return old;
        }

        inline TIterator& operator++() noexcept {
            Next();

            return *this;
        }

        friend inline bool operator==(const TIterator& l, const TIterator& r) noexcept {
            return l.C_ == r.C_;
        }

        friend inline bool operator!=(const TIterator& l, const TIterator& r) noexcept {
            return !(l == r);
        }

        inline struct addrinfo& operator*() const noexcept {
            return *C_;
        }

        inline struct addrinfo* operator->() const noexcept {
            return C_;
        }

    private:
        struct addrinfo* C_;
    };

    TNetworkAddress(ui16 port);
    TNetworkAddress(const TString& host, ui16 port);
    TNetworkAddress(const TString& host, ui16 port, int flags);
    TNetworkAddress(const TUnixSocketPath& unixSocketPath, int flags = 0);
    ~TNetworkAddress();

    inline TIterator Begin() const noexcept {
        return TIterator(Info());
    }

    inline TIterator End() const noexcept {
        return TIterator(nullptr);
    }

private:
    struct addrinfo* Info() const noexcept;

private:
    class TImpl;
    TSimpleIntrusivePtr<TImpl> Impl_;
};

class TSocket;

class TSocketHolder: public TMoveOnly {
public:
    inline TSocketHolder()
        : Fd_(INVALID_SOCKET)
    {
    }

    inline TSocketHolder(SOCKET fd)
        : Fd_(fd)
    {
    }

    inline TSocketHolder(TSocketHolder&& other) noexcept {
        Fd_ = other.Fd_;
        other.Fd_ = INVALID_SOCKET;
    }

    inline TSocketHolder& operator=(TSocketHolder&& other) noexcept {
        Close();
        Swap(other);

        return *this;
    }

    inline ~TSocketHolder() {
        Close();
    }

    inline SOCKET Release() noexcept {
        SOCKET ret = Fd_;
        Fd_ = INVALID_SOCKET;
        return ret;
    }

    void Close() noexcept;

    inline void ShutDown(int mode) const {
        ::ShutDown(Fd_, mode);
    }

    inline void Swap(TSocketHolder& r) noexcept {
        DoSwap(Fd_, r.Fd_);
    }

    inline bool Closed() const noexcept {
        return Fd_ == INVALID_SOCKET;
    }

    inline operator SOCKET() const noexcept {
        return Fd_;
    }

private:
    SOCKET Fd_;

    // do not allow construction of TSocketHolder from TSocket
    TSocketHolder(const TSocket& fd);
};

class TSocket {
public:
    using TPart = IOutputStream::TPart;

    class TOps {
    public:
        inline TOps() noexcept = default;
        virtual ~TOps() = default;

        virtual ssize_t Send(SOCKET fd, const void* data, size_t len) = 0;
        virtual ssize_t Recv(SOCKET fd, void* buf, size_t len) = 0;
        virtual ssize_t SendV(SOCKET fd, const TPart* parts, size_t count) = 0;
    };

    TSocket();
    TSocket(SOCKET fd);
    TSocket(SOCKET fd, TOps* ops);
    TSocket(const TNetworkAddress& addr);
    TSocket(const TNetworkAddress& addr, const TDuration& timeOut);
    TSocket(const TNetworkAddress& addr, const TInstant& deadLine);

    ~TSocket();

    template <class T>
    inline void SetSockOpt(int level, int optname, T opt) {
        CheckedSetSockOpt(Fd(), level, optname, opt, "TSocket");
    }

    inline void SetSocketTimeout(long timeout) {
        ::SetSocketTimeout(Fd(), timeout);
    }

    inline void SetSocketTimeout(long sec, long msec) {
        ::SetSocketTimeout(Fd(), sec, msec);
    }

    inline void SetNoDelay(bool value) {
        ::SetNoDelay(Fd(), value);
    }

    inline void SetLinger(bool on, unsigned len) {
        ::SetLinger(Fd(), on, len);
    }

    inline void SetZeroLinger() {
        ::SetZeroLinger(Fd());
    }

    inline void SetKeepAlive(bool value) {
        ::SetKeepAlive(Fd(), value);
    }

    inline void SetOutputBuffer(unsigned value) {
        ::SetOutputBuffer(Fd(), value);
    }

    inline void SetInputBuffer(unsigned value) {
        ::SetInputBuffer(Fd(), value);
    }

    inline size_t MaximumSegmentSize() const {
        return GetMaximumSegmentSize(Fd());
    }

    inline size_t MaximumTransferUnit() const {
        return GetMaximumTransferUnit(Fd());
    }

    inline void ShutDown(int mode) const {
        ::ShutDown(Fd(), mode);
    }

    void Close();

    ssize_t Send(const void* data, size_t len);
    ssize_t Recv(void* buf, size_t len);

    /*
     * scatter/gather io
     */
    ssize_t SendV(const TPart* parts, size_t count);

    inline operator SOCKET() const noexcept {
        return Fd();
    }

private:
    SOCKET Fd() const noexcept;

private:
    class TImpl;
    TSimpleIntrusivePtr<TImpl> Impl_;
};

class TSocketInput: public IInputStream {
public:
    TSocketInput(const TSocket& s) noexcept;
    ~TSocketInput() override;

    TSocketInput(TSocketInput&&) noexcept = default;
    TSocketInput& operator=(TSocketInput&&) noexcept = default;

    const TSocket& GetSocket() const noexcept {
        return S_;
    }

private:
    size_t DoRead(void* buf, size_t len) override;

private:
    TSocket S_;
};

class TSocketOutput: public IOutputStream {
public:
    TSocketOutput(const TSocket& s) noexcept;
    ~TSocketOutput() override;

    TSocketOutput(TSocketOutput&&) noexcept = default;
    TSocketOutput& operator=(TSocketOutput&&) noexcept = default;

    const TSocket& GetSocket() const noexcept {
        return S_;
    }

private:
    void DoWrite(const void* buf, size_t len) override;
    void DoWriteV(const TPart* parts, size_t count) override;

private:
    TSocket S_;
};

//return -(error code) if error occured, or number of ready fds
ssize_t PollD(struct pollfd fds[], nfds_t nfds, const TInstant& deadLine) noexcept;