aboutsummaryrefslogtreecommitdiffstats
path: root/library/cpp/messagebus/misc/weak_ptr.h
blob: 45f05cae5696ff57cef0648a3a2c4d3e5fe90d81 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#pragma once

#include <util/generic/ptr.h>
#include <util/system/mutex.h>

template <typename T>
struct TWeakPtr;

template <typename TSelf>
struct TWeakRefCounted {
    template <typename>
    friend struct TWeakPtr;

private:
    struct TRef: public TAtomicRefCount<TRef> {
        TMutex Mutex;
        TSelf* Outer;

        TRef(TSelf* outer)
            : Outer(outer)
        {
        }

        void Release() {
            TGuard<TMutex> g(Mutex);
            Y_ASSERT(!!Outer);
            Outer = nullptr; 
        }

        TIntrusivePtr<TSelf> Get() {
            TGuard<TMutex> g(Mutex);
            Y_ASSERT(!Outer || Outer->RefCount() > 0);
            return Outer;
        }
    };

    TAtomicCounter Counter;
    TIntrusivePtr<TRef> RefPtr;

public:
    TWeakRefCounted()
        : RefPtr(new TRef(static_cast<TSelf*>(this)))
    {
    }

    void Ref() {
        Counter.Inc();
    }

    void UnRef() {
        if (Counter.Dec() == 0) {
            RefPtr->Release();

            // drop is to prevent dtor from reading it
            RefPtr.Drop();

            delete static_cast<TSelf*>(this);
        }
    }

    void DecRef() {
        Counter.Dec();
    }

    unsigned RefCount() const {
        return Counter.Val();
    }
};

template <typename T>
struct TWeakPtr {
private:
    typedef TIntrusivePtr<typename T::TRef> TRefPtr;
    TRefPtr RefPtr;

public:
    TWeakPtr() {
    }

    TWeakPtr(T* t) {
        if (!!t) {
            RefPtr = t->RefPtr;
        }
    }

    TWeakPtr(TIntrusivePtr<T> t) {
        if (!!t) {
            RefPtr = t->RefPtr;
        }
    }

    TIntrusivePtr<T> Get() {
        if (!RefPtr) {
            return nullptr; 
        } else {
            return RefPtr->Get();
        }
    }
};