Serenity Operating System
1/*
2 * Copyright (c) 2018-2020, Andreas Kling <kling@serenityos.org>
3 *
4 * SPDX-License-Identifier: BSD-2-Clause
5 */
6
7#pragma once
8
9#include <Kernel/Library/LockWeakable.h>
10
11namespace AK {
12
13template<typename T>
14class [[nodiscard]] LockWeakPtr {
15 template<typename U>
16 friend class LockWeakable;
17
18public:
19 LockWeakPtr() = default;
20
21 template<typename U>
22 LockWeakPtr(WeakPtr<U> const& other)
23 requires(IsBaseOf<T, U>)
24 : m_link(other.m_link)
25 {
26 }
27
28 template<typename U>
29 LockWeakPtr(WeakPtr<U>&& other)
30 requires(IsBaseOf<T, U>)
31 : m_link(other.take_link())
32 {
33 }
34
35 template<typename U>
36 LockWeakPtr& operator=(WeakPtr<U>&& other)
37 requires(IsBaseOf<T, U>)
38 {
39 m_link = other.take_link();
40 return *this;
41 }
42
43 template<typename U>
44 LockWeakPtr& operator=(WeakPtr<U> const& other)
45 requires(IsBaseOf<T, U>)
46 {
47 if ((void const*)this != (void const*)&other)
48 m_link = other.m_link;
49 return *this;
50 }
51
52 LockWeakPtr& operator=(nullptr_t)
53 {
54 clear();
55 return *this;
56 }
57
58 template<typename U>
59 LockWeakPtr(U const& object)
60 requires(IsBaseOf<T, U>)
61 : m_link(object.template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link())
62 {
63 }
64
65 template<typename U>
66 LockWeakPtr(U const* object)
67 requires(IsBaseOf<T, U>)
68 {
69 if (object)
70 m_link = object->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
71 }
72
73 template<typename U>
74 LockWeakPtr(LockRefPtr<U> const& object)
75 requires(IsBaseOf<T, U>)
76 {
77 object.do_while_locked([&](U* obj) {
78 if (obj)
79 m_link = obj->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
80 });
81 }
82
83 template<typename U>
84 LockWeakPtr(NonnullLockRefPtr<U> const& object)
85 requires(IsBaseOf<T, U>)
86 {
87 object.do_while_locked([&](U* obj) {
88 if (obj)
89 m_link = obj->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
90 });
91 }
92
93 template<typename U>
94 LockWeakPtr& operator=(U const& object)
95 requires(IsBaseOf<T, U>)
96 {
97 m_link = object.template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
98 return *this;
99 }
100
101 template<typename U>
102 LockWeakPtr& operator=(U const* object)
103 requires(IsBaseOf<T, U>)
104 {
105 if (object)
106 m_link = object->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
107 else
108 m_link = nullptr;
109 return *this;
110 }
111
112 template<typename U>
113 LockWeakPtr& operator=(LockRefPtr<U> const& object)
114 requires(IsBaseOf<T, U>)
115 {
116 object.do_while_locked([&](U* obj) {
117 if (obj)
118 m_link = obj->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
119 else
120 m_link = nullptr;
121 });
122 return *this;
123 }
124
125 template<typename U>
126 LockWeakPtr& operator=(NonnullLockRefPtr<U> const& object)
127 requires(IsBaseOf<T, U>)
128 {
129 object.do_while_locked([&](U* obj) {
130 if (obj)
131 m_link = obj->template try_make_weak_ptr<U>().release_value_but_fixme_should_propagate_errors().take_link();
132 else
133 m_link = nullptr;
134 });
135 return *this;
136 }
137
138 [[nodiscard]] LockRefPtr<T> strong_ref() const
139 {
140 // This only works with RefCounted objects, but it is the only
141 // safe way to get a strong reference from a LockWeakPtr. Any code
142 // that uses objects not derived from RefCounted will have to
143 // use unsafe_ptr(), but as the name suggests, it is not safe...
144 LockRefPtr<T> ref;
145 // Using do_while_locked protects against a race with clear()!
146 m_link.do_while_locked([&](LockWeakLink* link) {
147 if (link)
148 ref = link->template strong_ref<T>();
149 });
150 return ref;
151 }
152
153 [[nodiscard]] T* unsafe_ptr() const
154 {
155 T* ptr = nullptr;
156 m_link.do_while_locked([&](LockWeakLink* link) {
157 if (link)
158 ptr = link->unsafe_ptr<T>();
159 });
160 return ptr;
161 }
162
163 operator bool() const { return m_link ? !m_link->is_null() : false; }
164
165 [[nodiscard]] bool is_null() const { return !m_link || m_link->is_null(); }
166 void clear() { m_link = nullptr; }
167
168 [[nodiscard]] LockRefPtr<LockWeakLink> take_link() { return move(m_link); }
169
170private:
171 LockWeakPtr(LockRefPtr<LockWeakLink> const& link)
172 : m_link(link)
173 {
174 }
175
176 LockRefPtr<LockWeakLink> m_link;
177};
178
179template<typename T>
180template<typename U>
181inline ErrorOr<LockWeakPtr<U>> LockWeakable<T>::try_make_weak_ptr() const
182{
183 if constexpr (IsBaseOf<AtomicRefCountedBase, T>) {
184 // Checking m_being_destroyed isn't sufficient when dealing with
185 // a RefCounted type.The reference count will drop to 0 before the
186 // destructor is invoked and revoke_weak_ptrs is called. So, try
187 // to add a ref (which should fail if the ref count is at 0) so
188 // that we prevent the destructor and revoke_weak_ptrs from being
189 // triggered until we're done.
190 if (!static_cast<T const*>(this)->try_ref())
191 return LockWeakPtr<U> {};
192 } else {
193 // For non-RefCounted types this means a weak reference can be
194 // obtained until the ~LockWeakable destructor is invoked!
195 if (m_being_destroyed.load(AK::MemoryOrder::memory_order_acquire))
196 return LockWeakPtr<U> {};
197 }
198 if (!m_link) {
199 // There is a small chance that we create a new LockWeakLink and throw
200 // it away because another thread beat us to it. But the window is
201 // pretty small and the overhead isn't terrible.
202 m_link.assign_if_null(TRY(adopt_nonnull_lock_ref_or_enomem(new (nothrow) LockWeakLink(const_cast<T&>(static_cast<T const&>(*this))))));
203 }
204
205 LockWeakPtr<U> weak_ptr(m_link);
206
207 if constexpr (IsBaseOf<AtomicRefCountedBase, T>) {
208 // Now drop the reference we temporarily added
209 if (static_cast<T const*>(this)->unref()) {
210 // We just dropped the last reference, which should have called
211 // revoke_weak_ptrs, which should have invalidated our weak_ptr
212 VERIFY(!weak_ptr.strong_ref());
213 return LockWeakPtr<U> {};
214 }
215 }
216 return weak_ptr;
217}
218
219template<typename T>
220struct Formatter<LockWeakPtr<T>> : Formatter<T const*> {
221 ErrorOr<void> format(FormatBuilder& builder, LockWeakPtr<T> const& value)
222 {
223 auto ref = value.strong_ref();
224 return Formatter<T const*>::format(builder, ref.ptr());
225 }
226};
227
228template<typename T>
229ErrorOr<LockWeakPtr<T>> try_make_weak_ptr_if_nonnull(T const* ptr)
230{
231 if (ptr) {
232 return ptr->template try_make_weak_ptr<T>();
233 }
234 return LockWeakPtr<T> {};
235}
236
237}
238
239using AK::LockWeakPtr;