Skip to content

Commit

Permalink
Auto-label function systems with SystemTypeIdLabel (#4224)
Browse files Browse the repository at this point in the history
This adds the concept of "default labels" for systems (currently scoped to "parallel systems", but this could just as easily be implemented for "exclusive systems"). Function systems now include their function's `SystemTypeIdLabel` by default.

This enables the following patterns:

```rust
// ordering two systems without manually defining labels
app
  .add_system(update_velocity)
  .add_system(movement.after(update_velocity))

// ordering sets of systems without manually defining labels
app
  .add_system(foo)
  .add_system_set(
    SystemSet::new()
      .after(foo)
      .with_system(bar)
      .with_system(baz)
  )
```

Fixes: #4219
Related to: #4220 

Credit to @aevyrie @alice-i-cecile @DJMcNab (and probably others) for proposing (and supporting) this idea about a year ago. I was a big dummy that both shut down this (very good) idea and then forgot I did that. Sorry. You all were right!
  • Loading branch information
cart committed Mar 23, 2022
1 parent d51b54a commit b1c3e98
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 51 deletions.
1 change: 1 addition & 0 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod prelude {
system::{
Commands, In, IntoChainSystem, IntoExclusiveSystem, IntoSystem, Local, NonSend,
NonSendMut, Query, QuerySet, RemovedComponents, Res, ResMut, System,
SystemParamFunction,
},
world::{FromWorld, Mut, World},
};
Expand Down
46 changes: 28 additions & 18 deletions crates/bevy_ecs/src/schedule/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,15 +1578,25 @@ mod tests {
fn ambiguity_detection() {
use super::{find_ambiguities, SystemContainer};

fn find_ambiguities_first_labels(
fn find_ambiguities_first_str_labels(
systems: &[impl SystemContainer],
) -> Vec<(BoxedSystemLabel, BoxedSystemLabel)> {
find_ambiguities(systems)
.drain(..)
.map(|(index_a, index_b, _conflicts)| {
(
systems[index_a].labels()[0].clone(),
systems[index_b].labels()[0].clone(),
systems[index_a]
.labels()
.iter()
.find(|a| (&***a).type_id() == std::any::TypeId::of::<&str>())
.unwrap()
.clone(),
systems[index_b]
.labels()
.iter()
.find(|a| (&***a).type_id() == std::any::TypeId::of::<&str>())
.unwrap()
.clone(),
)
})
.collect()
Expand Down Expand Up @@ -1616,7 +1626,7 @@ mod tests {
.with_system(component.label("4"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("4")))
|| ambiguities.contains(&(Box::new("4"), Box::new("1")))
Expand All @@ -1631,7 +1641,7 @@ mod tests {
.with_system(resource.label("4"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("4")))
|| ambiguities.contains(&(Box::new("4"), Box::new("1")))
Expand All @@ -1656,7 +1666,7 @@ mod tests {
.with_system(resource.label("4"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("0"), Box::new("3")))
|| ambiguities.contains(&(Box::new("3"), Box::new("0")))
Expand All @@ -1675,7 +1685,7 @@ mod tests {
.with_system(resource.label("4").in_ambiguity_set("a"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("0"), Box::new("3")))
|| ambiguities.contains(&(Box::new("3"), Box::new("0")))
Expand All @@ -1688,7 +1698,7 @@ mod tests {
.with_system(component.label("2"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("0"), Box::new("1")))
|| ambiguities.contains(&(Box::new("1"), Box::new("0")))
Expand All @@ -1701,7 +1711,7 @@ mod tests {
.with_system(component.label("2").after("0"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("2")))
|| ambiguities.contains(&(Box::new("2"), Box::new("1")))
Expand All @@ -1715,7 +1725,7 @@ mod tests {
.with_system(component.label("3").after("1").after("2"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("2")))
|| ambiguities.contains(&(Box::new("2"), Box::new("1")))
Expand All @@ -1729,7 +1739,7 @@ mod tests {
.with_system(component.label("3").after("1").after("2"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert_eq!(ambiguities.len(), 0);

let mut stage = SystemStage::parallel()
Expand All @@ -1739,7 +1749,7 @@ mod tests {
.with_system(component.label("3").after("1").after("2"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("2")))
|| ambiguities.contains(&(Box::new("2"), Box::new("1")))
Expand Down Expand Up @@ -1769,7 +1779,7 @@ mod tests {
);
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("2")))
|| ambiguities.contains(&(Box::new("2"), Box::new("1")))
Expand Down Expand Up @@ -1819,7 +1829,7 @@ mod tests {
);
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert_eq!(ambiguities.len(), 0);

let mut stage = SystemStage::parallel()
Expand Down Expand Up @@ -1850,7 +1860,7 @@ mod tests {
);
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.parallel);
let ambiguities = find_ambiguities_first_str_labels(&stage.parallel);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("4")))
|| ambiguities.contains(&(Box::new("4"), Box::new("1")))
Expand Down Expand Up @@ -1884,7 +1894,7 @@ mod tests {
.with_system(empty.exclusive_system().label("6").after("2").after("5"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.exclusive_at_start);
let ambiguities = find_ambiguities_first_str_labels(&stage.exclusive_at_start);
assert!(
ambiguities.contains(&(Box::new("1"), Box::new("3")))
|| ambiguities.contains(&(Box::new("3"), Box::new("1")))
Expand Down Expand Up @@ -1921,7 +1931,7 @@ mod tests {
.with_system(empty.exclusive_system().label("6").after("2").after("5"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.exclusive_at_start);
let ambiguities = find_ambiguities_first_str_labels(&stage.exclusive_at_start);
assert!(
ambiguities.contains(&(Box::new("2"), Box::new("3")))
|| ambiguities.contains(&(Box::new("3"), Box::new("2")))
Expand All @@ -1947,7 +1957,7 @@ mod tests {
.with_system(empty.exclusive_system().label("3").in_ambiguity_set("a"));
stage.initialize_systems(&mut world);
stage.rebuild_orders_and_dependencies();
let ambiguities = find_ambiguities_first_labels(&stage.exclusive_at_start);
let ambiguities = find_ambiguities_first_str_labels(&stage.exclusive_at_start);
assert_eq!(ambiguities.len(), 0);
}

Expand Down
27 changes: 15 additions & 12 deletions crates/bevy_ecs/src/schedule/system_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use crate::{
AmbiguitySetLabel, BoxedAmbiguitySetLabel, BoxedSystemLabel, IntoRunCriteria,
RunCriteriaDescriptorOrLabel, SystemLabel,
},
system::{BoxedSystem, ExclusiveSystem, ExclusiveSystemCoerced, ExclusiveSystemFn, IntoSystem},
system::{
AsSystemLabel, BoxedSystem, ExclusiveSystem, ExclusiveSystemCoerced, ExclusiveSystemFn,
IntoSystem,
},
};

/// Encapsulates a system and information on when it run in a `SystemStage`.
Expand Down Expand Up @@ -105,9 +108,9 @@ pub struct ParallelSystemDescriptor {

fn new_parallel_descriptor(system: BoxedSystem<(), ()>) -> ParallelSystemDescriptor {
ParallelSystemDescriptor {
labels: system.default_labels(),
system,
run_criteria: None,
labels: Vec::new(),
before: Vec::new(),
after: Vec::new(),
ambiguity_sets: Vec::new(),
Expand All @@ -126,10 +129,10 @@ pub trait ParallelSystemDescriptorCoercion<Params> {
fn label(self, label: impl SystemLabel) -> ParallelSystemDescriptor;

/// Specifies that the system should run before systems with the given label.
fn before(self, label: impl SystemLabel) -> ParallelSystemDescriptor;
fn before<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor;

/// Specifies that the system should run after systems with the given label.
fn after(self, label: impl SystemLabel) -> ParallelSystemDescriptor;
fn after<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor;

/// Specifies that the system is exempt from execution order ambiguity detection
/// with other systems in this set.
Expand All @@ -150,13 +153,13 @@ impl ParallelSystemDescriptorCoercion<()> for ParallelSystemDescriptor {
self
}

fn before(mut self, label: impl SystemLabel) -> ParallelSystemDescriptor {
self.before.push(Box::new(label));
fn before<Marker>(mut self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
self.before.push(Box::new(label.as_system_label()));
self
}

fn after(mut self, label: impl SystemLabel) -> ParallelSystemDescriptor {
self.after.push(Box::new(label));
fn after<Marker>(mut self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
self.after.push(Box::new(label.as_system_label()));
self
}

Expand All @@ -182,11 +185,11 @@ where
new_parallel_descriptor(Box::new(IntoSystem::into_system(self))).label(label)
}

fn before(self, label: impl SystemLabel) -> ParallelSystemDescriptor {
fn before<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
new_parallel_descriptor(Box::new(IntoSystem::into_system(self))).before(label)
}

fn after(self, label: impl SystemLabel) -> ParallelSystemDescriptor {
fn after<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
new_parallel_descriptor(Box::new(IntoSystem::into_system(self))).after(label)
}

Expand All @@ -207,11 +210,11 @@ impl ParallelSystemDescriptorCoercion<()> for BoxedSystem<(), ()> {
new_parallel_descriptor(self).label(label)
}

fn before(self, label: impl SystemLabel) -> ParallelSystemDescriptor {
fn before<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
new_parallel_descriptor(self).before(label)
}

fn after(self, label: impl SystemLabel) -> ParallelSystemDescriptor {
fn after<Marker>(self, label: impl AsSystemLabel<Marker>) -> ParallelSystemDescriptor {
new_parallel_descriptor(self).after(label)
}

Expand Down
69 changes: 68 additions & 1 deletion crates/bevy_ecs/src/system/function_system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ use crate::{
archetype::{Archetype, ArchetypeComponentId, ArchetypeGeneration, ArchetypeId},
component::ComponentId,
query::{Access, FilteredAccessSet},
schedule::SystemLabel,
system::{
check_system_change_tick, ReadOnlySystemParamFetch, System, SystemParam, SystemParamFetch,
SystemParamState,
},
world::{World, WorldId},
};
use bevy_ecs_macros::all_tuples;
use std::{borrow::Cow, marker::PhantomData};
use std::{borrow::Cow, fmt::Debug, hash::Hash, marker::PhantomData};

/// The metadata of a [`System`].
pub struct SystemMeta {
Expand Down Expand Up @@ -422,6 +423,47 @@ where
self.system_meta.name.as_ref(),
);
}
fn default_labels(&self) -> Vec<Box<dyn SystemLabel>> {
vec![Box::new(self.func.as_system_label())]
}
}

/// A [`SystemLabel`] that was automatically generated for a system on the basis of its `TypeId`.
pub struct SystemTypeIdLabel<T: 'static>(PhantomData<fn() -> T>);

impl<T> Debug for SystemTypeIdLabel<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("SystemTypeIdLabel")
.field(&std::any::type_name::<T>())
.finish()
}
}
impl<T> Hash for SystemTypeIdLabel<T> {
fn hash<H: std::hash::Hasher>(&self, _state: &mut H) {
// All SystemTypeIds of a given type are the same.
}
}
impl<T> Clone for SystemTypeIdLabel<T> {
fn clone(&self) -> Self {
Self(PhantomData)
}
}

impl<T> Copy for SystemTypeIdLabel<T> {}

impl<T> PartialEq for SystemTypeIdLabel<T> {
#[inline]
fn eq(&self, _other: &Self) -> bool {
// All labels of a given type are equal, as they will all have the same type id
true
}
}
impl<T> Eq for SystemTypeIdLabel<T> {}

impl<T> SystemLabel for SystemTypeIdLabel<T> {
fn dyn_clone(&self) -> Box<dyn SystemLabel> {
Box::new(*self)
}
}

/// A trait implemented for all functions that can be used as [`System`]s.
Expand Down Expand Up @@ -490,3 +532,28 @@ macro_rules! impl_system_function {
}

all_tuples!(impl_system_function, 0, 16, F);

/// Used to implicitly convert systems to their default labels. For example, it will convert
/// "system functions" to their [`SystemTypeIdLabel`].
pub trait AsSystemLabel<Marker> {
type SystemLabel: SystemLabel;
fn as_system_label(&self) -> Self::SystemLabel;
}

impl<In, Out, Param: SystemParam, Marker, T: SystemParamFunction<In, Out, Param, Marker>>
AsSystemLabel<(In, Out, Param, Marker)> for T
{
type SystemLabel = SystemTypeIdLabel<Self>;

fn as_system_label(&self) -> Self::SystemLabel {
SystemTypeIdLabel(PhantomData::<fn() -> Self>)
}
}

impl<T: SystemLabel + Clone> AsSystemLabel<()> for T {
type SystemLabel = T;

fn as_system_label(&self) -> Self::SystemLabel {
self.clone()
}
}
5 changes: 5 additions & 0 deletions crates/bevy_ecs/src/system/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
archetype::{Archetype, ArchetypeComponentId},
component::ComponentId,
query::Access,
schedule::SystemLabel,
world::World,
};
use std::borrow::Cow;
Expand Down Expand Up @@ -56,6 +57,10 @@ pub trait System: Send + Sync + 'static {
/// Initialize the system.
fn initialize(&mut self, _world: &mut World);
fn check_change_tick(&mut self, change_tick: u32);
/// The default labels for the system
fn default_labels(&self) -> Vec<Box<dyn SystemLabel>> {
Vec::new()
}
}

/// A convenience type alias for a boxed [`System`] trait object.
Expand Down
Loading

0 comments on commit b1c3e98

Please sign in to comment.