Skip to content

Commit 7f8977a

Browse files
y86-devojeda
authored andcommitted
rust: init: add {pin_}chain functions to {Pin}Init<T, E>
The `{pin_}chain` functions extend an initializer: it not only initializes the value, but also executes a closure taking a reference to the initialized value. This allows to do something with a value directly after initialization. Suggested-by: Asahi Lina <lina@asahilina.net> Reviewed-by: Martin Rodriguez Reboredo <yakoyoku@gmail.com> Signed-off-by: Benno Lossin <benno.lossin@proton.me> Reviewed-by: Alice Ryhl <aliceryhl@google.com> Link: https://lore.kernel.org/r/20230814084602.25699-13-benno.lossin@proton.me [ Cleaned a few trivial nits. ] Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
1 parent 1a8076a commit 7f8977a

File tree

2 files changed

+143
-1
lines changed

2 files changed

+143
-1
lines changed

Diff for: rust/kernel/init.rs

+142
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,79 @@ pub unsafe trait PinInit<T: ?Sized, E = Infallible>: Sized {
767767
/// deallocate.
768768
/// - `slot` will not move until it is dropped, i.e. it will be pinned.
769769
unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E>;
770+
771+
/// First initializes the value using `self` then calls the function `f` with the initialized
772+
/// value.
773+
///
774+
/// If `f` returns an error the value is dropped and the initializer will forward the error.
775+
///
776+
/// # Examples
777+
///
778+
/// ```rust
779+
/// # #![allow(clippy::disallowed_names)]
780+
/// use kernel::{types::Opaque, init::pin_init_from_closure};
781+
/// #[repr(C)]
782+
/// struct RawFoo([u8; 16]);
783+
/// extern {
784+
/// fn init_foo(_: *mut RawFoo);
785+
/// }
786+
///
787+
/// #[pin_data]
788+
/// struct Foo {
789+
/// #[pin]
790+
/// raw: Opaque<RawFoo>,
791+
/// }
792+
///
793+
/// impl Foo {
794+
/// fn setup(self: Pin<&mut Self>) {
795+
/// pr_info!("Setting up foo");
796+
/// }
797+
/// }
798+
///
799+
/// let foo = pin_init!(Foo {
800+
/// raw <- unsafe {
801+
/// Opaque::ffi_init(|s| {
802+
/// init_foo(s);
803+
/// })
804+
/// },
805+
/// }).pin_chain(|foo| {
806+
/// foo.setup();
807+
/// Ok(())
808+
/// });
809+
/// ```
810+
fn pin_chain<F>(self, f: F) -> ChainPinInit<Self, F, T, E>
811+
where
812+
F: FnOnce(Pin<&mut T>) -> Result<(), E>,
813+
{
814+
ChainPinInit(self, f, PhantomData)
815+
}
816+
}
817+
818+
/// An initializer returned by [`PinInit::pin_chain`].
819+
pub struct ChainPinInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, Box<T>)>);
820+
821+
// SAFETY: The `__pinned_init` function is implemented such that it
822+
// - returns `Ok(())` on successful initialization,
823+
// - returns `Err(err)` on error and in this case `slot` will be dropped.
824+
// - considers `slot` pinned.
825+
unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainPinInit<I, F, T, E>
826+
where
827+
I: PinInit<T, E>,
828+
F: FnOnce(Pin<&mut T>) -> Result<(), E>,
829+
{
830+
unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> {
831+
// SAFETY: All requirements fulfilled since this function is `__pinned_init`.
832+
unsafe { self.0.__pinned_init(slot)? };
833+
// SAFETY: The above call initialized `slot` and we still have unique access.
834+
let val = unsafe { &mut *slot };
835+
// SAFETY: `slot` is considered pinned.
836+
let val = unsafe { Pin::new_unchecked(val) };
837+
(self.1)(val).map_err(|e| {
838+
// SAFETY: `slot` was initialized above.
839+
unsafe { core::ptr::drop_in_place(slot) };
840+
e
841+
})
842+
}
770843
}
771844

772845
/// An initializer for `T`.
@@ -808,6 +881,75 @@ pub unsafe trait Init<T: ?Sized, E = Infallible>: PinInit<T, E> {
808881
/// - the caller does not touch `slot` when `Err` is returned, they are only permitted to
809882
/// deallocate.
810883
unsafe fn __init(self, slot: *mut T) -> Result<(), E>;
884+
885+
/// First initializes the value using `self` then calls the function `f` with the initialized
886+
/// value.
887+
///
888+
/// If `f` returns an error the value is dropped and the initializer will forward the error.
889+
///
890+
/// # Examples
891+
///
892+
/// ```rust
893+
/// # #![allow(clippy::disallowed_names)]
894+
/// use kernel::{types::Opaque, init::{self, init_from_closure}};
895+
/// struct Foo {
896+
/// buf: [u8; 1_000_000],
897+
/// }
898+
///
899+
/// impl Foo {
900+
/// fn setup(&mut self) {
901+
/// pr_info!("Setting up foo");
902+
/// }
903+
/// }
904+
///
905+
/// let foo = init!(Foo {
906+
/// buf <- init::zeroed()
907+
/// }).chain(|foo| {
908+
/// foo.setup();
909+
/// Ok(())
910+
/// });
911+
/// ```
912+
fn chain<F>(self, f: F) -> ChainInit<Self, F, T, E>
913+
where
914+
F: FnOnce(&mut T) -> Result<(), E>,
915+
{
916+
ChainInit(self, f, PhantomData)
917+
}
918+
}
919+
920+
/// An initializer returned by [`Init::chain`].
921+
pub struct ChainInit<I, F, T: ?Sized, E>(I, F, __internal::Invariant<(E, Box<T>)>);
922+
923+
// SAFETY: The `__init` function is implemented such that it
924+
// - returns `Ok(())` on successful initialization,
925+
// - returns `Err(err)` on error and in this case `slot` will be dropped.
926+
unsafe impl<T: ?Sized, E, I, F> Init<T, E> for ChainInit<I, F, T, E>
927+
where
928+
I: Init<T, E>,
929+
F: FnOnce(&mut T) -> Result<(), E>,
930+
{
931+
unsafe fn __init(self, slot: *mut T) -> Result<(), E> {
932+
// SAFETY: All requirements fulfilled since this function is `__init`.
933+
unsafe { self.0.__pinned_init(slot)? };
934+
// SAFETY: The above call initialized `slot` and we still have unique access.
935+
(self.1)(unsafe { &mut *slot }).map_err(|e| {
936+
// SAFETY: `slot` was initialized above.
937+
unsafe { core::ptr::drop_in_place(slot) };
938+
e
939+
})
940+
}
941+
}
942+
943+
// SAFETY: `__pinned_init` behaves exactly the same as `__init`.
944+
unsafe impl<T: ?Sized, E, I, F> PinInit<T, E> for ChainInit<I, F, T, E>
945+
where
946+
I: Init<T, E>,
947+
F: FnOnce(&mut T) -> Result<(), E>,
948+
{
949+
unsafe fn __pinned_init(self, slot: *mut T) -> Result<(), E> {
950+
// SAFETY: `__init` has less strict requirements compared to `__pinned_init`.
951+
unsafe { self.__init(slot) }
952+
}
811953
}
812954

813955
/// Creates a new [`PinInit<T, E>`] from the given closure.

Diff for: rust/kernel/init/__internal.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use super::*;
1313
///
1414
/// [nomicon]: https://doc.rust-lang.org/nomicon/subtyping.html
1515
/// [this table]: https://doc.rust-lang.org/nomicon/phantom-data.html#table-of-phantomdata-patterns
16-
type Invariant<T> = PhantomData<fn(*mut T) -> *mut T>;
16+
pub(super) type Invariant<T> = PhantomData<fn(*mut T) -> *mut T>;
1717

1818
/// This is the module-internal type implementing `PinInit` and `Init`. It is unsafe to create this
1919
/// type, since the closure needs to fulfill the same safety requirement as the

0 commit comments

Comments
 (0)