diff options
author | lukyan <lukyan@yandex-team.com> | 2022-08-15 22:02:21 +0300 |
---|---|---|
committer | lukyan <lukyan@yandex-team.com> | 2022-08-15 22:02:21 +0300 |
commit | d9fbef9fe6370e3b86bab8759ee0b6fd44b0effe (patch) | |
tree | df739788c3940d2b53ed3cbb25dba4d93132124c /library | |
parent | 7a392c977a198d4bf975393edd99ef939a94a8ba (diff) | |
download | ydb-d9fbef9fe6370e3b86bab8759ee0b6fd44b0effe.tar.gz |
Add simple atomic ptr
Diffstat (limited to 'library')
-rw-r--r-- | library/cpp/yt/memory/atomic_intrusive_ptr-inl.h | 211 | ||||
-rw-r--r-- | library/cpp/yt/memory/atomic_intrusive_ptr.h | 80 | ||||
-rw-r--r-- | library/cpp/yt/memory/ref_counted-inl.h | 21 | ||||
-rw-r--r-- | library/cpp/yt/memory/ref_counted.h | 8 | ||||
-rw-r--r-- | library/cpp/yt/memory/unittests/atomic_intrusive_ptr_ut.cpp | 267 | ||||
-rw-r--r-- | library/cpp/yt/memory/unittests/intrusive_ptr_ut.cpp | 4 |
6 files changed, 574 insertions, 17 deletions
diff --git a/library/cpp/yt/memory/atomic_intrusive_ptr-inl.h b/library/cpp/yt/memory/atomic_intrusive_ptr-inl.h new file mode 100644 index 0000000000..e2a6fd41fd --- /dev/null +++ b/library/cpp/yt/memory/atomic_intrusive_ptr-inl.h @@ -0,0 +1,211 @@ +#ifndef ATOMIC_INTRUSIVE_PTR_INL_H_ +#error "Direct inclusion of this file is not allowed, include atomic_intrusive_ptr.h" +// For the sake of sane code completion. +#include "atomic_intrusive_ptr.h" +#endif +#undef ATOMIC_INTRUSIVE_PTR_INL_H_ + +#include <util/system/spinlock.h> + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +TAtomicIntrusivePtr<T>::TAtomicIntrusivePtr(std::nullptr_t) +{ } + +template <class T> +TAtomicIntrusivePtr<T>::TAtomicIntrusivePtr(TIntrusivePtr<T> other) + : Ptr_(AcquireObject(other.Release(), true)) +{ } + +template <class T> +TAtomicIntrusivePtr<T>::TAtomicIntrusivePtr(TAtomicIntrusivePtr&& other) + : Ptr_(other.Ptr_.load(std::memory_order_relaxed)) +{ + other.Ptr_.store(nullptr, std::memory_order_relaxed); +} + +template <class T> +TAtomicIntrusivePtr<T>::~TAtomicIntrusivePtr() +{ + ReleaseObject(Ptr_.load()); +} + +template <class T> +TAtomicIntrusivePtr<T>& TAtomicIntrusivePtr<T>::operator=(TIntrusivePtr<T> other) +{ + ReleaseObject(Ptr_.exchange(AcquireObject(other.Release(), true))); + return *this; +} + +template <class T> +TAtomicIntrusivePtr<T>& TAtomicIntrusivePtr<T>::operator=(std::nullptr_t) +{ + Reset(); + return *this; +} + +template <class T> +TIntrusivePtr<T> TAtomicIntrusivePtr<T>::Acquire() +{ + char* ptr = Ptr_.load(); + while (true) { + auto [localRefs, obj] = UnpackPointer<T>(ptr); + + if (!obj) { + return TIntrusivePtr<T>(); + } + + YT_VERIFY(localRefs < ReservedRefCount); + + auto newLocalRefs = localRefs + 1; + + if (newLocalRefs == ReservedRefCount) { + SpinLockPause(); + + ptr = Ptr_.load(); + continue; + } + + // Can not Ref(obj) here because it can be destroyed. + if (Ptr_.compare_exchange_weak(ptr, PackPointer(obj, newLocalRefs))) { + if (Y_UNLIKELY(newLocalRefs > ReservedRefCount / 2)) { + Ref(obj, ReservedRefCount / 2); + + // Decrease local ref count. + while (true) { + auto [localRefs, currentObj] = UnpackPointer<T>(ptr); + + if (currentObj != obj || localRefs <= ReservedRefCount / 2) { + Unref(obj, ReservedRefCount / 2); + break; + } + + if (Ptr_.compare_exchange_weak(ptr, PackPointer(obj, localRefs - ReservedRefCount / 2))) { + break; + } + } + } + + return TIntrusivePtr<T>(obj, false); + } + } +} + +template <class T> +TIntrusivePtr<T> TAtomicIntrusivePtr<T>::Exchange(TIntrusivePtr<T> other) +{ + auto [localRefs, obj] = UnpackPointer<T>(Ptr_.exchange(AcquireObject(other.Release(), true))); + DoRelease(obj, localRefs + 1); + return TIntrusivePtr<T>(obj, false); +} + +template <class T> +void TAtomicIntrusivePtr<T>::Reset() +{ + ReleaseObject(Ptr_.exchange(nullptr)); +} + +template <class T> +bool TAtomicIntrusivePtr<T>::CompareAndSwap(void*& comparePtr, T* target) +{ + auto targetPtr = AcquireObject(target, false); + + auto currentPtr = Ptr_.load(); + if (UnpackPointer<T>(currentPtr).Ptr == comparePtr && Ptr_.compare_exchange_strong(currentPtr, targetPtr)) { + ReleaseObject(currentPtr); + return true; + } + + comparePtr = UnpackPointer<T>(currentPtr).Ptr; + + ReleaseObject(targetPtr); + return false; +} + +template <class T> +bool TAtomicIntrusivePtr<T>::CompareAndSwap(void*& comparePtr, TIntrusivePtr<T> target) +{ + // TODO(lukyan): Make helper for packed owning ptr? + auto targetPtr = AcquireObject(target.Release(), true); + + auto currentPtr = Ptr_.load(); + if (UnpackPointer<T>(currentPtr).Ptr == comparePtr && Ptr_.compare_exchange_strong(currentPtr, targetPtr)) { + ReleaseObject(currentPtr); + return true; + } + + comparePtr = UnpackPointer<T>(currentPtr).Ptr; + + ReleaseObject(targetPtr); + return false; +} + +template <class T> +void* TAtomicIntrusivePtr<T>::Get() const +{ + return UnpackPointer<void>(Ptr_.load()).Ptr; +} + +template <class T> +TAtomicIntrusivePtr<T>::operator bool() const +{ + return Get(); +} + +template <class T> +char* TAtomicIntrusivePtr<T>::AcquireObject(T* obj, bool consumeRef) +{ + if (obj) { + Ref(obj, static_cast<int>(ReservedRefCount - consumeRef)); + } + + return PackPointer(obj, 0); +} + +template <class T> +void TAtomicIntrusivePtr<T>::ReleaseObject(void* packedPtr) +{ + auto [localRefs, obj] = UnpackPointer<T>(packedPtr); + DoRelease(obj, localRefs); +} + +template <class T> +void TAtomicIntrusivePtr<T>::DoRelease(T* obj, int refs) +{ + if (obj) { + Unref(obj, static_cast<int>(ReservedRefCount - refs)); + } +} + +//////////////////////////////////////////////////////////////////////////////// + +template <class T> +bool operator==(const TAtomicIntrusivePtr<T>& lhs, const TIntrusivePtr<T>& rhs) +{ + return lhs.Get() == rhs.Get(); +} + +template <class T> +bool operator==(const TIntrusivePtr<T>& lhs, const TAtomicIntrusivePtr<T>& rhs) +{ + return lhs.Get() == rhs.Get(); +} + +template <class T> +bool operator!=(const TAtomicIntrusivePtr<T>& lhs, const TIntrusivePtr<T>& rhs) +{ + return lhs.Get() != rhs.Get(); +} + +template <class T> +bool operator!=(const TIntrusivePtr<T>& lhs, const TAtomicIntrusivePtr<T>& rhs) +{ + return lhs.Get() != rhs.Get(); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT diff --git a/library/cpp/yt/memory/atomic_intrusive_ptr.h b/library/cpp/yt/memory/atomic_intrusive_ptr.h new file mode 100644 index 0000000000..c487330352 --- /dev/null +++ b/library/cpp/yt/memory/atomic_intrusive_ptr.h @@ -0,0 +1,80 @@ +#pragma once + +#include "intrusive_ptr.h" + +namespace NYT { + +//////////////////////////////////////////////////////////////////////////////// + +// Atomic ptr based on https://github.com/facebook/folly/blob/main/folly/concurrency/AtomicSharedPtr.h + +// Operators * and -> for TAtomicIntrusivePtr are useless because it is not safe to work with atomic ptr such way +// Safe usage is to convert to TIntrusivePtr. + +// Max TAtomicIntrusivePtr count per object is (2**16 = 2**32 / 2**16). + +template <class T> +class TAtomicIntrusivePtr +{ +public: + TAtomicIntrusivePtr() = default; + TAtomicIntrusivePtr(std::nullptr_t); + + explicit TAtomicIntrusivePtr(TIntrusivePtr<T> other); + TAtomicIntrusivePtr(TAtomicIntrusivePtr&& other); + + ~TAtomicIntrusivePtr(); + + TAtomicIntrusivePtr& operator=(TIntrusivePtr<T> other); + TAtomicIntrusivePtr& operator=(std::nullptr_t); + + TIntrusivePtr<T> Acquire(); + + TIntrusivePtr<T> Exchange(TIntrusivePtr<T> other); + + void Reset(); + bool CompareAndSwap(void*& comparePtr, T* target); + bool CompareAndSwap(void*& comparePtr, TIntrusivePtr<T> target); + + // Result is suitable only for comparison. Not dereference. + void* Get() const; + + explicit operator bool() const; + +private: + template <class U> + friend bool operator==(const TAtomicIntrusivePtr<U>& lhs, const TIntrusivePtr<U>& rhs); + + template <class U> + friend bool operator==(const TIntrusivePtr<U>& lhs, const TAtomicIntrusivePtr<U>& rhs); + + template <class U> + friend bool operator!=(const TAtomicIntrusivePtr<U>& lhs, const TIntrusivePtr<U>& rhs); + + template <class U> + friend bool operator!=(const TIntrusivePtr<U>& lhs, const TAtomicIntrusivePtr<U>& rhs); + + // Keeps packed pointer (localRefCount, objectPtr). + // Atomic ptr holds N references, where N = ReservedRefCount - localRefCount. + // LocalRefCount is incremented in Acquire method. + // When localRefCount exceeds ReservedRefCount / 2 a new portion of refs are required globally. + std::atomic<char*> Ptr_ = nullptr; + + constexpr static int CounterBits = 64 - PtrBits; + constexpr static int ReservedRefCount = (1 << CounterBits) - 1; + + // Consume ref if ownership is transferred. + // AcquireObject(ptr.Release(), true) + // AcquireObject(ptr.Get(), false) + static char* AcquireObject(T* obj, bool consumeRef = false); + static void ReleaseObject(void* packedPtr); + static void DoRelease(T* obj, int refs); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace NYT + +#define ATOMIC_INTRUSIVE_PTR_INL_H_ +#include "atomic_intrusive_ptr-inl.h" +#undef ATOMIC_INTRUSIVE_PTR_INL_H_ diff --git a/library/cpp/yt/memory/ref_counted-inl.h b/library/cpp/yt/memory/ref_counted-inl.h index b41f64e803..1ad7433ca9 100644 --- a/library/cpp/yt/memory/ref_counted-inl.h +++ b/library/cpp/yt/memory/ref_counted-inl.h @@ -69,10 +69,10 @@ Y_FORCE_INLINE int TRefCounter::GetRefCount() const noexcept return StrongCount_.load(std::memory_order_acquire); } -Y_FORCE_INLINE void TRefCounter::Ref() const noexcept +Y_FORCE_INLINE void TRefCounter::Ref(int n) const noexcept { // It is safe to use relaxed here, since new reference is always created from another live reference. - StrongCount_.fetch_add(1, std::memory_order_relaxed); + StrongCount_.fetch_add(n, std::memory_order_relaxed); YT_ASSERT(WeakCount_.load(std::memory_order_relaxed) > 0); } @@ -86,16 +86,16 @@ Y_FORCE_INLINE bool TRefCounter::TryRef() const noexcept return value != 0; } -Y_FORCE_INLINE bool TRefCounter::Unref() const +Y_FORCE_INLINE bool TRefCounter::Unref(int n) const { // We must properly synchronize last access to object with it destruction. // Otherwise compiler might reorder access to object past this decrement. // // See http://www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html#boost_atomic.usage_examples.example_reference_counters // - auto oldStrongCount = StrongCount_.fetch_sub(1, std::memory_order_release); - YT_ASSERT(oldStrongCount > 0); - if (oldStrongCount == 1) { + auto oldStrongCount = StrongCount_.fetch_sub(n, std::memory_order_release); + YT_ASSERT(oldStrongCount >= n); + if (oldStrongCount == n) { std::atomic_thread_fence(std::memory_order_acquire); return true; } else { @@ -216,15 +216,15 @@ Y_FORCE_INLINE void DeallocateRefCounted(const T* obj) //////////////////////////////////////////////////////////////////////////////// template <class T> -Y_FORCE_INLINE void Ref(T* obj) +Y_FORCE_INLINE void Ref(T* obj, int n) { - GetRefCounter(obj)->Ref(); + GetRefCounter(obj)->Ref(n); } template <class T> -Y_FORCE_INLINE void Unref(T* obj) +Y_FORCE_INLINE void Unref(T* obj, int n) { - if (GetRefCounter(obj)->Unref()) { + if (GetRefCounter(obj)->Unref(n)) { DestroyRefCounted(obj); } } @@ -243,7 +243,6 @@ Y_FORCE_INLINE void TRefCounted::WeakUnref() const } } - template <class T> void TRefCounted::DestroyRefCountedImpl(T* ptr) { diff --git a/library/cpp/yt/memory/ref_counted.h b/library/cpp/yt/memory/ref_counted.h index a295444f49..9c049c0998 100644 --- a/library/cpp/yt/memory/ref_counted.h +++ b/library/cpp/yt/memory/ref_counted.h @@ -65,13 +65,13 @@ public: int GetRefCount() const noexcept; //! Increments the strong reference counter. - void Ref() const noexcept; + void Ref(int n = 1) const noexcept; //! Increments the strong reference counter if it is not null. bool TryRef() const noexcept; //! Decrements the strong reference counter. - bool Unref() const; + bool Unref(int n = 1) const; //! Returns current number of weak references to the object. int GetWeakRefCount() const noexcept; @@ -103,10 +103,10 @@ void DeallocateRefCounted(const T* obj); // API template <class T> -void Ref(T* obj); +void Ref(T* obj, int n = 1); template <class T> -void Unref(T* obj); +void Unref(T* obj, int n = 1); //////////////////////////////////////////////////////////////////////////////// diff --git a/library/cpp/yt/memory/unittests/atomic_intrusive_ptr_ut.cpp b/library/cpp/yt/memory/unittests/atomic_intrusive_ptr_ut.cpp new file mode 100644 index 0000000000..27c87d379b --- /dev/null +++ b/library/cpp/yt/memory/unittests/atomic_intrusive_ptr_ut.cpp @@ -0,0 +1,267 @@ +#include <library/cpp/testing/gtest/gtest.h> + +#include <library/cpp/yt/memory/new.h> +#include <library/cpp/yt/memory/ref_counted.h> +#include <library/cpp/yt/memory/atomic_intrusive_ptr.h> + +namespace NYT { +namespace { + +//////////////////////////////////////////////////////////////////////////////// + +using ::testing::IsNull; +using ::testing::NotNull; +using ::testing::InSequence; +using ::testing::MockFunction; +using ::testing::StrictMock; + +//////////////////////////////////////////////////////////////////////////////// +// Auxiliary types and functions. +//////////////////////////////////////////////////////////////////////////////// + +// This object tracks number of increments and decrements +// to the reference counter (see traits specialization below). +struct TIntricateObject + : private TNonCopyable +{ + int Increments = 0; + int Decrements = 0; + int Zeros = 0; + + void Ref(int n) + { + Increments += n; + } + + void Unref(int n) + { + Decrements += n; + if (Increments == Decrements) { + ++Zeros; + } + } +}; + +typedef TIntrusivePtr<TIntricateObject> TIntricateObjectPtr; + +void Ref(TIntricateObject* obj, int n = 1) +{ + obj->Ref(n); +} + +void Unref(TIntricateObject* obj, int n = 1) +{ + obj->Unref(n); +} + +MATCHER_P3(HasRefCounts, increments, decrements, zeros, + "Reference counter " \ + "was incremented " + ::testing::PrintToString(increments) + " times, " + + "was decremented " + ::testing::PrintToString(decrements) + " times, " + + "vanished to zero " + ::testing::PrintToString(zeros) + " times") +{ + Y_UNUSED(result_listener); + return + arg.Increments == increments && + arg.Decrements == decrements && + arg.Zeros == zeros; +} + +void PrintTo(const TIntricateObject& arg, ::std::ostream* os) +{ + *os << arg.Increments << " increments, " + << arg.Decrements << " decrements and " + << arg.Zeros << " times vanished"; +} + +// This is an object which creates intrusive pointers to the self +// during its construction. +class TObjectWithSelfPointers + : public TRefCounted +{ +public: + explicit TObjectWithSelfPointers(IOutputStream* output) + : Output_(output) + { + *Output_ << "Cb"; + + for (int i = 0; i < 3; ++i) { + *Output_ << '!'; + TIntrusivePtr<TObjectWithSelfPointers> ptr(this); + } + + *Output_ << "Ca"; + } + + virtual ~TObjectWithSelfPointers() + { + *Output_ << 'D'; + } + +private: + IOutputStream* const Output_; + +}; + +// This is a simple object with simple reference counting. +class TObjectWithSimpleRC + : public TRefCounted +{ +public: + explicit TObjectWithSimpleRC(IOutputStream* output) + : Output_(output) + { + *Output_ << 'C'; + } + + virtual ~TObjectWithSimpleRC() + { + *Output_ << 'D'; + } + + void DoSomething() + { + *Output_ << '!'; + } + +private: + IOutputStream* const Output_; + +}; + +// This is a simple object with full-fledged reference counting. +class TObjectWithFullRC + : public TRefCounted +{ +public: + explicit TObjectWithFullRC(IOutputStream* output) + : Output_(output) + { + *Output_ << 'C'; + } + + virtual ~TObjectWithFullRC() + { + *Output_ << 'D'; + } + + void DoSomething() + { + *Output_ << '!'; + } + +private: + IOutputStream* const Output_; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +TEST(TAtomicPtrTest, Empty) +{ + TIntricateObjectPtr emptyPointer; + EXPECT_EQ(nullptr, emptyPointer.Get()); +} + +// Reserved ref count. +constexpr int RRC = 65535; + +TEST(TAtomicPtrTest, Basic) +{ + TIntricateObject object; + + EXPECT_THAT(object, HasRefCounts(0, 0, 0)); + + { + TIntricateObjectPtr owningPointer(&object); + EXPECT_THAT(object, HasRefCounts(1, 0, 0)); + EXPECT_EQ(&object, owningPointer.Get()); + } + + EXPECT_THAT(object, HasRefCounts(1, 1, 1)); + + + { + TIntricateObjectPtr owningPointer(&object); + TAtomicIntrusivePtr<TIntricateObject> atomicPointer(owningPointer); + + EXPECT_THAT(object, HasRefCounts(2 + RRC, 1, 1)); + EXPECT_EQ(&object, owningPointer.Get()); + + + auto p1 = atomicPointer.Acquire(); + + EXPECT_THAT(object, HasRefCounts(2 + RRC, 1, 1)); + + p1.Reset(); + + EXPECT_THAT(object, HasRefCounts(2 + RRC, 2, 1)); + + owningPointer.Reset(); + + EXPECT_THAT(object, HasRefCounts(2 + RRC, 3, 1)); + } + + EXPECT_THAT(object, HasRefCounts(2 + RRC, 2 + RRC, 2)); +} + +TEST(TAtomicPtrTest, Acquire) +{ + TIntricateObject object; + { + TAtomicIntrusivePtr<TIntricateObject> atomicPtr{TIntricateObjectPtr(&object)}; + EXPECT_THAT(object, HasRefCounts(RRC, 0, 0)); + + for (int i = 0; i < RRC / 2; ++i) { + { + auto tmp = atomicPtr.Acquire(); + EXPECT_THAT(object, HasRefCounts(RRC, i, 0)); + } + EXPECT_THAT(object, HasRefCounts(RRC, i + 1, 0)); + } + + { + auto tmp = atomicPtr.Acquire(); + EXPECT_THAT(object, HasRefCounts( RRC + RRC / 2, RRC - 1, 0)); + } + + EXPECT_THAT(object, HasRefCounts(RRC + RRC / 2, RRC, 0)); + } + + EXPECT_THAT(object, HasRefCounts(RRC + RRC / 2, RRC + RRC / 2, 1)); +} + +TEST(TAtomicPtrTest, CAS) +{ + TIntricateObject o1; + TIntricateObject o2; + { + + TAtomicIntrusivePtr<TIntricateObject> atomicPtr{TIntricateObjectPtr(&o1)}; + EXPECT_THAT(o1, HasRefCounts(RRC, 0, 0)); + + TIntricateObjectPtr p2(&o2); + EXPECT_THAT(o2, HasRefCounts(1, 0, 0)); + + void* rawPtr = &o1; + EXPECT_TRUE(atomicPtr.CompareAndSwap(rawPtr, std::move(p2))); + EXPECT_EQ(rawPtr, &o1); + + EXPECT_THAT(o1, HasRefCounts(RRC, RRC, 1)); + EXPECT_THAT(o2, HasRefCounts(RRC, 0, 0)); + + rawPtr = nullptr; + EXPECT_FALSE(atomicPtr.CompareAndSwap(rawPtr, TIntricateObjectPtr(&o1))); + EXPECT_EQ(rawPtr, &o2); + + EXPECT_THAT(o1, HasRefCounts(2 * RRC, 2 * RRC, 2)); + EXPECT_THAT(o2, HasRefCounts(RRC, 0, 0)); + } + + EXPECT_THAT(o2, HasRefCounts(RRC, RRC, 1)); +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace +} // namespace NYT diff --git a/library/cpp/yt/memory/unittests/intrusive_ptr_ut.cpp b/library/cpp/yt/memory/unittests/intrusive_ptr_ut.cpp index 622bed0eb0..5c35612ee0 100644 --- a/library/cpp/yt/memory/unittests/intrusive_ptr_ut.cpp +++ b/library/cpp/yt/memory/unittests/intrusive_ptr_ut.cpp @@ -43,12 +43,12 @@ struct TIntricateObject typedef TIntrusivePtr<TIntricateObject> TIntricateObjectPtr; -void Ref(TIntricateObject* obj) +void Ref(TIntricateObject* obj, int /*n*/ = 1) { obj->Ref(); } -void Unref(TIntricateObject* obj) +void Unref(TIntricateObject* obj, int /*n*/ = 1) { obj->Unref(); } |