Skip to content

Commit

Permalink
Added method to restart the current state (#3328)
Browse files Browse the repository at this point in the history
# Objective

It would be useful to be able to restart a state (such as if an operation fails and needs to be retried from `on_enter`). Currently, it seems the way to restart a state is to transition to a dummy state and then transition back.

## Solution

The solution is to add a `restart` method on `State<T>` that allows for transitioning to the already-active state.

## Context

Based on [this](https://discord.com/channels/691052431525675048/742884593551802431/920335041756815441) question from the Discord.

Closes #2385


Co-authored-by: Carter Anderson <mcanders1@gmail.com>
  • Loading branch information
MrGVSV and cart committed Feb 4, 2022
1 parent e2cce09 commit f00aec2
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions crates/bevy_ecs/src/schedule/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ impl<T> StateData for T where T: Send + Sync + Clone + Eq + Debug + Hash + 'stat
#[derive(Debug)]
pub struct State<T: StateData> {
transition: Option<StateTransition<T>>,
/// The current states in the stack.
///
/// There is always guaranteed to be at least one.
stack: Vec<T>,
scheduled: Option<ScheduledOperation<T>>,
end_next_loop: bool,
Expand Down Expand Up @@ -369,6 +372,25 @@ where
Ok(())
}

/// Schedule a state change that restarts the active state.
/// This will fail if there is a scheduled operation
pub fn restart(&mut self) -> Result<(), StateError> {
if self.scheduled.is_some() {
return Err(StateError::StateAlreadyQueued);
}

let state = self.stack.last().unwrap();
self.scheduled = Some(ScheduledOperation::Set(state.clone()));
Ok(())
}

/// Same as [`Self::restart`], but if there is already a scheduled state operation,
/// it will be overwritten instead of failing
pub fn overwrite_restart(&mut self) {
let state = self.stack.last().unwrap();
self.scheduled = Some(ScheduledOperation::Set(state.clone()));
}

pub fn current(&self) -> &T {
self.stack.last().unwrap()
}
Expand Down Expand Up @@ -655,4 +677,89 @@ mod test {
stage.run(&mut world);
assert!(*world.get_resource::<bool>().unwrap(), "after test");
}

#[test]
fn restart_state_tests() {
#[derive(Clone, PartialEq, Eq, Debug, Hash)]
enum LoadState {
Load,
Finish,
}

#[derive(PartialEq, Eq, Debug)]
enum LoadStatus {
EnterLoad,
ExitLoad,
EnterFinish,
}

let mut world = World::new();
world.insert_resource(Vec::<LoadStatus>::new());
world.insert_resource(State::new(LoadState::Load));

let mut stage = SystemStage::parallel();
stage.add_system_set(State::<LoadState>::get_driver());

// Systems to track loading status
stage
.add_system_set(
State::on_enter_set(LoadState::Load)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::EnterLoad)),
)
.add_system_set(
State::on_exit_set(LoadState::Load)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::ExitLoad)),
)
.add_system_set(
State::on_enter_set(LoadState::Finish)
.with_system(|mut r: ResMut<Vec<LoadStatus>>| r.push(LoadStatus::EnterFinish)),
);

stage.run(&mut world);

// A. Restart state
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
let result = state.restart();
assert!(matches!(result, Ok(())));
stage.run(&mut world);

// B. Restart state (overwrite schedule)
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
state.set(LoadState::Finish).unwrap();
state.overwrite_restart();
stage.run(&mut world);

// C. Fail restart state (transition already scheduled)
let mut state = world.get_resource_mut::<State<LoadState>>().unwrap();
state.set(LoadState::Finish).unwrap();
let result = state.restart();
assert!(matches!(result, Err(StateError::StateAlreadyQueued)));
stage.run(&mut world);

const EXPECTED: &[LoadStatus] = &[
LoadStatus::EnterLoad,
// A
LoadStatus::ExitLoad,
LoadStatus::EnterLoad,
// B
LoadStatus::ExitLoad,
LoadStatus::EnterLoad,
// C
LoadStatus::ExitLoad,
LoadStatus::EnterFinish,
];

let mut collected = world.get_resource_mut::<Vec<LoadStatus>>().unwrap();
let mut count = 0;
for (found, expected) in collected.drain(..).zip(EXPECTED) {
assert_eq!(found, *expected);
count += 1;
}
// If not equal, some elements weren't executed
assert_eq!(EXPECTED.len(), count);
assert_eq!(
world.get_resource::<State<LoadState>>().unwrap().current(),
&LoadState::Finish
);
}
}

0 comments on commit f00aec2

Please sign in to comment.