Skip to content

Commit

Permalink
Issue #45: Tests for new features
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzzoQM committed Sep 15, 2023
1 parent 8487b4b commit e0ba845
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bsk_rl.envs.general_satellite_tasking.scenario import sat_actions as sa
from bsk_rl.envs.general_satellite_tasking.scenario import sat_observations as so
from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
NadirTarget,
StaticTargets,
)
from bsk_rl.envs.general_satellite_tasking.simulation import dynamics, environment, fsw
Expand Down Expand Up @@ -148,3 +149,42 @@ def test_desat_action(self):
assert np.linalg.norm(
self.env.satellite.dynamics.wheel_speeds
) < np.linalg.norm(init_speeds)


class TestNadirImagingActions:
class ImageSat(
sa.NadirImagingActions,
so.TimeState,
):
dyn_type = dynamics.ContinuousImagingDynModel
fsw_type = fsw.ContinuousImagingFSWModel

env = gym.make(
"SingleSatelliteTasking-v1",
satellites=ImageSat(
"EO-1",
n_ahead_act=10,
sat_args=ImageSat.default_sat_args(
oe=random_orbit,
imageAttErrorRequirement=0.05,
imageRateErrorRequirement=0.05,
instrumentBaudRate=1.0,
dataStorageCapacity=3.0,
transmitterBaudRate=-1.0,
),
),
env_type=environment.BasicEnvironmentModel,
env_args=environment.BasicEnvironmentModel.default_env_args(),
env_features=NadirTarget(),
data_manager=data.NoDataManager(),
sim_rate=1.0,
time_limit=10000.0,
max_step_duration=1e9,
disable_env_checker=True,
)

def test_image(self):
self.env.reset()
storage_init = self.env.satellite.dynamics.storage_level
self.env.step(0)
assert self.env.satellite.dynamics.storage_level > storage_init
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,81 @@ def test_calc_reward_custom_fn(self):
}
)
assert reward == approx(1.5)


class TestNadirScanningTimeData:
def test_add_null(self):
dat1 = data.NadirScanningTimeData()
dat2 = data.NadirScanningTimeData()
dat = dat1 + dat2
assert dat.scanning_time == 0.0

def test_add_to_null(self):
dat1 = data.NadirScanningTimeData(1.0)
dat2 = data.NadirScanningTimeData()
dat = dat1 + dat2
assert dat.scanning_time == 1.0

def test_add(self):
dat1 = data.NadirScanningTimeData(1.0)
dat2 = data.NadirScanningTimeData(3.0)
dat = dat1 + dat2
assert dat.scanning_time == 4.0


class TestScanningNadirTimeStore:
def test_get_log_state(self):
sat = MagicMock()
sat.dynamics.storageUnit.storageUnitDataOutMsg.read().storageLevel = 6
sat.dynamics.instrument.nodeBaudRate = 3
ds = data.ScanningNadirTimeStore(MagicMock(), sat)
assert ds._get_log_state() == 2.0

@pytest.mark.parametrize(
"before,after,new_time",
[
(0, 1, 1),
(1, 2, 1),
(1, 1, 0),
],
)
def test_compare_log_states(self, before, after, new_time):
sat = MagicMock()
ds = data.ScanningNadirTimeStore(MagicMock(), sat)
dat = ds._compare_log_states(before, after)
assert dat.scanning_time == new_time


class TestNadirScanningManager:
def test_calc_reward(self):
dm = data.NadirScanningManager(MagicMock())
dm.data = data.NadirScanningTimeData([])
reward = dm._calc_reward(
{
"sat1": data.NadirScanningTimeData(1),
"sat2": data.NadirScanningTimeData(2),
}
)
assert reward == approx(3)

def test_calc_reward_existing(self):
dm = data.NadirScanningManager(MagicMock())
dm.data = data.NadirScanningTimeData(1)
reward = dm._calc_reward(
{
"sat1": data.NadirScanningTimeData(2),
"sat2": data.NadirScanningTimeData(3),
}
)
assert reward == approx(5)

def test_calc_reward_custom_fn(self):
dm = data.NadirScanningManager(MagicMock(), reward_fnc=lambda x: 1 / x)
dm.data = data.NadirScanningTimeData([])
reward = dm._calc_reward(
{
"sat1": data.NadirScanningTimeData(2),
"sat2": data.NadirScanningTimeData(2),
}
)
assert reward == approx(1.0)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from bsk_rl.envs.general_satellite_tasking.scenario.environment_features import (
CityTargets,
NadirTarget,
StaticTargets,
Target,
lla2ecef,
Expand Down Expand Up @@ -146,3 +147,10 @@ def test_regenerate_targets_offset(self, mock_read_csv, mock_lla2ecef):
for target in ct.targets:
assert np.linalg.norm(target.location - nominal) <= 0.03
assert np.linalg.norm(target.location) == approx(1.0)


class TestNadirTarget:
def test_init(self):
st = NadirTarget()
assert st.name == "nadir"
assert st.location == [0, 0, 0]
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,28 @@ def test_set_action(self, sat_init, discrete_set, target):
sat.image.assert_called_once()


@patch.multiple(sa.NadirImagingActions, __abstractmethods__=set())
@patch(
"bsk_rl.envs.general_satellite_tasking.scenario.satellites.ImagingSatellite.__init__"
)
class TestNadirImagingActions:
def test_init(self, sat_init):
sat = sa.NadirImagingActions()
sat_init.assert_called_once()
assert sat.action_map == {"0-0": "nadirImage"}

class MockTarget(MagicMock, Target):
@property
def id(self):
return "target_1"

@pytest.mark.parametrize("target", [1, "target_1", MockTarget()])
def test_image(self, sat_init, target):
sat = sa.NadirImagingActions()
sat.task_target_for_imaging = MagicMock()
assert "nadir_image" == sat.nadirImage(target)


@patch.multiple(sa.ChargingAction, __abstractmethods__=set())
@patch.multiple(sa.DriftAction, __abstractmethods__=set())
@patch.multiple(sa.DesatAction, __abstractmethods__=set())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bsk_rl.envs.general_satellite_tasking.simulation import environment
from bsk_rl.envs.general_satellite_tasking.simulation.dynamics import (
BasicDynamicsModel,
ContinuousImagingDynModel,
DynamicsModel,
GroundStationDynModel,
ImagingDynModel,
Expand Down Expand Up @@ -256,3 +257,15 @@ def test_init_objects(self, *args):
GroundStationDynModel(MagicMock(simulator=MagicMock()), 1.0)
for setter in args:
setter.assert_called_once()


@patch(imdyn + "requires_env", MagicMock(return_value=[]))
@patch(imdyn + "_init_dynamics_objects", MagicMock())
class TestContinuousImagingDynModel:
def test_storage_properties(self):
dyn = ContinuousImagingDynModel(MagicMock(simulator=MagicMock()), 1.0)
dyn.storageUnit = MagicMock()
dyn.storageUnit.storageUnitDataOutMsg.read.return_value.storageLevel = 50.0
dyn.storageUnit.storageCapacity = 100.0
assert dyn.storage_level == 50.0
assert dyn.storage_level_fraction == 0.5

0 comments on commit e0ba845

Please sign in to comment.