Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: Apache-2.0 OR MIT
2
3#![allow(clippy::undocumented_unsafe_blocks)]
4#![cfg_attr(feature = "alloc", feature(allocator_api))]
5
6use core::{
7 cell::{Cell, UnsafeCell},
8 mem::MaybeUninit,
9 ops,
10 pin::Pin,
11 time::Duration,
12};
13use pin_init::*;
14use std::{
15 sync::Arc,
16 thread::{sleep, Builder},
17};
18
19#[expect(unused_attributes)]
20mod mutex;
21use mutex::*;
22
23pub struct StaticInit<T, I> {
24 cell: UnsafeCell<MaybeUninit<T>>,
25 init: Cell<Option<I>>,
26 lock: SpinLock,
27 present: Cell<bool>,
28}
29
30unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
31unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
32
33impl<T, I: PinInit<T>> StaticInit<T, I> {
34 pub const fn new(init: I) -> Self {
35 Self {
36 cell: UnsafeCell::new(MaybeUninit::uninit()),
37 init: Cell::new(Some(init)),
38 lock: SpinLock::new(),
39 present: Cell::new(false),
40 }
41 }
42}
43
44impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
45 type Target = T;
46 fn deref(&self) -> &Self::Target {
47 if self.present.get() {
48 unsafe { (*self.cell.get()).assume_init_ref() }
49 } else {
50 println!("acquire spinlock on static init");
51 let _guard = self.lock.acquire();
52 println!("rechecking present...");
53 std::thread::sleep(std::time::Duration::from_millis(200));
54 if self.present.get() {
55 return unsafe { (*self.cell.get()).assume_init_ref() };
56 }
57 println!("doing init");
58 let ptr = self.cell.get().cast::<T>();
59 match self.init.take() {
60 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
61 None => unsafe { core::hint::unreachable_unchecked() },
62 }
63 self.present.set(true);
64 unsafe { (*self.cell.get()).assume_init_ref() }
65 }
66 }
67}
68
69pub struct CountInit;
70
71unsafe impl PinInit<CMutex<usize>> for CountInit {
72 unsafe fn __pinned_init(
73 self,
74 slot: *mut CMutex<usize>,
75 ) -> Result<(), core::convert::Infallible> {
76 let init = CMutex::new(0);
77 std::thread::sleep(std::time::Duration::from_millis(1000));
78 unsafe { init.__pinned_init(slot) }
79 }
80}
81
82pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
83
84#[cfg(not(any(feature = "std", feature = "alloc")))]
85fn main() {}
86
87#[cfg(any(feature = "std", feature = "alloc"))]
88fn main() {
89 let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
90 let mut handles = vec![];
91 let thread_count = 20;
92 let workload = 1_000;
93 for i in 0..thread_count {
94 let mtx = mtx.clone();
95 handles.push(
96 Builder::new()
97 .name(format!("worker #{i}"))
98 .spawn(move || {
99 for _ in 0..workload {
100 *COUNT.lock() += 1;
101 std::thread::sleep(std::time::Duration::from_millis(10));
102 *mtx.lock() += 1;
103 std::thread::sleep(std::time::Duration::from_millis(10));
104 *COUNT.lock() += 1;
105 }
106 println!("{i} halfway");
107 sleep(Duration::from_millis((i as u64) * 10));
108 for _ in 0..workload {
109 std::thread::sleep(std::time::Duration::from_millis(10));
110 *mtx.lock() += 1;
111 }
112 println!("{i} finished");
113 })
114 .expect("should not fail"),
115 );
116 }
117 for h in handles {
118 h.join().expect("thread panicked");
119 }
120 println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
121 assert_eq!(*mtx.lock(), workload * thread_count * 2);
122}