Skip to content

Commit d279a3c

Browse files
committed
Add support for defaulted methods
1 parent 6a5e7ab commit d279a3c

File tree

2 files changed

+104
-14
lines changed

2 files changed

+104
-14
lines changed

Diff for: trait-variant/examples/variant.rs

+16
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,25 @@ pub trait LocalIntFactory {
1717
Self: 'a;
1818

1919
async fn make(&self, x: u32, y: &str) -> i32;
20+
async fn make_mut(&mut self);
2021
fn stream(&self) -> impl Iterator<Item = i32>;
2122
fn call(&self) -> u32;
2223
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
24+
async fn defaulted(&self) -> i32 {
25+
self.make(10, "10").await
26+
}
27+
async fn defaulted_mut(&mut self) -> i32 {
28+
self.make(10, "10").await
29+
}
30+
async fn defaulted_mut_2(&mut self) {
31+
self.make_mut().await
32+
}
33+
async fn defaulted_move(self) -> i32
34+
where
35+
Self: Sized,
36+
{
37+
self.make(10, "10").await
38+
}
2339
}
2440

2541
#[allow(dead_code)]

Diff for: trait-variant/src/variant.rs

+88-14
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88

99
use std::iter;
1010

11-
use proc_macro2::TokenStream;
11+
use proc_macro2::{Span, TokenStream};
1212
use quote::quote;
1313
use syn::{
1414
parse::{Parse, ParseStream},
1515
parse_macro_input,
1616
punctuated::Punctuated,
1717
token::Plus,
18-
Error, FnArg, Generics, Ident, ItemTrait, Pat, PatType, Result, ReturnType, Signature, Token,
19-
TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type, TypeImplTrait,
20-
TypeParamBound,
18+
Error, FnArg, Generics, Ident, ItemTrait, Pat, PatType, Receiver, Result, ReturnType,
19+
Signature, Token, TraitBound, TraitItem, TraitItemConst, TraitItemFn, TraitItemType, Type,
20+
TypeImplTrait, TypeParamBound, WhereClause,
2121
};
2222

2323
struct Attrs {
@@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
119119
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
120120
// fn call(&self) -> u32;
121121
// }
122-
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
122+
let TraitItem::Fn(fn_item @ TraitItemFn { sig, default, .. }) = item else {
123123
return item.clone();
124124
};
125-
let (arrow, output) = if sig.asyncness.is_some() {
125+
let (sig, default) = if sig.asyncness.is_some() {
126126
let orig = match &sig.output {
127127
ReturnType::Default => quote! { () },
128128
ReturnType::Type(_, ty) => quote! { #ty },
@@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134134
.chain(bounds.iter().cloned())
135135
.collect(),
136136
});
137-
(syn::parse2(quote! { -> }).unwrap(), ty)
137+
let mut sig = sig.clone();
138+
if default.is_some() {
139+
add_receiver_bounds(&mut sig);
140+
}
141+
142+
(
143+
Signature {
144+
asyncness: None,
145+
output: ReturnType::Type(syn::parse2(quote! { -> }).unwrap(), Box::new(ty)),
146+
..sig.clone()
147+
},
148+
fn_item
149+
.default
150+
.as_ref()
151+
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
152+
)
138153
} else {
139154
match &sig.output {
140155
ReturnType::Type(arrow, ty) => match &**ty {
@@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143158
impl_token: it.impl_token,
144159
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
145160
});
146-
(*arrow, ty)
161+
(
162+
Signature {
163+
output: ReturnType::Type(*arrow, Box::new(ty)),
164+
..sig.clone()
165+
},
166+
fn_item.default.clone(),
167+
)
147168
}
148169
_ => return item.clone(),
149170
},
150171
ReturnType::Default => return item.clone(),
151172
}
152173
};
153174
TraitItem::Fn(TraitItemFn {
154-
sig: Signature {
155-
asyncness: None,
156-
output: ReturnType::Type(arrow, Box::new(output)),
157-
..sig.clone()
158-
},
175+
sig,
176+
default,
159177
..fn_item.clone()
160178
})
161179
}
@@ -164,8 +182,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
164182
let orig = &tr.ident;
165183
let variant = &attrs.variant.name;
166184
let items = tr.items.iter().map(|item| blanket_impl_item(item, variant));
185+
let self_is_sync = tr
186+
.items
187+
.iter()
188+
.any(|item| {
189+
matches!(
190+
item,
191+
TraitItem::Fn(TraitItemFn {
192+
default: Some(_),
193+
..
194+
})
195+
)
196+
})
197+
.then(|| quote! { Self: Sync })
198+
.unwrap_or_default();
167199
quote! {
168-
impl<T> #orig for T where T: #variant {
200+
impl<T> #orig for T
201+
where
202+
T: #variant,
203+
#self_is_sync
204+
{
169205
#(#items)*
170206
}
171207
}
@@ -205,6 +241,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
205241
} else {
206242
quote! {}
207243
};
244+
208245
quote! {
209246
#sig {
210247
<Self as #variant>::#ident(#(#args),*)#maybe_await
@@ -228,3 +265,40 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
228265
_ => Error::new_spanned(item, "unsupported item type").into_compile_error(),
229266
}
230267
}
268+
269+
fn add_receiver_bounds(sig: &mut Signature) {
270+
if let Some(FnArg::Receiver(Receiver { ty, reference, .. })) = sig.inputs.first_mut() {
271+
let predicate =
272+
if let (Type::Reference(reference), Some((_and, lt))) = (&mut **ty, reference) {
273+
let lifetime = syn::Lifetime {
274+
apostrophe: Span::mixed_site(),
275+
ident: Ident::new("the_self_lt", Span::mixed_site()),
276+
};
277+
sig.generics.params.insert(
278+
0,
279+
syn::GenericParam::Lifetime(syn::LifetimeParam {
280+
lifetime: lifetime.clone(),
281+
colon_token: None,
282+
bounds: Default::default(),
283+
attrs: Default::default(),
284+
}),
285+
);
286+
reference.lifetime = Some(lifetime.clone());
287+
let predicate = syn::parse2(quote! { #reference: Send }).unwrap();
288+
*lt = Some(lifetime);
289+
predicate
290+
} else {
291+
syn::parse2(quote! { #ty: Send }).unwrap()
292+
};
293+
294+
if let Some(wh) = &mut sig.generics.where_clause {
295+
wh.predicates.push(predicate);
296+
} else {
297+
let where_clause = WhereClause {
298+
where_token: Token![where](Span::mixed_site()),
299+
predicates: Punctuated::from_iter([predicate]),
300+
};
301+
sig.generics.where_clause = Some(where_clause);
302+
}
303+
}
304+
}

0 commit comments

Comments
 (0)