Next Generation WASM Microkernel Operating System
1// Copyright 2025. Jonas Kruckenberg
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8mod builder;
9mod id;
10mod join_handle;
11mod state;
12mod yield_now;
13
14use crate::loom::cell::UnsafeCell;
15use crate::task::state::{JoinAction, StartPollAction, State, WakeByRefAction, WakeByValAction};
16use alloc::boxed::Box;
17use cordyceps::mpsc_queue;
18use core::alloc::Allocator;
19use core::any::type_name;
20use core::mem::offset_of;
21use core::panic::AssertUnwindSafe;
22use core::pin::Pin;
23use core::ptr::NonNull;
24use core::sync::atomic::Ordering;
25use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
26use core::{fmt, mem};
27use util::{CachePadded, CheckedMaybeUninit, loom_const_fn};
28
29use crate::executor::Scheduler;
30pub use builder::TaskBuilder;
31pub use id::Id;
32pub use join_handle::{JoinError, JoinHandle};
33pub use yield_now::yield_now;
34
35/// Outcome of calling [`Task::poll`].
36///
37/// This type describes how to proceed with a given task, whether it needs to be rescheduled
38/// or can be dropped etc.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub(crate) enum PollResult {
41 /// The task has completed, without waking a [`JoinHandle`] waker.
42 ///
43 /// The scheduler can increment a counter of completed tasks, and then drop
44 /// the [`TaskRef`].
45 Ready,
46
47 /// The task has completed and a [`JoinHandle`] waker has been woken.
48 ///
49 /// The scheduler can increment a counter of completed tasks, and then drop
50 /// the [`TaskRef`].
51 ReadyJoined,
52
53 /// The task is pending, but not woken.
54 ///
55 /// The scheduler can drop the [`TaskRef`], as whoever intends to wake the
56 /// task later is holding a clone of its [`Waker`].
57 Pending,
58
59 /// The task has woken itself during the poll.
60 ///
61 /// The scheduler should re-schedule the task, rather than dropping the [`TaskRef`].
62 PendingSchedule,
63}
64
65/// A type-erased, reference-counted pointer to a spawned `Task`.
66///
67/// Once a `Task` is spawned, it is generally pinned in memory (a requirement of [`Future`]). Instead
68/// of moving `Task`s around the scheduler, we therefore use `TaskRef`s which are just pointers to the
69/// pinned `Task`. `TaskRef`s are type-erased interacting with the allocated `Tasks` through their
70/// `Vtable` methods. This is done to reduce the monopolization cost otherwise incurred, since futures,
71/// especially ones crated through `async {}` blocks, `async` closures or `async fn` calls are all
72/// treated as *unique*, *disjoint* types which would all cause separate normalizations. E.g. spawning
73/// 10 futures on the runtime (which is not a lot) would cause 10 different copies of the entire runtime
74/// to be compiled, obviously terrible! The `Vtable` allows us to treat all spawned futures, regardless
75/// of their exact type, the same way.
76///
77/// `TaskRef`s are reference-counted, and the task will be deallocated when the
78/// last `TaskRef` pointing to it is dropped.
79#[derive(Eq, PartialEq)]
80pub struct TaskRef(NonNull<Header>);
81
82#[repr(C)]
83pub struct Task<F: Future>(CachePadded<TaskInner<F>>);
84
85#[repr(C)]
86struct TaskInner<F: Future> {
87 /// This must be the first field of the `Task` struct!
88 schedulable: Schedulable,
89
90 /// The future that the task is running.
91 ///
92 /// If `COMPLETE` is one, then the `JoinHandle` has exclusive access to this field
93 /// If COMPLETE is zero, then the RUNNING bitfield functions as
94 /// a lock for the stage field, and it can be accessed only by the thread
95 /// that set RUNNING to one.
96 stage: UnsafeCell<Stage<F>>,
97
98 /// Consumer task waiting on completion of this task.
99 ///
100 /// This field may be access by different threads: on one cpu we may complete a task and *read*
101 /// the waker field to invoke the waker, and in another thread the task's `JoinHandle` may be
102 /// polled, and if the task hasn't yet completed, the `JoinHandle` may *write* a waker to the
103 /// waker field. The `JOIN_WAKER` bit in the headers`state` field ensures safe access by multiple
104 /// cpu to the waker field using the following rules:
105 ///
106 /// 1. `JOIN_WAKER` is initialized to zero.
107 ///
108 /// 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable)
109 /// access to the waker field.
110 ///
111 /// 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) access to the waker
112 /// field.
113 ///
114 /// 4. If `JOIN_WAKER` is one and COMPLETE is one, then the executor has shared (read-only) access
115 /// to the waker field.
116 ///
117 /// 5. If the `JoinHandle` needs to write to the waker field, then the `JoinHandle` needs to
118 /// (i) successfully set `JOIN_WAKER` to zero if it is not already zero to gain exclusive access
119 /// to the waker field per rule 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER`
120 /// to one. If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped
121 /// to clear the waker field, only steps (i) and (ii) are relevant.
122 ///
123 /// 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e.
124 /// the task hasn't yet completed). The executor can change `JOIN_WAKER` only
125 /// if COMPLETE is one.
126 ///
127 /// 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the executor has
128 /// exclusive (mutable) access to the waker field. This might happen if the
129 /// `JoinHandle` gets dropped right after the task completes and the executor
130 /// sets the `COMPLETE` bit. In this case the executor needs the mutable access
131 /// to the waker field to drop it.
132 ///
133 /// Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a
134 /// race. If step (i) fails, then the attempt to write a waker is aborted. If step (iii) fails
135 /// because `COMPLETE` is set to one by another thread after step (i), then the waker field is
136 /// cleared. Once `COMPLETE` is one (i.e. task has completed), the `JoinHandle` will not
137 /// modify `JOIN_WAKER`. After the runtime sets COMPLETE to one, it invokes the waker if there
138 /// is one so in this case when a task completes the `JOIN_WAKER` bit implicates to the runtime
139 /// whether it should invoke the waker or not. After the runtime is done with using the waker
140 /// during task completion, it unsets the `JOIN_WAKER` bit to give the `JoinHandle` exclusive
141 /// access again so that it is able to drop the waker at a later point.
142 join_waker: UnsafeCell<Option<Waker>>,
143}
144
145#[repr(C)]
146struct Schedulable {
147 /// This must be the first field of the `Schedulable` struct!
148 header: Header,
149 scheduler: UnsafeCell<Option<&'static Scheduler>>,
150}
151
152/// The current lifecycle stage of the future. Either the future itself or its output.
153#[repr(C)] // https://github.com/rust-lang/miri/issues/3780
154enum Stage<F: Future> {
155 /// The future is still pending.
156 Pending(F),
157
158 /// The future has completed, and its output is ready to be taken by a
159 /// `JoinHandle`, if one exists.
160 Ready(Result<F::Output, JoinError<F::Output>>),
161
162 /// The future has completed, and the task's output has been taken or is not
163 /// needed.
164 Consumed,
165}
166
167#[derive(Debug)]
168pub(crate) struct Header {
169 /// The task's state.
170 ///
171 /// This field is access with atomic instructions, so it's always safe to access it.
172 state: State,
173 /// The task vtable for this task.
174 vtable: &'static VTable,
175 /// The task's ID.
176 id: Id,
177 run_queue_links: mpsc_queue::Links<Self>,
178 /// The tracing span associated with this task, for debugging purposes.
179 span: tracing::Span,
180}
181
182#[derive(Debug)]
183struct VTable {
184 /// Poll the future, returning a [`PollResult`] that indicates what the
185 /// scheduler should do with the polled task.
186 poll: unsafe fn(NonNull<Header>) -> PollResult,
187
188 /// Poll the task's `JoinHandle` for completion, storing the output at the
189 /// provided [`NonNull`] pointer if the task has completed.
190 ///
191 /// If the task has not completed, the [`Waker`] from the provided
192 /// [`Context`] is registered to be woken when the task completes.
193 // Splitting this up into type aliases just makes it *harder* to understand
194 // IMO...
195 #[expect(clippy::type_complexity, reason = "")]
196 poll_join: unsafe fn(
197 ptr: NonNull<Header>,
198 outptr: NonNull<()>,
199 cx: &mut Context<'_>,
200 ) -> Poll<Result<(), JoinError<()>>>,
201
202 /// Drops the task and deallocates its memory.
203 deallocate: unsafe fn(NonNull<Header>),
204
205 /// The `wake_by_ref` function from the task's [`RawWakerVTable`].
206 ///
207 /// This is duplicated here as it's used to wake canceled tasks when a task
208 /// is canceled by a [`TaskRef`] or [`JoinHandle`].
209 wake_by_ref: unsafe fn(*const ()),
210}
211
212// === impl TaskRef ===
213
214impl TaskRef {
215 #[track_caller]
216 pub(crate) fn new_allocated<F, A>(task: Box<Task<F>, A>) -> (Self, JoinHandle<F::Output>)
217 where
218 F: Future,
219 A: Allocator,
220 {
221 assert_eq!(task.state().refcount(), 1);
222 let ptr = Box::into_raw(task);
223
224 // Safety: we just allocated the ptr so it is never null
225 let task = Self(unsafe { NonNull::new_unchecked(ptr).cast() });
226 let join = JoinHandle::new(task.clone());
227
228 (task, join)
229 }
230
231 /// Returns the tasks unique[^1] identifier.
232 ///
233 /// [^1]: Unique to all *currently running* tasks, *not* unique across spacetime. See [`Id`] for details.
234 pub fn id(&self) -> Id {
235 self.header().id
236 }
237
238 /// Returns `true` when this task has run to completion.
239 pub fn is_complete(&self) -> bool {
240 self.state()
241 .load(Ordering::Acquire)
242 .get(state::Snapshot::COMPLETE)
243 }
244
245 /// Cancels the task.
246 pub fn cancel(&self) -> bool {
247 // try to set the canceled bit.
248 let canceled = self.state().cancel();
249
250 // if the task was successfully canceled, wake it so that it can clean
251 // up after itself.
252 if canceled {
253 tracing::trace!("woke canceled task");
254 self.wake_by_ref();
255 }
256
257 canceled
258 }
259
260 pub(crate) fn clone_from_raw(ptr: NonNull<Header>) -> TaskRef {
261 let this = Self(ptr);
262 this.state().clone_ref();
263 this
264 }
265
266 pub(crate) fn header_ptr(&self) -> NonNull<Header> {
267 self.0
268 }
269
270 pub(crate) fn header(&self) -> &Header {
271 // Safety: constructor ensures the pointer is always valid
272 unsafe { self.0.as_ref() }
273 }
274
275 /// Returns a reference to the task's state.
276 pub(crate) fn state(&self) -> &State {
277 &self.header().state
278 }
279
280 pub(crate) fn wake_by_ref(&self) {
281 tracing::trace!("TaskRef::wake_by_ref {self:?}");
282 let wake_by_ref_fn = self.header().vtable.wake_by_ref;
283 // Safety: Called through our Vtable so this access should be fine
284 unsafe { wake_by_ref_fn(self.0.as_ptr().cast::<()>()) }
285 }
286
287 pub(crate) fn poll(&self) -> PollResult {
288 let poll_fn = self.header().vtable.poll;
289 // Safety: Called through our Vtable so this access should be fine
290 unsafe { poll_fn(self.0) }
291 }
292
293 /// # Safety
294 ///
295 /// The caller needs to make sure that `T` is the same type as the one that this `TaskRef` was
296 /// created with.
297 pub(crate) unsafe fn poll_join<T>(
298 &self,
299 cx: &mut Context<'_>,
300 ) -> Poll<Result<T, JoinError<T>>> {
301 let poll_join_fn = self.header().vtable.poll_join;
302 let mut slot = CheckedMaybeUninit::<Result<T, JoinError<T>>>::uninit();
303
304 // Safety: This is called through the Vtable and as long as the caller makes sure that the `T` is the right
305 // type, this call is safe
306 let result = unsafe { poll_join_fn(self.0, NonNull::from(&mut slot).cast::<()>(), cx) };
307
308 result.map(|result| {
309 if let Err(e) = result {
310 let output = if e.is_completed() {
311 // Safety: if the task completed before being canceled, we can still
312 // take its output.
313 Some(unsafe { slot.assume_init_read() }?)
314 } else {
315 None
316 };
317 Err(e.with_output(output))
318 } else {
319 // Safety: if the poll function returned `Ok`, we get to take the
320 // output!
321 unsafe { slot.assume_init_read() }
322 }
323 })
324 }
325
326 /// Bind this task to a new scheduler
327 ///
328 /// # Safety
329 ///
330 /// The new scheduler `S` must be of the **same** type as the scheduler that this task got created
331 /// with. The shape of the allocated tasks depend on the type of the scheduler, binding a task
332 /// to a differently typed scheduler will therefore cause invalid memory accesses.
333 pub(crate) fn bind_scheduler(&self, scheduler: &'static Scheduler) {
334 // Safety: the repr(C) on Schedulable ensures the layout matches and this cast is safe
335 unsafe {
336 self.0
337 .cast::<Schedulable>()
338 .as_ref()
339 .scheduler
340 .with_mut(|current| *current = Some(scheduler));
341 }
342 }
343}
344
345impl fmt::Debug for TaskRef {
346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
347 f.debug_struct("TaskRef")
348 .field("id", &self.id())
349 .field("addr", &self.0)
350 .finish()
351 }
352}
353
354impl fmt::Pointer for TaskRef {
355 #[inline]
356 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
357 fmt::Pointer::fmt(&self.0, f)
358 }
359}
360
361impl Clone for TaskRef {
362 #[inline]
363 #[track_caller]
364 fn clone(&self) -> Self {
365 let loc = core::panic::Location::caller();
366 tracing::trace!(
367 task.addr=?self.0,
368 task.is_stub=self.id().is_stub(),
369 loc.file = loc.file(),
370 loc.line = loc.line(),
371 loc.col = loc.column(),
372 "TaskRef::clone",
373 );
374 self.state().clone_ref();
375 Self(self.0)
376 }
377}
378
379impl Drop for TaskRef {
380 #[inline]
381 #[track_caller]
382 fn drop(&mut self) {
383 tracing::trace!(
384 task.addr=?self.0,
385 task.is_stub=self.id().is_stub(),
386 "TaskRef::drop"
387 );
388 if !self.state().drop_ref() {
389 return;
390 }
391
392 let deallocate = self.header().vtable.deallocate;
393 // Safety: as long as we're constructed from a NonNull<Header> this is safe
394 unsafe {
395 deallocate(self.0);
396 }
397 }
398}
399
400// Safety: The state protocol ensured synchronized access to the inner task
401unsafe impl Send for TaskRef {}
402// Safety: The state protocol ensured synchronized access to the inner task
403unsafe impl Sync for TaskRef {}
404
405// === impl Task ===
406
407impl<F: Future> Task<F> {
408 const TASK_VTABLE: VTable = VTable {
409 poll: Self::poll,
410 poll_join: Self::poll_join,
411 deallocate: Self::deallocate,
412 wake_by_ref: Schedulable::wake_by_ref,
413 };
414
415 loom_const_fn! {
416 pub const fn new(future: F, task_id: Id, span: tracing::Span) -> Self {
417 let inner = TaskInner {
418 schedulable: Schedulable {
419 header: Header {
420 state: State::new(),
421 vtable: &Self::TASK_VTABLE,
422 id: task_id,
423 run_queue_links: mpsc_queue::Links::new(),
424 span,
425 },
426 scheduler: UnsafeCell::new(None)
427 },
428 stage: UnsafeCell::new(Stage::Pending(future)),
429 join_waker: UnsafeCell::new(None),
430 };
431 Self(CachePadded(inner))
432 }
433 }
434
435 /// Poll the future, returning a [`PollResult`] that indicates what the
436 /// scheduler should do with the polled task.
437 ///
438 /// This is a type-erased function called through the task's [`Vtable`].
439 ///
440 /// # Safety
441 ///
442 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
443 /// pointed header must have the same `S`, `F`, and `STO` type parameters
444 /// as `Self`).
445 unsafe fn poll(ptr: NonNull<Header>) -> PollResult {
446 // Safety: ensured by caller
447 unsafe {
448 let this = ptr.cast::<Self>().as_ref();
449
450 tracing::trace!(
451 task.addr=?ptr,
452 task.output=type_name::<F::Output>(),
453 task.id=?this.id(),
454 "Task::poll",
455 );
456
457 match this.state().start_poll() {
458 // Successfully to transitioned to `POLLING` all is good!
459 StartPollAction::Poll => {}
460 // Something isn't right, we shouldn't poll the task right now...
461 StartPollAction::DontPoll => {
462 tracing::warn!(task.addr=?ptr, "failed to transition to polling",);
463 return PollResult::Ready;
464 }
465 StartPollAction::Cancelled { wake_join_waker } => {
466 tracing::trace!(task.addr=?ptr, "task cancelled");
467 if wake_join_waker {
468 this.wake_join_waker();
469 return PollResult::ReadyJoined;
470 } else {
471 return PollResult::Ready;
472 }
473 }
474 }
475
476 // wrap the waker in `ManuallyDrop` because we're converting it from an
477 // existing task ref, rather than incrementing the task ref count. if
478 // this waker is consumed during the poll, we don't want to decrement
479 // its ref count when the poll ends.
480 let waker = {
481 let raw = Schedulable::raw_waker(ptr.as_ptr().cast());
482 mem::ManuallyDrop::new(Waker::from_raw(raw))
483 };
484
485 // actually poll the task
486 let poll = {
487 let cx = Context::from_waker(&waker);
488
489 this.poll_inner(cx)
490 };
491
492 let result = this.state().end_poll(poll.is_ready());
493
494 // if the task is ready and has a `JoinHandle` to wake, wake the join
495 // waker now.
496 if result == PollResult::ReadyJoined {
497 this.wake_join_waker();
498 }
499
500 result
501 }
502 }
503
504 /// Poll to join the task pointed to by `ptr`, taking its output if it has
505 /// completed.
506 ///
507 /// If the task has completed, this method returns [`Poll::Ready`], and the
508 /// task's output is stored at the memory location pointed to by `outptr`.
509 /// This function is called by [`JoinHandle`]s o poll the task they
510 /// correspond to.
511 ///
512 /// This is a type-erased function called through the task's [`Vtable`].
513 ///
514 /// # Safety
515 ///
516 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
517 /// pointed header must have the same `S`, `F`, and `STO` type parameters
518 /// as `Self`).
519 /// - `outptr` must point to a valid `MaybeUninit<F::Output>`.
520 unsafe fn poll_join(
521 ptr: NonNull<Header>,
522 outptr: NonNull<()>,
523 cx: &mut Context<'_>,
524 ) -> Poll<Result<(), JoinError<()>>> {
525 // Safety: ensured by caller
526 unsafe {
527 let this = ptr.cast::<Self>().as_ref();
528 tracing::trace!(
529 task.addr=?ptr,
530 task.output=type_name::<F::Output>(),
531 task.id=?this.id(),
532 "Task::poll_join"
533 );
534
535 match this.state().try_join() {
536 JoinAction::TakeOutput => {
537 // safety: if the state transition returns
538 // `JoinAction::TakeOutput`, this indicates that we have
539 // exclusive permission to read the task output.
540 this.take_output(outptr);
541 return Poll::Ready(Ok(()));
542 }
543 JoinAction::Canceled { completed } => {
544 // if the task has completed before it was canceled, also try to
545 // read the output, so that it can be returned in the `JoinError`.
546 if completed {
547 // safety: if the state transition returned `Canceled`
548 // with `completed` set, this indicates that we have
549 // exclusive permission to take the output.
550 this.take_output(outptr);
551 }
552 return Poll::Ready(Err(JoinError::cancelled(completed, *this.id())));
553 }
554 JoinAction::Register => {
555 this.0.0.join_waker.with_mut(|waker| {
556 waker.write(Some(cx.waker().clone()));
557 });
558 }
559 JoinAction::Reregister => {
560 this.0.0.join_waker.with_mut(|waker| {
561 let waker = (*waker).as_mut().unwrap();
562
563 let new_waker = cx.waker();
564 if !waker.will_wake(new_waker) {
565 *waker = new_waker.clone();
566 }
567 });
568 }
569 }
570 this.state().join_waker_registered();
571 Poll::Pending
572 }
573 }
574
575 /// Drops the task and deallocates its memory.
576 ///
577 /// This is a type-erased function called through the task's [`Vtable`].
578 ///
579 /// # Safety
580 ///
581 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the
582 /// pointed header must have the same `S`, `F`, and `STO` type parameters
583 /// as `Self`).
584 unsafe fn deallocate(ptr: NonNull<Header>) {
585 // Safety: ensured by caller
586 unsafe {
587 let this = ptr.cast::<Self>();
588 tracing::trace!(
589 task.addr=?ptr,
590 task.output=type_name::<F::Output>(),
591 task.id=?this.as_ref().id(),
592 task.is_stub=?this.as_ref().id().is_stub(),
593 "Task::deallocate",
594 );
595 debug_assert_eq!(
596 ptr.as_ref().state.load(Ordering::Acquire).ref_count(),
597 0,
598 "a task may not be deallocated if its ref count is greater than zero!"
599 );
600 drop(Box::from_raw(this.as_ptr()));
601 }
602 }
603
604 /// Polls the future. If the future completes, the output is written to the
605 /// stage field.
606 ///
607 /// # Safety
608 ///
609 /// The caller has to ensure this cpu has exclusive mutable access to the tasks `stage` field (ie the
610 /// future or output).
611 pub unsafe fn poll_inner(&self, mut cx: Context) -> Poll<()> {
612 let _span = self.span().enter();
613
614 self.0.0.stage.with_mut(|stage| {
615 // Safety: ensured by caller
616 let stage = unsafe { &mut *stage };
617 stage.poll(&mut cx, *self.id())
618 })
619 }
620
621 /// Wakes the task's [`JoinHandle`], if it has one.
622 ///
623 /// # Safety
624 ///
625 /// - The caller must have exclusive access to the task's `JoinWaker`. This
626 /// is ensured by the task's state management.
627 unsafe fn wake_join_waker(&self) {
628 // Safety: ensured by caller
629 unsafe {
630 self.0.0.join_waker.with_mut(|waker| {
631 if let Some(join_waker) = (*waker).take() {
632 tracing::trace!("waking {join_waker:?}");
633 join_waker.wake();
634 } else {
635 tracing::trace!("called wake_join_waker on non-existing waker");
636 }
637 });
638 }
639 }
640
641 unsafe fn take_output(&self, dst: NonNull<()>) {
642 // Safety: ensured by caller
643 unsafe {
644 self.0.0.stage.with_mut(|stage| {
645 match mem::replace(&mut *stage, Stage::Consumed) {
646 Stage::Ready(output) => {
647 // let output = self.stage.take_output();
648 // safety: the caller is responsible for ensuring that this
649 // points to a `MaybeUninit<F::Output>`.
650 let dst = dst
651 .cast::<CheckedMaybeUninit<Result<F::Output, JoinError<F::Output>>>>()
652 .as_mut();
653
654 // that's right, it goes in the `NonNull<()>` hole!
655 dst.write(output);
656 }
657 _ => panic!("JoinHandle polled after completion"),
658 }
659 });
660 }
661 }
662
663 fn id(&self) -> &Id {
664 &self.0.0.schedulable.header.id
665 }
666 fn state(&self) -> &State {
667 &self.0.0.schedulable.header.state
668 }
669 #[inline]
670 fn span(&self) -> &tracing::Span {
671 &self.0.0.schedulable.header.span
672 }
673}
674
675impl Task<Stub> {
676 const HEAP_STUB_VTABLE: VTable = VTable {
677 poll: stub_poll,
678 poll_join: stub_poll_join,
679 // Heap allocated stub tasks *will* need to be deallocated, since the
680 // scheduler will deallocate its stub task if it's dropped.
681 deallocate: Self::deallocate,
682 wake_by_ref: stub_wake_by_ref,
683 };
684
685 loom_const_fn! {
686 /// Create a new stub task.
687 pub(crate) const fn new_stub() -> Self {
688 let inner = TaskInner {
689 schedulable: Schedulable {
690 header: Header {
691 state: State::new(),
692 vtable: &Self::HEAP_STUB_VTABLE,
693 id: Id::stub(),
694 run_queue_links: mpsc_queue::Links::new_stub(),
695 span: tracing::Span::none(),
696 },
697 scheduler: UnsafeCell::new(None)
698 },
699 stage: UnsafeCell::new(Stage::Pending(Stub)),
700 join_waker: UnsafeCell::new(None),
701 };
702
703 Self(CachePadded(inner))
704 }
705 }
706}
707
708// === impl Stage ===
709
710impl<F> Stage<F>
711where
712 F: Future,
713{
714 fn poll(&mut self, cx: &mut Context<'_>, id: Id) -> Poll<()> {
715 struct Guard<'a, T: Future> {
716 stage: &'a mut Stage<T>,
717 }
718 impl<T: Future> Drop for Guard<'_, T> {
719 fn drop(&mut self) {
720 // If the future panics on poll, we drop it inside the panic
721 // guard.
722 // Safety: caller has to ensure mutual exclusion
723 *self.stage = Stage::Consumed;
724 }
725 }
726
727 let poll = AssertUnwindSafe(|| -> Poll<F::Output> {
728 let guard = Guard { stage: self };
729
730 // Safety: caller has to ensure mutual exclusion
731 let Stage::Pending(future) = guard.stage else {
732 // TODO this will be caught by the `catch_unwind` which isn't great
733 unreachable!("unexpected stage");
734 };
735
736 // Safety: The caller ensures the future is pinned.
737 let future = unsafe { Pin::new_unchecked(future) };
738 let res = future.poll(cx);
739 mem::forget(guard);
740 res
741 });
742
743 cfg_if::cfg_if! {
744 if #[cfg(test)] {
745 let result = ::std::panic::catch_unwind(poll);
746 } else if #[cfg(feature = "unwind2")] {
747 let result = panic_unwind2::catch_unwind(poll);
748 } else {
749 let result = Ok(poll());
750 }
751 }
752
753 match result {
754 Ok(Poll::Pending) => Poll::Pending,
755 Ok(Poll::Ready(ready)) => {
756 *self = Stage::Ready(Ok(ready));
757 Poll::Ready(())
758 }
759 Err(err) => {
760 *self = Stage::Ready(Err(JoinError::panic(id, err)));
761 Poll::Ready(())
762 }
763 }
764 }
765}
766
767// === impl Schedulable ===
768
769impl Schedulable {
770 const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
771 Self::clone_waker,
772 Self::wake_by_val,
773 Self::wake_by_ref,
774 Self::drop_waker,
775 );
776
777 // `Waker::will_wake` is used all over the place to optimize waker code (e.g. only update wakers if they
778 // have a different wake target). Problem is `will_wake` only checks for pointer equality and since
779 // the `into_raw_waker` would usually be inlined in release mode (and with it `WAKER_VTABLE`) the
780 // Waker identity would be different before and after calling `.clone()`. This isn't a correctness
781 // problem since it's still the same waker in the end, it just causes a lot of unnecessary wake ups.
782 // the `inline(never)` below is therefore quite load-bearing
783 #[inline(never)]
784 fn raw_waker(this: *const Self) -> RawWaker {
785 RawWaker::new(this.cast::<()>(), &Self::WAKER_VTABLE)
786 }
787
788 #[inline(always)]
789 fn state(&self) -> &State {
790 &self.header.state
791 }
792
793 unsafe fn schedule(this: TaskRef) {
794 // Safety: ensured by caller
795 unsafe {
796 this.header_ptr()
797 .cast::<Self>()
798 .as_ref()
799 .scheduler
800 .with(|scheduler| {
801 (*scheduler)
802 .as_ref()
803 .expect("task doesn't have an associated scheduler, this is a bug!")
804 .schedule(this);
805 });
806 }
807 }
808
809 #[inline]
810 unsafe fn drop_ref(this: NonNull<Self>) {
811 // Safety: ensured by caller
812 unsafe {
813 tracing::trace!(task.addr=?this, task.id=?this.as_ref().header.id, "Task::drop_ref");
814 if !this.as_ref().state().drop_ref() {
815 return;
816 }
817
818 let deallocate = this.as_ref().header.vtable.deallocate;
819 deallocate(this.cast::<Header>());
820 }
821 }
822
823 // === Waker vtable methods ===
824
825 unsafe fn wake_by_val(ptr: *const ()) {
826 // Safety: called through RawWakerVtable
827 unsafe {
828 let ptr = ptr.cast::<Self>();
829 tracing::trace!(
830 target: "scheduler:waker",
831 {
832 task.addr = ?ptr,
833 task.tid = (*ptr).header.id.as_u64()
834 },
835 "Task::wake_by_val"
836 );
837
838 let this = NonNull::new_unchecked(ptr.cast_mut());
839 match this.as_ref().header.state.wake_by_val() {
840 WakeByValAction::Enqueue => {
841 // the task should be enqueued.
842 //
843 // in the case that the task is enqueued, the state
844 // transition does *not* decrement the reference count. this is
845 // in order to avoid dropping the task while it is being
846 // scheduled. one reference is consumed by enqueuing the task...
847 Self::schedule(TaskRef(this.cast::<Header>()));
848 // now that the task has been enqueued, decrement the reference
849 // count to drop the waker that performed the `wake_by_val`.
850 Self::drop_ref(this);
851 }
852 WakeByValAction::Drop => Self::drop_ref(this),
853 WakeByValAction::None => {}
854 }
855 }
856 }
857
858 unsafe fn wake_by_ref(ptr: *const ()) {
859 // Safety: called through RawWakerVtable
860 unsafe {
861 let this = ptr.cast::<Self>();
862 tracing::trace!(
863 target: "scheduler:waker",
864 {
865 task.addr = ?this,
866 task.tid = (*this).header.id.as_u64()
867 },
868 "Task::wake_by_ref"
869 );
870
871 let this = NonNull::new_unchecked(this.cast_mut()).cast::<Self>();
872 if this.as_ref().state().wake_by_ref() == WakeByRefAction::Enqueue {
873 Self::schedule(TaskRef(this.cast::<Header>()));
874 }
875 }
876 }
877
878 unsafe fn clone_waker(ptr: *const ()) -> RawWaker {
879 // Safety: called through RawWakerVtable
880 unsafe {
881 let ptr = ptr.cast::<Self>();
882 tracing::trace!(
883 target: "scheduler:waker",
884 {
885 task.addr = ?ptr,
886 task.tid = (*ptr).header.id.as_u64()
887 },
888 "Task::clone_waker"
889 );
890
891 (*ptr).header.state.clone_ref();
892 Self::raw_waker(ptr)
893 }
894 }
895
896 unsafe fn drop_waker(ptr: *const ()) {
897 // Safety: called through RawWakerVtable
898 unsafe {
899 let ptr = ptr.cast::<Self>();
900 tracing::trace!(
901 target: "scheduler:waker",
902 {
903 task.addr = ?ptr,
904 task.tid = (*ptr).header.id.as_u64()
905 },
906 "Task::drop_waker"
907 );
908
909 let this = ptr.cast_mut();
910 Self::drop_ref(NonNull::new_unchecked(this));
911 }
912 }
913}
914
915// === impl Header ===
916
917// Safety: tasks are always treated as pinned in memory (a requirement for polling them)
918// and care has been taken below to ensure the underlying memory isn't freed as long as the
919// `TaskRef` is part of the owned tasks list.
920unsafe impl cordyceps::Linked<mpsc_queue::Links<Self>> for Header {
921 type Handle = TaskRef;
922
923 fn into_ptr(task: Self::Handle) -> NonNull<Self> {
924 let ptr = task.0;
925 // converting a `TaskRef` into a pointer to enqueue it assigns ownership
926 // of the ref count to the queue, so we don't want to run its `Drop`
927 // impl.
928 mem::forget(task);
929 ptr
930 }
931
932 unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle {
933 TaskRef(ptr)
934 }
935
936 unsafe fn links(ptr: NonNull<Self>) -> NonNull<mpsc_queue::Links<Self>>
937 where
938 Self: Sized,
939 {
940 // Safety: `TaskRef` is just a newtype wrapper around `NonNull<Header>`
941 ptr.map_addr(|addr| {
942 let offset = offset_of!(Self, run_queue_links);
943 addr.checked_add(offset).unwrap()
944 })
945 .cast()
946 }
947}
948
949/// DO NOT confuse this with [`TaskSTub`]. This type is just a zero-size placeholder so we
950/// can plug *something* into the generics when creating the *heap allocated* stub task.
951/// This type is *not* publicly exported, contrary to [`TaskSTub`] which users will have to statically
952/// allocate themselves.
953#[derive(Copy, Clone, Debug)]
954pub(crate) struct Stub;
955
956impl Future for Stub {
957 type Output = ();
958 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
959 unreachable!("the stub task should never be polled!")
960 }
961}
962
963unsafe fn stub_poll(ptr: NonNull<Header>) -> PollResult {
964 // Safety: this method should never be called
965 unsafe {
966 debug_assert!(ptr.as_ref().id.is_stub());
967 unreachable!("stub task ({ptr:?}) should never be polled!");
968 }
969}
970
971unsafe fn stub_poll_join(
972 ptr: NonNull<Header>,
973 _outptr: NonNull<()>,
974 _cx: &mut Context<'_>,
975) -> Poll<Result<(), JoinError<()>>> {
976 // Safety: this method should never be called
977 unsafe {
978 debug_assert!(ptr.as_ref().id.is_stub());
979 unreachable!("stub task ({ptr:?}) should never be polled!");
980 }
981}
982
983unsafe fn stub_wake_by_ref(ptr: *const ()) {
984 unreachable!("stub task ({ptr:p}) has no waker and should never be woken!");
985}