From f00aec2454dff63a992dcef30256b8663c38ac5b Mon Sep 17 00:00:00 2001 From: MrGVSV Date: Fri, 4 Feb 2022 03:07:18 +0000 Subject: [PATCH] Added method to restart the current state (#3328) # Objective It would be useful to be able to restart a state (such as if an operation fails and needs to be retried from `on_enter`). Currently, it seems the way to restart a state is to transition to a dummy state and then transition back. ## Solution The solution is to add a `restart` method on `State` that allows for transitioning to the already-active state. ## Context Based on [this](https://discord.com/channels/691052431525675048/742884593551802431/920335041756815441) question from the Discord. Closes #2385 Co-authored-by: Carter Anderson --- crates/bevy_ecs/src/schedule/state.rs | 107 ++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/crates/bevy_ecs/src/schedule/state.rs b/crates/bevy_ecs/src/schedule/state.rs index 24b8801917cf5..7dcbefb6afb87 100644 --- a/crates/bevy_ecs/src/schedule/state.rs +++ b/crates/bevy_ecs/src/schedule/state.rs @@ -21,6 +21,9 @@ impl StateData for T where T: Send + Sync + Clone + Eq + Debug + Hash + 'stat #[derive(Debug)] pub struct State { transition: Option>, + /// The current states in the stack. + /// + /// There is always guaranteed to be at least one. stack: Vec, scheduled: Option>, end_next_loop: bool, @@ -369,6 +372,25 @@ where Ok(()) } + /// Schedule a state change that restarts the active state. + /// This will fail if there is a scheduled operation + pub fn restart(&mut self) -> Result<(), StateError> { + if self.scheduled.is_some() { + return Err(StateError::StateAlreadyQueued); + } + + let state = self.stack.last().unwrap(); + self.scheduled = Some(ScheduledOperation::Set(state.clone())); + Ok(()) + } + + /// Same as [`Self::restart`], but if there is already a scheduled state operation, + /// it will be overwritten instead of failing + pub fn overwrite_restart(&mut self) { + let state = self.stack.last().unwrap(); + self.scheduled = Some(ScheduledOperation::Set(state.clone())); + } + pub fn current(&self) -> &T { self.stack.last().unwrap() } @@ -655,4 +677,89 @@ mod test { stage.run(&mut world); assert!(*world.get_resource::().unwrap(), "after test"); } + + #[test] + fn restart_state_tests() { + #[derive(Clone, PartialEq, Eq, Debug, Hash)] + enum LoadState { + Load, + Finish, + } + + #[derive(PartialEq, Eq, Debug)] + enum LoadStatus { + EnterLoad, + ExitLoad, + EnterFinish, + } + + let mut world = World::new(); + world.insert_resource(Vec::::new()); + world.insert_resource(State::new(LoadState::Load)); + + let mut stage = SystemStage::parallel(); + stage.add_system_set(State::::get_driver()); + + // Systems to track loading status + stage + .add_system_set( + State::on_enter_set(LoadState::Load) + .with_system(|mut r: ResMut>| r.push(LoadStatus::EnterLoad)), + ) + .add_system_set( + State::on_exit_set(LoadState::Load) + .with_system(|mut r: ResMut>| r.push(LoadStatus::ExitLoad)), + ) + .add_system_set( + State::on_enter_set(LoadState::Finish) + .with_system(|mut r: ResMut>| r.push(LoadStatus::EnterFinish)), + ); + + stage.run(&mut world); + + // A. Restart state + let mut state = world.get_resource_mut::>().unwrap(); + let result = state.restart(); + assert!(matches!(result, Ok(()))); + stage.run(&mut world); + + // B. Restart state (overwrite schedule) + let mut state = world.get_resource_mut::>().unwrap(); + state.set(LoadState::Finish).unwrap(); + state.overwrite_restart(); + stage.run(&mut world); + + // C. Fail restart state (transition already scheduled) + let mut state = world.get_resource_mut::>().unwrap(); + state.set(LoadState::Finish).unwrap(); + let result = state.restart(); + assert!(matches!(result, Err(StateError::StateAlreadyQueued))); + stage.run(&mut world); + + const EXPECTED: &[LoadStatus] = &[ + LoadStatus::EnterLoad, + // A + LoadStatus::ExitLoad, + LoadStatus::EnterLoad, + // B + LoadStatus::ExitLoad, + LoadStatus::EnterLoad, + // C + LoadStatus::ExitLoad, + LoadStatus::EnterFinish, + ]; + + let mut collected = world.get_resource_mut::>().unwrap(); + let mut count = 0; + for (found, expected) in collected.drain(..).zip(EXPECTED) { + assert_eq!(found, *expected); + count += 1; + } + // If not equal, some elements weren't executed + assert_eq!(EXPECTED.len(), count); + assert_eq!( + world.get_resource::>().unwrap().current(), + &LoadState::Finish + ); + } }