diff --git a/src/bsk_rl/__init__.py b/src/bsk_rl/__init__.py index b5ce49e..78a1c78 100644 --- a/src/bsk_rl/__init__.py +++ b/src/bsk_rl/__init__.py @@ -2,7 +2,12 @@ from gymnasium.envs.registration import register from bsk_rl.check_bsk_version import check_bsk_version -from bsk_rl.gym import ConstellationTasking, GeneralSatelliteTasking, SatelliteTasking +from bsk_rl.gym import ( + NO_ACTION, + ConstellationTasking, + GeneralSatelliteTasking, + SatelliteTasking, +) __all__ = [ "GeneralSatelliteTasking", diff --git a/src/bsk_rl/gym.py b/src/bsk_rl/gym.py index 7547675..4f7d942 100644 --- a/src/bsk_rl/gym.py +++ b/src/bsk_rl/gym.py @@ -28,6 +28,8 @@ MultiSatAct = Iterable[SatAct] SatArgRandomizer = Callable[[list[Satellite]], dict[Satellite, dict[str, Any]]] +NO_ACTION = int(2**31) - 1 + class GeneralSatelliteTasking(Env, Generic[SatObs, SatAct]): @@ -356,7 +358,8 @@ def _step(self, actions: MultiSatAct) -> None: if len(actions) != len(self.satellites): raise ValueError("There must be the same number of actions and satellites") for satellite, action in zip(self.satellites, actions): - if action is not None: + satellite.info = [] # reset satellite info log + if action is not None and action != NO_ACTION: satellite.requires_retasking = False satellite.set_action(action) if not satellite.is_alive(): diff --git a/src/bsk_rl/utils/rllib.py b/src/bsk_rl/utils/rllib.py index 78fd0a0..74237c2 100644 --- a/src/bsk_rl/utils/rllib.py +++ b/src/bsk_rl/utils/rllib.py @@ -1,8 +1,27 @@ """``bsk_rl.utils.rllib`` is a collection of utilities for working with RLlib.""" -from typing import Any +from typing import Any, List, Optional +import gymnasium as gym +import numpy as np from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner +from ray.rllib.algorithms.ppo.tf.ppo_tf_learner import PPOTfLearner +from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.evaluation.postprocessing import Postprocessing +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.postprocessing.episodes import ( + add_one_ts_to_episodes_and_truncate, + remove_last_ts_from_data, + remove_last_ts_from_episodes_and_restore_truncateds, +) +from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary +from ray.rllib.utils.typing import EpisodeType + +from bsk_rl import NO_ACTION def unpack_config(env): @@ -57,37 +76,509 @@ def pull_env_metrics(self, env) -> dict[str, float]: """ return {} - def on_episode_end( + def on_episode_step( self, env=None, metrics_logger=None, + episode=None, **kwargs, ) -> None: """Call pull_env_metrics and log the results. :meta private: """ - if "base_env" in kwargs: # Old RLlib stack - env = kwargs["base_env"].vector_env.envs[0] # noqa: F841 - env_data = self.pull_env_metrics(env) - for k, v in env_data.items(): - kwargs["episode"].custom_metrics[k] = v - else: # New RLlib stack - env = env.envs[0] - env_data = self.pull_env_metrics(env) - for k, v in env_data.items(): - metrics_logger.log_value(k, v) - - def on_train_result(self, *, algorithm, result, **kwargs): - """Log frames per second. + # print(episode.is_done) + env = env.envs[-1] + env_data = self.pull_env_metrics(env) + for k, v in env_data.items(): + metrics_logger.log_value(k, v, clear_on_reduce=True) + + # def on_episode_end( + # self, + # env=None, + # metrics_logger=None, + # **kwargs, + # ) -> None: + # """Call pull_env_metrics and log the results. + + # :meta private: + # """ + # if "base_env" in kwargs: # Old RLlib stack + # env = kwargs["base_env"].vector_env.envs[0] # noqa: F841 + # env_data = self.pull_env_metrics(env) + # for k, v in env_data.items(): + # kwargs["episode"].custom_metrics[k] = v + # else: # New RLlib stack + # env = env.envs[-1] + # env_data = self.pull_env_metrics(env) + # for k, v in env_data.items(): + # metrics_logger.log_value(k, v, clear_on_reduce=True) + + # def on_train_result(self, *, algorithm, result, **kwargs): + # """Log frames per second. + + # :meta private: + # """ + # if "num_env_steps_sampled_this_iter" in result: + # result["fps"] = ( + # result["num_env_steps_sampled_this_iter"] / result["time_this_iter_s"] + # ) + + +class MultiagentEpisodeDataCallbacks(DefaultCallbacks): + def __init__(self, *args, **kwargs): + """Log information at the end of each episode. + + Make a subclass of ``MultiagentEpisodeDataCallbacks`` and override + ``pull_env_metrics`` and ``pull_sat_metrics`` to log environment-specific + information at the end of each episode. Satellite metrics are logged per-satellite + and as a mean across all satellites. + + Note that satellites persist in the simulator even after death, so recorded values + from the end of the episode may not be the same as the values when the agent died. + """ + super().__init__(*args, **kwargs) + + def pull_env_metrics(self, env) -> dict[str, float]: + """Log environment metrics at the end of each episode. + + This function is called whenever ``env`` is terminated or truncated. It should + return a dictionary of metrics to log. + + Args: + env: An environment that has completed. + """ + return {} + + def pull_sat_metrics(self, env, satellite) -> dict[str, float]: + """Log per-satellite metrics at the end of each episode. + + This function is called whenever ``env`` is terminated or truncated. It should + return a dictionary of metrics to log. + + Args: + env: An environment that has completed. + satellite: A satellite in the environment. + """ + return {} + + def on_episode_end( + self, + env=None, + metrics_logger=None, + **kwargs, + ) -> None: + """Call pull_env_metrics and log the results. :meta private: """ - if "num_env_steps_sampled_this_iter" in result: - result["fps"] = ( - result["num_env_steps_sampled_this_iter"] / result["time_this_iter_s"] + env = env.par_env + env_data = self.pull_env_metrics(env) + for k, v in env_data.items(): + metrics_logger.log_value(k, v, clear_on_reduce=True) + + all_sat_data = [] + + for sat in env.satellites: + sat_data = self.pull_sat_metrics(env, sat) + all_sat_data.append(sat_data) + for k, v in sat_data.items(): + metrics_logger.log_value(f"{sat.id}/{k}", v, clear_on_reduce=True) + + for k in all_sat_data[0].keys(): + metrics_logger.log_value( + f"mean/{k}", + np.mean([sat_data[k] for sat_data in all_sat_data]), + clear_on_reduce=True, + ) + + +class ContinuePreviousAction(ConnectorV2): + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + + for sa_episode in self.single_agent_episode_iterator( + episodes, + agents_that_stepped_only=True, + ): + if not sa_episode.get_infos(-1)["requires_retasking"]: + if sa_episode.agent_id is None: + assert len(data[Columns.ACTIONS]) == 1 + id_tuple = list(data[Columns.ACTIONS].keys())[0] + else: + id_tuples = [ + id_tuple + for id_tuple in data[Columns.ACTIONS].keys() + if id_tuple[1] == sa_episode.agent_id + ] + if len(id_tuples) == 0: + return data + else: + id_tuple = id_tuples[0] + data[Columns.ACTIONS][id_tuple][0] = NO_ACTION + return data + + +class MakeAddedStepActionValid(ConnectorV2): + def __init__(self, *args, expected_train_batch_size, **kwargs): + super().__init__(*args, **kwargs) + self.expected_train_batch_size = expected_train_batch_size + + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + + total_episodes = 0 + total_steps = 0 + episode_lens = [] + episode_multi_agent_ids = [] + for episode in self.single_agent_episode_iterator( + episodes, agents_that_stepped_only=False + ): + episode_lens.append(len(episode)) + episode_multi_agent_ids.append(episode.multi_agent_episode_id) + + if episode_multi_agent_ids[0] is None: + total_episodes = len(episode_lens) + total_steps = sum(episode_lens) + else: + total_episodes = len(set(episode_multi_agent_ids)) + max_lens = {} + for episode_id, length in zip(episode_multi_agent_ids, episode_lens): + if episode_id not in max_lens or length > max_lens[episode_id]: + max_lens[episode_id] = length + total_steps = sum(max_lens.values()) + + one_ts_added = False + if total_steps == self.expected_train_batch_size + total_episodes: + one_ts_added = True + + if one_ts_added: + for episode in self.single_agent_episode_iterator( + episodes, agents_that_stepped_only=False + ): + last_action = NO_ACTION + for action in reversed(episode.actions): + if last_action == NO_ACTION: + last_action = action + else: + break + if last_action == NO_ACTION: + last_action = 0 + episode.actions[-1] = last_action + + episode.validate() + + return data + + +class CondenseMultiStepActions(ConnectorV2): + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + + for episode in self.single_agent_episode_iterator( + episodes, agents_that_stepped_only=False + ): + if NO_ACTION not in episode.actions: + continue + + action_idx = list( + np.argwhere( + [action != NO_ACTION for action in episode.actions.data] + ).flatten() + ) + obs_idx = action_idx.copy() + obs_idx.append(len(episode) - 1) + + lookback = episode.actions.data[: episode.actions.lookback] + new_lookback = episode.actions.lookback + for action in lookback: + if action == NO_ACTION: + new_lookback -= 1 + else: + break + + # Only keep non-None actions + episode.actions.data = np.array(episode.actions.data)[action_idx] + episode.actions.lookback = new_lookback + episode.extra_model_outputs[Columns.ACTION_DIST_INPUTS].data = ( + episode.extra_model_outputs[Columns.ACTION_DIST_INPUTS].data[action_idx] + ) + episode.extra_model_outputs[Columns.ACTION_DIST_INPUTS].lookback = ( + new_lookback + ) + episode.extra_model_outputs[Columns.ACTION_LOGP].data = ( + episode.extra_model_outputs[Columns.ACTION_LOGP].data[action_idx] + ) + episode.extra_model_outputs[Columns.ACTION_LOGP].lookback = new_lookback + + # Update episode length + episode.t = episode.t_started + len(episode.actions) + + # Only keep obs that resulted in those actions + episode.observations.data = np.array(episode.observations.data)[obs_idx] + episode.observations.lookback = new_lookback + + # Associate subsequent rewards with prior action + rewards = [] + requires_retasking = [] + for i, idx_start in enumerate(action_idx): + if i == len(action_idx) - 1: + idx_end = len(episode) - 1 + else: + idx_end = action_idx[i + 1] - 1 + rewards.append( + sum(episode.rewards[idx_start : idx_end + 1]) + ) # Doesn't discount over course of multistep + requires_retasking.append( + episode.infos.data[idx_start]["requires_retasking"] + ) + requires_retasking.append(True) + episode.rewards.data = np.array(rewards) + episode.rewards.lookback = new_lookback + + # Accumulate d_ts from prior action + d_ts = [] + for i, idx_end in enumerate(obs_idx): + if i == 0: + idx_start = 0 + else: + idx_start = obs_idx[i - 1] + 1 + d_ts.append( + sum( + info["d_ts"] + for info in episode.infos.data[idx_start : idx_end + 1] + ) + ) + episode.infos.data = [ + dict(d_ts=d_ts, requires_retasking=requires_retasking) + for d_ts, requires_retasking in zip(d_ts, requires_retasking) + ] + episode.infos.lookback = new_lookback + + episode.validate() + + return data + + +def compute_value_targets_time_discounted( + values, + rewards, + terminateds, + truncateds, + step_durations, + gamma: float, + lambda_: float, +): + """Computes value function (vf) targets given vf predictions and rewards. + + Note that advantages can then easily be computed via the formula: + advantages = targets - vf_predictions + """ + + # Shift step durations to associate with previous timestep + # delta_t->t+1 comes with t+1's info, but should be used with t + step_durations = np.concatenate((step_durations[1:], [step_durations[-1]])) + + # Force-set all values at terminals (not at truncations!) to 0.0. + orig_values = flat_values = values * (1.0 - terminateds) + + flat_values = np.append(flat_values, 0.0) + # intermediates = rewards + gamma * (1 - lambda_) * flat_values[1:] + # intermediates = rewards + gamma**step_durations * (1 - lambda_) * flat_values[1:] + intermediates = gamma**step_durations * (rewards + (1 - lambda_) * flat_values[1:]) + continues = 1.0 - terminateds + + Rs = [] + last = flat_values[-1] + for t in reversed(range(intermediates.shape[0])): + last = ( + intermediates[t] + + continues[t] * gamma ** step_durations[t] * lambda_ * last + ) + # last = ( + # intermediates[t] + # + continues[t] * gamma * lambda_ * last + # ) + + Rs.append(last) + if truncateds[t]: + last = orig_values[t] + + # Reverse back to correct (time) direction. + value_targets = np.stack(list(reversed(Rs)), axis=0) + + return value_targets.astype(np.float32) + + +class TimeDiscountedGAEPPOLearner(PPOLearner): + def _compute_gae_from_episodes( + self, + *, + episodes: Optional[List[EpisodeType]] = None, + ) -> tuple[Optional[dict[str, Any]], Optional[List[EpisodeType]]]: + if not episodes: + raise ValueError( + "`PPOLearner._compute_gae_from_episodes()` must have the `episodes` " + "arg provided! Otherwise, GAE/advantage computation can't be performed." + ) + + batch = {} + + sa_episodes_list = list( + self._learner_connector.single_agent_episode_iterator( + episodes, agents_that_stepped_only=False + ) + ) + # Make all episodes one ts longer in order to just have a single batch + # (and distributed forward pass) for both vf predictions AND the bootstrap + # vf computations. + orig_truncateds_of_sa_episodes = add_one_ts_to_episodes_and_truncate( + sa_episodes_list + ) + + # Call the learner connector (on the artificially elongated episodes) + # in order to get the batch to pass through the module for vf (and + # bootstrapped vf) computations. + batch_for_vf = self._learner_connector( + rl_module=self.module, + data={}, + episodes=episodes, + shared_data={}, + ) + + # print(batch_for_vf) + # Perform the value model's forward pass. + vf_preds = convert_to_numpy(self._compute_values(batch_for_vf)) + + for module_id, module_vf_preds in vf_preds.items(): + # Collect new (single-agent) episode lengths. + episode_lens_plus_1 = [ + len(e) + for e in sa_episodes_list + if e.module_id is None or e.module_id == module_id + ] + + # Remove all zero-padding again, if applicable, for the upcoming + # GAE computations. + module_vf_preds = unpad_data_if_necessary( + episode_lens_plus_1, module_vf_preds + ) + + # Compute value targets. + module_value_targets = compute_value_targets_time_discounted( + values=module_vf_preds, + rewards=unpad_data_if_necessary( + episode_lens_plus_1, + convert_to_numpy(batch_for_vf[module_id][Columns.REWARDS]), + ), + terminateds=unpad_data_if_necessary( + episode_lens_plus_1, + convert_to_numpy(batch_for_vf[module_id][Columns.TERMINATEDS]), + ), + truncateds=unpad_data_if_necessary( + episode_lens_plus_1, + convert_to_numpy(batch_for_vf[module_id][Columns.TRUNCATEDS]), + ), + step_durations=unpad_data_if_necessary( + episode_lens_plus_1, + np.array( + [ + info["d_ts"] + for info in batch_for_vf[module_id][Columns.INFOS] + ] + ), + ), + gamma=self.config.gamma, + lambda_=self.config.lambda_, ) + # Remove the extra timesteps again from vf_preds and value targets. Now that + # the GAE computation is done, we don't need this last timestep anymore in + # any of our data. + module_vf_preds, module_value_targets = remove_last_ts_from_data( + episode_lens_plus_1, + module_vf_preds, + module_value_targets, + ) + module_advantages = module_value_targets - module_vf_preds + # Drop vf-preds, not needed in loss. Note that in the PPORLModule, vf-preds + # are recomputed with each `forward_train` call anyway. + # Standardize advantages (used for more stable and better weighted + # policy gradient computations). + module_advantages = (module_advantages - module_advantages.mean()) / max( + 1e-4, module_advantages.std() + ) + + # Restructure ADVANTAGES and VALUE_TARGETS in a way that the Learner + # connector can properly re-batch these new fields. + batch_pos = 0 + for eps in sa_episodes_list: + if eps.module_id is not None and eps.module_id != module_id: + continue + len_ = len(eps) - 1 + self._learner_connector.add_n_batch_items( + batch=batch, + column=Postprocessing.ADVANTAGES, + items_to_add=module_advantages[batch_pos : batch_pos + len_], + num_items=len_, + single_agent_episode=eps, + ) + self._learner_connector.add_n_batch_items( + batch=batch, + column=Postprocessing.VALUE_TARGETS, + items_to_add=module_value_targets[batch_pos : batch_pos + len_], + num_items=len_, + single_agent_episode=eps, + ) + batch_pos += len_ + + # Remove the extra (artificial) timesteps again at the end of all episodes. + remove_last_ts_from_episodes_and_restore_truncateds( + sa_episodes_list, + orig_truncateds_of_sa_episodes, + ) + + return batch, episodes + + +class TimeDiscountedGAEPPOTorchLearner(PPOTorchLearner, TimeDiscountedGAEPPOLearner): + pass + + +class TimeDiscountedGAEPPOTfLearner(PPOTfLearner, TimeDiscountedGAEPPOLearner): + pass + __doc_title__ = "RLlib Utilities" -__all__ = ["unpack_config", "EpisodeDataCallbacks"] +__all__ = [ + "unpack_config", + "EpisodeDataCallbacks", + "MultiagentEpisodeDataCallbacks", + "ContinuePreviousAction", + "MakeAddedStepActionValid", + "CondenseMultiStepActions", +]