Skip to content

Commit

Permalink
Merge pull request #149 from dtolnay/bound
Browse files Browse the repository at this point in the history
Implied bounds for transparent attribute
  • Loading branch information
dtolnay authored Sep 5, 2021
2 parents e95b4ad + 3e699aa commit 0a1c5bd
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
37 changes: 27 additions & 10 deletions impl/src/expand.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::ast::{Enum, Field, Input, Struct};
use crate::attr::Trait;
use proc_macro2::TokenStream;
use quote::{format_ident, quote, quote_spanned, ToTokens};
use std::collections::BTreeSet as Set;
Expand All @@ -24,9 +25,16 @@ fn impl_struct(input: Struct) -> TokenStream {
let error_where_clause = error_generics.make_where_clause();

let source_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
let only_field = &input.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
}
let member = &only_field.member;
Some(quote! {
std::error::Error::source(self.#only_field.as_dyn_error())
std::error::Error::source(self.#member.as_dyn_error())
})
} else if let Some(source_field) = input.source_field() {
let source = &source_field.member;
Expand Down Expand Up @@ -101,14 +109,15 @@ fn impl_struct(input: Struct) -> TokenStream {
}
});

let mut display_implied_bounds = &Set::new();
let mut display_implied_bounds = Set::new();
let display_body = if input.attrs.transparent.is_some() {
let only_field = &input.fields[0].member;
display_implied_bounds.insert((0, Trait::Display));
Some(quote! {
std::fmt::Display::fmt(&self.#only_field, __formatter)
})
} else if let Some(display) = &input.attrs.display {
display_implied_bounds = &display.implied_bounds;
display_implied_bounds = display.implied_bounds.clone();
let use_as_display = if display.has_bonus_display {
Some(quote! {
#[allow(unused_imports)]
Expand All @@ -130,7 +139,7 @@ fn impl_struct(input: Struct) -> TokenStream {
let display_impl = display_body.map(|body| {
let mut display_generics = input.generics.clone();
let display_where_clause = display_generics.make_where_clause();
for &(field, bound) in display_implied_bounds {
for (field, bound) in display_implied_bounds {
let field = &input.fields[field];
if field.contains_generic {
let field_ty = field.ty;
Expand Down Expand Up @@ -193,10 +202,17 @@ fn impl_enum(input: Enum) -> TokenStream {
let arms = input.variants.iter().map(|variant| {
let ident = &variant.ident;
if variant.attrs.transparent.is_some() {
let only_field = &variant.fields[0].member;
let only_field = &variant.fields[0];
if only_field.contains_generic {
let ty = only_field.ty;
error_where_clause
.predicates
.push(parse_quote!(#ty: std::error::Error));
}
let member = &only_field.member;
let source = quote!(std::error::Error::source(transparent.as_dyn_error()));
quote! {
#ty::#ident {#only_field: transparent} => #source,
#ty::#ident {#member: transparent} => #source,
}
} else if let Some(source_field) = variant.source_field() {
let source = &source_field.member;
Expand Down Expand Up @@ -345,21 +361,22 @@ fn impl_enum(input: Enum) -> TokenStream {
None
};
let arms = input.variants.iter().map(|variant| {
let mut display_implied_bounds = &Set::new();
let mut display_implied_bounds = Set::new();
let display = match &variant.attrs.display {
Some(display) => {
display_implied_bounds = &display.implied_bounds;
display_implied_bounds = display.implied_bounds.clone();
display.to_token_stream()
}
None => {
let only_field = match &variant.fields[0].member {
Member::Named(ident) => ident.clone(),
Member::Unnamed(index) => format_ident!("_{}", index),
};
display_implied_bounds.insert((0, Trait::Display));
quote!(std::fmt::Display::fmt(#only_field, __formatter))
}
};
for &(field, bound) in display_implied_bounds {
for (field, bound) in display_implied_bounds {
let field = &variant.fields[field];
if field.contains_generic {
let field_ty = field.ty;
Expand Down
32 changes: 32 additions & 0 deletions tests/test_generics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ fn test_display_enum_compound() {
assert_eq!(format!("{}", instance), "DebugOnly");
}

// Should expand to:
//
// impl<E> Display for EnumTransparentGeneric<E>
// where
// E: Display;
//
// impl<E> Error for EnumTransparentGeneric<E>
// where
// E: Error,
// Self: Debug + Display;
//
#[derive(Error, Debug)]
pub enum EnumTransparentGeneric<E> {
#[error(transparent)]
Other(E),
}

// Should expand to:
//
// impl<E> Display for StructDebugGeneric<E>
Expand Down Expand Up @@ -127,3 +144,18 @@ pub struct StructFromGeneric<E> {
#[from]
pub source: StructDebugGeneric<E>,
}

// Should expand to:
//
// impl<E> Display for StructTransparentGeneric<E>
// where
// E: Display;
//
// impl<E> Error for StructTransparentGeneric<E>
// where
// E: Error,
// Self: Debug + Display;
//
#[derive(Error, Debug)]
#[error(transparent)]
pub struct StructTransparentGeneric<E>(E);

0 comments on commit 0a1c5bd

Please sign in to comment.