diff --git a/crates/bevy_ecs/src/component.rs b/crates/bevy_ecs/src/component.rs index 7a67571d89056..7f8954b82e114 100644 --- a/crates/bevy_ecs/src/component.rs +++ b/crates/bevy_ecs/src/component.rs @@ -1376,10 +1376,10 @@ impl Components { // SAFETY: Component ID and constructor match the ones on the original requiree. // The original requiree is responsible for making sure the registration is safe. unsafe { - required_components.register_dynamic( + required_components.register_dynamic_with( *component_id, - component.constructor.clone(), component.inheritance_depth + depth + 1, + || component.constructor.clone(), ); }; } @@ -1422,21 +1422,21 @@ impl Components { .collect(); // Register the new required components. - for (component_id, component) in inherited_requirements.iter().cloned() { + for (component_id, component) in inherited_requirements.iter() { // Register the required component for the requiree. // SAFETY: Component ID and constructor match the ones on the original requiree. unsafe { - required_components.register_dynamic( - component_id, - component.constructor, + required_components.register_dynamic_with( + *component_id, component.inheritance_depth, + || component.constructor.clone(), ); }; // Add the requiree to the list of components that require the required component. // SAFETY: The caller ensures that the required components are valid. let required_by = unsafe { - self.get_required_by_mut(component_id) + self.get_required_by_mut(*component_id) .debug_checked_unwrap() }; required_by.insert(requiree); @@ -1992,25 +1992,30 @@ impl RequiredComponents { /// `constructor` _must_ initialize a component for `component_id` in such a way that /// matches the storage type of the component. It must only use the given `table_row` or `Entity` to /// initialize the storage for `component_id` corresponding to the given entity. - pub unsafe fn register_dynamic( + pub unsafe fn register_dynamic_with( &mut self, component_id: ComponentId, - constructor: RequiredComponentConstructor, inheritance_depth: u16, + constructor: impl FnOnce() -> RequiredComponentConstructor, ) { - self.0 - .entry(component_id) - .and_modify(|component| { - if component.inheritance_depth > inheritance_depth { - // New registration is more specific than existing requirement - component.constructor = constructor.clone(); - component.inheritance_depth = inheritance_depth; + let entry = self.0.entry(component_id); + match entry { + bevy_platform_support::collections::hash_map::Entry::Occupied(mut occupied) => { + let current = occupied.get_mut(); + if current.inheritance_depth > inheritance_depth { + *current = RequiredComponent { + constructor: constructor(), + inheritance_depth, + } } - }) - .or_insert(RequiredComponent { - constructor, - inheritance_depth, - }); + } + bevy_platform_support::collections::hash_map::Entry::Vacant(vacant) => { + vacant.insert(RequiredComponent { + constructor: constructor(), + inheritance_depth, + }); + } + } } /// Registers a required component. @@ -2037,64 +2042,66 @@ impl RequiredComponents { constructor: fn() -> C, inheritance_depth: u16, ) { - let erased: RequiredComponentConstructor = RequiredComponentConstructor({ - // `portable-atomic-util` `Arc` is not able to coerce an unsized - // type like `std::sync::Arc` can. Creating a `Box` first does the - // coercion. - // - // This would be resolved by https://github.com/rust-lang/rust/issues/123430 - - #[cfg(not(target_has_atomic = "ptr"))] - use alloc::boxed::Box; - - type Constructor = dyn for<'a, 'b> Fn( - &'a mut Table, - &'b mut SparseSets, - Tick, - TableRow, - Entity, - MaybeLocation, - ); + let erased = || { + RequiredComponentConstructor({ + // `portable-atomic-util` `Arc` is not able to coerce an unsized + // type like `std::sync::Arc` can. Creating a `Box` first does the + // coercion. + // + // This would be resolved by https://github.com/rust-lang/rust/issues/123430 + + #[cfg(not(target_has_atomic = "ptr"))] + use alloc::boxed::Box; + + type Constructor = dyn for<'a, 'b> Fn( + &'a mut Table, + &'b mut SparseSets, + Tick, + TableRow, + Entity, + MaybeLocation, + ); - #[cfg(not(target_has_atomic = "ptr"))] - type Intermediate = Box; - - #[cfg(target_has_atomic = "ptr")] - type Intermediate = Arc; - - let boxed: Intermediate = Intermediate::new( - move |table, sparse_sets, change_tick, table_row, entity, caller| { - OwningPtr::make(constructor(), |ptr| { - // SAFETY: This will only be called in the context of `BundleInfo::write_components`, which will - // pass in a valid table_row and entity requiring a C constructor - // C::STORAGE_TYPE is the storage type associated with `component_id` / `C` - // `ptr` points to valid `C` data, which matches the type associated with `component_id` - unsafe { - BundleInfo::initialize_required_component( - table, - sparse_sets, - change_tick, - table_row, - entity, - component_id, - C::STORAGE_TYPE, - ptr, - caller, - ); - } - }); - }, - ); + #[cfg(not(target_has_atomic = "ptr"))] + type Intermediate = Box; + + #[cfg(target_has_atomic = "ptr")] + type Intermediate = Arc; + + let boxed: Intermediate = Intermediate::new( + move |table, sparse_sets, change_tick, table_row, entity, caller| { + OwningPtr::make(constructor(), |ptr| { + // SAFETY: This will only be called in the context of `BundleInfo::write_components`, which will + // pass in a valid table_row and entity requiring a C constructor + // C::STORAGE_TYPE is the storage type associated with `component_id` / `C` + // `ptr` points to valid `C` data, which matches the type associated with `component_id` + unsafe { + BundleInfo::initialize_required_component( + table, + sparse_sets, + change_tick, + table_row, + entity, + component_id, + C::STORAGE_TYPE, + ptr, + caller, + ); + } + }); + }, + ); - Arc::from(boxed) - }); + Arc::from(boxed) + }) + }; // SAFETY: // `component_id` matches the type initialized by the `erased` constructor above. // `erased` initializes a component for `component_id` in such a way that // matches the storage type of the component. It only uses the given `table_row` or `Entity` to // initialize the storage corresponding to the given entity. - unsafe { self.register_dynamic(component_id, erased, inheritance_depth) }; + unsafe { self.register_dynamic_with(component_id, inheritance_depth, erased) }; } /// Iterates the ids of all required components. This includes recursive required components. @@ -2112,11 +2119,26 @@ impl RequiredComponents { } } - // Merges `required_components` into this collection. This only inserts a required component - // if it _did not already exist_. + /// Merges `required_components` into this collection. This only inserts a required component + /// if it _did not already exist_ *or* if the required component is more specific than the existing one + /// (in other words, if the inheritance depth is smaller). + /// + /// See [`register_dynamic_with`](Self::register_dynamic_with) for details. pub(crate) fn merge(&mut self, required_components: &RequiredComponents) { - for (id, constructor) in &required_components.0 { - self.0.entry(*id).or_insert_with(|| constructor.clone()); + for ( + component_id, + RequiredComponent { + constructor, + inheritance_depth, + }, + ) in required_components.0.iter() + { + // SAFETY: This exact registration must have been done on `required_components`, so safety is ensured by that caller. + unsafe { + self.register_dynamic_with(*component_id, *inheritance_depth, || { + constructor.clone() + }); + } } } } diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index bae3dbfed2438..7962c8bf728a8 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -2635,6 +2635,37 @@ mod tests { assert_eq!(to_vec(required_z), vec![(b, 0), (c, 1)]); } + #[test] + fn required_components_inheritance_depth_bias() { + #[derive(Component, PartialEq, Eq, Clone, Copy, Debug)] + struct MyRequired(bool); + + #[derive(Component, Default)] + #[require(MyRequired(|| MyRequired(false)))] + struct MiddleMan; + + #[derive(Component, Default)] + #[require(MiddleMan)] + struct ConflictingRequire; + + #[derive(Component, Default)] + #[require(MyRequired(|| MyRequired(true)))] + struct MyComponent; + + let mut world = World::new(); + let order_a = world + .spawn((ConflictingRequire, MyComponent)) + .get::() + .cloned(); + let order_b = world + .spawn((MyComponent, ConflictingRequire)) + .get::() + .cloned(); + + assert_eq!(order_a, Some(MyRequired(true))); + assert_eq!(order_b, Some(MyRequired(true))); + } + #[test] #[should_panic = "Recursive required components detected: A → B → C → B\nhelp: If this is intentional, consider merging the components."] fn required_components_recursion_errors() {