aboutsummaryrefslogtreecommitdiffstats
path: root/util
diff options
context:
space:
mode:
authorsv-denisov <sv-denisov@yandex-team.com>2023-01-14 15:13:39 +0300
committersv-denisov <sv-denisov@yandex-team.com>2023-01-14 15:13:39 +0300
commit3657be5988251fc9074ba5b86b62bfa985ff4643 (patch)
tree2f9eca1c7158569dec9e2397c39b1a3731a7ee1c /util
parent380ce27d41b76ca1640e48b7271681b3719d6be0 (diff)
downloadydb-3657be5988251fc9074ba5b86b62bfa985ff4643.tar.gz
TSharedPtr::As()&: leakage fix
По следам https://a.yandex-team.ru/review/2996409/details Исправлены случаи при неудачных кастах: 1. контрольный блок дёргается даже при указании на nullptr (нехорошо для атомиков) 2. если исходный шаред поинтер разрушится раньше, чем полученный nullptr (например, при возврате из функции), то ресурс утечёт
Diffstat (limited to 'util')
-rw-r--r--util/generic/ptr.h21
-rw-r--r--util/generic/ptr_ut.cpp44
2 files changed, 59 insertions, 6 deletions
diff --git a/util/generic/ptr.h b/util/generic/ptr.h
index b35d54d780..cc2e3c0f51 100644
--- a/util/generic/ptr.h
+++ b/util/generic/ptr.h
@@ -900,18 +900,27 @@ public:
[[nodiscard]] inline TSharedPtr<TT, C, D> As() & noexcept {
static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor");
static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT");
- Ref();
- return TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_);
+ if (const auto ttPtr = dynamic_cast<TT*>(T_)) {
+ TSharedPtr<TT, C, D> ttSharedPtr(ttPtr, C_);
+ ttSharedPtr.Ref();
+ return ttSharedPtr;
+ } else {
+ return TSharedPtr<TT, C, D>{};
+ }
}
template <class TT>
[[nodiscard]] inline TSharedPtr<TT, C, D> As() && noexcept {
static_assert(std::has_virtual_destructor<TT>(), "Type should have a virtual dtor");
static_assert(std::is_base_of<T, TT>(), "When downcasting from T to TT, T should be a parent of TT");
- auto resultPtr = TSharedPtr<TT, C, D>(dynamic_cast<TT*>(T_), C_);
- T_ = nullptr;
- C_ = nullptr;
- return resultPtr;
+ if (const auto ttPtr = dynamic_cast<TT*>(T_)) {
+ TSharedPtr<TT, C, D> ttSharedPtr(ttPtr, C_);
+ T_ = nullptr;
+ C_ = nullptr;
+ return ttSharedPtr;
+ } else {
+ return TSharedPtr<TT, C, D>{};
+ }
}
#ifdef __cpp_impl_three_way_comparison
diff --git a/util/generic/ptr_ut.cpp b/util/generic/ptr_ut.cpp
index 6e028725b4..5f0fd2d470 100644
--- a/util/generic/ptr_ut.cpp
+++ b/util/generic/ptr_ut.cpp
@@ -850,6 +850,11 @@ public:
using TVirtualProbe::TVirtualProbe;
};
+class TDerivedProbeSibling: public TVirtualProbe {
+public:
+ using TVirtualProbe::TVirtualProbe;
+};
+
void TPointerTest::TestSharedPtrDowncast() {
{
NTesting::TProbeState probeState = {};
@@ -889,4 +894,43 @@ void TPointerTest::TestSharedPtrDowncast() {
UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1);
}
+ {
+ NTesting::TProbeState probeState = {};
+
+ {
+ TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ {
+ auto derivedSibling = base.As<TDerivedProbeSibling>();
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ UNIT_ASSERT_VALUES_EQUAL(derivedSibling.Get(), nullptr);
+ UNIT_ASSERT_VALUES_UNEQUAL(base.ReferenceCounter(), derivedSibling.ReferenceCounter());
+
+ UNIT_ASSERT_VALUES_EQUAL(base.RefCount(), 1l);
+ UNIT_ASSERT_VALUES_EQUAL(derivedSibling.RefCount(), 0l);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1);
+ }
+ {
+ NTesting::TProbeState probeState = {};
+
+ {
+ TSimpleSharedPtr<TVirtualProbe> base = MakeSimpleShared<TDerivedProbe>(&probeState);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+
+ auto derived = std::move(base).As<TDerivedProbeSibling>();
+ UNIT_ASSERT_VALUES_EQUAL(derived.Get(), nullptr);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Constructors, 1);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.CopyConstructors, 0);
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 0);
+ }
+
+ UNIT_ASSERT_VALUES_EQUAL(probeState.Destructors, 1);
+ }
}