Next Generation WASM Microkernel Operating System
1// Copyright 2025 Jonas Kruckenberg
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::loom::sync::atomic::{self, AtomicUsize, Ordering};
9use crate::task::PollResult;
10use core::fmt;
11use spin::Backoff;
12use util::loom_const_fn;
13
14/// Task state. The task stores its state in an atomic `usize` with various bitfields for the
15/// necessary information. The state has the following layout:
16///
17/// ```text
18/// | 63 7 | 6 5 | 4 4 | 3 3 | 2 2 | 1 0 |
19/// | refcount | join waker | has join handle | cancelled | woken | lifecycle |
20/// ```
21///
22/// The rest of the bits are used for the ref-count.
23pub(crate) struct State {
24 val: AtomicUsize,
25}
26
27mycelium_bitfield::bitfield! {
28 /// A snapshot of a task's current state.
29 #[derive(PartialEq, Eq)]
30 pub(crate) struct Snapshot<usize> {
31 /// If set, this task is currently being polled.
32 pub const POLLING: bool;
33 /// If set, this task's `Future` has completed (i.e., it has returned
34 /// `Poll::Ready`).
35 pub const COMPLETE: bool;
36 /// If set, this task's `Waker` has been woken.
37 pub(crate) const WOKEN: bool;
38 /// If set, this task has been canceled.
39 pub const CANCELLED: bool;
40 /// If set, this task has a `JoinHandle` awaiting its completion.
41 ///
42 /// If the `JoinHandle` is dropped, this flag is unset.
43 ///
44 /// This flag does *not* indicate the presence of a `Waker` in the
45 /// `join_waker` slot; it only indicates that a `JoinHandle` for this
46 /// task *exists*. The join waker may not currently be registered if
47 /// this flag is set.
48 pub const HAS_JOIN_HANDLE: bool;
49 /// The state of the task's `JoinHandle` `Waker`.
50 const JOIN_WAKER: JoinWakerState;
51 /// If set, this task has output ready to be taken by a `JoinHandle`.
52 const HAS_OUTPUT: bool;
53 /// The number of currently live references to this task.
54 ///
55 /// When this is 0, the task may be deallocated.
56 const REFS = ..;
57 }
58}
59
60#[derive(Copy, Clone, Debug, Eq, PartialEq)]
61#[repr(u8)]
62enum JoinWakerState {
63 /// There is no join waker; the slot is uninitialized.
64 Empty = 0b00,
65 /// A join waker is *being* registered.
66 Registering = 0b01,
67 /// A join waker is registered, the slot is initialized.
68 Waiting = 0b10,
69 /// The join waker has been woken.
70 Woken = 0b11,
71}
72
73#[must_use]
74pub(super) enum StartPollAction {
75 /// Successful transition, it's okay to poll the task.
76 Poll,
77 /// Transition failed for some reason - most likely it is already running on another thread
78 /// (which shouldn't happen) - doesn't matter though we shouldn't poll the task.
79 DontPoll,
80 /// Transition failed because the task was cancelled and its `JoinHandle` waker may need to be woken.
81 Cancelled {
82 /// If `true`, the task's join waker must be woken.
83 wake_join_waker: bool,
84 },
85}
86
87#[must_use]
88pub enum JoinAction {
89 /// It's safe to take the task's output!
90 TakeOutput,
91
92 /// The task was canceled, it cannot be joined.
93 Canceled {
94 /// If `true`, the task completed successfully before it was cancelled.
95 completed: bool,
96 },
97
98 /// Register the *first* join waker; there is no previous join waker and the
99 /// slot is not initialized.
100 Register,
101
102 /// The task is not ready to read the output, but a previous join waker is
103 /// registered.
104 Reregister,
105}
106
107#[derive(Copy, Clone, Debug, PartialEq, Eq)]
108pub(super) enum WakeByRefAction {
109 /// The task should be enqueued.
110 Enqueue,
111 /// The task does not need to be enqueued.
112 None,
113}
114
115#[derive(Copy, Clone, Debug, PartialEq, Eq)]
116pub(super) enum WakeByValAction {
117 /// The task should be enqueued.
118 Enqueue,
119 /// The task does not need to be enqueued.
120 None,
121 /// The task should be deallocated.
122 Drop,
123}
124
125const REF_ONE: usize = Snapshot::REFS.first_bit();
126const REF_MAX: usize = Snapshot::REFS.raw_mask();
127
128impl State {
129 loom_const_fn! {
130 /// Returns a task's initial state.
131 pub(super) const fn new() -> State {
132 // The raw task returned by this method has a ref-count of three. See
133 // the comment on INITIAL_STATE for more.
134 State {
135 val: AtomicUsize::new(REF_ONE),
136 }
137 }
138 }
139
140 pub(super) fn load(&self, ordering: Ordering) -> Snapshot {
141 Snapshot(self.val.load(ordering))
142 }
143
144 /// Attempt to transition the task from `IDLE` to `POLLING`, the returned enum indicates what
145 /// to the with the task.
146 ///
147 /// This method should always be followed by a call to [`Self::end_poll`] after the actual poll
148 /// is completed.
149 #[tracing::instrument(level = "trace")]
150 pub(super) fn start_poll(&self) -> StartPollAction {
151 let mut should_wait_for_join_waker = false;
152 let action = self.transition(|s| {
153 // cannot start polling a task which is being polled on another
154 // thread, or a task which has completed
155 if s.get(Snapshot::POLLING) || s.get(Snapshot::COMPLETE) {
156 return StartPollAction::DontPoll;
157 }
158
159 // if the task has been canceled, don't poll it.
160 if s.get(Snapshot::CANCELLED) {
161 let wake_join_waker = s.has_join_waker(&mut should_wait_for_join_waker);
162 return StartPollAction::Cancelled { wake_join_waker };
163 }
164
165 s
166 // the task is now being polled.
167 .set(Snapshot::POLLING, true)
168 // if the task was woken, consume the wakeup.
169 .set(Snapshot::WOKEN, false);
170
171 StartPollAction::Poll
172 });
173
174 if should_wait_for_join_waker {
175 debug_assert!(matches!(action, StartPollAction::Cancelled { .. }));
176 self.wait_for_join_waker(self.load(Ordering::Acquire));
177 }
178
179 action
180 }
181
182 /// Transition the task from `POLLING` to `IDLE`, the returned enum indicates what to do with task.
183 /// The `completed` argument should be set to true if the polled future returned a `Poll::Ready`
184 /// indicating the task is completed and should not be rescheduled.
185 #[tracing::instrument(level = "trace")]
186 pub(super) fn end_poll(&self, completed: bool) -> PollResult {
187 let mut should_wait_for_join_waker = false;
188 let action = self.transition(|s| {
189 // Cannot end a poll if a task is not being polled!
190 debug_assert!(s.get(Snapshot::POLLING));
191 debug_assert!(!s.get(Snapshot::COMPLETE));
192 debug_assert!(
193 s.ref_count() > 0,
194 "cannot poll a task that has zero references, what is happening!"
195 );
196
197 s.set(Snapshot::POLLING, false)
198 .set(Snapshot::COMPLETE, completed);
199
200 // Was the task woken during the poll?
201 if !completed && s.get(Snapshot::WOKEN) {
202 return PollResult::PendingSchedule;
203 }
204
205 let had_join_waker = if completed {
206 // set the output flag so that the JoinHandle knows it is now
207 // safe to read the task's output.
208 s.set(Snapshot::HAS_OUTPUT, true);
209 s.has_join_waker(&mut should_wait_for_join_waker)
210 } else {
211 false
212 };
213
214 if had_join_waker {
215 PollResult::ReadyJoined
216 } else if completed {
217 PollResult::Ready
218 } else {
219 PollResult::Pending
220 }
221 });
222
223 if should_wait_for_join_waker {
224 debug_assert_eq!(action, PollResult::ReadyJoined);
225 self.wait_for_join_waker(self.load(Ordering::Acquire));
226 }
227
228 action
229 }
230
231 #[tracing::instrument(level = "trace")]
232 fn wait_for_join_waker(&self, mut state: Snapshot) {
233 let mut boff = Backoff::new();
234 loop {
235 state.set(Snapshot::JOIN_WAKER, JoinWakerState::Waiting);
236 let next = state.with(Snapshot::JOIN_WAKER, JoinWakerState::Woken);
237 match self.val.compare_exchange_weak(
238 state.0,
239 next.0,
240 Ordering::AcqRel,
241 Ordering::Acquire,
242 ) {
243 Ok(_) => return,
244 Err(actual) => state = Snapshot(actual),
245 }
246 boff.spin();
247 }
248 }
249
250 #[tracing::instrument(level = "trace")]
251 pub(super) fn try_join(&self) -> JoinAction {
252 fn should_register(s: &mut Snapshot) -> JoinAction {
253 let action = match s.get(Snapshot::JOIN_WAKER) {
254 JoinWakerState::Empty => JoinAction::Register,
255 x => {
256 debug_assert_eq!(x, JoinWakerState::Waiting);
257 JoinAction::Reregister
258 }
259 };
260 s.set(Snapshot::JOIN_WAKER, JoinWakerState::Registering);
261
262 action
263 }
264
265 self.transition(|s| {
266 let has_output = s.get(Snapshot::HAS_OUTPUT);
267
268 if s.get(Snapshot::CANCELLED) {
269 return JoinAction::Canceled {
270 completed: has_output,
271 };
272 }
273
274 // If the task has not completed, we can't take its join output.
275 if !s.get(Snapshot::COMPLETE) {
276 return should_register(s);
277 }
278
279 // If the task does not have output, we cannot take it.
280 if !has_output {
281 return should_register(s);
282 }
283
284 *s = s.with(Snapshot::HAS_OUTPUT, false);
285 JoinAction::TakeOutput
286 })
287 }
288
289 #[tracing::instrument(level = "trace")]
290 pub(super) fn join_waker_registered(&self) {
291 self.transition(|s| {
292 debug_assert_eq!(s.get(Snapshot::JOIN_WAKER), JoinWakerState::Registering);
293 s.set(Snapshot::HAS_JOIN_HANDLE, true)
294 .set(Snapshot::JOIN_WAKER, JoinWakerState::Waiting);
295 });
296 }
297
298 #[tracing::instrument(level = "trace")]
299 pub(super) fn wake_by_val(&self) -> WakeByValAction {
300 self.transition(|s| {
301 // If the task was woken *during* a poll, it will be re-queued by the
302 // scheduler at the end of the poll if needed, so don't enqueue it now.
303 if s.get(Snapshot::POLLING) {
304 *s = s.with(Snapshot::WOKEN, true).drop_ref();
305 assert!(s.ref_count() > 0);
306
307 return WakeByValAction::None;
308 }
309
310 // If the task is already completed or woken, we don't need to
311 // requeue it, but decrement the ref count for the waker that was
312 // used for this wakeup.
313 if s.get(Snapshot::COMPLETE) || s.get(Snapshot::WOKEN) {
314 let new_state = s.drop_ref();
315 *s = new_state;
316 return if new_state.ref_count() == 0 {
317 WakeByValAction::Drop
318 } else {
319 WakeByValAction::None
320 };
321 }
322
323 // Otherwise, transition to the woken state and enqueue the task.
324 *s = s.with(Snapshot::WOKEN, true).clone_ref();
325 WakeByValAction::Enqueue
326 })
327 }
328
329 #[tracing::instrument(level = "trace")]
330 pub(super) fn wake_by_ref(&self) -> WakeByRefAction {
331 self.transition(|state| {
332 if state.get(Snapshot::COMPLETE) || state.get(Snapshot::WOKEN) {
333 return WakeByRefAction::None;
334 }
335
336 if state.get(Snapshot::POLLING) {
337 state.set(Snapshot::WOKEN, true);
338 return WakeByRefAction::None;
339 }
340
341 // Otherwise, transition to the woken state and enqueue the task.
342 *state = state.with(Snapshot::WOKEN, true).clone_ref();
343 WakeByRefAction::Enqueue
344 })
345 }
346
347 pub(super) fn refcount(&self) -> usize {
348 let raw = self.val.load(Ordering::Acquire);
349 Snapshot::REFS.unpack(raw)
350 }
351
352 #[tracing::instrument(level = "trace")]
353 pub(super) fn clone_ref(&self) {
354 // Using a relaxed ordering is alright here, as knowledge of the
355 // original reference prevents other threads from erroneously deleting
356 // the object.
357 //
358 // As explained in the [Boost documentation][1], Increasing the
359 // reference counter can always be done with memory_order_relaxed: New
360 // references to an object can only be formed from an existing
361 // reference, and passing an existing reference from one thread to
362 // another must already provide any required synchronization.
363 //
364 // [1]: (www.boost.org/doc/libs/1_55_0/doc/html/atomic/usage_examples.html)
365 let old_refs = self.val.fetch_add(REF_ONE, Ordering::Relaxed);
366 Snapshot::REFS.unpack(old_refs);
367
368 // However we need to guard against massive refcounts in case someone
369 // is `mem::forget`ing tasks. If we don't do this the count can overflow
370 // and users will use-after free. We racily saturate to `isize::MAX` on
371 // the assumption that there aren't ~2 billion threads incrementing
372 // the reference count at once. This branch will never be taken in
373 // any realistic program.
374 //
375 // We abort because such a program is incredibly degenerate, and we
376 // don't care to support it.
377 assert!(old_refs < REF_MAX, "task reference count overflow");
378 }
379
380 #[tracing::instrument(level = "trace")]
381 pub(super) fn drop_ref(&self) -> bool {
382 // We do not need to synchronize with other cores unless we are going to
383 // delete the task.
384 let old_refs = self.val.fetch_sub(REF_ONE, Ordering::Release);
385 let old_refs = Snapshot::REFS.unpack(old_refs);
386
387 // Did we drop the last ref?
388 if old_refs > 1 {
389 return false;
390 }
391
392 atomic::fence(Ordering::Acquire);
393 true
394 }
395
396 /// Cancel the task.
397 ///
398 /// Returns `true` if the task was successfully canceled.
399 #[tracing::instrument(level = "trace")]
400 pub(super) fn cancel(&self) -> bool {
401 self.transition(|s| {
402 // you can't cancel a task that has already been canceled, that doesn't make sense.
403 if s.get(Snapshot::CANCELLED) {
404 return false;
405 }
406
407 s.set(Snapshot::CANCELLED, true).set(Snapshot::WOKEN, true);
408
409 true
410 })
411 }
412
413 #[tracing::instrument(level = "trace")]
414 pub(super) fn create_join_handle(&self) {
415 self.transition(|s| {
416 debug_assert!(
417 !s.get(Snapshot::HAS_JOIN_HANDLE),
418 "task already has a join handle, cannot create a new one! state={s:?}"
419 );
420
421 *s = s.with(Snapshot::HAS_JOIN_HANDLE, true);
422 });
423 }
424
425 #[tracing::instrument(level = "trace")]
426 pub(super) fn drop_join_handle(&self) {
427 const MASK: usize = !Snapshot::HAS_JOIN_HANDLE.raw_mask();
428 let _prev = self.val.fetch_and(MASK, Ordering::Release);
429 tracing::trace!(
430 "drop_join_handle; prev_state:\n{}\nstate:\n{}",
431 Snapshot::from_bits(_prev),
432 self.load(Ordering::Acquire),
433 );
434 debug_assert!(
435 Snapshot(_prev).get(Snapshot::HAS_JOIN_HANDLE),
436 "tried to drop a join handle when the task did not have a join handle!\nstate: {:#?}",
437 Snapshot(_prev),
438 );
439 }
440
441 fn transition<T>(&self, mut transition: impl FnMut(&mut Snapshot) -> T) -> T {
442 let mut current = self.load(Ordering::Acquire);
443 loop {
444 tracing::trace!("State::transition; current:\n{}", current);
445 let mut next = current;
446 // Run the transition function.
447 let res = transition(&mut next);
448
449 if current.0 == next.0 {
450 return res;
451 }
452
453 tracing::trace!("State::transition; next:\n{}", next);
454 match self.val.compare_exchange_weak(
455 current.0,
456 next.0,
457 Ordering::AcqRel,
458 Ordering::Acquire,
459 ) {
460 Ok(_) => return res,
461 Err(actual) => current = Snapshot(actual),
462 }
463 }
464 }
465}
466
467impl fmt::Debug for State {
468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469 self.load(Ordering::Relaxed).fmt(f)
470 }
471}
472
473impl Snapshot {
474 pub fn ref_count(self) -> usize {
475 Snapshot::REFS.unpack(self.0)
476 }
477
478 fn drop_ref(self) -> Self {
479 Self(self.0 - REF_ONE)
480 }
481
482 fn clone_ref(self) -> Self {
483 Self(self.0 + REF_ONE)
484 }
485
486 fn has_join_waker(&mut self, should_wait: &mut bool) -> bool {
487 match self.get(Snapshot::JOIN_WAKER) {
488 JoinWakerState::Empty => false,
489 JoinWakerState::Registering => {
490 *should_wait = true;
491 debug_assert!(
492 self.get(Snapshot::HAS_JOIN_HANDLE),
493 "a task cannot register a join waker if it does not have a join handle!",
494 );
495 true
496 }
497 JoinWakerState::Waiting => {
498 debug_assert!(
499 self.get(Snapshot::HAS_JOIN_HANDLE),
500 "a task cannot have a join waker if it does not have a join handle!",
501 );
502 *should_wait = false;
503 self.set(Snapshot::JOIN_WAKER, JoinWakerState::Empty);
504 true
505 }
506 JoinWakerState::Woken => {
507 debug_assert!(
508 false,
509 "join waker should not be woken until task has completed, wtf"
510 );
511 false
512 }
513 }
514 }
515}
516
517impl mycelium_bitfield::FromBits<usize> for JoinWakerState {
518 type Error = core::convert::Infallible;
519
520 /// The number of bits required to represent a value of this type.
521 const BITS: u32 = 2;
522
523 #[inline]
524 fn try_from_bits(bits: usize) -> Result<Self, Self::Error> {
525 match bits {
526 b if b == Self::Registering as usize => Ok(Self::Registering),
527 b if b == Self::Waiting as usize => Ok(Self::Waiting),
528 b if b == Self::Empty as usize => Ok(Self::Empty),
529 b if b == Self::Woken as usize => Ok(Self::Woken),
530 _ => {
531 // this should never happen unless the bitpacking code is broken
532 unreachable!("invalid join waker state {bits:#b}")
533 }
534 }
535 }
536
537 #[inline]
538 fn into_bits(self) -> usize {
539 self as u8 as usize
540 }
541}