From 6e959db134c7bf9902f486d45136aec05bff3025 Mon Sep 17 00:00:00 2001 From: Gino Valente <49806985+MrGVSV@users.noreply.github.com> Date: Sun, 28 Jan 2024 09:24:03 -0700 Subject: [PATCH] bevy_reflect: Type parameter bounds (#9046) # Objective Fixes #8965. #### Background For convenience and to ensure everything is setup properly, we automatically add certain bounds to the derived types. The current implementation does this by taking the types from all active fields and adding them to the where-clause of the generated impls. I believe this method was chosen because it won't add bounds to types that are otherwise ignored. ```rust #[derive(Reflect)] struct Foo { t: T, u: U::Assoc, #[reflect(ignore)] v: [V; 2] } // Generates something like: impl for Foo where // Active: T: Reflect, U::Assoc: Reflect, // Ignored: [V; 2]: Send + Sync + Any { // ... } ``` The self-referential type fails because it ends up using _itself_ as a type bound due to being one of its own active fields. ```rust #[derive(Reflect)] struct Foo { foo: Vec } // Foo where Vec: Reflect -> Vec where T: Reflect -> Foo where Vec: Reflect -> ... ``` ## Solution We can't simply parse all field types for the name of our type. That would be both complex and prone to errors and false-positives. And even if it wasn't, what would we replace the bound with? Instead, I opted to go for a solution that only adds the bounds to what really needs it: the type parameters. While the bounds on concrete types make errors a bit cleaner, they aren't strictly necessary. This means we can change our generated where-clause to only add bounds to generic type parameters. Doing this, though, returns us back to the problem of over-bounding parameters that don't need to be bounded. To solve this, I added a new container attribute (based on [this](https://github.com/dtolnay/syn/issues/422#issuecomment-406882925) comment and @nicopap's [comment](https://github.com/bevyengine/bevy/pull/9046#issuecomment-1623593780)) that allows us to pass in a custom where clause to modify what bounds are added to these type parameters. This allows us to do stuff like: ```rust trait Trait { type Assoc; } // We don't need `T` to be reflectable since we only care about `T::Assoc`. #[derive(Reflect)] #[reflect(where T::Assoc: FromReflect)] struct Foo(T::Assoc); #[derive(TypePath)] struct Bar; impl Trait for Bar { type Assoc = usize; } #[derive(Reflect)] struct Baz { a: Foo, } ``` > **Note** > I also [tried](https://github.com/bevyengine/bevy/commit/dc139ea34c4737da3ea7ab7ea2e8701462693d02) allowing `#[reflect(ignore)]` to be used on the type parameters themselves, but that proved problematic since the derive macro does not consume the attribute. This is why I went with the container attribute approach. ### Alternatives One alternative could possibly be to just not add reflection bounds automatically (i.e. only add required bounds like `Send`, `Sync`, `Any`, and `TypePath`). The downside here is we add more friction to using reflection, which already comes with its own set of considerations. This is a potentially viable option, but we really need to consider whether or not the ergonomics hit is worth it. If we did decide to go the more manual route, we should at least consider something like #5772 to make it easier for users to add the right bounds (although, this could still become tricky with `FromReflect` also being automatically derived). ### Open Questions 1. Should we go with this approach or the manual alternative? 2. ~~Should we add a `skip_params` attribute to avoid the `T: 'static` trick?~~ ~~Decided to go with `custom_where()` as it's the simplest~~ Scratch that, went with a normal where clause 3. ~~`custom_where` bikeshedding?~~ No longer needed since we are using a normal where clause ### TODO - [x] Add compile-fail tests --- ## Changelog - Fixed issue preventing recursive types from deriving `Reflect` - Changed how where-clause bounds are generated by the `Reflect` derive macro - They are now only applied to the type parameters, not to all active fields - Added `#[reflect(where T: Trait, U::Assoc: Trait, ...)]` container attribute ## Migration Guide When deriving `Reflect`, generic type params that do not need the automatic reflection bounds (such as `Reflect`) applied to them will need to opt-out using a custom where clause like: `#[reflect(where T: Trait, U::Assoc: Trait, ...)]`. The attribute can define custom bounds only used by the reflection impls. To simply opt-out all the type params, we can pass in an empty where clause: `#[reflect(where)]`. ```rust // BEFORE: #[derive(Reflect)] struct Foo(#[reflect(ignore)] T); // AFTER: #[derive(Reflect)] #[reflect(where)] struct Foo(#[reflect(ignore)] T); ``` --------- Co-authored-by: Nicola Papale --- crates/bevy_asset/src/handle.rs | 1 + crates/bevy_asset/src/id.rs | 1 + .../src/container_attributes.rs | 105 ++++-- .../bevy_reflect_derive/src/derive_data.rs | 31 +- .../bevy_reflect_derive/src/from_reflect.rs | 46 +-- .../bevy_reflect_derive/src/impls/enums.rs | 3 +- .../bevy_reflect_derive/src/impls/structs.rs | 4 +- .../src/impls/tuple_structs.rs | 3 +- .../bevy_reflect_derive/src/impls/typed.rs | 10 +- .../bevy_reflect_derive/src/impls/values.rs | 6 +- .../bevy_reflect_derive/src/lib.rs | 51 +++ .../bevy_reflect_derive/src/registration.rs | 4 +- .../bevy_reflect_derive/src/utility.rs | 303 ++++++++---------- crates/bevy_reflect/src/lib.rs | 91 ++++++ .../tests/reflect_derive/custom_where.fail.rs | 31 ++ .../reflect_derive/custom_where.fail.stderr | 11 + .../tests/reflect_derive/custom_where.pass.rs | 31 ++ 17 files changed, 447 insertions(+), 285 deletions(-) create mode 100644 crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.rs create mode 100644 crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.stderr create mode 100644 crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.pass.rs diff --git a/crates/bevy_asset/src/handle.rs b/crates/bevy_asset/src/handle.rs index d006d85ed6f78..fa9f3ee2dc1dd 100644 --- a/crates/bevy_asset/src/handle.rs +++ b/crates/bevy_asset/src/handle.rs @@ -122,6 +122,7 @@ impl std::fmt::Debug for StrongHandle { /// [`Handle::Strong`] also provides access to useful [`Asset`] metadata, such as the [`AssetPath`] (if it exists). #[derive(Component, Reflect)] #[reflect(Component)] +#[reflect(where A: Asset)] pub enum Handle { /// A "strong" reference to a live (or loading) [`Asset`]. If a [`Handle`] is [`Handle::Strong`], the [`Asset`] will be kept /// alive until the [`Handle`] is dropped. Strong handles also provide access to additional asset metadata. diff --git a/crates/bevy_asset/src/id.rs b/crates/bevy_asset/src/id.rs index 1efc61bacc1af..a11fcad7746ff 100644 --- a/crates/bevy_asset/src/id.rs +++ b/crates/bevy_asset/src/id.rs @@ -17,6 +17,7 @@ use thiserror::Error; /// /// For an "untyped" / "generic-less" id, see [`UntypedAssetId`]. #[derive(Reflect)] +#[reflect(where A: Asset)] pub enum AssetId { /// A small / efficient runtime identifier that can be used to efficiently look up an asset stored in [`Assets`]. This is /// the "default" identifier used for assets. The alternative(s) (ex: [`AssetId::Uuid`]) will only be used if assets are diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs b/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs index ee388731d7497..b8dc979cb7b04 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/container_attributes.rs @@ -7,13 +7,13 @@ use crate::utility; use bevy_macro_utils::fq_std::{FQAny, FQOption}; -use proc_macro2::{Ident, Span}; +use proc_macro2::{Ident, Span, TokenTree}; use quote::quote_spanned; use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::token::Comma; -use syn::{Expr, LitBool, Meta, Path}; +use syn::{Expr, LitBool, Meta, MetaList, Path, WhereClause}; // The "special" trait idents that are used internally for reflection. // Received via attributes like `#[reflect(PartialEq, Hash, ...)]` @@ -211,24 +211,40 @@ pub(crate) struct ReflectTraits { partial_eq: TraitImpl, from_reflect_attrs: FromReflectAttrs, type_path_attrs: TypePathAttrs, + custom_where: Option, idents: Vec, } impl ReflectTraits { - pub fn from_metas( + pub fn from_meta_list( + meta: &MetaList, + is_from_reflect_derive: bool, + ) -> Result { + match meta.tokens.clone().into_iter().next() { + // Handles `#[reflect(where T: Trait, U::Assoc: Trait)]` + Some(TokenTree::Ident(ident)) if ident == "where" => Ok(Self { + custom_where: Some(meta.parse_args::()?), + ..Self::default() + }), + _ => Self::from_metas( + meta.parse_args_with(Punctuated::::parse_terminated)?, + is_from_reflect_derive, + ), + } + } + + fn from_metas( metas: Punctuated, is_from_reflect_derive: bool, ) -> Result { let mut traits = ReflectTraits::default(); for meta in &metas { match meta { - // Handles `#[reflect( Hash, Default, ... )]` + // Handles `#[reflect( Debug, PartialEq, Hash, SomeTrait )]` Meta::Path(path) => { - // Get the first ident in the path (hopefully the path only contains one and not `std::hash::Hash`) - let Some(segment) = path.segments.iter().next() else { + let Some(ident) = path.get_ident() else { continue; }; - let ident = &segment.ident; let ident_name = ident.to_string(); // Track the span where the trait is implemented for future errors @@ -255,38 +271,38 @@ impl ReflectTraits { } } } + // Handles `#[reflect( Debug(custom_debug_fn) )]` + Meta::List(list) if list.path.is_ident(DEBUG_ATTR) => { + let ident = list.path.get_ident().unwrap(); + list.parse_nested_meta(|meta| { + let trait_func_ident = TraitImpl::Custom(meta.path, ident.span()); + traits.debug.merge(trait_func_ident) + })?; + } + // Handles `#[reflect( PartialEq(custom_partial_eq_fn) )]` + Meta::List(list) if list.path.is_ident(PARTIAL_EQ_ATTR) => { + let ident = list.path.get_ident().unwrap(); + list.parse_nested_meta(|meta| { + let trait_func_ident = TraitImpl::Custom(meta.path, ident.span()); + traits.partial_eq.merge(trait_func_ident) + })?; + } // Handles `#[reflect( Hash(custom_hash_fn) )]` - Meta::List(list) => { - // Get the first ident in the path (hopefully the path only contains one and not `std::hash::Hash`) - let Some(segment) = list.path.segments.iter().next() else { - continue; - }; - - let ident = segment.ident.to_string(); - - // Track the span where the trait is implemented for future errors - let span = ident.span(); - + Meta::List(list) if list.path.is_ident(HASH_ATTR) => { + let ident = list.path.get_ident().unwrap(); list.parse_nested_meta(|meta| { - // This should be the path of the custom function - let trait_func_ident = TraitImpl::Custom(meta.path, span); - match ident.as_str() { - DEBUG_ATTR => { - traits.debug.merge(trait_func_ident)?; - } - PARTIAL_EQ_ATTR => { - traits.partial_eq.merge(trait_func_ident)?; - } - HASH_ATTR => { - traits.hash.merge(trait_func_ident)?; - } - _ => { - return Err(syn::Error::new(span, "Can only use custom functions for special traits (i.e. `Hash`, `PartialEq`, `Debug`)")); - } - } - Ok(()) + let trait_func_ident = TraitImpl::Custom(meta.path, ident.span()); + traits.hash.merge(trait_func_ident) })?; } + Meta::List(list) => { + return Err(syn::Error::new_spanned( + list, + format!( + "expected one of [{DEBUG_ATTR:?}, {PARTIAL_EQ_ATTR:?}, {HASH_ATTR:?}]" + ), + )); + } Meta::NameValue(pair) => { if pair.path.is_ident(FROM_REFLECT_ATTR) { traits.from_reflect_attrs.auto_derive = @@ -402,6 +418,10 @@ impl ReflectTraits { } } + pub fn custom_where(&self) -> Option<&WhereClause> { + self.custom_where.as_ref() + } + /// Merges the trait implementations of this [`ReflectTraits`] with another one. /// /// An error is returned if the two [`ReflectTraits`] have conflicting implementations. @@ -411,11 +431,26 @@ impl ReflectTraits { self.partial_eq.merge(other.partial_eq)?; self.from_reflect_attrs.merge(other.from_reflect_attrs)?; self.type_path_attrs.merge(other.type_path_attrs)?; + + self.merge_custom_where(other.custom_where); + for ident in other.idents { add_unique_ident(&mut self.idents, ident)?; } Ok(()) } + + fn merge_custom_where(&mut self, other: Option) { + match (&mut self.custom_where, other) { + (Some(this), Some(other)) => { + this.predicates.extend(other.predicates); + } + (None, Some(other)) => { + self.custom_where = Some(other); + } + _ => {} + } + } } impl Parse for ReflectTraits { diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs b/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs index ce8777bbc041d..cfdf641714abf 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/derive_data.rs @@ -167,10 +167,8 @@ impl<'a> ReflectDerive<'a> { } reflect_mode = Some(ReflectMode::Normal); - let new_traits = ReflectTraits::from_metas( - meta_list.parse_args_with(Punctuated::::parse_terminated)?, - is_from_reflect_derive, - )?; + let new_traits = + ReflectTraits::from_meta_list(meta_list, is_from_reflect_derive)?; traits.merge(new_traits)?; } Meta::List(meta_list) if meta_list.path.is_ident(REFLECT_VALUE_ATTRIBUTE_NAME) => { @@ -182,10 +180,8 @@ impl<'a> ReflectDerive<'a> { } reflect_mode = Some(ReflectMode::Value); - let new_traits = ReflectTraits::from_metas( - meta_list.parse_args_with(Punctuated::::parse_terminated)?, - is_from_reflect_derive, - )?; + let new_traits = + ReflectTraits::from_meta_list(meta_list, is_from_reflect_derive)?; traits.merge(new_traits)?; } Meta::Path(path) if path.is_ident(REFLECT_VALUE_ATTRIBUTE_NAME) => { @@ -484,7 +480,7 @@ impl<'a> ReflectStruct<'a> { } pub fn where_clause_options(&self) -> WhereClauseOptions { - WhereClauseOptions::new(self.meta(), self.active_fields(), self.ignored_fields()) + WhereClauseOptions::new(self.meta()) } } @@ -507,22 +503,8 @@ impl<'a> ReflectEnum<'a> { &self.variants } - /// Get an iterator of fields which are exposed to the reflection API - pub fn active_fields(&self) -> impl Iterator> { - self.variants() - .iter() - .flat_map(|variant| variant.active_fields()) - } - - /// Get an iterator of fields which are ignored by the reflection API - pub fn ignored_fields(&self) -> impl Iterator> { - self.variants() - .iter() - .flat_map(|variant| variant.ignored_fields()) - } - pub fn where_clause_options(&self) -> WhereClauseOptions { - WhereClauseOptions::new(self.meta(), self.active_fields(), self.ignored_fields()) + WhereClauseOptions::new(self.meta()) } } @@ -668,6 +650,7 @@ impl<'a> ReflectTypePath<'a> { where_clause: None, params: Punctuated::new(), }; + match self { Self::Internal { generics, .. } | Self::External { generics, .. } => generics, _ => EMPTY_GENERICS, diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs b/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs index bca7162de8b2d..c55267d3741b9 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/from_reflect.rs @@ -2,7 +2,7 @@ use crate::container_attributes::REFLECT_DEFAULT; use crate::derive_data::ReflectEnum; use crate::enum_utility::{get_variant_constructors, EnumVariantConstructors}; use crate::field_attributes::DefaultBehavior; -use crate::utility::{extend_where_clause, ident_or_index, WhereClauseOptions}; +use crate::utility::{ident_or_index, WhereClauseOptions}; use crate::{ReflectMeta, ReflectStruct}; use bevy_macro_utils::fq_std::{FQAny, FQClone, FQDefault, FQOption}; use proc_macro2::Span; @@ -24,7 +24,7 @@ pub(crate) fn impl_value(meta: &ReflectMeta) -> proc_macro2::TokenStream { let bevy_reflect_path = meta.bevy_reflect_path(); let (impl_generics, ty_generics, where_clause) = type_path.generics().split_for_impl(); let where_from_reflect_clause = - extend_where_clause(where_clause, &WhereClauseOptions::new_value(meta)); + WhereClauseOptions::new_type_path(meta).extend_where_clause(where_clause); quote! { impl #impl_generics #bevy_reflect_path::FromReflect for #type_path #ty_generics #where_from_reflect_clause { fn from_reflect(reflect: &dyn #bevy_reflect_path::Reflect) -> #FQOption { @@ -50,22 +50,8 @@ pub(crate) fn impl_enum(reflect_enum: &ReflectEnum) -> proc_macro2::TokenStream let (impl_generics, ty_generics, where_clause) = enum_path.generics().split_for_impl(); // Add FromReflect bound for each active field - let where_from_reflect_clause = extend_where_clause( - where_clause, - &WhereClauseOptions::new_with_bounds( - reflect_enum.meta(), - reflect_enum.active_fields(), - reflect_enum.ignored_fields(), - |field| match &field.attrs.default { - DefaultBehavior::Default => Some(quote!(#FQDefault)), - _ => None, - }, - |field| match &field.attrs.default { - DefaultBehavior::Func(_) => None, - _ => Some(quote!(#FQDefault)), - }, - ), - ); + let where_from_reflect_clause = + WhereClauseOptions::new(reflect_enum.meta()).extend_where_clause(where_clause); quote! { impl #impl_generics #bevy_reflect_path::FromReflect for #enum_path #ty_generics #where_from_reflect_clause { @@ -144,28 +130,8 @@ fn impl_struct_internal( .split_for_impl(); // Add FromReflect bound for each active field - let where_from_reflect_clause = extend_where_clause( - where_clause, - &WhereClauseOptions::new_with_bounds( - reflect_struct.meta(), - reflect_struct.active_fields(), - reflect_struct.ignored_fields(), - |field| match &field.attrs.default { - DefaultBehavior::Default => Some(quote!(#FQDefault)), - _ => None, - }, - |field| { - if is_defaultable { - None - } else { - match &field.attrs.default { - DefaultBehavior::Func(_) => None, - _ => Some(quote!(#FQDefault)), - } - } - }, - ), - ); + let where_from_reflect_clause = + WhereClauseOptions::new(reflect_struct.meta()).extend_where_clause(where_clause); quote! { impl #impl_generics #bevy_reflect_path::FromReflect for #struct_path #ty_generics #where_from_reflect_clause { diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/impls/enums.rs b/crates/bevy_reflect/bevy_reflect_derive/src/impls/enums.rs index 2d2bebfc9c08d..b9cd484267e5d 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/impls/enums.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/impls/enums.rs @@ -1,7 +1,6 @@ use crate::derive_data::{EnumVariant, EnumVariantFields, ReflectEnum, StructField}; use crate::enum_utility::{get_variant_constructors, EnumVariantConstructors}; use crate::impls::{impl_type_path, impl_typed}; -use crate::utility::extend_where_clause; use bevy_macro_utils::fq_std::{FQAny, FQBox, FQOption, FQResult}; use proc_macro2::{Ident, Span}; use quote::quote; @@ -92,7 +91,7 @@ pub(crate) fn impl_enum(reflect_enum: &ReflectEnum) -> proc_macro2::TokenStream let (impl_generics, ty_generics, where_clause) = reflect_enum.meta().type_path().generics().split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, &where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); quote! { #get_type_registration_impl diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/impls/structs.rs b/crates/bevy_reflect/bevy_reflect_derive/src/impls/structs.rs index 9aef44d3505a8..90c51f36232db 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/impls/structs.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/impls/structs.rs @@ -1,5 +1,5 @@ use crate::impls::{impl_type_path, impl_typed}; -use crate::utility::{extend_where_clause, ident_or_index}; +use crate::utility::ident_or_index; use crate::ReflectStruct; use bevy_macro_utils::fq_std::{FQAny, FQBox, FQDefault, FQOption, FQResult}; use quote::{quote, ToTokens}; @@ -99,7 +99,7 @@ pub(crate) fn impl_struct(reflect_struct: &ReflectStruct) -> proc_macro2::TokenS .generics() .split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, &where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); quote! { #get_type_registration_impl diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/impls/tuple_structs.rs b/crates/bevy_reflect/bevy_reflect_derive/src/impls/tuple_structs.rs index 14af4851fd2e9..b8a17100d04bc 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/impls/tuple_structs.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/impls/tuple_structs.rs @@ -1,5 +1,4 @@ use crate::impls::{impl_type_path, impl_typed}; -use crate::utility::extend_where_clause; use crate::ReflectStruct; use bevy_macro_utils::fq_std::{FQAny, FQBox, FQDefault, FQOption, FQResult}; use quote::{quote, ToTokens}; @@ -90,7 +89,7 @@ pub(crate) fn impl_tuple_struct(reflect_struct: &ReflectStruct) -> proc_macro2:: .generics() .split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, &where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); quote! { #get_type_registration_impl diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/impls/typed.rs b/crates/bevy_reflect/bevy_reflect_derive/src/impls/typed.rs index 46edd1895c3c5..294a8cce83945 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/impls/typed.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/impls/typed.rs @@ -1,4 +1,4 @@ -use crate::utility::{extend_where_clause, StringExpr, WhereClauseOptions}; +use crate::utility::{StringExpr, WhereClauseOptions}; use quote::{quote, ToTokens}; use crate::{ @@ -49,9 +49,7 @@ pub(crate) enum TypedProperty { } pub(crate) fn impl_type_path(meta: &ReflectMeta) -> proc_macro2::TokenStream { - // Use `WhereClauseOptions::new_value` here so we don't enforce reflection bounds, - // ensuring the impl applies in the most cases possible. - let where_clause_options = &WhereClauseOptions::new_value(meta); + let where_clause_options = WhereClauseOptions::new_type_path(meta); if !meta.traits().type_path_attrs().should_auto_derive() { return proc_macro2::TokenStream::new(); @@ -102,7 +100,7 @@ pub(crate) fn impl_type_path(meta: &ReflectMeta) -> proc_macro2::TokenStream { let (impl_generics, ty_generics, where_clause) = type_path.generics().split_for_impl(); // Add Typed bound for each active field - let where_reflect_clause = extend_where_clause(where_clause, where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); quote! { #primitive_assert @@ -143,7 +141,7 @@ pub(crate) fn impl_typed( let (impl_generics, ty_generics, where_clause) = type_path.generics().split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); quote! { impl #impl_generics #bevy_reflect_path::Typed for #type_path #ty_generics #where_reflect_clause { diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/impls/values.rs b/crates/bevy_reflect/bevy_reflect_derive/src/impls/values.rs index 17e0838d799d3..db1082236f0e2 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/impls/values.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/impls/values.rs @@ -1,5 +1,5 @@ use crate::impls::{impl_type_path, impl_typed}; -use crate::utility::{extend_where_clause, WhereClauseOptions}; +use crate::utility::WhereClauseOptions; use crate::ReflectMeta; use bevy_macro_utils::fq_std::{FQAny, FQBox, FQClone, FQOption, FQResult}; use quote::quote; @@ -21,7 +21,7 @@ pub(crate) fn impl_value(meta: &ReflectMeta) -> proc_macro2::TokenStream { #[cfg(not(feature = "documentation"))] let with_docs: Option = None; - let where_clause_options = WhereClauseOptions::new_value(meta); + let where_clause_options = WhereClauseOptions::new_type_path(meta); let typed_impl = impl_typed( meta, &where_clause_options, @@ -34,7 +34,7 @@ pub(crate) fn impl_value(meta: &ReflectMeta) -> proc_macro2::TokenStream { let type_path_impl = impl_type_path(meta); let (impl_generics, ty_generics, where_clause) = type_path.generics().split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, &where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); let get_type_registration_impl = meta.get_type_registration(&where_clause_options); quote! { diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs index 79503785e176f..55d1fc845caa5 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/lib.rs @@ -131,6 +131,53 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name"; /// This is useful for when a type can't or shouldn't implement `TypePath`, /// or if a manual implementation is desired. /// +/// ## `#[reflect(where T: Trait, U::Assoc: Trait, ...)]` +/// +/// By default, the derive macro will automatically add certain trait bounds to all generic type parameters +/// in order to make them compatible with reflection without the user needing to add them manually. +/// This includes traits like `Reflect` and `FromReflect`. +/// However, this may not always be desired, and some type paramaters can't or shouldn't require those bounds +/// (i.e. their usages in fields are ignored or they're only used for their associated types). +/// +/// With this attribute, you can specify a custom `where` clause to be used instead of the default. +/// If this attribute is present, none of the type parameters will receive the default bounds. +/// Only the bounds specified by the type itself and by this attribute will be used. +/// The only exceptions to this are the `Any`, `Send`, `Sync`, and `TypePath` bounds, +/// which will always be added regardless of this attribute due to their necessity for reflection +/// in general. +/// +/// This means that if you want to opt-out of the default bounds for _all_ type parameters, +/// you can add `#[reflect(where)]` to the container item to indicate +/// that an empty `where` clause should be used. +/// +/// ### Example +/// +/// ```ignore +/// trait Trait { +/// type Assoc; +/// } +/// +/// #[derive(Reflect)] +/// #[reflect(where T::Assoc: FromReflect)] +/// struct Foo where T::Assoc: Default { +/// value: T::Assoc, +/// } +/// +/// // Generates a where clause like the following +/// // (notice that `T` does not have any `Reflect` or `FromReflect` bounds): +/// // +/// // impl bevy_reflect::Reflect for Foo +/// // where +/// // Self: 'static, +/// // T::Assoc: Default, +/// // T: bevy_reflect::TypePath +/// // + ::core::any::Any +/// // + ::core::marker::Send +/// // + ::core::marker::Sync, +/// // T::Assoc: FromReflect, +/// // {/* ... */} +/// ``` +/// /// # Field Attributes /// /// Along with the container attributes, this macro comes with some attributes that may be applied @@ -144,6 +191,10 @@ pub(crate) static TYPE_NAME_ATTRIBUTE_NAME: &str = "type_name"; /// which may be useful for maintaining invariants, keeping certain data private, /// or allowing the use of types that do not implement `Reflect` within the container. /// +/// If the field contains a generic type parameter, you will likely need to add a +/// [`#[reflect(where)]`](#reflectwheret-trait-uassoc-trait-) +/// attribute to the container in order to avoid the default bounds being applied to the type parameter. +/// /// ## `#[reflect(skip_serializing)]` /// /// This works similar to `#[reflect(ignore)]`, but rather than opting out of _all_ of reflection, diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/registration.rs b/crates/bevy_reflect/bevy_reflect_derive/src/registration.rs index 115274ad46ae1..45d9731c18c0e 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/registration.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/registration.rs @@ -2,7 +2,7 @@ use crate::derive_data::ReflectMeta; use crate::serialization::SerializationDataDef; -use crate::utility::{extend_where_clause, WhereClauseOptions}; +use crate::utility::WhereClauseOptions; use quote::quote; /// Creates the `GetTypeRegistration` impl for the given type data. @@ -16,7 +16,7 @@ pub(crate) fn impl_get_type_registration( let bevy_reflect_path = meta.bevy_reflect_path(); let registration_data = meta.traits().idents(); let (impl_generics, ty_generics, where_clause) = type_path.generics().split_for_impl(); - let where_reflect_clause = extend_where_clause(where_clause, where_clause_options); + let where_reflect_clause = where_clause_options.extend_where_clause(where_clause); let from_reflect_data = if meta.from_reflect().should_auto_derive() { Some(quote! { diff --git a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs index 7bce217f27424..b50b0b92eb8a7 100644 --- a/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs +++ b/crates/bevy_reflect/bevy_reflect_derive/src/utility.rs @@ -1,13 +1,13 @@ //! General-purpose utility functions for internal usage within this crate. -use crate::derive_data::{ReflectMeta, StructField}; +use crate::derive_data::ReflectMeta; use bevy_macro_utils::{ fq_std::{FQAny, FQOption, FQSend, FQSync}, BevyManifest, }; use proc_macro2::{Ident, Span}; use quote::{quote, ToTokens}; -use syn::{spanned::Spanned, LitStr, Member, Path, Type, WhereClause}; +use syn::{spanned::Spanned, LitStr, Member, Path, WhereClause}; /// Returns the correct path for `bevy_reflect`. pub(crate) fn get_bevy_reflect_path() -> Path { @@ -66,188 +66,153 @@ pub(crate) fn ident_or_index(ident: Option<&Ident>, index: usize) -> Member { ) } -/// Options defining how to extend the `where` clause in reflection with any additional bounds needed. -pub(crate) struct WhereClauseOptions { - /// Type parameters that need extra trait bounds. - parameter_types: Box<[Ident]>, - /// Trait bounds to add to the type parameters. - parameter_trait_bounds: Box<[proc_macro2::TokenStream]>, - /// Any types that will be reflected and need an extra trait bound - active_types: Box<[Type]>, - /// Trait bounds to add to the active types - active_trait_bounds: Box<[proc_macro2::TokenStream]>, - /// Any types that won't be reflected and need an extra trait bound - ignored_types: Box<[Type]>, - /// Trait bounds to add to the ignored types - ignored_trait_bounds: Box<[proc_macro2::TokenStream]>, +/// Options defining how to extend the `where` clause for reflection. +pub(crate) struct WhereClauseOptions<'a, 'b> { + meta: &'a ReflectMeta<'b>, + additional_bounds: proc_macro2::TokenStream, + required_bounds: proc_macro2::TokenStream, } -impl Default for WhereClauseOptions { - /// By default, don't add any additional bounds to the `where` clause - fn default() -> Self { +impl<'a, 'b> WhereClauseOptions<'a, 'b> { + /// Create [`WhereClauseOptions`] for a reflected struct or enum type. + pub fn new(meta: &'a ReflectMeta<'b>) -> Self { + let bevy_reflect_path = meta.bevy_reflect_path(); + + let active_bound = if meta.from_reflect().should_auto_derive() { + quote!(#bevy_reflect_path::FromReflect) + } else { + quote!(#bevy_reflect_path::Reflect) + }; + + let type_path_bound = if meta.traits().type_path_attrs().should_auto_derive() { + Some(quote!(#bevy_reflect_path::TypePath +)) + } else { + None + }; + Self { - parameter_types: Box::new([]), - active_types: Box::new([]), - ignored_types: Box::new([]), - active_trait_bounds: Box::new([]), - ignored_trait_bounds: Box::new([]), - parameter_trait_bounds: Box::new([]), + meta, + additional_bounds: quote!(#type_path_bound #active_bound), + required_bounds: quote!(#type_path_bound #FQAny + #FQSend + #FQSync), } } -} -impl WhereClauseOptions { - /// Create [`WhereClauseOptions`] for a struct or enum type. - pub fn new<'a: 'b, 'b>( - meta: &ReflectMeta, - active_fields: impl Iterator>, - ignored_fields: impl Iterator>, - ) -> Self { - Self::new_with_bounds(meta, active_fields, ignored_fields, |_| None, |_| None) - } + /// Create [`WhereClauseOptions`] with the minimum bounds needed to fulfill `TypePath`. + pub fn new_type_path(meta: &'a ReflectMeta<'b>) -> Self { + let bevy_reflect_path = meta.bevy_reflect_path(); - /// Create [`WhereClauseOptions`] for a simple value type. - pub fn new_value(meta: &ReflectMeta) -> Self { - Self::new_with_bounds( + Self { meta, - std::iter::empty(), - std::iter::empty(), - |_| None, - |_| None, - ) + additional_bounds: quote!(#bevy_reflect_path::TypePath), + required_bounds: quote!(#bevy_reflect_path::TypePath + #FQAny + #FQSend + #FQSync), + } } - /// Create [`WhereClauseOptions`] for a struct or enum type. + /// Extends the `where` clause in reflection with additional bounds needed for reflection. /// - /// Compared to [`WhereClauseOptions::new`], this version allows you to specify - /// custom trait bounds for each field. - pub fn new_with_bounds<'a: 'b, 'b>( - meta: &ReflectMeta, - active_fields: impl Iterator>, - ignored_fields: impl Iterator>, - active_bounds: impl Fn(&StructField<'a>) -> Option, - ignored_bounds: impl Fn(&StructField<'a>) -> Option, - ) -> Self { - let bevy_reflect_path = meta.bevy_reflect_path(); - let is_from_reflect = meta.from_reflect().should_auto_derive(); - - let (active_types, active_trait_bounds): (Vec<_>, Vec<_>) = active_fields - .map(|field| { - let ty = field.data.ty.clone(); - - let custom_bounds = active_bounds(field).map(|bounds| quote!(+ #bounds)); - - let bounds = if is_from_reflect { - quote!(#bevy_reflect_path::FromReflect #custom_bounds) - } else { - quote!(#bevy_reflect_path::Reflect #custom_bounds) - }; - - (ty, bounds) - }) - .unzip(); - - let (ignored_types, ignored_trait_bounds): (Vec<_>, Vec<_>) = ignored_fields - .map(|field| { - let ty = field.data.ty.clone(); - - let custom_bounds = ignored_bounds(field).map(|bounds| quote!(+ #bounds)); - let bounds = quote!(#FQAny + #FQSend + #FQSync #custom_bounds); - - (ty, bounds) - }) - .unzip(); - - let (parameter_types, parameter_trait_bounds): (Vec<_>, Vec<_>) = - if meta.traits().type_path_attrs().should_auto_derive() { - meta.type_path() - .generics() - .type_params() - .map(|param| { - let ident = param.ident.clone(); - let bounds = quote!(#bevy_reflect_path::TypePath); - (ident, bounds) - }) - .unzip() - } else { - // If we don't need to derive `TypePath` for the type parameters, - // we can skip adding its bound to the `where` clause. - (Vec::new(), Vec::new()) - }; + /// This will only add bounds for generic type parameters. + /// + /// If the container has a `#[reflect(where)]` attribute, + /// this method will extend the type parameters with the _required_ bounds. + /// If the attribute is not present, it will extend the type parameters with the _additional_ bounds. + /// + /// The required bounds are the minimum bounds needed for a type to be reflected. + /// These include `TypePath`, `Any`, `Send`, and `Sync`. + /// + /// The additional bounds are added bounds used to enforce that a generic type parameter + /// is itself reflectable. + /// These include `Reflect` and `FromReflect`, as well as `TypePath`. + /// + /// # Example + /// + /// Take the following struct: + /// + /// ```ignore (bevy_reflect is not accessible from this crate) + /// #[derive(Reflect)] + /// struct Foo { + /// a: T, + /// #[reflect(ignore)] + /// b: U + /// } + /// ``` + /// + /// It has type parameters `T` and `U`. + /// + /// Since there is no `#[reflect(where)]` attribute, this method will extend the type parameters + /// with the additional bounds: + /// + /// ```ignore (bevy_reflect is not accessible from this crate) + /// where + /// T: FromReflect + TypePath, // additional bounds + /// U: FromReflect + TypePath, // additional bounds + /// ``` + /// + /// If we had this struct: + /// ```ignore (bevy_reflect is not accessible from this crate) + /// #[derive(Reflect)] + /// #[reflect(where T: FromReflect + Default)] + /// struct Foo { + /// a: T, + /// #[reflect(ignore)] + /// b: U + /// } + /// ``` + /// + /// Since there is a `#[reflect(where)]` attribute, this method will extend the type parameters + /// with _just_ the required bounds along with the predicates specified in the attribute: + /// + /// ```ignore (bevy_reflect is not accessible from this crate) + /// where + /// T: FromReflect + Default, // predicates from attribute + /// T: TypePath + Any + Send + Sync, // required bounds + /// U: TypePath + Any + Send + Sync, // required bounds + /// ``` + pub fn extend_where_clause( + &self, + where_clause: Option<&WhereClause>, + ) -> proc_macro2::TokenStream { + // Maintain existing where clause, if any. + let mut generic_where_clause = if let Some(where_clause) = where_clause { + let predicates = where_clause.predicates.iter(); + quote! {where Self: 'static, #(#predicates,)*} + } else { + quote!(where Self: 'static,) + }; + + // Add additional reflection trait bounds + let types = self.type_param_idents(); + let custom_where = self + .meta + .traits() + .custom_where() + .map(|clause| &clause.predicates); + let trait_bounds = self.trait_bounds(); + + generic_where_clause.extend(quote! { + #(#types: #trait_bounds,)* + #custom_where + }); + + generic_where_clause + } - Self { - active_types: active_types.into_boxed_slice(), - active_trait_bounds: active_trait_bounds.into_boxed_slice(), - ignored_types: ignored_types.into_boxed_slice(), - ignored_trait_bounds: ignored_trait_bounds.into_boxed_slice(), - parameter_types: parameter_types.into_boxed_slice(), - parameter_trait_bounds: parameter_trait_bounds.into_boxed_slice(), + /// Returns the trait bounds to use for all type parameters. + fn trait_bounds(&self) -> &proc_macro2::TokenStream { + if self.meta.traits().custom_where().is_some() { + &self.required_bounds + } else { + &self.additional_bounds } } -} -/// Extends the `where` clause in reflection with any additional bounds needed. -/// -/// This is mostly used to add additional bounds to reflected objects with generic types. -/// For reflection purposes, we usually have: -/// * `active_trait_bounds: Reflect` -/// * `ignored_trait_bounds: Any + Send + Sync` -/// -/// # Arguments -/// -/// * `where_clause`: existing `where` clause present on the object to be derived -/// * `where_clause_options`: additional parameters defining which trait bounds to add to the `where` clause -/// -/// # Example -/// -/// The struct: -/// ```ignore (bevy_reflect is not accessible from this crate) -/// #[derive(Reflect)] -/// struct Foo { -/// a: T, -/// #[reflect(ignore)] -/// b: U -/// } -/// ``` -/// will have active types: `[T]` and ignored types: `[U]` -/// -/// The `extend_where_clause` function will yield the following `where` clause: -/// ```ignore (bevy_reflect is not accessible from this crate) -/// where -/// T: Reflect, // active_trait_bounds -/// U: Any + Send + Sync, // ignored_trait_bounds -/// ``` -pub(crate) fn extend_where_clause( - where_clause: Option<&WhereClause>, - where_clause_options: &WhereClauseOptions, -) -> proc_macro2::TokenStream { - let parameter_types = &where_clause_options.parameter_types; - let active_types = &where_clause_options.active_types; - let ignored_types = &where_clause_options.ignored_types; - let parameter_trait_bounds = &where_clause_options.parameter_trait_bounds; - let active_trait_bounds = &where_clause_options.active_trait_bounds; - let ignored_trait_bounds = &where_clause_options.ignored_trait_bounds; - - let mut generic_where_clause = if let Some(where_clause) = where_clause { - let predicates = where_clause.predicates.iter(); - quote! {where #(#predicates,)*} - } else if !(parameter_types.is_empty() && active_types.is_empty() && ignored_types.is_empty()) { - quote! {where} - } else { - quote!() - }; - - // The nested parentheses here are required to properly scope HRTBs coming - // from field types to the type itself, as the compiler will scope them to - // the whole bound by default, resulting in a failure to prove trait - // adherence. - generic_where_clause.extend(quote! { - #((#active_types): #active_trait_bounds,)* - #((#ignored_types): #ignored_trait_bounds,)* - // Leave parameter bounds to the end for more sane error messages. - #((#parameter_types): #parameter_trait_bounds,)* - }); - generic_where_clause + /// Returns an iterator of the type parameter idents for the reflected type. + fn type_param_idents(&self) -> impl Iterator { + self.meta + .type_path() + .generics() + .type_params() + .map(|param| ¶m.ident) + } } impl Default for ResultSifter { diff --git a/crates/bevy_reflect/src/lib.rs b/crates/bevy_reflect/src/lib.rs index 8989ecaea9ce7..59a69ca6c7f7c 100644 --- a/crates/bevy_reflect/src/lib.rs +++ b/crates/bevy_reflect/src/lib.rs @@ -545,6 +545,7 @@ mod tests { ser::{to_string_pretty, PrettyConfig}, Deserializer, }; + use static_assertions::{assert_impl_all, assert_not_impl_all}; use std::{ any::TypeId, borrow::Cow, @@ -1867,6 +1868,72 @@ bevy_reflect::tests::Test { assert_eq!("123", format!("{:?}", foo)); } + #[test] + fn should_allow_custom_where() { + #[derive(Reflect)] + #[reflect(where T: Default)] + struct Foo(String, #[reflect(ignore)] PhantomData); + + #[derive(Default, TypePath)] + struct Bar; + + #[derive(TypePath)] + struct Baz; + + assert_impl_all!(Foo: Reflect); + assert_not_impl_all!(Foo: Reflect); + } + + #[test] + fn should_allow_empty_custom_where() { + #[derive(Reflect)] + #[reflect(where)] + struct Foo(String, #[reflect(ignore)] PhantomData); + + #[derive(TypePath)] + struct Bar; + + assert_impl_all!(Foo: Reflect); + } + + #[test] + fn should_allow_multiple_custom_where() { + #[derive(Reflect)] + #[reflect(where T: Default + FromReflect)] + #[reflect(where U: std::ops::Add + FromReflect)] + struct Foo(T, U); + + #[derive(Reflect)] + struct Baz { + a: Foo, + b: Foo, + } + + assert_impl_all!(Foo: Reflect); + assert_not_impl_all!(Foo: Reflect); + } + + #[test] + fn should_allow_custom_where_wtih_assoc_type() { + trait Trait { + type Assoc: FromReflect + TypePath; + } + + // We don't need `T` to be `Reflect` since we only care about `T::Assoc` + #[derive(Reflect)] + #[reflect(where T::Assoc: FromReflect)] + struct Foo(T::Assoc); + + #[derive(TypePath)] + struct Bar; + + impl Trait for Bar { + type Assoc = usize; + } + + assert_impl_all!(Foo: Reflect); + } + #[test] fn recursive_typed_storage_does_not_hang() { #[derive(Reflect)] @@ -1874,12 +1941,36 @@ bevy_reflect::tests::Test { let _ = > as Typed>::type_info(); let _ = > as TypePath>::type_path(); + + #[derive(Reflect)] + struct SelfRecurse { + recurse: Vec, + } + + let _ = ::type_info(); + let _ = ::type_path(); + + #[derive(Reflect)] + enum RecurseA { + Recurse(RecurseB), + } + + #[derive(Reflect)] + struct RecurseB { + vector: Vec, + } + + let _ = ::type_info(); + let _ = ::type_path(); + let _ = ::type_info(); + let _ = ::type_path(); } #[test] fn can_opt_out_type_path() { #[derive(Reflect)] #[reflect(type_path = false)] + #[reflect(where)] struct Foo { #[reflect(ignore)] _marker: PhantomData, diff --git a/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.rs b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.rs new file mode 100644 index 0000000000000..f10793a5f08a9 --- /dev/null +++ b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.rs @@ -0,0 +1,31 @@ +use bevy_reflect::{Reflect, FromType}; +use std::marker::PhantomData; + +#[derive(Clone)] +struct ReflectMyTrait; + +impl FromType for ReflectMyTrait { + fn from_type() -> Self { + Self + } +} + +// Reason: where clause cannot be used with #[reflect(MyTrait)] +#[derive(Reflect)] +#[reflect(MyTrait, where)] +pub struct Foo { + value: String, + #[reflect(ignore)] + _marker: PhantomData, +} + +// Reason: where clause cannot be used with #[reflect(MyTrait)] +#[derive(Reflect)] +#[reflect(where, MyTrait)] +pub struct Bar { + value: String, + #[reflect(ignore)] + _marker: PhantomData, +} + +fn main() {} \ No newline at end of file diff --git a/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.stderr b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.stderr new file mode 100644 index 0000000000000..10916ef3541ae --- /dev/null +++ b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.fail.stderr @@ -0,0 +1,11 @@ +error: expected identifier, found keyword `where` + --> tests/reflect_derive/custom_where.fail.rs:15:20 + | +15 | #[reflect(MyTrait, where)] + | ^^^^^ + +error: unexpected token + --> tests/reflect_derive/custom_where.fail.rs:24:16 + | +24 | #[reflect(where, MyTrait)] + | ^ diff --git a/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.pass.rs b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.pass.rs new file mode 100644 index 0000000000000..f6c4eede76e74 --- /dev/null +++ b/crates/bevy_reflect_compile_fail_tests/tests/reflect_derive/custom_where.pass.rs @@ -0,0 +1,31 @@ +use bevy_reflect::{Reflect, FromType}; +use std::marker::PhantomData; + +#[derive(Clone)] +struct ReflectMyTrait; + +impl FromType for ReflectMyTrait { + fn from_type() -> Self { + Self + } +} + +#[derive(Reflect)] +#[reflect(MyTrait)] +#[reflect(where)] +pub struct Foo { + value: String, + #[reflect(ignore)] + _marker: PhantomData, +} + +#[derive(Reflect)] +#[reflect(where)] +#[reflect(MyTrait)] +pub struct Bar { + value: String, + #[reflect(ignore)] + _marker: PhantomData, +} + +fn main() {}