diff --git a/pin-project-internal/src/lib.rs b/pin-project-internal/src/lib.rs index f7a73865..1146e4a8 100644 --- a/pin-project-internal/src/lib.rs +++ b/pin-project-internal/src/lib.rs @@ -31,6 +31,32 @@ use syn::parse::Nothing; /// the field. /// - For the other fields, makes the unpinned reference to the field. /// +/// The following methods are implemented on the original `#[pin_project]` type: +/// +/// ```ignore +/// fn project(&mut Pin<&mut Self>) -> ProjectedType; +/// fn project_into(Pin<&mut Self>) -> ProjectedType; +/// ``` +/// +/// The `project` method takes a mutable reference to a pinned +/// type, and returns a projection struct. This is the method +/// you'll usually want to use - since it takes a mutable reference, +/// it can be called multiple times, and allows you to use +/// the original Pin type later on (e.g. to call [`Pin::set`](core::pin::Pin::set)) +/// +/// The `project_into` type takes a pinned type by value (consuming it), +/// and returns a projection struct. The difference between this and the `project` +/// method lies in the lifetime. While the type returned by `project` only lives +/// as long as the 'outer' mutable reference, the type returned by this method +/// lives for as long as the original Pin. This can be useful when returning a pin +/// projection from a method: +/// +/// ```ignore +/// fn get_pin_mut<'a>(mut self: Pin<&'a mut Self>) -> Pin<&'a mut T> { +/// self.project_into().pinned +/// } +/// ``` +/// /// ## Safety /// /// This attribute is completely safe. In the absence of other `unsafe` code *that you write*, diff --git a/pin-project-internal/src/pin_project/enums.rs b/pin-project-internal/src/pin_project/enums.rs index 9c3cbaaa..e0f119e1 100644 --- a/pin-project-internal/src/pin_project/enums.rs +++ b/pin-project-internal/src/pin_project/enums.rs @@ -31,27 +31,33 @@ pub(super) fn parse(cx: &mut Context, mut item: ItemEnum) -> Result let (proj_variants, proj_arms) = variants(cx, &mut item)?; - let Context { proj_ident, proj_trait, orig_ident, lifetime, .. } = &cx; + let proj_ident = &cx.proj_ident; let proj_generics = cx.proj_generics(); - let proj_ty_generics = proj_generics.split_for_impl().1; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); + let where_clause = item.generics.split_for_impl().2; let mut proj_items = quote! { #[allow(clippy::mut_mut)] #[allow(dead_code)] enum #proj_ident #proj_generics #where_clause { #(#proj_variants,)* } }; - proj_items.extend(quote! { - impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause { - fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause { - unsafe { - match self.as_mut().get_unchecked_mut() { - #(#proj_arms,)* - } - } + + let project_body = quote! { + unsafe { + match self.as_mut().get_unchecked_mut() { + #(#proj_arms,)* } } - }); + }; + + let project_into_body = quote! { + unsafe { + match self.get_unchecked_mut() { + #(#proj_arms,)* + } + } + }; + + proj_items.extend(cx.make_trait_impl(&project_body, &project_into_body)); let mut item = item.into_token_stream(); item.extend(proj_items); diff --git a/pin-project-internal/src/pin_project/mod.rs b/pin-project-internal/src/pin_project/mod.rs index aec9a484..7b7a1d17 100644 --- a/pin-project-internal/src/pin_project/mod.rs +++ b/pin-project-internal/src/pin_project/mod.rs @@ -8,6 +8,7 @@ use syn::{ use crate::utils::{ self, crate_path, proj_ident, proj_lifetime_name, proj_trait_ident, DEFAULT_LIFETIME_NAME, + TRAIT_LIFETIME_NAME, }; mod enums; @@ -68,6 +69,9 @@ struct Context { /// Lifetime added to projected type. lifetime: Lifetime, + /// Lifetime on the generated projection trait + trait_lifetime: Lifetime, + /// Where-clause for conditional Unpin implementation. impl_unpin: WhereClause, @@ -92,6 +96,10 @@ impl Context { proj_lifetime_name(&mut lifetime_name, &generics.params); let lifetime = Lifetime::new(&lifetime_name, Span::call_site()); + let mut trait_lifetime_name = String::from(TRAIT_LIFETIME_NAME); + proj_lifetime_name(&mut trait_lifetime_name, &generics.params); + let trait_lifetime = Lifetime::new(&trait_lifetime_name, Span::call_site()); + let mut generics = generics.clone(); let mut impl_unpin = generics.make_where_clause().clone(); if let Some(unsafe_unpin) = unsafe_unpin { @@ -111,6 +119,7 @@ impl Context { vis: vis.clone(), generics, lifetime, + trait_lifetime, impl_unpin, pinned_fields: vec![], unsafe_unpin: unsafe_unpin.is_some(), @@ -118,6 +127,36 @@ impl Context { }) } + /// Creates an implementation of the projection trait. + /// The provided TokenStream will be used as the body of the + /// 'project' and 'project_into' implementations + fn make_trait_impl( + &self, + project_body: &TokenStream, + project_into_body: &TokenStream, + ) -> TokenStream { + let Context { proj_ident, proj_trait, orig_ident, lifetime, trait_lifetime, .. } = &self; + let proj_generics = self.proj_generics(); + + let project_into_generics = self.project_into_generics(); + + let proj_ty_generics = proj_generics.split_for_impl().1; + let (impl_generics, project_into_ty_generics, _) = project_into_generics.split_for_impl(); + let (_, ty_generics, where_clause) = self.generics.split_for_impl(); + + quote! { + impl #impl_generics #proj_trait #project_into_ty_generics for ::core::pin::Pin<&#trait_lifetime mut #orig_ident #ty_generics> #where_clause { + fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause { + #project_body + } + + fn project_into(self) -> #proj_ident #project_into_ty_generics #where_clause { + #project_into_body + } + } + } + } + /// Makes the generics of projected type from the reference of the original generics. fn proj_generics(&self) -> Generics { let mut generics = self.generics.clone(); @@ -125,6 +164,13 @@ impl Context { generics } + /// Makes the generics for the 'project_into' method + fn project_into_generics(&self) -> Generics { + let mut generics = self.generics.clone(); + utils::proj_generics(&mut generics, self.trait_lifetime.clone()); + generics + } + fn push_unpin_bounds(&mut self, ty: Type) { self.pinned_fields.push(ty); } @@ -345,11 +391,17 @@ impl Context { let proj_generics = self.proj_generics(); let proj_ty_generics = proj_generics.split_for_impl().1; - let (orig_generics, _, orig_where_clause) = self.generics.split_for_impl(); + // Add trait lifetime to trait generics + let mut trait_generics = self.generics.clone(); + utils::proj_generics(&mut trait_generics, self.trait_lifetime.clone()); + + let (trait_generics, trait_ty_generics, orig_where_clause) = + trait_generics.split_for_impl(); quote! { - trait #proj_trait #orig_generics { + trait #proj_trait #trait_generics { fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #orig_where_clause; + fn project_into(self) -> #proj_ident #trait_ty_generics #orig_where_clause; } } } diff --git a/pin-project-internal/src/pin_project/structs.rs b/pin-project-internal/src/pin_project/structs.rs index ee7dc855..c69474bf 100644 --- a/pin-project-internal/src/pin_project/structs.rs +++ b/pin-project-internal/src/pin_project/structs.rs @@ -28,26 +28,31 @@ pub(super) fn parse(cx: &mut Context, mut item: ItemStruct) -> Result unnamed(cx, fields)?, }; - let Context { proj_ident, proj_trait, orig_ident, lifetime, .. } = &cx; + let proj_ident = &cx.proj_ident; let proj_generics = cx.proj_generics(); - let proj_ty_generics = proj_generics.split_for_impl().1; - let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl(); + let where_clause = item.generics.split_for_impl().2; let mut proj_items = quote! { #[allow(clippy::mut_mut)] #[allow(dead_code)] struct #proj_ident #proj_generics #where_clause #proj_fields }; - proj_items.extend(quote! { - impl #impl_generics #proj_trait #ty_generics for ::core::pin::Pin<&mut #orig_ident #ty_generics> #where_clause { - fn project<#lifetime>(&#lifetime mut self) -> #proj_ident #proj_ty_generics #where_clause { - unsafe { - let this = self.as_mut().get_unchecked_mut(); - #proj_ident #proj_init - } - } + + let project_body = quote! { + unsafe { + let this = self.as_mut().get_unchecked_mut(); + #proj_ident #proj_init + } + }; + + let project_into_body = quote! { + unsafe { + let this = self.get_unchecked_mut(); + #proj_ident #proj_init } - }); + }; + + proj_items.extend(cx.make_trait_impl(&project_body, &project_into_body)); let mut item = item.into_token_stream(); item.extend(proj_items); diff --git a/pin-project-internal/src/utils.rs b/pin-project-internal/src/utils.rs index befe4d61..aba3d117 100644 --- a/pin-project-internal/src/utils.rs +++ b/pin-project-internal/src/utils.rs @@ -6,6 +6,7 @@ use syn::{ }; pub(crate) const DEFAULT_LIFETIME_NAME: &str = "'_pin"; +pub(crate) const TRAIT_LIFETIME_NAME: &str = "'_outer_pin"; /// Makes the ident of projected type from the reference of the original ident. pub(crate) fn proj_ident(ident: &Ident) -> Ident { diff --git a/tests/pin_project.rs b/tests/pin_project.rs index df595bcb..098176bf 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -262,3 +262,36 @@ fn test_private_type_in_public_type() { OtherVariant(u8), } } + +#[test] +fn test_lifetime_project() { + #[pin_project::pin_project] + struct Struct { + #[pin] + pinned: T, + unpinned: U, + } + + #[pin_project::pin_project] + enum Enum { + Variant { + #[pin] + pinned: T, + unpinned: U, + }, + } + + impl Struct { + fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { + self.project_into().pinned + } + } + + impl Enum { + fn get_pin_mut<'a>(self: Pin<&'a mut Self>) -> Pin<&'a mut T> { + match self.project_into() { + __EnumProjection::Variant { pinned, .. } => pinned, + } + } + } +}