at v6.15-rc2 209 lines 5.4 kB view raw
1// SPDX-License-Identifier: Apache-2.0 OR MIT 2 3#![allow(clippy::undocumented_unsafe_blocks)] 4#![cfg_attr(feature = "alloc", feature(allocator_api))] 5#![allow(clippy::missing_safety_doc)] 6 7use core::{ 8 cell::{Cell, UnsafeCell}, 9 marker::PhantomPinned, 10 ops::{Deref, DerefMut}, 11 pin::Pin, 12 sync::atomic::{AtomicBool, Ordering}, 13}; 14use std::{ 15 sync::Arc, 16 thread::{self, park, sleep, Builder, Thread}, 17 time::Duration, 18}; 19 20use pin_init::*; 21#[expect(unused_attributes)] 22#[path = "./linked_list.rs"] 23pub mod linked_list; 24use linked_list::*; 25 26pub struct SpinLock { 27 inner: AtomicBool, 28} 29 30impl SpinLock { 31 #[inline] 32 pub fn acquire(&self) -> SpinLockGuard<'_> { 33 while self 34 .inner 35 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) 36 .is_err() 37 { 38 while self.inner.load(Ordering::Relaxed) { 39 thread::yield_now(); 40 } 41 } 42 SpinLockGuard(self) 43 } 44 45 #[inline] 46 #[allow(clippy::new_without_default)] 47 pub const fn new() -> Self { 48 Self { 49 inner: AtomicBool::new(false), 50 } 51 } 52} 53 54pub struct SpinLockGuard<'a>(&'a SpinLock); 55 56impl Drop for SpinLockGuard<'_> { 57 #[inline] 58 fn drop(&mut self) { 59 self.0.inner.store(false, Ordering::Release); 60 } 61} 62 63#[pin_data] 64pub struct CMutex<T> { 65 #[pin] 66 wait_list: ListHead, 67 spin_lock: SpinLock, 68 locked: Cell<bool>, 69 #[pin] 70 data: UnsafeCell<T>, 71} 72 73impl<T> CMutex<T> { 74 #[inline] 75 pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> { 76 pin_init!(CMutex { 77 wait_list <- ListHead::new(), 78 spin_lock: SpinLock::new(), 79 locked: Cell::new(false), 80 data <- unsafe { 81 pin_init_from_closure(|slot: *mut UnsafeCell<T>| { 82 val.__pinned_init(slot.cast::<T>()) 83 }) 84 }, 85 }) 86 } 87 88 #[inline] 89 pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> { 90 let mut sguard = self.spin_lock.acquire(); 91 if self.locked.get() { 92 stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); 93 // println!("wait list length: {}", self.wait_list.size()); 94 while self.locked.get() { 95 drop(sguard); 96 park(); 97 sguard = self.spin_lock.acquire(); 98 } 99 // This does have an effect, as the ListHead inside wait_entry implements Drop! 100 #[expect(clippy::drop_non_drop)] 101 drop(wait_entry); 102 } 103 self.locked.set(true); 104 unsafe { 105 Pin::new_unchecked(CMutexGuard { 106 mtx: self, 107 _pin: PhantomPinned, 108 }) 109 } 110 } 111 112 #[allow(dead_code)] 113 pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { 114 // SAFETY: we have an exclusive reference and thus nobody has access to data. 115 unsafe { &mut *self.data.get() } 116 } 117} 118 119unsafe impl<T: Send> Send for CMutex<T> {} 120unsafe impl<T: Send> Sync for CMutex<T> {} 121 122pub struct CMutexGuard<'a, T> { 123 mtx: &'a CMutex<T>, 124 _pin: PhantomPinned, 125} 126 127impl<T> Drop for CMutexGuard<'_, T> { 128 #[inline] 129 fn drop(&mut self) { 130 let sguard = self.mtx.spin_lock.acquire(); 131 self.mtx.locked.set(false); 132 if let Some(list_field) = self.mtx.wait_list.next() { 133 let wait_entry = list_field.as_ptr().cast::<WaitEntry>(); 134 unsafe { (*wait_entry).thread.unpark() }; 135 } 136 drop(sguard); 137 } 138} 139 140impl<T> Deref for CMutexGuard<'_, T> { 141 type Target = T; 142 143 #[inline] 144 fn deref(&self) -> &Self::Target { 145 unsafe { &*self.mtx.data.get() } 146 } 147} 148 149impl<T> DerefMut for CMutexGuard<'_, T> { 150 #[inline] 151 fn deref_mut(&mut self) -> &mut Self::Target { 152 unsafe { &mut *self.mtx.data.get() } 153 } 154} 155 156#[pin_data] 157#[repr(C)] 158struct WaitEntry { 159 #[pin] 160 wait_list: ListHead, 161 thread: Thread, 162} 163 164impl WaitEntry { 165 #[inline] 166 fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ { 167 pin_init!(Self { 168 thread: thread::current(), 169 wait_list <- ListHead::insert_prev(list), 170 }) 171 } 172} 173 174#[cfg(not(any(feature = "std", feature = "alloc")))] 175fn main() {} 176 177#[allow(dead_code)] 178#[cfg_attr(test, test)] 179#[cfg(any(feature = "std", feature = "alloc"))] 180fn main() { 181 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap(); 182 let mut handles = vec![]; 183 let thread_count = 20; 184 let workload = if cfg!(miri) { 100 } else { 1_000 }; 185 for i in 0..thread_count { 186 let mtx = mtx.clone(); 187 handles.push( 188 Builder::new() 189 .name(format!("worker #{i}")) 190 .spawn(move || { 191 for _ in 0..workload { 192 *mtx.lock() += 1; 193 } 194 println!("{i} halfway"); 195 sleep(Duration::from_millis((i as u64) * 10)); 196 for _ in 0..workload { 197 *mtx.lock() += 1; 198 } 199 println!("{i} finished"); 200 }) 201 .expect("should not fail"), 202 ); 203 } 204 for h in handles { 205 h.join().expect("thread panicked"); 206 } 207 println!("{:?}", &*mtx.lock()); 208 assert_eq!(*mtx.lock(), workload * thread_count * 2); 209}