From 017dcfc592ac62fd48489719599960d7326bfcb8 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Fri, 10 Nov 2023 15:16:27 +0100 Subject: [PATCH 1/3] wip Signed-off-by: sven1977 --- rllib/algorithms/algorithm.py | 5 +- rllib/algorithms/algorithm_config.py | 33 +++++++--- rllib/core/learner/learner.py | 2 + rllib/core/learner/learner_group.py | 2 + rllib/core/models/torch/encoder.py | 60 ++++++++----------- rllib/core/rl_module/rl_module.py | 4 +- rllib/env/env_runner.py | 1 - rllib/env/multi_agent_episode.py | 44 +++++++------- rllib/env/single_agent_env_runner.py | 45 ++++++-------- rllib/env/tests/test_single_agent_episode.py | 2 +- rllib/evaluation/worker_set.py | 5 +- ...nt-cartpole-crashing-restart-env-appo.yaml | 2 +- rllib/utils/spaces/space_utils.py | 50 ++++++++++++++++ rllib/utils/typing.py | 8 ++- 14 files changed, 161 insertions(+), 102 deletions(-) diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index eac1d06eb435f..a06666a68bf8b 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -1932,7 +1932,10 @@ def compute_actions( filtered_obs, filtered_state = [], [] for agent_id, ob in observations.items(): worker = self.workers.local_worker() - preprocessed = worker.preprocessors[policy_id].transform(ob) + if worker.preprocessors.get(policy_id) is not None: + preprocessed = worker.preprocessors[policy_id].transform(ob) + else: + preprocessed = ob filtered = worker.filters[policy_id](preprocessed, update=False) filtered_obs.append(filtered) if state is None: diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index b7d28a8a4ee84..42ad1ff6e256f 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -319,26 +319,37 @@ def __init__(self, algo_class=None): # If not specified, we will try to auto-detect this. self._is_atari = None + # TODO (sven): Rename this method into `AlgorithmConfig.sampling()` # `self.rollouts()` self.env_runner_cls = None + # TODO (sven): Rename into `num_env_runner_workers`. self.num_rollout_workers = 0 self.num_envs_per_worker = 1 - self.sample_collector = SimpleListCollector self.create_env_on_local_worker = False - self.sample_async = False self.enable_connectors = True - self.update_worker_filter_stats = True - self.use_worker_filter_stats = True + # TODO (sven): Rename into `sample_timesteps` (or `sample_duration` + # and `sample_duration_unit` (replacing batch_mode), like we do it + # in the evaluation config). self.rollout_fragment_length = 200 + # TODO (sven): Rename into `sample_mode`. self.batch_mode = "truncate_episodes" + # TODO (sven): Rename into `validate_env_runner_workers_after_construction`. + self.validate_workers_after_construction = True + self.compress_observations = False + # TODO (sven): Rename into `env_runner_perf_stats_ema_coef`. + self.sampler_perf_stats_ema_coef = None + + # TODO (sven): Deprecate together with old API stack. + self.sample_async = False self.remote_worker_envs = False self.remote_env_batch_wait_ms = 0 - self.validate_workers_after_construction = True + self.enable_tf1_exec_eagerly = False + self.sample_collector = SimpleListCollector self.preprocessor_pref = "deepmind" self.observation_filter = "NoFilter" - self.compress_observations = False - self.enable_tf1_exec_eagerly = False - self.sampler_perf_stats_ema_coef = None + self.update_worker_filter_stats = True + self.use_worker_filter_stats = True + # TODO (sven): End: deprecate. # `self.training()` self.gamma = 0.99 @@ -890,7 +901,7 @@ def validate(self) -> None: error=True, ) - # RLModule API only works with connectors and with Learner API. + # New API stack (RLModule, Learner APIs) only works with connectors. if not self.enable_connectors and self._enable_new_api_stack: raise ValueError( "The new API stack (RLModule and Learner APIs) only works with " @@ -937,6 +948,8 @@ def validate(self) -> None: "https://github.com/ray-project/ray/issues/35409 for more details." ) + # TODO (sven): Remove this hack. We should not have env-var dependent logic + # in the codebase. if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)): # Enable RLModule API and connectors if env variable is set # (to be used in unittesting) @@ -1764,6 +1777,8 @@ def training( dashboard. If you're seeing that the object store is filling up, turn down the number of remote requests in flight, or enable compression in your experiment of timesteps. + learner_class: The `Learner` class to use for (distributed) updating of the + RLModule. Only used when `_enable_new_api_stack=True`. Returns: This updated AlgorithmConfig object. diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 1a0e6515f6e0d..41575149c3c4a 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -63,6 +63,7 @@ ResultDict, TensorType, ) +from ray.util.annotations import PublicAPI if TYPE_CHECKING: from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig @@ -226,6 +227,7 @@ def get_hps_for_module(self, module_id: ModuleID) -> "LearnerHyperparameters": return self +@PublicAPI(stability="alpha") class Learner: """Base class for Learners. diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index f8d89fdc8d3ae..60088b0e29d8f 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -29,6 +29,7 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.train._internal.backend_executor import BackendExecutor from ray.tune.utils.file_transfer import sync_dir_between_nodes +from ray.util.annotations import PublicAPI if TYPE_CHECKING: @@ -58,6 +59,7 @@ def _is_module_trainable(module_id: ModuleID, batch: MultiAgentBatch) -> bool: return True +@PublicAPI(stability="alpha") class LearnerGroup: """Coordinator of Learners. diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index 2b2aa7c376678..16f02be09efbe 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -285,30 +285,31 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: bias=config.use_bias, ) + self.state_in_out_spec = { + "h": TensorSpec( + "b, l, d", + d=self.config.hidden_dim, + l=self.config.num_layers, + framework="torch", + ), + "c": TensorSpec( + "b, l, d", + d=self.config.hidden_dim, + l=self.config.num_layers, + framework="torch", + ), + } + + @override(Model) def get_input_specs(self) -> Optional[Spec]: - return SpecDict( - { - # b, t for batch major; t, b for time major. - SampleBatch.OBS: TensorSpec( - "b, t, d", d=self.config.input_dims[0], framework="torch" - ), - STATE_IN: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - "c": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - }, - } - ) + return SpecDict({ + # b, t for batch major; t, b for time major. + SampleBatch.OBS: TensorSpec( + "b, t, d", d=self.config.input_dims[0], framework="torch" + ), + STATE_IN: self.state_in_out_spec, + }) @override(Model) def get_output_specs(self) -> Optional[Spec]: @@ -317,20 +318,7 @@ def get_output_specs(self) -> Optional[Spec]: ENCODER_OUT: TensorSpec( "b, t, d", d=self.config.output_dims[0], framework="torch" ), - STATE_OUT: { - "h": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - "c": TensorSpec( - "b, l, h", - h=self.config.hidden_dim, - l=self.config.num_layers, - framework="torch", - ), - }, + STATE_OUT: self.state_in_out_spec, } ) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 9d6247b9ca5a0..1b2d3c426e772 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -474,9 +474,9 @@ def get_initial_state(self) -> Any: @OverrideToImplementCustomLogic def is_stateful(self) -> bool: - """Returns True if the initial state is empty. + """Returns False if the initial state is an empty dict (or None). - By default, RLlib assumes that the module is not recurrent if the initial + By default, RLlib assumes that the module is non-recurrent if the initial state is an empty dict and recurrent otherwise. This behavior can be overridden by implementing this method. """ diff --git a/rllib/env/env_runner.py b/rllib/env/env_runner.py index 7c165386df5ac..2cae4089272c1 100644 --- a/rllib/env/env_runner.py +++ b/rllib/env/env_runner.py @@ -1,5 +1,4 @@ import abc - from typing import Any, Dict, TYPE_CHECKING from ray.rllib.utils.actor_manager import FaultAwareApply diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index e64edbd08a59a..249646d95513e 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -212,75 +212,75 @@ def get_observations( return self._getattr_by_index("observations", indices, global_ts) - def get_actions( + def get_infos( self, indices: Union[int, List[int]] = -1, global_ts: bool = True ) -> MultiAgentDict: - """Gets actions for all agents that stepped in the last timesteps. + """Gets infos for all agents that stepped in the last timesteps. - Note that actions are only returned for agents that stepped + Note that infos are only returned for agents that stepped during the given index range. Args: indices: Either a single index or a list of indices. The indices can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). - This defines the time indices for which the actions + This defines the time indices for which the infos should be returned. global_ts: Boolean that defines, if the indices should be considered environment (`True`) or agent (`False`) steps. - Returns: A dictionary mapping agent ids to actions (of different + Returns: A dictionary mapping agent ids to infos (of different timesteps). Only for agents that have stepped (were ready) at a - timestep, actions are returned (i.e. not all agent ids are + timestep, infos are returned (i.e. not all agent ids are necessarily in the keys). """ + return self._getattr_by_index("infos", indices, global_ts) - return self._getattr_by_index("actions", indices, global_ts) - - def get_rewards( + def get_actions( self, indices: Union[int, List[int]] = -1, global_ts: bool = True ) -> MultiAgentDict: - """Gets rewards for all agents that stepped in the last timesteps. + """Gets actions for all agents that stepped in the last timesteps. - Note that rewards are only returned for agents that stepped + Note that actions are only returned for agents that stepped during the given index range. Args: indices: Either a single index or a list of indices. The indices can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). - This defines the time indices for which the rewards + This defines the time indices for which the actions should be returned. global_ts: Boolean that defines, if the indices should be considered environment (`True`) or agent (`False`) steps. - Returns: A dictionary mapping agent ids to rewards (of different + Returns: A dictionary mapping agent ids to actions (of different timesteps). Only for agents that have stepped (were ready) at a - timestep, rewards are returned (i.e. not all agent ids are + timestep, actions are returned (i.e. not all agent ids are necessarily in the keys). """ - return self._getattr_by_index("rewards", indices, global_ts) - def get_infos( + return self._getattr_by_index("actions", indices, global_ts) + + def get_rewards( self, indices: Union[int, List[int]] = -1, global_ts: bool = True ) -> MultiAgentDict: - """Gets infos for all agents that stepped in the last timesteps. + """Gets rewards for all agents that stepped in the last timesteps. - Note that infos are only returned for agents that stepped + Note that rewards are only returned for agents that stepped during the given index range. Args: indices: Either a single index or a list of indices. The indices can be reversed (e.g. [-1, -2]) or absolute (e.g. [98, 99]). - This defines the time indices for which the infos + This defines the time indices for which the rewards should be returned. global_ts: Boolean that defines, if the indices should be considered environment (`True`) or agent (`False`) steps. - Returns: A dictionary mapping agent ids to infos (of different + Returns: A dictionary mapping agent ids to rewards (of different timesteps). Only for agents that have stepped (were ready) at a - timestep, infos are returned (i.e. not all agent ids are + timestep, rewards are returned (i.e. not all agent ids are necessarily in the keys). """ - return self._getattr_by_index("infos", indices, global_ts) + return self._getattr_by_index("rewards", indices, global_ts) def get_extra_model_outputs( self, indices: Union[int, List[int]] = -1, global_ts: bool = True diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 0146e114bbc26..92d966a64602a 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -23,7 +23,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig # TODO (sven): This gives a tricky circular import that goes - # deep into the library. We have to see, where to dissolve it. + # deep into the library. We have to see, where to dissolve it. from ray.rllib.env.single_agent_episode import SingleAgentEpisode _, tf, _ = try_import_tf() @@ -41,56 +41,48 @@ def __init__(self, config: "AlgorithmConfig", **kwargs): # Get the worker index on which this instance is running. self.worker_index: int = kwargs.get("worker_index") + # Create the vectorized gymnasium env. + # Register env for the local context. # Note, `gym.register` has to be called on each worker. - gym.register( - "custom-env-v0", - partial( + if ( + isinstance(self.config.env, str) + and _global_registry.contains(ENV_CREATOR, self.config.env) + ): + entry_point = partial( _global_registry.get(ENV_CREATOR, self.config.env), self.config.env_config, ) - if _global_registry.contains(ENV_CREATOR, self.config.env) - else partial( + + else: + entry_point = partial( _gym_env_creator, env_context=self.config.env_config, env_descriptor=self.config.env, - ), - ) + ) + gym.register("rllib-single-agent-env-runner-v0", entry_point=entry_point) - # Create the vectorized gymnasium env. # Wrap into `VectorListInfo`` wrapper to get infos as lists. self.env: gym.Wrapper = gym.wrappers.VectorListInfo( gym.vector.make( - "custom-env-v0", + "rllib-single-agent-env-runner-v0", num_envs=self.config.num_envs_per_worker, asynchronous=self.config.remote_worker_envs, ) ) - self.num_envs: int = self.env.num_envs assert self.num_envs == self.config.num_envs_per_worker - # Create our own instance of the single-agent `RLModule` (which + # Create our own instance of the (single-agent) `RLModule` (which # the needs to be weight-synched) each iteration. - # TODO (sven, simon): We need to get rid here of the policy_dict, - # but the 'RLModule' takes the 'policy_spec.observation_space' - # from it. - # Below is the non nice solution. - # policy_dict, _ = self.config.get_multi_agent_setup(env=self.env) module_spec: SingleAgentRLModuleSpec = self.config.get_default_rl_module_spec() module_spec.observation_space = self.env.envs[0].observation_space # TODO (simon): The `gym.Wrapper` for `gym.vector.VectorEnv` should - # actually hold the spaces for a single env, but for boxes the - # shape is (1, 1) which brings a problem with the action dists. - # shape=(1,) is expected. + # actually hold the spaces for a single env, but for boxes the + # shape is (1, 1) which brings a problem with the action dists. + # shape=(1,) is expected. module_spec.action_space = self.env.envs[0].action_space module_spec.model_config_dict = self.config.model - - # TODO (sven): By time the `AlgorithmConfig` will get rid of `PolicyDict` - # as well. Then we have to change this function parameter. - # module_spec: MultiAgentRLModuleSpec = self.config.get_marl_module_spec( - # policy_dict=module_dict - # ) self.module: RLModule = module_spec.build() # This should be the default. @@ -208,7 +200,6 @@ def _sample_timesteps( # Loop through env in enumerate.(self._episodes): ts = 0 - # print(f"EnvRunner {self.worker_index}: {self.module.weights[0][0][0]}") while ts < num_timesteps: # Act randomly. if random_actions: diff --git a/rllib/env/tests/test_single_agent_episode.py b/rllib/env/tests/test_single_agent_episode.py index dc549e4f5589d..f15a39bb0b6d6 100644 --- a/rllib/env/tests/test_single_agent_episode.py +++ b/rllib/env/tests/test_single_agent_episode.py @@ -9,7 +9,7 @@ from ray.rllib.env.single_agent_episode import SingleAgentEpisode # TODO (simon): Add to the tests `info` and `extra_model_outputs` -# as soon as #39732 is merged. +# as soon as #39732 is merged. class TestEnv(gym.Env): diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 8eeed101af997..75499ff66a461 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -21,11 +21,11 @@ from ray.exceptions import RayActorError from ray.rllib.core.learner import LearnerGroup from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.env.env_runner import EnvRunner from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.utils.actor_manager import RemoteCallResults from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner from ray.rllib.offline import get_dataset_and_shards from ray.rllib.policy.policy import Policy, PolicyState from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID @@ -690,6 +690,9 @@ def foreach_worker( if local_worker and self.local_worker() is not None: local_result = [func(self.local_worker())] + if not self.__worker_manager.actor_ids(): + return local_result + remote_results = self.__worker_manager.foreach_actor( func, healthy_only=healthy_only, diff --git a/rllib/tuned_examples/appo/multi-agent-cartpole-crashing-restart-env-appo.yaml b/rllib/tuned_examples/appo/multi-agent-cartpole-crashing-restart-env-appo.yaml index 34132fc02c922..89e8e4faf1a81 100644 --- a/rllib/tuned_examples/appo/multi-agent-cartpole-crashing-restart-env-appo.yaml +++ b/rllib/tuned_examples/appo/multi-agent-cartpole-crashing-restart-env-appo.yaml @@ -30,7 +30,7 @@ multi-agent-cartpole-crashing-appo: # Switch on resiliency for failed sub environments (within a vectorized stack). restart_failed_sub_environments: true - # Switch on evaluation workers being managed by AsyncRequestsManager object. + # Switch on asynchronous handling of evaluation workers. enable_async_evaluation: true evaluation_num_workers: 5 diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 2bf7cafe2890c..39e27268bc602 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -205,6 +205,56 @@ def flatten_to_single_ndarray(input_): return input_ +@DeveloperAPI +def batch(list_of_structs, individual_items_already_have_batch_1: bool = False): + """Converts input from a list of (nested) structs to (nested) struct of batches. + + Input: Batch (list) of structs (each of these structs representing a + single item). + [ + {"a": 1, "b": (4, 7.0)}, <- item 1 + {"a": 2, "b": (5, 8.0)}, <- item 2 + {"a": 3, "b": (6, 9.0)}, <- item 3 + ] + + Output: Struct of different batches (each batch has size=3 b/c there were 3 items + in the original list): + { + "a": np.array([1, 2, 3]), + "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0])) + } + + Args: + list_of_structs: The list of rows. Each item + in this list represents a single (maybe complex) struct. + individual_items_already_have_batch_1: True, if the individual items in + `list_of_structs` already have a batch dim (of 1). In this case, we will + concatenate (instead of stack) at the end. + + Returns: + The struct of component batches. Each leaf item + in this struct represents the batch for a single component + (in case struct is tuple/dict). Alternatively, a simple batch of + primitives (non tuple/dict) might be returned. + """ + flat = item = None + + for item in list_of_structs: + flattened_item = tree.flatten(item) + # Create the main list, in which each slot represents one leaf in the (nested) + # struct. Each slot holds a list of batch values. + if flat is None: + flat = [[] for _ in range(len(flattened_item))] + for i, value in enumerate(flattened_item): + flat[i].append(value) + + # Unflatten everything into the + out = tree.unflatten_as(item, flat) + np_func = np.stack if not individual_items_already_have_batch_1 else np.concatenate + out = tree.map_structure_up_to(item, lambda s: np_func(s, axis=0), out) + return out + + @DeveloperAPI def unbatch(batches_struct): """Converts input from (nested) struct of batches to batch of structs. diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index b61b5298bb411..24369a185f82c 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -19,6 +19,8 @@ if TYPE_CHECKING: from ray.rllib.env.env_context import EnvContext + from ray.rllib.env.multi_agent_episode import MultiAgentEpisode + from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.policy import PolicySpec @@ -75,7 +77,8 @@ # Represents a BaseEnv, MultiAgentEnv, ExternalEnv, ExternalMultiAgentEnv, # VectorEnv, gym.Env, or ActorHandle. -EnvType = Any +# TODO (sven): Specify this type more strictly (it should just be gym.Env). +EnvType = Union[Any, gym.Env] # A callable, taking a EnvContext object # (config dict + properties: `worker_index`, `vector_index`, `num_workers`, @@ -101,6 +104,9 @@ # Represents an episode id. EpisodeID = int +# A new stack Episode type: Either single-agent or multi-agent. +EpisodeType = Type[Union["SingleAgentEpisode", "MultiAgentEpisode"]] + # Represents an "unroll" (maybe across different sub-envs in a vector env). UnrollID = int From 768b88c717363aa2bb0ce3c5743a0a859ef4f494 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 16 Nov 2023 12:16:11 +0100 Subject: [PATCH 2/3] wip Signed-off-by: sven1977 --- rllib/core/models/torch/encoder.py | 21 +++++----- rllib/utils/spaces/space_utils.py | 39 ++++++++++-------- rllib/utils/spaces/tests/test_space_utils.py | 43 +++++++++++++++++++- 3 files changed, 76 insertions(+), 27 deletions(-) diff --git a/rllib/core/models/torch/encoder.py b/rllib/core/models/torch/encoder.py index 16f02be09efbe..dd90c5af02a35 100644 --- a/rllib/core/models/torch/encoder.py +++ b/rllib/core/models/torch/encoder.py @@ -285,7 +285,7 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: bias=config.use_bias, ) - self.state_in_out_spec = { + self._state_in_out_spec = { "h": TensorSpec( "b, l, d", d=self.config.hidden_dim, @@ -300,16 +300,17 @@ def __init__(self, config: RecurrentEncoderConfig) -> None: ), } - @override(Model) def get_input_specs(self) -> Optional[Spec]: - return SpecDict({ - # b, t for batch major; t, b for time major. - SampleBatch.OBS: TensorSpec( - "b, t, d", d=self.config.input_dims[0], framework="torch" - ), - STATE_IN: self.state_in_out_spec, - }) + return SpecDict( + { + # b, t for batch major; t, b for time major. + SampleBatch.OBS: TensorSpec( + "b, t, d", d=self.config.input_dims[0], framework="torch" + ), + STATE_IN: self._state_in_out_spec, + } + ) @override(Model) def get_output_specs(self) -> Optional[Spec]: @@ -318,7 +319,7 @@ def get_output_specs(self) -> Optional[Spec]: ENCODER_OUT: TensorSpec( "b, t, d", d=self.config.output_dims[0], framework="torch" ), - STATE_OUT: self.state_in_out_spec, + STATE_OUT: self._state_in_out_spec, } ) diff --git a/rllib/utils/spaces/space_utils.py b/rllib/utils/spaces/space_utils.py index 39e27268bc602..c3d621189d8f8 100644 --- a/rllib/utils/spaces/space_utils.py +++ b/rllib/utils/spaces/space_utils.py @@ -206,15 +206,17 @@ def flatten_to_single_ndarray(input_): @DeveloperAPI -def batch(list_of_structs, individual_items_already_have_batch_1: bool = False): - """Converts input from a list of (nested) structs to (nested) struct of batches. +def batch( + list_of_structs: List[Any], + individual_items_already_have_batch_1: bool = False, +): + """Converts input from a list of (nested) structs to a (nested) struct of batches. - Input: Batch (list) of structs (each of these structs representing a - single item). + Input: List of structs (each of these structs representing a single batch item). [ - {"a": 1, "b": (4, 7.0)}, <- item 1 - {"a": 2, "b": (5, 8.0)}, <- item 2 - {"a": 3, "b": (6, 9.0)}, <- item 3 + {"a": 1, "b": (4, 7.0)}, <- batch item 1 + {"a": 2, "b": (5, 8.0)}, <- batch item 2 + {"a": 3, "b": (6, 9.0)}, <- batch item 3 ] Output: Struct of different batches (each batch has size=3 b/c there were 3 items @@ -225,20 +227,25 @@ def batch(list_of_structs, individual_items_already_have_batch_1: bool = False): } Args: - list_of_structs: The list of rows. Each item - in this list represents a single (maybe complex) struct. + list_of_structs: The list of (possibly nested) structs. Each item + in this list represents a single batch item. individual_items_already_have_batch_1: True, if the individual items in `list_of_structs` already have a batch dim (of 1). In this case, we will - concatenate (instead of stack) at the end. + concatenate (instead of stack) at the end. In the example above, this would + look like this: Input: [{"a": [1], "b": ([4], [7.0])}, ...] -> Output: same + as in above example. Returns: - The struct of component batches. Each leaf item - in this struct represents the batch for a single component - (in case struct is tuple/dict). Alternatively, a simple batch of - primitives (non tuple/dict) might be returned. + The struct of component batches. Each leaf item in this struct represents the + batch for a single component (in case struct is tuple/dict). If the input is a + simple list of primitive items, e.g. a list of floats, a np.array of floats + will be returned. """ flat = item = None + if not list_of_structs: + raise ValueError("Input `list_of_structs` does not contain any items.") + for item in list_of_structs: flattened_item = tree.flatten(item) # Create the main list, in which each slot represents one leaf in the (nested) @@ -280,8 +287,8 @@ def unbatch(batches_struct): primitives (non tuple/dict). Returns: - List[struct[components]]: The list of rows. Each item - in the returned list represents a single (maybe complex) struct. + The list of individual structs. Each item in the returned list represents a + single (maybe complex) batch item. """ flat_batches = tree.flatten(batches_struct) diff --git a/rllib/utils/spaces/tests/test_space_utils.py b/rllib/utils/spaces/tests/test_space_utils.py index 9283e675dad89..c0200b6870ddb 100644 --- a/rllib/utils/spaces/tests/test_space_utils.py +++ b/rllib/utils/spaces/tests/test_space_utils.py @@ -2,13 +2,18 @@ import unittest -import numpy as np from gymnasium.spaces import Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict +import numpy as np +import tree # pip install dm_tree + from ray.rllib.utils.spaces.space_utils import ( + batch, convert_element_to_space_type, get_base_struct_from_space, + unbatch, unsquash_action, ) +from ray.rllib.utils.test_utils import check class TestSpaceUtils(unittest.TestCase): @@ -69,6 +74,42 @@ def test_unsquash_action(self): self.assertEqual(action[0], 6) self.assertEqual(action[1], 6) + def test_batch_and_unbatch(self): + """Tests the two utility functions `batch` and `unbatch`.""" + # Create a complex struct of individual batches (B=2). + complex_struct = { + "a": ( + np.array([-10.0, -20.0]), + { + "a1": np.array([-1, -2]), + "a2": np.array([False, False]), + }, + ), + "b": np.array([0, 1]), + "c": { + "c1": np.array([True, False]), + "c2": np.array([1, 2]), + "c3": (np.array([3, 4]), np.array([5, 6])), + }, + "d": np.array([0.0, 0.1]), + } + complex_struct_unbatched = unbatch(complex_struct) + # Check that we now have a list of two complex items, the first one + # containing all the index=0 values, the second one containing all the index=1 + # values. + check( + complex_struct_unbatched, + [ + tree.map_structure(lambda s: s[0], complex_struct), + tree.map_structure(lambda s: s[1], complex_struct), + ], + ) + + # Re-batch the unbatched struct. + complex_struct_rebatched = batch(complex_struct_unbatched) + # Should be identical to original struct. + check(complex_struct, complex_struct_rebatched) + if __name__ == "__main__": import pytest From 1b7d1ccc1ddfa56cca28d3efafe8ea13412b1ceb Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 16 Nov 2023 12:27:42 +0100 Subject: [PATCH 3/3] wip Signed-off-by: sven1977 --- rllib/env/single_agent_env_runner.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 92d966a64602a..55a0572e540ce 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -45,9 +45,8 @@ def __init__(self, config: "AlgorithmConfig", **kwargs): # Register env for the local context. # Note, `gym.register` has to be called on each worker. - if ( - isinstance(self.config.env, str) - and _global_registry.contains(ENV_CREATOR, self.config.env) + if isinstance(self.config.env, str) and _global_registry.contains( + ENV_CREATOR, self.config.env ): entry_point = partial( _global_registry.get(ENV_CREATOR, self.config.env),