diff options
author | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <[email protected]> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/actors/dnsresolver |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/actors/dnsresolver')
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver.cpp | 475 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver.h | 128 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver_caching.cpp | 730 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver_caching_ut.cpp | 630 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver_ondemand.cpp | 64 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver_ondemand_ut.cpp | 24 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/dnsresolver_ut.cpp | 98 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/ut/ya.make | 20 | ||||
-rw-r--r-- | library/cpp/actors/dnsresolver/ya.make | 20 |
9 files changed, 2189 insertions, 0 deletions
diff --git a/library/cpp/actors/dnsresolver/dnsresolver.cpp b/library/cpp/actors/dnsresolver/dnsresolver.cpp new file mode 100644 index 00000000000..6329bb00833 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver.cpp @@ -0,0 +1,475 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/core/hfunc.h> +#include <library/cpp/threading/queue/mpsc_htswap.h> +#include <util/network/pair.h> +#include <util/network/socket.h> +#include <util/string/builder.h> +#include <util/system/thread.h> + +#include <ares.h> + +#include <atomic> + +namespace NActors { +namespace NDnsResolver { + + class TAresLibraryInitBase { + protected: + TAresLibraryInitBase() noexcept { + int status = ares_library_init(ARES_LIB_INIT_ALL); + Y_VERIFY(status == ARES_SUCCESS, "Unexpected failure to initialize c-ares library"); + } + + ~TAresLibraryInitBase() noexcept { + ares_library_cleanup(); + } + }; + + class TCallbackQueueBase { + protected: + TCallbackQueueBase() noexcept { + int err = SocketPair(Sockets, false, true); + Y_VERIFY(err == 0, "Unexpected failure to create a socket pair"); + SetNonBlock(Sockets[0]); + SetNonBlock(Sockets[1]); + } + + ~TCallbackQueueBase() noexcept { + closesocket(Sockets[0]); + closesocket(Sockets[1]); + } + + protected: + using TCallback = std::function<void()>; + using TCallbackQueue = NThreading::THTSwapQueue<TCallback>; + + void PushCallback(TCallback callback) { + Y_VERIFY(callback, "Cannot push an empty callback"); + CallbackQueue.Push(std::move(callback)); // this is a lockfree queue + + // Wake up worker thread on the first activation + if (Activations.fetch_add(1, std::memory_order_acq_rel) == 0) { + char ch = 'x'; + ssize_t ret; +#ifdef _win_ + ret = send(SignalSock(), &ch, 1, 0); + if (ret == -1) { + Y_VERIFY(WSAGetLastError() == WSAEWOULDBLOCK, "Unexpected send error"); + return; + } +#else + do { + ret = send(SignalSock(), &ch, 1, 0); + } while (ret == -1 && errno == EINTR); + if (ret == -1) { + Y_VERIFY(errno == EAGAIN || errno == EWOULDBLOCK, "Unexpected send error"); + return; + } +#endif + Y_VERIFY(ret == 1, "Unexpected send result"); + } + } + + void RunCallbacks() noexcept { + char ch[32]; + ssize_t ret; + bool signalled = false; + for (;;) { + ret = recv(WaitSock(), ch, sizeof(ch), 0); + if (ret > 0) { + signalled = true; + } + if (ret == sizeof(ch)) { + continue; + } + if (ret != -1) { + break; + } +#ifdef _win_ + if (WSAGetLastError() == WSAEWOULDBLOCK) { + break; + } + Y_FAIL("Unexpected recv error"); +#else + if (errno == EAGAIN || errno == EWOULDBLOCK) { + break; + } + Y_VERIFY(errno == EINTR, "Unexpected recv error"); +#endif + } + + if (signalled) { + // There's exactly one write to SignalSock while Activations != 0 + // It's impossible to get signalled while Activations == 0 + // We must set Activations = 0 to receive new signals + size_t count = Activations.exchange(0, std::memory_order_acq_rel); + Y_VERIFY(count != 0); + + // N.B. due to the way HTSwap works we may not be able to pop + // all callbacks on this activation, however we expect a new + // delayed activation to happen at a later time. + while (auto callback = CallbackQueue.Pop()) { + callback(); + } + } + } + + SOCKET SignalSock() { + return Sockets[0]; + } + + SOCKET WaitSock() { + return Sockets[1]; + } + + private: + SOCKET Sockets[2]; + TCallbackQueue CallbackQueue; + std::atomic<size_t> Activations{ 0 }; + }; + + class TSimpleDnsResolver + : public TActor<TSimpleDnsResolver> + , private TAresLibraryInitBase + , private TCallbackQueueBase + { + public: + TSimpleDnsResolver(TSimpleDnsResolverOptions options) noexcept + : TActor(&TThis::StateWork) + , Options(std::move(options)) + , WorkerThread(&TThis::WorkerThreadStart, this) + { + InitAres(); + + WorkerThread.Start(); + } + + ~TSimpleDnsResolver() noexcept override { + if (!Stopped) { + PushCallback([this] { + // Mark as stopped first + Stopped = true; + + // Cancel all current ares requests (will not send replies) + ares_cancel(AresChannel); + }); + + WorkerThread.Join(); + } + + StopAres(); + } + + static constexpr EActivityType ActorActivityType() { + return DNS_RESOLVER; + } + + private: + void InitAres() noexcept { + struct ares_options options; + memset(&options, 0, sizeof(options)); + int optmask = 0; + + options.flags = ARES_FLAG_STAYOPEN; + optmask |= ARES_OPT_FLAGS; + + options.sock_state_cb = &TThis::SockStateCallback; + options.sock_state_cb_data = this; + optmask |= ARES_OPT_SOCK_STATE_CB; + + options.timeout = Options.Timeout.MilliSeconds(); + if (options.timeout > 0) { + optmask |= ARES_OPT_TIMEOUTMS; + } + + options.tries = Options.Attempts; + if (options.tries > 0) { + optmask |= ARES_OPT_TRIES; + } + + int err = ares_init_options(&AresChannel, &options, optmask); + Y_VERIFY(err == 0, "Unexpected failure to initialize c-ares channel"); + + if (Options.Servers) { + TStringBuilder csv; + for (const TString& server : Options.Servers) { + if (csv) { + csv << ','; + } + csv << server; + } + err = ares_set_servers_ports_csv(AresChannel, csv.c_str()); + Y_VERIFY(err == 0, "Unexpected failure to set a list of dns servers: %s", ares_strerror(err)); + } + } + + void StopAres() noexcept { + // Destroy the ares channel + ares_destroy(AresChannel); + AresChannel = nullptr; + } + + private: + STRICT_STFUNC(StateWork, { + hFunc(TEvents::TEvPoison, Handle); + hFunc(TEvDns::TEvGetHostByName, Handle); + hFunc(TEvDns::TEvGetAddr, Handle); + }) + + void Handle(TEvents::TEvPoison::TPtr&) { + Y_VERIFY(!Stopped); + + PushCallback([this] { + // Cancel all current ares requests (will send notifications) + ares_cancel(AresChannel); + + // Mark as stopped last + Stopped = true; + }); + + WorkerThread.Join(); + PassAway(); + } + + private: + enum class ERequestType { + GetHostByName, + GetAddr, + }; + + struct TRequestContext : public TThrRefBase { + using TPtr = TIntrusivePtr<TRequestContext>; + + TThis* Self; + TActorSystem* ActorSystem; + TActorId SelfId; + TActorId Sender; + ui64 Cookie; + ERequestType Type; + + TRequestContext(TThis* self, TActorSystem* as, TActorId selfId, TActorId sender, ui64 cookie, ERequestType type) + : Self(self) + , ActorSystem(as) + , SelfId(selfId) + , Sender(sender) + , Cookie(cookie) + , Type(type) + { } + }; + + private: + void Handle(TEvDns::TEvGetHostByName::TPtr& ev) { + auto* msg = ev->Get(); + auto reqCtx = MakeIntrusive<TRequestContext>( + this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetHostByName); + PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable { + StartGetHostByName(std::move(reqCtx), std::move(name), family); + }); + } + + void Handle(TEvDns::TEvGetAddr::TPtr& ev) { + auto* msg = ev->Get(); + auto reqCtx = MakeIntrusive<TRequestContext>( + this, TActivationContext::ActorSystem(), SelfId(), ev->Sender, ev->Cookie, ERequestType::GetAddr); + PushCallback([this, reqCtx = std::move(reqCtx), name = std::move(msg->Name), family = msg->Family] () mutable { + StartGetHostByName(std::move(reqCtx), std::move(name), family); + }); + } + + void StartGetHostByName(TRequestContext::TPtr reqCtx, TString name, int family) noexcept { + reqCtx->Ref(); + ares_gethostbyname(AresChannel, name.c_str(), family, + &TThis::GetHostByNameAresCallback, reqCtx.Get()); + } + + private: + static void GetHostByNameAresCallback(void* arg, int status, int timeouts, struct hostent* info) { + Y_UNUSED(timeouts); + TRequestContext::TPtr reqCtx(static_cast<TRequestContext*>(arg)); + reqCtx->UnRef(); + + if (reqCtx->Self->Stopped) { + // Don't send any replies after destruction + return; + } + + switch (reqCtx->Type) { + case ERequestType::GetHostByName: { + auto result = MakeHolder<TEvDns::TEvGetHostByNameResult>(); + if (status == 0) { + switch (info->h_addrtype) { + case AF_INET: { + for (int i = 0; info->h_addr_list[i] != nullptr; ++i) { + result->AddrsV4.emplace_back(*(struct in_addr*)(info->h_addr_list[i])); + } + break; + } + case AF_INET6: { + for (int i = 0; info->h_addr_list[i] != nullptr; ++i) { + result->AddrsV6.emplace_back(*(struct in6_addr*)(info->h_addr_list[i])); + } + break; + } + default: + Y_FAIL("unknown address family in ares callback"); + } + } else { + result->ErrorText = ares_strerror(status); + } + result->Status = status; + + reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie)); + break; + } + + case ERequestType::GetAddr: { + auto result = MakeHolder<TEvDns::TEvGetAddrResult>(); + if (status == 0 && Y_UNLIKELY(info->h_addr_list[0] == nullptr)) { + status = ARES_ENODATA; + } + if (status == 0) { + switch (info->h_addrtype) { + case AF_INET: { + result->Addr = *(struct in_addr*)(info->h_addr_list[0]); + break; + } + case AF_INET6: { + result->Addr = *(struct in6_addr*)(info->h_addr_list[0]); + break; + } + default: + Y_FAIL("unknown address family in ares callback"); + } + } else { + result->ErrorText = ares_strerror(status); + } + result->Status = status; + + reqCtx->ActorSystem->Send(new IEventHandle(reqCtx->Sender, reqCtx->SelfId, result.Release(), 0, reqCtx->Cookie)); + break; + } + } + } + + private: + static void SockStateCallback(void* data, ares_socket_t socket_fd, int readable, int writable) { + static_cast<TThis*>(data)->DoSockStateCallback(socket_fd, readable, writable); + } + + void DoSockStateCallback(ares_socket_t socket_fd, int readable, int writable) noexcept { + int events = (readable ? (POLLRDNORM | POLLIN) : 0) | (writable ? (POLLWRNORM | POLLOUT) : 0); + if (events == 0) { + AresSockStates.erase(socket_fd); + } else { + AresSockStates[socket_fd].NeededEvents = events; + } + } + + private: + static void* WorkerThreadStart(void* arg) noexcept { + static_cast<TSimpleDnsResolver*>(arg)->WorkerThreadLoop(); + return nullptr; + } + + void WorkerThreadLoop() noexcept { + TThread::SetCurrentThreadName("DnsResolver"); + + TVector<struct pollfd> fds; + while (!Stopped) { + fds.clear(); + fds.reserve(1 + AresSockStates.size()); + { + auto& entry = fds.emplace_back(); + entry.fd = WaitSock(); + entry.events = POLLRDNORM | POLLIN; + } + for (auto& kv : AresSockStates) { + auto& entry = fds.emplace_back(); + entry.fd = kv.first; + entry.events = kv.second.NeededEvents; + } + + int timeout = -1; + struct timeval tv; + if (ares_timeout(AresChannel, nullptr, &tv)) { + timeout = tv.tv_sec * 1000 + tv.tv_usec / 1000; + } + + int ret = poll(fds.data(), fds.size(), timeout); + if (ret == -1) { + if (errno == EINTR) { + continue; + } + // we cannot handle failures, run callbacks and pretend everything is ok + RunCallbacks(); + if (Stopped) { + break; + } + ret = 0; + } + + bool ares_called = false; + if (ret > 0) { + for (size_t i = 0; i < fds.size(); ++i) { + auto& entry = fds[i]; + + // Handle WaitSock activation and run callbacks + if (i == 0) { + if (entry.revents & (POLLRDNORM | POLLIN)) { + RunCallbacks(); + if (Stopped) { + break; + } + } + continue; + } + + // All other sockets belong to ares + if (entry.revents == 0) { + continue; + } + // Previous invocation of aress_process_fd might have removed some sockets + if (Y_UNLIKELY(!AresSockStates.contains(entry.fd))) { + continue; + } + ares_process_fd( + AresChannel, + entry.revents & (POLLRDNORM | POLLIN) ? entry.fd : ARES_SOCKET_BAD, + entry.revents & (POLLWRNORM | POLLOUT) ? entry.fd : ARES_SOCKET_BAD); + ares_called = true; + } + + if (Stopped) { + break; + } + } + + if (!ares_called) { + // Let ares handle timeouts + ares_process_fd(AresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); + } + } + } + + private: + struct TSockState { + short NeededEvents = 0; // poll events + }; + + private: + TSimpleDnsResolverOptions Options; + TThread WorkerThread; + + ares_channel AresChannel; + THashMap<SOCKET, TSockState> AresSockStates; + + bool Stopped = false; + }; + + IActor* CreateSimpleDnsResolver(TSimpleDnsResolverOptions options) { + return new TSimpleDnsResolver(std::move(options)); + } + +} // namespace NDnsResolver +} // namespace NActors diff --git a/library/cpp/actors/dnsresolver/dnsresolver.h b/library/cpp/actors/dnsresolver/dnsresolver.h new file mode 100644 index 00000000000..88fc74df7d1 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver.h @@ -0,0 +1,128 @@ +#pragma once + +#include <library/cpp/actors/core/actor.h> +#include <library/cpp/actors/core/events.h> +#include <library/cpp/actors/core/event_local.h> +#include <library/cpp/monlib/dynamic_counters/counters.h> +#include <util/network/address.h> +#include <variant> + +namespace NActors { +namespace NDnsResolver { + + struct TEvDns { + enum EEv { + EvGetHostByName = EventSpaceBegin(TEvents::ES_DNS), + EvGetHostByNameResult, + EvGetAddr, + EvGetAddrResult, + }; + + /** + * TEvGetHostByName returns the result of ares_gethostbyname + */ + struct TEvGetHostByName : public TEventLocal<TEvGetHostByName, EvGetHostByName> { + TString Name; + int Family; + + explicit TEvGetHostByName(TString name, int family = AF_UNSPEC) + : Name(std::move(name)) + , Family(family) + { } + }; + + struct TEvGetHostByNameResult : public TEventLocal<TEvGetHostByNameResult, EvGetHostByNameResult> { + TVector<struct in_addr> AddrsV4; + TVector<struct in6_addr> AddrsV6; + TString ErrorText; + int Status = 0; + }; + + /** + * TEvGetAddr returns a single address for a given hostname + */ + struct TEvGetAddr : public TEventLocal<TEvGetAddr, EvGetAddr> { + TString Name; + int Family; + + explicit TEvGetAddr(TString name, int family = AF_UNSPEC) + : Name(std::move(name)) + , Family(family) + { } + }; + + struct TEvGetAddrResult : public TEventLocal<TEvGetAddrResult, EvGetAddrResult> { + // N.B. "using" here doesn't work with Visual Studio compiler + typedef struct in6_addr TIPv6Addr; + typedef struct in_addr TIPv4Addr; + + std::variant<std::monostate, TIPv6Addr, TIPv4Addr> Addr; + TString ErrorText; + int Status = 0; + + bool IsV6() const { + return std::holds_alternative<TIPv6Addr>(Addr); + } + + bool IsV4() const { + return std::holds_alternative<TIPv4Addr>(Addr); + } + + const TIPv6Addr& GetAddrV6() const { + const TIPv6Addr* p = std::get_if<TIPv6Addr>(&Addr); + Y_VERIFY(p, "Result is not an ipv6 address"); + return *p; + } + + const TIPv4Addr& GetAddrV4() const { + const TIPv4Addr* p = std::get_if<TIPv4Addr>(&Addr); + Y_VERIFY(p, "Result is not an ipv4 address"); + return *p; + } + }; + }; + + struct TSimpleDnsResolverOptions { + // Initial per-server timeout, grows exponentially with each retry + TDuration Timeout = TDuration::Seconds(1); + // Number of attempts per-server + int Attempts = 2; + // Optional list of custom dns servers (ip.v4[:port], ip::v6 or [ip::v6]:port format) + TVector<TString> Servers; + }; + + IActor* CreateSimpleDnsResolver(TSimpleDnsResolverOptions options = TSimpleDnsResolverOptions()); + + struct TCachingDnsResolverOptions { + // Soft expire time specifies delay before name is refreshed in background + TDuration SoftNegativeExpireTime = TDuration::Seconds(1); + TDuration SoftPositiveExpireTime = TDuration::Seconds(10); + // Hard expire time specifies delay before the last result is forgotten + TDuration HardNegativeExpireTime = TDuration::Seconds(10); + TDuration HardPositiveExpireTime = TDuration::Hours(2); + // Allow these request families + bool AllowIPv6 = true; + bool AllowIPv4 = true; + // Optional counters + NMonitoring::TDynamicCounterPtr MonCounters = nullptr; + }; + + IActor* CreateCachingDnsResolver(TActorId upstream, TCachingDnsResolverOptions options = TCachingDnsResolverOptions()); + + struct TOnDemandDnsResolverOptions + : public TSimpleDnsResolverOptions + , public TCachingDnsResolverOptions + { + }; + + IActor* CreateOnDemandDnsResolver(TOnDemandDnsResolverOptions options = TOnDemandDnsResolverOptions()); + + /** + * Returns actor id of a globally registered dns resolver + */ + inline TActorId MakeDnsResolverActorId() { + return TActorId(0, TStringBuf("dnsresolver")); + } + +} // namespace NDnsResolver +} // namespace NActors diff --git a/library/cpp/actors/dnsresolver/dnsresolver_caching.cpp b/library/cpp/actors/dnsresolver/dnsresolver_caching.cpp new file mode 100644 index 00000000000..02760f4c275 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver_caching.cpp @@ -0,0 +1,730 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/core/hfunc.h> +#include <util/generic/intrlist.h> + +#include <ares.h> + +#include <queue> + +namespace NActors { +namespace NDnsResolver { + + class TCachingDnsResolver : public TActor<TCachingDnsResolver> { + public: + struct TMonCounters { + NMonitoring::TDynamicCounters::TCounterPtr OutgoingInFlightV4; + NMonitoring::TDynamicCounters::TCounterPtr OutgoingInFlightV6; + NMonitoring::TDynamicCounters::TCounterPtr OutgoingErrorsV4; + NMonitoring::TDynamicCounters::TCounterPtr OutgoingErrorsV6; + NMonitoring::TDynamicCounters::TCounterPtr OutgoingTotalV4; + NMonitoring::TDynamicCounters::TCounterPtr OutgoingTotalV6; + + NMonitoring::TDynamicCounters::TCounterPtr IncomingInFlight; + NMonitoring::TDynamicCounters::TCounterPtr IncomingErrors; + NMonitoring::TDynamicCounters::TCounterPtr IncomingTotal; + + NMonitoring::TDynamicCounters::TCounterPtr CacheSize; + NMonitoring::TDynamicCounters::TCounterPtr CacheHits; + NMonitoring::TDynamicCounters::TCounterPtr CacheMisses; + + TMonCounters(const NMonitoring::TDynamicCounterPtr& counters) + : OutgoingInFlightV4(counters->GetCounter("DnsResolver/Outgoing/InFlight/V4", false)) + , OutgoingInFlightV6(counters->GetCounter("DnsResolver/Outgoing/InFlight/V6", false)) + , OutgoingErrorsV4(counters->GetCounter("DnsResolver/Outgoing/Errors/V4", true)) + , OutgoingErrorsV6(counters->GetCounter("DnsResolver/Outgoing/Errors/V6", true)) + , OutgoingTotalV4(counters->GetCounter("DnsResolver/Outgoing/Total/V4", true)) + , OutgoingTotalV6(counters->GetCounter("DnsResolver/Outgoing/Total/V6", true)) + , IncomingInFlight(counters->GetCounter("DnsResolver/Incoming/InFlight", false)) + , IncomingErrors(counters->GetCounter("DnsResolver/Incoming/Errors", true)) + , IncomingTotal(counters->GetCounter("DnsResolver/Incoming/Total", true)) + , CacheSize(counters->GetCounter("DnsResolver/Cache/Size", false)) + , CacheHits(counters->GetCounter("DnsResolver/Cache/Hits", true)) + , CacheMisses(counters->GetCounter("DnsResolver/Cache/Misses", true)) + { } + }; + + public: + TCachingDnsResolver(TActorId upstream, TCachingDnsResolverOptions options) + : TActor(&TThis::StateWork) + , Upstream(upstream) + , Options(std::move(options)) + , MonCounters(Options.MonCounters ? new TMonCounters(Options.MonCounters) : nullptr) + { } + + static constexpr EActivityType ActorActivityType() { + return DNS_RESOLVER; + } + + private: + STRICT_STFUNC(StateWork, { + hFunc(TEvents::TEvPoison, Handle); + hFunc(TEvDns::TEvGetHostByName, Handle); + hFunc(TEvDns::TEvGetAddr, Handle); + hFunc(TEvDns::TEvGetHostByNameResult, Handle); + hFunc(TEvents::TEvUndelivered, Handle); + }); + + void Handle(TEvents::TEvPoison::TPtr&) { + DropPending(ARES_ECANCELLED); + PassAway(); + } + + void Handle(TEvDns::TEvGetHostByName::TPtr& ev) { + auto req = MakeHolder<TIncomingRequest>(); + req->Type = EIncomingRequestType::GetHostByName; + req->Sender = ev->Sender; + req->Cookie = ev->Cookie; + req->Name = std::move(ev->Get()->Name); + req->Family = ev->Get()->Family; + EnqueueRequest(std::move(req)); + } + + void Handle(TEvDns::TEvGetAddr::TPtr& ev) { + auto req = MakeHolder<TIncomingRequest>(); + req->Type = EIncomingRequestType::GetAddr; + req->Sender = ev->Sender; + req->Cookie = ev->Cookie; + req->Name = std::move(ev->Get()->Name); + req->Family = ev->Get()->Family; + EnqueueRequest(std::move(req)); + } + + void Handle(TEvDns::TEvGetHostByNameResult::TPtr& ev) { + auto waitingIt = WaitingRequests.find(ev->Cookie); + Y_VERIFY(waitingIt != WaitingRequests.end(), "Unexpected reply, reqId=%" PRIu64, ev->Cookie); + auto waitingInfo = waitingIt->second; + WaitingRequests.erase(waitingIt); + + switch (waitingInfo.Family) { + case AF_INET6: + if (ev->Get()->Status) { + ProcessErrorV6(waitingInfo.Position, ev->Get()->Status, std::move(ev->Get()->ErrorText)); + } else { + ProcessAddrsV6(waitingInfo.Position, std::move(ev->Get()->AddrsV6)); + } + break; + + case AF_INET: + if (ev->Get()->Status) { + ProcessErrorV4(waitingInfo.Position, ev->Get()->Status, std::move(ev->Get()->ErrorText)); + } else { + ProcessAddrsV4(waitingInfo.Position, std::move(ev->Get()->AddrsV4)); + } + break; + + default: + Y_FAIL("Unexpected request family %d", waitingInfo.Family); + } + } + + void Handle(TEvents::TEvUndelivered::TPtr& ev) { + switch (ev->Get()->SourceType) { + case TEvDns::TEvGetHostByName::EventType: { + auto waitingIt = WaitingRequests.find(ev->Cookie); + Y_VERIFY(waitingIt != WaitingRequests.end(), "Unexpected TEvUndelivered, reqId=%" PRIu64, ev->Cookie); + auto waitingInfo = waitingIt->second; + WaitingRequests.erase(waitingIt); + + switch (waitingInfo.Family) { + case AF_INET6: + ProcessErrorV6(waitingInfo.Position, ARES_ENOTINITIALIZED, "Caching dns resolver cannot deliver to the underlying resolver"); + break; + case AF_INET: + ProcessErrorV4(waitingInfo.Position, ARES_ENOTINITIALIZED, "Caching dns resolver cannot deliver to the underlying resolver"); + break; + default: + Y_FAIL("Unexpected request family %d", waitingInfo.Family); + } + + break; + } + + default: + Y_FAIL("Unexpected TEvUndelievered, type=%" PRIu32, ev->Get()->SourceType); + } + } + + private: + enum EIncomingRequestType { + GetHostByName, + GetAddr, + }; + + struct TIncomingRequest : public TIntrusiveListItem<TIncomingRequest> { + EIncomingRequestType Type; + TActorId Sender; + ui64 Cookie; + TString Name; + int Family; + }; + + using TIncomingRequestList = TIntrusiveListWithAutoDelete<TIncomingRequest, TDelete>; + + void EnqueueRequest(THolder<TIncomingRequest> req) { + if (MonCounters) { + ++*MonCounters->IncomingTotal; + } + + CleanupExpired(TActivationContext::Now()); + + switch (req->Family) { + case AF_UNSPEC: + if (Options.AllowIPv6) { + EnqueueRequestIPv6(std::move(req)); + return; + } + if (Options.AllowIPv4) { + EnqueueRequestIPv4(std::move(req)); + return; + } + break; + + case AF_INET6: + if (Options.AllowIPv6) { + EnqueueRequestIPv6(std::move(req)); + return; + } + break; + + case AF_INET: + if (Options.AllowIPv4) { + EnqueueRequestIPv4(std::move(req)); + return; + } + break; + } + + ReplyWithError(std::move(req), ARES_EBADFAMILY); + } + + void EnqueueRequestIPv6(THolder<TIncomingRequest> req) { + auto now = TActivationContext::Now(); + + auto& fullState = NameToState[req->Name]; + if (MonCounters) { + *MonCounters->CacheSize = NameToState.size(); + } + + auto& state = fullState.StateIPv6; + EnsureRequest(state, req->Name, AF_INET6, now); + + if (state.IsHardExpired(now)) { + Y_VERIFY(state.Waiting); + if (MonCounters) { + ++*MonCounters->CacheMisses; + } + // We need to wait for ipv6 reply, schedule ipv4 request in parallel if needed + if (Options.AllowIPv4) { + EnsureRequest(fullState.StateIPv4, req->Name, AF_INET, now); + } + state.WaitingRequests.PushBack(req.Release()); + return; + } + + // We want to retry AF_UNSPEC with IPv4 in some cases + if (req->Family == AF_UNSPEC && Options.AllowIPv4 && state.RetryUnspec()) { + EnqueueRequestIPv4(std::move(req)); + return; + } + + if (MonCounters) { + ++*MonCounters->CacheHits; + } + + if (state.Status != 0) { + ReplyWithError(std::move(req), state.Status, state.ErrorText); + } else { + ReplyWithAddrs(std::move(req), fullState.AddrsIPv6); + } + } + + void EnqueueRequestIPv4(THolder<TIncomingRequest> req, bool isCacheMiss = false) { + auto now = TActivationContext::Now(); + + auto& fullState = NameToState[req->Name]; + if (MonCounters) { + *MonCounters->CacheSize = NameToState.size(); + } + + auto& state = fullState.StateIPv4; + EnsureRequest(state, req->Name, AF_INET, now); + + if (state.IsHardExpired(now)) { + Y_VERIFY(state.Waiting); + if (MonCounters && !isCacheMiss) { + ++*MonCounters->CacheMisses; + } + state.WaitingRequests.PushBack(req.Release()); + return; + } + + if (MonCounters && !isCacheMiss) { + ++*MonCounters->CacheHits; + } + + if (state.Status != 0) { + ReplyWithError(std::move(req), state.Status, state.ErrorText); + } else { + ReplyWithAddrs(std::move(req), fullState.AddrsIPv4); + } + } + + private: + struct TFamilyState { + TIncomingRequestList WaitingRequests; + TInstant SoftDeadline; + TInstant HardDeadline; + TInstant NextSoftDeadline; + TInstant NextHardDeadline; + TString ErrorText; + int Status = -1; // never requested before + bool InSoftHeap = false; + bool InHardHeap = false; + bool Waiting = false; + + bool Needed() const { + return InSoftHeap || InHardHeap || Waiting; + } + + bool RetryUnspec() const { + return ( + Status == ARES_ENODATA || + Status == ARES_EBADRESP || + Status == ARES_ETIMEOUT); + } + + bool ServerReplied() const { + return ServerReplied(Status); + } + + bool IsSoftExpired(TInstant now) const { + return !InSoftHeap || NextSoftDeadline < now; + } + + bool IsHardExpired(TInstant now) const { + return !InHardHeap || NextHardDeadline < now; + } + + static bool ServerReplied(int status) { + return ( + status == ARES_SUCCESS || + status == ARES_ENODATA || + status == ARES_ENOTFOUND); + } + }; + + struct TState { + TFamilyState StateIPv6; + TFamilyState StateIPv4; + TVector<struct in6_addr> AddrsIPv6; + TVector<struct in_addr> AddrsIPv4; + + bool Needed() const { + return StateIPv6.Needed() || StateIPv4.Needed(); + } + }; + + using TNameToState = THashMap<TString, TState>; + + template<const TFamilyState TState::* StateToFamily, + const TInstant TFamilyState::* FamilyToDeadline> + struct THeapCompare { + // returns true when b < a + bool operator()(TNameToState::iterator a, TNameToState::iterator b) const { + const TState& aState = a->second; + const TState& bState = b->second; + const TFamilyState& aFamily = aState.*StateToFamily; + const TFamilyState& bFamily = bState.*StateToFamily; + const TInstant& aDeadline = aFamily.*FamilyToDeadline; + const TInstant& bDeadline = bFamily.*FamilyToDeadline; + return bDeadline < aDeadline; + } + }; + + template<const TFamilyState TState::* StateToFamily, + const TInstant TFamilyState::* FamilyToDeadline> + using TStateHeap = std::priority_queue< + TNameToState::iterator, + std::vector<TNameToState::iterator>, + THeapCompare<StateToFamily, FamilyToDeadline> + >; + + struct TWaitingInfo { + TNameToState::iterator Position; + int Family; + }; + + private: + void EnsureRequest(TFamilyState& state, const TString& name, int family, TInstant now) { + if (state.Waiting) { + return; // request is already pending + } + + if (!state.IsSoftExpired(now) && !state.IsHardExpired(now)) { + return; // response is not expired yet + } + + if (MonCounters) { + switch (family) { + case AF_INET6: + ++*MonCounters->OutgoingInFlightV6; + ++*MonCounters->OutgoingTotalV6; + break; + case AF_INET: + ++*MonCounters->OutgoingInFlightV4; + ++*MonCounters->OutgoingTotalV4; + break; + } + } + + ui64 reqId = ++LastRequestId; + auto& req = WaitingRequests[reqId]; + req.Position = NameToState.find(name); + req.Family = family; + Y_VERIFY(req.Position != NameToState.end()); + + Send(Upstream, new TEvDns::TEvGetHostByName(name, family), IEventHandle::FlagTrackDelivery, reqId); + state.Waiting = true; + } + + template<TFamilyState TState::* StateToFamily, + TInstant TFamilyState::* FamilyToDeadline, + TInstant TFamilyState::* FamilyToNextDeadline, + bool TFamilyState::* FamilyToFlag, + class THeap> + void PushToHeap(THeap& heap, TNameToState::iterator it, TInstant newDeadline) { + auto& family = it->second.*StateToFamily; + TInstant& deadline = family.*FamilyToDeadline; + TInstant& nextDeadline = family.*FamilyToNextDeadline; + bool& flag = family.*FamilyToFlag; + nextDeadline = newDeadline; + if (!flag) { + deadline = newDeadline; + heap.push(it); + flag = true; + } + } + + void PushSoftV6(TNameToState::iterator it, TInstant newDeadline) { + PushToHeap<&TState::StateIPv6, &TFamilyState::SoftDeadline, &TFamilyState::NextSoftDeadline, &TFamilyState::InSoftHeap>(SoftHeapIPv6, it, newDeadline); + } + + void PushHardV6(TNameToState::iterator it, TInstant newDeadline) { + PushToHeap<&TState::StateIPv6, &TFamilyState::HardDeadline, &TFamilyState::NextHardDeadline, &TFamilyState::InHardHeap>(HardHeapIPv6, it, newDeadline); + } + + void PushSoftV4(TNameToState::iterator it, TInstant newDeadline) { + PushToHeap<&TState::StateIPv4, &TFamilyState::SoftDeadline, &TFamilyState::NextSoftDeadline, &TFamilyState::InSoftHeap>(SoftHeapIPv4, it, newDeadline); + } + + void PushHardV4(TNameToState::iterator it, TInstant newDeadline) { + PushToHeap<&TState::StateIPv4, &TFamilyState::HardDeadline, &TFamilyState::NextHardDeadline, &TFamilyState::InHardHeap>(HardHeapIPv4, it, newDeadline); + } + + void ProcessErrorV6(TNameToState::iterator it, int status, TString errorText) { + auto now = TActivationContext::Now(); + if (MonCounters) { + --*MonCounters->OutgoingInFlightV6; + ++*MonCounters->OutgoingErrorsV6; + } + + auto& state = it->second.StateIPv6; + Y_VERIFY(state.Waiting, "Got error for a state we are not waiting"); + state.Waiting = false; + + // When we have a cached positive reply, don't overwrite it with spurious errors + const bool serverReplied = TFamilyState::ServerReplied(status); + if (!serverReplied && state.ServerReplied() && !state.IsHardExpired(now)) { + PushSoftV6(it, now + Options.SoftNegativeExpireTime); + if (state.Status == ARES_SUCCESS) { + SendAddrsV6(it); + } else { + SendErrorsV6(it, now); + } + return; + } + + state.Status = status; + state.ErrorText = std::move(errorText); + PushSoftV6(it, now + Options.SoftNegativeExpireTime); + if (serverReplied) { + // Server actually replied, so keep it cached for longer + PushHardV6(it, now + Options.HardPositiveExpireTime); + } else { + PushHardV6(it, now + Options.HardNegativeExpireTime); + } + + SendErrorsV6(it, now); + } + + void SendErrorsV6(TNameToState::iterator it, TInstant now) { + bool cleaned = false; + auto& state = it->second.StateIPv6; + while (state.WaitingRequests) { + THolder<TIncomingRequest> req(state.WaitingRequests.PopFront()); + if (req->Family == AF_UNSPEC && Options.AllowIPv4 && state.RetryUnspec()) { + if (!cleaned) { + CleanupExpired(now); + cleaned = true; + } + EnqueueRequestIPv4(std::move(req), /* isCacheMiss */ true); + } else { + ReplyWithError(std::move(req), state.Status, state.ErrorText); + } + } + } + + void ProcessErrorV4(TNameToState::iterator it, int status, TString errorText) { + auto now = TActivationContext::Now(); + if (MonCounters) { + --*MonCounters->OutgoingInFlightV4; + ++*MonCounters->OutgoingErrorsV4; + } + + auto& state = it->second.StateIPv4; + Y_VERIFY(state.Waiting, "Got error for a state we are not waiting"); + state.Waiting = false; + + // When we have a cached positive reply, don't overwrite it with spurious errors + const bool serverReplied = TFamilyState::ServerReplied(status); + if (!serverReplied && state.ServerReplied() && !state.IsHardExpired(now)) { + PushSoftV4(it, now + Options.SoftNegativeExpireTime); + if (state.Status == ARES_SUCCESS) { + SendAddrsV4(it); + } else { + SendErrorsV4(it); + } + return; + } + + state.Status = status; + state.ErrorText = std::move(errorText); + PushSoftV4(it, now + Options.SoftNegativeExpireTime); + if (serverReplied) { + // Server actually replied, so keep it cached for longer + PushHardV4(it, now + Options.HardPositiveExpireTime); + } else { + PushHardV4(it, now + Options.HardNegativeExpireTime); + } + + SendErrorsV4(it); + } + + void SendErrorsV4(TNameToState::iterator it) { + auto& state = it->second.StateIPv4; + while (state.WaitingRequests) { + THolder<TIncomingRequest> req(state.WaitingRequests.PopFront()); + ReplyWithError(std::move(req), state.Status, state.ErrorText); + } + } + + void ProcessAddrsV6(TNameToState::iterator it, TVector<struct in6_addr> addrs) { + if (Y_UNLIKELY(addrs.empty())) { + // Probably unnecessary: we don't want to deal with empty address lists + return ProcessErrorV6(it, ARES_ENODATA, ares_strerror(ARES_ENODATA)); + } + + auto now = TActivationContext::Now(); + if (MonCounters) { + --*MonCounters->OutgoingInFlightV6; + } + + auto& state = it->second.StateIPv6; + Y_VERIFY(state.Waiting, "Got reply for a state we are not waiting"); + state.Waiting = false; + + state.Status = ARES_SUCCESS; + it->second.AddrsIPv6 = std::move(addrs); + PushSoftV6(it, now + Options.SoftPositiveExpireTime); + PushHardV6(it, now + Options.HardPositiveExpireTime); + + SendAddrsV6(it); + } + + void SendAddrsV6(TNameToState::iterator it) { + auto& state = it->second.StateIPv6; + while (state.WaitingRequests) { + THolder<TIncomingRequest> req(state.WaitingRequests.PopFront()); + ReplyWithAddrs(std::move(req), it->second.AddrsIPv6); + } + } + + void ProcessAddrsV4(TNameToState::iterator it, TVector<struct in_addr> addrs) { + if (Y_UNLIKELY(addrs.empty())) { + // Probably unnecessary: we don't want to deal with empty address lists + return ProcessErrorV4(it, ARES_ENODATA, ares_strerror(ARES_ENODATA)); + } + + auto now = TActivationContext::Now(); + if (MonCounters) { + --*MonCounters->OutgoingInFlightV4; + } + + auto& state = it->second.StateIPv4; + Y_VERIFY(state.Waiting, "Got reply for a state we are not waiting"); + state.Waiting = false; + + state.Status = ARES_SUCCESS; + it->second.AddrsIPv4 = std::move(addrs); + PushSoftV4(it, now + Options.SoftPositiveExpireTime); + PushHardV4(it, now + Options.HardPositiveExpireTime); + + SendAddrsV4(it); + } + + void SendAddrsV4(TNameToState::iterator it) { + auto& state = it->second.StateIPv4; + while (state.WaitingRequests) { + THolder<TIncomingRequest> req(state.WaitingRequests.PopFront()); + ReplyWithAddrs(std::move(req), it->second.AddrsIPv4); + } + } + + private: + template<TFamilyState TState::*StateToFamily, + TInstant TFamilyState::* FamilyToDeadline, + TInstant TFamilyState::* FamilyToNextDeadline, + bool TFamilyState::* FamilyToFlag> + void DoCleanupExpired(TStateHeap<StateToFamily, FamilyToDeadline>& heap, TInstant now) { + while (!heap.empty()) { + auto it = heap.top(); + auto& family = it->second.*StateToFamily; + TInstant& deadline = family.*FamilyToDeadline; + if (now <= deadline) { + break; + } + + bool& flag = family.*FamilyToFlag; + Y_VERIFY(flag); + heap.pop(); + flag = false; + + TInstant& nextDeadline = family.*FamilyToNextDeadline; + if (now < nextDeadline) { + deadline = nextDeadline; + heap.push(it); + flag = true; + continue; + } + + // Remove unnecessary items + if (!it->second.Needed()) { + NameToState.erase(it); + if (MonCounters) { + *MonCounters->CacheSize = NameToState.size(); + } + } + } + } + + void CleanupExpired(TInstant now) { + DoCleanupExpired<&TState::StateIPv6, &TFamilyState::SoftDeadline, &TFamilyState::NextSoftDeadline, &TFamilyState::InSoftHeap>(SoftHeapIPv6, now); + DoCleanupExpired<&TState::StateIPv6, &TFamilyState::HardDeadline, &TFamilyState::NextHardDeadline, &TFamilyState::InHardHeap>(HardHeapIPv6, now); + DoCleanupExpired<&TState::StateIPv4, &TFamilyState::SoftDeadline, &TFamilyState::NextSoftDeadline, &TFamilyState::InSoftHeap>(SoftHeapIPv4, now); + DoCleanupExpired<&TState::StateIPv4, &TFamilyState::HardDeadline, &TFamilyState::NextHardDeadline, &TFamilyState::InHardHeap>(HardHeapIPv4, now); + } + + template<class TEvent> + void SendError(TActorId replyTo, ui64 cookie, int status, const TString& errorText) { + auto reply = MakeHolder<TEvent>(); + reply->Status = status; + reply->ErrorText = errorText; + this->Send(replyTo, reply.Release(), 0, cookie); + } + + void ReplyWithError(THolder<TIncomingRequest> req, int status, const TString& errorText) { + if (MonCounters) { + ++*MonCounters->IncomingErrors; + } + switch (req->Type) { + case EIncomingRequestType::GetHostByName: { + SendError<TEvDns::TEvGetHostByNameResult>(req->Sender, req->Cookie, status, errorText); + break; + } + case EIncomingRequestType::GetAddr: { + SendError<TEvDns::TEvGetAddrResult>(req->Sender, req->Cookie, status, errorText); + break; + } + } + } + + void ReplyWithAddrs(THolder<TIncomingRequest> req, const TVector<struct in6_addr>& addrs) { + switch (req->Type) { + case EIncomingRequestType::GetHostByName: { + auto reply = MakeHolder<TEvDns::TEvGetHostByNameResult>(); + reply->AddrsV6 = addrs; + Send(req->Sender, reply.Release(), 0, req->Cookie); + break; + } + case EIncomingRequestType::GetAddr: { + Y_VERIFY(!addrs.empty()); + auto reply = MakeHolder<TEvDns::TEvGetAddrResult>(); + reply->Addr = addrs.front(); + Send(req->Sender, reply.Release(), 0, req->Cookie); + break; + } + } + } + + void ReplyWithAddrs(THolder<TIncomingRequest> req, const TVector<struct in_addr>& addrs) { + switch (req->Type) { + case EIncomingRequestType::GetHostByName: { + auto reply = MakeHolder<TEvDns::TEvGetHostByNameResult>(); + reply->AddrsV4 = addrs; + Send(req->Sender, reply.Release(), 0, req->Cookie); + break; + } + case EIncomingRequestType::GetAddr: { + Y_VERIFY(!addrs.empty()); + auto reply = MakeHolder<TEvDns::TEvGetAddrResult>(); + reply->Addr = addrs.front(); + Send(req->Sender, reply.Release(), 0, req->Cookie); + break; + } + } + } + + void ReplyWithError(THolder<TIncomingRequest> req, int status) { + ReplyWithError(std::move(req), status, ares_strerror(status)); + } + + void DropPending(TIncomingRequestList& list, int status, const TString& errorText) { + while (list) { + THolder<TIncomingRequest> req(list.PopFront()); + ReplyWithError(std::move(req), status, errorText); + } + } + + void DropPending(int status, const TString& errorText) { + for (auto& [name, state] : NameToState) { + DropPending(state.StateIPv6.WaitingRequests, status, errorText); + DropPending(state.StateIPv4.WaitingRequests, status, errorText); + } + } + + void DropPending(int status) { + DropPending(status, ares_strerror(status)); + } + + private: + const TActorId Upstream; + const TCachingDnsResolverOptions Options; + const THolder<TMonCounters> MonCounters; + + TNameToState NameToState; + TStateHeap<&TState::StateIPv6, &TFamilyState::SoftDeadline> SoftHeapIPv6; + TStateHeap<&TState::StateIPv6, &TFamilyState::HardDeadline> HardHeapIPv6; + TStateHeap<&TState::StateIPv4, &TFamilyState::SoftDeadline> SoftHeapIPv4; + TStateHeap<&TState::StateIPv4, &TFamilyState::HardDeadline> HardHeapIPv4; + + THashMap<ui64, TWaitingInfo> WaitingRequests; + ui64 LastRequestId = 0; + }; + + IActor* CreateCachingDnsResolver(TActorId upstream, TCachingDnsResolverOptions options) { + return new TCachingDnsResolver(upstream, std::move(options)); + } + +} // namespace NDnsResolver +} // namespace NActors diff --git a/library/cpp/actors/dnsresolver/dnsresolver_caching_ut.cpp b/library/cpp/actors/dnsresolver/dnsresolver_caching_ut.cpp new file mode 100644 index 00000000000..c3b7cb3c77c --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver_caching_ut.cpp @@ -0,0 +1,630 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/core/hfunc.h> +#include <library/cpp/actors/testlib/test_runtime.h> +#include <library/cpp/testing/unittest/registar.h> +#include <util/string/builder.h> + +#include <ares.h> + +using namespace NActors; +using namespace NActors::NDnsResolver; + +// FIXME: use a mock resolver +Y_UNIT_TEST_SUITE(CachingDnsResolver) { + + struct TAddrToString { + TString operator()(const std::monostate&) const { + return "<nothing>"; + } + + TString operator()(const struct in6_addr& addr) const { + char dst[INET6_ADDRSTRLEN]; + auto res = ares_inet_ntop(AF_INET6, &addr, dst, INET6_ADDRSTRLEN); + Y_VERIFY(res, "Cannot convert ipv6 address"); + return dst; + } + + TString operator()(const struct in_addr& addr) const { + char dst[INET_ADDRSTRLEN]; + auto res = ares_inet_ntop(AF_INET, &addr, dst, INET_ADDRSTRLEN); + Y_VERIFY(res, "Cannot convert ipv4 address"); + return dst; + } + }; + + TString AddrToString(const std::variant<std::monostate, struct in6_addr, struct in_addr>& v) { + return std::visit(TAddrToString(), v); + } + + struct TMockReply { + static constexpr TDuration DefaultDelay = TDuration::MilliSeconds(1); + + int Status = 0; + TDuration Delay; + TVector<struct in6_addr> AddrsV6; + TVector<struct in_addr> AddrsV4; + + static TMockReply Error(int status, TDuration delay = DefaultDelay) { + Y_VERIFY(status != 0); + TMockReply reply; + reply.Status = status; + reply.Delay = delay; + return reply; + } + + static TMockReply Empty(TDuration delay = DefaultDelay) { + TMockReply reply; + reply.Delay = delay; + return reply; + } + + static TMockReply ManyV6(const TVector<TString>& addrs, TDuration delay = DefaultDelay) { + TMockReply reply; + reply.Delay = delay; + for (const TString& addr : addrs) { + void* dst = &reply.AddrsV6.emplace_back(); + int status = ares_inet_pton(AF_INET6, addr.c_str(), dst); + Y_VERIFY(status == 1, "Invalid ipv6 address: %s", addr.c_str()); + } + return reply; + } + + static TMockReply ManyV4(const TVector<TString>& addrs, TDuration delay = DefaultDelay) { + TMockReply reply; + reply.Delay = delay; + for (const TString& addr : addrs) { + void* dst = &reply.AddrsV4.emplace_back(); + int status = ares_inet_pton(AF_INET, addr.c_str(), dst); + Y_VERIFY(status == 1, "Invalid ipv4 address: %s", addr.c_str()); + } + return reply; + } + + static TMockReply SingleV6(const TString& addr, TDuration delay = DefaultDelay) { + return ManyV6({ addr }, delay); + } + + static TMockReply SingleV4(const TString& addr, TDuration delay = DefaultDelay) { + return ManyV4({ addr }, delay); + } + }; + + using TMockDnsCallback = std::function<TMockReply (const TString&, int)>; + + class TMockDnsResolver : public TActor<TMockDnsResolver> { + public: + TMockDnsResolver(TMockDnsCallback callback) + : TActor(&TThis::StateWork) + , Callback(std::move(callback)) + { } + + private: + struct TEvPrivate { + enum EEv { + EvScheduled = EventSpaceBegin(TEvents::ES_PRIVATE), + }; + + struct TEvScheduled : public TEventLocal<TEvScheduled, EvScheduled> { + TActorId Sender; + ui64 Cookie; + TMockReply Reply; + + TEvScheduled(TActorId sender, ui64 cookie, TMockReply reply) + : Sender(sender) + , Cookie(cookie) + , Reply(std::move(reply)) + { } + }; + }; + + private: + STRICT_STFUNC(StateWork, { + hFunc(TEvents::TEvPoison, Handle); + hFunc(TEvDns::TEvGetHostByName, Handle); + hFunc(TEvPrivate::TEvScheduled, Handle); + }); + + void Handle(TEvents::TEvPoison::TPtr&) { + PassAway(); + } + + void Handle(TEvDns::TEvGetHostByName::TPtr& ev) { + auto reply = Callback(ev->Get()->Name, ev->Get()->Family); + if (reply.Delay) { + Schedule(reply.Delay, new TEvPrivate::TEvScheduled(ev->Sender, ev->Cookie, std::move(reply))); + } else { + SendReply(ev->Sender, ev->Cookie, std::move(reply)); + } + } + + void Handle(TEvPrivate::TEvScheduled::TPtr& ev) { + SendReply(ev->Get()->Sender, ev->Get()->Cookie, std::move(ev->Get()->Reply)); + } + + private: + void SendReply(const TActorId& sender, ui64 cookie, TMockReply&& reply) { + auto res = MakeHolder<TEvDns::TEvGetHostByNameResult>(); + res->Status = reply.Status; + if (res->Status != 0) { + res->ErrorText = ares_strerror(res->Status); + } else { + res->AddrsV6 = std::move(reply.AddrsV6); + res->AddrsV4 = std::move(reply.AddrsV4); + } + Send(sender, res.Release(), 0, cookie); + } + + private: + TMockDnsCallback Callback; + }; + + struct TCachingDnsRuntime : public TTestActorRuntimeBase { + TCachingDnsResolverOptions ResolverOptions; + TActorId MockResolver; + TActorId Resolver; + TActorId Sleeper; + TString Section_; + + NMonitoring::TDynamicCounters::TCounterPtr InFlight6; + NMonitoring::TDynamicCounters::TCounterPtr InFlight4; + NMonitoring::TDynamicCounters::TCounterPtr Total6; + NMonitoring::TDynamicCounters::TCounterPtr Total4; + NMonitoring::TDynamicCounters::TCounterPtr Misses; + NMonitoring::TDynamicCounters::TCounterPtr Hits; + + THashMap<TString, TMockReply> ReplyV6; + THashMap<TString, TMockReply> ReplyV4; + + TCachingDnsRuntime() { + SetScheduledEventFilter([](auto&&, auto&&, auto&&, auto&&) { return false; }); + ResolverOptions.MonCounters = new NMonitoring::TDynamicCounters; + + ReplyV6["localhost"] = TMockReply::SingleV6("::1"); + ReplyV4["localhost"] = TMockReply::SingleV4("127.0.0.1"); + ReplyV6["yandex.ru"] = TMockReply::SingleV6("2a02:6b8:a::a", TDuration::MilliSeconds(500)); + ReplyV4["yandex.ru"] = TMockReply::SingleV4("77.88.55.77", TDuration::MilliSeconds(250)); + ReplyV6["router.asus.com"] = TMockReply::Error(ARES_ENODATA); + ReplyV4["router.asus.com"] = TMockReply::SingleV4("192.168.0.1"); + } + + void Start(TMockDnsCallback callback) { + MockResolver = Register(new TMockDnsResolver(std::move(callback))); + EnableScheduleForActor(MockResolver); + Resolver = Register(CreateCachingDnsResolver(MockResolver, ResolverOptions)); + Sleeper = AllocateEdgeActor(); + + InFlight6 = ResolverOptions.MonCounters->GetCounter("DnsResolver/Outgoing/InFlight/V6", false); + InFlight4 = ResolverOptions.MonCounters->GetCounter("DnsResolver/Outgoing/InFlight/V4", false); + Total6 = ResolverOptions.MonCounters->GetCounter("DnsResolver/Outgoing/Total/V6", true); + Total4 = ResolverOptions.MonCounters->GetCounter("DnsResolver/Outgoing/Total/V4", true); + Misses = ResolverOptions.MonCounters->GetCounter("DnsResolver/Cache/Misses", true); + Hits = ResolverOptions.MonCounters->GetCounter("DnsResolver/Cache/Hits", true); + } + + void Start() { + Start([this](const TString& name, int family) { + switch (family) { + case AF_INET6: { + auto it = ReplyV6.find(name); + if (it != ReplyV6.end()) { + return it->second; + } + break; + } + case AF_INET: { + auto it = ReplyV4.find(name); + if (it != ReplyV4.end()) { + return it->second; + } + break; + } + } + return TMockReply::Error(ARES_ENOTFOUND); + }); + } + + void Section(const TString& section) { + Section_ = section; + } + + void Sleep(TDuration duration) { + Schedule(new IEventHandle(Sleeper, Sleeper, new TEvents::TEvWakeup), duration); + GrabEdgeEventRethrow<TEvents::TEvWakeup>(Sleeper); + } + + void WaitNoInFlight() { + if (*InFlight6 || *InFlight4) { + TDispatchOptions options; + options.CustomFinalCondition = [&]() { + return !*InFlight6 && !*InFlight4; + }; + DispatchEvents(options); + UNIT_ASSERT_C(!*InFlight6 && !*InFlight4, "Failed to wait for no inflight in " << Section_); + } + } + + void SendGetHostByName(const TActorId& sender, const TString& name, int family = AF_UNSPEC) { + Send(new IEventHandle(Resolver, sender, new TEvDns::TEvGetHostByName(name, family)), 0, true); + } + + void SendGetAddr(const TActorId& sender, const TString& name, int family = AF_UNSPEC) { + Send(new IEventHandle(Resolver, sender, new TEvDns::TEvGetAddr(name, family)), 0, true); + } + + TEvDns::TEvGetHostByNameResult::TPtr WaitGetHostByName(const TActorId& sender) { + return GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + } + + TEvDns::TEvGetAddrResult::TPtr WaitGetAddr(const TActorId& sender) { + return GrabEdgeEventRethrow<TEvDns::TEvGetAddrResult>(sender); + } + + void ExpectInFlight6(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(InFlight6->Val(), count, Section_); + } + + void ExpectInFlight4(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(InFlight4->Val(), count, Section_); + } + + void ExpectTotal6(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(Total6->Val(), count, Section_); + } + + void ExpectTotal4(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(Total4->Val(), count, Section_); + } + + void Expect6(i64 total, i64 inflight) { + UNIT_ASSERT_C( + Total6->Val() == total && InFlight6->Val() == inflight, + Section_ << ": Expect6(" << total << ", " << inflight << ") " + << " but got (" << Total6->Val() << ", " << InFlight6->Val() << ")"); + } + + void Expect4(i64 total, i64 inflight) { + UNIT_ASSERT_C( + Total4->Val() == total && InFlight4->Val() == inflight, + Section_ << ": Expect4(" << total << ", " << inflight << ") " + << " got (" << Total4->Val() << ", " << InFlight4->Val() << ")"); + } + + void ExpectMisses(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(Misses->Val(), count, Section_); + } + + void ExpectHits(i64 count) { + UNIT_ASSERT_VALUES_EQUAL_C(Hits->Val(), count, Section_); + } + + void ExpectGetHostByNameError(const TActorId& sender, int status) { + auto ev = WaitGetHostByName(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, status, Section_ << ": " << ev->Get()->ErrorText); + } + + void ExpectGetAddrError(const TActorId& sender, int status) { + auto ev = WaitGetAddr(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, status, Section_ << ": " << ev->Get()->ErrorText); + } + + void ExpectGetHostByNameSuccess(const TActorId& sender, const TString& expected) { + auto ev = WaitGetHostByName(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, Section_ << ": " << ev->Get()->ErrorText); + TStringBuilder result; + for (const auto& addr : ev->Get()->AddrsV6) { + if (result) { + result << ','; + } + result << TAddrToString()(addr); + } + for (const auto& addr : ev->Get()->AddrsV4) { + if (result) { + result << ','; + } + result << TAddrToString()(addr); + } + UNIT_ASSERT_VALUES_EQUAL_C(TString(result), expected, Section_); + } + + void ExpectGetAddrSuccess(const TActorId& sender, const TString& expected) { + auto ev = WaitGetAddr(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, Section_ << ": " << ev->Get()->ErrorText); + TString result = AddrToString(ev->Get()->Addr); + UNIT_ASSERT_VALUES_EQUAL_C(result, expected, Section_); + } + }; + + Y_UNIT_TEST(UnusableResolver) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + + runtime.Send(new IEventHandle(runtime.MockResolver, { }, new TEvents::TEvPoison), 0, true); + runtime.SendGetAddr(sender, "foo.ru", AF_UNSPEC); + runtime.ExpectGetAddrError(sender, ARES_ENOTINITIALIZED); + } + + Y_UNIT_TEST(ResolveCaching) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + // First time resolve, ipv4 and ipv6 sent in parallel, we wait for ipv6 result + runtime.Section("First time resolve"); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(1, 0); + runtime.Expect4(1, 0); + runtime.ExpectMisses(1); + runtime.ExpectHits(0); + + // Second resolve, ipv6 and ipv4 queries result in a cache hit + runtime.Section("Second resolve, ipv6"); + runtime.SendGetAddr(sender, "yandex.ru", AF_INET6); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(1, 0); + runtime.ExpectHits(1); + runtime.Section("Second resolve, ipv4"); + runtime.SendGetAddr(sender, "yandex.ru", AF_INET); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + runtime.Expect4(1, 0); + runtime.ExpectHits(2); + + // Wait until soft expiration and try ipv4 again + // Will cause a cache hit, but will start a new ipv4 request in background + runtime.Section("Retry ipv4 after soft expiration"); + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "yandex.ru", AF_INET); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + runtime.Expect6(1, 0); + runtime.Expect4(2, 1); + runtime.ExpectMisses(1); + runtime.ExpectHits(3); + runtime.WaitNoInFlight(); + + // Wait until soft expiration and try both again + // Will cause a cache hit, but will start a new ipv6 request in background + runtime.Section("Retry both after soft expiration"); + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(2, 1); + runtime.Expect4(2, 0); + runtime.ExpectMisses(1); + runtime.ExpectHits(4); + runtime.WaitNoInFlight(); + + // Wait until hard expiration and try both again + // Will cause a cache miss and new resolve requests + runtime.Section("Retry both after hard expiration"); + runtime.Sleep(TDuration::Hours(2)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(3, 0); + runtime.Expect4(3, 0); + runtime.ExpectMisses(2); + runtime.ExpectHits(4); + + // Wait half the hard expiration time, must always result in a cache hit + runtime.Section("Retry both after half hard expiration"); + for (ui64 i = 1; i <= 4; ++i) { + runtime.Sleep(TDuration::Hours(1)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(3 + i, 1); + runtime.ExpectHits(4 + i); + runtime.WaitNoInFlight(); + } + + // Change v6 result to a timeout, must keep using cached result until hard expiration + runtime.Section("Dns keeps timing out"); + runtime.ReplyV6["yandex.ru"] = TMockReply::Error(ARES_ETIMEOUT); + for (ui64 i = 1; i <= 4; ++i) { + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.Expect6(7 + i, 1); + runtime.ExpectHits(8 + i); + runtime.WaitNoInFlight(); + } + + // Change v6 result to nodata, must switch to a v4 result eventually + runtime.Section("Host changes to being ipv4 only"); + runtime.ReplyV6["yandex.ru"] = TMockReply::Error(ARES_ENODATA); + runtime.Sleep(TDuration::Seconds(2)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + runtime.WaitNoInFlight(); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + runtime.Expect6(12, 0); + runtime.Expect4(4, 0); + runtime.ExpectMisses(3); + + // Change v6 result to nxdomain, must not fall back to a v4 result + runtime.Section("Host is removed from dns"); + runtime.ReplyV6["yandex.ru"] = TMockReply::Error(ARES_ENOTFOUND); + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + runtime.WaitNoInFlight(); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + } + + Y_UNIT_TEST(ResolveCachingV4) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.Section("First request"); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "192.168.0.1"); + runtime.ExpectMisses(1); + + runtime.Section("Second request"); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "192.168.0.1"); + runtime.ExpectHits(1); + + runtime.Section("Dns keeps timing out"); + runtime.ReplyV6["router.asus.com"] = TMockReply::Error(ARES_ETIMEOUT); + runtime.ReplyV4["router.asus.com"] = TMockReply::Error(ARES_ETIMEOUT); + for (ui64 i = 1; i <= 4; ++i) { + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "192.168.0.1"); + runtime.Expect6(1 + i, 1); + runtime.Expect4(1 + i, 1); + runtime.ExpectHits(1 + i); + runtime.WaitNoInFlight(); + } + + runtime.Section("Host is removed from ipv4 dns"); + runtime.ReplyV4["router.asus.com"] = TMockReply::Error(ARES_ENOTFOUND); + runtime.Sleep(TDuration::Seconds(15)); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetAddrSuccess(sender, "192.168.0.1"); + runtime.WaitNoInFlight(); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + } + + Y_UNIT_TEST(EventualTimeout) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.ReplyV6["notfound.ru"] = TMockReply::Error(ARES_ENODATA); + runtime.ReplyV4["notfound.ru"] = TMockReply::Error(ARES_ENOTFOUND); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + + runtime.ReplyV4["notfound.ru"] = TMockReply::Error(ARES_ETIMEOUT); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + runtime.WaitNoInFlight(); + + bool timeout = false; + for (ui64 i = 1; i <= 8; ++i) { + runtime.Sleep(TDuration::Minutes(30)); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + auto ev = runtime.WaitGetAddr(sender); + if (ev->Get()->Status == ARES_ETIMEOUT && i > 2) { + timeout = true; + break; + } + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, ARES_ENOTFOUND, + "Iteration " << i << ": " << ev->Get()->ErrorText); + } + + UNIT_ASSERT_C(timeout, "DnsResolver did not reply with a timeout"); + } + + Y_UNIT_TEST(MultipleRequestsAndHosts) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.SendGetHostByName(sender, "router.asus.com", AF_UNSPEC); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.SendGetHostByName(sender, "yandex.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameSuccess(sender, "192.168.0.1"); + runtime.ExpectGetAddrSuccess(sender, "192.168.0.1"); + runtime.ExpectGetHostByNameSuccess(sender, "2a02:6b8:a::a"); + runtime.ExpectGetAddrSuccess(sender, "2a02:6b8:a::a"); + + runtime.SendGetHostByName(sender, "notfound.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameError(sender, ARES_ENOTFOUND); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + } + + Y_UNIT_TEST(DisabledIPv6) { + TCachingDnsRuntime runtime; + runtime.ResolverOptions.AllowIPv6 = false; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.SendGetHostByName(sender, "yandex.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameSuccess(sender, "77.88.55.77"); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + + runtime.SendGetHostByName(sender, "yandex.ru", AF_INET6); + runtime.SendGetAddr(sender, "yandex.ru", AF_INET6); + runtime.ExpectGetHostByNameError(sender, ARES_EBADFAMILY); + runtime.ExpectGetAddrError(sender, ARES_EBADFAMILY); + + runtime.SendGetHostByName(sender, "yandex.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameSuccess(sender, "77.88.55.77"); + runtime.ExpectGetAddrSuccess(sender, "77.88.55.77"); + + runtime.SendGetHostByName(sender, "notfound.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameError(sender, ARES_ENOTFOUND); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + } + + Y_UNIT_TEST(DisabledIPv4) { + TCachingDnsRuntime runtime; + runtime.ResolverOptions.AllowIPv4 = false; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.SendGetHostByName(sender, "router.asus.com", AF_UNSPEC); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetHostByNameError(sender, ARES_ENODATA); + runtime.ExpectGetAddrError(sender, ARES_ENODATA); + + runtime.SendGetHostByName(sender, "router.asus.com", AF_INET); + runtime.SendGetAddr(sender, "router.asus.com", AF_INET); + runtime.ExpectGetHostByNameError(sender, ARES_EBADFAMILY); + runtime.ExpectGetAddrError(sender, ARES_EBADFAMILY); + + runtime.SendGetHostByName(sender, "router.asus.com", AF_UNSPEC); + runtime.SendGetAddr(sender, "router.asus.com", AF_UNSPEC); + runtime.ExpectGetHostByNameError(sender, ARES_ENODATA); + runtime.ExpectGetAddrError(sender, ARES_ENODATA); + + runtime.SendGetHostByName(sender, "notfound.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "notfound.ru", AF_UNSPEC); + runtime.ExpectGetHostByNameError(sender, ARES_ENOTFOUND); + runtime.ExpectGetAddrError(sender, ARES_ENOTFOUND); + } + + Y_UNIT_TEST(PoisonPill) { + TCachingDnsRuntime runtime; + runtime.Initialize(); + runtime.Start(); + + auto sender = runtime.AllocateEdgeActor(); + + runtime.SendGetHostByName(sender, "yandex.ru", AF_UNSPEC); + runtime.SendGetAddr(sender, "yandex.ru", AF_UNSPEC); + runtime.Send(new IEventHandle(runtime.Resolver, sender, new TEvents::TEvPoison), 0, true); + runtime.ExpectGetHostByNameError(sender, ARES_ECANCELLED); + runtime.ExpectGetAddrError(sender, ARES_ECANCELLED); + } + +} diff --git a/library/cpp/actors/dnsresolver/dnsresolver_ondemand.cpp b/library/cpp/actors/dnsresolver/dnsresolver_ondemand.cpp new file mode 100644 index 00000000000..2025162e951 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver_ondemand.cpp @@ -0,0 +1,64 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/core/hfunc.h> + +namespace NActors { +namespace NDnsResolver { + + class TOnDemandDnsResolver : public TActor<TOnDemandDnsResolver> { + public: + TOnDemandDnsResolver(TOnDemandDnsResolverOptions options) + : TActor(&TThis::StateWork) + , Options(std::move(options)) + { } + + static constexpr EActivityType ActorActivityType() { + return DNS_RESOLVER; + } + + private: + STRICT_STFUNC(StateWork, { + cFunc(TEvents::TEvPoison::EventType, PassAway); + fFunc(TEvDns::TEvGetHostByName::EventType, Forward); + fFunc(TEvDns::TEvGetAddr::EventType, Forward); + }); + + void Forward(STATEFN_SIG) { + ev->Rewrite(ev->GetTypeRewrite(), GetUpstream()); + TActivationContext::Send(std::move(ev)); + } + + private: + TActorId GetUpstream() { + if (Y_UNLIKELY(!CachingResolverId)) { + if (Y_LIKELY(!SimpleResolverId)) { + SimpleResolverId = RegisterWithSameMailbox(CreateSimpleDnsResolver(Options)); + } + CachingResolverId = RegisterWithSameMailbox(CreateCachingDnsResolver(SimpleResolverId, Options)); + } + return CachingResolverId; + } + + void PassAway() override { + if (CachingResolverId) { + Send(CachingResolverId, new TEvents::TEvPoison); + CachingResolverId = { }; + } + if (SimpleResolverId) { + Send(SimpleResolverId, new TEvents::TEvPoison); + SimpleResolverId = { }; + } + } + + private: + TOnDemandDnsResolverOptions Options; + TActorId SimpleResolverId; + TActorId CachingResolverId; + }; + + IActor* CreateOnDemandDnsResolver(TOnDemandDnsResolverOptions options) { + return new TOnDemandDnsResolver(std::move(options)); + } + +} // namespace NDnsResolver +} // namespace NActors diff --git a/library/cpp/actors/dnsresolver/dnsresolver_ondemand_ut.cpp b/library/cpp/actors/dnsresolver/dnsresolver_ondemand_ut.cpp new file mode 100644 index 00000000000..27584845524 --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver_ondemand_ut.cpp @@ -0,0 +1,24 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/testlib/test_runtime.h> +#include <library/cpp/testing/unittest/registar.h> + +using namespace NActors; +using namespace NActors::NDnsResolver; + +Y_UNIT_TEST_SUITE(OnDemandDnsResolver) { + + Y_UNIT_TEST(ResolveLocalHost) { + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + auto resolver = runtime.Register(CreateOnDemandDnsResolver()); + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetHostByName("localhost", AF_UNSPEC)), + 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, ev->Get()->ErrorText); + size_t addrs = ev->Get()->AddrsV4.size() + ev->Get()->AddrsV6.size(); + UNIT_ASSERT_C(addrs > 0, "Got " << addrs << " addresses"); + } + +} diff --git a/library/cpp/actors/dnsresolver/dnsresolver_ut.cpp b/library/cpp/actors/dnsresolver/dnsresolver_ut.cpp new file mode 100644 index 00000000000..0c343a805ce --- /dev/null +++ b/library/cpp/actors/dnsresolver/dnsresolver_ut.cpp @@ -0,0 +1,98 @@ +#include "dnsresolver.h" + +#include <library/cpp/actors/testlib/test_runtime.h> +#include <library/cpp/testing/unittest/registar.h> +#include <util/string/builder.h> + +#include <ares.h> + +using namespace NActors; +using namespace NActors::NDnsResolver; + +Y_UNIT_TEST_SUITE(DnsResolver) { + + struct TSilentUdpServer { + TInetDgramSocket Socket; + ui16 Port; + + TSilentUdpServer() { + TSockAddrInet addr("127.0.0.1", 0); + int err = Socket.Bind(&addr); + Y_VERIFY(err == 0, "Cannot bind a udp socket"); + Port = addr.GetPort(); + } + }; + + Y_UNIT_TEST(ResolveLocalHost) { + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + auto resolver = runtime.Register(CreateSimpleDnsResolver()); + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetHostByName("localhost", AF_UNSPEC)), + 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, ev->Get()->ErrorText); + size_t addrs = ev->Get()->AddrsV4.size() + ev->Get()->AddrsV6.size(); + UNIT_ASSERT_C(addrs > 0, "Got " << addrs << " addresses"); + } + + Y_UNIT_TEST(ResolveYandexRu) { + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + auto resolver = runtime.Register(CreateSimpleDnsResolver()); + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetHostByName("yandex.ru", AF_UNSPEC)), + 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, ev->Get()->ErrorText); + size_t addrs = ev->Get()->AddrsV4.size() + ev->Get()->AddrsV6.size(); + UNIT_ASSERT_C(addrs > 0, "Got " << addrs << " addresses"); + } + + Y_UNIT_TEST(GetAddrYandexRu) { + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + auto resolver = runtime.Register(CreateSimpleDnsResolver()); + + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetAddr("yandex.ru", AF_UNSPEC)), + 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetAddrResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, 0, ev->Get()->ErrorText); + UNIT_ASSERT_C(ev->Get()->IsV4() || ev->Get()->IsV6(), "Expect v4 or v6 address"); + } + + Y_UNIT_TEST(ResolveTimeout) { + TSilentUdpServer server; + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + TSimpleDnsResolverOptions options; + options.Timeout = TDuration::MilliSeconds(250); + options.Attempts = 2; + options.Servers.emplace_back(TStringBuilder() << "127.0.0.1:" << server.Port); + auto resolver = runtime.Register(CreateSimpleDnsResolver(options)); + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetHostByName("timeout.yandex.ru", AF_INET)), + 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, ARES_ETIMEOUT, ev->Get()->ErrorText); + } + + Y_UNIT_TEST(ResolveGracefulStop) { + TSilentUdpServer server; + TTestActorRuntimeBase runtime; + runtime.Initialize(); + auto sender = runtime.AllocateEdgeActor(); + TSimpleDnsResolverOptions options; + options.Timeout = TDuration::Seconds(5); + options.Attempts = 5; + options.Servers.emplace_back(TStringBuilder() << "127.0.0.1:" << server.Port); + auto resolver = runtime.Register(CreateSimpleDnsResolver(options)); + runtime.Send(new IEventHandle(resolver, sender, new TEvDns::TEvGetHostByName("timeout.yandex.ru", AF_INET)), + 0, true); + runtime.Send(new IEventHandle(resolver, sender, new TEvents::TEvPoison), 0, true); + auto ev = runtime.GrabEdgeEventRethrow<TEvDns::TEvGetHostByNameResult>(sender); + UNIT_ASSERT_VALUES_EQUAL_C(ev->Get()->Status, ARES_ECANCELLED, ev->Get()->ErrorText); + } + +} diff --git a/library/cpp/actors/dnsresolver/ut/ya.make b/library/cpp/actors/dnsresolver/ut/ya.make new file mode 100644 index 00000000000..ad936bdacd6 --- /dev/null +++ b/library/cpp/actors/dnsresolver/ut/ya.make @@ -0,0 +1,20 @@ +UNITTEST_FOR(library/cpp/actors/dnsresolver) + +OWNER(g:kikimr) + +PEERDIR( + library/cpp/actors/testlib +) + +SRCS( + dnsresolver_caching_ut.cpp + dnsresolver_ondemand_ut.cpp + dnsresolver_ut.cpp +) + +ADDINCL(contrib/libs/c-ares) + +TAG(ya:external) +REQUIREMENTS(network:full) + +END() diff --git a/library/cpp/actors/dnsresolver/ya.make b/library/cpp/actors/dnsresolver/ya.make new file mode 100644 index 00000000000..329c56c5b3a --- /dev/null +++ b/library/cpp/actors/dnsresolver/ya.make @@ -0,0 +1,20 @@ +LIBRARY() + +OWNER(g:kikimr) + +SRCS( + dnsresolver.cpp + dnsresolver_caching.cpp + dnsresolver_ondemand.cpp +) + +PEERDIR( + library/cpp/actors/core + contrib/libs/c-ares +) + +ADDINCL(contrib/libs/c-ares) + +END() + +RECURSE_FOR_TESTS(ut) |