8
8
9
9
use std:: iter;
10
10
11
- use proc_macro2:: TokenStream ;
11
+ use proc_macro2:: { Span , TokenStream } ;
12
12
use quote:: quote;
13
13
use syn:: {
14
14
parse:: { Parse , ParseStream } ,
15
15
parse_macro_input,
16
16
punctuated:: Punctuated ,
17
17
token:: Plus ,
18
- Error , FnArg , Generics , Ident , ItemTrait , Pat , PatType , Result , ReturnType , Signature , Token ,
19
- TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type , TypeImplTrait ,
20
- TypeParamBound ,
18
+ Error , FnArg , Generics , Ident , ItemTrait , Pat , PatType , Receiver , Result , ReturnType ,
19
+ Signature , Token , TraitBound , TraitItem , TraitItemConst , TraitItemFn , TraitItemType , Type ,
20
+ TypeImplTrait , TypeParamBound , WhereClause ,
21
21
} ;
22
22
23
23
struct Attrs {
@@ -119,10 +119,10 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
119
119
// fn stream(&self) -> impl Iterator<Item = i32> + Send;
120
120
// fn call(&self) -> u32;
121
121
// }
122
- let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, .. } ) = item else {
122
+ let TraitItem :: Fn ( fn_item @ TraitItemFn { sig, default , .. } ) = item else {
123
123
return item. clone ( ) ;
124
124
} ;
125
- let ( arrow , output ) = if sig. asyncness . is_some ( ) {
125
+ let ( sig , default ) = if sig. asyncness . is_some ( ) {
126
126
let orig = match & sig. output {
127
127
ReturnType :: Default => quote ! { ( ) } ,
128
128
ReturnType :: Type ( _, ty) => quote ! { #ty } ,
@@ -134,7 +134,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
134
134
. chain ( bounds. iter ( ) . cloned ( ) )
135
135
. collect ( ) ,
136
136
} ) ;
137
- ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , ty)
137
+ let mut sig = sig. clone ( ) ;
138
+ if default. is_some ( ) {
139
+ add_receiver_bounds ( & mut sig) ;
140
+ }
141
+
142
+ (
143
+ Signature {
144
+ asyncness : None ,
145
+ output : ReturnType :: Type ( syn:: parse2 ( quote ! { -> } ) . unwrap ( ) , Box :: new ( ty) ) ,
146
+ ..sig. clone ( )
147
+ } ,
148
+ fn_item
149
+ . default
150
+ . as_ref ( )
151
+ . map ( |b| syn:: parse2 ( quote ! { { async move #b } } ) . unwrap ( ) ) ,
152
+ )
138
153
} else {
139
154
match & sig. output {
140
155
ReturnType :: Type ( arrow, ty) => match & * * ty {
@@ -143,19 +158,22 @@ fn transform_item(item: &TraitItem, bounds: &Vec<TypeParamBound>) -> TraitItem {
143
158
impl_token : it. impl_token ,
144
159
bounds : it. bounds . iter ( ) . chain ( bounds) . cloned ( ) . collect ( ) ,
145
160
} ) ;
146
- ( * arrow, ty)
161
+ (
162
+ Signature {
163
+ output : ReturnType :: Type ( * arrow, Box :: new ( ty) ) ,
164
+ ..sig. clone ( )
165
+ } ,
166
+ fn_item. default . clone ( ) ,
167
+ )
147
168
}
148
169
_ => return item. clone ( ) ,
149
170
} ,
150
171
ReturnType :: Default => return item. clone ( ) ,
151
172
}
152
173
} ;
153
174
TraitItem :: Fn ( TraitItemFn {
154
- sig : Signature {
155
- asyncness : None ,
156
- output : ReturnType :: Type ( arrow, Box :: new ( output) ) ,
157
- ..sig. clone ( )
158
- } ,
175
+ sig,
176
+ default,
159
177
..fn_item. clone ( )
160
178
} )
161
179
}
@@ -164,8 +182,26 @@ fn mk_blanket_impl(attrs: &Attrs, tr: &ItemTrait) -> TokenStream {
164
182
let orig = & tr. ident ;
165
183
let variant = & attrs. variant . name ;
166
184
let items = tr. items . iter ( ) . map ( |item| blanket_impl_item ( item, variant) ) ;
185
+ let self_is_sync = tr
186
+ . items
187
+ . iter ( )
188
+ . any ( |item| {
189
+ matches ! (
190
+ item,
191
+ TraitItem :: Fn ( TraitItemFn {
192
+ default : Some ( _) ,
193
+ ..
194
+ } )
195
+ )
196
+ } )
197
+ . then ( || quote ! { Self : Sync } )
198
+ . unwrap_or_default ( ) ;
167
199
quote ! {
168
- impl <T > #orig for T where T : #variant {
200
+ impl <T > #orig for T
201
+ where
202
+ T : #variant,
203
+ #self_is_sync
204
+ {
169
205
#( #items) *
170
206
}
171
207
}
@@ -205,6 +241,7 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
205
241
} else {
206
242
quote ! { }
207
243
} ;
244
+
208
245
quote ! {
209
246
#sig {
210
247
<Self as #variant>:: #ident( #( #args) , * ) #maybe_await
@@ -228,3 +265,40 @@ fn blanket_impl_item(item: &TraitItem, variant: &Ident) -> TokenStream {
228
265
_ => Error :: new_spanned ( item, "unsupported item type" ) . into_compile_error ( ) ,
229
266
}
230
267
}
268
+
269
+ fn add_receiver_bounds ( sig : & mut Signature ) {
270
+ if let Some ( FnArg :: Receiver ( Receiver { ty, reference, .. } ) ) = sig. inputs . first_mut ( ) {
271
+ let predicate =
272
+ if let ( Type :: Reference ( reference) , Some ( ( _and, lt) ) ) = ( & mut * * ty, reference) {
273
+ let lifetime = syn:: Lifetime {
274
+ apostrophe : Span :: mixed_site ( ) ,
275
+ ident : Ident :: new ( "the_self_lt" , Span :: mixed_site ( ) ) ,
276
+ } ;
277
+ sig. generics . params . insert (
278
+ 0 ,
279
+ syn:: GenericParam :: Lifetime ( syn:: LifetimeParam {
280
+ lifetime : lifetime. clone ( ) ,
281
+ colon_token : None ,
282
+ bounds : Default :: default ( ) ,
283
+ attrs : Default :: default ( ) ,
284
+ } ) ,
285
+ ) ;
286
+ reference. lifetime = Some ( lifetime. clone ( ) ) ;
287
+ let predicate = syn:: parse2 ( quote ! { #reference: Send } ) . unwrap ( ) ;
288
+ * lt = Some ( lifetime) ;
289
+ predicate
290
+ } else {
291
+ syn:: parse2 ( quote ! { #ty: Send } ) . unwrap ( )
292
+ } ;
293
+
294
+ if let Some ( wh) = & mut sig. generics . where_clause {
295
+ wh. predicates . push ( predicate) ;
296
+ } else {
297
+ let where_clause = WhereClause {
298
+ where_token : Token ! [ where ] ( Span :: mixed_site ( ) ) ,
299
+ predicates : Punctuated :: from_iter ( [ predicate] ) ,
300
+ } ;
301
+ sig. generics . where_clause = Some ( where_clause) ;
302
+ }
303
+ }
304
+ }
0 commit comments