diff options
author | ionagamed <ionagamed@yandex-team.com> | 2022-10-01 11:18:40 +0300 |
---|---|---|
committer | ionagamed <ionagamed@yandex-team.com> | 2022-10-01 11:18:40 +0300 |
commit | fecbe41234ec99b7709d8d9ec617219655ac51a7 (patch) | |
tree | 04de3a8f2f402fd94185c2a5d7d9dca80a21d67c | |
parent | d479316c284f478d8bd44dd5eb57d5899b1ded4e (diff) | |
download | ydb-fecbe41234ec99b7709d8d9ec617219655ac51a7.tar.gz |
util: add .As<T> to TSharedPtr
-rw-r--r-- | util/generic/ptr.h | 19 | ||||
-rw-r--r-- | util/generic/ptr_ut.cpp | 58 |
2 files changed, 76 insertions, 1 deletions
diff --git a/util/generic/ptr.h b/util/generic/ptr.h index 3addc85753e..b35d54d780a 100644 --- a/util/generic/ptr.h +++ b/util/generic/ptr.h @@ -7,6 +7,7 @@ #include "typetraits.h" #include "singleton.h" +#include <type_traits> #include <utility> #include <util/system/compiler.h> @@ -895,6 +896,24 @@ public: return C_ ? C_->Val() : 0; } + template <class TT> + [[nodiscard]] inline TSharedPtr<TT, C, D> As() & noexcept { + static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor"); + static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT"); + Ref(); + return TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_); + } + + template <class TT> + [[nodiscard]] inline TSharedPtr<TT, C, D> As() && noexcept { + static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor"); + static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT"); + auto resultPtr = TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_); + T_ = nullptr; + C_ = nullptr; + return resultPtr; + } + #ifdef __cpp_impl_three_way_comparison template <class Other> inline bool operator==(const Other& p) const noexcept { diff --git a/util/generic/ptr_ut.cpp b/util/generic/ptr_ut.cpp index 1b4071ea07a..6e028725b4e 100644 --- a/util/generic/ptr_ut.cpp +++ b/util/generic/ptr_ut.cpp @@ -2,6 +2,7 @@ #include "vector.h" #include "noncopyable.h" +#include <library/cpp/testing/common/probe.h> #include <library/cpp/testing/unittest/registar.h> #include <util/generic/hash_set.h> @@ -32,7 +33,8 @@ class TPointerTest: public TTestBase { UNIT_TEST(TestMakeShared); UNIT_TEST(TestComparison); UNIT_TEST(TestSimpleIntrusivePtrCtorTsan); - UNIT_TEST(TestRefCountedPtrsInHashSet) + UNIT_TEST(TestRefCountedPtrsInHashSet); + UNIT_TEST(TestSharedPtrDowncast); UNIT_TEST_SUITE_END(); private: @@ -86,6 +88,7 @@ private: template <class T, class TRefCountedPtr> void TestRefCountedPtrsInHashSetImpl(); void TestRefCountedPtrsInHashSet(); + void TestSharedPtrDowncast(); }; UNIT_TEST_SUITE_REGISTRATION(TPointerTest); @@ -834,3 +837,56 @@ void TPointerTest::TestIntrusiveConstConstruction() { UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1); } } + +class TVirtualProbe: public NTesting::TProbe { +public: + using NTesting::TProbe::TProbe; + + virtual ~TVirtualProbe() = default; +}; + +class TDerivedProbe: public TVirtualProbe { +public: + using TVirtualProbe::TVirtualProbe; +}; + +void TPointerTest::TestSharedPtrDowncast() { + { + NTesting::TProbeState probeState = {}; + + { + TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState); + UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1); + + { + auto derived = base.As<TDerivedProbe>(); + UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1); + + UNIT_ASSERT_VALUES_EQUAL(base.Get(), derived.Get()); + UNIT_ASSERT_VALUES_EQUAL(base.ReferenceCounter(), derived.ReferenceCounter()); + + UNIT_ASSERT_VALUES_EQUAL(base.RefCount(), 2l); + UNIT_ASSERT_VALUES_EQUAL(derived.RefCount(), 2l); + } + + UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0); + } + + UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1); + } + { + NTesting::TProbeState probeState = {}; + + { + TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState); + UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1); + + auto derived = std::move(base).As<TDerivedProbe>(); + UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1); + UNIT_ASSERT_VALUES_EQUAL(probeState.CopyConstructors, 0); + UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0); + } + + UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1); + } +} |