Skip to content

Commit 8c6e4e5

Browse files
committed
DRY out VisitEntities derive macros
1 parent a9cd2fc commit 8c6e4e5

File tree

1 file changed

+23
-76
lines changed
  • crates/bevy_ecs/macros/src

1 file changed

+23
-76
lines changed

crates/bevy_ecs/macros/src/lib.rs

Lines changed: 23 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use crate::{query_data::derive_query_data_impl, query_filter::derive_query_filte
1414
use bevy_macro_utils::{derive_label, ensure_no_collision, get_struct_fields, BevyManifest};
1515
use proc_macro::TokenStream;
1616
use proc_macro2::Span;
17+
use proc_macro2::TokenStream as TokenStream2;
1718
use quote::{format_ident, quote};
1819
use syn::{
1920
parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
@@ -180,8 +181,11 @@ pub fn derive_bundle(input: TokenStream) -> TokenStream {
180181
})
181182
}
182183

183-
#[proc_macro_derive(VisitEntitiesMut, attributes(visit_entities))]
184-
pub fn derive_visit_entities_mut(input: TokenStream) -> TokenStream {
184+
fn derive_visit_entities_base(
185+
input: TokenStream,
186+
trait_name: TokenStream2,
187+
gen_methods: impl FnOnce(Vec<TokenStream2>) -> TokenStream2,
188+
) -> TokenStream {
185189
let ast = parse_macro_input!(input as DeriveInput);
186190
let ecs_path = bevy_ecs_path();
187191

@@ -223,7 +227,7 @@ pub fn derive_visit_entities_mut(input: TokenStream) -> TokenStream {
223227
if field.is_empty() {
224228
return syn::Error::new(
225229
ast.span(),
226-
"Invalid `VisitEntitiesMut` type: at least one field",
230+
format!("Invalid `{}` type: at least one field", trait_name),
227231
)
228232
.into_compile_error()
229233
.into();
@@ -246,93 +250,36 @@ pub fn derive_visit_entities_mut(input: TokenStream) -> TokenStream {
246250
})
247251
.collect::<Vec<_>>();
248252

253+
let methods = gen_methods(field_access);
254+
249255
let generics = ast.generics;
250256
let (impl_generics, ty_generics, _) = generics.split_for_impl();
251257
let struct_name = &ast.ident;
252258

253259
TokenStream::from(quote! {
254-
impl #impl_generics #ecs_path::entity::VisitEntitiesMut for #struct_name #ty_generics {
260+
impl #impl_generics #ecs_path::entity:: #trait_name for #struct_name #ty_generics {
261+
#methods
262+
}
263+
})
264+
}
265+
266+
#[proc_macro_derive(VisitEntitiesMut, attributes(visit_entities))]
267+
pub fn derive_visit_entities_mut(input: TokenStream) -> TokenStream {
268+
derive_visit_entities_base(input, quote! { VisitEntitiesMut }, |field| {
269+
quote! {
255270
fn visit_entities_mut<F: FnMut(&mut Entity)>(&mut self, mut f: F) {
256-
#(#field_access.visit_entities_mut(&mut f);)*
271+
#(#field.visit_entities_mut(&mut f);)*
257272
}
258273
}
259274
})
260275
}
261276

262277
#[proc_macro_derive(VisitEntities, attributes(visit_entities))]
263278
pub fn derive_visit_entities(input: TokenStream) -> TokenStream {
264-
let ast = parse_macro_input!(input as DeriveInput);
265-
let ecs_path = bevy_ecs_path();
266-
267-
let named_fields = match get_struct_fields(&ast.data) {
268-
Ok(fields) => fields,
269-
Err(e) => return e.into_compile_error().into(),
270-
};
271-
272-
let field = named_fields
273-
.iter()
274-
.filter_map(|field| {
275-
if let Some(attr) = field
276-
.attrs
277-
.iter()
278-
.find(|a| a.path().is_ident("visit_entities"))
279-
{
280-
let ignore = attr.parse_nested_meta(|meta| {
281-
if meta.path.is_ident("ignore") {
282-
Ok(())
283-
} else {
284-
Err(meta.error("Invalid visit_entities attribute. Use `ignore`"))
285-
}
286-
});
287-
return match ignore {
288-
Ok(()) => None,
289-
Err(e) => Some(Err(e)),
290-
};
291-
}
292-
Some(Ok(field))
293-
})
294-
.map(|res| res.map(|field| field.ident.as_ref()))
295-
.collect::<Result<Vec<_>, _>>();
296-
297-
let field = match field {
298-
Ok(field) => field,
299-
Err(e) => return e.into_compile_error().into(),
300-
};
301-
302-
if field.is_empty() {
303-
return syn::Error::new(
304-
ast.span(),
305-
"Invalid `VisitEntities` type: at least one field",
306-
)
307-
.into_compile_error()
308-
.into();
309-
}
310-
311-
let field_access = field
312-
.iter()
313-
.enumerate()
314-
.map(|(n, f)| {
315-
if let Some(ident) = f {
316-
quote! {
317-
self.#ident
318-
}
319-
} else {
320-
let idx = Index::from(n);
321-
quote! {
322-
self.#idx
323-
}
324-
}
325-
})
326-
.collect::<Vec<_>>();
327-
328-
let generics = ast.generics;
329-
let (impl_generics, ty_generics, _) = generics.split_for_impl();
330-
let struct_name = &ast.ident;
331-
332-
TokenStream::from(quote! {
333-
impl #impl_generics #ecs_path::entity::VisitEntities for #struct_name #ty_generics {
279+
derive_visit_entities_base(input, quote! { VisitEntities }, |field| {
280+
quote! {
334281
fn visit_entities<F: FnMut(Entity)>(&self, mut f: F) {
335-
#(#field_access.visit_entities(&mut f);)*
282+
#(#field.visit_entities(&mut f);)*
336283
}
337284
}
338285
})

0 commit comments

Comments
 (0)