diff --git a/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py b/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py index 7df395d7..a09d67e8 100644 --- a/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py +++ b/bsk_rl/envs/general_satellite_tasking/simulation/dynamics.py @@ -684,13 +684,16 @@ def _set_transmitter_power_sink( ) self.powerMonitor.addPowerNodeToModel(self.transmitterPowerSink.nodePowerOutMsg) - @default_args(dataStorageCapacity=20 * 8e6, bufferNames=None) + @default_args( + dataStorageCapacity=20 * 8e6, bufferNames=None, storageUnitValidCheck=True + ) def _set_storage_unit( self, dataStorageCapacity: int, transmitterNumBuffers: Optional[int] = None, bufferNames: Optional[Iterable[str]] = None, priority: int = 699, + storageUnitValidCheck: bool = True, **kwargs, ) -> None: """Configure the storage unit and its buffers. @@ -701,12 +704,15 @@ def _set_storage_unit( given. bufferNames: List of buffer names to use. Named by number if None. priority: Model priority. + storageUnitValidCheck: If True, check that the storage level is below the + storage capacity. """ self.storageUnit = partitionedStorageUnit.PartitionedStorageUnit() self.storageUnit.ModelTag = "storageUnit" + self.satellite.id self.storageUnit.storageCapacity = dataStorageCapacity # bits self.storageUnit.addDataNodeToModel(self.instrument.nodeDataOutMsg) self.storageUnit.addDataNodeToModel(self.transmitter.nodeDataOutMsg) + self.storageUnitValidCheck = storageUnitValidCheck # Add all of the targets to the data buffer if bufferNames is None: for buffer_idx in range(transmitterNumBuffers): @@ -733,7 +739,13 @@ def _set_storage_unit( @aliveness_checker def data_storage_valid(self) -> bool: """Check that the buffer has not run out of space.""" - return self.storage_level <= self.storageUnit.storageCapacity + storage_check = self.storageUnitValidCheck + if storage_check: + return self.storage_level < self.storageUnit.storageCapacity or np.isclose( + self.storage_level, self.storageUnit.storageCapacity + ) + else: + return True @default_args( groundLocationPlanetRadius=orbitalMotion.REQ_EARTH * 1e3, @@ -809,11 +821,12 @@ def _set_instrument( self.task_name, self.instrument, ModelPriority=priority ) - @default_args(dataStorageCapacity=20 * 8e6) + @default_args(dataStorageCapacity=20 * 8e6, storageUnitValidCheck=True) def _set_storage_unit( self, dataStorageCapacity: int, priority: int = 699, + storageUnitValidCheck: bool = True, **kwargs, ) -> None: """Configure the storage unit and its buffers. @@ -821,12 +834,15 @@ def _set_storage_unit( Args: dataStorageCapacity: Maximum data to be stored [bits] priority: Model priority. + storageUnitValidCheck: If True, check that the storage level is below the + storage capacity. """ self.storageUnit = simpleStorageUnit.SimpleStorageUnit() self.storageUnit.ModelTag = "storageUnit" + self.satellite.id self.storageUnit.storageCapacity = dataStorageCapacity # bits self.storageUnit.addDataNodeToModel(self.instrument.nodeDataOutMsg) self.storageUnit.addDataNodeToModel(self.transmitter.nodeDataOutMsg) + self.storageUnitValidCheck = storageUnitValidCheck # Add the storage unit to the transmitter self.transmitter.addStorageUnitToTransmitter( diff --git a/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py b/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py index c57d4077..866836ab 100644 --- a/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py +++ b/tests/unittest/envs/general_satellite_tasking/simulation/test_dynamics.py @@ -210,12 +210,20 @@ def test_storage_properties(self): assert dyn.storage_level_fraction == 0.5 @pytest.mark.parametrize( - "level,valid", - [(10, True), (0, True), (110, False)], + "level,valid_check,valid", + [ + (10, True, True), + (0, True, True), + (110, True, False), + (100.001, True, True), + (10, False, True), + (110, False, True), + ], ) - def test_data_storage_valid(self, level, valid): + def test_data_storage_valid(self, level, valid_check, valid): dyn = ImagingDynModel(MagicMock(simulator=MagicMock()), 1.0) dyn.storageUnit = MagicMock() + dyn.storageUnitValidCheck = valid_check dyn.storageUnit.storageUnitDataOutMsg.read.return_value.storageLevel = level dyn.storageUnit.storageCapacity = 100.0 assert dyn.data_storage_valid() == valid