diff options
| author | agri <[email protected]> | 2026-06-23 10:21:04 +0300 |
|---|---|---|
| committer | agri <[email protected]> | 2026-06-23 10:57:56 +0300 |
| commit | e28af77b3232cd5b5655b24923d4e5243d67f338 (patch) | |
| tree | ef1d46f8480b58678b2013edee1588b360d25bb8 /library/cpp/threading | |
| parent | 68a6016f5e3a1fe07371b65b84f7f2c7ede7219d (diff) | |
Fix weak ptr for TTrueAtomicSharedPtr
fix weak ptr
add swap test
reinterpret\_cast, precise memory\_order
commit_hash:a1b73fa28d9314d4cd21a473f421cfedcaf19330
Diffstat (limited to 'library/cpp/threading')
| -rw-r--r-- | library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h | 104 |
1 files changed, 66 insertions, 38 deletions
diff --git a/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h b/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h index a991eb3bccc..fbc9b9576b8 100644 --- a/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h +++ b/library/cpp/threading/atomic_shared_ptr/atomic_shared_ptr.h @@ -36,11 +36,11 @@ namespace NPrivate { } ui64 Ref(ui64 add = USE_INCREMENT) noexcept { - return ref_count_.fetch_add(add, std::memory_order_relaxed) + add; + return ref_count_.fetch_add(add, std::memory_order_acquire) + add; } ui64 Unref(ui64 sub = USE_INCREMENT) noexcept { - return ref_count_.fetch_sub(sub, std::memory_order_seq_cst) - sub; + return ref_count_.fetch_sub(sub, std::memory_order_acq_rel) - sub; } void UnrefAndDelete(ui64 sub = USE_INCREMENT) { @@ -48,8 +48,7 @@ namespace NPrivate { return; ui64 expect = 0; bool flag_is_set = ref_count_.compare_exchange_strong( - expect, DESTROYED_FLAG, - std::memory_order_seq_cst, std::memory_order_relaxed); + expect, DESTROYED_FLAG, std::memory_order_acq_rel, std::memory_order_relaxed); if (!flag_is_set) return; DestroyPayload(); @@ -61,8 +60,7 @@ namespace NPrivate { return nullptr; ui64 expect = 0; bool flag_is_set = ref_count_.compare_exchange_strong( - expect, DESTROYED_FLAG, - std::memory_order_seq_cst, std::memory_order_relaxed); + expect, DESTROYED_FLAG, std::memory_order_acq_rel, std::memory_order_relaxed); if (!flag_is_set) return nullptr; void* result = GetPtr(); @@ -79,17 +77,16 @@ namespace NPrivate { } ui64 RefWeak(ui64 add = 1) noexcept { - return weak_count_.fetch_add(add, std::memory_order_relaxed) + add; + return weak_count_.fetch_add(add, std::memory_order_acquire) + add; } ui64 UnrefWeak(ui64 sub = 1) noexcept { - return weak_count_.fetch_sub(sub, std::memory_order_seq_cst) - sub; + return weak_count_.fetch_sub(sub, std::memory_order_acq_rel) - sub; } void UnrefWeakAndDelete(ui64 sub = 1) noexcept { - if (UnrefWeak(sub) != 0) - return; - delete this; + if (UnrefWeak(sub) == 0) + delete this; } }; @@ -125,7 +122,7 @@ namespace NPrivate { explicit TSharedBasePtr(TRefCounter* counter_ptr) noexcept : ptr_((uintptr_t)counter_ptr) { - Y_ABORT_UNLESS((ptr_.load(std::memory_order_relaxed) & ~PTR_MASK) == 0, + Y_ABORT_UNLESS((ptr_.load(std::memory_order_acquire) & ~PTR_MASK) == 0, "you must provide a clean ptr"); } @@ -140,21 +137,32 @@ namespace NPrivate { TRefCounter* ConcurrentAcquire() noexcept { auto result = - ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_seq_cst); + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); auto ptr_result = CleanUpPtr(result); - if (ptr_result) - ptr_result->Ref(CONCURRENT_INCREMENT + TRefCounter::USE_INCREMENT); + if (!ptr_result) + return nullptr; + ptr_result->Ref(CONCURRENT_INCREMENT + TRefCounter::USE_INCREMENT); return ptr_result; } TRefCounter* ConcurrentWeakAcquire() noexcept { auto result = - ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_seq_cst); + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); + auto ptr_result = CleanUpPtr(result); + if (!ptr_result) + return nullptr; + ptr_result->RefWeak(); + ptr_result->Ref(CONCURRENT_INCREMENT); + return ptr_result; + } + + TRefCounter* ConcurrentWeakAcquireFromWeak() noexcept { + auto result = + ptr_.fetch_add(CONCURRENT_INCREMENT, std::memory_order_acquire); auto ptr_result = CleanUpPtr(result); - if (ptr_result) { - ptr_result->RefWeak(); - ptr_result->Ref(CONCURRENT_INCREMENT); - } + if (!ptr_result) + return nullptr; + ptr_result->RefWeak(CONCURRENT_INCREMENT + 1); return ptr_result; } @@ -176,35 +184,35 @@ namespace NPrivate { static void DestroyWeakPtr(TRefCounter* ptr) noexcept { auto clean_ptr = CleanUpPtr(ptr); - if (clean_ptr) { - ui64 cnt = GetCounter(ptr); - clean_ptr->UnrefWeakAndDelete(cnt); - } + if (!clean_ptr) + return; + ui64 cnt = GetCounter(ptr); + clean_ptr->UnrefWeakAndDelete(cnt + 1); } static TRefCounter* CleanUpPtr(TRefCounter* ptr) noexcept { - return (TRefCounter*)((uintptr_t)ptr & PTR_MASK); + return CleanUpPtr(reinterpret_cast<uintptr_t>(ptr)); } static TRefCounter* CleanUpPtr(uintptr_t ptr) noexcept { - return (TRefCounter*)(ptr & PTR_MASK); + return reinterpret_cast<TRefCounter*>(ptr & PTR_MASK); } static uintptr_t GetCounter(TRefCounter* ptr) noexcept { - return (uintptr_t)ptr & ~PTR_MASK; + return reinterpret_cast<uintptr_t>(ptr) & ~PTR_MASK; } TRefCounter* GetClean() const noexcept { - return CleanUpPtr(ptr_.load(std::memory_order_relaxed)); + return CleanUpPtr(ptr_.load(std::memory_order_acquire)); } TRefCounter* GetRaw() const noexcept { - return (TRefCounter*)ptr_.load(std::memory_order_relaxed); + return reinterpret_cast<TRefCounter*>(ptr_.load(std::memory_order_acquire)); } TRefCounter* Swap(TRefCounter* other) noexcept { - return (TRefCounter*)ptr_.exchange( - (uintptr_t)other, std::memory_order_seq_cst); + return reinterpret_cast<TRefCounter*>( + ptr_.exchange(reinterpret_cast<uintptr_t>(other))); } size_t UseCount() const noexcept { @@ -212,7 +220,7 @@ namespace NPrivate { auto clean_ptr = CleanUpPtr(ptr); if (!clean_ptr) return 0; - ui64 result = clean_ptr->ref_count_.load(std::memory_order_relaxed); + ui64 result = clean_ptr->ref_count_.load(std::memory_order_acquire); result -= GetCounter(ptr); result /= TRefCounter::USE_INCREMENT; return result; @@ -368,17 +376,36 @@ public: } TTrueAtomicWeakPtr(const TTrueAtomicWeakPtr& other) noexcept - : ptr_(other.ptr_.ConcurrentWeakAcquire()) + : ptr_(other.ptr_.ConcurrentWeakAcquireFromWeak()) { } - TTrueAtomicWeakPtr& operator=(const TTrueAtomicWeakPtr& other) noexcept { + + TTrueAtomicWeakPtr(TTrueAtomicWeakPtr&& other) noexcept + : ptr_(other.ptr_.Swap(nullptr)) + { + } + + ~TTrueAtomicWeakPtr() + { + auto ptr = ptr_.GetRaw(); + NPrivate::TSharedBasePtr::DestroyWeakPtr(ptr); + } + + TTrueAtomicWeakPtr& operator=(const TTrueAtomicSharedPtr<PayloadType>& other) noexcept { auto new_ptr = other.ptr_.ConcurrentWeakAcquire(); auto back_ptr = ptr_.Swap(new_ptr); NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); return *this; } + TTrueAtomicWeakPtr& operator=(const TTrueAtomicWeakPtr& other) noexcept { + auto new_ptr = other.ptr_.ConcurrentWeakAcquireFromWeak(); + auto back_ptr = ptr_.Swap(new_ptr); + NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); + return *this; + } + TTrueAtomicWeakPtr& operator=(TTrueAtomicWeakPtr&& other) noexcept { auto new_ptr = other.ptr_.Swap(nullptr); auto back_ptr = ptr_.Swap(new_ptr); @@ -389,9 +416,10 @@ public: TTrueAtomicSharedPtr<PayloadType> lock() noexcept { // create local TTrueAtomicWeakPtr to avoid concurrent changes of this TTrueAtomicWeakPtr<PayloadType> local(*this); - if (!local.ptr_.GetRaw()->RefFromWeak()) - return TTrueAtomicSharedPtr<PayloadType>(); - return TTrueAtomicSharedPtr<PayloadType>(ptr_.GetClean()); + if (auto raw_ptr = local.ptr_.GetRaw()) + if (raw_ptr->RefFromWeak()) + return TTrueAtomicSharedPtr<PayloadType>(ptr_.GetClean()); + return TTrueAtomicSharedPtr<PayloadType>(); } void reset() noexcept { @@ -403,7 +431,7 @@ public: // Having A with possible accesses from multiple threads and B with exclusive // use from a single thread, then calling A.swap(B) is atomic. void swap(TTrueAtomicWeakPtr& other) noexcept { - auto copy = other.ptr_.ConcurrentWeakAcquire(); + auto copy = other.ptr_.ConcurrentWeakAcquireFromWeak(); auto back_ptr = ptr_.Swap(copy); back_ptr = other.ptr_.Swap(back_ptr); NPrivate::TSharedBasePtr::DestroyWeakPtr(back_ptr); |
