Next Generation WASM Microkernel Operating System
at main 947 lines 34 kB view raw
1// Copyright 2025. Jonas Kruckenberg 2// 3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or 4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or 5// http://opensource.org/licenses/MIT>, at your option. This file may not be 6// copied, modified, or distributed except according to those terms. 7 8mod builder; 9mod id; 10mod join_handle; 11mod state; 12mod yield_now; 13 14use alloc::boxed::Box; 15use core::alloc::AllocError; 16use core::any::type_name; 17use core::mem::{ManuallyDrop, offset_of}; 18use core::panic::AssertUnwindSafe; 19use core::pin::Pin; 20use core::ptr::NonNull; 21use core::sync::atomic::Ordering; 22use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; 23use core::{fmt, mem}; 24 25pub use builder::TaskBuilder; 26use cordyceps::mpsc_queue; 27pub use id::Id; 28pub use join_handle::{JoinError, JoinHandle}; 29use k32_util::{CachePadded, CheckedMaybeUninit, loom_const_fn}; 30pub use yield_now::yield_now; 31 32use crate::executor::Scheduler; 33use crate::loom::cell::UnsafeCell; 34use crate::task::state::{JoinAction, StartPollAction, State, WakeByRefAction, WakeByValAction}; 35 36/// Outcome of calling [`Task::poll`]. 37/// 38/// This type describes how to proceed with a given task, whether it needs to be rescheduled 39/// or can be dropped etc. 40#[derive(Debug, Clone, Copy, PartialEq, Eq)] 41pub(crate) enum PollResult { 42 /// The task has completed, without waking a [`JoinHandle`] waker. 43 /// 44 /// The scheduler can increment a counter of completed tasks, and then drop 45 /// the [`TaskRef`]. 46 Ready, 47 48 /// The task has completed and a [`JoinHandle`] waker has been woken. 49 /// 50 /// The scheduler can increment a counter of completed tasks, and then drop 51 /// the [`TaskRef`]. 52 ReadyJoined, 53 54 /// The task is pending, but not woken. 55 /// 56 /// The scheduler can drop the [`TaskRef`], as whoever intends to wake the 57 /// task later is holding a clone of its [`Waker`]. 58 Pending, 59 60 /// The task has woken itself during the poll. 61 /// 62 /// The scheduler should re-schedule the task, rather than dropping the [`TaskRef`]. 63 PendingSchedule, 64} 65 66/// A type-erased, reference-counted pointer to a spawned `Task`. 67/// 68/// Once a `Task` is spawned, it is generally pinned in memory (a requirement of [`Future`]). Instead 69/// of moving `Task`s around the scheduler, we therefore use `TaskRef`s which are just pointers to the 70/// pinned `Task`. `TaskRef`s are type-erased interacting with the allocated `Tasks` through their 71/// `Vtable` methods. This is done to reduce the monopolization cost otherwise incurred, since futures, 72/// especially ones crated through `async {}` blocks, `async` closures or `async fn` calls are all 73/// treated as *unique*, *disjoint* types which would all cause separate normalizations. E.g. spawning 74/// 10 futures on the runtime (which is not a lot) would cause 10 different copies of the entire runtime 75/// to be compiled, obviously terrible! The `Vtable` allows us to treat all spawned futures, regardless 76/// of their exact type, the same way. 77/// 78/// `TaskRef`s are reference-counted, and the task will be deallocated when the 79/// last `TaskRef` pointing to it is dropped. 80#[derive(Eq, PartialEq)] 81pub struct TaskRef(NonNull<Header>); 82 83#[repr(C)] 84struct Task<F: Future, M>(CachePadded<TaskInner<F, M>>); 85 86#[repr(C)] 87struct TaskInner<F: Future, M> { 88 /// This must be the first field of the `Task` struct! 89 header_and_metadata: HeaderAndMetadata<M>, 90 91 /// The future that the task is running. 92 /// 93 /// If `COMPLETE` is one, then the `JoinHandle` has exclusive access to this field 94 /// If COMPLETE is zero, then the RUNNING bitfield functions as 95 /// a lock for the stage field, and it can be accessed only by the thread 96 /// that set RUNNING to one. 97 stage: UnsafeCell<Stage<F>>, 98 99 /// Consumer task waiting on completion of this task. 100 /// 101 /// This field may be access by different threads: on one cpu we may complete a task and *read* 102 /// the waker field to invoke the waker, and in another thread the task's `JoinHandle` may be 103 /// polled, and if the task hasn't yet completed, the `JoinHandle` may *write* a waker to the 104 /// waker field. The `JOIN_WAKER` bit in the headers`state` field ensures safe access by multiple 105 /// cpu to the waker field using the following rules: 106 /// 107 /// 1. `JOIN_WAKER` is initialized to zero. 108 /// 109 /// 2. If `JOIN_WAKER` is zero, then the `JoinHandle` has exclusive (mutable) 110 /// access to the waker field. 111 /// 112 /// 3. If `JOIN_WAKER` is one, then the `JoinHandle` has shared (read-only) access to the waker 113 /// field. 114 /// 115 /// 4. If `JOIN_WAKER` is one and COMPLETE is one, then the executor has shared (read-only) access 116 /// to the waker field. 117 /// 118 /// 5. If the `JoinHandle` needs to write to the waker field, then the `JoinHandle` needs to 119 /// (i) successfully set `JOIN_WAKER` to zero if it is not already zero to gain exclusive access 120 /// to the waker field per rule 2, (ii) write a waker, and (iii) successfully set `JOIN_WAKER` 121 /// to one. If the `JoinHandle` unsets `JOIN_WAKER` in the process of being dropped 122 /// to clear the waker field, only steps (i) and (ii) are relevant. 123 /// 124 /// 6. The `JoinHandle` can change `JOIN_WAKER` only if COMPLETE is zero (i.e. 125 /// the task hasn't yet completed). The executor can change `JOIN_WAKER` only 126 /// if COMPLETE is one. 127 /// 128 /// 7. If `JOIN_INTEREST` is zero and COMPLETE is one, then the executor has 129 /// exclusive (mutable) access to the waker field. This might happen if the 130 /// `JoinHandle` gets dropped right after the task completes and the executor 131 /// sets the `COMPLETE` bit. In this case the executor needs the mutable access 132 /// to the waker field to drop it. 133 /// 134 /// Rule 6 implies that the steps (i) or (iii) of rule 5 may fail due to a 135 /// race. If step (i) fails, then the attempt to write a waker is aborted. If step (iii) fails 136 /// because `COMPLETE` is set to one by another thread after step (i), then the waker field is 137 /// cleared. Once `COMPLETE` is one (i.e. task has completed), the `JoinHandle` will not 138 /// modify `JOIN_WAKER`. After the runtime sets COMPLETE to one, it invokes the waker if there 139 /// is one so in this case when a task completes the `JOIN_WAKER` bit implicates to the runtime 140 /// whether it should invoke the waker or not. After the runtime is done with using the waker 141 /// during task completion, it unsets the `JOIN_WAKER` bit to give the `JoinHandle` exclusive 142 /// access again so that it is able to drop the waker at a later point. 143 join_waker: UnsafeCell<Option<Waker>>, 144} 145 146#[repr(C)] 147struct HeaderAndMetadata<M> { 148 /// This must be the first field of the `HeaderAndMetadata` struct! 149 header: Header, 150 metadata: M, 151} 152 153/// The current lifecycle stage of the future. Either the future itself or its output. 154#[repr(C)] // https://github.com/rust-lang/miri/issues/3780 155enum Stage<F: Future> { 156 /// The future is still pending. 157 Pending(F), 158 159 /// The future has completed, and its output is ready to be taken by a 160 /// `JoinHandle`, if one exists. 161 Ready(Result<F::Output, JoinError<F::Output>>), 162 163 /// The future has completed, and the task's output has been taken or is not 164 /// needed. 165 Consumed, 166} 167 168#[derive(Debug)] 169pub(crate) struct Header { 170 /// The task's state. 171 /// 172 /// This field is access with atomic instructions, so it's always safe to access it. 173 state: State, 174 /// The task vtable for this task. 175 vtable: &'static VTable, 176 /// The task's ID. 177 id: Id, 178 run_queue_links: mpsc_queue::Links<Self>, 179 /// The tracing span associated with this task, for debugging purposes. 180 span: tracing::Span, 181 scheduler: UnsafeCell<Option<&'static Scheduler>>, 182} 183 184#[derive(Debug)] 185struct VTable { 186 /// Poll the future, returning a [`PollResult`] that indicates what the 187 /// scheduler should do with the polled task. 188 poll: unsafe fn(NonNull<Header>) -> PollResult, 189 190 /// Poll the task's `JoinHandle` for completion, storing the output at the 191 /// provided [`NonNull`] pointer if the task has completed. 192 /// 193 /// If the task has not completed, the [`Waker`] from the provided 194 /// [`Context`] is registered to be woken when the task completes. 195 // Splitting this up into type aliases just makes it *harder* to understand 196 // IMO... 197 #[expect(clippy::type_complexity, reason = "")] 198 poll_join: unsafe fn( 199 ptr: NonNull<Header>, 200 outptr: NonNull<()>, 201 cx: &mut Context<'_>, 202 ) -> Poll<Result<(), JoinError<()>>>, 203 204 /// Drops the task and deallocates its memory. 205 deallocate: unsafe fn(NonNull<Header>), 206} 207 208// === impl TaskRef === 209 210impl TaskRef { 211 /// Returns the tasks unique[^1] identifier. 212 /// 213 /// [^1]: Unique to all *currently running* tasks, *not* unique across spacetime. See [`Id`] for details. 214 pub fn id(&self) -> Id { 215 self.header().id 216 } 217 218 /// # Safety 219 /// 220 /// The caller must ensure the generic argument matches the metadata type this task got created with. 221 pub unsafe fn metadata<M>(&self) -> &M { 222 // Safety: ensured by caller 223 unsafe { &self.0.cast::<HeaderAndMetadata<M>>().as_ref().metadata } 224 } 225 226 /// Returns `true` when this task has run to completion. 227 pub fn is_complete(&self) -> bool { 228 self.state() 229 .load(Ordering::Acquire) 230 .get(state::Snapshot::COMPLETE) 231 } 232 233 /// Cancels the task. 234 pub fn cancel(&self) -> bool { 235 // try to set the canceled bit. 236 let canceled = self.state().cancel(); 237 238 // if the task was successfully canceled, wake it so that it can clean 239 // up after itself. 240 if canceled { 241 tracing::trace!("waking canceled task..."); 242 todo!() 243 } 244 245 canceled 246 } 247 248 pub(crate) fn poll(&self) -> PollResult { 249 let poll_fn = self.header().vtable.poll; 250 // Safety: Called through our Vtable so this access should be fine 251 unsafe { poll_fn(self.0) } 252 } 253 254 /// # Safety 255 /// 256 /// The caller needs to make sure that `T` is the same type as the one that this `TaskRef` was 257 /// created with. 258 pub(crate) unsafe fn poll_join<T>( 259 &self, 260 cx: &mut Context<'_>, 261 ) -> Poll<Result<T, JoinError<T>>> { 262 let poll_join_fn = self.header().vtable.poll_join; 263 let mut slot = CheckedMaybeUninit::<Result<T, JoinError<T>>>::uninit(); 264 265 // Safety: This is called through the Vtable and as long as the caller makes sure that the `T` is the right 266 // type, this call is safe 267 let result = unsafe { poll_join_fn(self.0, NonNull::from(&mut slot).cast::<()>(), cx) }; 268 269 result.map(|result| { 270 if let Err(e) = result { 271 let output = if e.is_completed() { 272 // Safety: if the task completed before being canceled, we can still 273 // take its output. 274 Some(unsafe { slot.assume_init_read() }?) 275 } else { 276 None 277 }; 278 Err(e.with_output(output)) 279 } else { 280 // Safety: if the poll function returned `Ok`, we get to take the 281 // output! 282 unsafe { slot.assume_init_read() } 283 } 284 }) 285 } 286 287 /// Bind this task to a new scheduler 288 /// 289 /// # Safety 290 /// 291 /// The new scheduler `S` must be of the **same** type as the scheduler that this task got created 292 /// with. The shape of the allocated tasks depend on the type of the scheduler, binding a task 293 /// to a differently typed scheduler will therefore cause invalid memory accesses. 294 pub(crate) fn bind_scheduler(&self, scheduler: &'static Scheduler) { 295 // Safety: the repr(C) on Schedulable ensures the layout matches and this cast is safe 296 unsafe { 297 self.0 298 .as_ref() 299 .scheduler 300 .with_mut(|current| *current = Some(scheduler)); 301 } 302 } 303 304 pub(crate) fn new_stub() -> Result<Self, AllocError> { 305 const HEAP_STUB_VTABLE: VTable = VTable { 306 poll: stub_poll, 307 poll_join: stub_poll_join, 308 deallocate: stub_deallocate, 309 }; 310 311 unsafe fn stub_poll(ptr: NonNull<Header>) -> PollResult { 312 // Safety: this method should never be called 313 unsafe { 314 debug_assert!(ptr.as_ref().id.is_stub()); 315 unreachable!("stub task ({ptr:?}) should never be polled!"); 316 } 317 } 318 319 unsafe fn stub_poll_join( 320 ptr: NonNull<Header>, 321 _outptr: NonNull<()>, 322 _cx: &mut Context<'_>, 323 ) -> Poll<Result<(), JoinError<()>>> { 324 // Safety: this method should never be called 325 unsafe { 326 debug_assert!(ptr.as_ref().id.is_stub()); 327 unreachable!("stub task ({ptr:?}) should never be polled!"); 328 } 329 } 330 331 unsafe fn stub_deallocate(ptr: NonNull<Header>) { 332 // Safety: this method should never be called 333 unsafe { 334 debug_assert!(ptr.as_ref().id.is_stub()); 335 drop(Box::from_raw(ptr.as_ptr())); 336 } 337 } 338 339 let inner = Box::into_raw(Box::try_new(Header { 340 state: State::new(), 341 vtable: &HEAP_STUB_VTABLE, 342 id: Id::stub(), 343 run_queue_links: mpsc_queue::Links::new_stub(), 344 span: tracing::Span::none(), 345 scheduler: UnsafeCell::new(None), 346 })?); 347 348 // Safety: we just allocated the ptr so it is never null 349 Ok(Self(unsafe { NonNull::new_unchecked(inner) })) 350 } 351 352 pub(crate) fn clone_from_raw(ptr: NonNull<Header>) -> TaskRef { 353 let this = Self(ptr); 354 this.state().clone_ref(); 355 this 356 } 357 358 pub(crate) fn header_ptr(&self) -> NonNull<Header> { 359 self.0 360 } 361 362 // ===== private methods ===== 363 364 #[track_caller] 365 fn new_allocated<F, M>(task: Box<Task<F, M>>) -> (Self, JoinHandle<F::Output, M>) 366 where 367 F: Future, 368 { 369 assert_eq!(task.state().refcount(), 1); 370 let ptr = Box::into_raw(task); 371 372 // Safety: we just allocated the ptr so it is never null 373 let task = Self(unsafe { NonNull::new_unchecked(ptr).cast() }); 374 let join = JoinHandle::new(task.clone()); 375 376 (task, join) 377 } 378 379 fn schedule(self) { 380 let scheduler = self.header().scheduler.with(|scheduler| { 381 // Safety: ensured by caller 382 unsafe { 383 (*scheduler).expect("task doesn't have an associated scheduler, this is a bug!") 384 } 385 }); 386 387 scheduler.schedule(self); 388 } 389 390 fn header(&self) -> &Header { 391 // Safety: constructor ensures the pointer is always valid 392 unsafe { self.0.as_ref() } 393 } 394 395 /// Returns a reference to the task's state. 396 fn state(&self) -> &State { 397 &self.header().state 398 } 399 400 unsafe fn from_raw(ptr: *const Header) -> Self { 401 // Safety: ensured by caller 402 Self(unsafe { NonNull::new_unchecked(ptr.cast_mut()) }) 403 } 404 405 fn into_raw(self) -> *const Header { 406 let this = ManuallyDrop::new(self); 407 this.0.as_ptr() 408 } 409 410 // ===== private waker-related methods ===== 411 412 fn wake(self) { 413 match self.state().wake_by_val() { 414 // `Enqueue` means we should put the task back into the queue. In this case we want to assign 415 // ownership of the TaskRef to the scheduler. 416 WakeByValAction::Enqueue => { 417 self.schedule(); 418 } 419 // `Drop` means (unsurprisingly) the opposite: The task does not need to be enqueued, and 420 // we should therefore drop this handle. 421 WakeByValAction::Drop => { 422 drop(self); 423 } 424 } 425 } 426 427 /// # Safety 428 /// 429 /// The caller must ensure this task is currently bound to a scheduler. 430 unsafe fn wake_by_val(&self) { 431 if self.state().wake_by_ref() == WakeByRefAction::Enqueue { 432 self.clone().schedule(); 433 } 434 } 435 436 fn into_raw_waker(self) -> RawWaker { 437 // Increment the reference count of the arc to clone it. 438 // 439 // The #[inline(always)] is to ensure that raw_waker and clone_waker are 440 // always generated in the same code generation unit as one another, and 441 // therefore that the structurally identical const-promoted RawWakerVTable 442 // within both functions is deduplicated at LLVM IR code generation time. 443 // This allows optimizing Waker::will_wake to a single pointer comparison of 444 // the vtable pointers, rather than comparing all four function pointers 445 // within the vtables. 446 #[inline(always)] 447 unsafe fn clone_waker(waker: *const ()) -> RawWaker { 448 // Safety: ensured by VTable 449 unsafe { 450 waker.cast::<Header>().as_ref().unwrap().state.clone_ref(); 451 } 452 453 RawWaker::new( 454 waker, 455 &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), 456 ) 457 } 458 459 // Wake by value, moving the Arc into the Wake::wake function 460 unsafe fn wake(waker: *const ()) { 461 // Safety: ensured by VTable 462 let task = unsafe { TaskRef::from_raw(waker.cast::<Header>()) }; 463 464 tracing::trace!( 465 task.addr = ?task.header_ptr(), 466 task.tid = task.id().as_u64(), 467 "Task::wake_by_val" 468 ); 469 470 task.wake(); 471 } 472 473 // Wake by reference, wrap the waker in ManuallyDrop to avoid dropping it 474 unsafe fn wake_by_ref(waker: *const ()) { 475 // Safety: ensured by VTable 476 let task = unsafe { ManuallyDrop::new(TaskRef::from_raw(waker.cast::<Header>())) }; 477 478 tracing::trace!( 479 task.addr = ?task.header_ptr(), 480 task.tid = task.id().as_u64(), 481 "Task::wake_by_ref" 482 ); 483 484 // Safety: ensured by the caller 485 unsafe { 486 task.wake_by_val(); 487 } 488 } 489 490 // Decrement the reference count of the Arc on drop 491 unsafe fn drop_waker(waker: *const ()) { 492 // Safety: ensured by VTable 493 drop(unsafe { TaskRef::from_raw(waker.cast()) }); 494 } 495 496 RawWaker::new( 497 self.into_raw().cast(), 498 &RawWakerVTable::new(clone_waker, wake, wake_by_ref, drop_waker), 499 ) 500 } 501} 502 503impl fmt::Debug for TaskRef { 504 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 505 f.debug_struct("TaskRef") 506 .field("id", &self.id()) 507 .field("addr", &self.0) 508 .finish() 509 } 510} 511 512impl fmt::Pointer for TaskRef { 513 #[inline] 514 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 515 fmt::Pointer::fmt(&self.0, f) 516 } 517} 518 519impl Clone for TaskRef { 520 #[inline] 521 #[track_caller] 522 fn clone(&self) -> Self { 523 let loc = core::panic::Location::caller(); 524 tracing::trace!( 525 task.addr=?self.0, 526 task.is_stub=self.id().is_stub(), 527 loc.file = loc.file(), 528 loc.line = loc.line(), 529 loc.col = loc.column(), 530 "TaskRef::clone", 531 ); 532 self.state().clone_ref(); 533 Self(self.0) 534 } 535} 536 537impl Drop for TaskRef { 538 #[inline] 539 #[track_caller] 540 fn drop(&mut self) { 541 tracing::trace!( 542 task.addr=?self.0, 543 task.is_stub=self.id().is_stub(), 544 "TaskRef::drop" 545 ); 546 if !self.state().drop_ref() { 547 return; 548 } 549 550 let deallocate = self.header().vtable.deallocate; 551 // Safety: as long as we're constructed from a NonNull<Header> this is safe 552 unsafe { 553 deallocate(self.0); 554 } 555 } 556} 557 558// Safety: The state protocol ensured synchronized access to the inner task 559unsafe impl Send for TaskRef {} 560// Safety: The state protocol ensured synchronized access to the inner task 561unsafe impl Sync for TaskRef {} 562 563// === impl Task === 564 565impl<F: Future, M> Task<F, M> { 566 const TASK_VTABLE: VTable = VTable { 567 poll: Self::poll, 568 poll_join: Self::poll_join, 569 deallocate: Self::deallocate, 570 }; 571 572 loom_const_fn! { 573 pub const fn new(future: F, task_id: Id, span: tracing::Span, metadata: M) -> Self { 574 let inner = TaskInner { 575 header_and_metadata: HeaderAndMetadata { 576 header: Header { 577 state: State::new(), 578 vtable: &Self::TASK_VTABLE, 579 id: task_id, 580 run_queue_links: mpsc_queue::Links::new(), 581 span, 582 scheduler: UnsafeCell::new(None) 583 }, 584 metadata 585 }, 586 stage: UnsafeCell::new(Stage::Pending(future)), 587 join_waker: UnsafeCell::new(None), 588 }; 589 Self(CachePadded(inner)) 590 } 591 } 592 593 /// Poll the future, returning a [`PollResult`] that indicates what the 594 /// scheduler should do with the polled task. 595 /// 596 /// This is a type-erased function called through the task's [`Vtable`]. 597 /// 598 /// # Safety 599 /// 600 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 601 /// pointed header must have the same `S`, `F`, and `STO` type parameters 602 /// as `Self`). 603 /// - the task must currently be bound to a scheduler 604 unsafe fn poll(ptr: NonNull<Header>) -> PollResult { 605 // Safety: ensured by caller 606 unsafe { 607 let this = ptr.cast::<Self>().as_ref(); 608 609 tracing::trace!( 610 task.addr=?ptr, 611 task.output=type_name::<F::Output>(), 612 task.id=?this.id(), 613 "Task::poll", 614 ); 615 616 match this.state().start_poll() { 617 // Successfully to transitioned to `POLLING` all is good! 618 StartPollAction::Poll => {} 619 // Something isn't right, we shouldn't poll the task right now... 620 StartPollAction::DontPoll => { 621 tracing::warn!(task.addr=?ptr, "failed to transition to polling",); 622 return PollResult::Ready; 623 } 624 StartPollAction::Cancelled { wake_join_waker } => { 625 tracing::trace!(task.addr=?ptr, "task cancelled"); 626 if wake_join_waker { 627 this.wake_join_waker(); 628 return PollResult::ReadyJoined; 629 } else { 630 return PollResult::Ready; 631 } 632 } 633 } 634 635 // wrap the waker in `ManuallyDrop` because we're converting it from an 636 // existing task ref, rather than incrementing the task ref count. if 637 // this waker is consumed during the poll, we don't want to decrement 638 // its ref count when the poll ends. 639 let waker = { 640 let raw = TaskRef(ptr).into_raw_waker(); 641 ManuallyDrop::new(Waker::from_raw(raw)) 642 }; 643 644 // actually poll the task 645 let poll = { 646 let cx = Context::from_waker(&waker); 647 648 this.poll_inner(cx) 649 }; 650 651 let result = this.state().end_poll(poll.is_ready()); 652 653 // if the task is ready and has a `JoinHandle` to wake, wake the join 654 // waker now. 655 if result == PollResult::ReadyJoined { 656 this.wake_join_waker(); 657 } 658 659 result 660 } 661 } 662 663 /// Poll to join the task pointed to by `ptr`, taking its output if it has 664 /// completed. 665 /// 666 /// If the task has completed, this method returns [`Poll::Ready`], and the 667 /// task's output is stored at the memory location pointed to by `outptr`. 668 /// This function is called by [`JoinHandle`]s o poll the task they 669 /// correspond to. 670 /// 671 /// This is a type-erased function called through the task's [`Vtable`]. 672 /// 673 /// # Safety 674 /// 675 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 676 /// pointed header must have the same `S`, `F`, and `STO` type parameters 677 /// as `Self`). 678 /// - `outptr` must point to a valid `MaybeUninit<F::Output>`. 679 unsafe fn poll_join( 680 ptr: NonNull<Header>, 681 outptr: NonNull<()>, 682 cx: &mut Context<'_>, 683 ) -> Poll<Result<(), JoinError<()>>> { 684 // Safety: ensured by caller 685 unsafe { 686 let this = ptr.cast::<Self>().as_ref(); 687 tracing::trace!( 688 task.addr=?ptr, 689 task.output=type_name::<F::Output>(), 690 task.id=?this.id(), 691 "Task::poll_join" 692 ); 693 694 match this.state().try_join() { 695 JoinAction::TakeOutput => { 696 // safety: if the state transition returns 697 // `JoinAction::TakeOutput`, this indicates that we have 698 // exclusive permission to read the task output. 699 this.take_output(outptr); 700 return Poll::Ready(Ok(())); 701 } 702 JoinAction::Canceled { completed } => { 703 // if the task has completed before it was canceled, also try to 704 // read the output, so that it can be returned in the `JoinError`. 705 if completed { 706 // safety: if the state transition returned `Canceled` 707 // with `completed` set, this indicates that we have 708 // exclusive permission to take the output. 709 this.take_output(outptr); 710 } 711 return Poll::Ready(Err(JoinError::cancelled(completed, *this.id()))); 712 } 713 JoinAction::Register => { 714 this.0.0.join_waker.with_mut(|waker| { 715 waker.write(Some(cx.waker().clone())); 716 }); 717 } 718 JoinAction::Reregister => { 719 this.0.0.join_waker.with_mut(|waker| { 720 let waker = (*waker).as_mut().unwrap(); 721 722 let new_waker = cx.waker(); 723 if !waker.will_wake(new_waker) { 724 *waker = new_waker.clone(); 725 } 726 }); 727 } 728 } 729 this.state().join_waker_registered(); 730 Poll::Pending 731 } 732 } 733 734 /// Drops the task and deallocates its memory. 735 /// 736 /// This is a type-erased function called through the task's [`Vtable`]. 737 /// 738 /// # Safety 739 /// 740 /// - `ptr` must point to the [`Header`] of a task of type `Self` (i.e. the 741 /// pointed header must have the same `S`, `F`, and `STO` type parameters 742 /// as `Self`). 743 unsafe fn deallocate(ptr: NonNull<Header>) { 744 // Safety: ensured by caller 745 unsafe { 746 let this = ptr.cast::<Self>(); 747 tracing::trace!( 748 task.addr=?ptr, 749 task.output=type_name::<F::Output>(), 750 task.id=?this.as_ref().id(), 751 task.is_stub=?this.as_ref().id().is_stub(), 752 "Task::deallocate", 753 ); 754 debug_assert_eq!( 755 ptr.as_ref().state.load(Ordering::Acquire).ref_count(), 756 0, 757 "a task may not be deallocated if its ref count is greater than zero!" 758 ); 759 drop(Box::from_raw(this.as_ptr())); 760 } 761 } 762 763 /// Polls the future. If the future completes, the output is written to the 764 /// stage field. 765 /// 766 /// # Safety 767 /// 768 /// The caller has to ensure this cpu has exclusive mutable access to the tasks `stage` field (ie the 769 /// future or output). 770 pub unsafe fn poll_inner(&self, mut cx: Context) -> Poll<()> { 771 let _span = self.span().enter(); 772 773 self.0.0.stage.with_mut(|stage| { 774 // Safety: ensured by caller 775 let stage = unsafe { &mut *stage }; 776 stage.poll(&mut cx, *self.id()) 777 }) 778 } 779 780 /// Wakes the task's [`JoinHandle`], if it has one. 781 /// 782 /// # Safety 783 /// 784 /// - The caller must have exclusive access to the task's `JoinWaker`. This 785 /// is ensured by the task's state management. 786 unsafe fn wake_join_waker(&self) { 787 // Safety: ensured by caller 788 unsafe { 789 self.0.0.join_waker.with_mut(|waker| { 790 if let Some(join_waker) = (*waker).take() { 791 tracing::trace!("waking {join_waker:?}"); 792 join_waker.wake(); 793 } else { 794 tracing::trace!("called wake_join_waker on non-existing waker"); 795 } 796 }); 797 } 798 } 799 800 unsafe fn take_output(&self, dst: NonNull<()>) { 801 // Safety: ensured by caller 802 unsafe { 803 self.0.0.stage.with_mut(|stage| { 804 match mem::replace(&mut *stage, Stage::Consumed) { 805 Stage::Ready(output) => { 806 // let output = self.stage.take_output(); 807 // safety: the caller is responsible for ensuring that this 808 // points to a `MaybeUninit<F::Output>`. 809 let dst = dst 810 .cast::<CheckedMaybeUninit<Result<F::Output, JoinError<F::Output>>>>() 811 .as_mut(); 812 813 // that's right, it goes in the `NonNull<()>` hole! 814 dst.write(output); 815 } 816 _ => panic!("JoinHandle polled after completion"), 817 } 818 }); 819 } 820 } 821 822 fn id(&self) -> &Id { 823 &self.0.0.header_and_metadata.header.id 824 } 825 fn state(&self) -> &State { 826 &self.0.0.header_and_metadata.header.state 827 } 828 #[inline] 829 fn span(&self) -> &tracing::Span { 830 &self.0.0.header_and_metadata.header.span 831 } 832} 833 834// === impl Stage === 835 836impl<F> Stage<F> 837where 838 F: Future, 839{ 840 fn poll(&mut self, cx: &mut Context<'_>, id: Id) -> Poll<()> { 841 struct Guard<'a, T: Future> { 842 stage: &'a mut Stage<T>, 843 } 844 impl<T: Future> Drop for Guard<'_, T> { 845 fn drop(&mut self) { 846 // If the future panics on poll, we drop it inside the panic 847 // guard. 848 // Safety: caller has to ensure mutual exclusion 849 *self.stage = Stage::Consumed; 850 } 851 } 852 853 let poll = AssertUnwindSafe(|| -> Poll<F::Output> { 854 let guard = Guard { stage: self }; 855 856 // Safety: caller has to ensure mutual exclusion 857 let Stage::Pending(future) = guard.stage else { 858 // TODO this will be caught by the `catch_unwind` which isn't great 859 unreachable!("unexpected stage"); 860 }; 861 862 // Safety: The caller ensures the future is pinned. 863 let future = unsafe { Pin::new_unchecked(future) }; 864 let res = future.poll(cx); 865 mem::forget(guard); 866 res 867 }); 868 869 cfg_if::cfg_if! { 870 if #[cfg(test)] { 871 let result = ::std::panic::catch_unwind(poll); 872 } else if #[cfg(feature = "k23-unwind")] { 873 let result = k23_panic_unwind::catch_unwind(poll); 874 } else { 875 let result = Ok(poll()); 876 } 877 } 878 879 match result { 880 Ok(Poll::Pending) => Poll::Pending, 881 Ok(Poll::Ready(ready)) => { 882 *self = Stage::Ready(Ok(ready)); 883 Poll::Ready(()) 884 } 885 Err(err) => { 886 *self = Stage::Ready(Err(JoinError::panic(id, err))); 887 Poll::Ready(()) 888 } 889 } 890 } 891} 892 893// === impl Header === 894 895// Safety: tasks are always treated as pinned in memory (a requirement for polling them) 896// and care has been taken below to ensure the underlying memory isn't freed as long as the 897// `TaskRef` is part of the owned tasks list. 898unsafe impl cordyceps::Linked<mpsc_queue::Links<Self>> for Header { 899 type Handle = TaskRef; 900 901 fn into_ptr(task: Self::Handle) -> NonNull<Self> { 902 let ptr = task.0; 903 // converting a `TaskRef` into a pointer to enqueue it assigns ownership 904 // of the ref count to the queue, so we don't want to run its `Drop` 905 // impl. 906 mem::forget(task); 907 ptr 908 } 909 910 unsafe fn from_ptr(ptr: NonNull<Self>) -> Self::Handle { 911 TaskRef(ptr) 912 } 913 914 unsafe fn links(ptr: NonNull<Self>) -> NonNull<mpsc_queue::Links<Self>> 915 where 916 Self: Sized, 917 { 918 // Safety: `TaskRef` is just a newtype wrapper around `NonNull<Header>` 919 ptr.map_addr(|addr| { 920 let offset = offset_of!(Self, run_queue_links); 921 addr.checked_add(offset).unwrap() 922 }) 923 .cast() 924 } 925} 926 927#[cfg(test)] 928mod tests { 929 use alloc::boxed::Box; 930 931 use crate::loom; 932 use crate::task::{Id, Task, TaskRef}; 933 934 #[test] 935 fn metadata() { 936 loom::model(|| { 937 let (t1, _) = TaskRef::new_allocated(Box::new(Task::new( 938 async {}, 939 Id::next(), 940 tracing::Span::none(), 941 42usize, 942 ))); 943 944 assert_eq!(unsafe { *t1.metadata::<usize>() }, 42); 945 }); 946 } 947}