diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 7d581c2c..a6cf276a 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -242,8 +242,11 @@ def _update_timed_terminal_event( def _disable_timed_terminal_event(self) -> None: """Turn off simulator termination due to this satellite's window close checker""" - if self._timed_terminal_event_name is not None: - self.simulator.eventMap[self._timed_terminal_event_name].eventActive = False + if ( + self._timed_terminal_event_name is not None + and self._timed_terminal_event_name in self.simulator.eventMap + ): + self.simulator.delete_event(self._timed_terminal_event_name) @abstractmethod # pragma: no cover def get_obs(self) -> SatObs: @@ -562,8 +565,12 @@ def _update_image_event(self, target: Target) -> None: def _disable_image_event(self) -> None: """Turn off simulator termination due to this satellite's imaging checker""" - if self._image_event_name is not None: - self.simulator.eventMap[self._image_event_name].eventActive = False + if ( + self._image_event_name is not None + and self._image_event_name in self.simulator.eventMap + ): + self.simulator.delete_event(self._image_event_name) + # self.simulator.eventMap[self._image_event_name].eventActive = False def parse_target_selection(self, target_query: Union[int, Target, str]): """Identify a target based on upcoming target index, Target object, or target diff --git a/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py b/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py index 08d164bc..9e6c30b0 100644 --- a/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py +++ b/bsk_rl/envs/general_satellite_tasking/simulation/simulator.py @@ -83,6 +83,12 @@ def run(self) -> None: self.ConfigureStopTime(simulation_time) self.ExecuteSimulation() + def delete_event(self, event_name) -> None: + """Removes an event from the event map. Makes event checking faster""" + event = self.eventMap[event_name] + self.eventList.remove(event) + del self.eventMap[event_name] + def __del__(self): if MEMORY_LEAK_CHECKING: # pragma: no cover print("~~~ BSK SIMULATOR DELETED ~~~") diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py index 2f7d6f42..27f5a71f 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_satellites.py @@ -115,17 +115,17 @@ def test_update_timed_terminal_event(self): def test_disable_timed_event(self): sat = sats.Satellite(name="TestSat", sat_args={}) - sat.simulator = MagicMock() + sat.simulator = MagicMock(eventMap={"some_event": 1}) sat._timed_terminal_event_name = "some_event" sat._disable_timed_terminal_event() - assert sat.simulator.eventMap.__getitem__.called + sat.simulator.delete_event.assert_called_with("some_event") def test_disable_timed_event_no_event(self): sat = sats.Satellite(name="TestSat", sat_args={}) - sat.simulator = MagicMock() + sat.simulator = MagicMock(eventMap={"some_event": 1}) sat._timed_terminal_event_name = None sat._disable_timed_terminal_event() - assert not sat.simulator.eventMap.__getitem__.called + assert not sat.simulator.delete_event.called def test_proxy_setters(self): # Must be last test or others break @@ -292,17 +292,17 @@ def test_update_image_event_existing(self): def test_disable_image_event(self): sat = self.make_sat() - sat.simulator = MagicMock() + sat.simulator = MagicMock(eventMap={"some_image_event": 1}) sat._image_event_name = "some_image_event" sat._disable_image_event() - assert sat.simulator.eventMap.__getitem__.called + sat.simulator.delete_event.assert_called_with("some_image_event") def test_disable_image_event_no_event(self): sat = self.make_sat() - sat.simulator = MagicMock() + sat.simulator = MagicMock(eventMap={"some_event": 1}) sat._image_event_name = None sat._disable_image_event() - assert not sat.simulator.eventMap.__getitem__.called + assert not sat.simulator.delete_event.called upcoming_targets = [Target(f"tgt_{i}", [0, 0, 0], 1.0) for i in range(20)] diff --git a/tests/unittest/envs/general_satellite_tasking/simulation/test_simulator.py b/tests/unittest/envs/general_satellite_tasking/simulation/test_simulator.py index d6fa5efd..98898e06 100644 --- a/tests/unittest/envs/general_satellite_tasking/simulation/test_simulator.py +++ b/tests/unittest/envs/general_satellite_tasking/simulation/test_simulator.py @@ -53,6 +53,15 @@ def test_set_environment(self, simbase_init): assert sim.environment.sim == sim assert sim.environment.rate == sim.sim_rate + def test_delete_event(self, simbase_init): + sim = self.mock_sim() + event = MagicMock() + sim.eventMap = {"event": event, "other": MagicMock()} + sim.eventList = [MagicMock(), event, MagicMock()] + sim.delete_event("event") + assert "event" not in sim.eventMap + assert event not in sim.eventList + @pytest.mark.parametrize( "start_time,step_duration,time_limit,stop_time", [(0, 100, 50, 50), (0, 100, 200, 100), (10, 10, 50, 20)],