diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 8bb07d54..92627d43 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -13,12 +13,33 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) .unwrap_or_else(|e| e.to_compile_error()) } -fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { - match &mut stmt { +fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> { + match stmt { Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::new(mutability).replace_expr_match(expr) + Context::new(mutability).replace_expr_match(expr); + } + Stmt::Expr(Expr::If(expr_if)) => { + let mut expr_if = expr_if; + while let Expr::Let(ref mut expr) = &mut *expr_if.cond { + Context::new(mutability).replace_expr_let(expr); + if let Some((_, ref mut expr)) = expr_if.else_branch { + if let Expr::If(new_expr_if) = &mut **expr { + expr_if = new_expr_if; + continue; + } + } + break; + } } Stmt::Local(local) => Context::new(mutability).replace_local(local)?, + _ => {} + } + Ok(()) +} + +fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { + replace_stmt(&mut stmt, mutability)?; + match &mut stmt { Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?, Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, @@ -73,6 +94,10 @@ impl Context { Ok(()) } + fn replace_expr_let(&mut self, expr: &mut ExprLet) { + self.replace_pat(&mut expr.pat, true) + } + fn replace_expr_match(&mut self, expr: &mut ExprMatch) { expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true)) } @@ -195,17 +220,18 @@ impl FnVisitor { expr.attrs.find_remove(self.name())? } Stmt::Local(local) => local.attrs.find_remove(self.name())?, + Stmt::Expr(Expr::If(expr_if)) => { + if let Expr::Let(_) = &*expr_if.cond { + expr_if.attrs.find_remove(self.name())? + } else { + None + } + } _ => return Ok(()), }; if let Some(attr) = attr { parse_as_empty(&attr.tokens)?; - match node { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - Context::new(self.mutability).replace_expr_match(expr) - } - Stmt::Local(local) => Context::new(self.mutability).replace_local(local)?, - _ => unreachable!(), - } + replace_stmt(node, self.mutability)?; } Ok(()) } diff --git a/tests/project.rs b/tests/project.rs index 84858d8a..6463ab57 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -2,6 +2,16 @@ #![warn(rust_2018_idioms, single_use_lifetimes)] #![allow(dead_code)] +// This hack is needed until https://github.com/rust-lang/rust/pull/69201 +// makes it way into stable. +// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable, +// which means that it will error even behind a `#[rustversion::nightly]` +// +// This trick makes sure that we don't even attempt to parse +// the `#[project] if let _` test on stable. +#[rustversion::nightly] +include!("project_if_attr.rs.in"); + use pin_project::{pin_project, project}; use std::pin::Pin; diff --git a/tests/project_if_attr.rs.in b/tests/project_if_attr.rs.in new file mode 100644 index 00000000..2d1eb267 --- /dev/null +++ b/tests/project_if_attr.rs.in @@ -0,0 +1,29 @@ +// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its +// way into stable, move this back into `project.rs + +#[test] +#[project] +fn project_if_let() { + #[pin_project] + enum Foo { + Variant1(#[pin] A), + Variant2(u8), + Variant3 { + #[pin] field: B + } + } + + let mut foo: Foo = Foo::Variant1(true); + let foo = Pin::new(&mut foo).project(); + + #[project] + if let Foo::Variant1(a) = foo { + let a: Pin<&mut bool> = a; + assert_eq!(*a, true); + } else if let Foo::Variant2(_) = foo { + unreachable!(); + } else if let Foo::Variant3 { .. } = foo { + unreachable!(); + } +} +