From 746f6b622c41c5be7b48839cf046417fa1839bf2 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 22 Aug 2024 22:24:55 +0200 Subject: [PATCH] [RLlib] Add APPO/IMPALA multi-agent StatelessCartPole learning tests to CI (+ fix some bugs related to this). (#47245) --- .buildkite/rllib.rayci.yml | 4 +- rllib/BUILD | 134 +++++++--- .../add_states_from_episodes_to_batch.py | 19 +- .../env_to_module/mean_std_filter.py | 13 +- .../add_one_ts_to_episodes_and_truncate.py | 16 +- .../remove_single_ts_time_rank_from_batch.py | 20 +- rllib/core/learner/learner.py | 243 +++++++----------- .../core/learner/tests/test_learner_group.py | 2 +- .../appo/multi_agent_cartpole_appo.py | 6 +- .../multi_agent_stateless_cartpole_appo.py | 67 +++++ .../dqn/multi_agent_cartpole_dqn.py | 7 +- .../impala/multi_agent_cartpole_impala.py | 2 +- .../multi_agent_stateless_cartpole_impala.py | 67 +++++ .../ppo/multi_agent_cartpole_ppo.py | 7 +- .../ppo/multi_agent_pendulum_ppo.py | 7 +- .../sac/multi_agent_pendulum_sac.py | 9 +- rllib/utils/metrics/__init__.py | 1 + rllib/utils/test_utils.py | 11 +- 18 files changed, 413 insertions(+), 222 deletions(-) create mode 100644 rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py create mode 100644 rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py diff --git a/.buildkite/rllib.rayci.yml b/.buildkite/rllib.rayci.yml index f6cae0dc02f7..3072ffc3998f 100644 --- a/.buildkite/rllib.rayci.yml +++ b/.buildkite/rllib.rayci.yml @@ -83,7 +83,7 @@ steps: tags: - rllib_gpu - gpu - parallelism: 4 + parallelism: 5 instance_type: gpu commands: - bazel run //ci/ray_ci:test_in_docker -- //rllib/... rllib @@ -165,7 +165,7 @@ steps: tags: - rllib_gpu - gpu - parallelism: 4 + parallelism: 5 instance_type: gpu-large commands: - bazel run //ci/ray_ci:test_in_docker -- //rllib/... rllib diff --git a/rllib/BUILD b/rllib/BUILD index bed02f025a09..333c3c760ed0 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -184,23 +184,6 @@ py_test( srcs = ["tuned_examples/appo/cartpole_appo.py"], args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] ) -# StatelessCartPole -py_test( - name = "learning_tests_stateless_cartpole_appo", - main = "tuned_examples/appo/stateless_cartpole_appo.py", - tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], - size = "large", - srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] -) -py_test( - name = "learning_tests_stateless_cartpole_appo_multi_gpu", - main = "tuned_examples/appo/stateless_cartpole_appo.py", - tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], - size = "large", - srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] -) # MultiAgentCartPole py_test( name = "learning_tests_multi_agent_cartpole_appo", @@ -234,6 +217,72 @@ py_test( srcs = ["tuned_examples/appo/multi_agent_cartpole_appo.py"], args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-gpus=2", "--num-cpus=7"] ) +# StatelessCartPole +py_test( + name = "learning_tests_stateless_cartpole_appo", + main = "tuned_examples/appo/stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] +) +py_test( + name = "learning_tests_stateless_cartpole_appo_gpu", + main = "tuned_examples/appo/stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], + size = "large", + srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-gpus=1"] +) +py_test( + name = "learning_tests_stateless_cartpole_appo_multi_cpu", + main = "tuned_examples/appo/stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) +py_test( + name = "learning_tests_stateless_cartpole_appo_multi_gpu", + main = "tuned_examples/appo/stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], + size = "large", + srcs = ["tuned_examples/appo/stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) +# MultiAgentStatelessCartPole +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_appo", + main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] +) +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_appo_gpu", + main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"], + size = "large", + srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-gpus=1"] +) +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_cpu", + main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_appo_multi_gpu", + main = "tuned_examples/appo/multi_agent_stateless_cartpole_appo.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], + size = "large", + srcs = ["tuned_examples/appo/multi_agent_stateless_cartpole_appo.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) #@OldAPIStack py_test( @@ -462,23 +511,6 @@ py_test( srcs = ["tuned_examples/impala/cartpole_impala.py"], args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] ) -# StatelessCartPole -py_test( - name = "learning_tests_stateless_cartpole_impala", - main = "tuned_examples/impala/stateless_cartpole_impala.py", - tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], - size = "large", - srcs = ["tuned_examples/impala/stateless_cartpole_impala.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] -) -py_test( - name = "learning_tests_stateless_cartpole_impala_multi_gpu", - main = "tuned_examples/impala/stateless_cartpole_impala.py", - tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], - size = "large", - srcs = ["tuned_examples/impala/stateless_cartpole_impala.py"], - args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] -) # MultiAgentCartPole py_test( name = "learning_tests_multi_agent_cartpole_impala", @@ -512,6 +544,40 @@ py_test( srcs = ["tuned_examples/impala/multi_agent_cartpole_impala.py"], args = ["--as-test", "--enable-new-api-stack", "--num-agents=2", "--num-gpus=2", "--num-cpus=7"] ) +# StatelessCartPole +py_test( + name = "learning_tests_stateless_cartpole_impala", + main = "tuned_examples/impala/stateless_cartpole_impala.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/impala/stateless_cartpole_impala.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] +) +py_test( + name = "learning_tests_stateless_cartpole_impala_multi_gpu", + main = "tuned_examples/impala/stateless_cartpole_impala.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], + size = "large", + srcs = ["tuned_examples/impala/stateless_cartpole_impala.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) +# MultiAgentStatelessCartPole +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_impala", + main = "tuned_examples/impala/multi_agent_stateless_cartpole_impala.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"], + size = "large", + srcs = ["tuned_examples/impala/multi_agent_stateless_cartpole_impala.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"] +) +py_test( + name = "learning_tests_multi_agent_stateless_cartpole_impala_multi_gpu", + main = "tuned_examples/impala/multi_agent_stateless_cartpole_impala.py", + tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"], + size = "large", + srcs = ["tuned_examples/impala/multi_agent_stateless_cartpole_impala.py"], + args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"] +) #@OldAPIstack py_test( diff --git a/rllib/connectors/common/add_states_from_episodes_to_batch.py b/rllib/connectors/common/add_states_from_episodes_to_batch.py index 8a494dc69a94..995064dcf6c7 100644 --- a/rllib/connectors/common/add_states_from_episodes_to_batch.py +++ b/rllib/connectors/common/add_states_from_episodes_to_batch.py @@ -228,7 +228,7 @@ def __call__( # Also, let module-to-env pipeline know that we had added a single timestep # time rank to the data (to remove it again). if not self._as_learner_connector: - for column, column_data in data.copy().items(): + for column in data.keys(): self.foreach_batch_item_change_in_place( batch=data, column=column, @@ -250,11 +250,20 @@ def __call__( # Before adding STATE_IN to the `data`, zero-pad existing data and batch # into max_seq_len chunks. for column, column_data in data.copy().items(): + # Do not zero-pad INFOS column. + if column == Columns.INFOS: + continue for key, item_list in column_data.items(): - if column != Columns.INFOS: - column_data[key] = split_and_zero_pad_list( - item_list, T=self.max_seq_len - ) + # Multi-agent case AND RLModule is not stateful -> Do not zero-pad + # for this model. + assert isinstance(key, tuple) + if len(key) == 3: + eps_id, aid, mid = key + if not rl_module[mid].is_stateful(): + continue + column_data[key] = split_and_zero_pad_list( + item_list, T=self.max_seq_len + ) for sa_episode in self.single_agent_episode_iterator( episodes, diff --git a/rllib/connectors/env_to_module/mean_std_filter.py b/rllib/connectors/env_to_module/mean_std_filter.py index d568a7bc36a4..c92e33e139f0 100644 --- a/rllib/connectors/env_to_module/mean_std_filter.py +++ b/rllib/connectors/env_to_module/mean_std_filter.py @@ -116,9 +116,16 @@ def __call__( # anymore to the original observations). for sa_episode in self.single_agent_episode_iterator(episodes): sa_obs = sa_episode.get_observations(indices=-1) - normalized_sa_obs = self._filters[sa_episode.agent_id]( - sa_obs, update=self._update_stats - ) + try: + normalized_sa_obs = self._filters[sa_episode.agent_id]( + sa_obs, update=self._update_stats + ) + except KeyError: + raise KeyError( + "KeyError trying to access a filter by agent ID " + f"`{sa_episode.agent_id}`! You probably did NOT pass the " + f"`multi_agent=True` flag into the `MeanStdFilter()` constructor. " + ) sa_episode.set_observations(at_indices=-1, new_data=normalized_sa_obs) # We set the Episode's observation space to ours so that we can safely # set the last obs to the new value (without causing a space mismatch diff --git a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py index fd72e4b5f2a6..d6997fe3cafb 100644 --- a/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py +++ b/rllib/connectors/learner/add_one_ts_to_episodes_and_truncate.py @@ -3,6 +3,7 @@ from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.env.multi_agent_episode import MultiAgentEpisode from ray.rllib.utils.annotations import override from ray.rllib.utils.postprocessing.episodes import add_one_ts_to_episodes_and_truncate from ray.rllib.utils.typing import EpisodeType @@ -101,10 +102,23 @@ def __call__( # batch: - - - - - - - T B0- - - - - R Bx- - - - R Bx # mask : t t t t t t t t f t t t t t t f t t t t t f + # TODO (sven): Same situation as in TODO below, but for multi-agent episode. + # Maybe add a dedicated connector piece for this task? + # We extend the MultiAgentEpisode's ID by a running number here to make sure + # we treat each MAEpisode chunk as separate (for potentially upcoming v-trace + # and LSTM zero-padding) and don't mix data from different chunks. + if isinstance(episodes[0], MultiAgentEpisode): + for i, ma_episode in enumerate(episodes): + ma_episode.id_ += "_" + str(i) + # Also change the underlying single-agent episode's + # `multi_agent_episode_id` properties. + for sa_episode in ma_episode.agent_episodes.values(): + sa_episode.multi_agent_episode_id = ma_episode.id_ + for i, sa_episode in enumerate( self.single_agent_episode_iterator(episodes, agents_that_stepped_only=False) ): - # TODO (sven): This is a little bit of a hack: By expanding the Episode's + # TODO (sven): This is a little bit of a hack: By extending the Episode's # ID, we make sure that each episode chunk in `episodes` is treated as a # separate episode in the `self.add_n_batch_items` below. Some algos (e.g. # APPO) may have >1 episode chunks from the same episode (same ID) in the diff --git a/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py b/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py index 15aebf266053..811f30b3ea48 100644 --- a/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py +++ b/rllib/connectors/module_to_env/remove_single_ts_time_rank_from_batch.py @@ -50,9 +50,21 @@ def __call__( if shared_data is None or not shared_data.get("_added_single_ts_time_rank"): return data - data = tree.map_structure_with_path( - lambda p, s: s if Columns.STATE_OUT in p else np.squeeze(s, axis=0), - data, - ) + def _remove_single_ts(item, eps_id, aid, mid): + # Only remove time-rank for modules that are statefule (only for those has + # a timerank been added). + if mid is None or rl_module[mid].is_stateful(): + return tree.map_structure(lambda s: np.squeeze(s, axis=0), item) + return item + + for column, column_data in data.copy().items(): + # Skip state_out (doesn't have a time rank). + if column == Columns.STATE_OUT: + continue + self.foreach_batch_item_change_in_place( + data, + column=column, + func=_remove_single_ts, + ) return data diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index b52a135c89f2..bab75f15513a 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -53,6 +53,7 @@ NUM_ENV_STEPS_TRAINED, NUM_MODULE_STEPS_TRAINED, LEARNER_CONNECTOR_TIMER, + MODULE_TRAIN_BATCH_SIZE_MEAN, ) from ray.rllib.utils.metrics.metrics_logger import MetricsLogger from ray.rllib.utils.minibatch_utils import ( @@ -1025,6 +1026,84 @@ def update_from_episodes( num_total_mini_batches=num_total_mini_batches, ) + def update_from_iterator( + self, + iterator, + *, + timesteps: Optional[Dict[str, Any]] = None, + minibatch_size: Optional[int] = None, + num_iters: int = None, + **kwargs, + ): + self._check_is_built() + minibatch_size = minibatch_size or 32 + + # Call `before_gradient_based_update` to allow for non-gradient based + # preparations-, logging-, and update logic to happen. + self.before_gradient_based_update(timesteps=timesteps or {}) + + def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: + # Note, the incoming batch is a dictionary with a numpy array + # holding the `MultiAgentBatch`. + batch = self._convert_batch_type(batch["batch"][0]) + return {"batch": self._set_slicing_by_batch_id(batch, value=True)} + + i = 0 + for batch in iterator.iter_batches( + batch_size=minibatch_size, + _finalize_fn=_finalize_fn, + **kwargs, + ): + # Update the iteration counter. + i += 1 + + # Note, `_finalize_fn` must return a dictionary. + batch = batch["batch"] + # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs + # found in this batch. If not, throw an error. + unknown_module_ids = set(batch.policy_batches.keys()) - set( + self.module.keys() + ) + if len(unknown_module_ids) > 0: + raise ValueError( + "Batch contains one or more ModuleIDs that are not in this " + f"Learner! Found IDs: {unknown_module_ids}" + ) + + # Log metrics. + self._log_steps_trained_metrics(batch) + + # Make the actual in-graph/traced `_update` call. This should return + # all tensor values (no numpy). + fwd_out, loss_per_module, tensor_metrics = self._update( + batch.policy_batches + ) + + self._set_slicing_by_batch_id(batch, value=False) + # If `num_iters` is reached break and return. + if num_iters and i == num_iters: + break + + logger.info(f"[Learner] Iterations run in epoch: {i}") + # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) + # to actual (numpy) values. + self.metrics.tensors_to_numpy(tensor_metrics) + + # Log all individual RLModules' loss terms and its registered optimizers' + # current learning rates. + for mid, loss in convert_to_numpy(loss_per_module).items(): + self.metrics.log_value( + key=(mid, self.TOTAL_LOSS_KEY), + value=loss, + window=1, + ) + # Call `after_gradient_based_update` to allow for non-gradient based + # cleanups-, logging-, and update logic to happen. + self.after_gradient_based_update(timesteps=timesteps or {}) + + # Reduce results across all minibatch update steps. + return self.metrics.reduce() + @OverrideToImplementCustomLogic @abc.abstractmethod def _update( @@ -1130,95 +1209,6 @@ def _set_optimizer_state(self, state: StateDict) -> None: """ raise NotImplementedError - def update_from_iterator( - self, - iterator, - *, - timesteps: Optional[Dict[str, Any]] = None, - minibatch_size: Optional[int] = None, - num_iters: int = None, - **kwargs, - ): - self._check_is_built() - minibatch_size = minibatch_size or 32 - - # Call `before_gradient_based_update` to allow for non-gradient based - # preparations-, logging-, and update logic to happen. - self.before_gradient_based_update(timesteps=timesteps or {}) - - def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: - # Note, the incoming batch is a dictionary with a numpy array - # holding the `MultiAgentBatch`. - batch = self._convert_batch_type(batch["batch"][0]) - return {"batch": self._set_slicing_by_batch_id(batch, value=True)} - - i = 0 - for batch in iterator.iter_batches( - batch_size=minibatch_size, - _finalize_fn=_finalize_fn, - **kwargs, - ): - # Update the iteration counter. - i += 1 - - # Note, `_finalize_fn` must return a dictionary. - batch = batch["batch"] - # Check the MultiAgentBatch, whether our RLModule contains all ModuleIDs - # found in this batch. If not, throw an error. - unknown_module_ids = set(batch.policy_batches.keys()) - set( - self.module.keys() - ) - if len(unknown_module_ids) > 0: - raise ValueError( - "Batch contains one or more ModuleIDs that are not in this " - f"Learner! Found IDs: {unknown_module_ids}" - ) - - # Log metrics. - self.metrics.log_dict( - { - (ALL_MODULES, NUM_ENV_STEPS_TRAINED): batch.env_steps(), - (ALL_MODULES, NUM_MODULE_STEPS_TRAINED): batch.agent_steps(), - **{ - (mid, NUM_MODULE_STEPS_TRAINED): len(b) - for mid, b in batch.policy_batches.items() - }, - }, - reduce="sum", - clear_on_reduce=True, - ) - - # Make the actual in-graph/traced `_update` call. This should return - # all tensor values (no numpy). - fwd_out, loss_per_module, tensor_metrics = self._update( - batch.policy_batches - ) - - self._set_slicing_by_batch_id(batch, value=False) - # If `num_iters` is reached break and return. - if num_iters and i == num_iters: - break - - logger.info(f"[Learner] Iterations run in epoch: {i}") - # Convert logged tensor metrics (logged during tensor-mode of MetricsLogger) - # to actual (numpy) values. - self.metrics.tensors_to_numpy(tensor_metrics) - - # Log all individual RLModules' loss terms and its registered optimizers' - # current learning rates. - for mid, loss in convert_to_numpy(loss_per_module).items(): - self.metrics.log_value( - key=(mid, self.TOTAL_LOSS_KEY), - value=loss, - window=1, - ) - # Call `after_gradient_based_update` to allow for non-gradient based - # cleanups-, logging-, and update logic to happen. - self.after_gradient_based_update(timesteps=timesteps or {}) - - # Reduce results across all minibatch update steps. - return self.metrics.reduce() - def _update_from_batch_or_episodes( self, *, @@ -1294,24 +1284,8 @@ def _update_from_batch_or_episodes( if not self.should_module_be_updated(module_id, batch): del batch.policy_batches[module_id] - # Log all timesteps (env, agent, modules) based on given episodes. - if self._learner_connector is not None and episodes is not None: - self._log_steps_trained_metrics(episodes, batch, shared_data) - # TODO (sven): Possibly remove this if-else block entirely. We might be in a - # world soon where we always learn from episodes, never from an incoming batch. - else: - self.metrics.log_dict( - { - (ALL_MODULES, NUM_ENV_STEPS_TRAINED): batch.env_steps(), - (ALL_MODULES, NUM_MODULE_STEPS_TRAINED): batch.agent_steps(), - **{ - (mid, NUM_MODULE_STEPS_TRAINED): len(b) - for mid, b in batch.policy_batches.items() - }, - }, - reduce="sum", - clear_on_reduce=True, - ) + # Log all timesteps (env, agent, modules) based on given episodes/batch. + self._log_steps_trained_metrics(batch) if minibatch_size: if self._learner_connector is not None: @@ -1581,54 +1555,33 @@ def _set_optimizer_lr(optimizer: Optimizer, lr: float) -> None: def _get_clip_function() -> Callable: """Returns the gradient clipping function to use, given the framework.""" - def _log_steps_trained_metrics(self, episodes, batch, shared_data): - # Logs this iteration's steps trained, based on given `episodes`. - env_steps = sum(len(e) for e in episodes) + def _log_steps_trained_metrics(self, batch: MultiAgentBatch): + """Logs this iteration's steps trained, based on given `batch`.""" + log_dict = defaultdict(dict) - orig_lengths = shared_data.get("_sa_episodes_lengths", {}) - for sa_episode in self._learner_connector.single_agent_episode_iterator( - episodes, agents_that_stepped_only=False - ): - mid = ( - sa_episode.module_id - if sa_episode.module_id is not None - else DEFAULT_MODULE_ID - ) - # Do not log steps trained for those ModuleIDs that should not be updated. - if mid != ALL_MODULES and mid not in batch.policy_batches: - continue - - _len = ( - orig_lengths[sa_episode.id_] - if sa_episode.id_ in orig_lengths - else len(sa_episode) + for mid, module_batch in batch.policy_batches.items(): + module_batch_size = len(module_batch) + # Log average batch size (for each module). + self.metrics.log_value( + key=(mid, MODULE_TRAIN_BATCH_SIZE_MEAN), + value=module_batch_size, ) - # TODO (sven): Decide, whether agent_ids should be part of LEARNER_RESULTS. - # Currently and historically, only ModuleID keys and ALL_MODULES were used - # and expected. Does it make sense to include e.g. agent steps trained? - # I'm not sure atm. - # aid = ( - # sa_episode.agent_id if sa_episode.agent_id is not None - # else DEFAULT_AGENT_ID - # ) + # Log module steps (for each module). if NUM_MODULE_STEPS_TRAINED not in log_dict[mid]: - log_dict[mid][NUM_MODULE_STEPS_TRAINED] = _len + log_dict[mid][NUM_MODULE_STEPS_TRAINED] = module_batch_size else: - log_dict[mid][NUM_MODULE_STEPS_TRAINED] += _len - # TODO (sven): See above. - # if NUM_AGENT_STEPS_TRAINED not in log_dict[aid]: - # log_dict[aid][NUM_AGENT_STEPS_TRAINED] = _len - # else: - # log_dict[aid][NUM_AGENT_STEPS_TRAINED] += _len + log_dict[mid][NUM_MODULE_STEPS_TRAINED] += module_batch_size + + # Log module steps (sum of all modules). if NUM_MODULE_STEPS_TRAINED not in log_dict[ALL_MODULES]: - log_dict[ALL_MODULES][NUM_MODULE_STEPS_TRAINED] = _len + log_dict[ALL_MODULES][NUM_MODULE_STEPS_TRAINED] = module_batch_size else: - log_dict[ALL_MODULES][NUM_MODULE_STEPS_TRAINED] += _len + log_dict[ALL_MODULES][NUM_MODULE_STEPS_TRAINED] += module_batch_size # Log env steps (all modules). self.metrics.log_value( (ALL_MODULES, NUM_ENV_STEPS_TRAINED), - env_steps, + batch.env_steps(), reduce="sum", clear_on_reduce=True, ) diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index ca51dffd7859..605a694889e0 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -518,7 +518,7 @@ def test_save_to_path_and_restore_from_path(self): del learner_group # Compare the results of the two updates. - check(results_2nd_update_with_break, results_2nd_without_break) + check(results_2nd_update_with_break, results_2nd_without_break, rtol=0.05) check( weights_after_2_updates_with_break, weights_after_2_updates_without_break, diff --git a/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py b/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py index 177cc9a0cd10..dd1e76a0c978 100644 --- a/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py +++ b/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py @@ -8,7 +8,7 @@ from ray.rllib.utils.test_utils import add_rllib_example_script_args from ray.tune.registry import register_env -parser = add_rllib_example_script_args() +parser = add_rllib_example_script_args(default_timesteps=2000000) parser.set_defaults( enable_new_api_stack=True, num_agents=2, @@ -45,8 +45,8 @@ ) stop = { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 400.0 * args.num_agents, - f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 2000000, + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0 * args.num_agents, + f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": args.stop_timesteps, } diff --git a/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py b/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py new file mode 100644 index 000000000000..a8713f4350d3 --- /dev/null +++ b/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py @@ -0,0 +1,67 @@ +from ray.rllib.algorithms.appo import APPOConfig +from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentStatelessCartPole +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.test_utils import add_rllib_example_script_args +from ray.tune.registry import register_env + +parser = add_rllib_example_script_args( + default_timesteps=2000000, + default_reward=350.0, +) +parser.set_defaults( + enable_new_api_stack=True, + num_agents=2, + num_env_runners=3, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values toset up `config` below. +args = parser.parse_args() + +register_env("env", lambda cfg: MultiAgentStatelessCartPole(config=cfg)) + + +config = ( + APPOConfig() + # Enable new API stack and use EnvRunner. + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + .environment("env", env_config={"num_agents": args.num_agents}) + .env_runners( + env_to_module_connector=lambda env: MeanStdFilter(multi_agent=True), + ) + .training( + train_batch_size_per_learner=600, + lr=0.0005 * ((args.num_gpus or 1) ** 0.5), + num_sgd_iter=6, + vf_loss_coeff=0.05, + grad_clip=20.0, + ) + .rl_module( + model_config_dict={ + "use_lstm": True, + "uses_new_env_runners": True, + "max_seq_len": 50, + }, + ) + .multi_agent( + policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), + policies={f"p{i}" for i in range(args.num_agents)}, + ) +) + +stop = { + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 200.0 * args.num_agents, + NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, +} + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args, stop=stop) diff --git a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py index 94aac4c2c8f0..e7aa087413e5 100644 --- a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py +++ b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py @@ -20,10 +20,7 @@ # and (if needed) use their values to set up `config` below. args = parser.parse_args() -register_env( - "multi_agent_cartpole", - lambda _: MultiAgentCartPole({"num_agents": args.num_agents}), -) +register_env("multi_agent_cartpole", lambda cfg: MultiAgentCartPole(config=cfg)) config = ( DQNConfig() @@ -31,7 +28,7 @@ enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) - .environment(env="multi_agent_cartpole") + .environment(env="multi_agent_cartpole", env_config={"num_agents": args.num_agents}) .training( lr=0.0005 * (args.num_gpus or 1) ** 0.5, train_batch_size_per_learner=32, diff --git a/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py b/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py index a48e671c1353..b000f40cca7b 100644 --- a/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py +++ b/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py @@ -49,7 +49,7 @@ ) stop = { - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 400.0 * args.num_agents, + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 350.0 * args.num_agents, NUM_ENV_STEPS_SAMPLED_LIFETIME: 2500000, } diff --git a/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py b/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py new file mode 100644 index 000000000000..2940c4093f90 --- /dev/null +++ b/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py @@ -0,0 +1,67 @@ +from ray.rllib.algorithms.impala import IMPALAConfig +from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentStatelessCartPole +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.rllib.utils.test_utils import add_rllib_example_script_args +from ray.tune.registry import register_env + +parser = add_rllib_example_script_args(default_timesteps=5000000) +parser.set_defaults( + enable_new_api_stack=True, + num_agents=2, + num_env_runners=4, +) +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values toset up `config` below. +args = parser.parse_args() + +register_env( + "multi_stateless_cart", + lambda cfg: MultiAgentStatelessCartPole(config=cfg), +) + + +config = ( + IMPALAConfig() + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + .environment("multi_stateless_cart", env_config={"num_agents": args.num_agents}) + .env_runners( + env_to_module_connector=lambda env: MeanStdFilter(multi_agent=True), + ) + .training( + train_batch_size_per_learner=600, + lr=0.0003 * ((args.num_gpus or 1) ** 0.5), + vf_loss_coeff=0.05, + entropy_coeff=0.0, + grad_clip=20.0, + ) + .rl_module( + model_config_dict={ + "use_lstm": True, + "uses_new_env_runners": True, + "max_seq_len": 50, + }, + ) + .multi_agent( + policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), + policies={f"p{i}" for i in range(args.num_agents)}, + ) +) + +stop = { + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 200.0 * args.num_agents, + NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, +} + + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args, stop=stop) diff --git a/rllib/tuned_examples/ppo/multi_agent_cartpole_ppo.py b/rllib/tuned_examples/ppo/multi_agent_cartpole_ppo.py index b9f4aa7ca322..054cfc056831 100644 --- a/rllib/tuned_examples/ppo/multi_agent_cartpole_ppo.py +++ b/rllib/tuned_examples/ppo/multi_agent_cartpole_ppo.py @@ -17,10 +17,7 @@ # and (if needed) use their values toset up `config` below. args = parser.parse_args() -register_env( - "multi_agent_cartpole", - lambda _: MultiAgentCartPole({"num_agents": args.num_agents}), -) +register_env("multi_agent_cartpole", lambda cfg: MultiAgentCartPole(config=cfg)) config = ( PPOConfig() @@ -28,7 +25,7 @@ enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) - .environment("multi_agent_cartpole") + .environment("multi_agent_cartpole", env_config={"num_agents": args.num_agents}) .rl_module( model_config_dict={ "fcnet_hiddens": [32], diff --git a/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py b/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py index d1b162e7bea3..fb7b4ac70258 100644 --- a/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py +++ b/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py @@ -17,10 +17,7 @@ # and (if needed) use their values toset up `config` below. args = parser.parse_args() -register_env( - "multi_agent_pendulum", - lambda _: MultiAgentPendulum({"num_agents": args.num_agents}), -) +register_env("multi_agent_pendulum", lambda cfg: MultiAgentPendulum(config=cfg)) config = ( PPOConfig() @@ -28,7 +25,7 @@ enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) - .environment("multi_agent_pendulum") + .environment("multi_agent_pendulum", env_config={"num_agents": args.num_agents}) .rl_module( model_config_dict={ "fcnet_activation": "relu", diff --git a/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py b/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py index ef9279c28efc..edc7a5fa71ff 100644 --- a/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py +++ b/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py @@ -22,10 +22,7 @@ # and (if needed) use their values to set up `config` below. args = parser.parse_args() -register_env( - "multi_agent_pendulum", - lambda _: MultiAgentPendulum({"num_agents": args.num_agents}), -) +register_env("multi_agent_pendulum", lambda cfg: MultiAgentPendulum(config=cfg)) config = ( SACConfig() @@ -33,7 +30,7 @@ enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) - .environment("multi_agent_pendulum") + .environment("multi_agent_pendulum", env_config={"num_agents": args.num_agents}) .training( initial_alpha=1.001, # Use a smaller learning rate for the policy. @@ -79,7 +76,7 @@ stop = { NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, # `episode_return_mean` is the sum of all agents/policies' returns. - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -400.0 * args.num_agents, + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -450.0 * args.num_agents, } if __name__ == "__main__": diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 2cbac3f12a5e..59b828321992 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -47,6 +47,7 @@ NUM_ENV_STEPS_TRAINED_THIS_ITER = "num_env_steps_trained_this_iter" # @OldAPIStack NUM_MODULE_STEPS_TRAINED = "num_module_steps_trained" NUM_MODULE_STEPS_TRAINED_LIFETIME = "num_module_steps_trained_lifetime" +MODULE_TRAIN_BATCH_SIZE_MEAN = "module_train_batch_size_mean" # Backward compatibility: Replace with num_env_steps_... or num_agent_steps_... STEPS_TRAINED_THIS_ITER_COUNTER = "num_steps_trained_this_iter" diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 26f83ff2a2bc..8925024ee764 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -349,8 +349,14 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): assert bool(x) is not bool(y), f"ERROR: x ({x}) is y ({y})!" else: assert bool(x) is bool(y), f"ERROR: x ({x}) is not y ({y})!" - # Nones or primitives. - elif x is None or y is None or isinstance(x, (str, int)): + # Nones or primitives (excluding int vs float, which should be compared with + # tolerance/decimals as well). + elif ( + x is None + or y is None + or isinstance(x, str) + or (isinstance(x, int) and isinstance(y, int)) + ): if false is True: assert x != y, f"ERROR: x ({x}) is the same as y ({y})!" else: @@ -367,6 +373,7 @@ def check(x, y, decimals=5, atol=None, rtol=None, false=False): if false is False: raise e # Everything else (assume numeric or tf/torch.Tensor). + # Also includes int vs float comparison, which is performed with tolerance/decimals. else: if tf1 is not None: # y should never be a Tensor (y=expected value).