diff --git a/src/bsk_rl/env/gym_env.py b/src/bsk_rl/env/gym_env.py index ffb0d991..431d0fa7 100644 --- a/src/bsk_rl/env/gym_env.py +++ b/src/bsk_rl/env/gym_env.py @@ -2,15 +2,15 @@ Three classes are provided for creating environments: -+-------------------------------------+------------+-------------+--------------------------------------------------------------------+ -| Environment | API | Agent Count | Purpose | -+-------------------------------------+------------+-------------+--------------------------------------------------------------------+ -| :class:`SingleSatelliteTasking` | Gymnasium | 1 | Single-agent training; compatible with most RL libraries. | -+-------------------------------------+------------+-------------+--------------------------------------------------------------------+ -| :class:`GeneralSatelliteTasking` | Gymnasium | ≥1 | Multi-agent testing; actions and observations are given in tuples. | -+-------------------------------------+------------+-------------+--------------------------------------------------------------------+ -| :class:`MultiagentSatelliteTasking` | PettingZoo | ≥1 | Multi-agent training; compatible with multiagency RL libraries. | -+-------------------------------------+------------+-------------+--------------------------------------------------------------------+ ++-------------------------------------+------------+---------------+--------------------------------------------------------------------+ +| **Environment** | **API** |**Agent Count**| **Purpose** | ++-------------------------------------+------------+---------------+--------------------------------------------------------------------+ +| :class:`SingleSatelliteTasking` | Gymnasium | 1 | Single-agent training; compatible with most RL libraries. | ++-------------------------------------+------------+---------------+--------------------------------------------------------------------+ +| :class:`GeneralSatelliteTasking` | Gymnasium | ≥1 | Multi-agent testing; actions and observations are given in tuples. | ++-------------------------------------+------------+---------------+--------------------------------------------------------------------+ +| :class:`MultiagentSatelliteTasking` | PettingZoo | ≥1 | Multi-agent training; compatible with multiagency RL libraries. | ++-------------------------------------+------------+---------------+--------------------------------------------------------------------+ Environments are customized by passing keyword arguments to the environment constructor. When using ``gym.make``, the syntax looks like this: @@ -105,7 +105,9 @@ def __init__( function}, where function is called at reset to set the value (used for randomization). sim_rate: Rate for model simulation [s]. - max_step_duration: Maximum time to propagate sim at a step [s]. + max_step_duration: Maximum time to propagate sim at a step [s]. If + satellites are using variable interval actions, the step duration will + be less than or equal to this value. failure_penalty: Reward for satellite failure. Should be nonpositive. time_limit: Time at which to truncate the simulation [s]. terminate_on_time_limit: Send terminations signal time_limit instead of just @@ -463,7 +465,7 @@ def __init__(self, *args, **kwargs) -> None: *args: Passed to :class:`GeneralSatelliteTasking`. **kwargs: Passed to :class:`GeneralSatelliteTasking`. """ - pass + super().__init__(*args, **kwargs) def reset( self, seed: int | None = None, options=None diff --git a/src/bsk_rl/env/scenario/actions.py b/src/bsk_rl/env/scenario/actions.py index 91183e86..9cb84701 100644 --- a/src/bsk_rl/env/scenario/actions.py +++ b/src/bsk_rl/env/scenario/actions.py @@ -1,216 +1,416 @@ -"""Satellite action types can be used to add actions to the agents.""" - +"""Satellite action types can be used to add actions to an agent. + +To configure the observation, set the ``action_spec`` attribute of a +:class:`~bsk_rl.env.scenario.satellites.Satellite` subclass. For example: + +.. code-block:: python + + class MyActionSatellite(Satellite): + action_spec = [ + Charge(duration=60.0), + Desat(duration=30.0), + Downlink(duration=60.0), + Image(n_ahead_image=10), + ] + +Actions in an ``action_spec`` should all be of the same subclass of :class:`Action`. The +following actions are currently available: + +Discrete Actions: :class:`DiscreteAction` +----------------------------------------- +For integer-indexable, discrete actions. + ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| **Action** |**Count**| **Description** | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`DiscreteFSWAction` | 1 | Call an arbitrary ``@action`` decorated function in the :class:`~bsk_rl.env.simulation.fsw.FSWModel`. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`Charge` | 1 | Point the solar panels at the sun. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`Drift` | 1 | Do nothing. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`Desat` | 1 | Desaturate the reaction wheels with RCS thrusters. Needs to be called multiple times. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`Downlink` | 1 | Downlink data to any ground station that is in range. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ +| :class:`Image` | ≥1 | Image one of the next ``N`` upcoming, unimaged targets once in range. | ++----------------------------+---------+-------------------------------------------------------------------------------------------------------+ + +""" + +import logging +from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union + +if TYPE_CHECKING: # pragma: no cover + from bsk_rl.env.types import Satellite, Simulator import numpy as np from gymnasium import spaces from bsk_rl.env.scenario.environment_features import Target -from bsk_rl.env.scenario.satellites import ImagingSatellite, Satellite -from bsk_rl.utils.functional import bind, configurable +from bsk_rl.utils.functional import AbstractClassProperty, bind, configurable + + +def select_action_builder(satellite: "Satellite") -> "ActionBuilder": + """Identify the proper action builder based on a satellite's action spec. + + Args: + satellite: Satellite to build actions for. + Returns: + action builder of the appropriate type -class SatAction(Satellite): - """Base satellite subclass for composing actions.""" + :meta private: + """ + builder_types = [spec.builder_type for spec in satellite.action_spec] + if all([builder_type == builder_types[0] for builder_type in builder_types]): + return builder_types[0](satellite) + else: + raise NotImplementedError("Heterogenous action builders not supported.") + + +class ActionBuilder(ABC): + """:meta private:""" + + def __init__(self, satellite: "Satellite") -> None: + self.satellite = satellite + self.simulator: "Simulator" + self.action_spec = deepcopy(self.satellite.action_spec) + for act in self.action_spec: + act.link_satellite(self.satellite) + + def reset_post_sim(self) -> None: + """Perform any once-per-episode setup.""" + self.simulator = self.satellite.simulator # already a proxy + for act in self.action_spec: + act.link_simulator(self.simulator) # already a proxy + act.reset_post_sim() - pass + @property + @abstractmethod + def action_space(self) -> spaces.Space: + """Return the action space.""" + pass + @property + @abstractmethod + def action_description(self) -> Any: + """Return a description of the action space.""" + pass -class DiscreteSatAction(SatAction): - """Base satellite subclass for composing discrete actions.""" + @abstractmethod + def set_action(self, action: Any) -> None: + """Set the action to be taken.""" + pass - def __init__(self, *args, **kwargs) -> None: - """Construct satellite with discrete actions. - Actions are added to the satellite for each DiscreteSatAction subclass, and can - be accessed by index in order added. +class DiscreteActionBuilder(ActionBuilder): + """:meta private:""" + + def __init__(self, satellite: "Satellite") -> None: + super().__init__(satellite) + self.prev_action_key = None + + def reset_post_sim(self) -> None: + super().reset_post_sim() + self.prev_action_key = None + + @property + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(sum([act.n_actions for act in self.action_spec])) + + @property + def action_description(self) -> list[str]: + actions = [] + for act in self.action_spec: + if act.n_actions == 1: + actions.append(act.name) + else: + actions.extend([f"{act.name}_{i}" for i in range(act.n_actions)]) + return actions + + def set_action(self, action: int) -> None: + self.satellite._disable_timed_terminal_event() + if not np.issubdtype(type(action), np.integer): + logging.warning( + f"Action '{action}' is not an integer. Will attempt to use compatible set_action_override method." + ) + for act in self.action_spec: + try: + self.prev_action_key = act.set_action_override( + action, prev_action_key=self.prev_action_key + ) + return + except AttributeError: + pass + except TypeError: + pass + else: + raise ValueError( + f"Action '{action}' is not an integer and no compatible set_action_override method found." + ) + index = 0 + for act in self.action_spec: + if index + act.n_actions > action: + self.prev_action_key = act.set_action( + action - index, prev_action_key=self.prev_action_key + ) + return + index += act.n_actions + else: + raise ValueError(f"Action index {action} out of range.") + + +class Action(ABC): + builder_type: type[ActionBuilder] = AbstractClassProperty() #: :meta private: + + def __init__(self, name: str = "act") -> None: + """Base class for all actions. + + Args: + name: Name of the action. """ - super().__init__(*args, **kwargs) - self.action_list = [] - self.action_map = {} + self.name = name + self.satellite: "Satellite" + self.simulator: "Simulator" - def reset_pre_sim(self) -> None: - """Reset the previous action key.""" - self.prev_action_key = None # Used to avoid retasking of BSK tasks - return super().reset_pre_sim() + def link_satellite(self, satellite: "Satellite") -> None: + """Link the action to a satellite. - def add_action( - self, act_fn, act_name: Optional[str] = None, n_actions: Optional[int] = None - ): - """Add an action to the action map. + Args: + satellite: Satellite to link to + + :meta private: + """ + self.satellite = satellite # already a proxy + + def link_simulator(self, simulator: "Simulator") -> None: + """Link the action to a simulator. Args: - act_fn: Function to call when selecting action. Takes as a keyword - prev_action_key, used to avoid retasking of BSK models. Can accept an - integer argument. - act_name: String to refer to action. - n_actions: If not none, add action n_actions times, calling it with an - increasing integer argument for each subsequent action. + simulator: Simulator to link to + + :meta private: """ - if act_name is None: - act_name = act_fn.__name__ + self.simulator = simulator # already a proxy - if n_actions is None: - self.action_map[f"{len(self.action_list)}"] = act_name - self.action_list.append(act_fn) - else: - self.action_map[ - f"{len(self.action_list)}-{len(self.action_list)+n_actions-1}" - ] = act_name - for i in range(n_actions): - act_i = self.generate_indexed_action(act_fn, i) - act_i.__name__ = f"act_{act_fn.__name__}_{i}" - self.action_list.append(bind(self, deepcopy(act_i))) + def reset_post_sim(self) -> None: # pragma: no cover + """Perform any once-per-episode setup.""" + pass + + @abstractmethod + def set_action(self, action: Any) -> None: # pragma: no cover + """Execute code to perform an action.""" + pass + + +class DiscreteAction(Action): + builder_type = DiscreteActionBuilder + + def __init__(self, name: str = "discrete_act", n_actions: int = 1): + """Base class for discrete, integer-indexable actions. - def generate_indexed_action(self, act_fn, index: int): - """Create an indexed action function. + A discrete action may represent multiple indexed actions of the same type. - Makes an indexed action function from an action function that takes an index - as an argument. + Optionally, discrete actions may have a ``set_action_override`` function defined. + If the action passed to the satellite is not an integer, the satellite will iterate + over the ``action_spec`` and attempt to call ``set_action_override`` on each action + until one is successful. Args: - act_fn: Action function to index. - index: Index to pass to act_fn. + name: Name of the action. + n_actions: Number of actions available. """ + super().__init__(name=name) + self.n_actions = n_actions + + @abstractmethod + def set_action(self, action: int, prev_action_key=None) -> str: + """Activate an action by local index.""" + pass + + +class DiscreteFSWAction(DiscreteAction): + def __init__( + self, + fsw_action, + name=None, + duration: Optional[float] = None, + reset_task: bool = False, + ): + """Discrete action to task a flight software action function. - def act_i(self, prev_action_key=None) -> Any: - return getattr(self, act_fn.__name__)( - index, prev_action_key=prev_action_key - ) + This action executes a function of a :class:`~bsk_rl.env.simulation.fsw.FSWModel` + instance that takes no arguments, typically decorated with ``@action``. - return act_i + Args: + fsw_action: Name of the flight software function to task. + name: Name of the action. If not specified, defaults to the ``fsw_action`` name. + duration: Duration of the action in seconds. Defaults to a large value so that + the :class:`~bsk_rl.env.gym_env.GeneralSatelliteTasking` ``max_step_duration`` + controls step length. + reset_task: If true, reset the action if the previous action was the same. + Generally, this parameter should be false to ensure realistic, continuous + operation of satellite modes; however, some Basilisk modules may require + frequent resetting for normal operation. + """ + if name is None: + name = fsw_action + super().__init__(name=name, n_actions=1) + self.fsw_action = fsw_action + self.reset_task = reset_task + if duration is None: + duration = 1e9 + self.duration = duration + + def set_action(self, action: int, prev_action_key=None) -> str: + """Activate the ``fsw_action`` function. - def set_action(self, action: int): - """Call action function my index.""" - self._disable_timed_terminal_event() - self.prev_action_key = self.action_list[action]( - prev_action_key=self.prev_action_key - ) # Update prev action data to avoid retasking + Args: + action: Should always be ``1``. + prev_action_key: Previous action key. - @property - def action_space(self) -> spaces.Discrete: - """Infer action space.""" - return spaces.Discrete(len(self.action_list)) + Returns: + The name of the activated action. + """ + assert action == 0 + self.satellite.log_info(f"{self.name} tasked for {self.duration} seconds") + self.satellite._update_timed_terminal_event( + self.simulator.sim_time + self.duration, info=f"for {self.fsw_action}" + ) + if self.reset_task or prev_action_key != self.fsw_action: + getattr(self.satellite.fsw, self.fsw_action)() -def fsw_action_gen( - fsw_action: str, action_duration: float = 1e9, always_reset: bool = False -) -> type: - """Generate an action class for a FSW @action. + return self.fsw_action - Args: - fsw_action: Function name of FSW action. - action_duration: Time to task action for. - always_reset: Reset action if selected more than once in a row. - Returns: - Satellite action class with fsw_action action. - """ +class Charge(DiscreteFSWAction): + def __init__(self, name: Optional[str] = None, duration: Optional[float] = None): + """Action to enter a sun-pointing charging mode (:class:`~bsk_rl.env.simulation.fsw.BasicFSWModel.action_charge`). - @configurable - class FSWAction(DiscreteSatAction): - def __init__( - self, *args, action_duration: float = action_duration, **kwargs - ) -> None: - """Discrete action to perform a fsw action. - - Typically this is includes a function decorated by @action. - - Args: - action_duration: Time to act when action selected. [s] - args: Passed through to satellite - kwargs: Passed through to satellite - - """ - super().__init__(*args, **kwargs) - setattr(self, fsw_action + "_duration", action_duration) - - def act(self, prev_action_key=None) -> str: - """Activate action. - - Returns: - action key - """ - duration = getattr(self, fsw_action + "_duration") - self.log_info(f"{fsw_action} tasked for {duration} seconds") - self._disable_timed_terminal_event() - self._update_timed_terminal_event( - self.simulator.sim_time + duration, info=f"for {fsw_action}" - ) - if prev_action_key != fsw_action or always_reset: - getattr(self.fsw, fsw_action)() - return fsw_action + Charging will only occur if the satellite is in sunlight. - act.__name__ = f"act_{fsw_action}" + Args: + name: Action name. + duration: Time to task action, in seconds. + """ + super().__init__(fsw_action="action_charge", name=name, duration=duration) - self.add_action( - bind(self, act), - act_name=fsw_action, - ) - return FSWAction +class Drift(DiscreteFSWAction): + def __init__(self, name: Optional[str] = None, duration: Optional[float] = None): + """Action to disable all FSW tasks (:class:`~bsk_rl.env.simulation.fsw.BasicFSWModel.action_drift`). + Args: + name: Action name. + duration: Time to task action, in seconds. + """ + super().__init__(fsw_action="action_drift", name=name, duration=duration) -# Charges the satellite -ChargingAction = fsw_action_gen("action_charge") -# Disables all actuators and control -DriftAction = fsw_action_gen("action_drift") +class Desat(DiscreteFSWAction): + def __init__(self, name: Optional[str] = None, duration: Optional[float] = None): + """Action to desaturate reaction wheels (:class:`~bsk_rl.env.simulation.fsw.BasicFSWModel.action_desat`). -# Points in a specified direction while firing desat thrusters and desaturating wheels -DesatAction = fsw_action_gen("action_desat", always_reset=True) + This action must be called repeatedly to fully desaturate the reaction wheels. -# Points nadir while downlinking data -DownlinkAction = fsw_action_gen("action_downlink") + Args: + name: Action name. + duration: Time to task action, in seconds. + """ + super().__init__( + fsw_action="action_desat", name=name, duration=duration, reset_task=True + ) -@configurable -class ImagingActions(DiscreteSatAction, ImagingSatellite): - """Satellite subclass to add upcoming target imaging to action space.""" +class Downlink(DiscreteFSWAction): + def __init__(self, name: Optional[str] = None, duration: Optional[float] = None): + """Action to transmit data from the data buffer (:class:`~bsk_rl.env.simulation.fsw.ImagingFSWModel.action_downlink`). - def __init__(self, *args, n_ahead_act=10, **kwargs) -> None: - """Discrete action to image upcoming targets. + If not in range of a ground station (defined in + :class:`~bsk_rl.env.simulation.environment.GroundStationEnvModel`), no data will + be downlinked. Args: - n_ahead_act: Number of actions to include in action space. - args: Passed through to satellite - kwargs: Passed through to satellite + name: Action name. + duration: Time to task action, in seconds. """ - super().__init__(*args, **kwargs) - self.add_action(self.image, n_actions=n_ahead_act, act_name="image") + super().__init__(fsw_action="action_downlink", name=name, duration=duration) + - def image(self, target: Union[int, Target, str], prev_action_key=None) -> str: - """Activate imaging action. +class Scan(DiscreteFSWAction): + def __init__(self, name: Optional[str] = None, duration: Optional[float] = None): + """Action to collect data from a :class:`~bsk_rl.env.scenario.environment_features.UniformNadirFeature` (:class:`~bsk_rl.env.simulation.fsw.ContinuousImagingFSWModel.action_nadir_scan`). Args: - target: Target, in terms of upcoming index, Target, or ID, - prev_action_key: Previous action key + name: Action name. + duration: Time to task action, in seconds. + """ + super().__init__(fsw_action="action_nadir_scan", name=name, duration=duration) - Returns: - Target ID + +class Image(DiscreteAction): + def __init__( + self, + n_ahead_image: int, + name: str = "action_image", + ): + """Actions to image upcoming target (:class:`~bsk_rl.env.simulation.fsw.ImagingFSWModel.action_image`). + + Adds `n_ahead_image` actions to the action space, corresponding to the next + `n_ahead_image` unimaged targets. The action may be unsuccessful if the target + exits the satellite's field of regard before the satellite settles on the target + and takes an image. The action with stop as soon as the image is successfully + taken, or when the the target exits the field of regard. + + This action implements a `set_action_override` that allows a target to be tasked + based on the target's ID string or the Target object. + + Args: + name: Action name. + n_ahead_image: Number of unimaged, along-track targets to consider. """ - if np.issubdtype(type(target), np.integer): - self.log_info(f"target index {target} tasked") + from bsk_rl.env.scenario.satellites import ImagingSatellite + + self.satellite: "ImagingSatellite" + super().__init__(name=name, n_actions=n_ahead_image) - target = self.parse_target_selection(target) + def image( + self, target: Union[int, Target, str], prev_action_key: Optional[str] = None + ) -> str: + """:meta private:""" + target = self.satellite.parse_target_selection(target) if target.id != prev_action_key: - self.task_target_for_imaging(target) + self.satellite.task_target_for_imaging(target) else: - self.enable_target_window(target) + self.satellite.enable_target_window(target) return target.id - def set_action(self, action: Union[int, Target, str]): - """Allow the satellite to be tasked by Target or target id. + def set_action(self, action: int, prev_action_key: Optional[str] = None) -> str: + """Image a target by local index. + + Args: + action: Index of the target to image. + prev_action_key: Previous action key. - Allows for additional tasking modes in addition to action index-based tasking. + :meta_private: """ - self._disable_image_event() - if isinstance(action, (Target, str)): - self.prev_action_key = self.image(action, self.prev_action_key) - else: - super().set_action(action) + self.satellite.log_info(f"target index {action} tasked") + return self.image(action, prev_action_key) + + def set_action_override( + self, action: Union[Target, str], prev_action_key: Optional[str] = None + ) -> str: + """Image a target by target index, Target, or ID. + Args: + target: Target to image. + prev_action_key: Previous action key. -NadirImagingAction = fsw_action_gen("action_nadir_scan") + :meta_private: + """ + return self.image(action, prev_action_key) diff --git a/src/bsk_rl/env/scenario/observations.py b/src/bsk_rl/env/scenario/observations.py index 2d2e469e..08db20c4 100644 --- a/src/bsk_rl/env/scenario/observations.py +++ b/src/bsk_rl/env/scenario/observations.py @@ -1,8 +1,8 @@ """Satellite observation types can be used to add information to the observation. -:class:`SatObservation` provides an interface for creating new observation types. To -configure the observation, set the ``observation_spec`` attribute of the satellite -subclass. For example: +:class:`Observation` provides an interface for creating new observation types. To +configure the observation, set the ``observation_spec`` attribute of a +:class:`~bsk_rl.env.scenario.satellites.Satellite` subclass. For example: .. code-block:: python @@ -37,6 +37,8 @@ class MyObservationSatellite(Satellite): from copy import deepcopy from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from gymnasium import spaces + if TYPE_CHECKING: # pragma: no cover from bsk_rl.env.types import Satellite, Simulator @@ -48,15 +50,41 @@ class MyObservationSatellite(Satellite): logger = logging.getLogger(__name__) +def obs_dict_to_space(obs_dict): + """Convert an observation dictionary to a gym space. + + Args: + obs_dict: Observation dictionary + + Returns: + gym.Space: Observation space + + :meta private: + """ + if isinstance(obs_dict, dict): + return spaces.Dict( + {key: obs_dict_to_space(value) for key, value in obs_dict.items()} + ) + elif isinstance(obs_dict, list): + return spaces.Box( + low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64 + ) + elif isinstance(obs_dict, (float, int)): + return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64) + else: + return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64) + + class ObservationBuilder: - """:meta private:""" # noqa: D415 + """:meta private:""" def __init__(self, satellite: "Satellite", obs_type: type = np.ndarray) -> None: """Satellite subclass for composing observations. Args: satellite: Satellite to observe - obs_type: Datatype of satellite's returned observation + obs_type: Datatype of satellite's returned observation. Can be ``np.ndarray`` + (default), ``dict``, or ``list``. """ self.obs_type = obs_type self.obs_dict_cache = None @@ -121,8 +149,24 @@ def get_obs(self) -> Union[dict, np.ndarray, list]: else: raise ValueError(f"Invalid observation type: {self.obs_type}") + @property + def observation_space(self) -> spaces.Space: + """Space of the observation.""" + obs = self.get_obs() + if isinstance(obs, (list, np.ndarray)): + return spaces.Box(low=-1e16, high=1e16, shape=obs.shape, dtype=np.float64) + elif isinstance(obs, dict): + return obs_dict_to_space(obs) + else: + raise ValueError(f"Invalid observation type: {self.obs_type}") + + @property + def observation_description(self) -> Any: + """Human-interpretable description of observation space.""" + return self.obs_array_keys() + -class SatObservation(ABC): +class Observation(ABC): """Base observations class.""" def __init__(self, name: str = "obs") -> None: @@ -165,7 +209,7 @@ def get_obs(self) -> Any: pass -class SatProperties(SatObservation): +class SatProperties(Observation): """Add arbitrary `dynamics` and `fsw` .""" def __init__( @@ -239,7 +283,7 @@ def get_obs(self) -> dict[str, Any]: return obs -class Time(SatObservation): +class Time(Observation): def __init__(self, norm=None, name="time"): """Include the simulation time in the observation. @@ -276,7 +320,7 @@ def _target_angle(sat, opp): return np.arccos(np.dot(vector_target_spacecraft_P_hat, sat.fsw.c_hat_P)) -class OpportunityProperties(SatObservation): +class OpportunityProperties(Observation): _fn_map = { "priority": lambda sat, opp: opp[opp["type"]].priority, @@ -291,7 +335,7 @@ class OpportunityProperties(SatObservation): def __init__( self, *target_properties: dict[str, Any], - n_ahead_observe: int = 10, + n_ahead_observe: int, type="target", name=None, ): @@ -401,7 +445,7 @@ def get_obs(self): return obs -class Eclipse(SatObservation): +class Eclipse(Observation): def __init__(self, norm=5700.0, name="eclipse"): """Include a tuple of the next eclipse start and end times in the observation. diff --git a/src/bsk_rl/env/scenario/satellites.py b/src/bsk_rl/env/scenario/satellites.py index 0bb0f343..9135b905 100644 --- a/src/bsk_rl/env/scenario/satellites.py +++ b/src/bsk_rl/env/scenario/satellites.py @@ -3,7 +3,7 @@ import bisect import inspect import logging -from abc import ABC, abstractmethod +from abc import ABC from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from weakref import proxy @@ -12,7 +12,7 @@ DynamicsModel, FSWModel, Simulator, - SatObservation, + Observation, ) import numpy as np @@ -20,6 +20,7 @@ from gymnasium import spaces from scipy.optimize import minimize_scalar, root_scalar +from bsk_rl.env.scenario.actions import select_action_builder from bsk_rl.env.scenario.data import DataStore, UniqueImageStore from bsk_rl.env.scenario.environment_features import Target from bsk_rl.env.scenario.observations import ObservationBuilder @@ -41,7 +42,8 @@ class Satellite(ABC): dyn_type: type["DynamicsModel"] = AbstractClassProperty() fsw_type: type["FSWModel"] = AbstractClassProperty() - observation_spec: list["SatObservation"] = AbstractClassProperty() + observation_spec: list["Observation"] = AbstractClassProperty() + action_spec: list["Action"] = AbstractClassProperty() @classmethod def default_sat_args(cls, **kwargs) -> dict[str, Any]: @@ -95,6 +97,7 @@ def __init__( self.variable_interval = variable_interval self._timed_terminal_event_name = None self.observation_builder = ObservationBuilder(self, obs_type=obs_type) + self.action_builder = select_action_builder(self) @property def id(self) -> str: @@ -164,27 +167,35 @@ def set_fsw(self, fsw_rate: float) -> "FSWModel": def reset_post_sim(self) -> None: """Reset in environment reset, after simulator initialization.""" self.observation_builder.reset_post_sim() + self.action_builder.reset_post_sim() @property - def observation_space(self) -> spaces.Box: + def observation_space(self) -> spaces.Space: """Observation space for single satellite, determined from observation. Returns: gymanisium observation space """ - return spaces.Box( - low=-1e16, high=1e16, shape=self.get_obs().shape, dtype=np.float64 - ) + return self.observation_builder.observation_space + + @property + def observation_description(self) -> Any: + """Human-interpretable description of observation space.""" + return self.observation_builder.observation_description @property - @abstractmethod # pragma: no cover def action_space(self) -> spaces.Space: """Action space for single satellite. Returns: gymanisium action space """ - pass + return self.action_builder.action_space + + @property + def action_description(self) -> Any: + """Human-interpretable description of action space.""" + return self.action_builder.action_description def is_alive(self, log_failure=False) -> bool: """Check if the satellite is violating any aliveness requirements. @@ -273,8 +284,7 @@ def get_obs(self) -> SatObs: """ return self.observation_builder.get_obs() - @abstractmethod # pragma: no cover - def set_action(self, action: int) -> None: + def set_action(self, action: Any) -> None: """Enable certain processes in the simulator to command the satellite task. Should call an @action from FSW, among other things. @@ -282,7 +292,7 @@ def set_action(self, action: int) -> None: Args: action: action index """ - pass + self.action_builder.set_action(action) class AccessSatellite(Satellite): diff --git a/src/bsk_rl/env/types.py b/src/bsk_rl/env/types.py index 41e1cd97..415f65a5 100644 --- a/src/bsk_rl/env/types.py +++ b/src/bsk_rl/env/types.py @@ -4,7 +4,7 @@ from bsk_rl.env.scenario.communication import CommunicationMethod from bsk_rl.env.scenario.data import DataManager, DataStore, DataType from bsk_rl.env.scenario.environment_features import EnvironmentFeatures -from bsk_rl.env.scenario.observations import SatObservation +from bsk_rl.env.scenario.observations import Observation from bsk_rl.env.scenario.satellites import Satellite from bsk_rl.env.simulation.dynamics import DynamicsModel from bsk_rl.env.simulation.environment import EnvironmentModel diff --git a/tests/integration/env/scenario/test_int_communication.py b/tests/integration/env/scenario/test_int_communication.py index 2e4ef293..ecdb07fa 100644 --- a/tests/integration/env/scenario/test_int_communication.py +++ b/tests/integration/env/scenario/test_int_communication.py @@ -35,18 +35,18 @@ ) -class FullFeaturedSatellite(sats.SteeringImagerSatellite, act.ImagingActions): +class FullFeaturedSatellite(sats.SteeringImagerSatellite): observation_spec = [ obs.SatProperties(dict(prop="r_BN_P", module="dynamics", norm=6e6)), obs.Time(), ] + action_spec = [act.Image(n_ahead_image=10)] def make_communication_env(oes, comm_type): satellites = [ FullFeaturedSatellite( "EO-1", - n_ahead_act=10, sat_args=FullFeaturedSatellite.default_sat_args( oe=oe, imageAttErrorRequirement=0.05, diff --git a/tests/integration/env/scenario/test_int_environment_features.py b/tests/integration/env/scenario/test_int_environment_features.py index 7253afe2..f779a7d7 100644 --- a/tests/integration/env/scenario/test_int_environment_features.py +++ b/tests/integration/env/scenario/test_int_environment_features.py @@ -3,24 +3,23 @@ from bsk_rl.env.scenario import actions as act from bsk_rl.env.scenario import data from bsk_rl.env.scenario import observations as obs +from bsk_rl.env.scenario import satellites as sats from bsk_rl.env.scenario.environment_features import CityTargets, StaticTargets from bsk_rl.env.simulation import dynamics, environment, fsw from bsk_rl.utils.orbital import random_orbit def make_env(env_features): - class ImageSat( - act.ImagingActions, - ): + class ImageSat(sats.ImagingSatellite): dyn_type = dynamics.GroundStationDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Image(n_ahead_image=10)] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - n_ahead_act=10, sat_args=ImageSat.default_sat_args( oe=random_orbit, imageAttErrorRequirement=0.05, diff --git a/tests/integration/env/scenario/test_int_sat_actions.py b/tests/integration/env/scenario/test_int_sat_actions.py index cb61c7c4..b907747c 100644 --- a/tests/integration/env/scenario/test_int_sat_actions.py +++ b/tests/integration/env/scenario/test_int_sat_actions.py @@ -5,6 +5,7 @@ from bsk_rl.env.scenario import actions as act from bsk_rl.env.scenario import data from bsk_rl.env.scenario import observations as obs +from bsk_rl.env.scenario import satellites as sats from bsk_rl.env.scenario.environment_features import StaticTargets, UniformNadirFeature from bsk_rl.env.simulation import dynamics, environment, fsw from bsk_rl.utils.orbital import random_orbit @@ -15,19 +16,16 @@ class TestImagingAndDownlink: - class ImageSat( - act.ImagingActions, - act.DownlinkAction, - ): + class ImageSat(sats.SteeringImagerSatellite): dyn_type = dynamics.GroundStationDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Downlink(), act.Image(n_ahead_image=10)] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - # n_ahead_act=10, sat_args=ImageSat.default_sat_args( oe=random_orbit, imageAttErrorRequirement=0.05, @@ -75,12 +73,11 @@ def test_image_by_name(self): class TestChargingAction: - class ChargeSat( - act.ChargingAction, - ): + class ChargeSat(sats.Satellite): dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel observation_spec = [obs.Time()] + action_spec = [act.Charge()] env = gym.make( "SingleSatelliteTasking-v1", @@ -109,12 +106,11 @@ def test_charging_action(self): class TestDesatAction: - class DesatSat( - act.DesatAction, - ): + class DesatSat(sats.Satellite): dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel observation_spec = [obs.Time()] + action_spec = [act.Desat()] def make_env(self): return gym.make( @@ -171,18 +167,16 @@ def test_desat_action_pointing(self): class TestNadirImagingActions: - class ImageSat( - act.NadirImagingAction, - ): + class ImageSat(sats.Satellite): dyn_type = dynamics.ContinuousImagingDynModel fsw_type = fsw.ContinuousImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Scan()] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - # n_ahead_act=10, sat_args=ImageSat.default_sat_args( oe=random_orbit, imageAttErrorRequirement=0.05, diff --git a/tests/integration/env/scenario/test_int_sat_observations.py b/tests/integration/env/scenario/test_int_sat_observations.py index ccc71d24..2a3af7f7 100644 --- a/tests/integration/env/scenario/test_int_sat_observations.py +++ b/tests/integration/env/scenario/test_int_sat_observations.py @@ -15,10 +15,7 @@ # Composed Observation Tests # ############################## class TestComposedState: - class ComposedPropSat( - act.DriftAction, - sats.ImagingSatellite, - ): + class ComposedPropSat(sats.ImagingSatellite): dyn_type = dynamics.ImagingDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [ @@ -30,6 +27,7 @@ class ComposedPropSat( obs.OpportunityProperties(dict(prop="priority"), n_ahead_observe=2), obs.Eclipse(), ] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", @@ -60,7 +58,7 @@ def test_normd_property_state(self): class TestSatProperties: - class SatPropertiesSat(act.DriftAction): + class SatPropertiesSat(sats.Satellite): dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel observation_spec = [ @@ -69,6 +67,7 @@ class SatPropertiesSat(act.DriftAction): dict(prop="r_BN_N", norm=7000.0 * 1e3), ), ] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", @@ -93,10 +92,11 @@ def test_normd_property_state(self): class TestTime: - class TimedSat(act.DriftAction): + class TimedSat(sats.Satellite): dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel observation_spec = [obs.Time()] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", @@ -120,12 +120,13 @@ def test_normd_property_state(self): class TestOpportunityProperties: - class TargetSat(act.DriftAction, sats.ImagingSatellite): + class TargetSat(sats.ImagingSatellite): dyn_type = dynamics.ImagingDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [ obs.OpportunityProperties(dict(prop="priority"), n_ahead_observe=2) ] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", @@ -149,10 +150,11 @@ def test_target_state(self): class TestEclipse: - class EclipseSat(act.DriftAction): + class EclipseSat(sats.Satellite): dyn_type = dynamics.BasicDynamicsModel fsw_type = fsw.BasicFSWModel observation_spec = [obs.Eclipse()] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", @@ -177,7 +179,7 @@ def test_eclipse_state(self): class TestGroundStationProperties: - class GroundSat(act.DriftAction, sats.AccessSatellite): + class GroundSat(sats.AccessSatellite): dyn_type = dynamics.GroundStationDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [ @@ -189,6 +191,7 @@ class GroundSat(act.DriftAction, sats.AccessSatellite): type="ground_station", ), ] + action_spec = [act.Drift()] env = gym.make( "SingleSatelliteTasking-v1", diff --git a/tests/integration/env/scenario/test_int_satellites.py b/tests/integration/env/scenario/test_int_satellites.py index ad30a9d9..dc987f6f 100644 --- a/tests/integration/env/scenario/test_int_satellites.py +++ b/tests/integration/env/scenario/test_int_satellites.py @@ -5,24 +5,23 @@ from bsk_rl.env.scenario import actions as act from bsk_rl.env.scenario import data from bsk_rl.env.scenario import observations as obs +from bsk_rl.env.scenario import satellites as sats from bsk_rl.env.scenario.environment_features import StaticTargets from bsk_rl.env.simulation import dynamics, fsw from bsk_rl.utils.orbital import random_orbit class TestImagingSatellite: - class ImageSat( - act.ImagingActions.configure(), - ): + class ImageSat(sats.ImagingSatellite): dyn_type = dynamics.ImagingDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Image(n_ahead_image=10)] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - n_ahead_act=10, initial_generation_duration=1000.0, generation_duration=100.0, sat_args=ImageSat.default_sat_args( diff --git a/tests/integration/env/simulation/test_int_dynamics.py b/tests/integration/env/simulation/test_int_dynamics.py index 573266d5..81b60363 100644 --- a/tests/integration/env/simulation/test_int_dynamics.py +++ b/tests/integration/env/simulation/test_int_dynamics.py @@ -4,6 +4,7 @@ from bsk_rl.env.scenario import actions as act from bsk_rl.env.scenario import data from bsk_rl.env.scenario import observations as obs +from bsk_rl.env.scenario import satellites as sats from bsk_rl.env.scenario.environment_features import StaticTargets from bsk_rl.env.simulation import dynamics, fsw from bsk_rl.utils.orbital import random_orbit @@ -27,19 +28,16 @@ class TestImagingDynModelStorage: ) def test_storageInit(self, storage_capacity, initial_storage): - class ImageSat( - act.ImagingActions, - act.DownlinkAction, - ): + class ImageSat(sats.ImagingSatellite): dyn_type = dynamics.ImagingDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Downlink(), act.Image(n_ahead_image=10)] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - # n_ahead_act=10, sat_args=ImageSat.default_sat_args( oe=random_orbit, dataStorageCapacity=storage_capacity, @@ -70,18 +68,16 @@ class ImageSat( ) def test_storageInit_downlink(self, storage_capacity, initial_storage): - class ImageSat( - act.DownlinkAction, - ): + class ImageSat(sats.ImagingSatellite): dyn_type = dynamics.FullFeaturedDynModel fsw_type = fsw.ImagingFSWModel observation_spec = [obs.Time()] + action_spec = [act.Downlink()] env = gym.make( "SingleSatelliteTasking-v1", satellites=ImageSat( "EO-1", - # n_ahead_act=10, sat_args=ImageSat.default_sat_args( oe=random_orbit, dataStorageCapacity=storage_capacity, diff --git a/tests/integration/env/test_int_full_environments.py b/tests/integration/env/test_int_full_environments.py index 29fb6e05..707d3f39 100644 --- a/tests/integration/env/test_int_full_environments.py +++ b/tests/integration/env/test_int_full_environments.py @@ -14,11 +14,12 @@ from bsk_rl.utils.orbital import random_orbit -class FullFeaturedSatellite(sats.SteeringImagerSatellite, act.ImagingActions): +class FullFeaturedSatellite(sats.SteeringImagerSatellite): observation_spec = [ obs.SatProperties(dict(prop="r_BN_P", module="dynamics", norm=6e6)), obs.Time(), ] + action_spec = [act.Image(n_ahead_image=10)] multi_env = gym.make( diff --git a/tests/integration/env/test_int_gym_env.py b/tests/integration/env/test_int_gym_env.py index 449f9d74..04198c41 100644 --- a/tests/integration/env/test_int_gym_env.py +++ b/tests/integration/env/test_int_gym_env.py @@ -10,10 +10,9 @@ from bsk_rl.utils.orbital import random_orbit -class DoNothingSatellite(sats.SteeringImagerSatellite, act.DriftAction): - observation_spec = [ - obs.Time(), - ] +class DoNothingSatellite(sats.SteeringImagerSatellite): + observation_spec = [obs.Time()] + action_spec = [act.Drift()] class TestSingleSatelliteTasking: @@ -75,7 +74,7 @@ class TestSingleSatelliteDeath: satellites=DoNothingSatellite( "Skydiver", sat_args=DoNothingSatellite.default_sat_args( - rN=[0, 0, 7e6], vN=[0, 0, -100.0] + rN=[0, 0, 7e6], vN=[0, 0, -100.0], oe=None ), ), env_features=StaticTargets(n_targets=0), diff --git a/tests/unittest/env/scenario/test_actions.py b/tests/unittest/env/scenario/test_actions.py index c5c5c217..f37c5d79 100644 --- a/tests/unittest/env/scenario/test_actions.py +++ b/tests/unittest/env/scenario/test_actions.py @@ -7,211 +7,128 @@ from bsk_rl.env.scenario.environment_features import Target -@patch.multiple(act.DiscreteSatAction, __abstractmethods__=set()) -@patch("bsk_rl.env.scenario.satellites.Satellite.__init__") -@patch( - "bsk_rl.env.scenario.satellites.Satellite.reset_pre_sim", - MagicMock, -) -class TestDiscreteSatAction: - def test_init(self, sat_init): - act.DiscreteSatAction() - sat_init.assert_called_once() - - mock_action = MagicMock(__name__="some_action") - - @pytest.mark.parametrize( - "kwargs,expected_map,expected_list", - [ - (dict(act_fn=mock_action), {"0": "some_action"}, [mock_action]), - ( - dict(act_fn=mock_action, act_name="new_name"), - {"0": "new_name"}, - [mock_action], - ), - ], - ) - def test_add_single_action(self, sat_init, kwargs, expected_map, expected_list): - sat = act.DiscreteSatAction() - sat.add_action(**kwargs) - assert sat.action_map == expected_map - assert sat.action_list == expected_list - - @pytest.mark.parametrize("n_actions", [1, 3]) - def test_add_multiple_actions(self, sat_init, n_actions): - sat = act.DiscreteSatAction() - sat.some_action = MagicMock( - __name__="some_action", side_effect=lambda x, prev_action_key=None: x +@patch.multiple(act.ActionBuilder, __abstractmethods__=set()) +class TestActionBuilder: + def test_init(self): + action_spec = [MagicMock() for _ in range(3)] + satellite = MagicMock(action_spec=action_spec) + ab = act.ActionBuilder(satellite) + for a in ab.action_spec: + a.link_satellite.assert_called_once() + + def test_reset_post_sim(self): + ab = act.ActionBuilder(MagicMock(action_spec=[MagicMock() for _ in range(3)])) + ab.reset_post_sim() + for a in ab.action_spec: + a.link_simulator.assert_called_once() + a.reset_post_sim.assert_called_once() + + +class TestDiscreteActionBuilder: + def test_action_space(self): + satellite = MagicMock( + action_spec=[MagicMock(n_actions=1), MagicMock(n_actions=2)] ) - sat.add_action(sat.some_action, n_actions=n_actions) - assert sat.action_map == {f"0-{n_actions-1}": "some_action"} - assert [act() for act in sat.action_list] == list(range(n_actions)) - sat.some_action.assert_has_calls( - [call(i, prev_action_key=None) for i in range(n_actions)] + ab = act.DiscreteActionBuilder(satellite) + assert ab.action_space == spaces.Discrete(3) + + def test_action_description(self): + satellite = MagicMock( + action_spec=[ + MagicMock(n_actions=1), + MagicMock(n_actions=2), + ] ) - - @patch("bsk_rl.env.scenario.satellites.Satellite._disable_timed_terminal_event") - def test_set_action(self, sat_init, disable_timed): - sat = act.DiscreteSatAction() - sat.reset_pre_sim() - sat.action_list = [MagicMock(return_value="act_key")] - sat.set_action(0) - disable_timed.assert_called_once() - sat.action_list[0].assert_called_once() - assert sat.prev_action_key == "act_key" - - def test_action_space(self, sat_init): - sat = act.DiscreteSatAction() - sat.action_list = [0, 1, 2] - assert sat.action_space == spaces.Discrete(3) - - def test_reset_pre_sim(self, sat_init): - sat = act.DiscreteSatAction() - sat.prev_action_key = "some_action" - sat.reset_pre_sim() - assert sat.prev_action_key is None - - -@patch.multiple(act.DiscreteSatAction, __abstractmethods__=set()) -@patch("bsk_rl.env.scenario.satellites.Satellite.__init__") -@patch( - "bsk_rl.env.scenario.satellites.Satellite.reset_pre_sim", - MagicMock, -) -class TestFSWAction: - def test_init(self, sat_init): - FSWAct = act.fsw_action_gen("cool_action") - sat = FSWAct(action_duration=10.0) - sat_init.assert_called_once() - assert sat.cool_action_duration == 10.0 - - def make_action_sat(self): - FSWAct = act.fsw_action_gen("cool_action", 60.0) - sat = FSWAct() - sat.reset_pre_sim() - sat.fsw = MagicMock(cool_action=MagicMock()) - sat.log_info = MagicMock() - sat._disable_timed_terminal_event = MagicMock() - sat._update_timed_terminal_event = MagicMock() - sat.simulator = MagicMock(sim_time=0.0) - return sat - - def test_act(self, sat_init): - sat = self.make_action_sat() - assert sat.action_list[0].__name__ == "act_cool_action" - sat.set_action(0) - assert "cool_action" == sat.prev_action_key - sat.log_info.assert_called_once_with("cool_action tasked for 60.0 seconds") - sat.fsw.cool_action.assert_called_once() - - def make_action_sat_configured(self): - FSWAct = act.fsw_action_gen("cool_action", 59.0).configure(action_duration=60.0) - sat = FSWAct() - sat.reset_pre_sim() - sat.fsw = MagicMock(cool_action=MagicMock()) - sat.log_info = MagicMock() - sat._disable_timed_terminal_event = MagicMock() - sat._update_timed_terminal_event = MagicMock() - sat.simulator = MagicMock(sim_time=0.0) - return sat - - def test_act_configured(self, sat_init): - sat = self.make_action_sat_configured() - assert sat.action_list[0].__name__ == "act_cool_action" - sat.set_action(0) - assert "cool_action" == sat.prev_action_key - sat.log_info.assert_called_once_with("cool_action tasked for 60.0 seconds") - sat.fsw.cool_action.assert_called_once() - - def test_retask(self, sat_init): - sat = self.make_action_sat() - sat.set_action(0) - sat.set_action(0) - sat.fsw.cool_action.assert_called_once() - - -@patch.multiple(act.ImagingActions, __abstractmethods__=set()) -@patch("bsk_rl.env.scenario.satellites.ImagingSatellite.__init__") -class TestImagingActions: - def test_init(self, sat_init): - sat = act.ImagingActions(n_ahead_act=3) - sat_init.assert_called_once() - assert sat.action_map == {"0-2": "image"} - - class MockTarget(MagicMock, Target): - @property - def id(self): - return "target_1" - - @pytest.mark.parametrize("target", [1, "target_1", MockTarget()]) - def test_image(self, sat_init, target): - sat = act.ImagingActions() - sat.log_info = MagicMock() - sat.parse_target_selection = MagicMock(return_value=self.MockTarget()) - sat.task_target_for_imaging = MagicMock() - assert "target_1" == sat.image(target) - sat.task_target_for_imaging.assert_called_once_with( - sat.parse_target_selection() + satellite.action_spec[0].name = "foo" + satellite.action_spec[1].name = "bar" + ab = act.DiscreteActionBuilder(satellite) + assert ab.action_description == ["foo", "bar_0", "bar_1"] + + def test_set_action(self): + satellite = MagicMock( + action_spec=[ + MagicMock(n_actions=1, set_action=MagicMock(return_value="foo")), + MagicMock(n_actions=2, set_action=MagicMock(return_value="bar")), + MagicMock(n_actions=1, set_action=MagicMock(return_value="baz")), + ] + ) + ab = act.DiscreteActionBuilder(satellite) + ab.set_action(0) + assert ab.action_spec[0].set_action.call_args == call(0, prev_action_key=None) + ab.set_action(1) + assert ab.action_spec[1].set_action.call_args == call(0, prev_action_key="foo") + ab.set_action(2) + assert ab.action_spec[1].set_action.call_args == call(1, prev_action_key="bar") + ab.set_action(3) + assert ab.action_spec[2].set_action.call_args == call(0, prev_action_key="bar") + + def test_set_action_override(self): + satellite = MagicMock( + action_spec=[ + MagicMock(n_actions=1, set_action_override=None), + MagicMock(n_actions=2, set_action_override=MagicMock()), + ] ) + ab = act.DiscreteActionBuilder(satellite) + ab.set_action("foo") + assert ab.action_spec[1].set_action_override.call_args == call( + "foo", prev_action_key=None + ) + - @pytest.mark.parametrize("target", [1, "target_1", MockTarget()]) - def test_image_retask(self, sat_init, target): - sat = act.ImagingActions() - sat.log_info = MagicMock() - sat.enable_target_window = MagicMock() - sat.parse_target_selection = MagicMock(return_value=self.MockTarget()) - sat.task_target_for_imaging = MagicMock() - sat.image(target, prev_action_key="target_1") - sat.task_target_for_imaging.assert_not_called() - sat.enable_target_window.assert_called() - - @patch("bsk_rl.env.scenario.actions.DiscreteSatAction.set_action") - @pytest.mark.parametrize("target", [1, "target_1", MockTarget()]) - def test_set_action(self, sat_init, discrete_set, target): - sat = act.ImagingActions() - sat.prev_action_key = None - sat._disable_image_event = MagicMock() - sat.image = MagicMock() - sat.set_action(target) - sat._disable_image_event.assert_called() - if isinstance(target, int): - discrete_set.assert_called_once() - elif isinstance(target, (Target, str)): - sat.image.assert_called_once() - - -@patch.multiple(act.NadirImagingAction, __abstractmethods__=set()) -@patch("bsk_rl.env.scenario.satellites.Satellite.__init__") -class TestNadirImagingActions: - def test_init(self, sat_init): - sat = act.NadirImagingAction() - sat_init.assert_called_once() - assert sat.action_map == {"0": "action_nadir_scan"} - - -@patch.multiple(act.ChargingAction, __abstractmethods__=set()) -@patch.multiple(act.DriftAction, __abstractmethods__=set()) -@patch.multiple(act.DesatAction, __abstractmethods__=set()) -@patch.multiple(act.DownlinkAction, __abstractmethods__=set()) -@patch.multiple(act.ImagingActions, __abstractmethods__=set()) -@patch("bsk_rl.env.scenario.satellites.ImagingSatellite.__init__") -def test_combination(sat_init): - class ComboAct( - act.ImagingActions.configure(n_ahead_act=3), - act.DownlinkAction, - act.DesatAction, - act.DriftAction, - act.ChargingAction, - ): - pass - - sat = ComboAct() - assert sat.action_map == { - "0": "action_charge", - "1": "action_drift", - "2": "action_desat", - "3": "action_downlink", - "4-6": "image", - } - assert len(sat.action_list) == 7 - sat_init.assert_called_once() +class TestDiscreteFSWAction: + def test_set_action(self): + fswact = act.DiscreteFSWAction("action_fsw") + fswact.satellite = MagicMock() + fswact.simulator = MagicMock() + fswact.set_action(0) + fswact.satellite.fsw.action_fsw.assert_called_once() + + def test_set_action_again(self): + fswact = act.DiscreteFSWAction("action_fsw") + fswact.satellite = MagicMock() + fswact.simulator = MagicMock() + fswact.set_action(0, prev_action_key="action_fsw") + fswact.satellite.fsw.action_fsw.assert_not_called() + + def test_set_action_reset(self): + fswact = act.DiscreteFSWAction("action_fsw", reset_task=True) + fswact.satellite = MagicMock() + fswact.simulator = MagicMock() + fswact.set_action(0, prev_action_key="action_fsw") + fswact.satellite.fsw.action_fsw.assert_called_once() + + +class TestImage: + target = MagicMock() + target.id = "target_1" + + def test_image(self): + image = act.Image(n_ahead_image=10) + image.satellite = MagicMock() + image.satellite.parse_target_selection.return_value = self.target + out = image.image(5, None) + image.satellite.task_target_for_imaging.assert_called_once_with(self.target) + assert out == "target_1" + + def test_image_retask(self): + image = act.Image(n_ahead_image=10) + image.satellite = MagicMock() + image.satellite.parse_target_selection.return_value = self.target + out = image.image(5, "target_1") + image.satellite.enable_target_window.assert_called_once_with(self.target) + assert out == "target_1" + + def test_set_action(self): + image = act.Image(n_ahead_image=10) + image.satellite = MagicMock() + image.image = MagicMock() + image.set_action(5) + image.image.assert_called_once_with(5, None) + + def test_set_action_override(self): + image = act.Image(n_ahead_image=10) + image.satellite = MagicMock() + image.image = MagicMock() + image.set_action_override("image") + image.image.assert_called_once_with("image", None) diff --git a/tests/unittest/env/scenario/test_observations.py b/tests/unittest/env/scenario/test_observations.py index 6224be85..bf6d46be 100644 --- a/tests/unittest/env/scenario/test_observations.py +++ b/tests/unittest/env/scenario/test_observations.py @@ -2,6 +2,7 @@ import numpy as np import pytest +from gymnasium import spaces from bsk_rl.env.scenario import observations as obs @@ -64,6 +65,41 @@ def test_obs_cache(self): ob.satellite.simulator.sim_time = 1.0 assert ob.get_obs()[0] == 1 + @pytest.mark.parametrize( + "observation,space", + [ + ( + np.array([1]), + spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64), + ), + ( + np.array([1, 2]), + spaces.Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64), + ), + ( + {"a": 1, "b": {"c": 1}}, + spaces.Dict( + { + "a": spaces.Box( + low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + ), + "b": spaces.Dict( + { + "c": spaces.Box( + low=-1e16, high=1e16, shape=(1,), dtype=np.float64 + ) + } + ), + } + ), + ), + ], + ) + def test_obs_space(self, observation, space): + ob = obs.ObservationBuilder(MagicMock()) + ob.get_obs = MagicMock(return_value=observation) + assert ob.observation_space == space + class TestSatProperties: def test_init(self): @@ -133,6 +169,7 @@ def test_init(self): dict( prop="double_priority", fn=lambda sat, opp: opp["target"].priority * 2.0 ), + n_ahead_observe=2, ) assert ob.target_properties[0]["fn"] assert ob.target_properties[0]["name"] == "location_normd" @@ -157,7 +194,7 @@ def test_get_obs(self): def test_init_bad(self): with pytest.raises(ValueError): - obs.OpportunityProperties(dict(prop="not_a_prop")) + obs.OpportunityProperties(dict(prop="not_a_prop"), n_ahead_observe=2) class TestEclipse: diff --git a/tests/unittest/env/scenario/test_satellites.py b/tests/unittest/env/scenario/test_satellites.py index e87688ce..7287efc2 100644 --- a/tests/unittest/env/scenario/test_satellites.py +++ b/tests/unittest/env/scenario/test_satellites.py @@ -14,6 +14,7 @@ @patch.multiple(sats.Satellite, __abstractmethods__=set()) @patch("bsk_rl.env.scenario.satellites.Satellite.observation_spec", MagicMock()) +@patch("bsk_rl.env.scenario.satellites.Satellite.action_spec", [MagicMock()]) class TestSatellite: sats.Satellite.dyn_type = MagicMock(with_defaults=MagicMock(defaults={"a": 1})) Task.with_defaults = MagicMock(defaults={"c": 3}) @@ -67,18 +68,6 @@ def test_generate_sat_args(self): # assert sat.info == [] # assert sat._timed_terminal_event_name is None - @pytest.mark.parametrize( - "obs,space", - [ - (np.array([1]), Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64)), - (np.array([1, 2]), Box(low=-1e16, high=1e16, shape=(2,), dtype=np.float64)), - ], - ) - def test_obs_space(self, obs, space): - sat = sats.Satellite(name="TestSat", sat_args={}) - sat.get_obs = MagicMock(return_value=obs) - assert sat.observation_space == space - @pytest.mark.parametrize( "dyn_state,fsw_state", [(False, False), (False, True), (True, False), (True, True)],