diff options
author | Devtools Arcadia <arcadia-devtools@yandex-team.ru> | 2022-02-07 18:08:42 +0300 |
---|---|---|
committer | Devtools Arcadia <arcadia-devtools@mous.vla.yp-c.yandex.net> | 2022-02-07 18:08:42 +0300 |
commit | 1110808a9d39d4b808aef724c861a2e1a38d2a69 (patch) | |
tree | e26c9fed0de5d9873cce7e00bc214573dc2195b7 /util/ysafeptr.h | |
download | ydb-1110808a9d39d4b808aef724c861a2e1a38d2a69.tar.gz |
intermediate changes
ref:cde9a383711a11544ce7e107a78147fb96cc4029
Diffstat (limited to 'util/ysafeptr.h')
-rw-r--r-- | util/ysafeptr.h | 429 |
1 files changed, 429 insertions, 0 deletions
diff --git a/util/ysafeptr.h b/util/ysafeptr.h new file mode 100644 index 0000000000..af7dfd4bed --- /dev/null +++ b/util/ysafeptr.h @@ -0,0 +1,429 @@ +#pragma once + +#include <stddef.h> +#include <util/system/yassert.h> +#include <util/system/defaults.h> +#include <util/system/tls.h> + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// There are different templates of pointers: +// 1. Simple pointers. +// 2. TPtr with refereces. +// 3. TObj/TMObj with ownership. After destruction of a TObj the object it referenced to is +// cleaned up and marked as non valid. Similarly does TMobj organizing the parallel ownership +// of an object. +// +// Limitations: +// 1. It may be necessary to use BASIC_REGISTER_CLASS() in .cpp files to be able to use a +// pointer to a forward declared class. +// 2. It's prohibited to override the 'new' operator, since the standard 'delete' will be used +// for destruction of objects (because of 'delete this'). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(_MSC_VER) && defined(_DEBUG) + #include <util/system/winint.h> + #define CHECK_YPTR2 +#endif + +struct IBinSaver; + +class IObjectBase { +private: +#ifdef CHECK_YPTR2 + static Y_POD_THREAD(bool) DisableThreadCheck; + void CheckThreadId() { + if (dwThreadId == 0) + dwThreadId = GetCurrentThreadId(); + else + Y_ASSERT(dwThreadId == GetCurrentThreadId() || DisableThreadCheck); + } + void AddRef() { + CheckThreadId(); + ++RefData; + } + void AddObj(int nRef) { + CheckThreadId(); + ObjData += nRef; + } +#else + void CheckThreadId() { + } + void AddRef() { + ++RefData; + } + void AddObj(int nRef) { + ObjData += nRef; + } +#endif + void ReleaseRefComplete(); + void ReleaseObjComplete(int nMask); + void DecRef() { + CheckThreadId(); + --RefData; + } + void DecObj(int nRef) { + CheckThreadId(); + ObjData -= nRef; + } + void ReleaseRef() { + CheckThreadId(); + --RefData; + if (RefData == 0) + ReleaseRefComplete(); + } + void ReleaseObj(int nRef, int nMask) { + CheckThreadId(); + ObjData -= nRef; + if ((ObjData & nMask) == 0) + ReleaseObjComplete(nMask); + } + +protected: +#ifdef CHECK_YPTR2 + DWORD dwThreadId; +#endif + ui32 ObjData; + ui32 RefData; + // function should clear contents of object, easy to implement via consequent calls to + // destructor and constructor, this function should not be called directly, use Clear() + virtual void DestroyContents() = 0; + virtual ~IObjectBase() = default; + inline void CopyValidFlag(const IObjectBase& a) { + ObjData &= 0x7fffffff; + ObjData |= a.ObjData & 0x80000000; + } + +public: + IObjectBase() + : ObjData(0) + , RefData(0) + { +#ifdef CHECK_YPTR2 + dwThreadId = 0; +#endif + } + // do not copy refcount when copy object + IObjectBase(const IObjectBase& a) + : ObjData(0) + , RefData(0) + { +#ifdef CHECK_YPTR2 + dwThreadId = 0; +#endif + CopyValidFlag(a); + } + IObjectBase& operator=(const IObjectBase& a) { + CopyValidFlag(a); + return *this; + } +#ifdef CHECK_YPTR2 + static void SetThreadCheckMode(bool val) { + DisableThreadCheck = !val; + } + void ResetThreadId() { + Y_ASSERT(RefData == 0 && ObjData == 0); // can reset thread check only for ref free objects + dwThreadId = 0; + } +#else + static void SetThreadCheckMode(bool) { + } + void ResetThreadId() { + } +#endif + + // class name of derived class + virtual const char* GetClassName() const = 0; + + ui32 IsRefInvalid() const { + return (ObjData & 0x80000000); + } + ui32 IsRefValid() const { + return !IsRefInvalid(); + } + // reset data in class to default values, saves RefCount from destruction + void Clear() { + AddRef(); + DestroyContents(); + DecRef(); + } + + virtual int operator&(IBinSaver&) { + return 0; + } + + struct TRefO { + void AddRef(IObjectBase* pObj) { + pObj->AddObj(1); + } + void DecRef(IObjectBase* pObj) { + pObj->DecObj(1); + } + void Release(IObjectBase* pObj) { + pObj->ReleaseObj(1, 0x000fffff); + } + }; + struct TRefM { + void AddRef(IObjectBase* pObj) { + pObj->AddObj(0x100000); + } + void DecRef(IObjectBase* pObj) { + pObj->DecObj(0x100000); + } + void Release(IObjectBase* pObj) { + pObj->ReleaseObj(0x100000, 0x3ff00000); + } + }; + struct TRef { + void AddRef(IObjectBase* pObj) { + pObj->AddRef(); + } + void DecRef(IObjectBase* pObj) { + pObj->DecRef(); + } + void Release(IObjectBase* pObj) { + pObj->ReleaseRef(); + } + }; + friend struct IObjectBase::TRef; + friend struct IObjectBase::TRefO; + friend struct IObjectBase::TRefM; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// macro that helps to create neccessary members for proper operation of refcount system +// if class needs special destructor, use CFundament +#define OBJECT_METHODS(classname) \ +public: \ + virtual const char* GetClassName() const override { \ + return #classname; \ + } \ + static IObjectBase* NewSaveLoadNullItem() { \ + return new classname(); \ + } \ + \ +protected: \ + virtual void DestroyContents() override { \ + this->~classname(); \ + int nHoldRefs = this->RefData, nHoldObjs = this->ObjData; \ + new (this) classname(); \ + this->RefData += nHoldRefs; \ + this->ObjData += nHoldObjs; \ + } \ + \ +private: +#define OBJECT_NOCOPY_METHODS(classname) OBJECT_METHODS(classname) +#define BASIC_REGISTER_CLASS(classname) \ + Y_PRAGMA_DIAGNOSTIC_PUSH \ + Y_PRAGMA_NO_UNUSED_FUNCTION \ + template <> \ + IObjectBase* CastToObjectBaseImpl<classname>(classname * p, void*) { \ + return p; \ + } \ + template <> \ + classname* CastToUserObjectImpl<classname>(IObjectBase * p, classname*, void*) { \ + return dynamic_cast<classname*>(p); \ + } \ + Y_PRAGMA_DIAGNOSTIC_POP + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <class TUserObj> +IObjectBase* CastToObjectBaseImpl(TUserObj* p, void*); +template <class TUserObj> +IObjectBase* CastToObjectBaseImpl(TUserObj* p, IObjectBase*) { + return p; +} +template <class TUserObj> +TUserObj* CastToUserObjectImpl(IObjectBase* p, TUserObj*, void*); +template <class TUserObj> +TUserObj* CastToUserObjectImpl(IObjectBase* _p, TUserObj*, IObjectBase*) { + return dynamic_cast<TUserObj*>(_p); +} +template <class TUserObj> +inline IObjectBase* CastToObjectBase(TUserObj* p) { + return CastToObjectBaseImpl(p, p); +} +template <class TUserObj> +inline const IObjectBase* CastToObjectBase(const TUserObj* p) { + return p; +} +template <class TUserObj> +inline TUserObj* CastToUserObject(IObjectBase* p, TUserObj* pu) { + return CastToUserObjectImpl(p, pu, pu); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TObject - base object for reference counting, TUserObj - user object name +// TRef - struct with AddRef/DecRef/Release methods for refcounting to use +template <class TUserObj, class TRef> +class TPtrBase { +private: + TUserObj* ptr; + + void AddRef(TUserObj* _ptr) { + TRef p; + if (_ptr) + p.AddRef(CastToObjectBase(_ptr)); + } + void DecRef(TUserObj* _ptr) { + TRef p; + if (_ptr) + p.DecRef(CastToObjectBase(_ptr)); + } + void Release(TUserObj* _ptr) { + TRef p; + if (_ptr) + p.Release(CastToObjectBase(_ptr)); + } + +protected: + void SetObject(TUserObj* _ptr) { + TUserObj* pOld = ptr; + ptr = _ptr; + AddRef(ptr); + Release(pOld); + } + +public: + TPtrBase() + : ptr(nullptr) + { + } + TPtrBase(TUserObj* _ptr) + : ptr(_ptr) + { + AddRef(ptr); + } + TPtrBase(const TPtrBase& a) + : ptr(a.ptr) + { + AddRef(ptr); + } + ~TPtrBase() { + Release(ptr); + } + + void Set(TUserObj* _ptr) { + SetObject(_ptr); + } + TUserObj* Extract() { + TUserObj* pRes = ptr; + DecRef(ptr); + ptr = nullptr; + return pRes; + } + + const char* GetClassName() const { + return ptr->GetClassName(); + } + + // assignment operators + TPtrBase& operator=(TUserObj* _ptr) { + Set(_ptr); + return *this; + } + TPtrBase& operator=(const TPtrBase& a) { + Set(a.ptr); + return *this; + } + // access + TUserObj* operator->() const { + return ptr; + } + operator TUserObj*() const { + return ptr; + } + TUserObj* Get() const { + return ptr; + } + IObjectBase* GetBarePtr() const { + return CastToObjectBase(ptr); + } + int operator&(IBinSaver& f); +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// +template <class T> +inline bool IsValid(T* p) { + return p != nullptr && !CastToObjectBase(p)->IsRefInvalid(); +} +template <class T, class TRef> +inline bool IsValid(const TPtrBase<T, TRef>& p) { + return p.Get() && !p.GetBarePtr()->IsRefInvalid(); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// +#define BASIC_PTR_DECLARE(TPtrName, TRef) \ + template <class T> \ + class TPtrName: public TPtrBase<T, TRef> { \ + using CBase = TPtrBase<T, TRef>; \ + \ + public: \ + using CDestType = T; \ + TPtrName() { \ + } \ + TPtrName(T* _ptr) \ + : CBase(_ptr) \ + { \ + } \ + TPtrName(const TPtrName& a) \ + : CBase(a) \ + { \ + } \ + TPtrName& operator=(T* _ptr) { \ + this->Set(_ptr); \ + return *this; \ + } \ + TPtrName& operator=(const TPtrName& a) { \ + this->SetObject(a.Get()); \ + return *this; \ + } \ + int operator&(IBinSaver& f) { \ + return (*(CBase*)this) & (f); \ + } \ + }; + +BASIC_PTR_DECLARE(TPtr, IObjectBase::TRef) +BASIC_PTR_DECLARE(TObj, IObjectBase::TRefO) +BASIC_PTR_DECLARE(TMObj, IObjectBase::TRefM) +// misuse guard +template <class T> +inline bool IsValid(TObj<T>* p) { + return p->YouHaveMadeMistake(); +} +template <class T> +inline bool IsValid(TPtr<T>* p) { + return p->YouHaveMadeMistake(); +} +template <class T> +inline bool IsValid(TMObj<T>* p) { + return p->YouHaveMadeMistake(); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// +// assumes base class is IObjectBase +template <class T> +class TDynamicCast { + T* ptr; + +public: + template <class TT> + TDynamicCast(TT* _ptr) { + ptr = dynamic_cast<T*>(CastToObjectBase(_ptr)); + } + template <class TT> + TDynamicCast(const TT* _ptr) { + ptr = dynamic_cast<T*>(CastToObjectBase(const_cast<TT*>(_ptr))); + } + template <class T1, class T2> + TDynamicCast(const TPtrBase<T1, T2>& _ptr) { + ptr = dynamic_cast<T*>(_ptr.GetBarePtr()); + } + operator T*() const { + return ptr; + } + T* operator->() const { + return ptr; + } + T* Get() const { + return ptr; + } +}; +template <class T> +inline bool IsValid(const TDynamicCast<T>& p) { + return IsValid(p.Get()); +} |