Next Generation WASM Microkernel Operating System
1// Copyright 2025. Jonas Kruckenberg
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8mod builder;
9mod id;
10mod join_handle;
11mod state;
12mod yield_now;
13
14use alloc::boxed::Box;
15use core::alloc::AllocError;
16use core::any::type_name;
17use core::mem::{ManuallyDrop, offset_of};
18use core::panic::AssertUnwindSafe;
19use core::pin::Pin;
20use core::ptr::NonNull;
21use core::sync::atomic::Ordering;
22use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
23use core::{fmt, mem};
24
25pub use builder::TaskBuilder;
26use cordyceps::mpsc_queue;
27pub use id::Id;
28pub use join_handle::{JoinError, JoinHandle};
29use k32_util::{CachePadded, CheckedMaybeUninit, loom_const_fn};
30pub use yield_now::yield_now;
31
32use crate::executor::Scheduler;
33use crate::loom::cell::UnsafeCell;
34use crate::task::state::{JoinAction, StartPollAction, State, WakeByRefAction, WakeByValAction};
35
36/// Outcome of calling [`Task::poll`].
37///
38/// This type describes how to proceed with a given task, whether it needs to be rescheduled
39/// or can be dropped etc.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub(crate) enum PollResult {
42 /// The task has completed, without waking a [`JoinHandle`] waker.
43 ///
44 /// The scheduler can increment a counter of completed tasks, and then drop
45 /// the [`TaskRef`].
46 Ready,
47
48 /// The task has completed and a [`JoinHandle`] waker has been woken.
49 ///
50 /// The scheduler can increment a counter of completed tasks, and then drop
51 /// the [`TaskRef`].
52 ReadyJoined,
53
54 /// The task is pending, but not woken.
55 ///
56 /// The scheduler can drop the [`TaskRef`], as whoever intends to wake the
57 /// task later is holding a clone of its [`Waker`].
58 Pending,
59
60 /// The task has woken itself during the poll.
61 ///
62 /// The scheduler should re-schedule the task, rather than dropping the [`TaskRef`].
63 PendingSchedule,
64}
65
66/// A type-erased, reference-counted pointer to a spawned `Task`.
67///
68/// Once a `Task` is spawned, it is generally pinned in memory (a requirement of [`Future`]). Instead
69/// of moving `Task`s around the scheduler, we therefore use `TaskRef`s which are just pointers to the
70/// pinned `Task`. `TaskRef`s are type-erased interacting with the allocated `Tasks` through their
71/// `Vtable` methods. This is done to reduce the monopolization cost otherwise incurred, since futures,
72/// especially ones crated through `async {}` blocks, `async` closures or `async fn` calls are all
73/// treated as *unique*, *disjoint* types which would all cause separate normalizations. E.g. spawning
74/// 10 futures on the runtime (which is not a lot) would cause 10 different copies of the entire runtime
75/// to be compiled, obviously terrible! The `Vtable` allows us to treat all spawned futures, regardless
76/// of their exact type, the same way.
77///
78/// `TaskRef`s are reference-counted, and the task will be deallocated when the
79/// last `TaskRef` pointing to it is dropped.
80#[derive(Eq, PartialEq)]
81pub struct TaskRef(NonNull<Header>);
82
83#[repr(C)]
84struct Task<F: Future, M>(CachePadded<TaskInner<F, M>>);
85
86#[repr(C)]
87struct TaskInner<F: Future, M> {
88 /// This must be the first field of the `Task` struct!
89 header_and_metadata: HeaderAndMetadata<M>,
90
91 /// The future that the task is running.
92 ///
93 /// If `COMPLETE` is one, then the `JoinHandle` has exclusive access to this field
94 /// If COMPLETE is zero, then the RUNNING bitfield functions as
95 /// a lock for the stage field, and it can be accessed only by the thread
96 /// that set RUNNING to one.
97 stage: UnsafeCell<Stage<F>>,
98
99 /// Consumer task waiting on completion of this task.
100 ///
101 /// This field may be access by different threads: on one cpu we may complete a task and *read*
102 /// the waker field to invoke the waker, and in another thread the task's `JoinHandle` may be
103 /// polled, and if the task hasn't yet completed, the `JoinHandle` may *write* a waker to the
104 /// waker field. The `JOIN_WAKER` bit in the headers`state` field ensures safe access by multiple
105 /// cpu to the waker field using the following rules:
106 ///
107 /// 1. `JOIN_WAKER` is initialized to zero.
108 ///
109 /// 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable)
110 /// access to the waker field.
111 ///
112 /// 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) access to the waker
113 /// field.
114 ///
115 /// 4. If `JOIN_WAKER` is one and COMPLETE is one, then the executor has shared (read-only) access
116 /// to the waker field.
117 ///
118 /// 5. If the `JoinHandle` needs to write to the waker field, then the `JoinHandle` needs to
119 /// (i) successfully set `JOIN_WAKER` to zero if it is not already zero to gain exclusive access
120 /// to the waker field per rule 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER`
121 /// to one. If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped
122 /// to clear the waker field, only steps (i) and (ii) are relevant.
123 ///
124 /// 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e.
125 /// the task hasn't yet completed). The executor can change `JOIN_WAKER` only
126 /// if COMPLETE is one.
127 ///
128 /// 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the executor has
129 /// exclusive (mutable) access to the waker field. This might happen if the
130 /// `JoinHandle` gets dropped right after the task completes and the executor
131 /// sets the `COMPLETE` bit. In this case the executor needs the mutable access
132 /// to the waker field to drop it.
133 ///
134 /// Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a
135 /// race. If step (i) fails, then the attempt to write a waker is aborted. If step (iii) fails
136 /// because `COMPLETE` is set to one by another thread after step (i), then the waker field is
137 /// cleared. Once `COMPLETE` is one (i.e. task has completed), the `JoinHandle` will not
138 /// modify `JOIN_WAKER`. After the runtime sets COMPLETE to one, it invokes the waker if there
139 /// is one so in this case when a task completes the `JOIN_WAKER` bit implicates to the runtime
140 /// whether it should invoke the waker or not. After the runtime is done with using the waker
141 /// during task completion, it unsets the `JOIN_WAKER` bit to give the `JoinHandle` exclusive
142 /// access again so that it is able to drop the waker at a later point.
143 join_waker: UnsafeCell<Option<Waker>>,
144}
145
146#[repr(C)]
147struct HeaderAndMetadata<M> {
148 /// This must be the first field of the `HeaderAndMetadata` struct!
149 header: Header,
150 metadata: M,
151}
152
153/// The current lifecycle stage of the future. Either the future itself or its output.
154#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
155enum Stage<F: Future> {
156 /// The future is still pending.
157 Pending(F),
158
159 /// The future has completed, and its output is ready to be taken by a
160 /// `JoinHandle`, if one exists.
161 Ready(Result<F::Output, JoinError<F::Output>>),
162
163 /// The future has completed, and the task's output has been taken or is not
164 /// needed.
165 Consumed,
166}
167
168#[derive(Debug)]
169pub(crate) struct Header {
170 /// The task's state.
171 ///
172 /// This field is access with atomic instructions, so it's always safe to access it.
173 state: State,
174 /// The task vtable for this task.
175 vtable: &'static VTable,
176 /// The task's ID.
177 id: Id,
178 run_queue_links: mpsc_queue::Links<Self>,
179 /// The tracing span associated with this task, for debugging purposes.
180 span: tracing::Span,
181 scheduler: UnsafeCell<Option<&'static Scheduler>>,
182}
183
184#[derive(Debug)]
185struct VTable {
186 /// Poll the future, returning a [`PollResult`] that indicates what the
187 /// scheduler should do with the polled task.
188 poll: unsafe fn(NonNull<Header>) -> PollResult,
189
190 /// Poll the task's `JoinHandle` for completion, storing the output at the
191 /// provided [`NonNull`] pointer if the task has completed.
192 ///
193 /// If the task has not completed, the [`Waker`] from the provided
194 /// [`Context`] is registered to be woken when the task completes.
195 // Splitting this up into type aliases just makes it *harder* to understand
196 // IMO...
197 #[expect(clippy::type_complexity, reason = "")]
198 poll_join: unsafe fn(
199 ptr: NonNull<Header>,
200 outptr: NonNull<()>,
201 cx: &mut Context<'_>,
202 ) -> Poll<Result<(), JoinError<()>>>,
203
204 /// Drops the task and deallocates its memory.
205 deallocate: unsafe fn(NonNull<Header>),
206}
207
208// === impl TaskRef ===
209
210impl TaskRef {
211 /// Returns the tasks unique[^1] identifier.
212 ///
213 /// [^1]: Unique to all *currently running* tasks, *not* unique across spacetime. See [`Id`] for details.
214 pub fn id(&self) -> Id {
215 self.header().id
216 }
217
218 /// # Safety
219 ///
220 /// The caller must ensure the generic argument matches the metadata type this task got created with.
221 pub unsafe fn metadata<M>(&self) -> &M {
222 // Safety: ensured by caller
223 unsafe { &self.0.cast::<HeaderAndMetadata<M>>().as_ref().metadata }
224 }
225
226 /// Returns `true` when this task has run to completion.
227 pub fn is_complete(&self) -> bool {
228 self.state()
229 .load(Ordering::Acquire)
230 .get(state::Snapshot::COMPLETE)
231 }
232
233 /// Cancels the task.
234 pub fn cancel(&self) -> bool {
235 // try to set the canceled bit.
236 let canceled = self.state().cancel();
237
238 // if the task was successfully canceled, wake it so that it can clean
239 // up after itself.
240 if canceled {
241 tracing::trace!("waking canceled task...");
242 todo!()
243 }
244
245 canceled
246 }
247
248 pub(crate) fn poll(&self) -> PollResult {
249 let poll_fn = self.header().vtable.poll;
250 // Safety: Called through our Vtable so this access should be fine
251 unsafe { poll_fn(self.0) }
252 }
253
254 /// # Safety
255 ///
256 /// The caller needs to make sure that `T` is the same type as the one that this `TaskRef` was
257 /// created with.
258 pub(crate) unsafe fn poll_join<T>(
259 &self,
260 cx: &mut Context<'_>,
261 ) -> Poll<Result<T, JoinError<T>>> {
262 let poll_join_fn = self.header().vtable.poll_join;
263 let mut slot = CheckedMaybeUninit::<Result<T, JoinError<T>>>::uninit();
264
265 // Safety: This is called through the Vtable and as long as the caller makes sure that the `T` is the right
266 // type, this call is safe
267 let result = unsafe { poll_join_fn(self.0, NonNull::from(&mut slot).cast::<()>(), cx) };
268
269 result.map(|result| {
270 if let Err(e) = result {
271 let output = if e.is_completed() {
272 // Safety: if the task completed before being canceled, we can still
273 // take its output.
274 Some(unsafe { slot.assume_init_read() }?)
275 } else {
276 None
277 };
278 Err(e.with_output(output))
279 } else {
280 // Safety: if the poll function returned `Ok`, we get to take the
281 // output!
282 unsafe { slot.assume_init_read() }
283 }
284 })
285 }
286
287 /// Bind this task to a new scheduler
288 ///
289 /// # Safety
290 ///
291 /// The new scheduler `S` must be of the **same** type as the scheduler that this task got created
292 /// with. The shape of the allocated tasks depend on the type of the scheduler, binding a task
293 /// to a differently typed scheduler will therefore cause invalid memory accesses.
294 pub(crate) fn bind_scheduler(&self, scheduler: &'static Scheduler) {
295 // Safety: the repr(C) on Schedulable ensures the layout matches and this cast is safe
296 unsafe {
297 self.0
298 .as_ref()
299 .scheduler
300 .with_mut(|current| *current = Some(scheduler));
301 }
302 }
303
304 pub(crate) fn new_stub() -> Result<Self, AllocError> {
305 const HEAP_STUB_VTABLE: VTable = VTable {
306 poll: stub_poll,
307 poll_join: stub_poll_join,
308 deallocate: stub_deallocate,
309 };
310
311 unsafe fn stub_poll(ptr: NonNull<Header>) -> PollResult {
312 // Safety: this method should never be called
313 unsafe {
314 debug_assert!(ptr.as_ref().id.is_stub());
315 unreachable!("stub task ({ptr:?}) should never be polled!");
316 }
317 }
318
319 unsafe fn stub_poll_join(
320 ptr: NonNull<Header>,
321 _outptr: NonNull<()>,
322 _cx: &mut Context<'_>,
323 ) -> Poll<Result<(), JoinError<()>>> {
324 // Safety: this method should never be called
325 unsafe {
326 debug_assert!(ptr.as_ref().id.is_stub());
327 unreachable!("stub task ({ptr:?}) should never be polled!");
328 }
329 }
330
331 unsafe fn stub_deallocate(ptr: NonNull<Header>) {
332 // Safety: this method should never be called
333 unsafe {
334 debug_assert!(ptr.as_ref().id.is_stub());
335 drop(Box::from_raw(ptr.as_ptr()));
336 }
337 }
338
339 let inner = Box::into_raw(Box::try_new(Header {
340 state: State::new(),
341 vtable: &HEAP_STUB_VTABLE,
342 id: Id::stub(),
343 run_queue_links: mpsc_queue::Links::new_stub(),
344 span: tracing::Span::none(),
345 scheduler: UnsafeCell::new(None),
346 })?);
347
348 // Safety: we just allocated the ptr so it is never null
349 Ok(Self(unsafe { NonNull::new_unchecked(inner) }))
350 }
351
352 pub(crate) fn clone_from_raw(ptr: NonNull<Header>) -> TaskRef {
353 let this = Self(ptr);
354 this.state().clone_ref();
355 this
356 }
357
358 pub(crate) fn header_ptr(&self) -> NonNull<Header> {
359 self.0
360 }
361
362 // ===== private methods =====
363
364 #[track_caller]
365 fn new_allocated<F, M>(task: Box<Task<F, M>>) -> (Self, JoinHandle<F::Output, M>)
366 where
367 F: Future,
368 {
369 assert_eq!(task.state().refcount(), 1);
370 let ptr = Box::into_raw(task);
371
372 // Safety: we just allocated the ptr so it is never null
373 let task = Self(unsafe { NonNull::new_unchecked(ptr).cast() });
374 let join = JoinHandle::new(task.clone());
375
376 (task, join)
377 }
378
379 fn schedule(self) {
380 let scheduler = self.header().scheduler.with(|scheduler| {
381 // Safety: ensured by caller
382 unsafe {
383 (*scheduler).expect("task doesn't have an associated scheduler, this is a bug!")
384 }
385 });
386
387 scheduler.schedule(self);
388 }
389
390 fn header(&self) -> &Header {
391 // Safety: constructor ensures the pointer is always valid
392 unsafe { self.0.as_ref() }
393 }
394
395 /// Returns a reference to the task's state.
396 fn state(&self) -> &State {
397 &self.header().state
398 }
399
400 unsafe fn from_raw(ptr: *const Header) -> Self {
401 // Safety: ensured by caller
402 Self(unsafe { NonNull::new_unchecked(ptr.cast_mut()) })
403 }
404
405 fn into_raw(self) -> *const Header {
406 let this = ManuallyDrop::new(self);
407 this.0.as_ptr()
408 }
409
410 // ===== private waker-related methods =====
411
412 fn wake(self) {
413 match self.state().wake_by_val() {
414 // `Enqueue` means we should put the task back into the queue. In this case we want to assign
415 // ownership of the TaskRef to the scheduler.
416 WakeByValAction::Enqueue => {
417 self.schedule();
418 }
419 // `Drop` means (unsurprisingly) the opposite: The task does not need to be enqueued, and
420 // we should therefore drop this handle.
421 WakeByValAction::Drop => {
422 drop(self);
423 }
424 }
425 }
426
427 /// # Safety
428 ///
429 /// The caller must ensure this task is currently bound to a scheduler.
430 unsafe fn wake_by_val(&self) {
431 if self.state().wake_by_ref() == WakeByRefAction::Enqueue {
432 self.clone().schedule();
433 }
434 }
435
436 fn into_raw_waker(self) -> RawWaker {
437 // Increment the reference count of the arc to clone it.
438 //
439 // The #[inline(always)] is to ensure that raw_waker and clone_waker are
440 // always generated in the same code generation unit as one another, and
441 // therefore that the structurally identical const-promoted RawWakerVTable
442 // within both functions is deduplicated at LLVM IR code generation time.
443 // This allows optimizing Waker::will_wake to a single pointer comparison of
444 // the vtable pointers, rather than comparing all four function pointers
445 // within the vtables.
446 #[inline(always)]
447 unsafe fn clone_waker(waker: *const ()) -> RawWaker {
448 // Safety: ensured by VTable
449 unsafe {
450 waker.cast::<Header>().as_ref().unwrap().state.clone_ref();
451 }
452
453 RawWaker::new(
454 waker,
455 &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker),
456 )
457 }
458
459 // Wake by value, moving the Arc into the Wake::wake function
460 unsafe fn wake(waker: *const ()) {
461 // Safety: ensured by VTable
462 let task = unsafe { TaskRef::from_raw(waker.cast::<Header>()) };
463
464 tracing::trace!(
465 task.addr = ?task.header_ptr(),
466 task.tid = task.id().as_u64(),
467 "Task::wake_by_val"
468 );
469
470 task.wake();
471 }
472
473 // Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it
474 unsafe fn wake_by_ref(waker: *const ()) {
475 // Safety: ensured by VTable
476 let task = unsafe { ManuallyDrop::new(TaskRef::from_raw(waker.cast::<Header>())) };
477
478 tracing::trace!(
479 task.addr = ?task.header_ptr(),
480 task.tid = task.id().as_u64(),
481 "Task::wake_by_ref"
482 );
483
484 // Safety: ensured by the caller
485 unsafe {
486 task.wake_by_val();
487 }
488 }
489
490 // Decrement the reference count of the Arc on drop
491 unsafe fn drop_waker(waker: *const ()) {
492 // Safety: ensured by VTable
493 drop(unsafe { TaskRef::from_raw(waker.cast()) });
494 }
495
496 RawWaker::new(
497 self.into_raw().cast(),
498 &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker),
499 )
500 }
501}
502
503impl fmt::Debug for TaskRef {
504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
505 f.debug_struct("TaskRef")
506 .field("id", &self.id())
507 .field("addr", &self.0)
508 .finish()
509 }
510}
511
512impl fmt::Pointer for TaskRef {
513 #[inline]
514 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
515 fmt::Pointer::fmt(&self.0, f)
516 }
517}
518
519impl Clone for TaskRef {
520 #[inline]
521 #[track_caller]
522 fn clone(&self) -> Self {
523 let loc = core::panic::Location::caller();
524 tracing::trace!(
525 task.addr=?self.0,
526 task.is_stub=self.id().is_stub(),
527 loc.file = loc.file(),
528 loc.line = loc.line(),
529 loc.col = loc.column(),
530 "TaskRef::clone",
531 );
532 self.state().clone_ref();
533 Self(self.0)
534 }
535}
536
537impl Drop for TaskRef {
538 #[inline]
539 #[track_caller]
540 fn drop(&mut self) {
541 tracing::trace!(
542 task.addr=?self.0,
543 task.is_stub=self.id().is_stub(),
544 "TaskRef::drop"
545 );
546 if !self.state().drop_ref() {
547 return;
548 }
549
550 let deallocate = self.header().vtable.deallocate;
551 // Safety: as long as we're constructed from a NonNull<Header> this is safe
552 unsafe {
553 deallocate(self.0);
554 }
555 }
556}
557
558// Safety: The state protocol ensured synchronized access to the inner task
559unsafe impl Send for TaskRef {}
560// Safety: The state protocol ensured synchronized access to the inner task
561unsafe impl Sync for TaskRef {}
562
563// === impl Task ===
564
565impl<F: Future, M> Task<F, M> {
566 const TASK_VTABLE: VTable = VTable {
567 poll: Self::poll,
568 poll_join: Self::poll_join,
569 deallocate: Self::deallocate,
570 };
571
572 loom_const_fn! {
573 pub const fn new(future: F, task_id: Id, span: tracing::Span, metadata: M) -> Self {
574 let inner = TaskInner {
575 header_and_metadata: HeaderAndMetadata {
576 header: Header {
577 state: State::new(),
578 vtable: &Self::TASK_VTABLE,
579 id: task_id,
580 run_queue_links: mpsc_queue::Links::new(),
581 span,
582 scheduler: UnsafeCell::new(None)
583 },
584 metadata
585 },
586 stage: UnsafeCell::new(Stage::Pending(future)),
587 join_waker: UnsafeCell::new(None),
588 };
589 Self(CachePadded(inner))
590 }
591 }
592
593 /// Poll the future, returning a [`PollResult`] that indicates what the
594 /// scheduler should do with the polled task.
595 ///
596 /// This is a type-erased function called through the task's [`Vtable`].
597 ///
598 /// # Safety
599 ///
600 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
601 /// pointed header must have the same `S`, `F`, and `STO` type parameters
602 /// as `Self`).
603 /// - the task must currently be bound to a scheduler
604 unsafe fn poll(ptr: NonNull<Header>) -> PollResult {
605 // Safety: ensured by caller
606 unsafe {
607 let this = ptr.cast::<Self>().as_ref();
608
609 tracing::trace!(
610 task.addr=?ptr,
611 task.output=type_name::<F::Output>(),
612 task.id=?this.id(),
613 "Task::poll",
614 );
615
616 match this.state().start_poll() {
617 // Successfully to transitioned to `POLLING` all is good!
618 StartPollAction::Poll => {}
619 // Something isn't right, we shouldn't poll the task right now...
620 StartPollAction::DontPoll => {
621 tracing::warn!(task.addr=?ptr, "failed to transition to polling",);
622 return PollResult::Ready;
623 }
624 StartPollAction::Cancelled { wake_join_waker } => {
625 tracing::trace!(task.addr=?ptr, "task cancelled");
626 if wake_join_waker {
627 this.wake_join_waker();
628 return PollResult::ReadyJoined;
629 } else {
630 return PollResult::Ready;
631 }
632 }
633 }
634
635 // wrap the waker in `ManuallyDrop` because we're converting it from an
636 // existing task ref, rather than incrementing the task ref count. if
637 // this waker is consumed during the poll, we don't want to decrement
638 // its ref count when the poll ends.
639 let waker = {
640 let raw = TaskRef(ptr).into_raw_waker();
641 ManuallyDrop::new(Waker::from_raw(raw))
642 };
643
644 // actually poll the task
645 let poll = {
646 let cx = Context::from_waker(&waker);
647
648 this.poll_inner(cx)
649 };
650
651 let result = this.state().end_poll(poll.is_ready());
652
653 // if the task is ready and has a `JoinHandle` to wake, wake the join
654 // waker now.
655 if result == PollResult::ReadyJoined {
656 this.wake_join_waker();
657 }
658
659 result
660 }
661 }
662
663 /// Poll to join the task pointed to by `ptr`, taking its output if it has
664 /// completed.
665 ///
666 /// If the task has completed, this method returns [`Poll::Ready`], and the
667 /// task's output is stored at the memory location pointed to by `outptr`.
668 /// This function is called by [`JoinHandle`]s o poll the task they
669 /// correspond to.
670 ///
671 /// This is a type-erased function called through the task's [`Vtable`].
672 ///
673 /// # Safety
674 ///
675 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
676 /// pointed header must have the same `S`, `F`, and `STO` type parameters
677 /// as `Self`).
678 /// - `outptr` must point to a valid `MaybeUninit<F::Output>`.
679 unsafe fn poll_join(
680 ptr: NonNull<Header>,
681 outptr: NonNull<()>,
682 cx: &mut Context<'_>,
683 ) -> Poll<Result<(), JoinError<()>>> {
684 // Safety: ensured by caller
685 unsafe {
686 let this = ptr.cast::<Self>().as_ref();
687 tracing::trace!(
688 task.addr=?ptr,
689 task.output=type_name::<F::Output>(),
690 task.id=?this.id(),
691 "Task::poll_join"
692 );
693
694 match this.state().try_join() {
695 JoinAction::TakeOutput => {
696 // safety: if the state transition returns
697 // `JoinAction::TakeOutput`, this indicates that we have
698 // exclusive permission to read the task output.
699 this.take_output(outptr);
700 return Poll::Ready(Ok(()));
701 }
702 JoinAction::Canceled { completed } => {
703 // if the task has completed before it was canceled, also try to
704 // read the output, so that it can be returned in the `JoinError`.
705 if completed {
706 // safety: if the state transition returned `Canceled`
707 // with `completed` set, this indicates that we have
708 // exclusive permission to take the output.
709 this.take_output(outptr);
710 }
711 return Poll::Ready(Err(JoinError::cancelled(completed, *this.id())));
712 }
713 JoinAction::Register => {
714 this.0.0.join_waker.with_mut(|waker| {
715 waker.write(Some(cx.waker().clone()));
716 });
717 }
718 JoinAction::Reregister => {
719 this.0.0.join_waker.with_mut(|waker| {
720 let waker = (*waker).as_mut().unwrap();
721
722 let new_waker = cx.waker();
723 if !waker.will_wake(new_waker) {
724 *waker = new_waker.clone();
725 }
726 });
727 }
728 }
729 this.state().join_waker_registered();
730 Poll::Pending
731 }
732 }
733
734 /// Drops the task and deallocates its memory.
735 ///
736 /// This is a type-erased function called through the task's [`Vtable`].
737 ///
738 /// # Safety
739 ///
740 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
741 /// pointed header must have the same `S`, `F`, and `STO` type parameters
742 /// as `Self`).
743 unsafe fn deallocate(ptr: NonNull<Header>) {
744 // Safety: ensured by caller
745 unsafe {
746 let this = ptr.cast::<Self>();
747 tracing::trace!(
748 task.addr=?ptr,
749 task.output=type_name::<F::Output>(),
750 task.id=?this.as_ref().id(),
751 task.is_stub=?this.as_ref().id().is_stub(),
752 "Task::deallocate",
753 );
754 debug_assert_eq!(
755 ptr.as_ref().state.load(Ordering::Acquire).ref_count(),
756 0,
757 "a task may not be deallocated if its ref count is greater than zero!"
758 );
759 drop(Box::from_raw(this.as_ptr()));
760 }
761 }
762
763 /// Polls the future. If the future completes, the output is written to the
764 /// stage field.
765 ///
766 /// # Safety
767 ///
768 /// The caller has to ensure this cpu has exclusive mutable access to the tasks `stage` field (ie the
769 /// future or output).
770 pub unsafe fn poll_inner(&self, mut cx: Context) -> Poll<()> {
771 let _span = self.span().enter();
772
773 self.0.0.stage.with_mut(|stage| {
774 // Safety: ensured by caller
775 let stage = unsafe { &mut *stage };
776 stage.poll(&mut cx, *self.id())
777 })
778 }
779
780 /// Wakes the task's [`JoinHandle`], if it has one.
781 ///
782 /// # Safety
783 ///
784 /// - The caller must have exclusive access to the task's `JoinWaker`. This
785 /// is ensured by the task's state management.
786 unsafe fn wake_join_waker(&self) {
787 // Safety: ensured by caller
788 unsafe {
789 self.0.0.join_waker.with_mut(|waker| {
790 if let Some(join_waker) = (*waker).take() {
791 tracing::trace!("waking {join_waker:?}");
792 join_waker.wake();
793 } else {
794 tracing::trace!("called wake_join_waker on non-existing waker");
795 }
796 });
797 }
798 }
799
800 unsafe fn take_output(&self, dst: NonNull<()>) {
801 // Safety: ensured by caller
802 unsafe {
803 self.0.0.stage.with_mut(|stage| {
804 match mem::replace(&mut *stage, Stage::Consumed) {
805 Stage::Ready(output) => {
806 // let output = self.stage.take_output();
807 // safety: the caller is responsible for ensuring that this
808 // points to a `MaybeUninit<F::Output>`.
809 let dst = dst
810 .cast::<CheckedMaybeUninit<Result<F::Output, JoinError<F::Output>>>>()
811 .as_mut();
812
813 // that's right, it goes in the `NonNull<()>` hole!
814 dst.write(output);
815 }
816 _ => panic!("JoinHandle polled after completion"),
817 }
818 });
819 }
820 }
821
822 fn id(&self) -> &Id {
823 &self.0.0.header_and_metadata.header.id
824 }
825 fn state(&self) -> &State {
826 &self.0.0.header_and_metadata.header.state
827 }
828 #[inline]
829 fn span(&self) -> &tracing::Span {
830 &self.0.0.header_and_metadata.header.span
831 }
832}
833
834// === impl Stage ===
835
836impl<F> Stage<F>
837where
838 F: Future,
839{
840 fn poll(&mut self, cx: &mut Context<'_>, id: Id) -> Poll<()> {
841 struct Guard<'a, T: Future> {
842 stage: &'a mut Stage<T>,
843 }
844 impl<T: Future> Drop for Guard<'_, T> {
845 fn drop(&mut self) {
846 // If the future panics on poll, we drop it inside the panic
847 // guard.
848 // Safety: caller has to ensure mutual exclusion
849 *self.stage = Stage::Consumed;
850 }
851 }
852
853 let poll = AssertUnwindSafe(|| -> Poll<F::Output> {
854 let guard = Guard { stage: self };
855
856 // Safety: caller has to ensure mutual exclusion
857 let Stage::Pending(future) = guard.stage else {
858 // TODO this will be caught by the `catch_unwind` which isn't great
859 unreachable!("unexpected stage");
860 };
861
862 // Safety: The caller ensures the future is pinned.
863 let future = unsafe { Pin::new_unchecked(future) };
864 let res = future.poll(cx);
865 mem::forget(guard);
866 res
867 });
868
869 cfg_if::cfg_if! {
870 if #[cfg(test)] {
871 let result = ::std::panic::catch_unwind(poll);
872 } else if #[cfg(feature = "k23-unwind")] {
873 let result = k23_panic_unwind::catch_unwind(poll);
874 } else {
875 let result = Ok(poll());
876 }
877 }
878
879 match result {
880 Ok(Poll::Pending) => Poll::Pending,
881 Ok(Poll::Ready(ready)) => {
882 *self = Stage::Ready(Ok(ready));
883 Poll::Ready(())
884 }
885 Err(err) => {
886 *self = Stage::Ready(Err(JoinError::panic(id, err)));
887 Poll::Ready(())
888 }
889 }
890 }
891}
892
893// === impl Header ===
894
895// Safety: tasks are always treated as pinned in memory (a requirement for polling them)
896// and care has been taken below to ensure the underlying memory isn't freed as long as the
897// `TaskRef` is part of the owned tasks list.
898unsafe impl cordyceps::Linked<mpsc_queue::Links<Self>> for Header {
899 type Handle = TaskRef;
900
901 fn into_ptr(task: Self::Handle) -> NonNull<Self> {
902 let ptr = task.0;
903 // converting a `TaskRef` into a pointer to enqueue it assigns ownership
904 // of the ref count to the queue, so we don't want to run its `Drop`
905 // impl.
906 mem::forget(task);
907 ptr
908 }
909
910 unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
911 TaskRef(ptr)
912 }
913
914 unsafe fn links(ptr: NonNull<Self>) -> NonNull<mpsc_queue::Links<Self>>
915 where
916 Self: Sized,
917 {
918 // Safety: `TaskRef` is just a newtype wrapper around `NonNull<Header>`
919 ptr.map_addr(|addr| {
920 let offset = offset_of!(Self, run_queue_links);
921 addr.checked_add(offset).unwrap()
922 })
923 .cast()
924 }
925}
926
927#[cfg(test)]
928mod tests {
929 use alloc::boxed::Box;
930
931 use crate::loom;
932 use crate::task::{Id, Task, TaskRef};
933
934 #[test]
935 fn metadata() {
936 loom::model(|| {
937 let (t1, _) = TaskRef::new_allocated(Box::new(Task::new(
938 async {},
939 Id::next(),
940 tracing::Span::none(),
941 42usize,
942 )));
943
944 assert_eq!(unsafe { *t1.metadata::<usize>() }, 42);
945 });
946 }
947}