// Copyright 2025 Jonas Kruckenberg // // Licensed under the Apache License, Version 2.0, or the MIT license , at your option. This file may not be // copied, modified, or distributed except according to those terms. use crate::task::Id; use crate::task::TaskRef; use alloc::boxed::Box; use alloc::string::String; use core::any::Any; use core::fmt; use core::future::Future; use core::marker::PhantomData; use core::ops::Deref; use core::panic::{RefUnwindSafe, UnwindSafe}; use core::pin::Pin; use core::task::{Context, Poll}; pub struct JoinHandle { state: JoinHandleState, id: Id, _p: PhantomData, } static_assertions::assert_impl_all!(JoinHandle<()>: Send); #[derive(Debug)] enum JoinHandleState { Task(TaskRef), Empty, } pub struct JoinError { kind: JoinErrorKind, id: Id, output: Option, } #[derive(Debug)] #[non_exhaustive] pub enum JoinErrorKind { Cancelled { completed: bool }, Panic(Box), } // === impl JoinHandle === impl UnwindSafe for JoinHandle {} impl RefUnwindSafe for JoinHandle {} impl Unpin for JoinHandle {} impl Drop for JoinHandle { fn drop(&mut self) { // if the JoinHandle has not already been consumed, clear the join // handle flag on the task. if let JoinHandleState::Task(ref task) = self.state { tracing::trace!( state=?self.state, task.id=?task.id(), consumed=false, "drop JoinHandle" ); task.state().drop_join_handle(); } else { tracing::trace!( state=?self.state, consumed=false, "drop JoinHandle" ); } } } impl fmt::Debug for JoinHandle where T: fmt::Debug, { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("JoinHandle") .field("output", &core::any::type_name::()) .field("task", &self.state) .field("id", &self.id) .finish() } } impl Future for JoinHandle { type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.get_mut(); let task = match core::mem::replace(&mut this.state, JoinHandleState::Empty) { JoinHandleState::Task(task) => task, JoinHandleState::Empty => { panic!("`TaskRef` only taken while polling a `JoinHandle`; this is a bug") } }; // // Keep track of task budget // TODO let coop = ready!(crate::runtime::coop::poll_proceed(cx)); // Safety: the `JoinHandle` must have been constructed with the // task's actual output type! let poll = unsafe { task.poll_join::(cx) }; if poll.is_pending() { this.state = JoinHandleState::Task(task); } else { // TODO coop.made_progress(); // clear join interest task.state().drop_join_handle(); } poll } } // ==== PartialEq impls for JoinHandle/TaskRef ==== impl PartialEq for JoinHandle { fn eq(&self, other: &TaskRef) -> bool { match self.state { JoinHandleState::Task(ref task) => task == other, _ => false, } } } impl PartialEq<&'_ TaskRef> for JoinHandle { fn eq(&self, other: &&TaskRef) -> bool { match self.state { JoinHandleState::Task(ref task) => task == *other, _ => false, } } } impl PartialEq> for TaskRef { fn eq(&self, other: &JoinHandle) -> bool { match other.state { JoinHandleState::Task(ref task) => self == task, _ => false, } } } impl PartialEq<&'_ JoinHandle> for TaskRef { fn eq(&self, other: &&JoinHandle) -> bool { match other.state { JoinHandleState::Task(ref task) => self == task, _ => false, } } } // ==== PartialEq impls for JoinHandle/Id ==== impl PartialEq for JoinHandle { #[inline] fn eq(&self, other: &Id) -> bool { self.id == *other } } impl PartialEq> for Id { #[inline] fn eq(&self, other: &JoinHandle) -> bool { *self == other.id } } impl PartialEq<&'_ JoinHandle> for Id { #[inline] fn eq(&self, other: &&JoinHandle) -> bool { *self == other.id } } impl JoinHandle { pub(crate) fn new(task: TaskRef) -> Self { task.state().create_join_handle(); Self { id: task.id(), state: JoinHandleState::Task(task), _p: PhantomData, } } /// Cancels the task associated with the handle. /// /// Awaiting a cancelled task might complete as usual if the task was already completed at /// the time it was cancelled, but most likely it will fail with a [cancelled] [`JoinError`]. /// /// See [the module level docs] for details. /// /// [cancelled]: JoinError::is_cancelled /// [the module level docs]: super#cancellation pub fn cancel(&self) -> bool { match self.state { JoinHandleState::Task(ref task) => task.cancel(), _ => false, } } #[inline] #[must_use] pub fn is_complete(&self) -> bool { match self.state { JoinHandleState::Task(ref task) => task.is_complete(), // if the `JoinHandle`'s `TaskRef` has been taken, we know the // `Future` impl for `JoinHandle` completed, and the task has // _definitely_ completed. _ => true, } } } // === impl JoinError === impl JoinError<()> { pub(super) fn cancelled(completed: bool, id: Id) -> Self { Self { kind: JoinErrorKind::Cancelled { completed }, id, output: None, } } pub(super) fn with_output(self, output: Option) -> JoinError { JoinError { kind: self.kind, id: self.id, output, } } } impl JoinError { pub(super) fn panic(id: Id, err: Box) -> Self { Self { kind: JoinErrorKind::Panic(err), id, output: None, } } pub fn is_completed(&self) -> bool { matches!(&self.kind, JoinErrorKind::Cancelled { completed: true }) } /// Returns true if the error was caused by the task being cancelled. /// /// See [the module level docs] for more information on cancellation. /// /// [the module level docs]: crate::task#cancellation pub fn is_cancelled(&self) -> bool { matches!(&self.kind, JoinErrorKind::Cancelled { .. }) } /// Returns true if the error was caused by the task panicking. /// /// # Examples /// // ``` // use std::panic; // // #[tokio::main] // async fn main() { // let err = tokio::spawn(async { // panic!("boom"); // }).await.unwrap_err(); // // assert!(err.is_panic()); // } // ``` pub fn is_panic(&self) -> bool { matches!(&self.kind, JoinErrorKind::Panic(_)) } /// Consumes the join error, returning the object with which the task panicked. /// /// # Panics /// /// `into_panic()` panics if the `Error` does not represent the underlying /// task terminating with a panic. Use `is_panic` to check the error reason /// or `try_into_panic` for a variant that does not panic. /// /// # Examples // // ```should_panic // use std::panic; // // #[tokio::main] // async fn main() { // let err = tokio::spawn(async { // panic!("boom"); // }).await.unwrap_err(); // // if err.is_panic() { // // Resume the panic on the main task // panic::begin_unwind(err.into_panic()); // } // } // ``` #[track_caller] pub fn into_panic(self) -> Box { self.try_into_panic() .expect("`JoinError` reason is not a panic.") } /// Consumes the join error, returning the object with which the task /// panicked if the task terminated due to a panic. Otherwise, `self` is /// returned. /// // # Examples // // ```should_panic // use std::panic; // // let err = tokio::spawn(async { // panic!("boom"); // }).await.unwrap_err(); // // if let Ok(reason) = err.try_into_panic() { // /// Resume the panic on the main task // panic::begin_unwind(reason); // } // ``` /// /// # Errors /// /// Returns an `Err(Self)` when the error was **not** a panic. pub fn try_into_panic(self) -> Result, Self> { match self.kind { JoinErrorKind::Panic(p) => Ok(p), _ => Err(self), } } /// Returns a [task ID] that identifies the task which errored relative to /// other currently spawned tasks. /// /// [task ID]: Id pub fn id(&self) -> Id { self.id } /// Returns the task's output, if the task completed successfully before it /// was canceled. /// /// Otherwise, returns `None`. pub fn output(self) -> Option { self.output } } impl fmt::Display for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.kind { JoinErrorKind::Cancelled { completed: false } => { write!(fmt, "task {} was cancelled before completion", self.id) } JoinErrorKind::Cancelled { completed: true } => { write!(fmt, "task {} was cancelled after completion", self.id) } JoinErrorKind::Panic(p) => { write!( fmt, "task {} panicked with message {:?}", self.id, panic_payload_as_str(p.deref()) ) } } } } impl fmt::Debug for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.kind { JoinErrorKind::Cancelled { completed } => write!( fmt, "JoinError::Cancelled({:?}, completed: {completed})", self.id ), JoinErrorKind::Panic(p) => { write!( fmt, "JoinError::Panic({:?}, {:?}, ...)", self.id, panic_payload_as_str(p.deref()) ) } } } } impl core::error::Error for JoinError {} fn panic_payload_as_str(payload: &dyn Any) -> &str { if let Some(&s) = payload.downcast_ref::<&'static str>() { s } else if let Some(s) = payload.downcast_ref::() { s.as_str() } else { "Box" } }