Skip to content

Commit

Permalink
Fix #[project] on non-statement expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
taiki-e committed May 4, 2020
1 parent 7d46ff1 commit 14e0ad8
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 42 deletions.
57 changes: 41 additions & 16 deletions pin-project-internal/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability)
.unwrap_or_else(|e| e.to_compile_error())
}

fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
match stmt {
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
fn replace_expr(expr: &mut Expr, mutability: Mutability) {
match expr {
Expr::Match(expr) => {
Context::new(mutability).replace_expr_match(expr);
}
Stmt::Expr(Expr::If(expr_if)) => {
Expr::If(expr_if) => {
let mut expr_if = expr_if;
while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
Context::new(mutability).replace_expr_let(expr);
Expand All @@ -31,15 +31,18 @@ fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
break;
}
}
Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
_ => {}
}
Ok(())
}

fn replace_local(local: &mut Local, mutability: Mutability) -> Result<()> {
Context::new(mutability).replace_local(local)
}

fn parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream> {
replace_stmt(&mut stmt, mutability)?;
match &mut stmt {
Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, mutability),
Stmt::Local(local) => replace_local(local, mutability)?,
Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?,
Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability),
Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?,
Expand Down Expand Up @@ -219,12 +222,28 @@ impl FnVisitor {
}

fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
let attr = match node {
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
expr.attrs.find_remove(self.name())?
match node {
Stmt::Expr(expr) | Stmt::Semi(expr, _) => {
visit_mut::visit_expr_mut(self, expr);
self.visit_expr(expr)
}
Stmt::Local(local) => {
visit_mut::visit_local_mut(self, local);
if let Some(attr) = local.attrs.find_remove(self.name())? {
parse_as_empty(&attr.tokens)?;
replace_local(local, self.mutability)?;
}
Ok(())
}
Stmt::Local(local) => local.attrs.find_remove(self.name())?,
Stmt::Expr(Expr::If(expr_if)) => {
// Do not recurse into nested items.
Stmt::Item(_) => Ok(()),
}
}

fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
let attr = match node {
Expr::Match(expr) => expr.attrs.find_remove(self.name())?,
Expr::If(expr_if) => {
if let Expr::Let(_) = &*expr_if.cond {
expr_if.attrs.find_remove(self.name())?
} else {
Expand All @@ -235,7 +254,7 @@ impl FnVisitor {
};
if let Some(attr) = attr {
parse_as_empty(&attr.tokens)?;
replace_stmt(node, self.mutability)?;
replace_expr(node, self.mutability);
}
Ok(())
}
Expand All @@ -246,14 +265,20 @@ impl VisitMut for FnVisitor {
if self.res.is_err() {
return;
}

visit_mut::visit_stmt_mut(self, node);

if let Err(e) = self.visit_stmt(node) {
self.res = Err(e)
}
}

fn visit_expr_mut(&mut self, node: &mut Expr) {
if self.res.is_err() {
return;
}
if let Err(e) = self.visit_expr(node) {
self.res = Err(e)
}
}

fn visit_item_mut(&mut self, _: &mut Item) {
// Do not recurse into nested items.
}
Expand Down
24 changes: 24 additions & 0 deletions tests/pin_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,27 @@ fn self_in_where_clause() {
type Foo = Struct1<T>;
}
}

#[test]
fn where_clause() {
#[pin_project]
struct StructWhereClause<T>
where
T: Copy,
{
field: T,
}

#[pin_project]
struct TupleStructWhereClause<T>(T)
where
T: Copy;

#[pin_project]
enum EnumWhereClause<T>
where
T: Copy,
{
Variant(T),
}
}
42 changes: 19 additions & 23 deletions tests/project.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
#![warn(rust_2018_idioms, single_use_lifetimes)]
#![allow(dead_code)]

// This hack is needed until https://github.com/rust-lang/rust/pull/69201
// makes it way into stable.
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable,
// which means that it will error even behind a `#[rustversion::nightly]`
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on MSRV,
// which means that it will error even behind a `#[rustversion::since(..)]`
//
// This trick makes sure that we don't even attempt to parse
// the `#[project] if let _` test on stable.
#[rustversion::nightly]
// the `#[project] if let _` test on MSRV.
#[rustversion::since(1.43)]
include!("project_if_attr.rs.in");

use pin_project::{pin_project, project};
Expand Down Expand Up @@ -194,23 +192,21 @@ mod project_use_2 {
}
}

#[pin_project]
struct StructWhereClause<T>
where
T: Copy,
{
field: T,
}
#[test]
#[project]
fn non_stmt_expr_match() {
#[pin_project]
enum Enum<A> {
Variant(#[pin] A),
}

#[pin_project]
struct TupleStructWhereClause<T>(T)
where
T: Copy;
let mut x = Enum::Variant(1);
let x = Pin::new(&mut x).project();

#[pin_project]
enum EnumWhereClause<T>
where
T: Copy,
{
Variant(T),
Some(
#[project]
match x {
Enum::Variant(_x) => {}
},
);
}
20 changes: 17 additions & 3 deletions tests/project_if_attr.rs.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its
// way into stable, move this back into `project.rs

#[test]
#[project]
fn project_if_let() {
Expand All @@ -27,3 +24,20 @@ fn project_if_let() {
}
}

#[test]
#[project]
fn non_stmt_expr_if_let() {
#[pin_project]
enum Enum<A> {
Variant(#[pin] A),
}

let mut x = Enum::Variant(1);
let x = Pin::new(&mut x).project();

#[allow(irrefutable_let_patterns)]
Some(
#[project]
if let Enum::Variant(_x) = x {},
);
}

0 comments on commit 14e0ad8

Please sign in to comment.