Next Generation WASM Microkernel Operating System
at trap_handler 985 lines 35 kB view raw
1// Copyright 2025. Jonas Kruckenberg 2// 3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or 4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or 5// http://opensource.org/licenses/MIT>, at your option. This file may not be 6// copied, modified, or distributed except according to those terms. 7 8mod builder; 9mod id; 10mod join_handle; 11mod state; 12mod yield_now; 13 14use crate::loom::cell::UnsafeCell; 15use crate::task::state::{JoinAction, StartPollAction, State, WakeByRefAction, WakeByValAction}; 16use alloc::boxed::Box; 17use cordyceps::mpsc_queue; 18use core::alloc::Allocator; 19use core::any::type_name; 20use core::mem::offset_of; 21use core::panic::AssertUnwindSafe; 22use core::pin::Pin; 23use core::ptr::NonNull; 24use core::sync::atomic::Ordering; 25use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; 26use core::{fmt, mem}; 27use util::{CachePadded, CheckedMaybeUninit, loom_const_fn}; 28 29use crate::executor::Scheduler; 30pub use builder::TaskBuilder; 31pub use id::Id; 32pub use join_handle::{JoinError, JoinHandle}; 33pub use yield_now::yield_now; 34 35/// Outcome of calling [`Task::poll`]. 36/// 37/// This type describes how to proceed with a given task, whether it needs to be rescheduled 38/// or can be dropped etc. 39#[derive(Debug, Clone, Copy, PartialEq, Eq)] 40pub(crate) enum PollResult { 41 /// The task has completed, without waking a [`JoinHandle`] waker. 42 /// 43 /// The scheduler can increment a counter of completed tasks, and then drop 44 /// the [`TaskRef`]. 45 Ready, 46 47 /// The task has completed and a [`JoinHandle`] waker has been woken. 48 /// 49 /// The scheduler can increment a counter of completed tasks, and then drop 50 /// the [`TaskRef`]. 51 ReadyJoined, 52 53 /// The task is pending, but not woken. 54 /// 55 /// The scheduler can drop the [`TaskRef`], as whoever intends to wake the 56 /// task later is holding a clone of its [`Waker`]. 57 Pending, 58 59 /// The task has woken itself during the poll. 60 /// 61 /// The scheduler should re-schedule the task, rather than dropping the [`TaskRef`]. 62 PendingSchedule, 63} 64 65/// A type-erased, reference-counted pointer to a spawned `Task`. 66/// 67/// Once a `Task` is spawned, it is generally pinned in memory (a requirement of [`Future`]). Instead 68/// of moving `Task`s around the scheduler, we therefore use `TaskRef`s which are just pointers to the 69/// pinned `Task`. `TaskRef`s are type-erased interacting with the allocated `Tasks` through their 70/// `Vtable` methods. This is done to reduce the monopolization cost otherwise incurred, since futures, 71/// especially ones crated through `async {}` blocks, `async` closures or `async fn` calls are all 72/// treated as *unique*, *disjoint* types which would all cause separate normalizations. E.g. spawning 73/// 10 futures on the runtime (which is not a lot) would cause 10 different copies of the entire runtime 74/// to be compiled, obviously terrible! The `Vtable` allows us to treat all spawned futures, regardless 75/// of their exact type, the same way. 76/// 77/// `TaskRef`s are reference-counted, and the task will be deallocated when the 78/// last `TaskRef` pointing to it is dropped. 79#[derive(Eq, PartialEq)] 80pub struct TaskRef(NonNull<Header>); 81 82#[repr(C)] 83pub struct Task<F: Future>(CachePadded<TaskInner<F>>); 84 85#[repr(C)] 86struct TaskInner<F: Future> { 87 /// This must be the first field of the `Task` struct! 88 schedulable: Schedulable, 89 90 /// The future that the task is running. 91 /// 92 /// If `COMPLETE` is one, then the `JoinHandle` has exclusive access to this field 93 /// If COMPLETE is zero, then the RUNNING bitfield functions as 94 /// a lock for the stage field, and it can be accessed only by the thread 95 /// that set RUNNING to one. 96 stage: UnsafeCell<Stage<F>>, 97 98 /// Consumer task waiting on completion of this task. 99 /// 100 /// This field may be access by different threads: on one cpu we may complete a task and *read* 101 /// the waker field to invoke the waker, and in another thread the task's `JoinHandle` may be 102 /// polled, and if the task hasn't yet completed, the `JoinHandle` may *write* a waker to the 103 /// waker field. The `JOIN_WAKER` bit in the headers`state` field ensures safe access by multiple 104 /// cpu to the waker field using the following rules: 105 /// 106 /// 1. `JOIN_WAKER` is initialized to zero. 107 /// 108 /// 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable) 109 /// access to the waker field. 110 /// 111 /// 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) access to the waker 112 /// field. 113 /// 114 /// 4. If `JOIN_WAKER` is one and COMPLETE is one, then the executor has shared (read-only) access 115 /// to the waker field. 116 /// 117 /// 5. If the `JoinHandle` needs to write to the waker field, then the `JoinHandle` needs to 118 /// (i) successfully set `JOIN_WAKER` to zero if it is not already zero to gain exclusive access 119 /// to the waker field per rule 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` 120 /// to one. If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped 121 /// to clear the waker field, only steps (i) and (ii) are relevant. 122 /// 123 /// 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e. 124 /// the task hasn't yet completed). The executor can change `JOIN_WAKER` only 125 /// if COMPLETE is one. 126 /// 127 /// 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the executor has 128 /// exclusive (mutable) access to the waker field. This might happen if the 129 /// `JoinHandle` gets dropped right after the task completes and the executor 130 /// sets the `COMPLETE` bit. In this case the executor needs the mutable access 131 /// to the waker field to drop it. 132 /// 133 /// Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a 134 /// race. If step (i) fails, then the attempt to write a waker is aborted. If step (iii) fails 135 /// because `COMPLETE` is set to one by another thread after step (i), then the waker field is 136 /// cleared. Once `COMPLETE` is one (i.e. task has completed), the `JoinHandle` will not 137 /// modify `JOIN_WAKER`. After the runtime sets COMPLETE to one, it invokes the waker if there 138 /// is one so in this case when a task completes the `JOIN_WAKER` bit implicates to the runtime 139 /// whether it should invoke the waker or not. After the runtime is done with using the waker 140 /// during task completion, it unsets the `JOIN_WAKER` bit to give the `JoinHandle` exclusive 141 /// access again so that it is able to drop the waker at a later point. 142 join_waker: UnsafeCell<Option<Waker>>, 143} 144 145#[repr(C)] 146struct Schedulable { 147 /// This must be the first field of the `Schedulable` struct! 148 header: Header, 149 scheduler: UnsafeCell<Option<&'static Scheduler>>, 150} 151 152/// The current lifecycle stage of the future. Either the future itself or its output. 153#[repr(C)] // https://github.com/rust-lang/miri/issues/3780 154enum Stage<F: Future> { 155 /// The future is still pending. 156 Pending(F), 157 158 /// The future has completed, and its output is ready to be taken by a 159 /// `JoinHandle`, if one exists. 160 Ready(Result<F::Output, JoinError<F::Output>>), 161 162 /// The future has completed, and the task's output has been taken or is not 163 /// needed. 164 Consumed, 165} 166 167#[derive(Debug)] 168pub(crate) struct Header { 169 /// The task's state. 170 /// 171 /// This field is access with atomic instructions, so it's always safe to access it. 172 state: State, 173 /// The task vtable for this task. 174 vtable: &'static VTable, 175 /// The task's ID. 176 id: Id, 177 run_queue_links: mpsc_queue::Links<Self>, 178 /// The tracing span associated with this task, for debugging purposes. 179 span: tracing::Span, 180} 181 182#[derive(Debug)] 183struct VTable { 184 /// Poll the future, returning a [`PollResult`] that indicates what the 185 /// scheduler should do with the polled task. 186 poll: unsafe fn(NonNull<Header>) -> PollResult, 187 188 /// Poll the task's `JoinHandle` for completion, storing the output at the 189 /// provided [`NonNull`] pointer if the task has completed. 190 /// 191 /// If the task has not completed, the [`Waker`] from the provided 192 /// [`Context`] is registered to be woken when the task completes. 193 // Splitting this up into type aliases just makes it *harder* to understand 194 // IMO... 195 #[expect(clippy::type_complexity, reason = "")] 196 poll_join: unsafe fn( 197 ptr: NonNull<Header>, 198 outptr: NonNull<()>, 199 cx: &mut Context<'_>, 200 ) -> Poll<Result<(), JoinError<()>>>, 201 202 /// Drops the task and deallocates its memory. 203 deallocate: unsafe fn(NonNull<Header>), 204 205 /// The `wake_by_ref` function from the task's [`RawWakerVTable`]. 206 /// 207 /// This is duplicated here as it's used to wake canceled tasks when a task 208 /// is canceled by a [`TaskRef`] or [`JoinHandle`]. 209 wake_by_ref: unsafe fn(*const ()), 210} 211 212// === impl TaskRef === 213 214impl TaskRef { 215 #[track_caller] 216 pub(crate) fn new_allocated<F, A>(task: Box<Task<F>, A>) -> (Self, JoinHandle<F::Output>) 217 where 218 F: Future, 219 A: Allocator, 220 { 221 assert_eq!(task.state().refcount(), 1); 222 let ptr = Box::into_raw(task); 223 224 // Safety: we just allocated the ptr so it is never null 225 let task = Self(unsafe { NonNull::new_unchecked(ptr).cast() }); 226 let join = JoinHandle::new(task.clone()); 227 228 (task, join) 229 } 230 231 /// Returns the tasks unique[^1] identifier. 232 /// 233 /// [^1]: Unique to all *currently running* tasks, *not* unique across spacetime. See [`Id`] for details. 234 pub fn id(&self) -> Id { 235 self.header().id 236 } 237 238 /// Returns `true` when this task has run to completion. 239 pub fn is_complete(&self) -> bool { 240 self.state() 241 .load(Ordering::Acquire) 242 .get(state::Snapshot::COMPLETE) 243 } 244 245 /// Cancels the task. 246 pub fn cancel(&self) -> bool { 247 // try to set the canceled bit. 248 let canceled = self.state().cancel(); 249 250 // if the task was successfully canceled, wake it so that it can clean 251 // up after itself. 252 if canceled { 253 tracing::trace!("woke canceled task"); 254 self.wake_by_ref(); 255 } 256 257 canceled 258 } 259 260 pub(crate) fn clone_from_raw(ptr: NonNull<Header>) -> TaskRef { 261 let this = Self(ptr); 262 this.state().clone_ref(); 263 this 264 } 265 266 pub(crate) fn header_ptr(&self) -> NonNull<Header> { 267 self.0 268 } 269 270 pub(crate) fn header(&self) -> &Header { 271 // Safety: constructor ensures the pointer is always valid 272 unsafe { self.0.as_ref() } 273 } 274 275 /// Returns a reference to the task's state. 276 pub(crate) fn state(&self) -> &State { 277 &self.header().state 278 } 279 280 pub(crate) fn wake_by_ref(&self) { 281 tracing::trace!("TaskRef::wake_by_ref {self:?}"); 282 let wake_by_ref_fn = self.header().vtable.wake_by_ref; 283 // Safety: Called through our Vtable so this access should be fine 284 unsafe { wake_by_ref_fn(self.0.as_ptr().cast::<()>()) } 285 } 286 287 pub(crate) fn poll(&self) -> PollResult { 288 let poll_fn = self.header().vtable.poll; 289 // Safety: Called through our Vtable so this access should be fine 290 unsafe { poll_fn(self.0) } 291 } 292 293 /// # Safety 294 /// 295 /// The caller needs to make sure that `T` is the same type as the one that this `TaskRef` was 296 /// created with. 297 pub(crate) unsafe fn poll_join<T>( 298 &self, 299 cx: &mut Context<'_>, 300 ) -> Poll<Result<T, JoinError<T>>> { 301 let poll_join_fn = self.header().vtable.poll_join; 302 let mut slot = CheckedMaybeUninit::<Result<T, JoinError<T>>>::uninit(); 303 304 // Safety: This is called through the Vtable and as long as the caller makes sure that the `T` is the right 305 // type, this call is safe 306 let result = unsafe { poll_join_fn(self.0, NonNull::from(&mut slot).cast::<()>(), cx) }; 307 308 result.map(|result| { 309 if let Err(e) = result { 310 let output = if e.is_completed() { 311 // Safety: if the task completed before being canceled, we can still 312 // take its output. 313 Some(unsafe { slot.assume_init_read() }?) 314 } else { 315 None 316 }; 317 Err(e.with_output(output)) 318 } else { 319 // Safety: if the poll function returned `Ok`, we get to take the 320 // output! 321 unsafe { slot.assume_init_read() } 322 } 323 }) 324 } 325 326 /// Bind this task to a new scheduler 327 /// 328 /// # Safety 329 /// 330 /// The new scheduler `S` must be of the **same** type as the scheduler that this task got created 331 /// with. The shape of the allocated tasks depend on the type of the scheduler, binding a task 332 /// to a differently typed scheduler will therefore cause invalid memory accesses. 333 pub(crate) fn bind_scheduler(&self, scheduler: &'static Scheduler) { 334 // Safety: the repr(C) on Schedulable ensures the layout matches and this cast is safe 335 unsafe { 336 self.0 337 .cast::<Schedulable>() 338 .as_ref() 339 .scheduler 340 .with_mut(|current| *current = Some(scheduler)); 341 } 342 } 343} 344 345impl fmt::Debug for TaskRef { 346 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 347 f.debug_struct("TaskRef") 348 .field("id", &self.id()) 349 .field("addr", &self.0) 350 .finish() 351 } 352} 353 354impl fmt::Pointer for TaskRef { 355 #[inline] 356 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 357 fmt::Pointer::fmt(&self.0, f) 358 } 359} 360 361impl Clone for TaskRef { 362 #[inline] 363 #[track_caller] 364 fn clone(&self) -> Self { 365 let loc = core::panic::Location::caller(); 366 tracing::trace!( 367 task.addr=?self.0, 368 task.is_stub=self.id().is_stub(), 369 loc.file = loc.file(), 370 loc.line = loc.line(), 371 loc.col = loc.column(), 372 "TaskRef::clone", 373 ); 374 self.state().clone_ref(); 375 Self(self.0) 376 } 377} 378 379impl Drop for TaskRef { 380 #[inline] 381 #[track_caller] 382 fn drop(&mut self) { 383 tracing::trace!( 384 task.addr=?self.0, 385 task.is_stub=self.id().is_stub(), 386 "TaskRef::drop" 387 ); 388 if !self.state().drop_ref() { 389 return; 390 } 391 392 let deallocate = self.header().vtable.deallocate; 393 // Safety: as long as we're constructed from a NonNull<Header> this is safe 394 unsafe { 395 deallocate(self.0); 396 } 397 } 398} 399 400// Safety: The state protocol ensured synchronized access to the inner task 401unsafe impl Send for TaskRef {} 402// Safety: The state protocol ensured synchronized access to the inner task 403unsafe impl Sync for TaskRef {} 404 405// === impl Task === 406 407impl<F: Future> Task<F> { 408 const TASK_VTABLE: VTable = VTable { 409 poll: Self::poll, 410 poll_join: Self::poll_join, 411 deallocate: Self::deallocate, 412 wake_by_ref: Schedulable::wake_by_ref, 413 }; 414 415 loom_const_fn! { 416 pub const fn new(future: F, task_id: Id, span: tracing::Span) -> Self { 417 let inner = TaskInner { 418 schedulable: Schedulable { 419 header: Header { 420 state: State::new(), 421 vtable: &Self::TASK_VTABLE, 422 id: task_id, 423 run_queue_links: mpsc_queue::Links::new(), 424 span, 425 }, 426 scheduler: UnsafeCell::new(None) 427 }, 428 stage: UnsafeCell::new(Stage::Pending(future)), 429 join_waker: UnsafeCell::new(None), 430 }; 431 Self(CachePadded(inner)) 432 } 433 } 434 435 /// Poll the future, returning a [`PollResult`] that indicates what the 436 /// scheduler should do with the polled task. 437 /// 438 /// This is a type-erased function called through the task's [`Vtable`]. 439 /// 440 /// # Safety 441 /// 442 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 443 /// pointed header must have the same `S`, `F`, and `STO` type parameters 444 /// as `Self`). 445 unsafe fn poll(ptr: NonNull<Header>) -> PollResult { 446 // Safety: ensured by caller 447 unsafe { 448 let this = ptr.cast::<Self>().as_ref(); 449 450 tracing::trace!( 451 task.addr=?ptr, 452 task.output=type_name::<F::Output>(), 453 task.id=?this.id(), 454 "Task::poll", 455 ); 456 457 match this.state().start_poll() { 458 // Successfully to transitioned to `POLLING` all is good! 459 StartPollAction::Poll => {} 460 // Something isn't right, we shouldn't poll the task right now... 461 StartPollAction::DontPoll => { 462 tracing::warn!(task.addr=?ptr, "failed to transition to polling",); 463 return PollResult::Ready; 464 } 465 StartPollAction::Cancelled { wake_join_waker } => { 466 tracing::trace!(task.addr=?ptr, "task cancelled"); 467 if wake_join_waker { 468 this.wake_join_waker(); 469 return PollResult::ReadyJoined; 470 } else { 471 return PollResult::Ready; 472 } 473 } 474 } 475 476 // wrap the waker in `ManuallyDrop` because we're converting it from an 477 // existing task ref, rather than incrementing the task ref count. if 478 // this waker is consumed during the poll, we don't want to decrement 479 // its ref count when the poll ends. 480 let waker = { 481 let raw = Schedulable::raw_waker(ptr.as_ptr().cast()); 482 mem::ManuallyDrop::new(Waker::from_raw(raw)) 483 }; 484 485 // actually poll the task 486 let poll = { 487 let cx = Context::from_waker(&waker); 488 489 this.poll_inner(cx) 490 }; 491 492 let result = this.state().end_poll(poll.is_ready()); 493 494 // if the task is ready and has a `JoinHandle` to wake, wake the join 495 // waker now. 496 if result == PollResult::ReadyJoined { 497 this.wake_join_waker(); 498 } 499 500 result 501 } 502 } 503 504 /// Poll to join the task pointed to by `ptr`, taking its output if it has 505 /// completed. 506 /// 507 /// If the task has completed, this method returns [`Poll::Ready`], and the 508 /// task's output is stored at the memory location pointed to by `outptr`. 509 /// This function is called by [`JoinHandle`]s o poll the task they 510 /// correspond to. 511 /// 512 /// This is a type-erased function called through the task's [`Vtable`]. 513 /// 514 /// # Safety 515 /// 516 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 517 /// pointed header must have the same `S`, `F`, and `STO` type parameters 518 /// as `Self`). 519 /// - `outptr` must point to a valid `MaybeUninit<F::Output>`. 520 unsafe fn poll_join( 521 ptr: NonNull<Header>, 522 outptr: NonNull<()>, 523 cx: &mut Context<'_>, 524 ) -> Poll<Result<(), JoinError<()>>> { 525 // Safety: ensured by caller 526 unsafe { 527 let this = ptr.cast::<Self>().as_ref(); 528 tracing::trace!( 529 task.addr=?ptr, 530 task.output=type_name::<F::Output>(), 531 task.id=?this.id(), 532 "Task::poll_join" 533 ); 534 535 match this.state().try_join() { 536 JoinAction::TakeOutput => { 537 // safety: if the state transition returns 538 // `JoinAction::TakeOutput`, this indicates that we have 539 // exclusive permission to read the task output. 540 this.take_output(outptr); 541 return Poll::Ready(Ok(())); 542 } 543 JoinAction::Canceled { completed } => { 544 // if the task has completed before it was canceled, also try to 545 // read the output, so that it can be returned in the `JoinError`. 546 if completed { 547 // safety: if the state transition returned `Canceled` 548 // with `completed` set, this indicates that we have 549 // exclusive permission to take the output. 550 this.take_output(outptr); 551 } 552 return Poll::Ready(Err(JoinError::cancelled(completed, *this.id()))); 553 } 554 JoinAction::Register => { 555 this.0.0.join_waker.with_mut(|waker| { 556 waker.write(Some(cx.waker().clone())); 557 }); 558 } 559 JoinAction::Reregister => { 560 this.0.0.join_waker.with_mut(|waker| { 561 let waker = (*waker).as_mut().unwrap(); 562 563 let new_waker = cx.waker(); 564 if !waker.will_wake(new_waker) { 565 *waker = new_waker.clone(); 566 } 567 }); 568 } 569 } 570 this.state().join_waker_registered(); 571 Poll::Pending 572 } 573 } 574 575 /// Drops the task and deallocates its memory. 576 /// 577 /// This is a type-erased function called through the task's [`Vtable`]. 578 /// 579 /// # Safety 580 /// 581 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 582 /// pointed header must have the same `S`, `F`, and `STO` type parameters 583 /// as `Self`). 584 unsafe fn deallocate(ptr: NonNull<Header>) { 585 // Safety: ensured by caller 586 unsafe { 587 let this = ptr.cast::<Self>(); 588 tracing::trace!( 589 task.addr=?ptr, 590 task.output=type_name::<F::Output>(), 591 task.id=?this.as_ref().id(), 592 task.is_stub=?this.as_ref().id().is_stub(), 593 "Task::deallocate", 594 ); 595 debug_assert_eq!( 596 ptr.as_ref().state.load(Ordering::Acquire).ref_count(), 597 0, 598 "a task may not be deallocated if its ref count is greater than zero!" 599 ); 600 drop(Box::from_raw(this.as_ptr())); 601 } 602 } 603 604 /// Polls the future. If the future completes, the output is written to the 605 /// stage field. 606 /// 607 /// # Safety 608 /// 609 /// The caller has to ensure this cpu has exclusive mutable access to the tasks `stage` field (ie the 610 /// future or output). 611 pub unsafe fn poll_inner(&self, mut cx: Context) -> Poll<()> { 612 let _span = self.span().enter(); 613 614 self.0.0.stage.with_mut(|stage| { 615 // Safety: ensured by caller 616 let stage = unsafe { &mut *stage }; 617 stage.poll(&mut cx, *self.id()) 618 }) 619 } 620 621 /// Wakes the task's [`JoinHandle`], if it has one. 622 /// 623 /// # Safety 624 /// 625 /// - The caller must have exclusive access to the task's `JoinWaker`. This 626 /// is ensured by the task's state management. 627 unsafe fn wake_join_waker(&self) { 628 // Safety: ensured by caller 629 unsafe { 630 self.0.0.join_waker.with_mut(|waker| { 631 if let Some(join_waker) = (*waker).take() { 632 tracing::trace!("waking {join_waker:?}"); 633 join_waker.wake(); 634 } else { 635 tracing::trace!("called wake_join_waker on non-existing waker"); 636 } 637 }); 638 } 639 } 640 641 unsafe fn take_output(&self, dst: NonNull<()>) { 642 // Safety: ensured by caller 643 unsafe { 644 self.0.0.stage.with_mut(|stage| { 645 match mem::replace(&mut *stage, Stage::Consumed) { 646 Stage::Ready(output) => { 647 // let output = self.stage.take_output(); 648 // safety: the caller is responsible for ensuring that this 649 // points to a `MaybeUninit<F::Output>`. 650 let dst = dst 651 .cast::<CheckedMaybeUninit<Result<F::Output, JoinError<F::Output>>>>() 652 .as_mut(); 653 654 // that's right, it goes in the `NonNull<()>` hole! 655 dst.write(output); 656 } 657 _ => panic!("JoinHandle polled after completion"), 658 } 659 }); 660 } 661 } 662 663 fn id(&self) -> &Id { 664 &self.0.0.schedulable.header.id 665 } 666 fn state(&self) -> &State { 667 &self.0.0.schedulable.header.state 668 } 669 #[inline] 670 fn span(&self) -> &tracing::Span { 671 &self.0.0.schedulable.header.span 672 } 673} 674 675impl Task<Stub> { 676 const HEAP_STUB_VTABLE: VTable = VTable { 677 poll: stub_poll, 678 poll_join: stub_poll_join, 679 // Heap allocated stub tasks *will* need to be deallocated, since the 680 // scheduler will deallocate its stub task if it's dropped. 681 deallocate: Self::deallocate, 682 wake_by_ref: stub_wake_by_ref, 683 }; 684 685 loom_const_fn! { 686 /// Create a new stub task. 687 pub(crate) const fn new_stub() -> Self { 688 let inner = TaskInner { 689 schedulable: Schedulable { 690 header: Header { 691 state: State::new(), 692 vtable: &Self::HEAP_STUB_VTABLE, 693 id: Id::stub(), 694 run_queue_links: mpsc_queue::Links::new_stub(), 695 span: tracing::Span::none(), 696 }, 697 scheduler: UnsafeCell::new(None) 698 }, 699 stage: UnsafeCell::new(Stage::Pending(Stub)), 700 join_waker: UnsafeCell::new(None), 701 }; 702 703 Self(CachePadded(inner)) 704 } 705 } 706} 707 708// === impl Stage === 709 710impl<F> Stage<F> 711where 712 F: Future, 713{ 714 fn poll(&mut self, cx: &mut Context<'_>, id: Id) -> Poll<()> { 715 struct Guard<'a, T: Future> { 716 stage: &'a mut Stage<T>, 717 } 718 impl<T: Future> Drop for Guard<'_, T> { 719 fn drop(&mut self) { 720 // If the future panics on poll, we drop it inside the panic 721 // guard. 722 // Safety: caller has to ensure mutual exclusion 723 *self.stage = Stage::Consumed; 724 } 725 } 726 727 let poll = AssertUnwindSafe(|| -> Poll<F::Output> { 728 let guard = Guard { stage: self }; 729 730 // Safety: caller has to ensure mutual exclusion 731 let Stage::Pending(future) = guard.stage else { 732 // TODO this will be caught by the `catch_unwind` which isn't great 733 unreachable!("unexpected stage"); 734 }; 735 736 // Safety: The caller ensures the future is pinned. 737 let future = unsafe { Pin::new_unchecked(future) }; 738 let res = future.poll(cx); 739 mem::forget(guard); 740 res 741 }); 742 743 cfg_if::cfg_if! { 744 if #[cfg(test)] { 745 let result = ::std::panic::catch_unwind(poll); 746 } else if #[cfg(feature = "unwind2")] { 747 let result = panic_unwind2::catch_unwind(poll); 748 } else { 749 let result = Ok(poll()); 750 } 751 } 752 753 match result { 754 Ok(Poll::Pending) => Poll::Pending, 755 Ok(Poll::Ready(ready)) => { 756 *self = Stage::Ready(Ok(ready)); 757 Poll::Ready(()) 758 } 759 Err(err) => { 760 *self = Stage::Ready(Err(JoinError::panic(id, err))); 761 Poll::Ready(()) 762 } 763 } 764 } 765} 766 767// === impl Schedulable === 768 769impl Schedulable { 770 const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new( 771 Self::clone_waker, 772 Self::wake_by_val, 773 Self::wake_by_ref, 774 Self::drop_waker, 775 ); 776 777 // `Waker::will_wake` is used all over the place to optimize waker code (e.g. only update wakers if they 778 // have a different wake target). Problem is `will_wake` only checks for pointer equality and since 779 // the `into_raw_waker` would usually be inlined in release mode (and with it `WAKER_VTABLE`) the 780 // Waker identity would be different before and after calling `.clone()`. This isn't a correctness 781 // problem since it's still the same waker in the end, it just causes a lot of unnecessary wake ups. 782 // the `inline(never)` below is therefore quite load-bearing 783 #[inline(never)] 784 fn raw_waker(this: *const Self) -> RawWaker { 785 RawWaker::new(this.cast::<()>(), &Self::WAKER_VTABLE) 786 } 787 788 #[inline(always)] 789 fn state(&self) -> &State { 790 &self.header.state 791 } 792 793 unsafe fn schedule(this: TaskRef) { 794 // Safety: ensured by caller 795 unsafe { 796 this.header_ptr() 797 .cast::<Self>() 798 .as_ref() 799 .scheduler 800 .with(|scheduler| { 801 (*scheduler) 802 .as_ref() 803 .expect("task doesn't have an associated scheduler, this is a bug!") 804 .schedule(this); 805 }); 806 } 807 } 808 809 #[inline] 810 unsafe fn drop_ref(this: NonNull<Self>) { 811 // Safety: ensured by caller 812 unsafe { 813 tracing::trace!(task.addr=?this, task.id=?this.as_ref().header.id, "Task::drop_ref"); 814 if !this.as_ref().state().drop_ref() { 815 return; 816 } 817 818 let deallocate = this.as_ref().header.vtable.deallocate; 819 deallocate(this.cast::<Header>()); 820 } 821 } 822 823 // === Waker vtable methods === 824 825 unsafe fn wake_by_val(ptr: *const ()) { 826 // Safety: called through RawWakerVtable 827 unsafe { 828 let ptr = ptr.cast::<Self>(); 829 tracing::trace!( 830 target: "scheduler:waker", 831 { 832 task.addr = ?ptr, 833 task.tid = (*ptr).header.id.as_u64() 834 }, 835 "Task::wake_by_val" 836 ); 837 838 let this = NonNull::new_unchecked(ptr.cast_mut()); 839 match this.as_ref().header.state.wake_by_val() { 840 WakeByValAction::Enqueue => { 841 // the task should be enqueued. 842 // 843 // in the case that the task is enqueued, the state 844 // transition does *not* decrement the reference count. this is 845 // in order to avoid dropping the task while it is being 846 // scheduled. one reference is consumed by enqueuing the task... 847 Self::schedule(TaskRef(this.cast::<Header>())); 848 // now that the task has been enqueued, decrement the reference 849 // count to drop the waker that performed the `wake_by_val`. 850 Self::drop_ref(this); 851 } 852 WakeByValAction::Drop => Self::drop_ref(this), 853 WakeByValAction::None => {} 854 } 855 } 856 } 857 858 unsafe fn wake_by_ref(ptr: *const ()) { 859 // Safety: called through RawWakerVtable 860 unsafe { 861 let this = ptr.cast::<Self>(); 862 tracing::trace!( 863 target: "scheduler:waker", 864 { 865 task.addr = ?this, 866 task.tid = (*this).header.id.as_u64() 867 }, 868 "Task::wake_by_ref" 869 ); 870 871 let this = NonNull::new_unchecked(this.cast_mut()).cast::<Self>(); 872 if this.as_ref().state().wake_by_ref() == WakeByRefAction::Enqueue { 873 Self::schedule(TaskRef(this.cast::<Header>())); 874 } 875 } 876 } 877 878 unsafe fn clone_waker(ptr: *const ()) -> RawWaker { 879 // Safety: called through RawWakerVtable 880 unsafe { 881 let ptr = ptr.cast::<Self>(); 882 tracing::trace!( 883 target: "scheduler:waker", 884 { 885 task.addr = ?ptr, 886 task.tid = (*ptr).header.id.as_u64() 887 }, 888 "Task::clone_waker" 889 ); 890 891 (*ptr).header.state.clone_ref(); 892 Self::raw_waker(ptr) 893 } 894 } 895 896 unsafe fn drop_waker(ptr: *const ()) { 897 // Safety: called through RawWakerVtable 898 unsafe { 899 let ptr = ptr.cast::<Self>(); 900 tracing::trace!( 901 target: "scheduler:waker", 902 { 903 task.addr = ?ptr, 904 task.tid = (*ptr).header.id.as_u64() 905 }, 906 "Task::drop_waker" 907 ); 908 909 let this = ptr.cast_mut(); 910 Self::drop_ref(NonNull::new_unchecked(this)); 911 } 912 } 913} 914 915// === impl Header === 916 917// Safety: tasks are always treated as pinned in memory (a requirement for polling them) 918// and care has been taken below to ensure the underlying memory isn't freed as long as the 919// `TaskRef` is part of the owned tasks list. 920unsafe impl cordyceps::Linked<mpsc_queue::Links<Self>> for Header { 921 type Handle = TaskRef; 922 923 fn into_ptr(task: Self::Handle) -> NonNull<Self> { 924 let ptr = task.0; 925 // converting a `TaskRef` into a pointer to enqueue it assigns ownership 926 // of the ref count to the queue, so we don't want to run its `Drop` 927 // impl. 928 mem::forget(task); 929 ptr 930 } 931 932 unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle { 933 TaskRef(ptr) 934 } 935 936 unsafe fn links(ptr: NonNull<Self>) -> NonNull<mpsc_queue::Links<Self>> 937 where 938 Self: Sized, 939 { 940 // Safety: `TaskRef` is just a newtype wrapper around `NonNull<Header>` 941 ptr.map_addr(|addr| { 942 let offset = offset_of!(Self, run_queue_links); 943 addr.checked_add(offset).unwrap() 944 }) 945 .cast() 946 } 947} 948 949/// DO NOT confuse this with [`TaskSTub`]. This type is just a zero-size placeholder so we 950/// can plug *something* into the generics when creating the *heap allocated* stub task. 951/// This type is *not* publicly exported, contrary to [`TaskSTub`] which users will have to statically 952/// allocate themselves. 953#[derive(Copy, Clone, Debug)] 954pub(crate) struct Stub; 955 956impl Future for Stub { 957 type Output = (); 958 fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> { 959 unreachable!("the stub task should never be polled!") 960 } 961} 962 963unsafe fn stub_poll(ptr: NonNull<Header>) -> PollResult { 964 // Safety: this method should never be called 965 unsafe { 966 debug_assert!(ptr.as_ref().id.is_stub()); 967 unreachable!("stub task ({ptr:?}) should never be polled!"); 968 } 969} 970 971unsafe fn stub_poll_join( 972 ptr: NonNull<Header>, 973 _outptr: NonNull<()>, 974 _cx: &mut Context<'_>, 975) -> Poll<Result<(), JoinError<()>>> { 976 // Safety: this method should never be called 977 unsafe { 978 debug_assert!(ptr.as_ref().id.is_stub()); 979 unreachable!("stub task ({ptr:?}) should never be polled!"); 980 } 981} 982 983unsafe fn stub_wake_by_ref(ptr: *const ()) { 984 unreachable!("stub task ({ptr:p}) has no waker and should never be woken!"); 985}