Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #[project] attribute on non-statement expressions #197

Merged
merged 1 commit into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 108 additions & 89 deletions pin-project-internal/src/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ pub(crate) fn attribute(args: &TokenStream, input: Stmt, mutability: Mutability)
.unwrap_or_else(|e| e.to_compile_error())
}

fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
match stmt {
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
fn replace_expr(expr: &mut Expr, mutability: Mutability) {
match expr {
Expr::Match(expr) => {
Context::new(mutability).replace_expr_match(expr);
}
Stmt::Expr(Expr::If(expr_if)) => {
Expr::If(expr_if) => {
let mut expr_if = expr_if;
while let Expr::Let(ref mut expr) = &mut *expr_if.cond {
Context::new(mutability).replace_expr_let(expr);
Expand All @@ -31,15 +31,14 @@ fn replace_stmt(stmt: &mut Stmt, mutability: Mutability) -> Result<()> {
break;
}
}
Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
_ => {}
}
Ok(())
}

fn parse(mut stmt: Stmt, mutability: Mutability) -> Result<TokenStream> {
replace_stmt(&mut stmt, mutability)?;
match &mut stmt {
Stmt::Expr(expr) | Stmt::Semi(expr, _) => replace_expr(expr, mutability),
Stmt::Local(local) => Context::new(mutability).replace_local(local)?,
Stmt::Item(Item::Fn(item)) => replace_item_fn(item, mutability)?,
Stmt::Item(Item::Impl(item)) => replace_item_impl(item, mutability),
Stmt::Item(Item::Use(item)) => replace_item_use(item, mutability)?,
Expand Down Expand Up @@ -68,7 +67,7 @@ impl Context {

fn compare_paths(&self, ident: &Ident, len: usize) -> bool {
match &self.register {
Some((i, l)) => *l == len && ident == i,
Some((i, l)) => *l == len && i == ident,
None => false,
}
}
Expand Down Expand Up @@ -99,7 +98,7 @@ impl Context {
}

fn replace_expr_match(&mut self, expr: &mut ExprMatch) {
expr.arms.iter_mut().for_each(|Arm { pat, .. }| self.replace_pat(pat, true))
expr.arms.iter_mut().for_each(|arm| self.replace_pat(&mut arm.pat, true))
}

fn replace_pat(&mut self, pat: &mut Pat, allow_pat_path: bool) {
Expand Down Expand Up @@ -155,6 +154,10 @@ fn is_replaceable(pat: &Pat, allow_pat_path: bool) -> bool {
}
}

fn replace_ident(ident: &mut Ident, mutability: Mutability) {
*ident = proj_ident(ident, mutability);
}

fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) {
let PathSegment { ident, arguments } = match &mut *item.self_ty {
Type::Path(TypePath { qself: None, path }) => path.segments.last_mut().unwrap(),
Expand Down Expand Up @@ -185,106 +188,122 @@ fn replace_item_impl(item: &mut ItemImpl, mutability: Mutability) {
}

fn replace_item_fn(item: &mut ItemFn, mutability: Mutability) -> Result<()> {
let mut visitor = FnVisitor { res: Ok(()), mutability };
visitor.visit_block_mut(&mut item.block);
visitor.res
}

fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> {
let mut visitor = UseTreeVisitor { res: Ok(()), mutability };
visitor.visit_item_use_mut(item);
visitor.res
}

fn replace_ident(ident: &mut Ident, mutability: Mutability) {
*ident = proj_ident(ident, mutability);
}

// =================================================================================================
// visitors

struct FnVisitor {
res: Result<()>,
mutability: Mutability,
}
struct FnVisitor {
res: Result<()>,
mutability: Mutability,
}

impl FnVisitor {
/// Returns the attribute name.
fn name(&self) -> &str {
match self.mutability {
Mutable => "project",
Immutable => "project_ref",
Owned => "project_replace",
impl FnVisitor {
/// Returns the attribute name.
fn name(&self) -> &str {
match self.mutability {
Mutable => "project",
Immutable => "project_ref",
Owned => "project_replace",
}
}
}

fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
let attr = match node {
Stmt::Expr(Expr::Match(expr)) | Stmt::Semi(Expr::Match(expr), _) => {
expr.attrs.find_remove(self.name())?
fn visit_stmt(&mut self, node: &mut Stmt) -> Result<()> {
match node {
Stmt::Expr(expr) | Stmt::Semi(expr, _) => {
visit_mut::visit_expr_mut(self, expr);
self.visit_expr(expr)
}
Stmt::Local(local) => {
visit_mut::visit_local_mut(self, local);
if let Some(attr) = local.attrs.find_remove(self.name())? {
parse_as_empty(&attr.tokens)?;
Context::new(self.mutability).replace_local(local)?;
}
Ok(())
}
// Do not recurse into nested items.
Stmt::Item(_) => Ok(()),
}
Stmt::Local(local) => local.attrs.find_remove(self.name())?,
Stmt::Expr(Expr::If(expr_if)) => {
if let Expr::Let(_) = &*expr_if.cond {
expr_if.attrs.find_remove(self.name())?
} else {
None
}

fn visit_expr(&mut self, node: &mut Expr) -> Result<()> {
let attr = match node {
Expr::Match(expr) => expr.attrs.find_remove(self.name())?,
Expr::If(expr_if) => {
if let Expr::Let(_) = &*expr_if.cond {
expr_if.attrs.find_remove(self.name())?
} else {
None
}
}
_ => return Ok(()),
};
if let Some(attr) = attr {
parse_as_empty(&attr.tokens)?;
replace_expr(node, self.mutability);
}
_ => return Ok(()),
};
if let Some(attr) = attr {
parse_as_empty(&attr.tokens)?;
replace_stmt(node, self.mutability)?;
Ok(())
}
Ok(())
}
}

impl VisitMut for FnVisitor {
fn visit_stmt_mut(&mut self, node: &mut Stmt) {
if self.res.is_err() {
return;
impl VisitMut for FnVisitor {
fn visit_stmt_mut(&mut self, node: &mut Stmt) {
if self.res.is_err() {
return;
}
if let Err(e) = self.visit_stmt(node) {
self.res = Err(e)
}
}

visit_mut::visit_stmt_mut(self, node);

if let Err(e) = self.visit_stmt(node) {
self.res = Err(e)
fn visit_expr_mut(&mut self, node: &mut Expr) {
if self.res.is_err() {
return;
}
if let Err(e) = self.visit_expr(node) {
self.res = Err(e)
}
}
}

fn visit_item_mut(&mut self, _: &mut Item) {
// Do not recurse into nested items.
fn visit_item_mut(&mut self, _: &mut Item) {
// Do not recurse into nested items.
}
}
}

struct UseTreeVisitor {
res: Result<()>,
mutability: Mutability,
let mut visitor = FnVisitor { res: Ok(()), mutability };
visitor.visit_block_mut(&mut item.block);
visitor.res
}

impl VisitMut for UseTreeVisitor {
fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
if self.res.is_err() {
return;
}
fn replace_item_use(item: &mut ItemUse, mutability: Mutability) -> Result<()> {
struct UseTreeVisitor {
res: Result<()>,
mutability: Mutability,
}

match node {
// Desugar `use tree::<name>` into `tree::__<name>Projection`.
UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability),
UseTree::Glob(glob) => {
self.res =
Err(error!(glob, "#[project] attribute may not be used on glob imports"));
impl VisitMut for UseTreeVisitor {
fn visit_use_tree_mut(&mut self, node: &mut UseTree) {
if self.res.is_err() {
return;
}
UseTree::Rename(rename) => {
// TODO: Consider allowing the projected type to be renamed by `#[project] use Foo as Bar`.
self.res =
Err(error!(rename, "#[project] attribute may not be used on renamed imports"));
}
node @ UseTree::Path(_) | node @ UseTree::Group(_) => {
visit_mut::visit_use_tree_mut(self, node)

match node {
// Desugar `use tree::<name>` into `tree::__<name>Projection`.
UseTree::Name(name) => replace_ident(&mut name.ident, self.mutability),
UseTree::Glob(glob) => {
self.res =
Err(error!(glob, "#[project] attribute may not be used on glob imports"));
}
UseTree::Rename(rename) => {
self.res = Err(error!(
rename,
"#[project] attribute may not be used on renamed imports"
));
}
node @ UseTree::Path(_) | node @ UseTree::Group(_) => {
visit_mut::visit_use_tree_mut(self, node)
}
}
}
}

let mut visitor = UseTreeVisitor { res: Ok(()), mutability };
visitor.visit_item_use_mut(item);
visitor.res
}
24 changes: 24 additions & 0 deletions tests/pin_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,27 @@ fn self_in_where_clause() {
type Foo = Struct1<T>;
}
}

#[test]
fn where_clause() {
#[pin_project]
struct StructWhereClause<T>
where
T: Copy,
{
field: T,
}

#[pin_project]
struct TupleStructWhereClause<T>(T)
where
T: Copy;

#[pin_project]
enum EnumWhereClause<T>
where
T: Copy,
{
Variant(T),
}
}
42 changes: 19 additions & 23 deletions tests/project.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
#![warn(rust_2018_idioms, single_use_lifetimes)]
#![allow(dead_code)]

// This hack is needed until https://github.com/rust-lang/rust/pull/69201
// makes it way into stable.
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on stable,
// which means that it will error even behind a `#[rustversion::nightly]`
// Ceurrently, `#[attr] if true {}` doesn't even *parse* on MSRV,
// which means that it will error even behind a `#[rustversion::since(..)]`
//
// This trick makes sure that we don't even attempt to parse
// the `#[project] if let _` test on stable.
#[rustversion::nightly]
// the `#[project] if let _` test on MSRV.
#[rustversion::since(1.43)]
include!("project_if_attr.rs.in");

use pin_project::{pin_project, project};
Expand Down Expand Up @@ -194,23 +192,21 @@ mod project_use_2 {
}
}

#[pin_project]
struct StructWhereClause<T>
where
T: Copy,
{
field: T,
}
#[test]
#[project]
fn non_stmt_expr_match() {
#[pin_project]
enum Enum<A> {
Variant(#[pin] A),
}

#[pin_project]
struct TupleStructWhereClause<T>(T)
where
T: Copy;
let mut x = Enum::Variant(1);
let x = Pin::new(&mut x).project();

#[pin_project]
enum EnumWhereClause<T>
where
T: Copy,
{
Variant(T),
Some(
#[project]
match x {
Enum::Variant(_x) => {}
},
);
}
20 changes: 17 additions & 3 deletions tests/project_if_attr.rs.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// FIXME: Once https://github.com/rust-lang/rust/pull/69201 makes its
// way into stable, move this back into `project.rs

#[test]
#[project]
fn project_if_let() {
Expand All @@ -27,3 +24,20 @@ fn project_if_let() {
}
}

#[test]
#[project]
fn non_stmt_expr_if_let() {
#[pin_project]
enum Enum<A> {
Variant(#[pin] A),
}

let mut x = Enum::Variant(1);
let x = Pin::new(&mut x).project();

#[allow(irrefutable_let_patterns)]
Some(
#[project]
if let Enum::Variant(_x) = x {},
);
}