From 14e0ad8d1ed5bd7a9d81814a9954c1e5bb4cb85b Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Mon, 4 May 2020 21:57:16 +0900 Subject: [PATCH] Fix #[project] on non-statement expressions --- pin-project-internal/src/project.rs | 57 +++++++++++++++++++++-------- tests/pin_project.rs | 24 ++++++++++++ tests/project.rs | 42 ++++++++++----------- tests/project_if_attr.rs.in | 20 ++++++++-- 4 files changed, 101 insertions(+), 42 deletions(-) diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 1ab70b9d..bbd2f9f5 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -13,12 +13,12 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) .unwrap_or_else(|e| e.to_compile_error()) } -fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> { - match stmt { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { +fn replace_expr(expr: &mut Expr, mutability: Mutability) { + match expr { + Expr::Match(expr) => { Context::new(mutability).replace_expr_match(expr); } - Stmt::Expr(Expr::If(expr_if)) => { + Expr::If(expr_if) => { let mut expr_if = expr_if; while let Expr::Let(ref mut expr) = &mut *expr_if.cond { Context::new(mutability).replace_expr_let(expr); @@ -31,15 +31,18 @@ fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> { break; } } - Stmt::Local(local) => Context::new(mutability).replace_local(local)?, _ => {} } - Ok(()) +} + +fn replace_local(local: &mut Local, mutability: Mutability) -> Result<()> { + Context::new(mutability).replace_local(local) } fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { - replace_stmt(&mut stmt, mutability)?; match &mut stmt { + Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, mutability), + Stmt::Local(local) => replace_local(local, mutability)?, Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?, Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, @@ -219,12 +222,28 @@ impl FnVisitor { } fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> { - let attr = match node { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - expr.attrs.find_remove(self.name())? + match node { + Stmt::Expr(expr) | Stmt::Semi(expr, _) => { + visit_mut::visit_expr_mut(self, expr); + self.visit_expr(expr) + } + Stmt::Local(local) => { + visit_mut::visit_local_mut(self, local); + if let Some(attr) = local.attrs.find_remove(self.name())? { + parse_as_empty(&attr.tokens)?; + replace_local(local, self.mutability)?; + } + Ok(()) } - Stmt::Local(local) => local.attrs.find_remove(self.name())?, - Stmt::Expr(Expr::If(expr_if)) => { + // Do not recurse into nested items. + Stmt::Item(_) => Ok(()), + } + } + + fn visit_expr(&mut self, node: &mut Expr) -> Result<()> { + let attr = match node { + Expr::Match(expr) => expr.attrs.find_remove(self.name())?, + Expr::If(expr_if) => { if let Expr::Let(_) = &*expr_if.cond { expr_if.attrs.find_remove(self.name())? } else { @@ -235,7 +254,7 @@ impl FnVisitor { }; if let Some(attr) = attr { parse_as_empty(&attr.tokens)?; - replace_stmt(node, self.mutability)?; + replace_expr(node, self.mutability); } Ok(()) } @@ -246,14 +265,20 @@ impl VisitMut for FnVisitor { if self.res.is_err() { return; } - - visit_mut::visit_stmt_mut(self, node); - if let Err(e) = self.visit_stmt(node) { self.res = Err(e) } } + fn visit_expr_mut(&mut self, node: &mut Expr) { + if self.res.is_err() { + return; + } + if let Err(e) = self.visit_expr(node) { + self.res = Err(e) + } + } + fn visit_item_mut(&mut self, _: &mut Item) { // Do not recurse into nested items. } diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 595e69e6..8f23b9e4 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -556,3 +556,27 @@ fn self_in_where_clause() { type Foo = Struct1; } } + +#[test] +fn where_clause() { + #[pin_project] + struct StructWhereClause + where + T: Copy, + { + field: T, + } + + #[pin_project] + struct TupleStructWhereClause(T) + where + T: Copy; + + #[pin_project] + enum EnumWhereClause + where + T: Copy, + { + Variant(T), + } +} diff --git a/tests/project.rs b/tests/project.rs index bbf37214..a51caad4 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -1,14 +1,12 @@ #![warn(rust_2018_idioms, single_use_lifetimes)] #![allow(dead_code)] -// This hack is needed until https://github.com/rust-lang/rust/pull/69201 -// makes it way into stable. -// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable, -// which means that it will error even behind a `#[rustversion::nightly]` +// Ceurrently, `#[attr] if true {}` doesn't even *parse* on MSRV, +// which means that it will error even behind a `#[rustversion::since(..)]` // // This trick makes sure that we don't even attempt to parse -// the `#[project] if let _` test on stable. -#[rustversion::nightly] +// the `#[project] if let _` test on MSRV. +#[rustversion::since(1.43)] include!("project_if_attr.rs.in"); use pin_project::{pin_project, project}; @@ -194,23 +192,21 @@ mod project_use_2 { } } -#[pin_project] -struct StructWhereClause -where - T: Copy, -{ - field: T, -} +#[test] +#[project] +fn non_stmt_expr_match() { + #[pin_project] + enum Enum { + Variant(#[pin] A), + } -#[pin_project] -struct TupleStructWhereClause(T) -where - T: Copy; + let mut x = Enum::Variant(1); + let x = Pin::new(&mut x).project(); -#[pin_project] -enum EnumWhereClause -where - T: Copy, -{ - Variant(T), + Some( + #[project] + match x { + Enum::Variant(_x) => {} + }, + ); } diff --git a/tests/project_if_attr.rs.in b/tests/project_if_attr.rs.in index 2d1eb267..670d49f9 100644 --- a/tests/project_if_attr.rs.in +++ b/tests/project_if_attr.rs.in @@ -1,6 +1,3 @@ -// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its -// way into stable, move this back into `project.rs - #[test] #[project] fn project_if_let() { @@ -27,3 +24,20 @@ fn project_if_let() { } } +#[test] +#[project] +fn non_stmt_expr_if_let() { + #[pin_project] + enum Enum { + Variant(#[pin] A), + } + + let mut x = Enum::Variant(1); + let x = Pin::new(&mut x).project(); + + #[allow(irrefutable_let_patterns)] + Some( + #[project] + if let Enum::Variant(_x) = x {}, + ); +}