From e28af77b3232cd5b5655b24923d4e5243d67f338 Mon Sep 17 00:00:00 2001 From: agri Date: Tue, 23 Jun 2026 10:21:04 +0300 Subject: Fix weak ptr for TTrueAtomicSharedPtr fix weak ptr add swap test reinterpret\_cast, precise memory\_order commit_hash:a1b73fa28d9314d4cd21a473f421cfedcaf19330 --- .../atomic_shared_ptr/atomic_shared_ptr.h | 104 +++++++++++++-------- 1 file changed, 66 insertions(+), 38 deletions(-) (limited to 'library/cpp') diff --git a/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h b/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h index a991eb3bccc..fbc9b9576b8 100644 --- a/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h +++ b/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h @@ -36,11 +36,11 @@ namespace NPrivate { } ui64 Ref(ui64 add = USE_INCREMENT) noexcept { - return ref_count_.fetch_add(add, std::memory_order_relaxed) + add; + return ref_count_.fetch_add(add, std::memory_order_acquire) + add; } ui64 Unref(ui64 sub = USE_INCREMENT) noexcept { - return ref_count_.fetch_sub(sub, std::memory_order_seq_cst) - sub; + return ref_count_.fetch_sub(sub, std::memory_order_acq_rel) - sub; } void UnrefAndDelete(ui64 sub = USE_INCREMENT) { @@ -48,8 +48,7 @@ namespace NPrivate { return; ui64 expect = 0; bool flag_is_set = ref_count_.compare_exchange_strong( - expect, DESTROYED_FLAG, - std::memory_order_seq_cst, std::memory_order_relaxed); + expect, DESTROYED_FLAG, std::memory_order_acq_rel, std::memory_order_relaxed); if (!flag_is_set) return; DestroyPayload(); @@ -61,8 +60,7 @@ namespace NPrivate { return nullptr; ui64 expect = 0; bool flag_is_set = ref_count_.compare_exchange_strong( - expect, DESTROYED_FLAG, - std::memory_order_seq_cst, std::memory_order_relaxed); + expect, DESTROYED_FLAG, std::memory_order_acq_rel, std::memory_order_relaxed); if (!flag_is_set) return nullptr; void* result = GetPtr(); @@ -79,17 +77,16 @@ namespace NPrivate { } ui64 RefWeak(ui64 add = 1) noexcept { - return weak_count_.fetch_add(add, std::memory_order_relaxed) + add; + return weak_count_.fetch_add(add, std::memory_order_acquire) + add; } ui64 UnrefWeak(ui64 sub = 1) noexcept { - return weak_count_.fetch_sub(sub, std::memory_order_seq_cst) - sub; + return weak_count_.fetch_sub(sub, std::memory_order_acq_rel) - sub; } void UnrefWeakAndDelete(ui64 sub = 1) noexcept { - if (UnrefWeak(sub) != 0) - return; - delete this; + if (UnrefWeak(sub) == 0) + delete this; } }; @@ -125,7 +122,7 @@ namespace NPrivate { explicit TSharedBasePtr(TRefCounter* counter_ptr) noexcept : ptr_((uintptr_t)counter_ptr) { - Y_ABORT_UNLESS((ptr_.load(std::memory_order_relaxed) & ~PTR_MASK) == 0, + Y_ABORT_UNLESS((ptr_.load(std::memory_order_acquire) & ~PTR_MASK) == 0, "you must provide a clean ptr"); } @@ -140,21 +137,32 @@ namespace NPrivate { TRefCounter* ConcurrentAcquire() noexcept { auto result = - ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_seq_cst); + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); auto ptr_result = CleanUpPtr(result); - if (ptr_result) - ptr_result->Ref(CONCURRENT_INCREMENT + TRefCounter::USE_INCREMENT); + if (!ptr_result) + return nullptr; + ptr_result->Ref(CONCURRENT_INCREMENT + TRefCounter::USE_INCREMENT); return ptr_result; } TRefCounter* ConcurrentWeakAcquire() noexcept { auto result = - ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_seq_cst); + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); + auto ptr_result = CleanUpPtr(result); + if (!ptr_result) + return nullptr; + ptr_result->RefWeak(); + ptr_result->Ref(CONCURRENT_INCREMENT); + return ptr_result; + } + + TRefCounter* ConcurrentWeakAcquireFromWeak() noexcept { + auto result = + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); auto ptr_result = CleanUpPtr(result); - if (ptr_result) { - ptr_result->RefWeak(); - ptr_result->Ref(CONCURRENT_INCREMENT); - } + if (!ptr_result) + return nullptr; + ptr_result->RefWeak(CONCURRENT_INCREMENT + 1); return ptr_result; } @@ -176,35 +184,35 @@ namespace NPrivate { static void DestroyWeakPtr(TRefCounter* ptr) noexcept { auto clean_ptr = CleanUpPtr(ptr); - if (clean_ptr) { - ui64 cnt = GetCounter(ptr); - clean_ptr->UnrefWeakAndDelete(cnt); - } + if (!clean_ptr) + return; + ui64 cnt = GetCounter(ptr); + clean_ptr->UnrefWeakAndDelete(cnt + 1); } static TRefCounter* CleanUpPtr(TRefCounter* ptr) noexcept { - return (TRefCounter*)((uintptr_t)ptr & PTR_MASK); + return CleanUpPtr(reinterpret_cast(ptr)); } static TRefCounter* CleanUpPtr(uintptr_t ptr) noexcept { - return (TRefCounter*)(ptr & PTR_MASK); + return reinterpret_cast(ptr & PTR_MASK); } static uintptr_t GetCounter(TRefCounter* ptr) noexcept { - return (uintptr_t)ptr & ~PTR_MASK; + return reinterpret_cast(ptr) & ~PTR_MASK; } TRefCounter* GetClean() const noexcept { - return CleanUpPtr(ptr_.load(std::memory_order_relaxed)); + return CleanUpPtr(ptr_.load(std::memory_order_acquire)); } TRefCounter* GetRaw() const noexcept { - return (TRefCounter*)ptr_.load(std::memory_order_relaxed); + return reinterpret_cast(ptr_.load(std::memory_order_acquire)); } TRefCounter* Swap(TRefCounter* other) noexcept { - return (TRefCounter*)ptr_.exchange( - (uintptr_t)other, std::memory_order_seq_cst); + return reinterpret_cast( + ptr_.exchange(reinterpret_cast(other))); } size_t UseCount() const noexcept { @@ -212,7 +220,7 @@ namespace NPrivate { auto clean_ptr = CleanUpPtr(ptr); if (!clean_ptr) return 0; - ui64 result = clean_ptr->ref_count_.load(std::memory_order_relaxed); + ui64 result = clean_ptr->ref_count_.load(std::memory_order_acquire); result -= GetCounter(ptr); result /= TRefCounter::USE_INCREMENT; return result; @@ -368,17 +376,36 @@ public: } TTrueAtomicWeakPtr(const TTrueAtomicWeakPtr& other) noexcept - : ptr_(other.ptr_.ConcurrentWeakAcquire()) + : ptr_(other.ptr_.ConcurrentWeakAcquireFromWeak()) { } - TTrueAtomicWeakPtr& operator=(const TTrueAtomicWeakPtr& other) noexcept { + + TTrueAtomicWeakPtr(TTrueAtomicWeakPtr&& other) noexcept + : ptr_(other.ptr_.Swap(nullptr)) + { + } + + ~TTrueAtomicWeakPtr() + { + auto ptr = ptr_.GetRaw(); + NPrivate::TSharedBasePtr::DestroyWeakPtr(ptr); + } + + TTrueAtomicWeakPtr& operator=(const TTrueAtomicSharedPtr& other) noexcept { auto new_ptr = other.ptr_.ConcurrentWeakAcquire(); auto back_ptr = ptr_.Swap(new_ptr); NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); return *this; } + TTrueAtomicWeakPtr& operator=(const TTrueAtomicWeakPtr& other) noexcept { + auto new_ptr = other.ptr_.ConcurrentWeakAcquireFromWeak(); + auto back_ptr = ptr_.Swap(new_ptr); + NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); + return *this; + } + TTrueAtomicWeakPtr& operator=(TTrueAtomicWeakPtr&& other) noexcept { auto new_ptr = other.ptr_.Swap(nullptr); auto back_ptr = ptr_.Swap(new_ptr); @@ -389,9 +416,10 @@ public: TTrueAtomicSharedPtr lock() noexcept { // create local TTrueAtomicWeakPtr to avoid concurrent changes of this TTrueAtomicWeakPtr local(*this); - if (!local.ptr_.GetRaw()->RefFromWeak()) - return TTrueAtomicSharedPtr(); - return TTrueAtomicSharedPtr(ptr_.GetClean()); + if (auto raw_ptr = local.ptr_.GetRaw()) + if (raw_ptr->RefFromWeak()) + return TTrueAtomicSharedPtr(ptr_.GetClean()); + return TTrueAtomicSharedPtr(); } void reset() noexcept { @@ -403,7 +431,7 @@ public: // Having A with possible accesses from multiple threads and B with exclusive // use from a single thread, then calling A.swap(B) is atomic. void swap(TTrueAtomicWeakPtr& other) noexcept { - auto copy = other.ptr_.ConcurrentWeakAcquire(); + auto copy = other.ptr_.ConcurrentWeakAcquireFromWeak(); auto back_ptr = ptr_.Swap(copy); back_ptr = other.ptr_.Swap(back_ptr); NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); -- cgit v1.3