diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /library/cpp/ipmath | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'library/cpp/ipmath')
-rw-r--r-- | library/cpp/ipmath/ipmath.cpp | 357 | ||||
-rw-r--r-- | library/cpp/ipmath/ipmath.h | 160 | ||||
-rw-r--r-- | library/cpp/ipmath/ipmath_ut.cpp | 507 | ||||
-rw-r--r-- | library/cpp/ipmath/range_set.cpp | 99 | ||||
-rw-r--r-- | library/cpp/ipmath/range_set.h | 66 | ||||
-rw-r--r-- | library/cpp/ipmath/ut/ya.make | 16 | ||||
-rw-r--r-- | library/cpp/ipmath/ya.make | 17 |
7 files changed, 1222 insertions, 0 deletions
diff --git a/library/cpp/ipmath/ipmath.cpp b/library/cpp/ipmath/ipmath.cpp new file mode 100644 index 0000000000..b8cca00c80 --- /dev/null +++ b/library/cpp/ipmath/ipmath.cpp @@ -0,0 +1,357 @@ +#include "ipmath.h" + +namespace { + constexpr auto IPV4_BITS = 32; + constexpr auto IPV6_BITS = 128; + + const ui128 MAX_IPV4_ADDR = Max<ui32>(); + const ui128 MAX_IPV6_ADDR = Max<ui128>(); + + TStringBuf TypeToString(TIpv6Address::TIpType type) { + switch (type) { + case TIpv6Address::Ipv4: + return TStringBuf("IPv4"); + case TIpv6Address::Ipv6: + return TStringBuf("IPv6"); + default: + return TStringBuf("UNKNOWN"); + } + } + + size_t MaxPrefixLenForType(TIpv6Address::TIpType type) { + switch (type) { + case TIpv6Address::Ipv4: + return IPV4_BITS; + case TIpv6Address::Ipv6: + return IPV6_BITS; + case TIpv6Address::LAST: + ythrow yexception() << "invalid type"; + } + } + + template <ui8 ADDR_LEN> + ui128 LowerBoundForPrefix(ui128 value, ui8 prefixLen) { + const int shift = ADDR_LEN - prefixLen; + const ui128 shifted = (shift < 128) ? (ui128{1} << shift) : 0; + ui128 mask = ~(shifted - 1); + return value & mask; + } + + template <ui8 ADDR_LEN> + ui128 UpperBoundForPrefix(ui128 value, ui8 prefixLen) { + const int shift = ADDR_LEN - prefixLen; + const ui128 shifted = (shift < 128) ? (ui128{1} << shift) : 0; + ui128 mask = shifted - 1; + return value | mask; + } + + auto LowerBoundForPrefix4 = LowerBoundForPrefix<IPV4_BITS>; + auto LowerBoundForPrefix6 = LowerBoundForPrefix<IPV6_BITS>; + auto UpperBoundForPrefix4 = UpperBoundForPrefix<IPV4_BITS>; + auto UpperBoundForPrefix6 = UpperBoundForPrefix<IPV6_BITS>; + + TIpv6Address IpFromStringSafe(const TString& s) { + bool ok{}; + auto addr = TIpv6Address::FromString(s, ok); + Y_ENSURE(ok, "Failed to parse an IP address from " << s); + return addr; + } + + /// it's different from TIpv6Address::IsValid for 0.0.0.0 + bool IsValid(TIpv6Address addr) { + switch (addr.Type()) { + case TIpv6Address::Ipv4: + case TIpv6Address::Ipv6: + return true; + + case TIpv6Address::LAST: + return false; + } + } + + bool HasNext(TIpv6Address addr) { + switch (addr.Type()) { + case TIpv6Address::Ipv4: + return ui128(addr) != MAX_IPV4_ADDR; + case TIpv6Address::Ipv6: + return ui128(addr) != MAX_IPV6_ADDR; + case TIpv6Address::LAST: + return false; + } + } + + TIpv6Address Next(TIpv6Address addr) { + return {ui128(addr) + 1, addr.Type()}; + } +} // namespace + +TIpv6Address LowerBoundForPrefix(TIpv6Address value, ui8 prefixLen) { + auto type = value.Type(); + switch (type) { + case TIpv6Address::Ipv4: + return {LowerBoundForPrefix4(value, prefixLen), type}; + case TIpv6Address::Ipv6: + return {LowerBoundForPrefix6(value, prefixLen), type}; + default: + ythrow yexception() << "invalid type"; + } +} + +TIpv6Address UpperBoundForPrefix(TIpv6Address value, ui8 prefixLen) { + auto type = value.Type(); + switch (type) { + case TIpv6Address::Ipv4: + return {UpperBoundForPrefix4(value, prefixLen), type}; + case TIpv6Address::Ipv6: + return {UpperBoundForPrefix6(value, prefixLen), type}; + default: + ythrow yexception() << "invalid type"; + } +} + +TIpAddressRange::TIpAddressRangeBuilder::operator TIpAddressRange() { + return Build(); +} + +TIpAddressRange TIpAddressRange::TIpAddressRangeBuilder::Build() { + return TIpAddressRange{Start_, End_}; +} + +TIpAddressRange::TIpAddressRangeBuilder::TIpAddressRangeBuilder(const TString& from) + : TIpAddressRangeBuilder{IpFromStringSafe(from)} +{ +} + +TIpAddressRange::TIpAddressRangeBuilder::TIpAddressRangeBuilder(TIpv6Address from) { + Y_ENSURE_EX(IsValid(from), TInvalidIpRangeException() << "Address " << from.ToString() << " is invalid"); + Start_ = from; + End_ = Start_; +} + +TIpAddressRange::TIpAddressRangeBuilder& TIpAddressRange::TIpAddressRangeBuilder::To(const TString& to) { + End_ = IpFromStringSafe(to); + return *this; +} + +TIpAddressRange::TIpAddressRangeBuilder& TIpAddressRange::TIpAddressRangeBuilder::To(TIpv6Address to) { + Y_ENSURE_EX(IsValid(to), TInvalidIpRangeException() << "Address " << to.ToString() << " is invalid"); + Start_ = to; + return *this; +} + +TIpAddressRange::TIpAddressRangeBuilder& TIpAddressRange::TIpAddressRangeBuilder::WithPrefix(ui8 len) { + Y_ENSURE_EX(IsValid(Start_), TInvalidIpRangeException() << "Start value must be set before prefix"); + const auto type = Start_.Type(); + const auto maxLen = MaxPrefixLenForType(type); + Y_ENSURE_EX(len <= maxLen, TInvalidIpRangeException() << "Maximum prefix length for this address type is " + << maxLen << ", but requested " << (ui32)len); + + const auto lowerBound = LowerBoundForPrefix(Start_, len); + Y_ENSURE_EX(Start_ == lowerBound, TInvalidIpRangeException() << "Cannot create IP range from start address " + << Start_ << " with prefix length " << (ui32)len); + + End_ = UpperBoundForPrefix(Start_, len); + + return *this; +} + +void TIpAddressRange::Init(TIpv6Address from, TIpv6Address to) { + Start_ = from; + End_ = to; + + Y_ENSURE_EX(Start_ <= End_, TInvalidIpRangeException() << "Invalid IP address range: from " << Start_ << " to " << End_); + Y_ENSURE_EX(Start_.Type() == End_.Type(), TInvalidIpRangeException() + << "Address type mismtach: start address type is " << TypeToString(Start_.Type()) + << " end type is " << TypeToString(End_.Type())); +} + +TIpAddressRange::TIpAddressRange(TIpv6Address start, TIpv6Address end) { + Y_ENSURE_EX(IsValid(start), TInvalidIpRangeException() << "start address " << start.ToString() << " is invalid"); + Y_ENSURE_EX(IsValid(end), TInvalidIpRangeException() << "end address " << end.ToString() << " is invalid"); + Init(start, end); +} + +TIpAddressRange::TIpAddressRange(const TString& start, const TString& end) { + auto startAddr = IpFromStringSafe(start); + auto endAddr = IpFromStringSafe(end); + Init(startAddr, endAddr); +} + +TIpAddressRange::~TIpAddressRange() { +} + +TIpAddressRange::TIpType TIpAddressRange::Type() const { + return Start_.Type(); +} + +ui128 TIpAddressRange::Size() const { + return ui128(End_) - ui128(Start_) + 1; +} + +bool TIpAddressRange::IsSingle() const { + return Start_ == End_; +} + +bool TIpAddressRange::Contains(const TIpAddressRange& other) const { + return Start_ <= other.Start_ && End_ >= other.End_; +} + +bool TIpAddressRange::Contains(const TIpv6Address& addr) const { + return Start_ <= addr && End_ >= addr; +} + +bool TIpAddressRange::Overlaps(const TIpAddressRange& other) const { + return Start_ <= other.End_ && other.Start_ <= End_; +} + +bool TIpAddressRange::IsConsecutive(const TIpAddressRange& other) const { + return (HasNext(End_) && Next(End_) == other.Start_) + || (HasNext(other.End_) && Next(other.End_) == Start_); +} + +TIpAddressRange TIpAddressRange::Union(const TIpAddressRange& other) const { + Y_ENSURE(IsConsecutive(other) || Overlaps(other), "Can merge only consecutive or overlapping ranges"); + Y_ENSURE(other.Start_.Type() == Start_.Type(), "Cannot merge ranges of addresses of different types"); + + auto s = Start_; + auto e = End_; + + s = {Min<ui128>(Start_, other.Start_), Start_.Type()}; + e = {Max<ui128>(End_, other.End_), End_.Type()}; + + return {s, e}; +} + +TIpAddressRange TIpAddressRange::FromCidrString(const TString& str) { + if (auto result = TryFromCidrString(str)) { + return *result; + } + + ythrow TInvalidIpRangeException() << "Cannot parse " << str << " as a CIDR string"; +} + +TMaybe<TIpAddressRange> TIpAddressRange::TryFromCidrString(const TString& str) { + auto idx = str.rfind('/'); + if (idx == TString::npos) { + return Nothing(); + } + + TStringBuf sb{str}; + TStringBuf address, prefix; + sb.SplitAt(idx, address, prefix); + prefix.Skip(1); + + ui8 prefixLen{}; + if (!::TryFromString(prefix, prefixLen)) { + return Nothing(); + } + + return TIpAddressRange::From(TString{address}) + .WithPrefix(prefixLen); +} + +TIpAddressRange TIpAddressRange::FromRangeString(const TString& str) { + if (auto result = TryFromRangeString(str)) { + return *result; + } + + ythrow TInvalidIpRangeException() << "Cannot parse " << str << " as a range string"; +} + +TMaybe<TIpAddressRange> TIpAddressRange::TryFromRangeString(const TString& str) { + auto idx = str.find('-'); + if (idx == TString::npos) { + return Nothing(); + } + + TStringBuf sb{str}; + TStringBuf start, end; + sb.SplitAt(idx, start, end); + end.Skip(1); + + return TIpAddressRange::From(TString{start}).To(TString{end}); +} + +TIpAddressRange TIpAddressRange::FromString(const TString& str) { + if (auto result = TryFromString(str)) { + return *result; + } + + ythrow TInvalidIpRangeException() << "Cannot parse an IP address from " << str; +} + +TMaybe<TIpAddressRange> TIpAddressRange::TryFromString(const TString& str) { + if (auto idx = str.find('/'); idx != TString::npos) { + return TryFromCidrString(str); + } else if (idx = str.find('-'); idx != TString::npos) { + return TryFromRangeString(str); + } else { + bool ok{}; + auto addr = TIpv6Address::FromString(str, ok); + if (!ok) { + return Nothing(); + } + + return TIpAddressRange::From(addr); + } +} + +TString TIpAddressRange::ToRangeString() const { + bool ok{}; + return TStringBuilder() << Start_.ToString(ok) << "-" << End_.ToString(ok); +} + +TIpAddressRange::TIterator TIpAddressRange::begin() const { + return Begin(); +} + +TIpAddressRange::TIterator TIpAddressRange::Begin() const { + return TIpAddressRange::TIterator{Start_}; +} + +TIpAddressRange::TIterator TIpAddressRange::end() const { + return End(); +} + +TIpAddressRange::TIterator TIpAddressRange::End() const { + return TIpAddressRange::TIterator{{ui128(End_) + 1, End_.Type()}}; +} + +TIpAddressRange::TIpAddressRangeBuilder TIpAddressRange::From(TIpv6Address from) { + return TIpAddressRangeBuilder{from}; +}; + +TIpAddressRange::TIpAddressRangeBuilder TIpAddressRange::From(const TString& from) { + return TIpAddressRangeBuilder{from}; +}; + +bool operator==(const TIpAddressRange& lhs, const TIpAddressRange& rhs) { + return lhs.Start_ == rhs.Start_ && lhs.End_ == rhs.End_; +} + +bool operator!=(const TIpAddressRange& lhs, const TIpAddressRange& rhs) { + return !(lhs == rhs); +} + +TIpAddressRange::TIterator::TIterator(TIpv6Address val) noexcept + : Current_{val} +{ +} + +bool TIpAddressRange::TIterator::operator==(const TIpAddressRange::TIterator& other) noexcept { + return Current_ == other.Current_; +} + +bool TIpAddressRange::TIterator::operator!=(const TIpAddressRange::TIterator& other) noexcept { + return !(*this == other); +} + +TIpAddressRange::TIterator& TIpAddressRange::TIterator::operator++() noexcept { + ui128 numeric = Current_; + Current_ = {numeric + 1, Current_.Type()}; + return *this; +} + +const TIpv6Address& TIpAddressRange::TIterator::operator*() noexcept { + return Current_; +} diff --git a/library/cpp/ipmath/ipmath.h b/library/cpp/ipmath/ipmath.h new file mode 100644 index 0000000000..b6df5416f8 --- /dev/null +++ b/library/cpp/ipmath/ipmath.h @@ -0,0 +1,160 @@ +#pragma once + +#include <library/cpp/ipv6_address/ipv6_address.h> + +#include <util/generic/maybe.h> +#include <util/ysaveload.h> + +struct TInvalidIpRangeException: public virtual yexception { +}; + +class TIpAddressRange { + friend bool operator==(const TIpAddressRange& lhs, const TIpAddressRange& rhs); + friend bool operator!=(const TIpAddressRange& lhs, const TIpAddressRange& rhs); + + class TIpAddressRangeBuilder; +public: + class TIterator; + using TIpType = TIpv6Address::TIpType; + + TIpAddressRange() = default; + TIpAddressRange(TIpv6Address start, TIpv6Address end); + TIpAddressRange(const TString& start, const TString& end); + ~TIpAddressRange(); + + static TIpAddressRangeBuilder From(TIpv6Address from); + static TIpAddressRangeBuilder From(const TString& from); + + /** + * Parses a string tormatted in Classless Iter-Domain Routing (CIDR) notation. + * @param str a CIDR-formatted string, e.g. "192.168.0.0/16" + * @return a new TIpAddressRange + * @throws TInvalidIpRangeException if the string cannot be parsed. + */ + static TIpAddressRange FromCidrString(const TString& str); + static TMaybe<TIpAddressRange> TryFromCidrString(const TString& str); + + /** + * Parses a string formatted as two dash-separated addresses. + * @param str a CIDR-formatted string, e.g. "192.168.0.0-192.168.0.2" + * @return a new TIpAddressRange + * @throws TInvalidIpRangeException if the string cannot be parsed. + */ + static TIpAddressRange FromRangeString(const TString& str); + static TMaybe<TIpAddressRange> TryFromRangeString(const TString& str); + + TString ToRangeString() const; + + /** + * Tries to guess the format and parse it. Format must be one of: CIDR ("10.0.0.0/24"), range ("10.0.0.0-10.0.0.10") or a single address. + * @return a new TIpAddressRange + * @throws TInvlidIpRangeException if the string doesn't match any known format or if parsing failed. + */ + static TIpAddressRange FromString(const TString& str); + static TMaybe<TIpAddressRange> TryFromString(const TString& str); + + TIpType Type() const; + + // XXX: uint128 cannot hold size of the complete range of IPv6 addresses. Use IsComplete to determine whether this is the case. + ui128 Size() const; + + /** + * Determines whether this range contains only one address. + * @return true if contains only one address, otherwise false. + */ + bool IsSingle() const; + bool IsComplete() const; + + bool Contains(const TIpAddressRange& other) const; + bool Contains(const TIpv6Address& addr) const; + + bool Overlaps(const TIpAddressRange& other) const; + + /** + * Determines whether two ranges follow one after another without any gaps. + * @return true if either this range follows the given one or vice versa, otherwise false. + */ + bool IsConsecutive(const TIpAddressRange& other) const; + + /** + * Concatenates another range into this one. + * Note, that ranges must be either consecutive or overlapping. + * @throws yexception if ranges are neither consecutive nor overlapping. + */ + TIpAddressRange Union(const TIpAddressRange& other) const; + + template <typename TFunction> + void ForEach(TFunction func); + + // for-each compliance interface + TIterator begin() const; + TIterator end() const; + + // Arcadia style-guide friendly + TIterator Begin() const; + TIterator End() const; + + Y_SAVELOAD_DEFINE(Start_, End_); + +private: + void Init(TIpv6Address, TIpv6Address); + + TIpv6Address Start_; + TIpv6Address End_; +}; + +bool operator==(const TIpAddressRange& lhs, const TIpAddressRange& rhs); +bool operator!=(const TIpAddressRange& lhs, const TIpAddressRange& rhs); + +TIpv6Address LowerBoundForPrefix(TIpv6Address value, ui8 prefixLen); +TIpv6Address UpperBoundForPrefix(TIpv6Address value, ui8 prefixLen); + + +class TIpAddressRange::TIpAddressRangeBuilder { + friend class TIpAddressRange; + TIpAddressRangeBuilder() = default; + TIpAddressRangeBuilder(TIpv6Address from); + TIpAddressRangeBuilder(const TString& from); + TIpAddressRangeBuilder(const TIpAddressRangeBuilder&) = default; + TIpAddressRangeBuilder& operator=(const TIpAddressRangeBuilder&) = default; + + TIpAddressRangeBuilder(TIpAddressRangeBuilder&&) = default; + TIpAddressRangeBuilder& operator=(TIpAddressRangeBuilder&&) = default; + +public: + operator TIpAddressRange(); + TIpAddressRange Build(); + + TIpAddressRangeBuilder& To(const TString&); + TIpAddressRangeBuilder& To(TIpv6Address); + + TIpAddressRangeBuilder& WithPrefix(ui8 len); + +private: + TIpv6Address Start_; + TIpv6Address End_; +}; + + +class TIpAddressRange::TIterator { +public: + TIterator(TIpv6Address val) noexcept; + + bool operator==(const TIpAddressRange::TIterator& other) noexcept; + bool operator!=(const TIpAddressRange::TIterator& other) noexcept; + + TIterator& operator++() noexcept; + const TIpv6Address& operator*() noexcept; + +private: + TIpv6Address Current_; +}; + + +template <typename TFunction> +void TIpAddressRange::ForEach(TFunction func) { + static_assert(std::is_invocable<TFunction, TIpv6Address>::value, "function must take single address argument"); + for (auto addr : *this) { + func(addr); + } +} diff --git a/library/cpp/ipmath/ipmath_ut.cpp b/library/cpp/ipmath/ipmath_ut.cpp new file mode 100644 index 0000000000..5fe459ecc8 --- /dev/null +++ b/library/cpp/ipmath/ipmath_ut.cpp @@ -0,0 +1,507 @@ +#include "ipmath.h" +#include "range_set.h" + +#include <library/cpp/testing/unittest/registar.h> +#include <library/cpp/testing/gmock_in_unittest/gmock.h> + +#include <library/cpp/ipv6_address/ipv6_address.h> + +using namespace testing; + +static constexpr auto MIN_IPV6_ADDR = "::"; +static constexpr auto MAX_IPV6_ADDR = "ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"; + +std::ostream& operator<<(std::ostream& os, const TIpAddressRange& r) { + auto s = r.ToRangeString(); + os.write(s.data(), s.size()); + return os; +} + +std::ostream& operator<<(std::ostream& os, const TIpRangeSet& rangeSet) { + os << "{\n"; + + for (auto&& r : rangeSet) { + os << r << '\n'; + } + + os << "}\n"; + return os; +} + +class TIpRangeTests: public TTestBase { +public: + UNIT_TEST_SUITE(TIpRangeTests); + UNIT_TEST(IpRangeFromIpv4); + UNIT_TEST(IpRangeFromIpv6); + UNIT_TEST(FullIpRange); + UNIT_TEST(IpRangeFromCidr); + UNIT_TEST(IpRangeFromIpv4Builder); + UNIT_TEST(IpRangeFromInvalidIpv4); + UNIT_TEST(IpRangeFromInvalidIpv6); + UNIT_TEST(RangeFromSingleAddress); + UNIT_TEST(RangeFromRangeString); + UNIT_TEST(ManualIteration); + UNIT_TEST(RangeRelations); + UNIT_TEST(RangeUnion); + UNIT_TEST_SUITE_END(); + + void RangeFromSingleAddress() { + for (auto addrStr : {"192.168.0.1", "::2"}) { + auto range = TIpAddressRange::From(addrStr).Build(); + ASSERT_THAT(range.Size(), Eq(1)); + ASSERT_TRUE(range.IsSingle()); + + auto range2 = TIpAddressRange{addrStr, addrStr}; + ASSERT_THAT(range2, Eq(range)); + + TVector<ui128> result; + range.ForEach([&result] (TIpv6Address addr) { + result.push_back(addr); + }); + + bool ok{}; + ASSERT_THAT(result, ElementsAre(TIpv6Address::FromString(addrStr, ok))); + } + + } + + void IpRangeFromIpv4() { + bool ok{}; + + auto s = TIpv6Address::FromString("192.168.0.0", ok); + ASSERT_TRUE(ok); + auto e = TIpv6Address::FromString("192.168.0.255", ok); + ASSERT_TRUE(ok); + + TIpAddressRange range{s, e}; + + ASSERT_THAT(range.Size(), Eq(256)); + ASSERT_THAT(range.Type(), Eq(TIpAddressRange::TIpType::Ipv4)); + + TIpAddressRange range2{"192.168.0.0", "192.168.0.255"}; + ASSERT_THAT(range2.Size(), Eq(256)); + ASSERT_THAT(range2, Eq(range)); + } + + void IpRangeFromIpv6() { + bool ok{}; + + auto s = TIpv6Address::FromString("ffce:abcd::", ok); + ASSERT_TRUE(ok); + auto e = TIpv6Address::FromString("ffce:abcd::00ff", ok); + ASSERT_TRUE(ok); + + TIpAddressRange range{s, e}; + + ASSERT_THAT(range.Size(), Eq(256)); + + TIpAddressRange range2{"ffce:abcd::", "ffce:abcd::00ff"}; + ASSERT_THAT(range2.Size(), Eq(256)); + ASSERT_THAT(range.Type(), Eq(TIpAddressRange::TIpType::Ipv6)); + ASSERT_THAT(range2, Eq(range)); + } + + + void FullIpRange() { + auto check = [] (auto start, auto end, ui128 expectedSize) { + auto range = TIpAddressRange::From(start).To(end).Build(); + ASSERT_THAT(range.Size(), Eq(expectedSize)); + }; + + check("0.0.0.0", "255.255.255.255", ui128(Max<ui32>()) + 1); + // XXX + // check(MIN_IPV6_ADDR, MAX_IPV6_ADDR, ui128(Max<ui128>() + 1)); + } + + void IpRangeFromCidr() { + auto range = TIpAddressRange::FromCidrString("10.0.0.0/30"); + + ASSERT_THAT(range.Size(), Eq(4)); + TVector<TIpv6Address> result; + Copy(range.begin(), range.end(), std::back_inserter(result)); + + bool ok; + TVector<TIpv6Address> expected { + TIpv6Address::FromString("10.0.0.0", ok), + TIpv6Address::FromString("10.0.0.1", ok), + TIpv6Address::FromString("10.0.0.2", ok), + TIpv6Address::FromString("10.0.0.3", ok), + }; + + ASSERT_THAT(result, ElementsAreArray(expected)); + + // single host + ASSERT_THAT(TIpAddressRange::FromCidrString("ffce:abcd::/128"), Eq(TIpAddressRange::From("ffce:abcd::").Build())); + ASSERT_THAT(TIpAddressRange::FromCidrString("192.168.0.1/32"), Eq(TIpAddressRange::From("192.168.0.1").Build())); + + // full range + ASSERT_THAT(TIpAddressRange::FromCidrString("::/0"), Eq(TIpAddressRange::From(MIN_IPV6_ADDR).To(MAX_IPV6_ADDR).Build())); + ASSERT_THAT(TIpAddressRange::FromCidrString("0.0.0.0/0"), Eq(TIpAddressRange::From("0.0.0.0").To("255.255.255.255").Build())); + + // illformed + ASSERT_THROW(TIpAddressRange::FromCidrString("::/"), TInvalidIpRangeException); + ASSERT_THROW(TIpAddressRange::FromCidrString("::"), TInvalidIpRangeException); + ASSERT_THROW(TIpAddressRange::FromCidrString("/::"), TInvalidIpRangeException); + ASSERT_THROW(TIpAddressRange::FromCidrString("::/150"), TInvalidIpRangeException); + } + + void RangeFromRangeString() { + { + auto range = TIpAddressRange::FromRangeString("10.0.0.0-10.0.0.3"); + + TVector<TIpv6Address> result; + Copy(range.begin(), range.end(), std::back_inserter(result)); + + bool ok; + TVector<TIpv6Address> expected { + TIpv6Address::FromString("10.0.0.0", ok), + TIpv6Address::FromString("10.0.0.1", ok), + TIpv6Address::FromString("10.0.0.2", ok), + TIpv6Address::FromString("10.0.0.3", ok), + }; + + ASSERT_THAT(result, ElementsAreArray(expected)); + } + { + auto range = TIpAddressRange::FromRangeString("10.0.0.0-10.0.0.3"); + + TVector<TIpv6Address> result; + Copy(range.begin(), range.end(), std::back_inserter(result)); + + bool ok; + TVector<TIpv6Address> expected { + TIpv6Address::FromString("10.0.0.0", ok), + TIpv6Address::FromString("10.0.0.1", ok), + TIpv6Address::FromString("10.0.0.2", ok), + TIpv6Address::FromString("10.0.0.3", ok), + }; + + ASSERT_THAT(result, ElementsAreArray(expected)); + } + } + + void IpRangeFromIpv4Builder() { + auto range = TIpAddressRange::From("192.168.0.0") + .To("192.168.0.255") + .Build(); + + ASSERT_THAT(range.Size(), Eq(256)); + } + + void IpRangeFromInvalidIpv4() { + auto build = [] (auto from, auto to) { + return TIpAddressRange::From(from).To(to).Build(); + }; + + ASSERT_THROW(build("192.168.0.255", "192.168.0.0"), yexception); + ASSERT_THROW(build("192.168.0.0", "192.168.0.300"), yexception); + ASSERT_THROW(build("192.168.0.300", "192.168.0.330"), yexception); + ASSERT_THROW(build("192.168.0.0", "::1"), yexception); + ASSERT_THROW(build(TIpv6Address{}, TIpv6Address{}), yexception); + } + + void IpRangeFromInvalidIpv6() { + auto build = [] (auto from, auto to) { + return TIpAddressRange::From(from).To(to).Build(); + }; + + ASSERT_THROW(build("ffce:abcd::00ff", "ffce:abcd::"), yexception); + ASSERT_THROW(build("ffce:abcd::", "ffce:abcd::fffff"), yexception); + ASSERT_THROW(build("ffce:abcd::10000", "ffce:abcd::ffff"), yexception); + ASSERT_THROW(build("ffce:abcd::", TIpv6Address{}), yexception); + + auto ctor = [] (auto s, auto e) { + return TIpAddressRange{s, e}; + }; + + ASSERT_THROW(ctor(TIpv6Address{}, TIpv6Address{}), yexception); + ASSERT_THROW(ctor("", ""), yexception); + } + + void ManualIteration() { + { + TIpAddressRange range{"::", "::"}; + auto it = range.Begin(); + ++it; + bool ok; + ASSERT_THAT(*it, Eq(TIpv6Address::FromString("::1", ok))); + + for (auto i = 0; i < 254; ++i, ++it) { + } + + ASSERT_THAT(*it, Eq(TIpv6Address::FromString("::ff", ok))); + } + + { + TIpAddressRange range{"0.0.0.0", "0.0.0.0"}; + auto it = range.Begin(); + ++it; + bool ok; + ASSERT_THAT(*it, Eq(TIpv6Address::FromString("0.0.0.1", ok))); + + for (auto i = 0; i < 254; ++i, ++it) { + } + + ASSERT_THAT(*it, Eq(TIpv6Address::FromString("0.0.0.255", ok))); + } + } + + void RangeRelations() { + { + auto range = TIpAddressRange::From(MIN_IPV6_ADDR) + .To(MAX_IPV6_ADDR) + .Build(); + + ASSERT_TRUE(range.Overlaps(range)); + ASSERT_TRUE(range.Contains(range)); + // XXX + //ASSERT_FALSE(range.IsConsecutive(range)); + } + { + auto range = TIpAddressRange::From("0.0.0.1").To("0.0.0.4").Build(); + auto range0 = TIpAddressRange::From("0.0.0.0").Build(); + auto range1 = TIpAddressRange::From("0.0.0.1").Build(); + auto range2 = TIpAddressRange::From("0.0.0.5").Build(); + auto range4 = TIpAddressRange::From("0.0.0.4").Build(); + + ASSERT_FALSE(range.Overlaps(range0)); + ASSERT_TRUE(range.IsConsecutive(range0)); + ASSERT_FALSE(range.Contains(range0)); + + ASSERT_TRUE(range.Overlaps(range1)); + ASSERT_FALSE(range.IsConsecutive(range1)); + ASSERT_TRUE(range.Contains(range1)); + + ASSERT_TRUE(range.Overlaps(range4)); + ASSERT_FALSE(range.IsConsecutive(range4)); + ASSERT_TRUE(range.Contains(range4)); + } + { + auto range = TIpAddressRange::From("0.0.0.1").To("0.0.0.4").Build(); + auto range2 = TIpAddressRange::From("0.0.0.0").To("0.0.0.2").Build(); + + ASSERT_TRUE(range.Overlaps(range2)); + ASSERT_FALSE(range.IsConsecutive(range2)); + ASSERT_FALSE(range.Contains(range2)); + + bool ok; + ASSERT_TRUE(range.Contains(TIpv6Address::FromString("0.0.0.1", ok))); + ASSERT_TRUE(range.Contains(TIpv6Address::FromString("0.0.0.2", ok))); + ASSERT_FALSE(range.Contains(TIpv6Address::FromString("0.0.0.5", ok))); + } + } + + void RangeUnion() { + { + auto range = TIpAddressRange::From(MIN_IPV6_ADDR) + .To(MAX_IPV6_ADDR) + .Build(); + + ASSERT_THAT(range.Union(range), Eq(range)); + ASSERT_THAT(range.Union(TIpAddressRange::From("::")), range); + ASSERT_THAT(range.Union(TIpAddressRange::From("::1")), range); + + ASSERT_THROW(range.Union(TIpAddressRange::From("0.0.0.0")), yexception); + } + + { + auto expected = TIpAddressRange::From("0.0.0.1").To("0.0.0.10").Build(); + + auto range = TIpAddressRange{"0.0.0.1", "0.0.0.3"}.Union({"0.0.0.4", "0.0.0.10"}); + ASSERT_THAT(range, Eq(expected)); + + auto range2 = TIpAddressRange{"0.0.0.1", "0.0.0.3"}.Union({"0.0.0.2", "0.0.0.10"}); + ASSERT_THAT(range2, Eq(expected)); + + auto range3 = TIpAddressRange{"0.0.0.2", "0.0.0.3"}.Union({"0.0.0.1", "0.0.0.10"}); + ASSERT_THAT(range2, Eq(expected)); + + auto range4 = TIpAddressRange{"0.0.0.1", "0.0.0.10"}.Union({"0.0.0.2", "0.0.0.3"}); + ASSERT_THAT(range2, Eq(expected)); + + ASSERT_THROW(range.Union(TIpAddressRange::From("10.0.0.0")), yexception); + } + } +}; + +class TRangeSetTests: public TTestBase { +public: + UNIT_TEST_SUITE(TRangeSetTests); + UNIT_TEST(AddDisjoint); + UNIT_TEST(AddOverlapping); + UNIT_TEST(AddConsecutive); + UNIT_TEST(DisallowsMixingTypes); + UNIT_TEST(MembershipTest); + UNIT_TEST_SUITE_END(); + + void AddDisjoint() { + TIpRangeSet set; + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("0.0.0.2").Build(), + TIpAddressRange::From("0.0.0.4").To("255.255.255.255").Build(), + }; + + for (auto&& r : expected) { + set.Add(r); + } + + ASSERT_THAT(set, ElementsAreArray(expected)); + } + + void TestAdding(const TVector<TIpAddressRange>& toInsert, const TVector<TIpAddressRange>& expected) { + TIpRangeSet set; + { + set.Add(toInsert); + + ASSERT_THAT(set, ElementsAreArray(expected)); + } + + { + for (auto it = toInsert.rbegin(); it != toInsert.rend(); ++it) { + set.Add(*it); + } + + ASSERT_THAT(set, ElementsAreArray(expected)); + } + } + + void AddOverlapping() { + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.2").Build(), + TIpAddressRange::From("0.0.0.2").To("0.0.0.4").Build(), + TIpAddressRange::From("0.0.0.4").To("255.255.255.255").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("255.255.255.255").Build(), + }; + + TestAdding(toInsert, expected); + } + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.8").To("0.0.0.10").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + }; + + TestAdding(toInsert, expected); + } + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + TIpAddressRange::From("0.0.0.8").To("0.0.0.10").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + }; + + TestAdding(toInsert, expected); + } + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.3").To("0.0.0.10").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("0.0.0.10").Build(), + }; + + TestAdding(toInsert, expected); + } + } + + void DisallowsMixingTypes() { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("::1").Build(), + }; + + TIpRangeSet rangeSet; + + ASSERT_THROW([&] { rangeSet.Add(toInsert); }(), yexception); + ASSERT_THROW([&] { rangeSet.Add(toInsert[1]); rangeSet.Add(toInsert[0]); }(), yexception); + } + + void AddConsecutive() { + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.6").To("0.0.0.7").Build(), + TIpAddressRange::From("0.0.0.8").To("0.0.0.10").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("0.0.0.10").Build(), + }; + + TestAdding(toInsert, expected); + } + { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("255.255.255.255").Build(), + TIpAddressRange::From("0.0.0.0").To("255.255.255.255").Build(), + }; + + TVector<TIpAddressRange> expected { + TIpAddressRange::From("0.0.0.0").To("255.255.255.255").Build(), + }; + + TestAdding(toInsert, expected); + } + } + + void MembershipTest() { + TVector<TIpAddressRange> toInsert { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + TIpAddressRange::From("0.0.0.8").To("0.0.0.10").Build(), + }; + + TIpRangeSet rangeSet; + rangeSet.Add(toInsert); + + TVector<TIpAddressRange> in { + TIpAddressRange::From("0.0.0.0").To("0.0.0.5").Build(), + TIpAddressRange::From("0.0.0.7").To("0.0.0.12").Build(), + }; + + TVector<TIpAddressRange> notIn { + TIpAddressRange::From("0.0.0.6").Build(), + TIpAddressRange::From("0.0.0.13").To("0.0.0.255").Build(), + // enumerating full range is slow and makes little sense + TIpAddressRange::From("255.255.255.0").To("255.255.255.255").Build(), + }; + + for (auto&& range : in) { + for (auto&& addr : range) { + ASSERT_TRUE(rangeSet.Contains(addr)); + ASSERT_THAT(*rangeSet.Find(addr), Eq(range)); + } + } + + for (auto&& range : notIn) { + for (auto&& addr : range) { + ASSERT_FALSE(rangeSet.Contains(addr)); + ASSERT_THAT(rangeSet.Find(addr), rangeSet.End()); + } + } + + bool ok{}; + ASSERT_THAT(rangeSet.Find(TIpv6Address::FromString("::1", ok)), Eq(rangeSet.End())); + ASSERT_FALSE(rangeSet.Contains(TIpv6Address::FromString("::1", ok))); + } +}; + +UNIT_TEST_SUITE_REGISTRATION(TIpRangeTests); +UNIT_TEST_SUITE_REGISTRATION(TRangeSetTests); diff --git a/library/cpp/ipmath/range_set.cpp b/library/cpp/ipmath/range_set.cpp new file mode 100644 index 0000000000..55f42e451d --- /dev/null +++ b/library/cpp/ipmath/range_set.cpp @@ -0,0 +1,99 @@ +#include "range_set.h" + +#include <util/generic/algorithm.h> + +namespace { + bool ShouldJoin(const TIpAddressRange& lhs, const TIpAddressRange& rhs) { + return lhs.Overlaps(rhs) || lhs.IsConsecutive(rhs); + } +} + +bool TIpRangeSet::TRangeLess::operator()(const TIpAddressRange& lhs, const TIpAddressRange& rhs) const { + return *lhs.Begin() < *rhs.Begin(); +} + +TIpRangeSet::TIpRangeSet() = default; +TIpRangeSet::~TIpRangeSet() = default; + +void TIpRangeSet::Add(TIpAddressRange r) { + Y_ENSURE(IsEmpty() || r.Type() == Type(), "Mixing IPv4 and IPv6 ranges is disallowed"); + + auto lowerIt = Ranges_.lower_bound(r); + + // still may overlap the last interval in our tree + if (IsEmpty()) { + Ranges_.insert(r); + return; + } else if (lowerIt == Ranges_.end()) { + if (auto it = Ranges_.rbegin(); ShouldJoin(*it, r)) { + auto unitedRange = it->Union(r); + Ranges_.erase(--it.base()); + Ranges_.insert(unitedRange); + } else { + Ranges_.insert(r); + } + + return; + } + + + TIpAddressRange unitedRange{r}; + + auto joined = lowerIt; + if (lowerIt != Ranges_.begin()) { + if (ShouldJoin(unitedRange, *(--joined))) { + unitedRange = unitedRange.Union(*joined); + } else { + ++joined; + } + } + + auto it = lowerIt; + for (; it != Ranges_.end() && ShouldJoin(*it, unitedRange); ++it) { + unitedRange = unitedRange.Union(*it); + } + + Ranges_.erase(joined, it); + Ranges_.insert(unitedRange); +} + +TIpAddressRange::TIpType TIpRangeSet::Type() const { + return IsEmpty() + ? TIpAddressRange::TIpType::LAST + : Ranges_.begin()->Type(); +} + +bool TIpRangeSet::IsEmpty() const { + return Ranges_.empty(); +} + +TIpRangeSet::TIterator TIpRangeSet::Find(TIpv6Address addr) const { + if (IsEmpty() || addr.Type() != Type()) { + return End(); + } + + auto lowerIt = Ranges_.lower_bound(TIpAddressRange(addr, addr)); + + if (lowerIt == Ranges_.begin()) { + return lowerIt->Contains(addr) + ? lowerIt + : End(); + } else if (lowerIt == Ranges_.end()) { + auto rbegin = Ranges_.crbegin(); + return rbegin->Contains(addr) + ? (++rbegin).base() + : End(); + } else if (lowerIt->Contains(addr)) { + return lowerIt; + } + + --lowerIt; + + return lowerIt->Contains(addr) + ? lowerIt + : End(); +} + +bool TIpRangeSet::Contains(TIpv6Address addr) const { + return Find(addr) != End(); +} diff --git a/library/cpp/ipmath/range_set.h b/library/cpp/ipmath/range_set.h new file mode 100644 index 0000000000..d9e2451822 --- /dev/null +++ b/library/cpp/ipmath/range_set.h @@ -0,0 +1,66 @@ +#pragma once + +#include "ipmath.h" + +#include <util/generic/set.h> +#include <util/ysaveload.h> + + +/// @brief Maintains a disjoint set of added ranges. Allows for efficient membership queries +/// for an address in a set of IP ranges. +class TIpRangeSet { + struct TRangeLess { + bool operator()(const TIpAddressRange& lhs, const TIpAddressRange& rhs) const; + }; + + using TTree = TSet<TIpAddressRange, TRangeLess>; + +public: + using iterator = TTree::iterator; + using const_iterator = TTree::const_iterator; + using value_type = TTree::value_type; + using TIterator = TTree::iterator; + using TConstIterator = TTree::const_iterator; + + TIpRangeSet(); + ~TIpRangeSet(); + + void Add(TIpAddressRange range); + + template <typename TContainer> + void Add(TContainer&& addrs) { + using T = typename std::decay<TContainer>::type::value_type; + static_assert(std::is_convertible<T, TIpAddressRange>::value); + + for (auto&& addr : addrs) { + Add(addr); + } + } + + TIpAddressRange::TIpType Type() const; + + bool IsEmpty() const; + bool Contains(TIpv6Address addr) const; + TConstIterator Find(TIpv6Address addr) const; + + TConstIterator Begin() const { + return Ranges_.begin(); + } + + TConstIterator End() const { + return Ranges_.end(); + } + + TConstIterator begin() const { + return Begin(); + } + + TConstIterator end() const { + return End(); + } + + Y_SAVELOAD_DEFINE(Ranges_); + +private: + TTree Ranges_; +}; diff --git a/library/cpp/ipmath/ut/ya.make b/library/cpp/ipmath/ut/ya.make new file mode 100644 index 0000000000..b860cefd03 --- /dev/null +++ b/library/cpp/ipmath/ut/ya.make @@ -0,0 +1,16 @@ +UNITTEST_FOR(library/cpp/ipmath) + +OWNER( + msherbakov + g:solomon +) + +SRCS( + ipmath_ut.cpp +) + +PEERDIR( + library/cpp/testing/gmock_in_unittest +) + +END() diff --git a/library/cpp/ipmath/ya.make b/library/cpp/ipmath/ya.make new file mode 100644 index 0000000000..244838962e --- /dev/null +++ b/library/cpp/ipmath/ya.make @@ -0,0 +1,17 @@ +LIBRARY() + +OWNER( + msherbakov + g:solomon +) + +SRCS( + ipmath.cpp + range_set.cpp +) + +PEERDIR(library/cpp/ipv6_address) + +END() + +RECURSE_FOR_TESTS(ut) |