Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3#![allow(clippy::undocumented_unsafe_blocks)]
4#![cfg_attr(feature = "alloc", feature(allocator_api))]
5#![allow(clippy::missing_safety_doc)]
6
7use core::{
8 cell::{Cell, UnsafeCell},
9 marker::PhantomPinned,
10 ops::{Deref, DerefMut},
11 pin::Pin,
12 sync::atomic::{AtomicBool, Ordering},
13};
14use std::{
15 sync::Arc,
16 thread::{self, park, sleep, Builder, Thread},
17 time::Duration,
18};
19
20use pin_init::*;
21#[expect(unused_attributes)]
22#[path = "./linked_list.rs"]
23pub mod linked_list;
24use linked_list::*;
25
26pub struct SpinLock {
27 inner: AtomicBool,
28}
29
30impl SpinLock {
31 #[inline]
32 pub fn acquire(&self) -> SpinLockGuard<'_> {
33 while self
34 .inner
35 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
36 .is_err()
37 {
38 while self.inner.load(Ordering::Relaxed) {
39 thread::yield_now();
40 }
41 }
42 SpinLockGuard(self)
43 }
44
45 #[inline]
46 #[allow(clippy::new_without_default)]
47 pub const fn new() -> Self {
48 Self {
49 inner: AtomicBool::new(false),
50 }
51 }
52}
53
54pub struct SpinLockGuard<'a>(&'a SpinLock);
55
56impl Drop for SpinLockGuard<'_> {
57 #[inline]
58 fn drop(&mut self) {
59 self.0.inner.store(false, Ordering::Release);
60 }
61}
62
63#[pin_data]
64pub struct CMutex<T> {
65 #[pin]
66 wait_list: ListHead,
67 spin_lock: SpinLock,
68 locked: Cell<bool>,
69 #[pin]
70 data: UnsafeCell<T>,
71}
72
73impl<T> CMutex<T> {
74 #[inline]
75 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
76 pin_init!(CMutex {
77 wait_list <- ListHead::new(),
78 spin_lock: SpinLock::new(),
79 locked: Cell::new(false),
80 data <- unsafe {
81 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
82 val.__pinned_init(slot.cast::<T>())
83 })
84 },
85 })
86 }
87
88 #[inline]
89 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
90 let mut sguard = self.spin_lock.acquire();
91 if self.locked.get() {
92 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
93 // println!("wait list length: {}", self.wait_list.size());
94 while self.locked.get() {
95 drop(sguard);
96 park();
97 sguard = self.spin_lock.acquire();
98 }
99 // This does have an effect, as the ListHead inside wait_entry implements Drop!
100 #[expect(clippy::drop_non_drop)]
101 drop(wait_entry);
102 }
103 self.locked.set(true);
104 unsafe {
105 Pin::new_unchecked(CMutexGuard {
106 mtx: self,
107 _pin: PhantomPinned,
108 })
109 }
110 }
111
112 #[allow(dead_code)]
113 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
114 // SAFETY: we have an exclusive reference and thus nobody has access to data.
115 unsafe { &mut *self.data.get() }
116 }
117}
118
119unsafe impl<T: Send> Send for CMutex<T> {}
120unsafe impl<T: Send> Sync for CMutex<T> {}
121
122pub struct CMutexGuard<'a, T> {
123 mtx: &'a CMutex<T>,
124 _pin: PhantomPinned,
125}
126
127impl<T> Drop for CMutexGuard<'_, T> {
128 #[inline]
129 fn drop(&mut self) {
130 let sguard = self.mtx.spin_lock.acquire();
131 self.mtx.locked.set(false);
132 if let Some(list_field) = self.mtx.wait_list.next() {
133 let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
134 unsafe { (*wait_entry).thread.unpark() };
135 }
136 drop(sguard);
137 }
138}
139
140impl<T> Deref for CMutexGuard<'_, T> {
141 type Target = T;
142
143 #[inline]
144 fn deref(&self) -> &Self::Target {
145 unsafe { &*self.mtx.data.get() }
146 }
147}
148
149impl<T> DerefMut for CMutexGuard<'_, T> {
150 #[inline]
151 fn deref_mut(&mut self) -> &mut Self::Target {
152 unsafe { &mut *self.mtx.data.get() }
153 }
154}
155
156#[pin_data]
157#[repr(C)]
158struct WaitEntry {
159 #[pin]
160 wait_list: ListHead,
161 thread: Thread,
162}
163
164impl WaitEntry {
165 #[inline]
166 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
167 pin_init!(Self {
168 thread: thread::current(),
169 wait_list <- ListHead::insert_prev(list),
170 })
171 }
172}
173
174#[cfg(not(any(feature = "std", feature = "alloc")))]
175fn main() {}
176
177#[allow(dead_code)]
178#[cfg_attr(test, test)]
179#[cfg(any(feature = "std", feature = "alloc"))]
180fn main() {
181 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
182 let mut handles = vec![];
183 let thread_count = 20;
184 let workload = if cfg!(miri) { 100 } else { 1_000 };
185 for i in 0..thread_count {
186 let mtx = mtx.clone();
187 handles.push(
188 Builder::new()
189 .name(format!("worker #{i}"))
190 .spawn(move || {
191 for _ in 0..workload {
192 *mtx.lock() += 1;
193 }
194 println!("{i} halfway");
195 sleep(Duration::from_millis((i as u64) * 10));
196 for _ in 0..workload {
197 *mtx.lock() += 1;
198 }
199 println!("{i} finished");
200 })
201 .expect("should not fail"),
202 );
203 }
204 for h in handles {
205 h.join().expect("thread panicked");
206 }
207 println!("{:?}", &*mtx.lock());
208 assert_eq!(*mtx.lock(), workload * thread_count * 2);
209}