rust: driver: drop device private data post unbind

Currently, the driver's device private data is allocated and initialized
from driver core code called from bus abstractions after the driver's
probe() callback returned the corresponding initializer.

Similarly, the driver's device private data is dropped within the
remove() callback of bus abstractions after calling the remove()
callback of the corresponding driver.

However, commit 6f61a2637abe ("rust: device: introduce
Device::drvdata()") introduced an accessor for the driver's device
private data for a Device<Bound>, i.e. a device that is currently bound
to a driver.

Obviously, this is in conflict with dropping the driver's device private
data in remove(), since a device can not be considered to be fully
unbound after remove() has finished:

We also have to consider registrations guarded by devres - such as IRQ
or class device registrations - which are torn down after remove() in
devres_release_all().

Thus, it can happen that, for instance, a class device or IRQ callback
still calls Device::drvdata(), which then runs concurrently to remove()
(which sets dev->driver_data to NULL and drops the driver's device
private data), before devres_release_all() started to tear down the
corresponding registration. This is because devres guarded registrations
can, as expected, access the corresponding Device<Bound> that defines
their scope.

In C it simply is the driver's responsibility to ensure that its device
private data is freed after e.g. an IRQ registration is unregistered.

Typically, C drivers achieve this by allocating their device private data
with e.g. devm_kzalloc() before doing anything else, i.e. before e.g.
registering an IRQ with devm_request_threaded_irq(), relying on the
reverse order cleanup of devres.

Technically, we could do something similar in Rust. However, the
resulting code would be pretty messy:

In Rust we have to differentiate between allocated but uninitialized
memory and initialized memory in the type system. Thus, we would need to
somehow keep track of whether the driver's device private data object
has been initialized (i.e. probe() was successful and returned a valid
initializer for this memory) and conditionally call the destructor of
the corresponding object when it is freed.

This is because we'd need to allocate and register the memory of the
driver's device private data *before* it is initialized by the
initializer returned by the driver's probe() callback, because the
driver could already register devres guarded registrations within
probe() outside of the driver's device private data initializer.

Luckily there is a much simpler solution: Instead of dropping the
driver's device private data at the end of remove(), we just drop it
after the device has been fully unbound, i.e. after all devres callbacks
have been processed.

For this, we introduce a new post_unbind() callback private to the
driver-core, i.e. the callback is neither exposed to drivers, nor to bus
abstractions.

This way, the driver-core code can simply continue to conditionally
allocate the memory for the driver's device private data when the
driver's initializer is returned from probe() - no change needed - and
drop it when the driver-core code receives the post_unbind() callback.

Closes: https://lore.kernel.org/all/DEZMS6Y4A7XE.XE7EUBT5SJFJ@kernel.org/
Fixes: 6f61a2637abe ("rust: device: introduce Device::drvdata()")
Acked-by: Alice Ryhl <aliceryhl@google.com>
Acked-by: Greg Kroah-Hartman <gregkh@linuxfoundation.org>
Acked-by: Igor Korotin <igor.korotin.linux@gmail.com>
Link: https://patch.msgid.link/20260107103511.570525-7-dakr@kernel.org
[ Remove #ifdef CONFIG_RUST, rename post_unbind() to post_unbind_rust().
- Danilo]
Signed-off-by: Danilo Krummrich <dakr@kernel.org>

+67 -20
+2
drivers/base/dd.c
··· 548 548 static void device_unbind_cleanup(struct device *dev) 549 549 { 550 550 devres_release_all(dev); 551 + if (dev->driver->p_cb.post_unbind_rust) 552 + dev->driver->p_cb.post_unbind_rust(dev); 551 553 arch_teardown_dma_ops(dev); 552 554 kfree(dev->dma_range_map); 553 555 dev->dma_range_map = NULL;
+9
include/linux/device/driver.h
··· 85 85 * uevent. 86 86 * @p: Driver core's private data, no one other than the driver 87 87 * core can touch this. 88 + * @p_cb: Callbacks private to the driver core; no one other than the 89 + * driver core is allowed to touch this. 88 90 * 89 91 * The device driver-model tracks all of the drivers known to the system. 90 92 * The main reason for this tracking is to enable the driver core to match ··· 121 119 void (*coredump) (struct device *dev); 122 120 123 121 struct driver_private *p; 122 + struct { 123 + /* 124 + * Called after remove() and after all devres entries have been 125 + * processed. This is a Rust only callback. 126 + */ 127 + void (*post_unbind_rust)(struct device *dev); 128 + } p_cb; 124 129 }; 125 130 126 131
+2 -2
rust/kernel/auxiliary.rs
··· 96 96 // SAFETY: `remove_callback` is only ever called after a successful call to 97 97 // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called 98 98 // and stored a `Pin<KBox<T>>`. 99 - let data = unsafe { adev.as_ref().drvdata_obtain::<T>() }; 99 + let data = unsafe { adev.as_ref().drvdata_borrow::<T>() }; 100 100 101 - T::unbind(adev, data.as_ref()); 101 + T::unbind(adev, data); 102 102 } 103 103 } 104 104
+11 -9
rust/kernel/device.rs
··· 232 232 /// 233 233 /// # Safety 234 234 /// 235 - /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. 236 235 /// - The type `T` must match the type of the `ForeignOwnable` previously stored by 237 236 /// [`Device::set_drvdata`]. 238 - pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> { 237 + pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> { 239 238 // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. 240 239 let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; 241 240 242 241 // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. 243 242 unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) }; 244 243 244 + if ptr.is_null() { 245 + return None; 246 + } 247 + 245 248 // SAFETY: 246 - // - By the safety requirements of this function, `ptr` comes from a previous call to 247 - // `into_foreign()`. 249 + // - If `ptr` is not NULL, it comes from a previous call to `into_foreign()`. 248 250 // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` 249 251 // in `into_foreign()`. 250 - unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) } 252 + Some(unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) }) 251 253 } 252 254 253 255 /// Borrow the driver's private data bound to this [`Device`]. 254 256 /// 255 257 /// # Safety 256 258 /// 257 - /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before 258 - /// [`Device::drvdata_obtain`]. 259 + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before the 260 + /// device is fully unbound. 259 261 /// - The type `T` must match the type of the `ForeignOwnable` previously stored by 260 262 /// [`Device::set_drvdata`]. 261 263 pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { ··· 273 271 /// # Safety 274 272 /// 275 273 /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before 276 - /// [`Device::drvdata_obtain`]. 274 + /// the device is fully unbound. 277 275 /// - The type `T` must match the type of the `ForeignOwnable` previously stored by 278 276 /// [`Device::set_drvdata`]. 279 277 unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { ··· 322 320 323 321 // SAFETY: 324 322 // - The above check of `dev_get_drvdata()` guarantees that we are called after 325 - // `set_drvdata()` and before `drvdata_obtain()`. 323 + // `set_drvdata()`. 326 324 // - We've just checked that the type of the driver's private data is in fact `T`. 327 325 Ok(unsafe { self.drvdata_unchecked() }) 328 326 }
+35 -1
rust/kernel/driver.rs
··· 177 177 // any thread, so `Registration` is `Send`. 178 178 unsafe impl<T: RegistrationOps> Send for Registration<T> {} 179 179 180 - impl<T: RegistrationOps> Registration<T> { 180 + impl<T: RegistrationOps + 'static> Registration<T> { 181 + extern "C" fn post_unbind_callback(dev: *mut bindings::device) { 182 + // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to 183 + // a `struct device`. 184 + // 185 + // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`. 186 + let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() }; 187 + 188 + // `remove()` and all devres callbacks have been completed at this point, hence drop the 189 + // driver's device private data. 190 + // 191 + // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the 192 + // driver's device private data type. 193 + drop(unsafe { dev.drvdata_obtain::<T::DriverData>() }); 194 + } 195 + 196 + /// Attach generic `struct device_driver` callbacks. 197 + fn callbacks_attach(drv: &Opaque<T::DriverType>) { 198 + let ptr = drv.get().cast::<u8>(); 199 + 200 + // SAFETY: 201 + // - `drv.get()` yields a valid pointer to `Self::DriverType`. 202 + // - Adding `DEVICE_DRIVER_OFFSET` yields the address of the embedded `struct device_driver` 203 + // as guaranteed by the safety requirements of the `Driver` trait. 204 + let base = unsafe { ptr.add(T::DEVICE_DRIVER_OFFSET) }; 205 + 206 + // CAST: `base` points to the offset of the embedded `struct device_driver`. 207 + let base = base.cast::<bindings::device_driver>(); 208 + 209 + // SAFETY: It is safe to set the fields of `struct device_driver` on initialization. 210 + unsafe { (*base).p_cb.post_unbind_rust = Some(Self::post_unbind_callback) }; 211 + } 212 + 181 213 /// Creates a new instance of the registration object. 182 214 pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { 183 215 try_pin_init!(Self { ··· 220 188 // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write, and it has 221 189 // just been initialised above, so it's also valid for read. 222 190 let drv = unsafe { &*(ptr as *const Opaque<T::DriverType>) }; 191 + 192 + Self::callbacks_attach(drv); 223 193 224 194 // SAFETY: `drv` is guaranteed to be pinned until `T::unregister`. 225 195 unsafe { T::register(drv, name, module) }
+2 -2
rust/kernel/i2c.rs
··· 178 178 // SAFETY: `remove_callback` is only ever called after a successful call to 179 179 // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called 180 180 // and stored a `Pin<KBox<T>>`. 181 - let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; 181 + let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; 182 182 183 - T::unbind(idev, data.as_ref()); 183 + T::unbind(idev, data); 184 184 } 185 185 186 186 extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) {
+2 -2
rust/kernel/pci.rs
··· 123 123 // SAFETY: `remove_callback` is only ever called after a successful call to 124 124 // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called 125 125 // and stored a `Pin<KBox<T>>`. 126 - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; 126 + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; 127 127 128 - T::unbind(pdev, data.as_ref()); 128 + T::unbind(pdev, data); 129 129 } 130 130 } 131 131
+2 -2
rust/kernel/platform.rs
··· 101 101 // SAFETY: `remove_callback` is only ever called after a successful call to 102 102 // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called 103 103 // and stored a `Pin<KBox<T>>`. 104 - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; 104 + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; 105 105 106 - T::unbind(pdev, data.as_ref()); 106 + T::unbind(pdev, data); 107 107 } 108 108 } 109 109
+2 -2
rust/kernel/usb.rs
··· 103 103 // SAFETY: `disconnect_callback` is only ever called after a successful call to 104 104 // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called 105 105 // and stored a `Pin<KBox<T>>`. 106 - let data = unsafe { dev.drvdata_obtain::<T>() }; 106 + let data = unsafe { dev.drvdata_borrow::<T>() }; 107 107 108 - T::disconnect(intf, data.as_ref()); 108 + T::disconnect(intf, data); 109 109 } 110 110 } 111 111