#include "udp.h"
#include "details.h"
#include "neh.h"
#include "location.h"
#include "utils.h"
#include "factory.h"

#include <library/cpp/dns/cache.h>

#include <util/network/socket.h>
#include <util/network/address.h>
#include <util/generic/deque.h>
#include <util/generic/hash.h>
#include <util/generic/string.h>
#include <util/generic/buffer.h>
#include <util/generic/singleton.h>
#include <util/digest/murmur.h>
#include <util/random/random.h>
#include <util/ysaveload.h>
#include <util/system/thread.h>
#include <util/system/pipe.h>
#include <util/system/error.h>
#include <util/stream/mem.h>
#include <util/stream/buffer.h>
#include <util/string/cast.h>

using namespace NNeh;
using namespace NDns;
using namespace NAddr;

namespace {
    namespace NUdp {
        enum EPacketType {
            PT_REQUEST = 1,
            PT_RESPONSE = 2,
            PT_STOP = 3,
            PT_TIMEOUT = 4
        };

        struct TUdpHandle: public TNotifyHandle {
            inline TUdpHandle(IOnRecv* r, const TMessage& msg, TStatCollector* sc) noexcept
                : TNotifyHandle(r, msg, sc)
            {
            }

            void Cancel() noexcept override {
                THandle::Cancel(); //inform stat collector
            }

            bool MessageSendedCompletely() const noexcept override {
                //TODO
                return true;
            }
        };

        static inline IRemoteAddrPtr GetSendAddr(SOCKET s) {
            IRemoteAddrPtr local = GetSockAddr(s);
            const sockaddr* addr = local->Addr();

            switch (addr->sa_family) {
                case AF_INET: {
                    const TIpAddress a = *(const sockaddr_in*)addr;

                    return MakeHolder<TIPv4Addr>(TIpAddress(InetToHost(INADDR_LOOPBACK), a.Port()));
                }

                case AF_INET6: {
                    sockaddr_in6 a = *(const sockaddr_in6*)addr;

                    a.sin6_addr = in6addr_loopback;

                    return MakeHolder<TIPv6Addr>(a);
                }
            }

            ythrow yexception() << "unsupported";
        }

        typedef ui32 TCheckSum;

        static inline TString GenerateGuid() {
            const ui64 res[2] = {
                RandomNumber<ui64>(), RandomNumber<ui64>()};

            return TString((const char*)res, sizeof(res));
        }

        static inline TCheckSum Sum(const TStringBuf& s) noexcept {
            return HostToInet(MurmurHash<TCheckSum>(s.data(), s.size()));
        }

        struct TPacket;

        template <class T>
        static inline void Serialize(TPacket& p, const T& t);

        struct TPacket {
            inline TPacket(IRemoteAddrPtr addr)
                : Addr(std::move(addr))
            {
            }

            template <class T>
            inline TPacket(const T& t, IRemoteAddrPtr addr)
                : Addr(std::move(addr))
            {
                NUdp::Serialize(*this, t);
            }

            inline TPacket(TSocketHolder& s, TBuffer& tmp) {
                TAutoPtr<TOpaqueAddr> addr(new TOpaqueAddr());

            retry_on_intr : {
                const int rv = recvfrom(s, tmp.Data(), tmp.size(), MSG_WAITALL, addr->MutableAddr(), addr->LenPtr());

                if (rv < 0) {
                    int err = LastSystemError();
                    if (err == EAGAIN || err == EWOULDBLOCK) {
                        Data.Resize(sizeof(TCheckSum) + 1);
                        *(Data.data() + sizeof(TCheckSum)) = static_cast<char>(PT_TIMEOUT);
                    } else if (err == EINTR) {
                        goto retry_on_intr;
                    } else {
                        ythrow TSystemError() << "recv failed";
                    }
                } else {
                    Data.Append(tmp.Data(), (size_t)rv);
                    Addr.Reset(addr.Release());
                    CheckSign();
                }
            }
            }

            inline void SendTo(TSocketHolder& s) {
                Sign();

                if (sendto(s, Data.data(), Data.size(), 0, Addr->Addr(), Addr->Len()) < 0) {
                    Cdbg << LastSystemErrorText() << Endl;
                }
            }

            IRemoteAddrPtr Addr;
            TBuffer Data;

            inline void Sign() {
                const TCheckSum sum = CalcSign();

                memcpy(Data.Data(), &sum, sizeof(sum));
            }

            inline char Type() const noexcept {
                return *(Data.data() + sizeof(TCheckSum));
            }

            inline void CheckSign() const {
                if (Data.size() < 16) {
                    ythrow yexception() << "small packet";
                }

                if (StoredSign() != CalcSign()) {
                    ythrow yexception() << "bad checksum";
                }
            }

            inline TCheckSum StoredSign() const noexcept {
                TCheckSum sum;

                memcpy(&sum, Data.Data(), sizeof(sum));

                return sum;
            }

            inline TCheckSum CalcSign() const noexcept {
                return Sum(Body());
            }

            inline TStringBuf Body() const noexcept {
                return TStringBuf(Data.data() + sizeof(TCheckSum), Data.End());
            }
        };

        typedef TAutoPtr<TPacket> TPacketRef;

        class TPacketInput: public TMemoryInput {
        public:
            inline TPacketInput(const TPacket& p)
                : TMemoryInput(p.Body().data(), p.Body().size())
            {
            }
        };

        class TPacketOutput: public TBufferOutput {
        public:
            inline TPacketOutput(TPacket& p)
                : TBufferOutput(p.Data)
            {
                p.Data.Proceed(sizeof(TCheckSum));
            }
        };

        template <class T>
        static inline void Serialize(TPacketOutput* out, const T& t) {
            Save(out, t.Type());
            t.Serialize(out);
        }

        template <class T>
        static inline void Serialize(TPacket& p, const T& t) {
            TPacketOutput out(p);

            NUdp::Serialize(&out, t);
        }

        namespace NPrivate {
            template <class T>
            static inline void Deserialize(TPacketInput* in, T& t) {
                char type;
                Load(in, type);

                if (type != t.Type()) {
                    ythrow yexception() << "unsupported packet";
                }

                t.Deserialize(in);
            }

            template <class T>
            static inline void Deserialize(const TPacket& p, T& t) {
                TPacketInput in(p);

                Deserialize(&in, t);
            }
        }

        struct TRequestPacket {
            TString Guid;
            TString Service;
            TString Data;

            inline TRequestPacket(const TPacket& p) {
                NPrivate::Deserialize(p, *this);
            }

            inline TRequestPacket(const TString& srv, const TString& data)
                : Guid(GenerateGuid())
                , Service(srv)
                , Data(data)
            {
            }

            inline char Type() const noexcept {
                return static_cast<char>(PT_REQUEST);
            }

            inline void Serialize(TPacketOutput* out) const {
                Save(out, Guid);
                Save(out, Service);
                Save(out, Data);
            }

            inline void Deserialize(TPacketInput* in) {
                Load(in, Guid);
                Load(in, Service);
                Load(in, Data);
            }
        };

        template <class TStore>
        struct TResponsePacket {
            TString Guid;
            TStore Data;

            inline TResponsePacket(const TString& guid, TStore& data)
                : Guid(guid)
            {
                Data.swap(data);
            }

            inline TResponsePacket(const TPacket& p) {
                NPrivate::Deserialize(p, *this);
            }

            inline char Type() const noexcept {
                return static_cast<char>(PT_RESPONSE);
            }

            inline void Serialize(TPacketOutput* out) const {
                Save(out, Guid);
                Save(out, Data);
            }

            inline void Deserialize(TPacketInput* in) {
                Load(in, Guid);
                Load(in, Data);
            }
        };

        struct TStopPacket {
            inline char Type() const noexcept {
                return static_cast<char>(PT_STOP);
            }

            inline void Serialize(TPacketOutput* out) const {
                Save(out, TString("stop packet"));
            }
        };

        struct TBindError: public TSystemError {
        };

        struct TSocketDescr {
            inline TSocketDescr(TSocketHolder& s, int family)
                : S(s.Release())
                , Family(family)
            {
            }

            TSocketHolder S;
            int Family;
        };

        typedef TAutoPtr<TSocketDescr> TSocketRef;
        typedef TVector<TSocketRef> TSockets;

        static inline void CreateSocket(TSocketHolder& s, const IRemoteAddr& addr) {
            TSocketHolder res(socket(addr.Addr()->sa_family, SOCK_DGRAM, IPPROTO_UDP));

            if (!res) {
                ythrow TSystemError() << "can not create socket";
            }

            FixIPv6ListenSocket(res);

            if (bind(res, addr.Addr(), addr.Len()) != 0) {
                ythrow TBindError() << "can not bind " << PrintHostAndPort(addr);
            }

            res.Swap(s);
        }

        static inline void CreateSockets(TSockets& s, ui16 port) {
            TNetworkAddress addr(port);

            for (TNetworkAddress::TIterator it = addr.Begin(); it != addr.End(); ++it) {
                TSocketHolder res;

                CreateSocket(res, TAddrInfo(&*it));

                s.push_back(new TSocketDescr(res, it->ai_family));
            }
        }

        static inline void CreateSocketsOnRandomPort(TSockets& s) {
            while (true) {
                try {
                    TSockets tmp;

                    CreateSockets(tmp, 5000 + (RandomNumber<ui16>() % 1000));
                    tmp.swap(s);

                    return;
                } catch (const TBindError&) {
                }
            }
        }

        typedef ui64 TTimeStamp;

        static inline TTimeStamp TimeStamp() noexcept {
            return GetCycleCount() >> 31;
        }

        struct TRequestDescr: public TIntrusiveListItem<TRequestDescr> {
            inline TRequestDescr(const TString& guid, const TNotifyHandleRef& hndl, const TMessage& msg)
                : Guid(guid)
                , Hndl(hndl)
                , Msg(msg)
                , TS(TimeStamp())
            {
            }

            TString Guid;
            TNotifyHandleRef Hndl;
            TMessage Msg;
            TTimeStamp TS;
        };

        typedef TAutoPtr<TRequestDescr> TRequestDescrRef;

        class TProto {
            class TRequest: public IRequest, public TRequestPacket {
            public:
                inline TRequest(TPacket& p, TProto* parent)
                    : TRequestPacket(p)
                    , Addr_(std::move(p.Addr))
                    , H_(PrintHostByRfc(*Addr_))
                    , P_(parent)
                {
                }

                TStringBuf Scheme() const override {
                    return TStringBuf("udp");
                }

                TString RemoteHost() const override {
                    return H_;
                }

                TStringBuf Service() const override {
                    return ((TRequestPacket&)(*this)).Service;
                }

                TStringBuf Data() const override {
                    return ((TRequestPacket&)(*this)).Data;
                }

                TStringBuf RequestId() const override {
                    return ((TRequestPacket&)(*this)).Guid;
                }

                bool Canceled() const override {
                    //TODO ?
                    return false;
                }

                void SendReply(TData& data) override {
                    P_->Schedule(new TPacket(TResponsePacket<TData>(Guid, data), std::move(Addr_)));
                }

                void SendError(TResponseError, const TString&) override {
                    // TODO
                }

            private:
                IRemoteAddrPtr Addr_;
                TString H_;
                TProto* P_;
            };

        public:
            inline TProto(IOnRequest* cb, TSocketHolder& s)
                : CB_(cb)
                , ToSendEv_(TSystemEvent::rAuto)
                , S_(s.Release())
            {
                SetSocketTimeout(S_, 10);
                Thrs_.push_back(Spawn<TProto, &TProto::ExecuteRecv>(this));
                Thrs_.push_back(Spawn<TProto, &TProto::ExecuteSend>(this));
            }

            inline ~TProto() {
                Schedule(new TPacket(TStopPacket(), GetSendAddr(S_)));

                for (size_t i = 0; i < Thrs_.size(); ++i) {
                    Thrs_[i]->Join();
                }
            }

            inline TPacketRef Recv() {
                TBuffer tmp;

                tmp.Resize(128 * 1024);

                while (true) {
                    try {
                        return new TPacket(S_, tmp);
                    } catch (...) {
                        Cdbg << CurrentExceptionMessage() << Endl;

                        continue;
                    }
                }
            }

            typedef THashMap<TString, TRequestDescrRef> TInFlyBase;

            struct TInFly: public TInFlyBase, public TIntrusiveList<TRequestDescr> {
                typedef TInFlyBase::iterator TIter;
                typedef TInFlyBase::const_iterator TContsIter;

                inline void Insert(TRequestDescrRef& d) {
                    PushBack(d.Get());
                    (*this)[d->Guid] = d;
                }

                inline void EraseStale() noexcept {
                    const TTimeStamp now = TimeStamp();

                    for (TIterator it = Begin(); (it != End()) && (it->TS < now) && ((now - it->TS) > 120);) {
                        it->Hndl->NotifyError("request timeout");
                        TString safe_key = (it++)->Guid;
                        erase(safe_key);
                    }
                }
            };

            inline void ExecuteRecv() {
                SetHighestThreadPriority();

                TInFly infly;

                while (true) {
                    TPacketRef p = Recv();

                    switch (static_cast<EPacketType>(p->Type())) {
                        case PT_REQUEST:
                            if (CB_) {
                                CB_->OnRequest(new TRequest(*p, this));
                            } else {
                                //skip request in case of client
                            }

                            break;

                        case PT_RESPONSE: {
                            CancelStaleRequests(infly);

                            TResponsePacket<TString> rp(*p);

                            TInFly::TIter it = static_cast<TInFlyBase&>(infly).find(rp.Guid);

                            if (it == static_cast<TInFlyBase&>(infly).end()) {
                                break;
                            }

                            const TRequestDescrRef& d = it->second;
                            d->Hndl->NotifyResponse(rp.Data);

                            infly.erase(it);

                            break;
                        }

                        case PT_STOP:
                            Schedule(nullptr);

                            return;

                        case PT_TIMEOUT:
                            CancelStaleRequests(infly);

                            break;
                    }
                }
            }

            inline void ExecuteSend() {
                SetHighestThreadPriority();

                while (true) {
                    TPacketRef p;

                    while (!ToSend_.Dequeue(&p)) {
                        ToSendEv_.Wait();
                    }

                    //shutdown
                    if (!p) {
                        return;
                    }

                    p->SendTo(S_);
                }
            }

            inline void Schedule(TPacketRef p) {
                ToSend_.Enqueue(p);
                ToSendEv_.Signal();
            }

            inline void Schedule(TRequestDescrRef dsc, TPacketRef p) {
                ScheduledReqs_.Enqueue(dsc);
                Schedule(p);
            }

        protected:
            void CancelStaleRequests(TInFly& infly) {
                TRequestDescrRef d;

                while (ScheduledReqs_.Dequeue(&d)) {
                    infly.Insert(d);
                }

                infly.EraseStale();
            }

            IOnRequest* CB_;
            NNeh::TAutoLockFreeQueue<TPacket> ToSend_;
            NNeh::TAutoLockFreeQueue<TRequestDescr> ScheduledReqs_;
            TSystemEvent ToSendEv_;
            TSocketHolder S_;
            TVector<TThreadRef> Thrs_;
        };

        class TProtos {
        public:
            inline TProtos() {
                TSockets s;

                CreateSocketsOnRandomPort(s);
                Init(nullptr, s);
            }

            inline TProtos(IOnRequest* cb, ui16 port) {
                TSockets s;

                CreateSockets(s, port);
                Init(cb, s);
            }

            static inline TProtos* Instance() {
                return Singleton<TProtos>();
            }

            inline void Schedule(const TMessage& msg, const TNotifyHandleRef& hndl) {
                TParsedLocation loc(msg.Addr);
                const TNetworkAddress* addr = &CachedThrResolve(TResolveInfo(loc.Host, loc.GetPort()))->Addr;

                for (TNetworkAddress::TIterator ai = addr->Begin(); ai != addr->End(); ++ai) {
                    TProto* proto = Find(ai->ai_family);

                    if (proto) {
                        TRequestPacket rp(ToString(loc.Service), msg.Data);
                        TRequestDescrRef rd(new TRequestDescr(rp.Guid, hndl, msg));
                        IRemoteAddrPtr raddr(new TAddrInfo(&*ai));
                        TPacketRef p(new TPacket(rp, std::move(raddr)));

                        proto->Schedule(rd, p);

                        return;
                    }
                }

                ythrow yexception() << "unsupported protocol family";
            }

        private:
            inline void Init(IOnRequest* cb, TSockets& s) {
                for (auto& it : s) {
                    P_[it->Family] = new TProto(cb, it->S);
                }
            }

            inline TProto* Find(int family) const {
                TProtoStorage::const_iterator it = P_.find(family);

                if (it == P_.end()) {
                    return nullptr;
                }

                return it->second.Get();
            }

        private:
            typedef TAutoPtr<TProto> TProtoRef;
            typedef THashMap<int, TProtoRef> TProtoStorage;
            TProtoStorage P_;
        };

        class TRequester: public IRequester, public TProtos {
        public:
            inline TRequester(IOnRequest* cb, ui16 port)
                : TProtos(cb, port)
            {
            }
        };

        class TProtocol: public IProtocol {
        public:
            IRequesterRef CreateRequester(IOnRequest* cb, const TParsedLocation& loc) override {
                return new TRequester(cb, loc.GetPort());
            }

            THandleRef ScheduleRequest(const TMessage& msg, IOnRecv* fallback, TServiceStatRef& ss) override {
                TNotifyHandleRef ret(new TUdpHandle(fallback, msg, !ss ? nullptr : new TStatCollector(ss)));

                TProtos::Instance()->Schedule(msg, ret);

                return ret.Get();
            }

            TStringBuf Scheme() const noexcept override {
                return TStringBuf("udp");
            }
        };
    }
}

IProtocol* NNeh::UdpProtocol() {
    return Singleton<NUdp::TProtocol>();
}