From ad4e25680439998e742798e32ead3f3fdd872b86 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 21 Dec 2023 14:08:31 +0100 Subject: [PATCH] [RLlib] New ConnectorV2 API #03: Introduce actual `ConnectorV2` API. (#41074) (#41212) --- rllib/BUILD | 17 +- rllib/algorithms/algorithm.py | 18 +- rllib/algorithms/algorithm_config.py | 143 ++++++++++ rllib/algorithms/impala/impala.py | 7 +- rllib/algorithms/pg/pg.py | 5 +- rllib/algorithms/ppo/ppo.py | 20 +- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 20 +- .../ppo/torch/ppo_torch_rl_module.py | 20 +- rllib/connectors/common/__init__.py | 0 rllib/connectors/common/frame_stacking.py | 136 +++++++++ rllib/connectors/connector_pipeline_v2.py | 268 ++++++++++++++++++ rllib/connectors/connector_v2.py | 199 +++++++++++++ rllib/connectors/env_to_module/__init__.py | 9 + .../env_to_module/default_env_to_module.py | 80 ++++++ .../env_to_module/env_to_module_pipeline.py | 32 +++ .../env_to_module/frame_stacking.py | 6 + .../env_to_module/prev_action_prev_reward.py | 132 +++++++++ rllib/connectors/input_output_types.py | 75 +++++ rllib/connectors/learner/__init__.py | 11 + .../learner/default_learner_connector.py | 228 +++++++++++++++ rllib/connectors/learner/frame_stacking.py | 6 + .../learner/learner_connector_pipeline.py | 5 + rllib/connectors/module_to_env/__init__.py | 9 + .../module_to_env/default_module_to_env.py | 155 ++++++++++ .../module_to_env/module_to_env_pipeline.py | 5 + rllib/connectors/utils/zero_padding.py | 135 +++++++++ rllib/core/learner/torch/torch_learner.py | 1 + rllib/core/models/catalog.py | 11 +- rllib/core/models/torch/encoder.py | 2 +- rllib/env/wrappers/atari_wrappers.py | 21 +- ..._CONNECTOR_EXAMPLES_TO_SEPARATE_FOLDER.txt | 0 .../connectors/connector_v2_frame_stacking.py | 178 ++++++++++++ rllib/utils/filter_manager.py | 2 +- rllib/utils/numpy.py | 14 +- rllib/utils/tests/test_minibatch_utils.py | 8 +- rllib/utils/torch_utils.py | 4 +- 36 files changed, 1911 insertions(+), 71 deletions(-) create mode 100644 rllib/connectors/common/__init__.py create mode 100644 rllib/connectors/common/frame_stacking.py create mode 100644 rllib/connectors/connector_pipeline_v2.py create mode 100644 rllib/connectors/connector_v2.py create mode 100644 rllib/connectors/env_to_module/__init__.py create mode 100644 rllib/connectors/env_to_module/default_env_to_module.py create mode 100644 rllib/connectors/env_to_module/env_to_module_pipeline.py create mode 100644 rllib/connectors/env_to_module/frame_stacking.py create mode 100644 rllib/connectors/env_to_module/prev_action_prev_reward.py create mode 100644 rllib/connectors/input_output_types.py create mode 100644 rllib/connectors/learner/__init__.py create mode 100644 rllib/connectors/learner/default_learner_connector.py create mode 100644 rllib/connectors/learner/frame_stacking.py create mode 100644 rllib/connectors/learner/learner_connector_pipeline.py create mode 100644 rllib/connectors/module_to_env/__init__.py create mode 100644 rllib/connectors/module_to_env/default_module_to_env.py create mode 100644 rllib/connectors/module_to_env/module_to_env_pipeline.py create mode 100644 rllib/connectors/utils/zero_padding.py create mode 100644 rllib/examples/connectors/TODO_MOVE_OLD_CONNECTOR_EXAMPLES_TO_SEPARATE_FOLDER.txt create mode 100644 rllib/examples/connectors/connector_v2_frame_stacking.py diff --git a/rllib/BUILD b/rllib/BUILD index 7079a7cd63425..0623e5455815e 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -747,7 +747,7 @@ py_test( # -------------------------------------------------------------------- -# Connector tests +# Connector(V1) tests # rllib/connector/ # # Tag: connector @@ -774,6 +774,21 @@ py_test( srcs = ["connectors/tests/test_agent.py"] ) +# -------------------------------------------------------------------- +# ConnectorV2 tests +# rllib/connector/ +# +# Tag: connector_v2 +# -------------------------------------------------------------------- + +# TODO (sven): Add these tests in a separate PR. +# py_test( +# name = "connectors/tests/test_connector_v2", +# tags = ["team:rllib", "connector_v2"], +# size = "small", +# srcs = ["connectors/tests/test_connector_v2.py"] +# ) + # -------------------------------------------------------------------- # Env tests # rllib/env/ diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 170675e0e3956..11ba31c794da3 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -564,6 +564,11 @@ def setup(self, config: AlgorithmConfig) -> None: config_obj.env = self._env_id self.config = config_obj + self._uses_new_env_runners = ( + self.config.env_runner_cls is not None + and not issubclass(self.config.env_runner_cls, RolloutWorker) + ) + # Set Algorithm's seed after we have - if necessary - enabled # tf eager-execution. update_global_seed_if_necessary(self.config.framework_str, self.config.seed) @@ -751,13 +756,12 @@ def setup(self, config: AlgorithmConfig) -> None: ) # Only when using RolloutWorkers: Update also the worker set's - # `should_module_be_updated_fn` (analogous to is_policy_to_train). + # `is_policy_to_train` (analogous to LearnerGroup's + # `should_module_be_updated_fn`). # Note that with the new EnvRunner API in combination with the new stack, # this information only needs to be kept in the LearnerGroup and not on the # EnvRunners anymore. - if self.config.env_runner_cls is None or issubclass( - self.config.env_runner_cls, RolloutWorker - ): + if not self._uses_new_env_runners: update_fn = self.learner_group.should_module_be_updated_fn self.workers.foreach_worker( lambda w: w.set_is_policy_to_train(update_fn), @@ -3030,11 +3034,7 @@ def _run_one_evaluation( """ eval_func_to_use = ( self._evaluate_async_with_env_runner - if ( - self.config.enable_async_evaluation - and self.config.env_runner_cls is not None - and not issubclass(self.config.env_runner_cls, RolloutWorker) - ) + if (self.config.enable_async_evaluation and self._uses_new_env_runners) else self._evaluate_async if self.config.enable_async_evaluation else self.evaluate diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index f02fac61b2a03..8b2620241198f 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -99,8 +99,10 @@ if TYPE_CHECKING: from ray.rllib.algorithms.algorithm import Algorithm + from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core.learner import Learner from ray.rllib.core.learner.learner_group import LearnerGroup + from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.evaluation.episode import Episode as OldEpisode logger = logging.getLogger(__name__) @@ -327,6 +329,8 @@ def __init__(self, algo_class=None): self.num_envs_per_worker = 1 self.create_env_on_local_worker = False self.enable_connectors = True + self._env_to_module_connector = None + self._module_to_env_connector = None # TODO (sven): Rename into `sample_timesteps` (or `sample_duration` # and `sample_duration_unit` (replacing batch_mode), like we do it # in the evaluation config). @@ -374,6 +378,7 @@ def __init__(self, algo_class=None): except AttributeError: pass + self._learner_connector = None self.optimizer = {} self.max_requests_in_flight_per_sampler_worker = 2 self._learner_class = None @@ -1152,6 +1157,121 @@ class directly. Note that this arg can also be specified via logger_creator=self.logger_creator, ) + def build_env_to_module_connector(self, env): + from ray.rllib.connectors.env_to_module import ( + EnvToModulePipeline, + DefaultEnvToModule, + ) + + custom_connectors = [] + # Create an env-to-module connector pipeline (including RLlib's default + # env->module connector piece) and return it. + if self._env_to_module_connector is not None: + val_ = self._env_to_module_connector(env) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + if isinstance(val_, ConnectorV2) and not isinstance( + val_, EnvToModulePipeline + ): + custom_connectors = [val_] + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + else: + return val_ + + pipeline = EnvToModulePipeline( + connectors=custom_connectors, + input_observation_space=env.single_observation_space, + input_action_space=env.single_action_space, + env=env, + ) + pipeline.append( + DefaultEnvToModule( + input_observation_space=pipeline.observation_space, + input_action_space=pipeline.action_space, + env=env, + ) + ) + return pipeline + + def build_module_to_env_connector(self, env): + + from ray.rllib.connectors.module_to_env import ( + DefaultModuleToEnv, + ModuleToEnvPipeline, + ) + + custom_connectors = [] + # Create a module-to-env connector pipeline (including RLlib's default + # module->env connector piece) and return it. + if self._module_to_env_connector is not None: + val_ = self._module_to_env_connector(env) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + if isinstance(val_, ConnectorV2) and not isinstance( + val_, ModuleToEnvPipeline + ): + custom_connectors = [val_] + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + else: + return val_ + + pipeline = ModuleToEnvPipeline( + connectors=custom_connectors, + input_observation_space=env.single_observation_space, + input_action_space=env.single_action_space, + env=env, + ) + pipeline.append( + DefaultModuleToEnv( + input_observation_space=pipeline.observation_space, + input_action_space=pipeline.action_space, + env=env, + normalize_actions=self.normalize_actions, + clip_actions=self.clip_actions, + ) + ) + return pipeline + + def build_learner_connector(self, input_observation_space, input_action_space): + from ray.rllib.connectors.learner import ( + DefaultLearnerConnector, + LearnerConnectorPipeline, + ) + + custom_connectors = [] + # Create a learner connector pipeline (including RLlib's default + # learner connector piece) and return it. + if self._learner_connector is not None: + val_ = self._learner_connector(input_observation_space, input_action_space) + + from ray.rllib.connectors.connector_v2 import ConnectorV2 + + if isinstance(val_, ConnectorV2) and not isinstance( + val_, LearnerConnectorPipeline + ): + custom_connectors = [val_] + elif isinstance(val_, (list, tuple)): + custom_connectors = list(val_) + else: + return val_ + + pipeline = LearnerConnectorPipeline( + connectors=custom_connectors, + input_observation_space=input_observation_space, + input_action_space=input_action_space, + ) + pipeline.append( + DefaultLearnerConnector( + input_observation_space=pipeline.observation_space, + input_action_space=pipeline.action_space, + ) + ) + return pipeline + def build_learner_group( self, *, @@ -1605,6 +1725,12 @@ def rollouts( create_env_on_local_worker: Optional[bool] = NotProvided, sample_collector: Optional[Type[SampleCollector]] = NotProvided, enable_connectors: Optional[bool] = NotProvided, + env_to_module_connector: Optional[ + Callable[[EnvType], "ConnectorV2"] + ] = NotProvided, + module_to_env_connector: Optional[ + Callable[[EnvType, "RLModule"], "ConnectorV2"] + ] = NotProvided, use_worker_filter_stats: Optional[bool] = NotProvided, update_worker_filter_stats: Optional[bool] = NotProvided, rollout_fragment_length: Optional[Union[int, str]] = NotProvided, @@ -1650,6 +1776,11 @@ def rollouts( enable_connectors: Use connector based environment runner, so that all preprocessing of obs and postprocessing of actions are done in agent and action connectors. + env_to_module_connector: A callable taking an Env as input arg and returning + an env-to-module ConnectorV2 (might be a pipeline) object. + module_to_env_connector: A callable taking an Env and an RLModule as input + args and returning a module-to-env ConnectorV2 (might be a pipeline) + object. use_worker_filter_stats: Whether to use the workers in the WorkerSet to update the central filters (held by the local worker). If False, stats from the workers will not be used and discarded. @@ -1737,6 +1868,10 @@ def rollouts( self.create_env_on_local_worker = create_env_on_local_worker if enable_connectors is not NotProvided: self.enable_connectors = enable_connectors + if env_to_module_connector is not NotProvided: + self._env_to_module_connector = env_to_module_connector + if module_to_env_connector is not NotProvided: + self._module_to_env_connector = module_to_env_connector if use_worker_filter_stats is not NotProvided: self.use_worker_filter_stats = use_worker_filter_stats if update_worker_filter_stats is not NotProvided: @@ -1855,6 +1990,9 @@ def training( optimizer: Optional[dict] = NotProvided, max_requests_in_flight_per_sampler_worker: Optional[int] = NotProvided, learner_class: Optional[Type["Learner"]] = NotProvided, + learner_connector: Optional[ + Callable[["RLModule"], "ConnectorV2"] + ] = NotProvided, # Deprecated arg. _enable_learner_api: Optional[bool] = NotProvided, ) -> "AlgorithmConfig": @@ -1916,6 +2054,9 @@ def training( in your experiment of timesteps. learner_class: The `Learner` class to use for (distributed) updating of the RLModule. Only used when `_enable_new_api_stack=True`. + learner_connector: A callable taking an env observation space and an env + action space as inputs and returning a learner ConnectorV2 (might be + a pipeline) object. Returns: This updated AlgorithmConfig object. @@ -1960,6 +2101,8 @@ def training( ) if learner_class is not NotProvided: self._learner_class = learner_class + if learner_connector is not NotProvided: + self._learner_connector = learner_connector return self diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index fabde3ee8eb4e..0f29ba3939d79 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -86,18 +86,17 @@ class ImpalaConfig(AlgorithmConfig): # Update the config object. config = config.training( - lr=tune.grid_search([0.0001, ]), grad_clip=20.0 + lr=tune.grid_search([0.0001, 0.0002]), grad_clip=20.0 ) config = config.resources(num_gpus=0) config = config.rollouts(num_rollout_workers=1) # Set the config object's env. config = config.environment(env="CartPole-v1") - # Use to_dict() to get the old-style python config dict - # when running with tune. + # Run with tune. tune.Tuner( "IMPALA", + param_space=config, run_config=air.RunConfig(stop={"training_iteration": 1}), - param_space=config.to_dict(), ).fit() .. testoutput:: diff --git a/rllib/algorithms/pg/pg.py b/rllib/algorithms/pg/pg.py index 390943f8fe143..b5cfa38044053 100644 --- a/rllib/algorithms/pg/pg.py +++ b/rllib/algorithms/pg/pg.py @@ -30,12 +30,11 @@ class PGConfig(AlgorithmConfig): >>> config = config.training(lr=tune.grid_search([0.001, 0.0001])) >>> # Set the config object's env. >>> config = config.environment(env="CartPole-v1") - >>> # Use to_dict() to get the old-style python config dict - >>> # when running with tune. + >>> # Run with tune. >>> tune.Tuner( # doctest: +SKIP ... "PG", ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), - ... param_space=config.to_dict(), + ... param_space=config, ... ).fit() """ diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index c394b96914d83..9f9605312e2e3 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -253,13 +253,10 @@ def training( # Pass kwargs onto super's `training()` method. super().training(**kwargs) - # TODO (sven): Move to generic AlgorithmConfig. - if lr_schedule is not NotProvided: - self.lr_schedule = lr_schedule if use_critic is not NotProvided: self.use_critic = use_critic - # TODO (Kourosh) This is experimental. Set learner_hps parameters as - # well. Don't forget to remove .use_critic from algorithm config. + # TODO (Kourosh) This is experimental. + # Don't forget to remove .use_critic from algorithm config. if use_gae is not NotProvided: self.use_gae = use_gae if lambda_ is not NotProvided: @@ -280,8 +277,6 @@ def training( self.vf_loss_coeff = vf_loss_coeff if entropy_coeff is not NotProvided: self.entropy_coeff = entropy_coeff - if entropy_coeff_schedule is not NotProvided: - self.entropy_coeff_schedule = entropy_coeff_schedule if clip_param is not NotProvided: self.clip_param = clip_param if vf_clip_param is not NotProvided: @@ -289,6 +284,12 @@ def training( if grad_clip is not NotProvided: self.grad_clip = grad_clip + # TODO (sven): Remove these once new API stack is only option for PPO. + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if entropy_coeff_schedule is not NotProvided: + self.entropy_coeff_schedule = entropy_coeff_schedule + return self @override(AlgorithmConfig) @@ -312,8 +313,8 @@ def validate(self) -> None: raise ValueError( f"`sgd_minibatch_size` ({self.sgd_minibatch_size}) must be <= " f"`train_batch_size` ({self.train_batch_size}). In PPO, the train batch" - f" is be split into {self.sgd_minibatch_size} chunks, each of which is " - f"iterated over (used for updating the policy) {self.num_sgd_iter} " + f" will be split into {self.sgd_minibatch_size} chunks, each of which " + f"is iterated over (used for updating the policy) {self.num_sgd_iter} " "times." ) @@ -476,7 +477,6 @@ def training_step(self) -> ResultDict: self.workers.local_worker().set_weights(weights) if self.config._enable_new_api_stack: - kl_dict = {} if self.config.use_kl_loss: for pid in policies_to_update: diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 12856f9d0d8c0..2b30c810568da 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -20,13 +20,15 @@ class PPOTfRLModule(TfRLModule, PPORLModule): def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]: output = {} + # Encoder forward pass. encoder_outs = self.encoder(batch) if STATE_OUT in encoder_outs: output[STATE_OUT] = encoder_outs[STATE_OUT] - # Actions - action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - output[SampleBatch.ACTION_DIST_INPUTS] = action_logits + # Pi head. + output[SampleBatch.ACTION_DIST_INPUTS] = self.pi( + encoder_outs[ENCODER_OUT][ACTOR] + ) return output @@ -34,8 +36,8 @@ def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]: def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: """PPO forward pass during exploration. - Besides the action distribution, this method also returns the parameters of the - policy distribution to be used for computing KL divergence between the old + Besides the action distribution, this method also returns the parameters of + the policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ output = {} @@ -51,7 +53,6 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: # Policy head action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - output[SampleBatch.ACTION_DIST_INPUTS] = action_logits return output @@ -60,16 +61,17 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: def _forward_train(self, batch: NestedDict): output = {} - # Shared encoder + # Shared encoder. encoder_outs = self.encoder(batch) if STATE_OUT in encoder_outs: output[STATE_OUT] = encoder_outs[STATE_OUT] - # Value head + # Value head. vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dim (value function node). output[SampleBatch.VF_PREDS] = tf.squeeze(vf_out, axis=-1) - # Policy head + # Policy head. action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) output[SampleBatch.ACTION_DIST_INPUTS] = action_logits diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 09010c872c896..745f45bb603f6 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -20,21 +20,24 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): def _forward_inference(self, batch: NestedDict) -> Dict[str, Any]: output = {} + # Encoder forward pass. encoder_outs = self.encoder(batch) if STATE_OUT in encoder_outs: output[STATE_OUT] = encoder_outs[STATE_OUT] - # Actions - action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - output[SampleBatch.ACTION_DIST_INPUTS] = action_logits + # Pi head. + output[SampleBatch.ACTION_DIST_INPUTS] = self.pi( + encoder_outs[ENCODER_OUT][ACTOR] + ) return output @override(RLModule) def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: """PPO forward pass during exploration. - Besides the action distribution, this method also returns the parameters of the - policy distribution to be used for computing KL divergence between the old + + Besides the action distribution, this method also returns the parameters of + the policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ output = {} @@ -58,16 +61,17 @@ def _forward_exploration(self, batch: NestedDict) -> Dict[str, Any]: def _forward_train(self, batch: NestedDict) -> Dict[str, Any]: output = {} - # Shared encoder + # Shared encoder. encoder_outs = self.encoder(batch) if STATE_OUT in encoder_outs: output[STATE_OUT] = encoder_outs[STATE_OUT] - # Value head + # Value head. vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) + # Squeeze out last dim (value function node). output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) - # Policy head + # Policy head. action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) output[SampleBatch.ACTION_DIST_INPUTS] = action_logits diff --git a/rllib/connectors/common/__init__.py b/rllib/connectors/common/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/connectors/common/frame_stacking.py b/rllib/connectors/common/frame_stacking.py new file mode 100644 index 0000000000000..3b7592b852a35 --- /dev/null +++ b/rllib/connectors/common/frame_stacking.py @@ -0,0 +1,136 @@ +import numpy as np +from typing import Any, List, Optional + +import gymnasium as gym +import tree # pip install dm_tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import batch +from ray.rllib.utils.typing import EpisodeType + + +class _FrameStackingConnector(ConnectorV2): + """A connector piece that stacks the previous n observations into one.""" + + def __init__( + self, + *, + # Base class constructor args. + input_observation_space: gym.Space, + input_action_space: gym.Space, + # Specific framestacking args. + num_frames: int = 1, + as_learner_connector: bool = False, + **kwargs, + ): + """Initializes a _FrameStackingConnector instance. + + Args: + num_frames: The number of observation frames to stack up (into a single + observation) for the RLModule's forward pass. + as_preprocessor: Whether this connector should simply postprocess the + received observations from the env and store these directly in the + episode object. In this mode, the connector can only be used in + an `EnvToModulePipeline` and it will act as a classic + RLlib framestacking postprocessor. + as_learner_connector: Whether this connector is part of a Learner connector + pipeline, as opposed to an env-to-module pipeline. + """ + super().__init__( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + **kwargs, + ) + + self.num_frames = num_frames + self.as_learner_connector = as_learner_connector + + # Some assumptions: Space is box AND last dim (the stacking one) is 1. + assert isinstance(self.observation_space, gym.spaces.Box) + assert self.observation_space.shape[-1] == 1 + + # Change our observation space according to the given stacking settings. + self.observation_space = gym.spaces.Box( + low=np.repeat(self.observation_space.low, repeats=self.num_frames, axis=-1), + high=np.repeat( + self.observation_space.high, repeats=self.num_frames, axis=-1 + ), + shape=list(self.observation_space.shape)[:-1] + [self.num_frames], + dtype=self.observation_space.dtype, + ) + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # This is a data-in-data-out connector, so we expect `data` to be a dict + # with: key=column name, e.g. "obs" and value=[data to be processed by + # RLModule]. We will add to `data` the last n observations. + observations = [] + + # Learner connector pipeline. Episodes have been finalized/numpy'ized. + if self.as_learner_connector: + for episode in episodes: + + def _map_fn(s): + # Squeeze out last dim. + s = np.squeeze(s, axis=-1) + # Calculate new shape and strides + new_shape = (len(episode), self.num_frames) + s.shape[1:] + new_strides = (s.strides[0],) + s.strides + # Create a strided view of the array. + return np.lib.stride_tricks.as_strided( + s, shape=new_shape, strides=new_strides + ) + + # Get all observations from the episode in one np array (except for + # the very last one, which is the final observation not needed for + # learning). + observations.append( + tree.map_structure( + _map_fn, + episode.get_observations( + indices=slice(-self.num_frames + 1, len(episode)), + neg_indices_left_of_zero=True, + fill=0.0, + ), + ) + ) + + # Move stack-dimension to the end and concatenate along batch axis. + data[SampleBatch.OBS] = tree.map_structure( + lambda *s: np.transpose(np.concatenate(s, axis=0), axes=[0, 2, 3, 1]), + *observations, + ) + + # Env-to-module pipeline. Episodes still operate on lists. + else: + for episode in episodes: + assert not episode.is_finalized + # Get the list of observations to stack. + obs_stack = episode.get_observations( + indices=slice(-self.num_frames, None), + fill=0.0, + ) + # Observation components are (w, h, 1) + # -> stack to (w, h, [num_frames], 1), then squeeze out last dim to get + # (w, h, [num_frames]). + stacked_obs = tree.map_structure( + lambda *s: np.squeeze(np.stack(s, axis=2), axis=-1), + *obs_stack, + ) + observations.append(stacked_obs) + + data[SampleBatch.OBS] = batch(observations) + + return data diff --git a/rllib/connectors/connector_pipeline_v2.py b/rllib/connectors/connector_pipeline_v2.py new file mode 100644 index 0000000000000..86f0649d66a39 --- /dev/null +++ b/rllib/connectors/connector_pipeline_v2.py @@ -0,0 +1,268 @@ +from collections import defaultdict +import logging +from typing import Any, Dict, List, Optional, Type, Union + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI +from ray.util.timer import _Timer + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class ConnectorPipelineV2(ConnectorV2): + """Utility class for quick manipulation of a connector pipeline.""" + + def __init__( + self, + *, + connectors: Optional[List[ConnectorV2]] = None, + **kwargs, + ): + super().__init__(**kwargs) + + self.connectors = connectors or [] + self._fix_input_output_types() + + self.timers = defaultdict(_Timer) + + @override(ConnectorV2) + def __call__( + self, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + """In a pipeline, we simply call each of our connector pieces after each other. + + Each connector piece receives as input the output of the previous connector + piece in the pipeline. + """ + # Loop through connector pieces and call each one with the output of the + # previous one. Thereby, time each connector piece's call. + ret = data + for connector in self.connectors: + timer = self.timers[str(connector)] + with timer: + ret = connector( + rl_module=rl_module, + data=ret, + episodes=episodes, + explore=explore, + shared_data=shared_data, + **kwargs, + ) + return ret + + def remove(self, name_or_class: Union[str, Type]): + """Remove a single connector piece in this pipeline by its name or class. + + Args: + name: The name of the connector piece to be removed from the pipeline. + """ + idx = -1 + for i, c in enumerate(self.connectors): + if c.__class__.__name__ == name_or_class: + idx = i + break + if idx >= 0: + del self.connectors[idx] + self._fix_input_output_types() + logger.info( + f"Removed connector {name_or_class} from {self.__class__.__name__}." + ) + else: + logger.warning( + f"Trying to remove a non-existent connector {name_or_class}." + ) + + def insert_before( + self, + name_or_class: Union[str, type], + connector: ConnectorV2, + ) -> ConnectorV2: + """Insert a new connector piece before an existing piece (by name or class). + + Args: + name_or_class: Name or class of the connector piece before which `connector` + will get inserted. + connector: The new connector piece to be inserted. + + Returns: + The ConnectorV2 before which `connector` has been inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if ( + isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class + ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): + break + if idx < 0: + raise ValueError( + f"Can not find connector with name or type '{name_or_class}'!" + ) + next_connector = self.connectors[idx] + + self.connectors.insert(idx, connector) + self._fix_input_output_types() + + logger.info( + f"Inserted {connector.__class__.__name__} before {name_or_class} " + f"to {self.__class__.__name__}." + ) + return next_connector + + def insert_after( + self, + name_or_class: Union[str, Type], + connector: ConnectorV2, + ) -> ConnectorV2: + """Insert a new connector piece after an existing piece (by name or class). + + Args: + name_or_class: Name or class of the connector piece after which `connector` + will get inserted. + connector: The new connector piece to be inserted. + + Returns: + The ConnectorV2 after which `connector` has been inserted. + """ + idx = -1 + for idx, c in enumerate(self.connectors): + if ( + isinstance(name_or_class, str) and c.__class__.__name__ == name_or_class + ) or (isinstance(name_or_class, type) and c.__class__ is name_or_class): + break + if idx < 0: + raise ValueError( + f"Can not find connector with name or type '{name_or_class}'!" + ) + prev_connector = self.connectors[idx] + + self.connectors.insert(idx + 1, connector) + self._fix_input_output_types() + + logger.info( + f"Inserted {connector.__class__.__name__} after {name_or_class} " + f"to {self.__class__.__name__}." + ) + + return prev_connector + + def prepend(self, connector: ConnectorV2) -> None: + """Prepend a new connector at the beginning of a connector pipeline. + + Args: + connector: The new connector piece to be prepended to this pipeline. + """ + self.connectors.insert(0, connector) + self._fix_input_output_types() + + logger.info( + f"Added {connector.__class__.__name__} to the beginning of " + f"{self.__class__.__name__}." + ) + + def append(self, connector: ConnectorV2) -> None: + """Append a new connector at the end of a connector pipeline. + + Args: + connector: The new connector piece to be appended to this pipeline. + """ + self.connectors.append(connector) + self._fix_input_output_types() + + logger.info( + f"Added {connector.__class__.__name__} to the end of " + f"{self.__class__.__name__}." + ) + + @override(ConnectorV2) + def get_state(self) -> Dict[str, Any]: + states = {} + for i, connector in enumerate(self.connectors): + key = f"{i:03d}_{type(connector).__name__}" + state = connector.get_state() + states[key] = state + return states + + @override(ConnectorV2) + def set_state(self, state: Dict[str, Any]) -> None: + for i, connector in enumerate(self.connectors): + key = f"{i:03d}_{type(connector).__name__}" + if key not in state: + raise KeyError(f"No state found in `state` for connector piece: {key}!") + connector.set_state(state[key]) + + def __repr__(self, indentation: int = 0): + return "\n".join( + [" " * indentation + self.__class__.__name__] + + [c.__str__(indentation + 4) for c in self.connectors] + ) + + def __getitem__( + self, + key: Union[str, int, Type], + ) -> Union[ConnectorV2, List[ConnectorV2]]: + """Returns a single ConnectorV2 or list of ConnectorV2s that fit `key`. + + If key is an int, we return a single ConnectorV2 at that index in this pipeline. + If key is a ConnectorV2 type or a string matching the class name of a + ConnectorV2 in this pipeline, we return a list of all ConnectorV2s in this + pipeline matching the specified class. + + Args: + key: The key to find or to index by. + + Returns: + A single ConnectorV2 or a list of ConnectorV2s matching `key`. + """ + # Key is an int -> Index into pipeline and return. + if isinstance(key, int): + return self.connectors[key] + # Key is a class. + elif isinstance(key, type): + results = [] + for c in self.connectors: + if issubclass(c.__class__, key): + results.append(c) + return results + # Key is a string -> Find connector(s) by name. + elif isinstance(key, str): + results = [] + for c in self.connectors: + if c.name == key: + results.append(c) + return results + # Slicing not supported (yet). + elif isinstance(key, slice): + raise NotImplementedError( + "Slicing of ConnectorPipelineV2 is currently not supported!" + ) + else: + raise NotImplementedError( + f"Indexing ConnectorPipelineV2 by {type(key)} is currently not " + f"supported!" + ) + + def _fix_input_output_types(self): + if len(self.connectors) > 0: + self.input_type = self.connectors[0].input_type + self.output_type = self.connectors[-1].output_type + # TODO (sven): Create some examples for pipelines, in which the spaces + # are changed several times by the individual pieces. + self.input_observation_space = self.connectors[0].input_observation_space + self.input_action_space = self.connectors[0].input_action_space + self._observation_space = self.connectors[-1].observation_space + self._action_space = self.connectors[-1].action_space + else: + self.input_type = None + self.output_type = None + self._observation_space = None + self._action_space = None diff --git a/rllib/connectors/connector_v2.py b/rllib/connectors/connector_v2.py new file mode 100644 index 0000000000000..a4bad77b39da5 --- /dev/null +++ b/rllib/connectors/connector_v2.py @@ -0,0 +1,199 @@ +import abc +from typing import Any, Dict, List, Optional + +import gymnasium as gym + +from ray.rllib.connectors.input_output_types import INPUT_OUTPUT_TYPES +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class ConnectorV2(abc.ABC): + """Base class defining the API for an individual "connector piece". + + A ConnectorV2 ("connector piece") is usually part of a whole series of connector + pieces within a so-called connector pipeline, which in itself also abides to this + very API.. + For example, you might have a connector pipeline consisting of two connector pieces, + A and B, both instances of subclasses of ConnectorV2 and each one performing a + particular transformation on their input data. The resulting connector pipeline + (A->B) itself also abides to this very ConnectorV2 API and could thus be part of yet + another, higher-level connector pipeline. + + Any ConnectorV2 instance (individual pieces or several connector pieces in a + pipeline) is a callable and you should override their `__call__()` method. + When called, they take the outputs of a previous connector piece (or an empty dict + if there are no previous pieces) as well as all the data collected thus far in the + ongoing episode(s) (only applies to connectors used in EnvRunners) or retrieved + from a replay buffer or from an environment sampling step (only applies to + connectors used in Learner pipelines). From this input data, a ConnectorV2 then + performs a transformation step. + + There are 3 types of pipelines any ConnectorV2 piece can belong to: + 1) EnvToModulePipeline: The connector transforms environment data before it gets to + the RLModule. This type of pipeline is used by an EnvRunner for transforming + env output data into RLModule readable data (for the next RLModule forward pass). + For example, such a pipeline would include observation postprocessors, -filters, + or any RNN preparation code related to time-sequences and zero-padding. + 2) ModuleToEnvPipeline: This type of pipeline is used by an + EnvRunner to transform RLModule output data to env readable actions (for the next + `env.step()` call). For example, in case the RLModule only outputs action + distribution parameters (but not actual actions), the ModuleToEnvPipeline would + take care of sampling the actions to be sent back to the end from the + resulting distribution (made deterministic if exploration is off). + 3) LearnerConnectorPipeline: This connector pipeline type transforms data coming + from an `EnvRunner.sample()` call or a replay buffer and will then be sent into the + RLModule's `forward_train()` method in order to compute loss function inputs. + This type of pipeline is used by a Learner worker to transform raw training data + (a batch or a list of episodes) to RLModule readable training data (for the next + RLModule `forward_train()` call). + + Some connectors might be stateful, for example for keeping track of observation + filtering stats (mean and stddev values). Any Algorithm, which uses connectors is + responsible for frequently synchronizing the states of all connectors and connector + pipelines between the EnvRunners (owning the env-to-module and module-to-env + pipelines) and the Learners (owning the Learner pipelines). + """ + + # Set these in ALL subclasses. + # TODO (sven): Irrelevant for single-agent cases. Once multi-agent is supported + # by ConnectorV2, we need to elaborate more on the different input/output types. + # For single-agent, the types should always be just INPUT_OUTPUT_TYPES.DATA. + input_type = INPUT_OUTPUT_TYPES.DATA + output_type = INPUT_OUTPUT_TYPES.DATA + + @property + def observation_space(self): + """Getter for our (output) observation space. + + Logic: Use user provided space (if set via `observation_space` setter) + otherwise, use the same as the input space, assuming this connector piece + does not alter the space. + """ + return self._observation_space or self.input_observation_space + + @observation_space.setter + def observation_space(self, value): + """Setter for our (output) observation space.""" + self._observation_space = value + + @property + def action_space(self): + """Getter for our (output) action space. + + Logic: Use user provided space (if set via `action_space` setter) + otherwise, use the same as the input space, assuming this connector piece + does not alter the space. + """ + return self._action_space or self.input_action_space + + @action_space.setter + def action_space(self, value): + """Setter for our (output) action space.""" + self._action_space = value + + def __init__( + self, + *, + input_observation_space: gym.Space, + input_action_space: gym.Space, + **kwargs, + ): + """Initializes a ConnectorV2 instance. + + Args: + input_observation_space: The input observation space for this connector + piece. This is the space coming from a previous connector piece in the + (env-to-module or learner) pipeline or it is directly defined within + the used gym.Env. + input_action_space: The input action space for this connector piece. This + is the space coming from a previous connector piece in the + (module-to-env) pipeline or it is directly defined within the used + gym.Env. + **kwargs: Forward API-compatibility kwargs. + """ + self.input_observation_space = input_observation_space + self.input_action_space = input_action_space + + self._observation_space = None + self._action_space = None + + @abc.abstractmethod + def __call__( + self, + *, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + """Method for transforming input data into output data. + + Args: + rl_module: An optional RLModule object that the connector might need to know + about. Note that normally, only module-to-env connectors get this + information at construction time, but env-to-module and learner + connectors won't (b/c they get constructed before the RLModule). + data: The input data abiding to `self.input_type` to be transformed by + this connector. Transformations might either be done in-place or a new + structure may be returned that matches `self.output_type`. + episodes: The list of SingleAgentEpisode or MultiAgentEpisode objects, + each corresponding to one slot in the vector env. Note that episodes + should always be considered read-only and not be altered. + explore: Whether `explore` is currently on. Per convention, if True, the + RLModule's `forward_exploration` method should be called, if False, the + EnvRunner should call `forward_inference` instead. + shared_data: Optional additional context data that needs to be exchanged + between different Connector pieces and -pipelines. + kwargs: Forward API-compatibility kwargs. + + Returns: + The transformed connector output abiding to `self.output_type`. + """ + + def get_state(self) -> Dict[str, Any]: + """Returns the current state of this ConnectorV2 as a state dict. + + Returns: + A state dict mapping any string keys to their (state-defining) values. + """ + return {} + + def set_state(self, state: Dict[str, Any]) -> None: + """Sets the state of this ConnectorV2 to the given value. + + Args: + state: The state dict to define this ConnectorV2's new state. + """ + pass + + def reset_state(self) -> None: + """Resets the state of this ConnectorV2 to some initial value. + + Note that this may NOT be the exact state that this ConnectorV2 was originally + constructed with. + """ + pass + + @staticmethod + def merge_states(states: List[Dict[str, Any]]) -> Dict[str, Any]: + """Computes a resulting state given a list of other state dicts. + + Algorithms should use this method for synchronizing states between connectors + running on workers (of the same type, e.g. EnvRunner workers). + + Args: + states: The list of n other ConnectorV2 states to merge into a single + resulting state. + + Returns: + The resulting state dict. + """ + return {} + + def __str__(self, indentation: int = 0): + return " " * indentation + self.__class__.__name__ diff --git a/rllib/connectors/env_to_module/__init__.py b/rllib/connectors/env_to_module/__init__.py new file mode 100644 index 0000000000000..c156044aa9213 --- /dev/null +++ b/rllib/connectors/env_to_module/__init__.py @@ -0,0 +1,9 @@ +from ray.rllib.connectors.env_to_module.default_env_to_module import DefaultEnvToModule +from ray.rllib.connectors.env_to_module.env_to_module_pipeline import ( + EnvToModulePipeline, +) + +__all__ = [ + "DefaultEnvToModule", + "EnvToModulePipeline", +] diff --git a/rllib/connectors/env_to_module/default_env_to_module.py b/rllib/connectors/env_to_module/default_env_to_module.py new file mode 100644 index 0000000000000..9a5813403036f --- /dev/null +++ b/rllib/connectors/env_to_module/default_env_to_module.py @@ -0,0 +1,80 @@ +from typing import Any, List, Optional + +import numpy as np + +import tree +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.models.base import STATE_IN, STATE_OUT +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import batch +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class DefaultEnvToModule(ConnectorV2): + """Default connector piece added by RLlib to the end of any env-to-module pipeline. + + Makes sure that the output data will have at the minimum: + a) An observation (the most recent one returned by `env.step()`) under the + SampleBatch.OBS key for each agent and + b) In case the RLModule is stateful, a STATE_IN key populated with the most recently + computed STATE_OUT. + + The connector will not add any new data in case other connector pieces in the + pipeline already take care of populating these fields (obs and state in). + + TODO (sven): Generalize to MultiAgentEpisodes. + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any] = None, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # If observations cannot be found in `input`, add the most recent ones (from all + # episodes). + if SampleBatch.OBS not in data: + # Collect all most-recent observations from given episodes. + observations = [] + for episode in episodes: + observations.append(episode.get_observations(indices=-1)) + # Batch all collected observations together. + data[SampleBatch.OBS] = batch(observations) + + # If our module is stateful: + # - Add the most recent STATE_OUTs to `data`. + # - Make all data in `data` have a time rank (T=1). + if rl_module.is_stateful(): + # Collect all most recently computed STATE_OUT (or use initial states from + # RLModule if at beginning of episode). + states = [] + for episode in episodes: + # Make sure, we have at least one observation in the episode. + assert episode.observations + + # TODO (sven): Generalize to MultiAgentEpisodes. + # Episode just started -> Get initial state from our RLModule. + if len(episode) == 0: + state = rl_module.get_initial_state() + # Episode is already ongoing -> Use most recent STATE_OUT. + else: + state = episode.extra_model_outputs[STATE_OUT][-1] + states.append(state) + + # Make all other inputs have an additional T=1 axis. + data = tree.map_structure(lambda s: np.expand_dims(s, axis=1), data) + + # Batch states (from list of individual vector sub-env states). + # Note that state ins should NOT have the extra time dimension. + data[STATE_IN] = batch(states) + + return data diff --git a/rllib/connectors/env_to_module/env_to_module_pipeline.py b/rllib/connectors/env_to_module/env_to_module_pipeline.py new file mode 100644 index 0000000000000..5f790ec84a769 --- /dev/null +++ b/rllib/connectors/env_to_module/env_to_module_pipeline.py @@ -0,0 +1,32 @@ +from typing import Any, List, Optional + +from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class EnvToModulePipeline(ConnectorPipelineV2): + @override(ConnectorPipelineV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any] = None, + episodes: List[EpisodeType], + explore: bool, + shared_data: Optional[dict] = None, + **kwargs, + ): + # Make sure user does not necessarily send initial input into this pipeline. + # Might just be empty and to be populated from `episodes`. + return super().__call__( + rl_module=rl_module, + data=data if data is not None else {}, + episodes=episodes, + explore=explore, + shared_data=shared_data, + **kwargs, + ) diff --git a/rllib/connectors/env_to_module/frame_stacking.py b/rllib/connectors/env_to_module/frame_stacking.py new file mode 100644 index 0000000000000..b05385b6c10e2 --- /dev/null +++ b/rllib/connectors/env_to_module/frame_stacking.py @@ -0,0 +1,6 @@ +from functools import partial + +from ray.rllib.connectors.common.frame_stacking import _FrameStackingConnector + + +FrameStackingEnvToModule = partial(_FrameStackingConnector, as_learner_connector=False) diff --git a/rllib/connectors/env_to_module/prev_action_prev_reward.py b/rllib/connectors/env_to_module/prev_action_prev_reward.py new file mode 100644 index 0000000000000..cae717beee0b1 --- /dev/null +++ b/rllib/connectors/env_to_module/prev_action_prev_reward.py @@ -0,0 +1,132 @@ +from functools import partial +import numpy as np +from typing import Any, List, Optional + +import gymnasium as gym + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.spaces.space_utils import batch +from ray.rllib.utils.typing import EpisodeType + + +class _PrevRewardPrevActionConnector(ConnectorV2): + """A connector piece that adds previous rewards and actions to the input.""" + + def __init__( + self, + *, + # Base class constructor args. + input_observation_space: gym.Space, + input_action_space: gym.Space, + # Specific prev. r/a args. + n_prev_actions: int = 1, + n_prev_rewards: int = 1, + as_learner_connector: bool = False, + **kwargs, + ): + """Initializes a _PrevRewardPrevActionConnector instance. + + Args: + n_prev_actions: The number of previous actions to include in the output + data. Discrete actions are ont-hot'd. If > 1, will concatenate the + individual action tensors. + n_prev_rewards: The number of previous rewards to include in the output + data. + as_learner_connector: Whether this connector is part of a Learner connector + pipeline, as opposed to a env-to-module pipeline. + """ + super().__init__( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + **kwargs, + ) + + self.n_prev_actions = n_prev_actions + self.n_prev_rewards = n_prev_rewards + self.as_learner_connector = as_learner_connector + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Optional[Any], + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # This is a data-in-data-out connector, so we expect `data` to be a dict + # with: key=column name, e.g. "obs" and value=[data to be processed by + # RLModule]. We will just extract the most recent rewards and/or most recent + # actions from all episodes and store them inside the `data` data dict. + + prev_a = [] + prev_r = [] + for episode in episodes: + # TODO (sven): Get rid of this distinction. With the new Episode APIs, + # this should work the same, whether on finalized or non-finalized + # episodes. + # Learner connector pipeline. Episodes have been finalized/numpy'ized. + if self.as_learner_connector: + assert episode.is_finalized + # Loop through each timestep in the episode and add the previous n + # actions and previous m rewards (based on that timestep) to the batch. + for ts in range(len(episode)): + prev_a.append( + episode.get_actions( + # Extract n actions from `ts - n` to `ts` (excluding `ts`). + indices=slice(ts - self.n_prev_actions, ts), + # Make sure negative indices are NOT interpreted as + # "counting from the end", but as absolute indices meaning + # they refer to timesteps before 0 (which is the lookback + # buffer). + neg_indices_left_of_zero=True, + # In case we are at the very beginning of the episode, e.g. + # ts==0, fill the left side with zero-actions. + fill=0.0, + # Return one-hot arrays for those action components that are + # discrete or multi-discrete. + one_hot_discrete=True, + ) + ) + # Do the same for rewards. + prev_r.append( + episode.get_rewards( + indices=slice(ts - self.n_prev_rewards, ts), + neg_indices_left_of_zero=True, + fill=0.0, + ) + ) + # Env-to-module pipeline. Episodes still operate on lists. + else: + assert not episode.is_finalized + prev_a.append( + batch( + episode.get_actions( + indices=slice(-self.n_prev_actions, None), + fill=0.0, + one_hot_discrete=True, + ) + ) + ) + prev_r.append( + np.array( + episode.get_rewards( + indices=slice(-self.n_prev_rewards, None), + fill=0.0, + ) + ) + ) + + data[SampleBatch.PREV_ACTIONS] = batch(prev_a) + data[SampleBatch.PREV_REWARDS] = np.array(prev_r) + return data + + +PrevRewardPrevActionEnvToModule = partial( + _PrevRewardPrevActionConnector, as_learner_connector=False +) diff --git a/rllib/connectors/input_output_types.py b/rllib/connectors/input_output_types.py new file mode 100644 index 0000000000000..da9343c040678 --- /dev/null +++ b/rllib/connectors/input_output_types.py @@ -0,0 +1,75 @@ +from enum import Enum + + +class INPUT_OUTPUT_TYPES(Enum): + """Definitions of possible datatypes being processed by individual connectors. + + TODO: Make sure this is valid: + Each connector will always receive a list of Episodes (MultiAgentEpisodes or + SingleAgentEpisodes, depending on the setup and EnvRunner used). In addition, the + output of the previous connector (or an empty dict at the beginnnig) will be + received. + An IntoModule connector pipeline should eventually output a dict mapping module IDs + to SampleBatches + + Typical env-module-env pipeline: + env.step(List[Data]) -> List[MultiAgentEpisode] + + connector: auto-agent-extraction: List[MultiAgentEpisode] -> dict[AgentID, Data] + connector: auto-broadcast: Data -> Data (legacy postprocessing and filtering) + under the hood: dict[AgentID, Data] -> dict[AgentID, Data] + connector: auto-policy-mapping: dict[AgentID, Data] -> dict[ModuleID, Data] + + module.forward_exploration() -> dict[ModuleID, Data] + + connector: auto-action-sampling: dict[ModuleID, Data] -> dict[ModuleID, Data] + connector: action-clipping: Data -> Data + under the hood: dict[ModuleID, Data] -> dict[ModuleID, Data] + connector: auto-policy-unmapping: dict[ModuleID, Data] -> dict[AgentID, Data] + (using information stored in connector ctx) + connector: auto-action-sorting (using information stored in connector ctx): + dict[AgentID, Data] -> List[Data] + + env.step(List[Data]) ... repeats + + Typical training pipeline: + + + Default env-module-env pipeline picked by RLlib if no connector defined by user AND + module is RNN: + env.step(List[Data]) -> List[MultiAgentEpisode] + + connector: auto-agent-extraction: List[MultiAgentEpisode] -> dict[AgentID, Data] + connector: auto-policy-mapping: dict[AgentID, Data] -> dict[ModuleID, Data] + connector: auto-state-handling: dict[ModuleID, Data] -> + dict[ModuleID, Data + state] (using information stored in connector ctx) + + module.forward_exploration() -> dict[ModuleID, Data + state] + + connector: auto-state-handling: dict[ModuleID, Data + state] -> + dict[ModuleID, Data] (state was stored in ctx) + connector: auto-policy-unmapping: dict[ModuleID, Data] -> + dict[AgentID, Data] (using information stored in connector ctx) + connector: auto-action-sorting (using information stored in connector ctx): + dict[AgentID, Data] -> List[Data] + + env.step(List[Data]) ... repeats + """ + + # Normally, after `env.step()`, we have a list (vector env) of MultiAgentEpisodes + # as a starting point. + LIST_OF_MULTI_AGENT_EPISODES = 0 + # In the simplified case, there might be a list of SingleAgentEpisodes, instead. + LIST_OF_SINGLE_AGENT_EPISODES = 1 + + # From each MultiAgentEpisode, we might extract a dict, mapping agent IDs to data. + LIST_OF_DICTS_MAPPING_AGENT_IDS_TO_DATA = 10 + # Eventually boiling down to simply one dict mapping agent IDs to data. + # + DICT_MAPPING_AGENT_IDS_TO_DATA = 11 + + # Right after the module's forward pass, we usually have a single dict mapping + # Module IDs to data (model outputs). + DICT_MAPPING_MODULE_IDS_TO_DATA = 12 + + DATA = 11 diff --git a/rllib/connectors/learner/__init__.py b/rllib/connectors/learner/__init__.py new file mode 100644 index 0000000000000..dda5851866ebc --- /dev/null +++ b/rllib/connectors/learner/__init__.py @@ -0,0 +1,11 @@ +from ray.rllib.connectors.learner.default_learner_connector import ( + DefaultLearnerConnector, +) +from ray.rllib.connectors.learner.learner_connector_pipeline import ( + LearnerConnectorPipeline, +) + +__all__ = [ + "DefaultLearnerConnector", + "LearnerConnectorPipeline", +] diff --git a/rllib/connectors/learner/default_learner_connector.py b/rllib/connectors/learner/default_learner_connector.py new file mode 100644 index 0000000000000..6e17beb82f52a --- /dev/null +++ b/rllib/connectors/learner/default_learner_connector.py @@ -0,0 +1,228 @@ +from functools import partial +from typing import Any, List, Optional + +import numpy as np +import tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.models.base import STATE_IN, STATE_OUT +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import EpisodeType + + +class DefaultLearnerConnector(ConnectorV2): + """Connector added by default by RLlib to the end of any learner connector pipeline. + + If provided with `episodes` data, this connector piece makes sure that the final + train batch going into the RLModule for updating (`forward_train()` call) contains + at the minimum: + - Observations: From all episodes under the SampleBatch.OBS key. + - Actions, rewards, terminal/truncation flags: From all episodes under the + respective keys. + - All data inside the episodes' `extra_model_outs` property, e.g. action logp and + action probs under the respective keys. + - States: If the RLModule is stateful, the episodes' STATE_OUTS will be extracted + and restructured under a new STATE_IN key in such a way that the resulting STATE_IN + batch has the shape (B', ...). Here, B' is the sum of splits we have to do over + the given episodes, such that each chunk is at most `max_seq_len` long (T-axis). + Also, all other data will be properly reshaped into (B, T=max_seq_len, ...) and + will be zero-padded, if necessary. + + If the user wants to customize their own data under the given keys (e.g. obs, + actions, ...), they can extract from the episodes or recompute from `data` + their own data and store it in `data` under those keys. In this case, the default + connector will not change the data under these keys and simply act as a + pass-through. + """ + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # If episodes are provided, extract the essential data from them, but only if + # respective keys are not present yet in `data`. + if not episodes: + return data + + # Get the data dicts for all episodes. + data_dicts = [episode.get_data_dict() for episode in episodes] + + state_in = None + T = rl_module.config.model_config_dict.get("max_seq_len") + + # RLModule is stateful and STATE_IN is not found in `data` (user's custom + # connectors have not provided this information yet) -> Perform separate + # handling of STATE_OUT/STATE_IN keys: + if rl_module.is_stateful() and STATE_IN not in data: + if T is None: + raise ValueError( + "You are using a stateful RLModule and are not providing custom " + f"'{STATE_IN}' data through your connector(s)! Therefore, you need " + "to provide the 'max_seq_len' key inside your model config dict. " + "You can set this dict and/or override keys in it via " + "`config.training(model={'max_seq_len': x})`." + ) + # Get model init state. + init_state = convert_to_numpy(rl_module.get_initial_state()) + # Get STATE_OUTs for all episodes and only keep those (as STATE_INs) that + # are located at the `max_seq_len` edges (state inputs to RNNs only have a + # B-axis, no T-axis). + state_ins = [] + for episode, data_dict in zip(episodes, data_dicts): + # Remove state outs (should not be part of the T-axis rearrangements). + state_outs = data_dict.pop(STATE_OUT) + state_ins.append( + tree.map_structure( + # [::T] = only keep every Tth (max_seq_len) state in. + # [:-1] = shift state outs by one (ignore very last state out, + # but therefore add the init state at the beginning). + lambda i, o: np.concatenate([[i], o[:-1]])[::T], + ( + # Episode has a (reset) beginning -> Prepend initial state. + init_state + if episode.t_started == 0 + # Episode starts somewhere in the middle (is a cut + # continuation chunk) -> Use previous chunk's last STATE_OUT + # as initial state. + else episode.get_extra_model_outputs( + key=STATE_OUT, indices=-1, neg_indices_left_of_zero=True + ) + ), + state_outs, + ) + ) + # Concatenate the individual episodes' STATE_INs. + state_in = tree.map_structure(lambda *s: np.concatenate(s), *state_ins) + + # Before adding anything else to the `data`, add the time axis to existing + # data. + data = tree.map_structure( + lambda s: split_and_pad_single_record(s, episodes, T=T), + data, + ) + + # Set the reduce function for all the data we might still have to extract + # from our list of episodes. This function takes a list of data (e.g. obs) + # with each item in the list representing one episode and properly + # splits along the time axis and zero-pads if necessary (based on + # T=max_seq_len). + reduce_fn = partial(split_and_pad, T=T) + + # No stateful module, normal batch (w/o T-axis or zero-padding). + else: + # Set the reduce function for all the data we might still have to extract + # from our list of episodes. Simply concatenate the data from the different + # episodes along the batch axis (axis=0). + reduce_fn = np.concatenate + + # Extract all data from the episodes and add to `data`, if not already in + # `data`. + for key in [ + SampleBatch.OBS, + SampleBatch.ACTIONS, + SampleBatch.REWARDS, + SampleBatch.TERMINATEDS, + SampleBatch.TRUNCATEDS, + SampleBatch.T, # TODO: remove (normally not needed in train batch) + *episodes[0].extra_model_outputs.keys(), + ]: + if key not in data and key != STATE_OUT: + # Concatenate everything together (along B-axis=0). + data[key] = tree.map_structure( + lambda *s: reduce_fn(s), + *[d[key] for d in data_dicts], + ) + + # Handle infos (always lists, not numpy arrays). + if SampleBatch.INFOS not in data: + data[SampleBatch.INFOS] = sum( + [d[SampleBatch.INFOS] for d in data_dicts], + [], + ) + + # Now that all "normal" fields are time-dim'd and zero-padded, add + # the STATE_IN column to `data`. + if rl_module.is_stateful(): + data[STATE_IN] = state_in + # Also, create the loss mask (b/c of our now possibly zero-padded data) as + # well as the seq_lens array and add these to `data` as well. + (data["loss_mask"], data[SampleBatch.SEQ_LENS],) = create_mask_and_seq_lens( + episode_lens=[len(episode) for episode in episodes], + T=T, + ) + + return data + + +def split_and_pad(episodes_data, T): + all_chunks = [] + + for data in episodes_data: + num_chunks = int(np.ceil(data.shape[0] / T)) + + for i in range(num_chunks): + start_index = i * T + end_index = start_index + T + + # Extract the chunk + chunk = data[start_index:end_index] + + # Pad the chunk if it's shorter than T + if chunk.shape[0] < T: + padding_shape = [(0, T - chunk.shape[0])] + [ + (0, 0) for _ in range(chunk.ndim - 1) + ] + chunk = np.pad(chunk, pad_width=padding_shape, mode="constant") + + all_chunks.append(chunk) + + # Combine all chunks into a single array + result = np.concatenate(all_chunks, axis=0) + + # Reshape the array to include the time dimension T. + # The new shape should be (-1, T) + original dimensions (excluding the batch + # dimension) + result = result.reshape((-1, T) + result.shape[1:]) + + return result + + +def split_and_pad_single_record(data, episodes, T): + episodes_data = [] + idx = 0 + for episode in episodes: + len_ = len(episode) + episodes_data.append(data[idx : idx + len_]) + idx += len_ + return split_and_pad(episodes_data, T) + + +def create_mask_and_seq_lens(episode_lens, T): + mask = [] + seq_lens = [] + for episode_len in episode_lens: + len_ = min(episode_len, T) + seq_lens.append(len_) + row = [1] * len_ + [0] * (T - len_) + mask.append(row) + + # Handle sequence lengths greater than T. + overflow = episode_len - T + while overflow > 0: + len_ = min(overflow, T) + seq_lens.append(len_) + extra_row = [1] * len_ + [0] * (T - len_) + mask.append(extra_row) + overflow -= T + + return np.array(mask, dtype=np.bool_), np.array(seq_lens, dtype=np.int32) diff --git a/rllib/connectors/learner/frame_stacking.py b/rllib/connectors/learner/frame_stacking.py new file mode 100644 index 0000000000000..9b4a9f53ad613 --- /dev/null +++ b/rllib/connectors/learner/frame_stacking.py @@ -0,0 +1,6 @@ +from functools import partial + +from ray.rllib.connectors.common.frame_stacking import _FrameStackingConnector + + +FrameStackingLearner = partial(_FrameStackingConnector, as_learner_connector=True) diff --git a/rllib/connectors/learner/learner_connector_pipeline.py b/rllib/connectors/learner/learner_connector_pipeline.py new file mode 100644 index 0000000000000..225b5a4436e06 --- /dev/null +++ b/rllib/connectors/learner/learner_connector_pipeline.py @@ -0,0 +1,5 @@ +from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2 + + +class LearnerConnectorPipeline(ConnectorPipelineV2): + pass diff --git a/rllib/connectors/module_to_env/__init__.py b/rllib/connectors/module_to_env/__init__.py new file mode 100644 index 0000000000000..b7ada36aebdbf --- /dev/null +++ b/rllib/connectors/module_to_env/__init__.py @@ -0,0 +1,9 @@ +from ray.rllib.connectors.module_to_env.default_module_to_env import DefaultModuleToEnv +from ray.rllib.connectors.module_to_env.module_to_env_pipeline import ( + ModuleToEnvPipeline, +) + +__all__ = [ + "DefaultModuleToEnv", + "ModuleToEnvPipeline", +] diff --git a/rllib/connectors/module_to_env/default_module_to_env.py b/rllib/connectors/module_to_env/default_module_to_env.py new file mode 100644 index 0000000000000..e36b4c4c4771b --- /dev/null +++ b/rllib/connectors/module_to_env/default_module_to_env.py @@ -0,0 +1,155 @@ +from typing import Any, List, Optional + +import numpy as np +import tree # pip install dm_tree + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.models.base import STATE_OUT +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import ( + clip_action, + get_base_struct_from_space, + unsquash_action, +) +from ray.rllib.utils.typing import EpisodeType +from ray.util.annotations import PublicAPI + + +@PublicAPI(stability="alpha") +class DefaultModuleToEnv(ConnectorV2): + """Default connector piece added by RLlib to the end of any module-to-env pipeline. + + If necessary, this connector samples actions, given action dist. inputs and a + dist. class. + The connector will only sample from the action distribution, if the + SampleBatch.ACTIONS key cannot be found in `data`. Otherwise, it'll behave + as pass through (noop). If SampleBatch.ACTIONS is not present, but + SampleBatch.ACTION_DIST_INPUTS are, the connector will create a new action + distribution using the RLModule in the connector context and sample from this + distribution (deterministically, if we are not exploring, stochastically, if we + are). + + input_type: INPUT_OUTPUT_TYPES.DICT_OF_MODULE_IDS_TO_DATA + Operates per RLModule as it will have to pull the action distribution from each + in order to sample actions if necessary. Searches for the ACTIONS and + ACTION_DIST_INPUTS keys in a module's outputs and - should ACTIONS not be + found - sample actions from the module's action distribution. + output_type: INPUT_OUTPUT_TYPES.DICT_OF_MODULE_IDS_TO_DATA (same as input: data in, + data out, however, data + out might contain an additional ACTIONS key if it was not previously present + in the input). + """ + + def __init__( + self, + *, + normalize_actions: bool, + clip_actions: bool, + **kwargs, + ): + """Initializes a DefaultModuleToEnv (connector piece) instance. + + Args: + normalize_actions: If True, actions coming from the RLModule's distribution + (or are directly computed by the RLModule w/o sampling) will + be assumed 0.0 centered with a small stddev (only affecting Box + components) and thus be unsquashed (and clipped, just in case) to the + bounds of the env's action space. For example, if the action space of + the environment is `Box(-2.0, -0.5, (1,))`, the model outputs + mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9 + from the resulting distribution, then this 0.9 will be unsquashed into + the [-2.0 -0.5] interval. If - after unsquashing - the action still + breaches the action space, it will simply be clipped. + clip_actions: If True, actions coming from the RLModule's distribution + (or are directly computed by the RLModule w/o sampling) will be clipped + such that they fit into the env's action space's bounds. + For example, if the action space of the environment is + `Box(-0.5, 0.5, (1,))`, the model outputs + mean and stddev as 0.1 and exp(0.2), and we sample an action of 0.9 + from the resulting distribution, then this 0.9 will be clipped to 0.5 + to fit into the [-0.5 0.5] interval. + """ + super().__init__(**kwargs) + + self._action_space_struct = get_base_struct_from_space(self.action_space) + self.normalize_actions = normalize_actions + self.clip_actions = clip_actions + + @override(ConnectorV2) + def __call__( + self, + *, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + + # Loop through all modules that created some output. + # for mid in data.keys(): + # sa_module = ctx.rl_module.get_module(module_id=mid) + + # If our RLModule is stateful, remove the T=1 axis from all model outputs + # (except the state outs, which never have this extra time axis). + if rl_module.is_stateful(): + state = data.pop(STATE_OUT, None) + data = tree.map_structure(lambda s: np.squeeze(s, axis=1), data) + if state: + data[STATE_OUT] = state + + # ACTION_DIST_INPUTS field returned by `forward_exploration|inference()` -> + # Create a new action distribution object. + action_dist = None + if SampleBatch.ACTION_DIST_INPUTS in data: + if explore: + action_dist_class = rl_module.get_exploration_action_dist_cls() + else: + action_dist_class = rl_module.get_inference_action_dist_cls() + action_dist = action_dist_class.from_logits( + data[SampleBatch.ACTION_DIST_INPUTS] + ) + + # TODO (sven): Should this not already be taken care of by RLModule's + # `get_...action_dist_cls()` methods? + if not explore: + action_dist = action_dist.to_deterministic() + + # If `forward_...()` returned actions, use them here as-is. + if SampleBatch.ACTIONS in data: + actions = data[SampleBatch.ACTIONS] + # Otherwise, sample actions from the distribution. + else: + if action_dist is None: + raise KeyError( + "Your RLModule's `forward_[exploration|inference]()` methods must " + f"return a dict with either the '{SampleBatch.ACTIONS}' key or " + f"the '{SampleBatch.ACTION_DIST_INPUTS}' key in it (or both)!" + ) + actions = action_dist.sample() + + # For convenience and if possible, compute action logp from distribution + # and add to output. + if action_dist is not None and SampleBatch.ACTION_LOGP not in data: + data[SampleBatch.ACTION_LOGP] = convert_to_numpy(action_dist.logp(actions)) + + actions = convert_to_numpy(actions) + + # Process actions according to Env's action space bounds, if necessary. + # Normalize actions. + if self.normalize_actions: + actions = unsquash_action(actions, self._action_space_struct) + # Clip actions. + elif self.clip_actions: + actions = clip_action(actions, self._action_space_struct) + + data[SampleBatch.ACTIONS] = actions + + # Convert everything into numpy. + data = convert_to_numpy(data) + + return data diff --git a/rllib/connectors/module_to_env/module_to_env_pipeline.py b/rllib/connectors/module_to_env/module_to_env_pipeline.py new file mode 100644 index 0000000000000..e0a11fdac4a63 --- /dev/null +++ b/rllib/connectors/module_to_env/module_to_env_pipeline.py @@ -0,0 +1,5 @@ +from ray.rllib.connectors.connector_pipeline_v2 import ConnectorPipelineV2 + + +class ModuleToEnvPipeline(ConnectorPipelineV2): + pass diff --git a/rllib/connectors/utils/zero_padding.py b/rllib/connectors/utils/zero_padding.py new file mode 100644 index 0000000000000..e34c0eab85cc5 --- /dev/null +++ b/rllib/connectors/utils/zero_padding.py @@ -0,0 +1,135 @@ +from typing import List, Tuple + +import numpy as np + + +def create_mask_and_seq_lens( + episode_lens: List[int], + T: int, +) -> Tuple[np._typing.NDArray, np._typing.NDArray]: + """Creates loss mask and a seq_lens array, given a list of episode lengths and T. + + Args: + episode_lens: A list of episode lengths to infer the loss mask and seq_lens + array from. + T: The maximum number of timesteps in each "row", also known as the maximum + sequence length (max_seq_len). Episodes are split into chunks that are at + most `T` long and remaining timesteps will be zero-padded (and masked out). + + Returns: + Tuple consisting of a) the loss mask to use (masking out areas that are past + the end of an episode (or rollout), but had to be zero-added due to the added + extra time rank (of length T) and b) the array of sequence lengths resulting + from splitting the given episodes into chunks of at most `T` timesteps. + """ + mask = [] + seq_lens = [] + for episode_len in episode_lens: + len_ = min(episode_len, T) + seq_lens.append(len_) + row = [1] * len_ + [0] * (T - len_) + mask.append(row) + + # Handle sequence lengths greater than T. + overflow = episode_len - T + while overflow > 0: + len_ = min(overflow, T) + seq_lens.append(len_) + extra_row = [1] * len_ + [0] * (T - len_) + mask.append(extra_row) + overflow -= T + + return np.array(mask, dtype=np.bool_), np.array(seq_lens, dtype=np.int32) + + +def split_and_pad(data_chunks: List[np._typing.NDArray], T: int) -> np._typing.NDArray: + """Splits and zero-pads data from episodes into a single ndarray with a fixed T-axis. + + Processes each data chunk in `data_chunks`, coming from one episode by splitting + the chunk into smaller sub-chunks, each of a maximum size `T`. If a sub-chunk is + smaller than `T`, it is right-padded with zeros to match the desired size T. + All sub-chunks are then re-combined (concatenated) into a single ndarray, which is + reshaped to include the new time dimension `T` as axis 1 (axis 0 is the batch + axis). The resulting output array has dimensions (B=number of sub-chunks, T, ...), + where '...' represents the original dimensions of the input data (excluding the + batch dimension). + + Args: + data_chunks: A list where each element is a NumPy array representing + an episode. Each array's shape should be (episode_length, ...) + where '...' represents any number of additional dimensions. + T: The desired time dimension size for each chunk. + + Returns: + A np.ndarray containing the reshaped and padded chunks. The shape of the + array will be (B, T, ...) where B is automatically determined by the number + of chunks in `data_chunks` and `T`. + '...' represents the original dimensions of the input data, excluding the + batch dimension. + """ + all_chunks = [] + + for data_chunk in data_chunks: + num_sub_chunks = int(np.ceil(data_chunk.shape[0] / T)) + + for i in range(num_sub_chunks): + start_index = i * T + end_index = start_index + T + + # Extract the chunk. + sub_chunk = data_chunk[start_index:end_index] + + # Pad the chunk if it's shorter than T + if sub_chunk.shape[0] < T: + padding_shape = [(0, T - sub_chunk.shape[0])] + [ + (0, 0) for _ in range(sub_chunk.ndim - 1) + ] + sub_chunk = np.pad(sub_chunk, pad_width=padding_shape, mode="constant") + + all_chunks.append(sub_chunk) + + # Combine all chunks into a single array. + result = np.concatenate(all_chunks, axis=0) + + # Reshape the array to include the time dimension T. + # The new shape should be (-1, T) + original dimensions (excluding the + # batch dimension). + result = result.reshape((-1, T) + result.shape[1:]) + + return result + + +def split_and_pad_single_record( + data: np._typing.NDArray, episode_lengths: List[int], T: int +): + """See `split_and_pad`, but initial data has already been concatenated over episodes. + + Given an np.ndarray of data that is the result of a concatenation of data chunks + coming from different episodes, the lengths of these episodes, as well as the + maximum time dimension, split and possibly right-zero-pad this input data, such that + the resulting shape of the returned np.ndarray is (B', T, ...), where B' is the + number of generated sub-chunks and ... is the original shape of the data (excluding + the batch dim). T is the size of the newly inserted time axis (on which zero-padding + is applied if necessary). + + Args: + data: The single np.ndarray input data to be split, zero-added, and reshaped. + episode_lengths: The list of episode lengths, from which `data` was originally + concat'd. + T: The maximum number of timesteps on the T-axis in the resulting np.ndarray. + + Returns: + A single np.ndarray, which contains the same data as `data`, but split into sub- + chunks of max. size T (zero-padded if necessary at the end of individual + episodes), then reshaped to (B', T, ...). + """ + # Chop up `data` into chunks of max len=T, based on the lengths of the episodes + # where this data came from. + episodes_data = [] + idx = 0 + for episode_len in episode_lengths: + episodes_data.append(data[idx : idx + episode_len]) + idx += episode_len + # Send everything through `split_and_pad` to perform the actual splitting into + # sub-chunks of max len=T and zero-padding. + return split_and_pad(episodes_data, T) diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index 6e229b5f299a8..c022909120794 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -215,6 +215,7 @@ def get_parameters(self, module: RLModule) -> Sequence[Param]: @override(Learner) def _convert_batch_type(self, batch: MultiAgentBatch) -> MultiAgentBatch: batch = convert_to_torch_tensor(batch.policy_batches, device=self._device) + # TODO (sven): This computation of `env_steps` is not accurate! length = max(len(b) for b in batch.values()) batch = MultiAgentBatch(batch, env_steps=length) return batch diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index b956343babae7..aaf0f1d9fb83d 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -55,7 +55,6 @@ class Catalog: from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.core.models.catalog import Catalog - class MyCatalog(Catalog): def __init__( self, @@ -64,17 +63,19 @@ def __init__( model_config_dict: dict, ): super().__init__(observation_space, action_space, model_config_dict) - self.my_model_config_dict = MLPHeadConfig( + self.my_model_config = MLPHeadConfig( hidden_layer_dims=[64, 32], input_dims=[self.observation_space.shape[0]], ) def build_my_head(self, framework: str): - return self.my_model_config_dict.build(framework=framework) + return self.my_model_config.build(framework=framework) # With that, RLlib can build and use models from this catalog like this: catalog = MyCatalog(gym.spaces.Box(0, 1), gym.spaces.Box(0, 1), {}) - my_head = catalog.build_my_head("torch") + my_head = catalog.build_my_head(framework="torch") + + # Make a call to the built model. out = my_head(torch.Tensor([[1]])) """ @@ -348,7 +349,7 @@ def get_tokenizer_config( ) -> ModelConfig: """Returns a tokenizer config for the given space. - This is useful for recurrent / tranformer models that need to tokenize their + This is useful for recurrent / transformer models that need to tokenize their inputs. By default, RLlib uses the models supported by Catalog out of the box to tokenize. diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index dd90c5af02a35..5d5ee38ed8d5b 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -175,7 +175,7 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: assert len(gru_input_dims) == 1 gru_input_dim = gru_input_dims[0] - # Create the torch LSTM layer. + # Create the torch GRU layer. self.gru = nn.GRU( gru_input_dim, config.hidden_dim, diff --git a/rllib/env/wrappers/atari_wrappers.py b/rllib/env/wrappers/atari_wrappers.py index 0dfd74729efae..fb4fa762c819a 100644 --- a/rllib/env/wrappers/atari_wrappers.py +++ b/rllib/env/wrappers/atari_wrappers.py @@ -240,6 +240,23 @@ def reset(self, **kwargs): return self.env.reset(**kwargs) +@PublicAPI +class NormalizedImageEnv(gym.ObservationWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.observation_space = gym.spaces.Box( + -1.0, + 1.0, + shape=self.observation_space.shape, + dtype=np.float32, + ) + + # Divide by scale and center around 0.0, such that observations are in the range + # of -1.0 and 1.0. + def observation(self, observation): + return (observation.astype(np.float32) / 128.0) - 1.0 + + @PublicAPI class WarpFrame(gym.ObservationWrapper): def __init__(self, env, dim): @@ -266,8 +283,8 @@ def __init__(self, env, k): self.frames = deque([], maxlen=k) shp = env.observation_space.shape self.observation_space = spaces.Box( - low=0, - high=255, + low=np.repeat(env.observation_space.low, repeats=k, axis=-1), + high=np.repeat(env.observation_space.high, repeats=k, axis=-1), shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype, ) diff --git a/rllib/examples/connectors/TODO_MOVE_OLD_CONNECTOR_EXAMPLES_TO_SEPARATE_FOLDER.txt b/rllib/examples/connectors/TODO_MOVE_OLD_CONNECTOR_EXAMPLES_TO_SEPARATE_FOLDER.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/examples/connectors/connector_v2_frame_stacking.py b/rllib/examples/connectors/connector_v2_frame_stacking.py new file mode 100644 index 0000000000000..1119c2539bdd3 --- /dev/null +++ b/rllib/examples/connectors/connector_v2_frame_stacking.py @@ -0,0 +1,178 @@ +import argparse +from functools import partial + +import gymnasium as gym + +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.connectors.env_to_module.frame_stacking import FrameStackingEnvToModule +from ray.rllib.connectors.learner.frame_stacking import FrameStackingLearner +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner +from ray.rllib.env.wrappers.atari_wrappers import ( + EpisodicLifeEnv, + # FrameStack, # <- we do not want env-based frame stacking + MaxAndSkipEnv, + NoopResetEnv, + NormalizedImageEnv, + WarpFrame, # gray + resize +) +from ray.rllib.utils.test_utils import check_learning_achieved + + +parser = argparse.ArgumentParser() +parser.add_argument("--num-cpus", type=int, default=0) +parser.add_argument( + "--framework", + choices=["tf", "tf2", "torch"], + default="torch", + help="The DL framework specifier.", +) +parser.add_argument( + "--num-gpus", + type=int, + default=0, + help="The number of GPUs (Learner workers) to use.", +) +parser.add_argument( + "--num-frames", + type=int, + default=4, + help="The number of observation frames to stack.", +) +parser.add_argument( + "--as-test", + action="store_true", + help="Whether this script should be run as a test: --stop-reward must " + "be achieved within --stop-timesteps AND --stop-iters.", +) +parser.add_argument( + "--stop-iters", type=int, default=2000, help="Number of iterations to train." +) +parser.add_argument( + "--stop-timesteps", type=int, default=2000000, help="Number of timesteps to train." +) +parser.add_argument( + "--stop-reward", type=float, default=20.0, help="Reward at which we stop training." +) + + +if __name__ == "__main__": + import ray + from ray import air, tune + + args = parser.parse_args() + + ray.init() + + # Define our custom connector pipelines. + def _make_env_to_module_connector(env): + # Create the env-to-module connector. We return an individual connector piece + # here, which RLlib will then automatically integrate into a pipeline (and + # add its default connector piece to the end of that pipeline). + return FrameStackingEnvToModule( + input_observation_space=env.single_observation_space, + input_action_space=env.single_action_space, + num_frames=args.num_frames, + ) + + def _make_learner_connector(input_observation_space, input_action_space): + # Create the learner connector. + return FrameStackingLearner( + input_observation_space=input_observation_space, + input_action_space=input_action_space, + num_frames=args.num_frames, + ) + + # Create a custom Atari setup (w/o the usual RLlib-hard-coded framestacking in it). + # We would like our frame stacking connector to do this job. + tune.register_env( + "env", + ( + lambda cfg: ( + EpisodicLifeEnv( # each life is one episode + MaxAndSkipEnv( # frameskip=4 and take max over these 4 frames + NoopResetEnv( # perform n noops after a reset + # partial(FrameStack, k=4)( # <- no env-based framestacking + NormalizedImageEnv( + partial(WarpFrame, dim=64)( # grayscale + resize + partial( + gym.wrappers.TimeLimit, max_episode_steps=108000 + )( + gym.make( + "ALE/Pong-v5", + **dict(cfg, **{"render_mode": "rgb_array"}) + ) + ) + ) + ) + ) + ) + ) + ) + ), + ) + + config = ( + PPOConfig() + .framework(args.framework) + .environment( + "env", + env_config={ + # Make analogous to old v4 + NoFrameskip. + "frameskip": 1, + "full_action_space": False, + "repeat_action_probability": 0.0, + }, + clip_rewards=True, + ) + # Use new API stack ... + .experimental(_enable_new_api_stack=True) + .rollouts( + # ... new EnvRunner and our frame stacking env-to-module connector. + env_runner_cls=SingleAgentEnvRunner, + env_to_module_connector=_make_env_to_module_connector, + ) + .resources( + num_learner_workers=args.num_gpus, + num_gpus_per_learner_worker=1 if args.num_gpus else 0, + num_cpus_for_local_worker=1, + ) + .training( + # Use our frame stacking learner connector. + learner_connector=_make_learner_connector, + lambda_=0.95, + kl_coeff=0.5, + clip_param=0.1, + vf_clip_param=10.0, + entropy_coeff=0.01, + num_sgd_iter=10, + # Linearly adjust learning rate based on number of GPUs. + lr=0.00015 * (args.num_gpus or 1), + grad_clip=100.0, + grad_clip_by="global_norm", + model={ + "vf_share_layers": True, + "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + "conv_activation": "relu", + "post_fcnet_hiddens": [256], + }, + ) + ) + + stop = { + "training_iteration": args.stop_iters, + "timesteps_total": args.stop_timesteps, + "episode_reward_mean": args.stop_reward, + } + + tuner = tune.Tuner( + config.algo_class, + param_space=config, + run_config=air.RunConfig(stop=stop), + tune_config=tune.TuneConfig(num_samples=1), + ) + results = tuner.fit() + + if args.as_test: + check_learning_achieved(results, args.stop_reward) + + ray.shutdown() diff --git a/rllib/utils/filter_manager.py b/rllib/utils/filter_manager.py index e4b71af66d09e..8bcba09793421 100644 --- a/rllib/utils/filter_manager.py +++ b/rllib/utils/filter_manager.py @@ -29,7 +29,7 @@ def synchronize( Args: local_filters: Filters to be synchronized. - remotes: Remote evaluators with filters. + worker_set: WorkerSet with remote EnvRunners with filters. update_remote: Whether to push updates from the local filters to the remote workers' filters. timeout_seconds: How long to wait for filter to get or set filters diff --git a/rllib/utils/numpy.py b/rllib/utils/numpy.py index 9f040c8a0c286..944d4b758c8c4 100644 --- a/rllib/utils/numpy.py +++ b/rllib/utils/numpy.py @@ -7,11 +7,7 @@ from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.deprecation import ( - DEPRECATED_VALUE, - deprecation_warning, - Deprecated, -) +from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.typing import SpaceStruct, TensorType, TensorStructType, Union @@ -122,9 +118,7 @@ def concat_aligned( @PublicAPI -def convert_to_numpy( - x: TensorStructType, reduce_type: bool = True, reduce_floats=DEPRECATED_VALUE -): +def convert_to_numpy(x: TensorStructType, reduce_type: bool = True) -> TensorStructType: """Converts values in `stats` to non-Tensor numpy or python types. Args: @@ -139,10 +133,6 @@ def convert_to_numpy( values converted to numpy arrays (on CPU). """ - if reduce_floats != DEPRECATED_VALUE: - deprecation_warning(old="reduce_floats", new="reduce_types", error=True) - reduce_type = reduce_floats - # The mapping function used to numpyize torch/tf Tensors (and move them # to the CPU beforehand). def mapping(item): diff --git a/rllib/utils/tests/test_minibatch_utils.py b/rllib/utils/tests/test_minibatch_utils.py index a8d8180d05129..0256e9ffab311 100644 --- a/rllib/utils/tests/test_minibatch_utils.py +++ b/rllib/utils/tests/test_minibatch_utils.py @@ -93,8 +93,8 @@ def test_minibatch_cyclic_iterator(self): check(policy_batch.count, mini_batch_size) iteration_counter += 1 - # for each policy check that the last item in batch matches the expected - # values, i.e. iteration_counter * mini_batch_size % agent_steps - 1 + # For each policy check that the last item in batch matches the expected + # values, i.e. iteration_counter * mini_batch_size % agent_steps - 1. total_steps = iteration_counter * mini_batch_size for policy_idx, policy_batch in enumerate( batch.policy_batches.values() @@ -104,8 +104,8 @@ def test_minibatch_cyclic_iterator(self): expected_last_item = 0.0 check(policy_batch["obs"][-1], expected_last_item) - # check iteration counter (should be - # ceil(num_gsd_iter * max(agent_steps) / mini_batch_size)) + # Check iteration counter (should be + # ceil(num_gsd_iter * max(agent_steps) / mini_batch_size)). expected_iteration_counter = np.ceil( num_sgd_iter * max(agent_steps) / mini_batch_size ) diff --git a/rllib/utils/torch_utils.py b/rllib/utils/torch_utils.py index 0a56abf83a502..68c8ebda458e3 100644 --- a/rllib/utils/torch_utils.py +++ b/rllib/utils/torch_utils.py @@ -217,8 +217,8 @@ def convert_to_torch_tensor(x: TensorStructType, device: Optional[str] = None): Returns: Any: A new struct with the same structure as `x`, but with all - values converted to torch Tensor types. This does not convert possibly - nested elements that are None because torch has no representation for that. + values converted to torch Tensor types. This does not convert possibly + nested elements that are None because torch has no representation for that. """ def mapping(item):