diff options
| author | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 | 
|---|---|---|
| committer | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 | 
| commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
| tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/actors/dnsresolver/dnsresolver.cpp | |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/actors/dnsresolver/dnsresolver.cpp')
| -rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver.cpp | 475 | 
1 files changed, 475 insertions, 0 deletions
diff --git a/library/cpp/actors/dnsresolver/dnsresolver.cpp b/library/cpp/actors/dnsresolver/dnsresolver.cpp new file mode 100644 index 00000000000..6329bb00833 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver.cpp @@ -0,0 +1,475 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/core/hfunc.h> +#include <library/cpp/threading/queue/mpsc_htswap.h> +#include <util/network/pair.h> +#include <util/network/socket.h> +#include <util/string/builder.h> +#include <util/system/thread.h> + +#include <ares.h> + +#include <atomic> + +namespace NActors { +namespace NDnsResolver { + +    class TAresLibraryInitBase { +    protected: +        TAresLibraryInitBase() noexcept { +            int status = ares_library_init(ARES_LIB_INIT_ALL); +            Y_VERIFY(status == ARES_SUCCESS, "Unexpected failure to initialize c-ares library"); +        } + +        ~TAresLibraryInitBase() noexcept { +            ares_library_cleanup(); +        } +    }; + +    class TCallbackQueueBase { +    protected: +        TCallbackQueueBase() noexcept { +            int err = SocketPair(Sockets, false, true); +            Y_VERIFY(err == 0, "Unexpected failure to create a socket pair"); +            SetNonBlock(Sockets[0]); +            SetNonBlock(Sockets[1]); +        } + +        ~TCallbackQueueBase() noexcept { +            closesocket(Sockets[0]); +            closesocket(Sockets[1]); +        } + +    protected: +        using TCallback = std::function<void()>; +        using TCallbackQueue = NThreading::THTSwapQueue<TCallback>; + +        void PushCallback(TCallback callback) { +            Y_VERIFY(callback, "Cannot push an empty callback"); +            CallbackQueue.Push(std::move(callback)); // this is a lockfree queue + +            // Wake up worker thread on the first activation +            if (Activations.fetch_add(1, std::memory_order_acq_rel) == 0) { +                char ch = 'x'; +                ssize_t ret; +#ifdef _win_ +                ret = send(SignalSock(), &ch, 1, 0); +                if (ret == -1) { +                    Y_VERIFY(WSAGetLastError() == WSAEWOULDBLOCK, "Unexpected send error"); +                    return; +                } +#else +                do { +                    ret = send(SignalSock(), &ch, 1, 0); +                } while (ret == -1 && errno == EINTR); +                if (ret == -1) { +                    Y_VERIFY(errno == EAGAIN || errno == EWOULDBLOCK, "Unexpected send error"); +                    return; +                } +#endif +                Y_VERIFY(ret == 1, "Unexpected send result"); +            } +        } + +        void RunCallbacks() noexcept { +            char ch[32]; +            ssize_t ret; +            bool signalled = false; +            for (;;) { +                ret = recv(WaitSock(), ch, sizeof(ch), 0); +                if (ret > 0) { +                    signalled = true; +                } +                if (ret == sizeof(ch)) { +                    continue; +                } +                if (ret != -1) { +                    break; +                } +#ifdef _win_ +                if (WSAGetLastError() == WSAEWOULDBLOCK) { +                    break; +                } +                Y_FAIL("Unexpected recv error"); +#else +                if (errno == EAGAIN || errno == EWOULDBLOCK) { +                    break; +                } +                Y_VERIFY(errno == EINTR, "Unexpected recv error"); +#endif +            } + +            if (signalled) { +                // There's exactly one write to SignalSock while Activations != 0 +                // It's impossible to get signalled while Activations == 0 +                // We must set Activations = 0 to receive new signals +                size_t count = Activations.exchange(0, std::memory_order_acq_rel); +                Y_VERIFY(count != 0); + +                // N.B. due to the way HTSwap works we may not be able to pop +                // all callbacks on this activation, however we expect a new +                // delayed activation to happen at a later time. +                while (auto callback = CallbackQueue.Pop()) { +                    callback(); +                } +            } +        } + +        SOCKET SignalSock() { +            return Sockets[0]; +        } + +        SOCKET WaitSock() { +            return Sockets[1]; +        } + +    private: +        SOCKET Sockets[2]; +        TCallbackQueue CallbackQueue; +        std::atomic<size_t> Activations{ 0 }; +    }; + +    class TSimpleDnsResolver +        : public TActor<TSimpleDnsResolver> +        , private TAresLibraryInitBase +        , private TCallbackQueueBase +    { +    public: +        TSimpleDnsResolver(TSimpleDnsResolverOptions options) noexcept +            : TActor(&TThis::StateWork) +            , Options(std::move(options)) +            , WorkerThread(&TThis::WorkerThreadStart, this) +        { +            InitAres(); + +            WorkerThread.Start(); +        } + +        ~TSimpleDnsResolver() noexcept override { +            if (!Stopped) { +                PushCallback([this] { +                    // Mark as stopped first +                    Stopped = true; + +                    // Cancel all current ares requests (will not send replies) +                    ares_cancel(AresChannel); +                }); + +                WorkerThread.Join(); +            } + +            StopAres(); +        } + +        static constexpr EActivityType ActorActivityType() { +            return DNS_RESOLVER; +        } + +    private: +        void InitAres() noexcept { +            struct ares_options options; +            memset(&options, 0, sizeof(options)); +            int optmask = 0; + +            options.flags = ARES_FLAG_STAYOPEN; +            optmask |= ARES_OPT_FLAGS; + +            options.sock_state_cb = &TThis::SockStateCallback; +            options.sock_state_cb_data = this; +            optmask |= ARES_OPT_SOCK_STATE_CB; + +            options.timeout = Options.Timeout.MilliSeconds(); +            if (options.timeout > 0) { +                optmask |= ARES_OPT_TIMEOUTMS; +            } + +            options.tries = Options.Attempts; +            if (options.tries > 0) { +                optmask |= ARES_OPT_TRIES; +            } + +            int err = ares_init_options(&AresChannel, &options, optmask); +            Y_VERIFY(err == 0, "Unexpected failure to initialize c-ares channel"); + +            if (Options.Servers) { +                TStringBuilder csv; +                for (const TString& server : Options.Servers) { +                    if (csv) { +                        csv << ','; +                    } +                    csv << server; +                } +                err = ares_set_servers_ports_csv(AresChannel, csv.c_str()); +                Y_VERIFY(err == 0, "Unexpected failure to set a list of dns servers: %s", ares_strerror(err)); +            } +        } + +        void StopAres() noexcept { +            // Destroy the ares channel +            ares_destroy(AresChannel); +            AresChannel = nullptr; +        } + +    private: +        STRICT_STFUNC(StateWork, { +            hFunc(TEvents::TEvPoison, Handle); +            hFunc(TEvDns::TEvGetHostByName, Handle); +            hFunc(TEvDns::TEvGetAddr, Handle); +        }) + +        void Handle(TEvents::TEvPoison::TPtr&) { +            Y_VERIFY(!Stopped); + +            PushCallback([this] { +                // Cancel all current ares requests (will send notifications) +                ares_cancel(AresChannel); + +                // Mark as stopped last +                Stopped = true; +            }); + +            WorkerThread.Join(); +            PassAway(); +        } + +    private: +        enum class ERequestType { +            GetHostByName, +            GetAddr, +        }; + +        struct TRequestContext : public TThrRefBase { +            using TPtr = TIntrusivePtr<TRequestContext>; + +            TThis* Self; +            TActorSystem* ActorSystem; +            TActorId SelfId; +            TActorId Sender; +            ui64 Cookie; +            ERequestType Type; + +            TRequestContext(TThis* self, TActorSystem* as, TActorId selfId, TActorId sender, ui64 cookie, ERequestType type) +                : Self(self) +                , ActorSystem(as) +                , SelfId(selfId) +                , Sender(sender) +                , Cookie(cookie) +                , Type(type) +            { } +        }; + +    private: +        void Handle(TEvDns::TEvGetHostByName::TPtr& ev) { +            auto* msg = ev->Get(); +            auto reqCtx = MakeIntrusive<TRequestContext>( +                this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetHostByName); +            PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable { +                StartGetHostByName(std::move(reqCtx), std::move(name), family); +            }); +        } + +        void Handle(TEvDns::TEvGetAddr::TPtr& ev) { +            auto* msg = ev->Get(); +            auto reqCtx = MakeIntrusive<TRequestContext>( +                this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetAddr); +            PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable { +                StartGetHostByName(std::move(reqCtx), std::move(name), family); +            }); +        } + +        void StartGetHostByName(TRequestContext::TPtr reqCtx, TString name, int family) noexcept { +            reqCtx->Ref(); +            ares_gethostbyname(AresChannel, name.c_str(), family, +                &TThis::GetHostByNameAresCallback, reqCtx.Get()); +        } + +    private: +        static void GetHostByNameAresCallback(void* arg, int status, int timeouts, struct hostent* info) { +            Y_UNUSED(timeouts); +            TRequestContext::TPtr reqCtx(static_cast<TRequestContext*>(arg)); +            reqCtx->UnRef(); + +            if (reqCtx->Self->Stopped) { +                // Don't send any replies after destruction +                return; +            } + +            switch (reqCtx->Type) { +                case ERequestType::GetHostByName: { +                    auto result = MakeHolder<TEvDns::TEvGetHostByNameResult>(); +                    if (status == 0) { +                        switch (info->h_addrtype) { +                            case AF_INET: { +                                for (int i = 0; info->h_addr_list[i] != nullptr; ++i) { +                                    result->AddrsV4.emplace_back(*(struct in_addr*)(info->h_addr_list[i])); +                                } +                                break; +                            } +                            case AF_INET6: { +                                for (int i = 0; info->h_addr_list[i] != nullptr; ++i) { +                                    result->AddrsV6.emplace_back(*(struct in6_addr*)(info->h_addr_list[i])); +                                } +                                break; +                            } +                            default: +                                Y_FAIL("unknown address family in ares callback"); +                        } +                    } else { +                        result->ErrorText = ares_strerror(status); +                    } +                    result->Status = status; + +                    reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie)); +                    break; +                } + +                case ERequestType::GetAddr: { +                    auto result = MakeHolder<TEvDns::TEvGetAddrResult>(); +                    if (status == 0 && Y_UNLIKELY(info->h_addr_list[0] == nullptr)) { +                        status = ARES_ENODATA; +                    } +                    if (status == 0) { +                        switch (info->h_addrtype) { +                            case AF_INET: { +                                result->Addr = *(struct in_addr*)(info->h_addr_list[0]); +                                break; +                            } +                            case AF_INET6: { +                                result->Addr = *(struct in6_addr*)(info->h_addr_list[0]); +                                break; +                            } +                            default: +                                Y_FAIL("unknown address family in ares callback"); +                        } +                    } else { +                        result->ErrorText = ares_strerror(status); +                    } +                    result->Status = status; + +                    reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie)); +                    break; +                } +            } +        } + +    private: +        static void SockStateCallback(void* data, ares_socket_t socket_fd, int readable, int writable) { +            static_cast<TThis*>(data)->DoSockStateCallback(socket_fd, readable, writable); +        } + +        void DoSockStateCallback(ares_socket_t socket_fd, int readable, int writable) noexcept { +            int events = (readable ? (POLLRDNORM | POLLIN) : 0) | (writable ? (POLLWRNORM | POLLOUT) : 0); +            if (events == 0) { +                AresSockStates.erase(socket_fd); +            } else { +                AresSockStates[socket_fd].NeededEvents = events; +            } +        } + +    private: +        static void* WorkerThreadStart(void* arg) noexcept { +            static_cast<TSimpleDnsResolver*>(arg)->WorkerThreadLoop(); +            return nullptr; +        } + +        void WorkerThreadLoop() noexcept { +            TThread::SetCurrentThreadName("DnsResolver"); + +            TVector<struct pollfd> fds; +            while (!Stopped) { +                fds.clear(); +                fds.reserve(1 + AresSockStates.size()); +                { +                    auto& entry = fds.emplace_back(); +                    entry.fd = WaitSock(); +                    entry.events = POLLRDNORM | POLLIN; +                } +                for (auto& kv : AresSockStates) { +                    auto& entry = fds.emplace_back(); +                    entry.fd = kv.first; +                    entry.events = kv.second.NeededEvents; +                } + +                int timeout = -1; +                struct timeval tv; +                if (ares_timeout(AresChannel, nullptr, &tv)) { +                    timeout = tv.tv_sec * 1000 + tv.tv_usec / 1000; +                } + +                int ret = poll(fds.data(), fds.size(), timeout); +                if (ret == -1) { +                    if (errno == EINTR) { +                        continue; +                    } +                    // we cannot handle failures, run callbacks and pretend everything is ok +                    RunCallbacks(); +                    if (Stopped) { +                        break; +                    } +                    ret = 0; +                } + +                bool ares_called = false; +                if (ret > 0) { +                    for (size_t i = 0; i < fds.size(); ++i) { +                        auto& entry = fds[i]; + +                        // Handle WaitSock activation and run callbacks +                        if (i == 0) { +                            if (entry.revents & (POLLRDNORM | POLLIN)) { +                                RunCallbacks(); +                                if (Stopped) { +                                    break; +                                } +                            } +                            continue; +                        } + +                        // All other sockets belong to ares +                        if (entry.revents == 0) { +                            continue; +                        } +                        // Previous invocation of aress_process_fd might have removed some sockets +                        if (Y_UNLIKELY(!AresSockStates.contains(entry.fd))) { +                            continue; +                        } +                        ares_process_fd( +                                AresChannel, +                                entry.revents & (POLLRDNORM | POLLIN) ? entry.fd : ARES_SOCKET_BAD, +                                entry.revents & (POLLWRNORM | POLLOUT) ? entry.fd : ARES_SOCKET_BAD); +                        ares_called = true; +                    } + +                    if (Stopped) { +                        break; +                    } +                } + +                if (!ares_called) { +                    // Let ares handle timeouts +                    ares_process_fd(AresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); +                } +            } +        } + +    private: +        struct TSockState { +            short NeededEvents = 0; // poll events +        }; + +    private: +        TSimpleDnsResolverOptions Options; +        TThread WorkerThread; + +        ares_channel AresChannel; +        THashMap<SOCKET, TSockState> AresSockStates; + +        bool Stopped = false; +    }; + +    IActor* CreateSimpleDnsResolver(TSimpleDnsResolverOptions options) { +        return new TSimpleDnsResolver(std::move(options)); +    } + +} // namespace NDnsResolver +} // namespace NActors  | 
