#include "cache.h"

#include "thread.h"

#include <util/system/tls.h>
#include <util/system/info.h>
#include <util/system/rwlock.h>
#include <util/thread/singleton.h>
#include <util/generic/singleton.h>
#include <util/generic/hash.h>

using namespace NDns;

namespace {
    struct TResolveTask {
        enum EMethod {
            Normal,
            Threaded
        };

        inline TResolveTask(const TResolveInfo& info, EMethod method)
            : Info(info)
            , Method(method)
        {
        }

        const TResolveInfo& Info;
        const EMethod Method;
    };

    class IDns {
    public:
        virtual ~IDns() = default;
        virtual const TResolvedHost* Resolve(const TResolveTask&) = 0;
    };

    typedef TAtomicSharedPtr<TResolvedHost> TResolvedHostPtr;

    struct THashResolveInfo {
        inline size_t operator()(const TResolveInfo& ri) const {
            return ComputeHash(ri.Host) ^ ri.Port;
        }
    };

    struct TCompareResolveInfo {
        inline bool operator()(const NDns::TResolveInfo& x, const NDns::TResolveInfo& y) const {
            return x.Host == y.Host && x.Port == y.Port;
        }
    };

    class TGlobalCachedDns: public IDns, public TNonCopyable {
    public:
        const TResolvedHost* Resolve(const TResolveTask& rt) override {
            //2. search host in cache
            {
                TReadGuard guard(L_);

                TCache::const_iterator it = C_.find(rt.Info);

                if (it != C_.end()) {
                    return it->second.Get();
                }
            }

            TResolvedHostPtr res = ResolveA(rt);

            //update cache
            {
                TWriteGuard guard(L_);

                std::pair<TCache::iterator, bool> updateResult = C_.insert(std::make_pair(TResolveInfo(res->Host, rt.Info.Port), res));
                TResolvedHost* rh = updateResult.first->second.Get();

                if (updateResult.second) {
                    //fresh resolved host, set cache record id for it
                    rh->Id = C_.size() - 1;
                }

                return rh;
            }
        }

        void AddAlias(const TString& host, const TString& alias) noexcept {
            TWriteGuard guard(LA_);

            A_[host] = alias;
        }

        static inline TGlobalCachedDns* Instance() {
            return SingletonWithPriority<TGlobalCachedDns, 65530>();
        }

    private:
        inline TResolvedHostPtr ResolveA(const TResolveTask& rt) {
            TString originalHost(rt.Info.Host);
            TString host(originalHost);

            //3. replace host to alias, if exist
            if (A_.size()) {
                TReadGuard guard(LA_);
                TStringBuf names[] = {"*", host};

                for (const auto& name : names) {
                    TAliases::const_iterator it = A_.find(name);

                    if (it != A_.end()) {
                        host = it->second;
                    }
                }
            }

            if (host.length() > 2 && host[0] == '[') {
                TString unbracedIpV6(host.data() + 1, host.size() - 2);
                host.swap(unbracedIpV6);
            }

            TAutoPtr<TNetworkAddress> na;

            //4. getaddrinfo (direct or in separate thread)
            if (rt.Method == TResolveTask::Normal) {
                na.Reset(new TNetworkAddress(host, rt.Info.Port));
            } else if (rt.Method == TResolveTask::Threaded) {
                na = ThreadedResolve(host, rt.Info.Port);
            } else {
                Y_ASSERT(0);
                throw yexception() << TStringBuf("invalid resolve method");
            }

            return new TResolvedHost(originalHost, *na);
        }

        typedef THashMap<TResolveInfo, TResolvedHostPtr, THashResolveInfo, TCompareResolveInfo> TCache;
        TCache C_;
        TRWMutex L_;
        typedef THashMap<TString, TString> TAliases;
        TAliases A_;
        TRWMutex LA_;
    };

    class TCachedDns: public IDns {
    public:
        inline TCachedDns(IDns* slave)
            : S_(slave)
        {
        }

        const TResolvedHost* Resolve(const TResolveTask& rt) override {
            //1. search in local thread cache
            {
                TCache::const_iterator it = C_.find(rt.Info);

                if (it != C_.end()) {
                    return it->second;
                }
            }

            const TResolvedHost* res = S_->Resolve(rt);

            C_[TResolveInfo(res->Host, rt.Info.Port)] = res;

            return res;
        }

    private:
        typedef THashMap<TResolveInfo, const TResolvedHost*, THashResolveInfo, TCompareResolveInfo> TCache;
        TCache C_;
        IDns* S_;
    };

    struct TThreadedDns: public TCachedDns {
        inline TThreadedDns()
            : TCachedDns(TGlobalCachedDns::Instance())
        {
        }
    };

    inline IDns* ThrDns() {
        return FastTlsSingleton<TThreadedDns>();
    }
}

namespace NDns {
    const TResolvedHost* CachedResolve(const TResolveInfo& ri) {
        TResolveTask rt(ri, TResolveTask::Normal);

        return ThrDns()->Resolve(rt);
    }

    const TResolvedHost* CachedThrResolve(const TResolveInfo& ri) {
        TResolveTask rt(ri, TResolveTask::Threaded);

        return ThrDns()->Resolve(rt);
    }

    void AddHostAlias(const TString& host, const TString& alias) {
        TGlobalCachedDns::Instance()->AddAlias(host, alias);
    }
}