aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorionagamed <ionagamed@yandex-team.com>2022-10-01 11:18:40 +0300
committerionagamed <ionagamed@yandex-team.com>2022-10-01 11:18:40 +0300
commitfecbe41234ec99b7709d8d9ec617219655ac51a7 (patch)
tree04de3a8f2f402fd94185c2a5d7d9dca80a21d67c
parentd479316c284f478d8bd44dd5eb57d5899b1ded4e (diff)
downloadydb-fecbe41234ec99b7709d8d9ec617219655ac51a7.tar.gz
util: add .As<T> to TSharedPtr
-rw-r--r--util/generic/ptr.h19
-rw-r--r--util/generic/ptr_ut.cpp58
2 files changed, 76 insertions, 1 deletions
diff --git a/util/generic/ptr.h b/util/generic/ptr.h
index 3addc85753e..b35d54d780a 100644
--- a/util/generic/ptr.h
+++ b/util/generic/ptr.h
@@ -7,6 +7,7 @@
#include "typetraits.h"
#include "singleton.h"
+#include <type_traits>
#include <utility>
#include <util/system/compiler.h>
@@ -895,6 +896,24 @@ public:
return C_ ? C_->Val() : 0;
}
+ template <class TT>
+ [[nodiscard]] inline TSharedPtr<TT, C, D> As() & noexcept {
+ static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor");
+ static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT");
+ Ref();
+ return TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_);
+ }
+
+ template <class TT>
+ [[nodiscard]] inline TSharedPtr<TT, C, D> As() && noexcept {
+ static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor");
+ static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT");
+ auto resultPtr = TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_);
+ T_ = nullptr;
+ C_ = nullptr;
+ return resultPtr;
+ }
+
#ifdef __cpp_impl_three_way_comparison
template <class Other>
inline bool operator==(const Other& p) const noexcept {
diff --git a/util/generic/ptr_ut.cpp b/util/generic/ptr_ut.cpp
index 1b4071ea07a..6e028725b4e 100644
--- a/util/generic/ptr_ut.cpp
+++ b/util/generic/ptr_ut.cpp
@@ -2,6 +2,7 @@
#include "vector.h"
#include "noncopyable.h"
+#include <library/cpp/testing/common/probe.h>
#include <library/cpp/testing/unittest/registar.h>
#include <util/generic/hash_set.h>
@@ -32,7 +33,8 @@ class TPointerTest: public TTestBase {
UNIT_TEST(TestMakeShared);
UNIT_TEST(TestComparison);
UNIT_TEST(TestSimpleIntrusivePtrCtorTsan);
- UNIT_TEST(TestRefCountedPtrsInHashSet)
+ UNIT_TEST(TestRefCountedPtrsInHashSet);
+ UNIT_TEST(TestSharedPtrDowncast);
UNIT_TEST_SUITE_END();
private:
@@ -86,6 +88,7 @@ private:
template <class T, class TRefCountedPtr>
void TestRefCountedPtrsInHashSetImpl();
void TestRefCountedPtrsInHashSet();
+ void TestSharedPtrDowncast();
};
UNIT_TEST_SUITE_REGISTRATION(TPointerTest);
@@ -834,3 +837,56 @@ void TPointerTest::TestIntrusiveConstConstruction() {
UNIT_ASSERT_VALUES_EQUAL(cnt.Increments.load(), 1);
}
}
+
+class TVirtualProbe: public NTesting::TProbe {
+public:
+ using NTesting::TProbe::TProbe;
+
+ virtual ~TVirtualProbe() = default;
+};
+
+class TDerivedProbe: public TVirtualProbe {
+public:
+ using TVirtualProbe::TVirtualProbe;
+};
+
+void TPointerTest::TestSharedPtrDowncast() {
+ {
+ NTesting::TProbeState probeState = {};
+
+ {
+ TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ {
+ auto derived = base.As<TDerivedProbe>();
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ UNIT_ASSERT_VALUES_EQUAL(base.Get(), derived.Get());
+ UNIT_ASSERT_VALUES_EQUAL(base.ReferenceCounter(), derived.ReferenceCounter());
+
+ UNIT_ASSERT_VALUES_EQUAL(base.RefCount(), 2l);
+ UNIT_ASSERT_VALUES_EQUAL(derived.RefCount(), 2l);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1);
+ }
+ {
+ NTesting::TProbeState probeState = {};
+
+ {
+ TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ auto derived = std::move(base).As<TDerivedProbe>();
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.CopyConstructors, 0);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1);
+ }
+}