diff --git a/drivers/android/transaction.rs b/drivers/android/transaction.rs index 81db91aed5f354..39f24bad147f83 100644 --- a/drivers/android/transaction.rs +++ b/drivers/android/transaction.rs @@ -7,6 +7,7 @@ use kernel::{ io_buffer::IoBufferWriter, linked_list::List, linked_list::{GetLinks, Links}, + macros::pinned_drop, new_spinlock, prelude::*, sync::{Arc, SpinLock, UniqueArc}, @@ -27,7 +28,7 @@ struct TransactionInner { file_list: List>, } -#[pin_project] +#[pin_project(PinnedDrop)] pub(crate) struct Transaction { #[pin] inner: SpinLock, @@ -276,8 +277,9 @@ impl DeliverToRead for Transaction { } } -impl Drop for Transaction { - fn drop(&mut self) { +#[pinned_drop] +impl PinnedDrop for Transaction { + fn drop(self: Pin<&mut Self>) { if self.free_allocation.load(Ordering::Relaxed) { self.to.buffer_get(self.data_address); } diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 0341a591f9b7bc..107e345715d4f0 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -641,6 +641,18 @@ where } } +/// Trait facilitating pinned destruction. +/// +/// Use [`pinned_drop`] to implement this trait safely. +pub unsafe trait PinnedDrop { + /// # Safety + /// + /// Only call this from `::drop`. + unsafe fn drop(self: Pin<&mut Self>); + #[doc(hidden)] + fn __ensure_no_unsafe_op_in_drop(self: Pin<&mut Self>); +} + /// Smart pointer that can initialize memory in-place. pub trait InPlaceInit: Sized { /// Use the given initializer to in-place initialize a `T`. diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index c5cbbe45c630b6..bd25db53b66b6a 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -6,6 +6,7 @@ mod concat_idents; mod helpers; mod module; mod pin_project; +mod pinned_drop; mod vtable; use proc_macro::TokenStream; @@ -211,3 +212,9 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream { pub fn pin_project(inner: TokenStream, item: TokenStream) -> TokenStream { pin_project::pin_project(inner, item) } + +/// TODO +#[proc_macro_attribute] +pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { + pinned_drop::pinned_drop(args, input) +} diff --git a/rust/macros/pinned_drop.rs b/rust/macros/pinned_drop.rs new file mode 100644 index 00000000000000..c4ca15a129b3bd --- /dev/null +++ b/rust/macros/pinned_drop.rs @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: GPL-2.0 + +use proc_macro::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree}; + +pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut toks = input.into_iter().collect::>(); + assert!(!toks.is_empty()); + // ensure that we have an impl item + assert!(matches!(&toks[0], TokenTree::Ident(i) if i.to_string() == "impl")); + // ensure that we are implementing `PinnedDrop` + let mut nesting: usize = 0; + let mut pinned_drop_idx = None; + for (i, tt) in toks.iter().enumerate() { + match tt { + TokenTree::Punct(p) if p.as_char() == '<' => { + nesting += 1; + } + TokenTree::Punct(p) if p.as_char() == '>' => { + nesting = nesting.checked_sub(1).unwrap(); + } + _ => {} + } + if i >= 1 && nesting == 0 { + assert!(matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop")); + pinned_drop_idx = Some(i); + break; + } + } + let idx = pinned_drop_idx.unwrap(); + //inserting `::kernel::init::` in reverse order + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Alone))); + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Joint))); + toks.insert(idx, TokenTree::Ident(Ident::new("init", Span::call_site()))); + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Alone))); + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Joint))); + toks.insert( + idx, + TokenTree::Ident(Ident::new("kernel", Span::call_site())), + ); + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Alone))); + toks.insert(idx, TokenTree::Punct(Punct::new(':', Spacing::Joint))); + if let Some(TokenTree::Group(last)) = toks.pop() { + let mut inner = last.stream().into_iter().collect::>(); + if let Some(TokenTree::Group(inner_last)) = inner.pop() { + // make the impl unsafe + toks.insert(0, TokenTree::Ident(Ident::new("unsafe", Span::call_site()))); + // make the first function unsafe + inner.insert(0, TokenTree::Ident(Ident::new("unsafe", Span::call_site()))); + // re-add the body + inner.push(TokenTree::Group(inner_last.clone())); + add_ensure_no_unsafe_op_in_drop(&mut inner, inner_last); + toks.push(TokenTree::Group(Group::new( + Delimiter::Brace, + TokenStream::from_iter(inner), + ))); + TokenStream::from_iter(toks) + } else { + toks.push(TokenTree::Group(last)); + TokenStream::from_iter(toks) + } + } else { + TokenStream::from_iter(toks) + } +} + +fn add_ensure_no_unsafe_op_in_drop(v: &mut Vec, inner_last: Group) { + v.push(TokenTree::Ident(Ident::new("fn", Span::call_site()))); + v.push(TokenTree::Ident(Ident::new( + "__ensure_no_unsafe_op_in_drop", + Span::call_site(), + ))); + v.push(TokenTree::Group(Group::new( + Delimiter::Parenthesis, + TokenStream::from_iter(vec![ + TokenTree::Ident(Ident::new("self", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("core", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("pin", Span::call_site())), + TokenTree::Punct(Punct::new(':', Spacing::Joint)), + TokenTree::Punct(Punct::new(':', Spacing::Alone)), + TokenTree::Ident(Ident::new("Pin", Span::call_site())), + TokenTree::Punct(Punct::new('<', Spacing::Alone)), + TokenTree::Punct(Punct::new('&', Spacing::Alone)), + TokenTree::Ident(Ident::new("mut", Span::call_site())), + TokenTree::Ident(Ident::new("Self", Span::call_site())), + TokenTree::Punct(Punct::new('>', Spacing::Alone)), + ]), + ))); + v.push(TokenTree::Group(Group::new( + Delimiter::Brace, + TokenStream::from_iter(vec![ + TokenTree::Ident(Ident::new("if", Span::call_site())), + TokenTree::Ident(Ident::new("false", Span::call_site())), + TokenTree::Group(inner_last), + ]), + ))); +}