From 9be39c4e7769136958b1b24fedd5ff5d9c0e8cd9 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Tue, 5 May 2020 02:17:15 +0900 Subject: [PATCH] Fix #[project] attribute on non-statement expressions --- pin-project-internal/src/project.rs | 197 +++++++++++++++------------- tests/pin_project.rs | 24 ++++ tests/project.rs | 42 +++--- tests/project_if_attr.rs.in | 20 ++- 4 files changed, 168 insertions(+), 115 deletions(-) diff --git a/pin-project-internal/src/project.rs b/pin-project-internal/src/project.rs index 1ab70b9d..f2a17f83 100644 --- a/pin-project-internal/src/project.rs +++ b/pin-project-internal/src/project.rs @@ -13,12 +13,12 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability) .unwrap_or_else(|e| e.to_compile_error()) } -fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> { - match stmt { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { +fn replace_expr(expr: &mut Expr, mutability: Mutability) { + match expr { + Expr::Match(expr) => { Context::new(mutability).replace_expr_match(expr); } - Stmt::Expr(Expr::If(expr_if)) => { + Expr::If(expr_if) => { let mut expr_if = expr_if; while let Expr::Let(ref mut expr) = &mut *expr_if.cond { Context::new(mutability).replace_expr_let(expr); @@ -31,15 +31,14 @@ fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> { break; } } - Stmt::Local(local) => Context::new(mutability).replace_local(local)?, _ => {} } - Ok(()) } fn parse(mut stmt: Stmt, mutability: Mutability) -> Result { - replace_stmt(&mut stmt, mutability)?; match &mut stmt { + Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, mutability), + Stmt::Local(local) => Context::new(mutability).replace_local(local)?, Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?, Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability), Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?, @@ -68,7 +67,7 @@ impl Context { fn compare_paths(&self, ident: &Ident, len: usize) -> bool { match &self.register { - Some((i, l)) => *l == len && ident == i, + Some((i, l)) => *l == len && i == ident, None => false, } } @@ -99,7 +98,7 @@ impl Context { } fn replace_expr_match(&mut self, expr: &mut ExprMatch) { - expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true)) + expr.arms.iter_mut().for_each(|arm| self.replace_pat(&mut arm.pat, true)) } fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) { @@ -155,6 +154,10 @@ fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool { } } +fn replace_ident(ident: &mut Ident, mutability: Mutability) { + *ident = proj_ident(ident, mutability); +} + fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) { let PathSegment { ident, arguments } = match &mut *item.self_ty { Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(), @@ -185,106 +188,122 @@ fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) { } fn replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()> { - let mut visitor = FnVisitor { res: Ok(()), mutability }; - visitor.visit_block_mut(&mut item.block); - visitor.res -} - -fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> { - let mut visitor = UseTreeVisitor { res: Ok(()), mutability }; - visitor.visit_item_use_mut(item); - visitor.res -} - -fn replace_ident(ident: &mut Ident, mutability: Mutability) { - *ident = proj_ident(ident, mutability); -} - -// ================================================================================================= -// visitors - -struct FnVisitor { - res: Result<()>, - mutability: Mutability, -} + struct FnVisitor { + res: Result<()>, + mutability: Mutability, + } -impl FnVisitor { - /// Returns the attribute name. - fn name(&self) -> &str { - match self.mutability { - Mutable => "project", - Immutable => "project_ref", - Owned => "project_replace", + impl FnVisitor { + /// Returns the attribute name. + fn name(&self) -> &str { + match self.mutability { + Mutable => "project", + Immutable => "project_ref", + Owned => "project_replace", + } } - } - fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> { - let attr = match node { - Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => { - expr.attrs.find_remove(self.name())? + fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> { + match node { + Stmt::Expr(expr) | Stmt::Semi(expr, _) => { + visit_mut::visit_expr_mut(self, expr); + self.visit_expr(expr) + } + Stmt::Local(local) => { + visit_mut::visit_local_mut(self, local); + if let Some(attr) = local.attrs.find_remove(self.name())? { + parse_as_empty(&attr.tokens)?; + Context::new(self.mutability).replace_local(local)?; + } + Ok(()) + } + // Do not recurse into nested items. + Stmt::Item(_) => Ok(()), } - Stmt::Local(local) => local.attrs.find_remove(self.name())?, - Stmt::Expr(Expr::If(expr_if)) => { - if let Expr::Let(_) = &*expr_if.cond { - expr_if.attrs.find_remove(self.name())? - } else { - None + } + + fn visit_expr(&mut self, node: &mut Expr) -> Result<()> { + let attr = match node { + Expr::Match(expr) => expr.attrs.find_remove(self.name())?, + Expr::If(expr_if) => { + if let Expr::Let(_) = &*expr_if.cond { + expr_if.attrs.find_remove(self.name())? + } else { + None + } } + _ => return Ok(()), + }; + if let Some(attr) = attr { + parse_as_empty(&attr.tokens)?; + replace_expr(node, self.mutability); } - _ => return Ok(()), - }; - if let Some(attr) = attr { - parse_as_empty(&attr.tokens)?; - replace_stmt(node, self.mutability)?; + Ok(()) } - Ok(()) } -} -impl VisitMut for FnVisitor { - fn visit_stmt_mut(&mut self, node: &mut Stmt) { - if self.res.is_err() { - return; + impl VisitMut for FnVisitor { + fn visit_stmt_mut(&mut self, node: &mut Stmt) { + if self.res.is_err() { + return; + } + if let Err(e) = self.visit_stmt(node) { + self.res = Err(e) + } } - visit_mut::visit_stmt_mut(self, node); - - if let Err(e) = self.visit_stmt(node) { - self.res = Err(e) + fn visit_expr_mut(&mut self, node: &mut Expr) { + if self.res.is_err() { + return; + } + if let Err(e) = self.visit_expr(node) { + self.res = Err(e) + } } - } - fn visit_item_mut(&mut self, _: &mut Item) { - // Do not recurse into nested items. + fn visit_item_mut(&mut self, _: &mut Item) { + // Do not recurse into nested items. + } } -} -struct UseTreeVisitor { - res: Result<()>, - mutability: Mutability, + let mut visitor = FnVisitor { res: Ok(()), mutability }; + visitor.visit_block_mut(&mut item.block); + visitor.res } -impl VisitMut for UseTreeVisitor { - fn visit_use_tree_mut(&mut self, node: &mut UseTree) { - if self.res.is_err() { - return; - } +fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> { + struct UseTreeVisitor { + res: Result<()>, + mutability: Mutability, + } - match node { - // Desugar `use tree::` into `tree::__Projection`. - UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability), - UseTree::Glob(glob) => { - self.res = - Err(error!(glob, "#[project] attribute may not be used on glob imports")); + impl VisitMut for UseTreeVisitor { + fn visit_use_tree_mut(&mut self, node: &mut UseTree) { + if self.res.is_err() { + return; } - UseTree::Rename(rename) => { - // TODO: Consider allowing the projected type to be renamed by `#[project] use Foo as Bar`. - self.res = - Err(error!(rename, "#[project] attribute may not be used on renamed imports")); - } - node @ UseTree::Path(_) | node @ UseTree::Group(_) => { - visit_mut::visit_use_tree_mut(self, node) + + match node { + // Desugar `use tree::` into `tree::__Projection`. + UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability), + UseTree::Glob(glob) => { + self.res = + Err(error!(glob, "#[project] attribute may not be used on glob imports")); + } + UseTree::Rename(rename) => { + self.res = Err(error!( + rename, + "#[project] attribute may not be used on renamed imports" + )); + } + node @ UseTree::Path(_) | node @ UseTree::Group(_) => { + visit_mut::visit_use_tree_mut(self, node) + } } } } + + let mut visitor = UseTreeVisitor { res: Ok(()), mutability }; + visitor.visit_item_use_mut(item); + visitor.res } diff --git a/tests/pin_project.rs b/tests/pin_project.rs index 595e69e6..8f23b9e4 100644 --- a/tests/pin_project.rs +++ b/tests/pin_project.rs @@ -556,3 +556,27 @@ fn self_in_where_clause() { type Foo = Struct1; } } + +#[test] +fn where_clause() { + #[pin_project] + struct StructWhereClause + where + T: Copy, + { + field: T, + } + + #[pin_project] + struct TupleStructWhereClause(T) + where + T: Copy; + + #[pin_project] + enum EnumWhereClause + where + T: Copy, + { + Variant(T), + } +} diff --git a/tests/project.rs b/tests/project.rs index bbf37214..a51caad4 100644 --- a/tests/project.rs +++ b/tests/project.rs @@ -1,14 +1,12 @@ #![warn(rust_2018_idioms, single_use_lifetimes)] #![allow(dead_code)] -// This hack is needed until https://github.com/rust-lang/rust/pull/69201 -// makes it way into stable. -// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable, -// which means that it will error even behind a `#[rustversion::nightly]` +// Ceurrently, `#[attr] if true {}` doesn't even *parse* on MSRV, +// which means that it will error even behind a `#[rustversion::since(..)]` // // This trick makes sure that we don't even attempt to parse -// the `#[project] if let _` test on stable. -#[rustversion::nightly] +// the `#[project] if let _` test on MSRV. +#[rustversion::since(1.43)] include!("project_if_attr.rs.in"); use pin_project::{pin_project, project}; @@ -194,23 +192,21 @@ mod project_use_2 { } } -#[pin_project] -struct StructWhereClause -where - T: Copy, -{ - field: T, -} +#[test] +#[project] +fn non_stmt_expr_match() { + #[pin_project] + enum Enum { + Variant(#[pin] A), + } -#[pin_project] -struct TupleStructWhereClause(T) -where - T: Copy; + let mut x = Enum::Variant(1); + let x = Pin::new(&mut x).project(); -#[pin_project] -enum EnumWhereClause -where - T: Copy, -{ - Variant(T), + Some( + #[project] + match x { + Enum::Variant(_x) => {} + }, + ); } diff --git a/tests/project_if_attr.rs.in b/tests/project_if_attr.rs.in index 2d1eb267..670d49f9 100644 --- a/tests/project_if_attr.rs.in +++ b/tests/project_if_attr.rs.in @@ -1,6 +1,3 @@ -// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its -// way into stable, move this back into `project.rs - #[test] #[project] fn project_if_let() { @@ -27,3 +24,20 @@ fn project_if_let() { } } +#[test] +#[project] +fn non_stmt_expr_if_let() { + #[pin_project] + enum Enum { + Variant(#[pin] A), + } + + let mut x = Enum::Variant(1); + let x = Pin::new(&mut x).project(); + + #[allow(irrefutable_let_patterns)] + Some( + #[project] + if let Enum::Variant(_x) = x {}, + ); +}