Next Generation WASM Microkernel Operating System
at trap_handler 541 lines 19 kB view raw
1// Copyright 2025 Jonas Kruckenberg 2// 3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or 4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or 5// http://opensource.org/licenses/MIT>, at your option. This file may not be 6// copied, modified, or distributed except according to those terms. 7 8use crate::loom::sync::atomic::{self, AtomicUsize, Ordering}; 9use crate::task::PollResult; 10use core::fmt; 11use spin::Backoff; 12use util::loom_const_fn; 13 14/// Task state. The task stores its state in an atomic `usize` with various bitfields for the 15/// necessary information. The state has the following layout: 16/// 17/// ```text 18/// | 63 7 | 6 5 | 4 4 | 3 3 | 2 2 | 1 0 | 19/// | refcount | join waker | has join handle | cancelled | woken | lifecycle | 20/// ``` 21/// 22/// The rest of the bits are used for the ref-count. 23pub(crate) struct State { 24 val: AtomicUsize, 25} 26 27mycelium_bitfield::bitfield! { 28 /// A snapshot of a task's current state. 29 #[derive(PartialEq, Eq)] 30 pub(crate) struct Snapshot<usize> { 31 /// If set, this task is currently being polled. 32 pub const POLLING: bool; 33 /// If set, this task's `Future` has completed (i.e., it has returned 34 /// `Poll::Ready`). 35 pub const COMPLETE: bool; 36 /// If set, this task's `Waker` has been woken. 37 pub(crate) const WOKEN: bool; 38 /// If set, this task has been canceled. 39 pub const CANCELLED: bool; 40 /// If set, this task has a `JoinHandle` awaiting its completion. 41 /// 42 /// If the `JoinHandle` is dropped, this flag is unset. 43 /// 44 /// This flag does *not* indicate the presence of a `Waker` in the 45 /// `join_waker` slot; it only indicates that a `JoinHandle` for this 46 /// task *exists*. The join waker may not currently be registered if 47 /// this flag is set. 48 pub const HAS_JOIN_HANDLE: bool; 49 /// The state of the task's `JoinHandle` `Waker`. 50 const JOIN_WAKER: JoinWakerState; 51 /// If set, this task has output ready to be taken by a `JoinHandle`. 52 const HAS_OUTPUT: bool; 53 /// The number of currently live references to this task. 54 /// 55 /// When this is 0, the task may be deallocated. 56 const REFS = ..; 57 } 58} 59 60#[derive(Copy, Clone, Debug, Eq, PartialEq)] 61#[repr(u8)] 62enum JoinWakerState { 63 /// There is no join waker; the slot is uninitialized. 64 Empty = 0b00, 65 /// A join waker is *being* registered. 66 Registering = 0b01, 67 /// A join waker is registered, the slot is initialized. 68 Waiting = 0b10, 69 /// The join waker has been woken. 70 Woken = 0b11, 71} 72 73#[must_use] 74pub(super) enum StartPollAction { 75 /// Successful transition, it's okay to poll the task. 76 Poll, 77 /// Transition failed for some reason - most likely it is already running on another thread 78 /// (which shouldn't happen) - doesn't matter though we shouldn't poll the task. 79 DontPoll, 80 /// Transition failed because the task was cancelled and its `JoinHandle` waker may need to be woken. 81 Cancelled { 82 /// If `true`, the task's join waker must be woken. 83 wake_join_waker: bool, 84 }, 85} 86 87#[must_use] 88pub enum JoinAction { 89 /// It's safe to take the task's output! 90 TakeOutput, 91 92 /// The task was canceled, it cannot be joined. 93 Canceled { 94 /// If `true`, the task completed successfully before it was cancelled. 95 completed: bool, 96 }, 97 98 /// Register the *first* join waker; there is no previous join waker and the 99 /// slot is not initialized. 100 Register, 101 102 /// The task is not ready to read the output, but a previous join waker is 103 /// registered. 104 Reregister, 105} 106 107#[derive(Copy, Clone, Debug, PartialEq, Eq)] 108pub(super) enum WakeByRefAction { 109 /// The task should be enqueued. 110 Enqueue, 111 /// The task does not need to be enqueued. 112 None, 113} 114 115#[derive(Copy, Clone, Debug, PartialEq, Eq)] 116pub(super) enum WakeByValAction { 117 /// The task should be enqueued. 118 Enqueue, 119 /// The task does not need to be enqueued. 120 None, 121 /// The task should be deallocated. 122 Drop, 123} 124 125const REF_ONE: usize = Snapshot::REFS.first_bit(); 126const REF_MAX: usize = Snapshot::REFS.raw_mask(); 127 128impl State { 129 loom_const_fn! { 130 /// Returns a task's initial state. 131 pub(super) const fn new() -> State { 132 // The raw task returned by this method has a ref-count of three. See 133 // the comment on INITIAL_STATE for more. 134 State { 135 val: AtomicUsize::new(REF_ONE), 136 } 137 } 138 } 139 140 pub(super) fn load(&self, ordering: Ordering) -> Snapshot { 141 Snapshot(self.val.load(ordering)) 142 } 143 144 /// Attempt to transition the task from `IDLE` to `POLLING`, the returned enum indicates what 145 /// to the with the task. 146 /// 147 /// This method should always be followed by a call to [`Self::end_poll`] after the actual poll 148 /// is completed. 149 #[tracing::instrument(level = "trace")] 150 pub(super) fn start_poll(&self) -> StartPollAction { 151 let mut should_wait_for_join_waker = false; 152 let action = self.transition(|s| { 153 // cannot start polling a task which is being polled on another 154 // thread, or a task which has completed 155 if s.get(Snapshot::POLLING) || s.get(Snapshot::COMPLETE) { 156 return StartPollAction::DontPoll; 157 } 158 159 // if the task has been canceled, don't poll it. 160 if s.get(Snapshot::CANCELLED) { 161 let wake_join_waker = s.has_join_waker(&mut should_wait_for_join_waker); 162 return StartPollAction::Cancelled { wake_join_waker }; 163 } 164 165 s 166 // the task is now being polled. 167 .set(Snapshot::POLLING, true) 168 // if the task was woken, consume the wakeup. 169 .set(Snapshot::WOKEN, false); 170 171 StartPollAction::Poll 172 }); 173 174 if should_wait_for_join_waker { 175 debug_assert!(matches!(action, StartPollAction::Cancelled { .. })); 176 self.wait_for_join_waker(self.load(Ordering::Acquire)); 177 } 178 179 action 180 } 181 182 /// Transition the task from `POLLING` to `IDLE`, the returned enum indicates what to do with task. 183 /// The `completed` argument should be set to true if the polled future returned a `Poll::Ready` 184 /// indicating the task is completed and should not be rescheduled. 185 #[tracing::instrument(level = "trace")] 186 pub(super) fn end_poll(&self, completed: bool) -> PollResult { 187 let mut should_wait_for_join_waker = false; 188 let action = self.transition(|s| { 189 // Cannot end a poll if a task is not being polled! 190 debug_assert!(s.get(Snapshot::POLLING)); 191 debug_assert!(!s.get(Snapshot::COMPLETE)); 192 debug_assert!( 193 s.ref_count() > 0, 194 "cannot poll a task that has zero references, what is happening!" 195 ); 196 197 s.set(Snapshot::POLLING, false) 198 .set(Snapshot::COMPLETE, completed); 199 200 // Was the task woken during the poll? 201 if !completed && s.get(Snapshot::WOKEN) { 202 return PollResult::PendingSchedule; 203 } 204 205 let had_join_waker = if completed { 206 // set the output flag so that the JoinHandle knows it is now 207 // safe to read the task's output. 208 s.set(Snapshot::HAS_OUTPUT, true); 209 s.has_join_waker(&mut should_wait_for_join_waker) 210 } else { 211 false 212 }; 213 214 if had_join_waker { 215 PollResult::ReadyJoined 216 } else if completed { 217 PollResult::Ready 218 } else { 219 PollResult::Pending 220 } 221 }); 222 223 if should_wait_for_join_waker { 224 debug_assert_eq!(action, PollResult::ReadyJoined); 225 self.wait_for_join_waker(self.load(Ordering::Acquire)); 226 } 227 228 action 229 } 230 231 #[tracing::instrument(level = "trace")] 232 fn wait_for_join_waker(&self, mut state: Snapshot) { 233 let mut boff = Backoff::new(); 234 loop { 235 state.set(Snapshot::JOIN_WAKER, JoinWakerState::Waiting); 236 let next = state.with(Snapshot::JOIN_WAKER, JoinWakerState::Woken); 237 match self.val.compare_exchange_weak( 238 state.0, 239 next.0, 240 Ordering::AcqRel, 241 Ordering::Acquire, 242 ) { 243 Ok(_) => return, 244 Err(actual) => state = Snapshot(actual), 245 } 246 boff.spin(); 247 } 248 } 249 250 #[tracing::instrument(level = "trace")] 251 pub(super) fn try_join(&self) -> JoinAction { 252 fn should_register(s: &mut Snapshot) -> JoinAction { 253 let action = match s.get(Snapshot::JOIN_WAKER) { 254 JoinWakerState::Empty => JoinAction::Register, 255 x => { 256 debug_assert_eq!(x, JoinWakerState::Waiting); 257 JoinAction::Reregister 258 } 259 }; 260 s.set(Snapshot::JOIN_WAKER, JoinWakerState::Registering); 261 262 action 263 } 264 265 self.transition(|s| { 266 let has_output = s.get(Snapshot::HAS_OUTPUT); 267 268 if s.get(Snapshot::CANCELLED) { 269 return JoinAction::Canceled { 270 completed: has_output, 271 }; 272 } 273 274 // If the task has not completed, we can't take its join output. 275 if !s.get(Snapshot::COMPLETE) { 276 return should_register(s); 277 } 278 279 // If the task does not have output, we cannot take it. 280 if !has_output { 281 return should_register(s); 282 } 283 284 *s = s.with(Snapshot::HAS_OUTPUT, false); 285 JoinAction::TakeOutput 286 }) 287 } 288 289 #[tracing::instrument(level = "trace")] 290 pub(super) fn join_waker_registered(&self) { 291 self.transition(|s| { 292 debug_assert_eq!(s.get(Snapshot::JOIN_WAKER), JoinWakerState::Registering); 293 s.set(Snapshot::HAS_JOIN_HANDLE, true) 294 .set(Snapshot::JOIN_WAKER, JoinWakerState::Waiting); 295 }); 296 } 297 298 #[tracing::instrument(level = "trace")] 299 pub(super) fn wake_by_val(&self) -> WakeByValAction { 300 self.transition(|s| { 301 // If the task was woken *during* a poll, it will be re-queued by the 302 // scheduler at the end of the poll if needed, so don't enqueue it now. 303 if s.get(Snapshot::POLLING) { 304 *s = s.with(Snapshot::WOKEN, true).drop_ref(); 305 assert!(s.ref_count() > 0); 306 307 return WakeByValAction::None; 308 } 309 310 // If the task is already completed or woken, we don't need to 311 // requeue it, but decrement the ref count for the waker that was 312 // used for this wakeup. 313 if s.get(Snapshot::COMPLETE) || s.get(Snapshot::WOKEN) { 314 let new_state = s.drop_ref(); 315 *s = new_state; 316 return if new_state.ref_count() == 0 { 317 WakeByValAction::Drop 318 } else { 319 WakeByValAction::None 320 }; 321 } 322 323 // Otherwise, transition to the woken state and enqueue the task. 324 *s = s.with(Snapshot::WOKEN, true).clone_ref(); 325 WakeByValAction::Enqueue 326 }) 327 } 328 329 #[tracing::instrument(level = "trace")] 330 pub(super) fn wake_by_ref(&self) -> WakeByRefAction { 331 self.transition(|state| { 332 if state.get(Snapshot::COMPLETE) || state.get(Snapshot::WOKEN) { 333 return WakeByRefAction::None; 334 } 335 336 if state.get(Snapshot::POLLING) { 337 state.set(Snapshot::WOKEN, true); 338 return WakeByRefAction::None; 339 } 340 341 // Otherwise, transition to the woken state and enqueue the task. 342 *state = state.with(Snapshot::WOKEN, true).clone_ref(); 343 WakeByRefAction::Enqueue 344 }) 345 } 346 347 pub(super) fn refcount(&self) -> usize { 348 let raw = self.val.load(Ordering::Acquire); 349 Snapshot::REFS.unpack(raw) 350 } 351 352 #[tracing::instrument(level = "trace")] 353 pub(super) fn clone_ref(&self) { 354 // Using a relaxed ordering is alright here, as knowledge of the 355 // original reference prevents other threads from erroneously deleting 356 // the object. 357 // 358 // As explained in the [Boost documentation][1], Increasing the 359 // reference counter can always be done with memory_order_relaxed: New 360 // references to an object can only be formed from an existing 361 // reference, and passing an existing reference from one thread to 362 // another must already provide any required synchronization. 363 // 364 // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html) 365 let old_refs = self.val.fetch_add(REF_ONE, Ordering::Relaxed); 366 Snapshot::REFS.unpack(old_refs); 367 368 // However we need to guard against massive refcounts in case someone 369 // is `mem::forget`ing tasks. If we don't do this the count can overflow 370 // and users will use-after free. We racily saturate to `isize::MAX` on 371 // the assumption that there aren't ~2 billion threads incrementing 372 // the reference count at once. This branch will never be taken in 373 // any realistic program. 374 // 375 // We abort because such a program is incredibly degenerate, and we 376 // don't care to support it. 377 assert!(old_refs < REF_MAX, "task reference count overflow"); 378 } 379 380 #[tracing::instrument(level = "trace")] 381 pub(super) fn drop_ref(&self) -> bool { 382 // We do not need to synchronize with other cores unless we are going to 383 // delete the task. 384 let old_refs = self.val.fetch_sub(REF_ONE, Ordering::Release); 385 let old_refs = Snapshot::REFS.unpack(old_refs); 386 387 // Did we drop the last ref? 388 if old_refs > 1 { 389 return false; 390 } 391 392 atomic::fence(Ordering::Acquire); 393 true 394 } 395 396 /// Cancel the task. 397 /// 398 /// Returns `true` if the task was successfully canceled. 399 #[tracing::instrument(level = "trace")] 400 pub(super) fn cancel(&self) -> bool { 401 self.transition(|s| { 402 // you can't cancel a task that has already been canceled, that doesn't make sense. 403 if s.get(Snapshot::CANCELLED) { 404 return false; 405 } 406 407 s.set(Snapshot::CANCELLED, true).set(Snapshot::WOKEN, true); 408 409 true 410 }) 411 } 412 413 #[tracing::instrument(level = "trace")] 414 pub(super) fn create_join_handle(&self) { 415 self.transition(|s| { 416 debug_assert!( 417 !s.get(Snapshot::HAS_JOIN_HANDLE), 418 "task already has a join handle, cannot create a new one! state={s:?}" 419 ); 420 421 *s = s.with(Snapshot::HAS_JOIN_HANDLE, true); 422 }); 423 } 424 425 #[tracing::instrument(level = "trace")] 426 pub(super) fn drop_join_handle(&self) { 427 const MASK: usize = !Snapshot::HAS_JOIN_HANDLE.raw_mask(); 428 let _prev = self.val.fetch_and(MASK, Ordering::Release); 429 tracing::trace!( 430 "drop_join_handle; prev_state:\n{}\nstate:\n{}", 431 Snapshot::from_bits(_prev), 432 self.load(Ordering::Acquire), 433 ); 434 debug_assert!( 435 Snapshot(_prev).get(Snapshot::HAS_JOIN_HANDLE), 436 "tried to drop a join handle when the task did not have a join handle!\nstate: {:#?}", 437 Snapshot(_prev), 438 ); 439 } 440 441 fn transition<T>(&self, mut transition: impl FnMut(&mut Snapshot) -> T) -> T { 442 let mut current = self.load(Ordering::Acquire); 443 loop { 444 tracing::trace!("State::transition; current:\n{}", current); 445 let mut next = current; 446 // Run the transition function. 447 let res = transition(&mut next); 448 449 if current.0 == next.0 { 450 return res; 451 } 452 453 tracing::trace!("State::transition; next:\n{}", next); 454 match self.val.compare_exchange_weak( 455 current.0, 456 next.0, 457 Ordering::AcqRel, 458 Ordering::Acquire, 459 ) { 460 Ok(_) => return res, 461 Err(actual) => current = Snapshot(actual), 462 } 463 } 464 } 465} 466 467impl fmt::Debug for State { 468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 469 self.load(Ordering::Relaxed).fmt(f) 470 } 471} 472 473impl Snapshot { 474 pub fn ref_count(self) -> usize { 475 Snapshot::REFS.unpack(self.0) 476 } 477 478 fn drop_ref(self) -> Self { 479 Self(self.0 - REF_ONE) 480 } 481 482 fn clone_ref(self) -> Self { 483 Self(self.0 + REF_ONE) 484 } 485 486 fn has_join_waker(&mut self, should_wait: &mut bool) -> bool { 487 match self.get(Snapshot::JOIN_WAKER) { 488 JoinWakerState::Empty => false, 489 JoinWakerState::Registering => { 490 *should_wait = true; 491 debug_assert!( 492 self.get(Snapshot::HAS_JOIN_HANDLE), 493 "a task cannot register a join waker if it does not have a join handle!", 494 ); 495 true 496 } 497 JoinWakerState::Waiting => { 498 debug_assert!( 499 self.get(Snapshot::HAS_JOIN_HANDLE), 500 "a task cannot have a join waker if it does not have a join handle!", 501 ); 502 *should_wait = false; 503 self.set(Snapshot::JOIN_WAKER, JoinWakerState::Empty); 504 true 505 } 506 JoinWakerState::Woken => { 507 debug_assert!( 508 false, 509 "join waker should not be woken until task has completed, wtf" 510 ); 511 false 512 } 513 } 514 } 515} 516 517impl mycelium_bitfield::FromBits<usize> for JoinWakerState { 518 type Error = core::convert::Infallible; 519 520 /// The number of bits required to represent a value of this type. 521 const BITS: u32 = 2; 522 523 #[inline] 524 fn try_from_bits(bits: usize) -> Result<Self, Self::Error> { 525 match bits { 526 b if b == Self::Registering as usize => Ok(Self::Registering), 527 b if b == Self::Waiting as usize => Ok(Self::Waiting), 528 b if b == Self::Empty as usize => Ok(Self::Empty), 529 b if b == Self::Woken as usize => Ok(Self::Woken), 530 _ => { 531 // this should never happen unless the bitpacking code is broken 532 unreachable!("invalid join waker state {bits:#b}") 533 } 534 } 535 } 536 537 #[inline] 538 fn into_bits(self) -> usize { 539 self as u8 as usize 540 } 541}