Next Generation WASM Microkernel Operating System
1// Copyright 2025 Jonas Kruckenberg
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::task::Id;
9use crate::task::TaskRef;
10use alloc::boxed::Box;
11use alloc::string::String;
12use core::any::Any;
13use core::fmt;
14use core::future::Future;
15use core::marker::PhantomData;
16use core::ops::Deref;
17use core::panic::{RefUnwindSafe, UnwindSafe};
18use core::pin::Pin;
19use core::task::{Context, Poll};
20
21pub struct JoinHandle<T> {
22 state: JoinHandleState,
23 id: Id,
24 _p: PhantomData<T>,
25}
26static_assertions::assert_impl_all!(JoinHandle<()>: Send);
27
28#[derive(Debug)]
29enum JoinHandleState {
30 Task(TaskRef),
31 Empty,
32}
33
34pub struct JoinError<T> {
35 kind: JoinErrorKind,
36 id: Id,
37 output: Option<T>,
38}
39
40#[derive(Debug)]
41#[non_exhaustive]
42pub enum JoinErrorKind {
43 Cancelled { completed: bool },
44 Panic(Box<dyn Any + Send + 'static>),
45}
46
47// === impl JoinHandle ===
48
49impl<T> UnwindSafe for JoinHandle<T> {}
50
51impl<T> RefUnwindSafe for JoinHandle<T> {}
52
53impl<T> Unpin for JoinHandle<T> {}
54
55impl<T> Drop for JoinHandle<T> {
56 fn drop(&mut self) {
57 // if the JoinHandle has not already been consumed, clear the join
58 // handle flag on the task.
59 if let JoinHandleState::Task(ref task) = self.state {
60 tracing::trace!(
61 state=?self.state,
62 task.id=?task.id(),
63 consumed=false,
64 "drop JoinHandle"
65 );
66
67 task.state().drop_join_handle();
68 } else {
69 tracing::trace!(
70 state=?self.state,
71 consumed=false,
72 "drop JoinHandle"
73 );
74 }
75 }
76}
77
78impl<T> fmt::Debug for JoinHandle<T>
79where
80 T: fmt::Debug,
81{
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_struct("JoinHandle")
84 .field("output", &core::any::type_name::<T>())
85 .field("task", &self.state)
86 .field("id", &self.id)
87 .finish()
88 }
89}
90
91impl<T> Future for JoinHandle<T> {
92 type Output = Result<T, JoinError<T>>;
93
94 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95 let this = self.get_mut();
96 let task = match core::mem::replace(&mut this.state, JoinHandleState::Empty) {
97 JoinHandleState::Task(task) => task,
98 JoinHandleState::Empty => {
99 panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug")
100 }
101 };
102
103 // // Keep track of task budget
104 // TODO let coop = ready!(crate::runtime::coop::poll_proceed(cx));
105
106 // Safety: the `JoinHandle` must have been constructed with the
107 // task's actual output type!
108 let poll = unsafe { task.poll_join::<T>(cx) };
109
110 if poll.is_pending() {
111 this.state = JoinHandleState::Task(task);
112 } else {
113 // TODO coop.made_progress();
114
115 // clear join interest
116 task.state().drop_join_handle();
117 }
118 poll
119 }
120}
121
122// ==== PartialEq impls for JoinHandle/TaskRef ====
123
124impl<T> PartialEq<TaskRef> for JoinHandle<T> {
125 fn eq(&self, other: &TaskRef) -> bool {
126 match self.state {
127 JoinHandleState::Task(ref task) => task == other,
128 _ => false,
129 }
130 }
131}
132
133impl<T> PartialEq<&'_ TaskRef> for JoinHandle<T> {
134 fn eq(&self, other: &&TaskRef) -> bool {
135 match self.state {
136 JoinHandleState::Task(ref task) => task == *other,
137 _ => false,
138 }
139 }
140}
141
142impl<T> PartialEq<JoinHandle<T>> for TaskRef {
143 fn eq(&self, other: &JoinHandle<T>) -> bool {
144 match other.state {
145 JoinHandleState::Task(ref task) => self == task,
146 _ => false,
147 }
148 }
149}
150
151impl<T> PartialEq<&'_ JoinHandle<T>> for TaskRef {
152 fn eq(&self, other: &&JoinHandle<T>) -> bool {
153 match other.state {
154 JoinHandleState::Task(ref task) => self == task,
155 _ => false,
156 }
157 }
158}
159
160// ==== PartialEq impls for JoinHandle/Id ====
161
162impl<T> PartialEq<Id> for JoinHandle<T> {
163 #[inline]
164 fn eq(&self, other: &Id) -> bool {
165 self.id == *other
166 }
167}
168
169impl<T> PartialEq<JoinHandle<T>> for Id {
170 #[inline]
171 fn eq(&self, other: &JoinHandle<T>) -> bool {
172 *self == other.id
173 }
174}
175
176impl<T> PartialEq<&'_ JoinHandle<T>> for Id {
177 #[inline]
178 fn eq(&self, other: &&JoinHandle<T>) -> bool {
179 *self == other.id
180 }
181}
182
183impl<T> JoinHandle<T> {
184 pub(crate) fn new(task: TaskRef) -> Self {
185 task.state().create_join_handle();
186
187 Self {
188 id: task.id(),
189 state: JoinHandleState::Task(task),
190 _p: PhantomData,
191 }
192 }
193
194 /// Cancels the task associated with the handle.
195 ///
196 /// Awaiting a cancelled task might complete as usual if the task was already completed at
197 /// the time it was cancelled, but most likely it will fail with a [cancelled] [`JoinError`].
198 ///
199 /// See [the module level docs] for details.
200 ///
201 /// [cancelled]: JoinError::is_cancelled
202 /// [the module level docs]: super#cancellation
203 pub fn cancel(&self) -> bool {
204 match self.state {
205 JoinHandleState::Task(ref task) => task.cancel(),
206 _ => false,
207 }
208 }
209
210 #[inline]
211 #[must_use]
212 pub fn is_complete(&self) -> bool {
213 match self.state {
214 JoinHandleState::Task(ref task) => task.is_complete(),
215 // if the `JoinHandle`'s `TaskRef` has been taken, we know the
216 // `Future` impl for `JoinHandle` completed, and the task has
217 // _definitely_ completed.
218 _ => true,
219 }
220 }
221}
222
223// === impl JoinError ===
224
225impl JoinError<()> {
226 pub(super) fn cancelled(completed: bool, id: Id) -> Self {
227 Self {
228 kind: JoinErrorKind::Cancelled { completed },
229 id,
230 output: None,
231 }
232 }
233
234 pub(super) fn with_output<T>(self, output: Option<T>) -> JoinError<T> {
235 JoinError {
236 kind: self.kind,
237 id: self.id,
238 output,
239 }
240 }
241}
242
243impl<T> JoinError<T> {
244 pub(super) fn panic(id: Id, err: Box<dyn Any + Send + 'static>) -> Self {
245 Self {
246 kind: JoinErrorKind::Panic(err),
247 id,
248 output: None,
249 }
250 }
251
252 pub fn is_completed(&self) -> bool {
253 matches!(&self.kind, JoinErrorKind::Cancelled { completed: true })
254 }
255
256 /// Returns true if the error was caused by the task being cancelled.
257 ///
258 /// See [the module level docs] for more information on cancellation.
259 ///
260 /// [the module level docs]: crate::task#cancellation
261 pub fn is_cancelled(&self) -> bool {
262 matches!(&self.kind, JoinErrorKind::Cancelled { .. })
263 }
264
265 /// Returns true if the error was caused by the task panicking.
266 ///
267 /// # Examples
268 ///
269 // ```
270 // use std::panic;
271 //
272 // #[tokio::main]
273 // async fn main() {
274 // let err = tokio::spawn(async {
275 // panic!("boom");
276 // }).await.unwrap_err();
277 //
278 // assert!(err.is_panic());
279 // }
280 // ```
281 pub fn is_panic(&self) -> bool {
282 matches!(&self.kind, JoinErrorKind::Panic(_))
283 }
284
285 /// Consumes the join error, returning the object with which the task panicked.
286 ///
287 /// # Panics
288 ///
289 /// `into_panic()` panics if the `Error` does not represent the underlying
290 /// task terminating with a panic. Use `is_panic` to check the error reason
291 /// or `try_into_panic` for a variant that does not panic.
292 ///
293 /// # Examples
294 //
295 // ```should_panic
296 // use std::panic;
297 //
298 // #[tokio::main]
299 // async fn main() {
300 // let err = tokio::spawn(async {
301 // panic!("boom");
302 // }).await.unwrap_err();
303 //
304 // if err.is_panic() {
305 // // Resume the panic on the main task
306 // panic::begin_unwind(err.into_panic());
307 // }
308 // }
309 // ```
310 #[track_caller]
311 pub fn into_panic(self) -> Box<dyn Any + Send + 'static> {
312 self.try_into_panic()
313 .expect("`JoinError` reason is not a panic.")
314 }
315
316 /// Consumes the join error, returning the object with which the task
317 /// panicked if the task terminated due to a panic. Otherwise, `self` is
318 /// returned.
319 ///
320 // # Examples
321 //
322 // ```should_panic
323 // use std::panic;
324 //
325 // let err = tokio::spawn(async {
326 // panic!("boom");
327 // }).await.unwrap_err();
328 //
329 // if let Ok(reason) = err.try_into_panic() {
330 // /// Resume the panic on the main task
331 // panic::begin_unwind(reason);
332 // }
333 // ```
334 ///
335 /// # Errors
336 ///
337 /// Returns an `Err(Self)` when the error was **not** a panic.
338 pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, Self> {
339 match self.kind {
340 JoinErrorKind::Panic(p) => Ok(p),
341 _ => Err(self),
342 }
343 }
344
345 /// Returns a [task ID] that identifies the task which errored relative to
346 /// other currently spawned tasks.
347 ///
348 /// [task ID]: Id
349 pub fn id(&self) -> Id {
350 self.id
351 }
352
353 /// Returns the task's output, if the task completed successfully before it
354 /// was canceled.
355 ///
356 /// Otherwise, returns `None`.
357 pub fn output(self) -> Option<T> {
358 self.output
359 }
360}
361
362impl<T> fmt::Display for JoinError<T> {
363 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
364 match &self.kind {
365 JoinErrorKind::Cancelled { completed: false } => {
366 write!(fmt, "task {} was cancelled before completion", self.id)
367 }
368 JoinErrorKind::Cancelled { completed: true } => {
369 write!(fmt, "task {} was cancelled after completion", self.id)
370 }
371 JoinErrorKind::Panic(p) => {
372 write!(
373 fmt,
374 "task {} panicked with message {:?}",
375 self.id,
376 panic_payload_as_str(p.deref())
377 )
378 }
379 }
380 }
381}
382
383impl<T> fmt::Debug for JoinError<T> {
384 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
385 match &self.kind {
386 JoinErrorKind::Cancelled { completed } => write!(
387 fmt,
388 "JoinError::Cancelled({:?}, completed: {completed})",
389 self.id
390 ),
391 JoinErrorKind::Panic(p) => {
392 write!(
393 fmt,
394 "JoinError::Panic({:?}, {:?}, ...)",
395 self.id,
396 panic_payload_as_str(p.deref())
397 )
398 }
399 }
400 }
401}
402
403impl<T> core::error::Error for JoinError<T> {}
404
405fn panic_payload_as_str(payload: &dyn Any) -> &str {
406 if let Some(&s) = payload.downcast_ref::<&'static str>() {
407 s
408 } else if let Some(s) = payload.downcast_ref::<String>() {
409 s.as_str()
410 } else {
411 "Box<dyn Any>"
412 }
413}