From 37adacb724ca2f09f8df184d7a986ca921e620de Mon Sep 17 00:00:00 2001 From: Mark Stephenson Date: Tue, 24 Oct 2023 12:32:00 -0600 Subject: [PATCH] Issue #18: Maintain current action by selecting None action --- .../envs/general_satellite_tasking/gym_env.py | 14 ++++++++++-- .../scenario/satellites.py | 6 ++++- .../scenario/test_sat_observations.py | 2 ++ .../general_satellite_tasking/test_gym_env.py | 22 ++++++++++++++++++- 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/bsk_rl/envs/general_satellite_tasking/gym_env.py b/bsk_rl/envs/general_satellite_tasking/gym_env.py index 62e47937..b0e2e91d 100644 --- a/bsk_rl/envs/general_satellite_tasking/gym_env.py +++ b/bsk_rl/envs/general_satellite_tasking/gym_env.py @@ -5,6 +5,7 @@ from gymnasium import Env, spaces from bsk_rl.envs.general_satellite_tasking.scenario.communication import NoCommunication +from bsk_rl.envs.general_satellite_tasking.scenario.satellites import REQUIRES_RETASKING from bsk_rl.envs.general_satellite_tasking.simulation.simulator import Simulator from bsk_rl.envs.general_satellite_tasking.types import ( CommunicationMethod, @@ -220,7 +221,7 @@ def step( """Propagate the simulation, update information, and get rewards Args: - Joint action for satellites + Joint action for satellites. Can be none to maintain current task. Returns: observation, reward, terminated, truncated, info @@ -228,8 +229,17 @@ def step( if len(actions) != len(self.satellites): raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): + old_info = satellite.info satellite.info = [] # reset satellite info log - satellite.set_action(action) + if action is not None: + satellite.set_action(action) + else: + if REQUIRES_RETASKING in old_info: + print( + f"Satellite {satellite.id} requires retasking " + "but received no task." + ) + satellite.info.append(REQUIRES_RETASKING) previous_time = self.simulator.sim_time # should these be recorded in simulator self.simulator.run() diff --git a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py index 20c2df0c..31e2cf64 100644 --- a/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py +++ b/bsk_rl/envs/general_satellite_tasking/scenario/satellites.py @@ -35,6 +35,8 @@ SatObs = Any SatAct = Any +REQUIRES_RETASKING = "REQUIRES_RETASKING" + class Satellite(ABC): dyn_type: type["DynamicsModel"] # Type of dynamics model used by this satellite @@ -153,7 +155,7 @@ def set_fsw(self, fsw_rate: float) -> "FSWModel": def reset_post_sim(self) -> None: """Called in environment reset, after simulator initialization""" - pass + self.info.append(REQUIRES_RETASKING) @property def observation_space(self) -> spaces.Box: @@ -235,6 +237,7 @@ def _update_timed_terminal_event( [f"self.TotalSim.CurrentNanos * {macros.NANO2SEC} >= {t_close}"], [ self._info_command(f"timed termination at {t_close:.1f} " + info), + self._satellite_command + f".info.append('{REQUIRES_RETASKING}')", ] + extra_actions, terminal=self.variable_interval, @@ -747,6 +750,7 @@ def _update_image_event(self, target: Target) -> None: [ self._info_command(f"imaged {target}"), self._satellite_command + ".imaged += 1", + self._satellite_command + f".info.append('{REQUIRES_RETASKING}')", ], terminal=self.variable_interval, ) diff --git a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py index d7b0745c..e3834991 100644 --- a/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py +++ b/tests/unittest/envs/general_satellite_tasking/scenario/test_sat_observations.py @@ -100,12 +100,14 @@ def test_init(self, sat_init): def test_explicit_normalization(self, sat_init): sat = so.TimeState(normalization_time=10.0) + sat.info = MagicMock() sat.simulator = MagicMock(sim_time=1.0) sat.reset_post_sim() assert sat.normalized_time() == 0.1 def test_implicit_normalization(self, sat_init): sat = so.TimeState(normalization_time=None) + sat.info = MagicMock() sat.simulator = MagicMock(sim_time=1.0, time_limit=10.0) sat.reset_post_sim() assert sat.normalized_time() == 0.1 diff --git a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py index 7d267c56..8dcd6054 100644 --- a/tests/unittest/envs/general_satellite_tasking/test_gym_env.py +++ b/tests/unittest/envs/general_satellite_tasking/test_gym_env.py @@ -7,7 +7,10 @@ GeneralSatelliteTasking, SingleSatelliteTasking, ) -from bsk_rl.envs.general_satellite_tasking.scenario.satellites import Satellite +from bsk_rl.envs.general_satellite_tasking.scenario.satellites import ( + REQUIRES_RETASKING, + Satellite, +) class TestGeneralSatelliteTasking: @@ -144,6 +147,23 @@ def test_step_bad_action(self): with pytest.raises(ValueError): env.step((0,)) + @patch.multiple(Satellite, __abstractmethods__=set()) + def test_step_retask_needed(self, capfd): + mock_sat = MagicMock() + env = SingleSatelliteTasking( + satellites=[mock_sat], + env_type=MagicMock(), + env_features=MagicMock(), + data_manager=MagicMock(reward=MagicMock(return_value=25.0)), + ) + env.simulator = MagicMock(sim_time=101.0) + env.step(None) + assert REQUIRES_RETASKING not in mock_sat.info + mock_sat.info = [REQUIRES_RETASKING] + env.step(None) + assert REQUIRES_RETASKING in mock_sat.info + assert "requires retasking but received no task" in capfd.readouterr().out + @pytest.mark.parametrize("sat_death", [True, False]) @pytest.mark.parametrize("timeout", [True, False]) @pytest.mark.parametrize("terminate_on_time_limit", [True, False])