diff --git a/CHANGELOG.md b/CHANGELOG.md index 17216af99..dd04cbd5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,13 @@ - `DevicePathInstance::to_boxed`, `DevicePathInstance::to_owned`, and `DevicePathInstance::as_bytes` - `DevicePathNode::data` - Added `Event::from_ptr`, `Event::as_ptr`, and `Handle::as_ptr`. +- Added `ScopedProtocol::get` and `ScopedProtocol::get_mut` to access + potentially-null interfaces without panicking. ### Changed - Renamed `LoadImageSource::FromFilePath` to `LoadImageSource::FromDevicePath` +- The `Deref` and `DerefMut` impls for `ScopedProtocol` will now panic if the + interface pointer is null. ### Removed diff --git a/uefi-test-runner/src/boot/mod.rs b/uefi-test-runner/src/boot/mod.rs index e8f55fe47..89114a7f7 100644 --- a/uefi-test-runner/src/boot/mod.rs +++ b/uefi-test-runner/src/boot/mod.rs @@ -85,11 +85,18 @@ fn test_load_image(bt: &BootServices) { buffer: image_data.as_slice(), file_path: None, }; - let _ = bt + let loaded_image = bt .load_image(bt.image_handle(), load_source) .expect("should load image"); log::debug!("load_image with FromBuffer strategy works"); + + // Check that the `LoadedImageDevicePath` protocol can be opened and + // that the interface data is `None`. + let loaded_image_device_path = bt + .open_protocol_exclusive::(loaded_image) + .expect("should open LoadedImageDevicePath protocol"); + assert!(loaded_image_device_path.get().is_none()); } // Variant B: FromDevicePath { diff --git a/uefi/src/table/boot.rs b/uefi/src/table/boot.rs index 2669f4e8e..af4ee30d2 100644 --- a/uefi/src/table/boot.rs +++ b/uefi/src/table/boot.rs @@ -1338,10 +1338,13 @@ impl BootServices { attributes as u32, ) .to_result_with_val(|| { - let interface = P::mut_ptr_from_ffi(interface) as *const UnsafeCell

; + let interface = (!interface.is_null()).then(|| { + let interface = P::mut_ptr_from_ffi(interface) as *const UnsafeCell

; + &*interface + }); ScopedProtocol { - interface: &*interface, + interface, open_params: params, boot_services: self, } @@ -1814,12 +1817,23 @@ pub struct OpenProtocolParams { /// An open protocol interface. Automatically closes the protocol /// interface on drop. /// +/// Most protocols have interface data associated with them. `ScopedProtocol` +/// implements [`Deref`] and [`DerefMut`] to access this data. A few protocols +/// (such as [`DevicePath`] and [`LoadedImageDevicePath`]) may be installed with +/// null interface data, in which case [`Deref`] and [`DerefMut`] will +/// panic. The [`get`] and [`get_mut`] methods may be used to access the +/// optional interface data without panicking. +/// /// See also the [`BootServices`] documentation for details of how to open a /// protocol and why [`UnsafeCell`] is used. +/// +/// [`LoadedImageDevicePath`]: crate::proto::device_path::LoadedImageDevicePath +/// [`get`]: ScopedProtocol::get +/// [`get_mut`]: ScopedProtocol::get_mut #[derive(Debug)] pub struct ScopedProtocol<'a, P: Protocol + ?Sized> { /// The protocol interface. - interface: &'a UnsafeCell

, + interface: Option<&'a UnsafeCell

>, open_params: OpenProtocolParams, boot_services: &'a BootServices, @@ -1847,14 +1861,32 @@ impl<'a, P: Protocol + ?Sized> Drop for ScopedProtocol<'a, P> { impl<'a, P: Protocol + ?Sized> Deref for ScopedProtocol<'a, P> { type Target = P; + #[track_caller] fn deref(&self) -> &Self::Target { - unsafe { &*self.interface.get() } + unsafe { &*self.interface.unwrap().get() } } } impl<'a, P: Protocol + ?Sized> DerefMut for ScopedProtocol<'a, P> { + #[track_caller] fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *self.interface.get() } + unsafe { &mut *self.interface.unwrap().get() } + } +} + +impl<'a, P: Protocol + ?Sized> ScopedProtocol<'a, P> { + /// Get the protocol interface data, or `None` if the open protocol's + /// interface is null. + #[must_use] + pub fn get(&self) -> Option<&P> { + self.interface.map(|p| unsafe { &*p.get() }) + } + + /// Get the protocol interface data, or `None` if the open protocol's + /// interface is null. + #[must_use] + pub fn get_mut(&self) -> Option<&mut P> { + self.interface.map(|p| unsafe { &mut *p.get() }) } }