Skip to content

Commit 4ab0022

Browse files
committed
Add support for defaulted methods
1 parent 6a5e7ab commit 4ab0022

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

Diff for: trait-variant/examples/variant.rs

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ pub trait LocalIntFactory {
2020
fn stream(&self) -> impl Iterator<Item = i32>;
2121
fn call(&self) -> u32;
2222
fn another_async(&self, input: Result<(), &str>) -> Self::MyFut<'_>;
23+
async fn defaulted(&self) -> i32 {
24+
self.make(10, "10").await
25+
}
2326
}
2427

2528
#[allow(dead_code)]

Diff for: trait-variant/src/variant.rs

+44-4
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
122122
let TraitItem::Fn(fn_item @ TraitItemFn { sig, .. }) = item else {
123123
return item.clone();
124124
};
125-
let (arrow, output) = if sig.asyncness.is_some() {
125+
let (arrow, output, generics, default) = if sig.asyncness.is_some() {
126126
let orig = match &sig.output {
127127
ReturnType::Default => quote! { () },
128128
ReturnType::Type(_, ty) => quote! { #ty },
@@ -134,7 +134,24 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134134
.chain(bounds.iter().cloned())
135135
.collect(),
136136
});
137-
(syn::parse2(quote! { -> }).unwrap(), ty)
137+
let mut generics = fn_item.sig.generics.clone();
138+
if fn_item.default.is_some() {
139+
if let Some(wh) = &mut generics.where_clause {
140+
wh.predicates
141+
.push(syn::parse2(quote! { Self: Sync }).unwrap());
142+
} else {
143+
generics.where_clause = Some(syn::parse2(quote! { where Self: Sync }).unwrap())
144+
}
145+
}
146+
(
147+
syn::parse2(quote! { -> }).unwrap(),
148+
ty,
149+
generics,
150+
fn_item
151+
.default
152+
.as_ref()
153+
.map(|b| syn::parse2(quote! { { async move #b } }).unwrap()),
154+
)
138155
} else {
139156
match &sig.output {
140157
ReturnType::Type(arrow, ty) => match &**ty {
@@ -143,7 +160,12 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143160
impl_token: it.impl_token,
144161
bounds: it.bounds.iter().chain(bounds).cloned().collect(),
145162
});
146-
(*arrow, ty)
163+
(
164+
*arrow,
165+
ty,
166+
fn_item.sig.generics.clone(),
167+
fn_item.default.clone(),
168+
)
147169
}
148170
_ => return item.clone(),
149171
},
@@ -154,8 +176,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
154176
sig: Signature {
155177
asyncness: None,
156178
output: ReturnType::Type(arrow, Box::new(output)),
179+
generics,
157180
..sig.clone()
158181
},
182+
default,
159183
..fn_item.clone()
160184
})
161185
}
@@ -164,8 +188,24 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
164188
let orig = &tr.ident;
165189
let variant = &attrs.variant.name;
166190
let items = tr.items.iter().map(|item| blanket_impl_item(item, variant));
191+
// if there is a defaulted method, than that defaulted method has a bound of Self: Sync,
192+
// which means the blanket impl must also require T: Sync
193+
let self_is_sync = tr
194+
.items
195+
.iter()
196+
.any(|item| {
197+
matches!(
198+
item,
199+
TraitItem::Fn(TraitItemFn {
200+
default: Some(_),
201+
..
202+
})
203+
)
204+
})
205+
.then(|| quote! { T: Sync })
206+
.unwrap_or_default();
167207
quote! {
168-
impl<T> #orig for T where T: #variant {
208+
impl<T> #orig for T where T: #variant, #self_is_sync {
169209
#(#items)*
170210
}
171211
}

0 commit comments

Comments
 (0)