summaryrefslogtreecommitdiffstats
path: root/library/cpp/actors/dnsresolver/dnsresolver.cpp
diff options
context:
space:
mode:
authorDevtools Arcadia <[email protected]>2022-02-07 18:08:42 +0300
committerDevtools Arcadia <[email protected]>2022-02-07 18:08:42 +0300
commit1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch)
treee26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/actors/dnsresolver/dnsresolver.cpp
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/actors/dnsresolver/dnsresolver.cpp')
-rw-r--r--library/cpp/actors/dnsresolver/dnsresolver.cpp475
1 files changed, 475 insertions, 0 deletions
diff --git a/library/cpp/actors/dnsresolver/dnsresolver.cpp b/library/cpp/actors/dnsresolver/dnsresolver.cpp
new file mode 100644
index 00000000000..6329bb00833
--- /dev/null
+++ b/library/cpp/actors/dnsresolver/dnsresolver.cpp
@@ -0,0 +1,475 @@
+#include "dnsresolver.h"
+
+#include <library/cpp/actors/core/hfunc.h>
+#include <library/cpp/threading/queue/mpsc_htswap.h>
+#include <util/network/pair.h>
+#include <util/network/socket.h>
+#include <util/string/builder.h>
+#include <util/system/thread.h>
+
+#include <ares.h>
+
+#include <atomic>
+
+namespace NActors {
+namespace NDnsResolver {
+
+ class TAresLibraryInitBase {
+ protected:
+ TAresLibraryInitBase() noexcept {
+ int status = ares_library_init(ARES_LIB_INIT_ALL);
+ Y_VERIFY(status == ARES_SUCCESS, "Unexpected failure to initialize c-ares library");
+ }
+
+ ~TAresLibraryInitBase() noexcept {
+ ares_library_cleanup();
+ }
+ };
+
+ class TCallbackQueueBase {
+ protected:
+ TCallbackQueueBase() noexcept {
+ int err = SocketPair(Sockets, false, true);
+ Y_VERIFY(err == 0, "Unexpected failure to create a socket pair");
+ SetNonBlock(Sockets[0]);
+ SetNonBlock(Sockets[1]);
+ }
+
+ ~TCallbackQueueBase() noexcept {
+ closesocket(Sockets[0]);
+ closesocket(Sockets[1]);
+ }
+
+ protected:
+ using TCallback = std::function<void()>;
+ using TCallbackQueue = NThreading::THTSwapQueue<TCallback>;
+
+ void PushCallback(TCallback callback) {
+ Y_VERIFY(callback, "Cannot push an empty callback");
+ CallbackQueue.Push(std::move(callback)); // this is a lockfree queue
+
+ // Wake up worker thread on the first activation
+ if (Activations.fetch_add(1, std::memory_order_acq_rel) == 0) {
+ char ch = 'x';
+ ssize_t ret;
+#ifdef _win_
+ ret = send(SignalSock(), &ch, 1, 0);
+ if (ret == -1) {
+ Y_VERIFY(WSAGetLastError() == WSAEWOULDBLOCK, "Unexpected send error");
+ return;
+ }
+#else
+ do {
+ ret = send(SignalSock(), &ch, 1, 0);
+ } while (ret == -1 && errno == EINTR);
+ if (ret == -1) {
+ Y_VERIFY(errno == EAGAIN || errno == EWOULDBLOCK, "Unexpected send error");
+ return;
+ }
+#endif
+ Y_VERIFY(ret == 1, "Unexpected send result");
+ }
+ }
+
+ void RunCallbacks() noexcept {
+ char ch[32];
+ ssize_t ret;
+ bool signalled = false;
+ for (;;) {
+ ret = recv(WaitSock(), ch, sizeof(ch), 0);
+ if (ret > 0) {
+ signalled = true;
+ }
+ if (ret == sizeof(ch)) {
+ continue;
+ }
+ if (ret != -1) {
+ break;
+ }
+#ifdef _win_
+ if (WSAGetLastError() == WSAEWOULDBLOCK) {
+ break;
+ }
+ Y_FAIL("Unexpected recv error");
+#else
+ if (errno == EAGAIN || errno == EWOULDBLOCK) {
+ break;
+ }
+ Y_VERIFY(errno == EINTR, "Unexpected recv error");
+#endif
+ }
+
+ if (signalled) {
+ // There's exactly one write to SignalSock while Activations != 0
+ // It's impossible to get signalled while Activations == 0
+ // We must set Activations = 0 to receive new signals
+ size_t count = Activations.exchange(0, std::memory_order_acq_rel);
+ Y_VERIFY(count != 0);
+
+ // N.B. due to the way HTSwap works we may not be able to pop
+ // all callbacks on this activation, however we expect a new
+ // delayed activation to happen at a later time.
+ while (auto callback = CallbackQueue.Pop()) {
+ callback();
+ }
+ }
+ }
+
+ SOCKET SignalSock() {
+ return Sockets[0];
+ }
+
+ SOCKET WaitSock() {
+ return Sockets[1];
+ }
+
+ private:
+ SOCKET Sockets[2];
+ TCallbackQueue CallbackQueue;
+ std::atomic<size_t> Activations{ 0 };
+ };
+
+ class TSimpleDnsResolver
+ : public TActor<TSimpleDnsResolver>
+ , private TAresLibraryInitBase
+ , private TCallbackQueueBase
+ {
+ public:
+ TSimpleDnsResolver(TSimpleDnsResolverOptions options) noexcept
+ : TActor(&TThis::StateWork)
+ , Options(std::move(options))
+ , WorkerThread(&TThis::WorkerThreadStart, this)
+ {
+ InitAres();
+
+ WorkerThread.Start();
+ }
+
+ ~TSimpleDnsResolver() noexcept override {
+ if (!Stopped) {
+ PushCallback([this] {
+ // Mark as stopped first
+ Stopped = true;
+
+ // Cancel all current ares requests (will not send replies)
+ ares_cancel(AresChannel);
+ });
+
+ WorkerThread.Join();
+ }
+
+ StopAres();
+ }
+
+ static constexpr EActivityType ActorActivityType() {
+ return DNS_RESOLVER;
+ }
+
+ private:
+ void InitAres() noexcept {
+ struct ares_options options;
+ memset(&options, 0, sizeof(options));
+ int optmask = 0;
+
+ options.flags = ARES_FLAG_STAYOPEN;
+ optmask |= ARES_OPT_FLAGS;
+
+ options.sock_state_cb = &TThis::SockStateCallback;
+ options.sock_state_cb_data = this;
+ optmask |= ARES_OPT_SOCK_STATE_CB;
+
+ options.timeout = Options.Timeout.MilliSeconds();
+ if (options.timeout > 0) {
+ optmask |= ARES_OPT_TIMEOUTMS;
+ }
+
+ options.tries = Options.Attempts;
+ if (options.tries > 0) {
+ optmask |= ARES_OPT_TRIES;
+ }
+
+ int err = ares_init_options(&AresChannel, &options, optmask);
+ Y_VERIFY(err == 0, "Unexpected failure to initialize c-ares channel");
+
+ if (Options.Servers) {
+ TStringBuilder csv;
+ for (const TString& server : Options.Servers) {
+ if (csv) {
+ csv << ',';
+ }
+ csv << server;
+ }
+ err = ares_set_servers_ports_csv(AresChannel, csv.c_str());
+ Y_VERIFY(err == 0, "Unexpected failure to set a list of dns servers: %s", ares_strerror(err));
+ }
+ }
+
+ void StopAres() noexcept {
+ // Destroy the ares channel
+ ares_destroy(AresChannel);
+ AresChannel = nullptr;
+ }
+
+ private:
+ STRICT_STFUNC(StateWork, {
+ hFunc(TEvents::TEvPoison, Handle);
+ hFunc(TEvDns::TEvGetHostByName, Handle);
+ hFunc(TEvDns::TEvGetAddr, Handle);
+ })
+
+ void Handle(TEvents::TEvPoison::TPtr&) {
+ Y_VERIFY(!Stopped);
+
+ PushCallback([this] {
+ // Cancel all current ares requests (will send notifications)
+ ares_cancel(AresChannel);
+
+ // Mark as stopped last
+ Stopped = true;
+ });
+
+ WorkerThread.Join();
+ PassAway();
+ }
+
+ private:
+ enum class ERequestType {
+ GetHostByName,
+ GetAddr,
+ };
+
+ struct TRequestContext : public TThrRefBase {
+ using TPtr = TIntrusivePtr<TRequestContext>;
+
+ TThis* Self;
+ TActorSystem* ActorSystem;
+ TActorId SelfId;
+ TActorId Sender;
+ ui64 Cookie;
+ ERequestType Type;
+
+ TRequestContext(TThis* self, TActorSystem* as, TActorId selfId, TActorId sender, ui64 cookie, ERequestType type)
+ : Self(self)
+ , ActorSystem(as)
+ , SelfId(selfId)
+ , Sender(sender)
+ , Cookie(cookie)
+ , Type(type)
+ { }
+ };
+
+ private:
+ void Handle(TEvDns::TEvGetHostByName::TPtr& ev) {
+ auto* msg = ev->Get();
+ auto reqCtx = MakeIntrusive<TRequestContext>(
+ this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetHostByName);
+ PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable {
+ StartGetHostByName(std::move(reqCtx), std::move(name), family);
+ });
+ }
+
+ void Handle(TEvDns::TEvGetAddr::TPtr& ev) {
+ auto* msg = ev->Get();
+ auto reqCtx = MakeIntrusive<TRequestContext>(
+ this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetAddr);
+ PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable {
+ StartGetHostByName(std::move(reqCtx), std::move(name), family);
+ });
+ }
+
+ void StartGetHostByName(TRequestContext::TPtr reqCtx, TString name, int family) noexcept {
+ reqCtx->Ref();
+ ares_gethostbyname(AresChannel, name.c_str(), family,
+ &TThis::GetHostByNameAresCallback, reqCtx.Get());
+ }
+
+ private:
+ static void GetHostByNameAresCallback(void* arg, int status, int timeouts, struct hostent* info) {
+ Y_UNUSED(timeouts);
+ TRequestContext::TPtr reqCtx(static_cast<TRequestContext*>(arg));
+ reqCtx->UnRef();
+
+ if (reqCtx->Self->Stopped) {
+ // Don't send any replies after destruction
+ return;
+ }
+
+ switch (reqCtx->Type) {
+ case ERequestType::GetHostByName: {
+ auto result = MakeHolder<TEvDns::TEvGetHostByNameResult>();
+ if (status == 0) {
+ switch (info->h_addrtype) {
+ case AF_INET: {
+ for (int i = 0; info->h_addr_list[i] != nullptr; ++i) {
+ result->AddrsV4.emplace_back(*(struct in_addr*)(info->h_addr_list[i]));
+ }
+ break;
+ }
+ case AF_INET6: {
+ for (int i = 0; info->h_addr_list[i] != nullptr; ++i) {
+ result->AddrsV6.emplace_back(*(struct in6_addr*)(info->h_addr_list[i]));
+ }
+ break;
+ }
+ default:
+ Y_FAIL("unknown address family in ares callback");
+ }
+ } else {
+ result->ErrorText = ares_strerror(status);
+ }
+ result->Status = status;
+
+ reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie));
+ break;
+ }
+
+ case ERequestType::GetAddr: {
+ auto result = MakeHolder<TEvDns::TEvGetAddrResult>();
+ if (status == 0 && Y_UNLIKELY(info->h_addr_list[0] == nullptr)) {
+ status = ARES_ENODATA;
+ }
+ if (status == 0) {
+ switch (info->h_addrtype) {
+ case AF_INET: {
+ result->Addr = *(struct in_addr*)(info->h_addr_list[0]);
+ break;
+ }
+ case AF_INET6: {
+ result->Addr = *(struct in6_addr*)(info->h_addr_list[0]);
+ break;
+ }
+ default:
+ Y_FAIL("unknown address family in ares callback");
+ }
+ } else {
+ result->ErrorText = ares_strerror(status);
+ }
+ result->Status = status;
+
+ reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie));
+ break;
+ }
+ }
+ }
+
+ private:
+ static void SockStateCallback(void* data, ares_socket_t socket_fd, int readable, int writable) {
+ static_cast<TThis*>(data)->DoSockStateCallback(socket_fd, readable, writable);
+ }
+
+ void DoSockStateCallback(ares_socket_t socket_fd, int readable, int writable) noexcept {
+ int events = (readable ? (POLLRDNORM | POLLIN) : 0) | (writable ? (POLLWRNORM | POLLOUT) : 0);
+ if (events == 0) {
+ AresSockStates.erase(socket_fd);
+ } else {
+ AresSockStates[socket_fd].NeededEvents = events;
+ }
+ }
+
+ private:
+ static void* WorkerThreadStart(void* arg) noexcept {
+ static_cast<TSimpleDnsResolver*>(arg)->WorkerThreadLoop();
+ return nullptr;
+ }
+
+ void WorkerThreadLoop() noexcept {
+ TThread::SetCurrentThreadName("DnsResolver");
+
+ TVector<struct pollfd> fds;
+ while (!Stopped) {
+ fds.clear();
+ fds.reserve(1 + AresSockStates.size());
+ {
+ auto& entry = fds.emplace_back();
+ entry.fd = WaitSock();
+ entry.events = POLLRDNORM | POLLIN;
+ }
+ for (auto& kv : AresSockStates) {
+ auto& entry = fds.emplace_back();
+ entry.fd = kv.first;
+ entry.events = kv.second.NeededEvents;
+ }
+
+ int timeout = -1;
+ struct timeval tv;
+ if (ares_timeout(AresChannel, nullptr, &tv)) {
+ timeout = tv.tv_sec * 1000 + tv.tv_usec / 1000;
+ }
+
+ int ret = poll(fds.data(), fds.size(), timeout);
+ if (ret == -1) {
+ if (errno == EINTR) {
+ continue;
+ }
+ // we cannot handle failures, run callbacks and pretend everything is ok
+ RunCallbacks();
+ if (Stopped) {
+ break;
+ }
+ ret = 0;
+ }
+
+ bool ares_called = false;
+ if (ret > 0) {
+ for (size_t i = 0; i < fds.size(); ++i) {
+ auto& entry = fds[i];
+
+ // Handle WaitSock activation and run callbacks
+ if (i == 0) {
+ if (entry.revents & (POLLRDNORM | POLLIN)) {
+ RunCallbacks();
+ if (Stopped) {
+ break;
+ }
+ }
+ continue;
+ }
+
+ // All other sockets belong to ares
+ if (entry.revents == 0) {
+ continue;
+ }
+ // Previous invocation of aress_process_fd might have removed some sockets
+ if (Y_UNLIKELY(!AresSockStates.contains(entry.fd))) {
+ continue;
+ }
+ ares_process_fd(
+ AresChannel,
+ entry.revents & (POLLRDNORM | POLLIN) ? entry.fd : ARES_SOCKET_BAD,
+ entry.revents & (POLLWRNORM | POLLOUT) ? entry.fd : ARES_SOCKET_BAD);
+ ares_called = true;
+ }
+
+ if (Stopped) {
+ break;
+ }
+ }
+
+ if (!ares_called) {
+ // Let ares handle timeouts
+ ares_process_fd(AresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
+ }
+ }
+ }
+
+ private:
+ struct TSockState {
+ short NeededEvents = 0; // poll events
+ };
+
+ private:
+ TSimpleDnsResolverOptions Options;
+ TThread WorkerThread;
+
+ ares_channel AresChannel;
+ THashMap<SOCKET, TSockState> AresSockStates;
+
+ bool Stopped = false;
+ };
+
+ IActor* CreateSimpleDnsResolver(TSimpleDnsResolverOptions options) {
+ return new TSimpleDnsResolver(std::move(options));
+ }
+
+} // namespace NDnsResolver
+} // namespace NActors