Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3#![allow(clippy::undocumented_unsafe_blocks)]
4#![cfg_attr(feature = "alloc", feature(allocator_api))]
5#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
6#![allow(clippy::missing_safety_doc)]
7
8use core::{
9 cell::{Cell, UnsafeCell},
10 marker::PhantomPinned,
11 ops::{Deref, DerefMut},
12 pin::Pin,
13 sync::atomic::{AtomicBool, Ordering},
14};
15#[cfg(feature = "std")]
16use std::{
17 sync::Arc,
18 thread::{self, sleep, Builder, Thread},
19 time::Duration,
20};
21
22use pin_init::*;
23#[allow(unused_attributes)]
24#[path = "./linked_list.rs"]
25pub mod linked_list;
26use linked_list::*;
27
28pub struct SpinLock {
29 inner: AtomicBool,
30}
31
32impl SpinLock {
33 #[inline]
34 pub fn acquire(&self) -> SpinLockGuard<'_> {
35 while self
36 .inner
37 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
38 .is_err()
39 {
40 #[cfg(feature = "std")]
41 while self.inner.load(Ordering::Relaxed) {
42 thread::yield_now();
43 }
44 }
45 SpinLockGuard(self)
46 }
47
48 #[inline]
49 #[allow(clippy::new_without_default)]
50 pub const fn new() -> Self {
51 Self {
52 inner: AtomicBool::new(false),
53 }
54 }
55}
56
57pub struct SpinLockGuard<'a>(&'a SpinLock);
58
59impl Drop for SpinLockGuard<'_> {
60 #[inline]
61 fn drop(&mut self) {
62 self.0.inner.store(false, Ordering::Release);
63 }
64}
65
66#[pin_data]
67pub struct CMutex<T> {
68 #[pin]
69 wait_list: ListHead,
70 spin_lock: SpinLock,
71 locked: Cell<bool>,
72 #[pin]
73 data: UnsafeCell<T>,
74}
75
76impl<T> CMutex<T> {
77 #[inline]
78 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
79 pin_init!(CMutex {
80 wait_list <- ListHead::new(),
81 spin_lock: SpinLock::new(),
82 locked: Cell::new(false),
83 data <- unsafe {
84 pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
85 val.__pinned_init(slot.cast::<T>())
86 })
87 },
88 })
89 }
90
91 #[inline]
92 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
93 let mut sguard = self.spin_lock.acquire();
94 if self.locked.get() {
95 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
96 // println!("wait list length: {}", self.wait_list.size());
97 while self.locked.get() {
98 drop(sguard);
99 #[cfg(feature = "std")]
100 thread::park();
101 sguard = self.spin_lock.acquire();
102 }
103 // This does have an effect, as the ListHead inside wait_entry implements Drop!
104 #[expect(clippy::drop_non_drop)]
105 drop(wait_entry);
106 }
107 self.locked.set(true);
108 unsafe {
109 Pin::new_unchecked(CMutexGuard {
110 mtx: self,
111 _pin: PhantomPinned,
112 })
113 }
114 }
115
116 #[allow(dead_code)]
117 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
118 // SAFETY: we have an exclusive reference and thus nobody has access to data.
119 unsafe { &mut *self.data.get() }
120 }
121}
122
123unsafe impl<T: Send> Send for CMutex<T> {}
124unsafe impl<T: Send> Sync for CMutex<T> {}
125
126pub struct CMutexGuard<'a, T> {
127 mtx: &'a CMutex<T>,
128 _pin: PhantomPinned,
129}
130
131impl<T> Drop for CMutexGuard<'_, T> {
132 #[inline]
133 fn drop(&mut self) {
134 let sguard = self.mtx.spin_lock.acquire();
135 self.mtx.locked.set(false);
136 if let Some(list_field) = self.mtx.wait_list.next() {
137 let _wait_entry = list_field.as_ptr().cast::<WaitEntry>();
138 #[cfg(feature = "std")]
139 unsafe {
140 (*_wait_entry).thread.unpark()
141 };
142 }
143 drop(sguard);
144 }
145}
146
147impl<T> Deref for CMutexGuard<'_, T> {
148 type Target = T;
149
150 #[inline]
151 fn deref(&self) -> &Self::Target {
152 unsafe { &*self.mtx.data.get() }
153 }
154}
155
156impl<T> DerefMut for CMutexGuard<'_, T> {
157 #[inline]
158 fn deref_mut(&mut self) -> &mut Self::Target {
159 unsafe { &mut *self.mtx.data.get() }
160 }
161}
162
163#[pin_data]
164#[repr(C)]
165struct WaitEntry {
166 #[pin]
167 wait_list: ListHead,
168 #[cfg(feature = "std")]
169 thread: Thread,
170}
171
172impl WaitEntry {
173 #[inline]
174 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
175 #[cfg(feature = "std")]
176 {
177 pin_init!(Self {
178 thread: thread::current(),
179 wait_list <- ListHead::insert_prev(list),
180 })
181 }
182 #[cfg(not(feature = "std"))]
183 {
184 pin_init!(Self {
185 wait_list <- ListHead::insert_prev(list),
186 })
187 }
188 }
189}
190
191#[cfg_attr(test, test)]
192#[allow(dead_code)]
193fn main() {
194 #[cfg(feature = "std")]
195 {
196 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
197 let mut handles = vec![];
198 let thread_count = 20;
199 let workload = if cfg!(miri) { 100 } else { 1_000 };
200 for i in 0..thread_count {
201 let mtx = mtx.clone();
202 handles.push(
203 Builder::new()
204 .name(format!("worker #{i}"))
205 .spawn(move || {
206 for _ in 0..workload {
207 *mtx.lock() += 1;
208 }
209 println!("{i} halfway");
210 sleep(Duration::from_millis((i as u64) * 10));
211 for _ in 0..workload {
212 *mtx.lock() += 1;
213 }
214 println!("{i} finished");
215 })
216 .expect("should not fail"),
217 );
218 }
219 for h in handles {
220 h.join().expect("thread panicked");
221 }
222 println!("{:?}", &*mtx.lock());
223 assert_eq!(*mtx.lock(), workload * thread_count * 2);
224 }
225}