From 32faa1b8975f5e149f82cf9147f400d559fcdee5 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 22 Jul 2024 13:32:39 +0200 Subject: [PATCH 01/13] wip Signed-off-by: sven1977 --- rllib/BUILD | 20 +++ rllib/connectors/connector_pipeline_v2.py | 7 + rllib/examples/connectors/classes/__init__.py | 0 .../classes/count_based_curiosity.py | 92 +++++++++++++ .../euclidian_distance_based_curiosity.py | 0 .../connectors/count_based_curiosity.py | 14 ++ rllib/examples/curiosity/__init__.py | 0 .../curiosity/count_based_curiosity.py | 129 ++++++++++++++++++ .../euclidian_distance_based_curiosity.py | 0 9 files changed, 262 insertions(+) create mode 100644 rllib/examples/connectors/classes/__init__.py create mode 100644 rllib/examples/connectors/classes/count_based_curiosity.py create mode 100644 rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py create mode 100644 rllib/examples/connectors/count_based_curiosity.py create mode 100644 rllib/examples/curiosity/__init__.py create mode 100644 rllib/examples/curiosity/count_based_curiosity.py create mode 100644 rllib/examples/curiosity/euclidian_distance_based_curiosity.py diff --git a/rllib/BUILD b/rllib/BUILD index 043bbf4a3022..409e1ce8e4b6 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2544,6 +2544,26 @@ py_test( # args = ["--enable-new-api-stack", "--num-agents=2", "--as-test", "--stop-reward=-600.0", "--framework=torch", "--algo=IMPALA", "--num-env-runners=5", "--num-cpus=6"] # ) +# subdirectory: curiosity/ +# .................................... +py_test( + name = "examples/curiosity/count_based_curiosity", + main = "examples/curiosity/count_based_curiosity.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/curiosity/count_based_curiosity.py"], + args = ["--enable-new-api-stack", "--as-test"] +) + +py_test( + name = "examples/curiosity/euclidian_distance_based_curiosity", + main = "examples/curiosity/euclidian_distance_based_curiosity.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "medium", + srcs = ["examples/curiosity/euclidian_distance_based_curiosity.py"], + args = ["--enable-new-api-stack", "--as-test"] +) + # subdirectory: curriculum/ # .................................... py_test( diff --git a/rllib/connectors/connector_pipeline_v2.py b/rllib/connectors/connector_pipeline_v2.py index 0bd46b6aff69..5863312cc60c 100644 --- a/rllib/connectors/connector_pipeline_v2.py +++ b/rllib/connectors/connector_pipeline_v2.py @@ -90,6 +90,13 @@ def __call__( shared_data=shared_data, **kwargs, ) + if not isinstance(data, dict): + raise ValueError( + f"`data` returned by ConnectorV2 {connector} must be a dict! " + f"You returned {data}. Check your (custom) connectors' " + f"`__call__()` method's return value and make sure you return " + f"the `data` arg passed in (either altered or unchanged)." + ) return data def remove(self, name_or_class: Union[str, Type]): diff --git a/rllib/examples/connectors/classes/__init__.py b/rllib/examples/connectors/classes/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/examples/connectors/classes/count_based_curiosity.py b/rllib/examples/connectors/classes/count_based_curiosity.py new file mode 100644 index 000000000000..f9619cf620c0 --- /dev/null +++ b/rllib/examples/connectors/classes/count_based_curiosity.py @@ -0,0 +1,92 @@ +from collections import Counter +from typing import Any, List, Optional + +import gymnasium as gym + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.typing import EpisodeType + + +class CountBasedCuriosity(ConnectorV2): + """Learning ConnectorV2 piece to compute intrinsic rewards based on obs counts. + + Add this connector piece to your Learner pipeline, through your algo config: + ``` + config.training( + learner_connector=lambda obs_sp, act_sp: CountBasedCuriosity() + ) + ``` + + Intrinsic rewards are computed on the Learner side based on naive observation + counts, which is why this connector should only be used for simple environments + with a reasonable number of possible observations. The intrinsic reward for a given + timestep is: + r(i) = intrinsic_reward_coeff * (1 / C(obs(i))) + where C is the total (lifetime) count of the obs at timestep i. + + The instrinsic reward is added to the extrinsic reward and saved back into the + episode (under the main "rewards" key). + + Note that the computation and saving back to the episode all happens before the + actual train batch is generated from the episode data. Thus, the Learner and the + RLModule used do not take notice of the extra reward added. + + If you would like to use a more sophisticated mechanism for intrinsic reward + computations, take a look at the `EuclidianDistanceBasedCuriosity` connector piece + at `ray.rllib.examples.connectors.classes.euclidian_distance_based_curiosity` + """ + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + intrinsic_reward_coeff: float = 1.0, + **kwargs, + ): + """Initializes a CountBasedCuriosity instance. + + Args: + intrinsic_reward_coeff: The weight with which to multiply the intrinsic + reward before adding (and saving) it back to the main (extrinsic) + reward of the episode at each timestep. + """ + super().__init__(input_observation_space, input_action_space) + + # Naive observation counter. + self._counts = Counter() + self.intrinsic_reward_coeff = intrinsic_reward_coeff + + def __call__( + self, + *, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # Loop through all episodes and change the reward to + # [reward + intrinsic reward] + for sa_episode in self.single_agent_episode_iterator( + episodes=episodes, agents_that_stepped_only=False + ): + # Loop through all obs, except the last one. + observations = sa_episode.get_observations(slice(None, -1)) + # Get all respective (extrinsic) rewards. + rewards = sa_episode.get_rewards() + + for i, (obs, rew) in enumerate(zip(observations, rewards)): + # Add 1 to obs counter. + obs = tuple(obs) + self._counts[obs] += 1 + # Compute our count-based intrinsic reward and add it to the main + # (extrinsic) reward. + rew += self.intrinsic_reward_coeff * (1 / self._counts[obs]) + # Store the new reward back to the episode (under the correct + # timestep/index). + sa_episode.set_rewards(new_data=rew, at_indices=i) + + return data diff --git a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/examples/connectors/count_based_curiosity.py b/rllib/examples/connectors/count_based_curiosity.py new file mode 100644 index 000000000000..ad09e4ceb6bf --- /dev/null +++ b/rllib/examples/connectors/count_based_curiosity.py @@ -0,0 +1,14 @@ +"""Placeholder for training with count-based curiosity. + +The actual script can be found at a different location (see code below). +""" + +if __name__ == "__main__": + import subprocess + import sys + + # Forward to "python ../curiosity/[same script name].py [same options]" + command = [sys.executable, "../curiosity/", sys.argv[0]] + sys.argv[1:] + + # Run the script. + subprocess.run(command, capture_output=True) diff --git a/rllib/examples/curiosity/__init__.py b/rllib/examples/curiosity/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/examples/curiosity/count_based_curiosity.py b/rllib/examples/curiosity/count_based_curiosity.py new file mode 100644 index 000000000000..62d79a387023 --- /dev/null +++ b/rllib/examples/curiosity/count_based_curiosity.py @@ -0,0 +1,129 @@ +"""Example of using a count-based curiosity mechanism to learn in sparse-rewards envs. + +This example: + - demonstrates how to define your own count-based curiosity ConnectorV2 piece + that computes intrinsic rewards based on simple observation counts and adds these + intrinsic rewards to the "main" (extrinsic) rewards. + - shows how this connector piece overrides the main (extrinsic) rewards in the + episode and thus demonstrates how to do reward shaping in general with RLlib. + - shows how to plug this connector piece into your algorithm's config. + - uses Tune and RLlib to learn the env described above and compares 2 + algorithms, one that does use curiosity vs one that does not. + +We use a FrozenLake (sparse reward) environment with a map size of 8x8 and a time step +limit of 14 to make it almost impossible for a non-curiosity based policy to learn. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +Use the `--no-curiosity` flag to disable curiosity learning and force your policy +to be trained on the task w/o the use of intrinsic rewards. With this option, the +algorithm should NOT succeed. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +In the console output, you can see that only a PPO policy that uses curiosity can +actually learn. + +Policy using count-based curiosity: ++-------------------------------+------------+--------+------------------+ +| Trial name | status | iter | total time (s) | +| | | | | +|-------------------------------+------------+--------+------------------+ +| PPO_FrozenLake-v1_109de_00000 | TERMINATED | 48 | 44.46 | ++-------------------------------+------------+--------+------------------+ ++------------------------+-------------------------+------------------------+ +| episode_return_mean | num_episodes_lifetime | num_env_steps_traine | +| | | d_lifetime | +|------------------------+-------------------------+------------------------| +| 0.99 | 12960 | 194000 | ++------------------------+-------------------------+------------------------+ + +Policy NOT using curiosity: +[DOES NOT LEARN AT ALL] +""" +from ray.rllib.connectors.env_to_module import FlattenObservations +from ray.rllib.examples.connectors.classes.count_based_curiosity import ( + CountBasedCuriosity, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls + +parser = add_rllib_example_script_args( + default_reward=0.99, default_iters=200, default_timesteps=1000000 +) +parser.set_defaults(enable_new_api_stack=True) +parser.add_argument( + "--no-curiosity", + action="store_true", + help="Whether to NOT use count-based curiosity.", +) + +ENV_OPTIONS = { + "is_slippery": False, + # Use this hard-to-solve 8x8 map with lots of holes (H) to fall into and only very + # few valid paths from the starting state (S) to the goal state (G). + "desc": [ + "SFFHFFFH", + "FFFHFFFF", + "FFFHHFFF", + "FFFFFFFH", + "HFFHFFFF", + "HHFHFFHF", + "FFFHFHHF", + "FHFFFFFG", + ], + # Limit the number of steps the agent is allowed to make in the env to + # make it almost impossible to learn without (count-based) curiosity. + "max_episode_steps": 14, +} + + +if __name__ == "__main__": + args = parser.parse_args() + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment( + "FrozenLake-v1", + env_config=ENV_OPTIONS, + ) + .env_runners( + num_envs_per_env_runner=5, + # Flatten discrete observations (into one-hot vectors). + env_to_module_connector=lambda env: FlattenObservations(), + ) + .training( + # The main code in this example: We add the `CountBasedCuriosity` connector + # piece to our Learner connector pipeline. + # This pipeline is fed with collected episodes (either directly from the + # EnvRunners in on-policy fashion or from a replay buffer) and converts + # these episodes into the final train batch. The added piece computes + # intrinsic rewards based on simple observation counts and add them to + # the "main" (extrinsic) rewards. + learner_connector=( + None if args.no_curiosity else lambda *ags, **kw: CountBasedCuriosity() + ), + num_sgd_iter=10, + vf_loss_coeff=0.01, + ) + .rl_module(model_config_dict={"vf_share_layers": True}) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py new file mode 100644 index 000000000000..e69de29bb2d1 From 31e6230799391a57ffbaf4b47b75c2515be57a77 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 22 Jul 2024 15:58:17 +0200 Subject: [PATCH 02/13] wip Signed-off-by: sven1977 --- .../classes/count_based_curiosity.py | 6 +- .../euclidian_distance_based_curiosity.py | 117 ++++++++++++++++ .../curiosity/count_based_curiosity.py | 7 + .../euclidian_distance_based_curiosity.py | 129 ++++++++++++++++++ 4 files changed, 256 insertions(+), 3 deletions(-) diff --git a/rllib/examples/connectors/classes/count_based_curiosity.py b/rllib/examples/connectors/classes/count_based_curiosity.py index f9619cf620c0..37af0ad9bf13 100644 --- a/rllib/examples/connectors/classes/count_based_curiosity.py +++ b/rllib/examples/connectors/classes/count_based_curiosity.py @@ -9,7 +9,7 @@ class CountBasedCuriosity(ConnectorV2): - """Learning ConnectorV2 piece to compute intrinsic rewards based on obs counts. + """Learner ConnectorV2 piece to compute intrinsic rewards based on obs counts. Add this connector piece to your Learner pipeline, through your algo config: ``` @@ -25,7 +25,7 @@ class CountBasedCuriosity(ConnectorV2): r(i) = intrinsic_reward_coeff * (1 / C(obs(i))) where C is the total (lifetime) count of the obs at timestep i. - The instrinsic reward is added to the extrinsic reward and saved back into the + The intrinsic reward is added to the extrinsic reward and saved back into the episode (under the main "rewards" key). Note that the computation and saving back to the episode all happens before the @@ -79,8 +79,8 @@ def __call__( rewards = sa_episode.get_rewards() for i, (obs, rew) in enumerate(zip(observations, rewards)): - # Add 1 to obs counter. obs = tuple(obs) + # Add 1 to obs counter. self._counts[obs] += 1 # Compute our count-based intrinsic reward and add it to the main # (extrinsic) reward. diff --git a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py index e69de29bb2d1..58a9e2746d2f 100644 --- a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py +++ b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py @@ -0,0 +1,117 @@ +from collections import deque +from typing import Any, List, Optional + +import gymnasium as gym +import numpy as np + +from ray.rllib.connectors.connector_v2 import ConnectorV2 +from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.utils.typing import EpisodeType + + +class EuclidianDistanceBasedCuriosity(ConnectorV2): + """Learner ConnectorV2 piece computing intrinsic rewards with euclidian distance. + + Add this connector piece to your Learner pipeline, through your algo config: + ``` + config.training( + learner_connector=lambda obs_sp, act_sp: EuclidianDistanceBasedCuriosity() + ) + ``` + + Intrinsic rewards are computed on the Learner side based on comparing the euclidian + distance of observations vs already seen ones. A configurable number of observations + will be stored in a FIFO buffer and all incoming observations have their distance + measured against those. + + The minimum distance measured is the intrinsic reward for the incoming obs + (multiplied by a fixed coeffieicnt and added to the "main" extrinsic reward): + r(i) = intrinsic_reward_coeff * min(ED(o, o(i)) for o in stored_obs)) + where `ED` is the euclidian distance and `stored_obs` is the buffer. + + The intrinsic reward is then added to the extrinsic reward and saved back into the + episode (under the main "rewards" key). + + Note that the computation and saving back to the episode all happens before the + actual train batch is generated from the episode data. Thus, the Learner and the + RLModule used do not take notice of the extra reward added. + + Only one observation per incoming episode will be stored as a new one in the buffer. + Thereby, we pick the observation with the largest `min(ED)` value over all already + stored observations to be stored per episode. + + If you would like to use a simpler, count-based mechanism for intrinsic reward + computations, take a look at the `CountBasedCuriosity` connector piece + at `ray.rllib.examples.connectors.classes.count_based_curiosity` + """ + + def __init__( + self, + input_observation_space: Optional[gym.Space] = None, + input_action_space: Optional[gym.Space] = None, + *, + intrinsic_reward_coeff: float = 1.0, + max_buffer_size: int = 100, + **kwargs, + ): + """Initializes a CountBasedCuriosity instance. + + Args: + intrinsic_reward_coeff: The weight with which to multiply the intrinsic + reward before adding (and saving) it back to the main (extrinsic) + reward of the episode at each timestep. + """ + super().__init__(input_observation_space, input_action_space) + + # Create an observation buffer + self.obs_buffer = deque(maxlen=max_buffer_size) + self.intrinsic_reward_coeff = intrinsic_reward_coeff + + def __call__( + self, + *, + rl_module: RLModule, + data: Any, + episodes: List[EpisodeType], + explore: Optional[bool] = None, + shared_data: Optional[dict] = None, + **kwargs, + ) -> Any: + # Loop through all episodes and change the reward to + # [reward + intrinsic reward] + for sa_episode in self.single_agent_episode_iterator( + episodes=episodes, agents_that_stepped_only=False + ): + # Loop through all obs, except the last one. + observations = sa_episode.get_observations(slice(None, -1)) + # Get all respective (extrinsic) rewards. + rewards = sa_episode.get_rewards() + + max_dist_obs = None + max_dist = float("-inf") + for i, (obs, rew) in enumerate(zip(observations, rewards)): + # Compare obs to all stored observations and compute euclidian distance. + min_dist = 0.0 + if self.obs_buffer: + min_dist = min( + np.sqrt(np.sum((obs - stored_obs) ** 2)) + for stored_obs in self.obs_buffer + ) + if min_dist > max_dist: + max_dist = min_dist + max_dist_obs = obs + + # Compute our euclidian distance-based intrinsic reward and add it to + # the main (extrinsic) reward. + rew += self.intrinsic_reward_coeff * min_dist + # Store the new reward back to the episode (under the correct + # timestep/index). + sa_episode.set_rewards(new_data=rew, at_indices=i) + + # Add the one observation of this episode with the largest (min) euclidian + # dist to all already stored obs to the buffer (maybe throwing out the + # oldest obs in there). + if max_dist_obs is not None: + self.obs_buffer.append(max_dist_obs) + + return data diff --git a/rllib/examples/curiosity/count_based_curiosity.py b/rllib/examples/curiosity/count_based_curiosity.py index 62d79a387023..90f69a513ac9 100644 --- a/rllib/examples/curiosity/count_based_curiosity.py +++ b/rllib/examples/curiosity/count_based_curiosity.py @@ -68,6 +68,13 @@ default_reward=0.99, default_iters=200, default_timesteps=1000000 ) parser.set_defaults(enable_new_api_stack=True) +parser.add_argument( + "--intrinsic-reward-coeff", + type=float, + default=1.0, + help="The weight with which to multiply intrinsic rewards before adding them to " + "the extrinsic ones (default is 1.0).", +) parser.add_argument( "--no-curiosity", action="store_true", diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py index e69de29bb2d1..14fa09fb596c 100644 --- a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +++ b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py @@ -0,0 +1,129 @@ +"""Example of a euclidian-distance curiosity mechanism to learn in sparse-rewards envs. + +This example: + - demonstrates how to define your own euclidian-distance-based curiosity ConnectorV2 + piece that computes intrinsic rewards based on the delta between incoming + observations and some set of already stored (prior) observations. Thereby, the + further away the incoming observation is from the already stored ones, the higher + its corresponding intrinsic reward. + - shows how this connector piece adds the intrinsic reward to the corresponding + "main" (extrinsic) reward and overrides the value in the "rewards" key in the + episode. It thus demonstrates how to do reward shaping in general with RLlib. + - shows how to plug this connector piece into your algorithm's config. + - uses Tune and RLlib to learn the env described above and compares 2 + algorithms, one that does use curiosity vs one that does not. + +We use the MountainCar-v0 environment, a sparse-reward env that is very hard to learn +for a regular PPO algorithm. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +Use the `--no-curiosity` flag to disable curiosity learning and force your policy +to be trained on the task w/o the use of intrinsic rewards. With this option, the +algorithm should NOT succeed. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +In the console output, you can see that only a PPO policy that uses curiosity can +actually learn. + +Policy using count-based curiosity: ++-------------------------------+------------+--------+------------------+ +| Trial name | status | iter | total time (s) | +| | | | | +|-------------------------------+------------+--------+------------------+ +| PPO_FrozenLake-v1_109de_00000 | TERMINATED | 48 | 44.46 | ++-------------------------------+------------+--------+------------------+ ++------------------------+-------------------------+------------------------+ +| episode_return_mean | num_episodes_lifetime | num_env_steps_traine | +| | | d_lifetime | +|------------------------+-------------------------+------------------------| +| 0.99 | 12960 | 194000 | ++------------------------+-------------------------+------------------------+ + +Policy NOT using curiosity: +[DOES NOT LEARN AT ALL] +""" +from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.examples.connectors.classes.euclidian_distance_based_curiosity import ( + EuclidianDistanceBasedCuriosity, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls + +parser = add_rllib_example_script_args( + default_reward=-100.0, default_iters=2000, default_timesteps=1000000 +) +parser.set_defaults( + enable_new_api_stack=True, + num_env_runners=4, +) +parser.add_argument( + "--intrinsic-reward-coeff", + type=float, + default=0.00001, + help="The weight with which to multiply intrinsic rewards before adding them to " + "the extrinsic ones (default is 1.0).", +) +parser.add_argument( + "--no-curiosity", + action="store_true", + help="Whether to NOT use count-based curiosity.", +) + + +if __name__ == "__main__": + args = parser.parse_args() + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment("MountainCar-v0") + .env_runners( + env_to_module_connector=lambda env: MeanStdFilter(), + num_envs_per_env_runner=5, + ) + .training( + # The main code in this example: We add the + # `EuclidianDistanceBasedCuriosity` connector piece to our Learner connector + # pipeline. This pipeline is fed with collected episodes (either directly + # from the EnvRunners in on-policy fashion or from a replay buffer) and + # converts these episodes into the final train batch. The added piece + # computes intrinsic rewards based on simple observation counts and add them + # to the "main" (extrinsic) rewards. + learner_connector=( + None if args.no_curiosity + else lambda *ags, **kw: EuclidianDistanceBasedCuriosity() + ), + #train_batch_size_per_learner=512, + gamma=0.99, + lr=0.0002, + #lambda_=0.1, + #vf_clip_param=10.0, + #sgd_minibatch_size=64, + + #lambda_=0.98, + #num_sgd_iter=16, + #vf_loss_coeff=0.01, + #lr=0.0003, + ) + #.rl_module(model_config_dict={"vf_share_layers": True}) + ) + + run_rllib_example_script_experiment(base_config, args) From 5ae97edc24ad17b4c49d3ffaa66628554fc83167 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 22 Jul 2024 17:01:40 +0200 Subject: [PATCH 03/13] wip Signed-off-by: sven1977 --- rllib/examples/curiosity/euclidian_distance_based_curiosity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py index 14fa09fb596c..d45a486e6d5a 100644 --- a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +++ b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py @@ -114,7 +114,7 @@ #train_batch_size_per_learner=512, gamma=0.99, lr=0.0002, - #lambda_=0.1, + lambda_=0.98, #vf_clip_param=10.0, #sgd_minibatch_size=64, From f3348fcc57154a0ff412211fbee40f5755e61440 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 22 Jul 2024 18:48:52 +0200 Subject: [PATCH 04/13] wip Signed-off-by: sven1977 --- .../euclidian_distance_based_curiosity.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py index d45a486e6d5a..d8bab56f3e96 100644 --- a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +++ b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py @@ -77,9 +77,9 @@ parser.add_argument( "--intrinsic-reward-coeff", type=float, - default=0.00001, + default=0.0001, help="The weight with which to multiply intrinsic rewards before adding them to " - "the extrinsic ones (default is 1.0).", + "the extrinsic ones (default is 0.0001).", ) parser.add_argument( "--no-curiosity", @@ -108,22 +108,16 @@ # computes intrinsic rewards based on simple observation counts and add them # to the "main" (extrinsic) rewards. learner_connector=( - None if args.no_curiosity + None + if args.no_curiosity else lambda *ags, **kw: EuclidianDistanceBasedCuriosity() ), - #train_batch_size_per_learner=512, + # train_batch_size_per_learner=512, gamma=0.99, lr=0.0002, lambda_=0.98, - #vf_clip_param=10.0, - #sgd_minibatch_size=64, - - #lambda_=0.98, - #num_sgd_iter=16, - #vf_loss_coeff=0.01, - #lr=0.0003, ) - #.rl_module(model_config_dict={"vf_share_layers": True}) + # .rl_module(model_config_dict={"vf_share_layers": True}) ) run_rllib_example_script_experiment(base_config, args) From 00f50856b6418d4438b504ada4d39a2ed2e7f877 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 22 Jul 2024 19:13:07 +0200 Subject: [PATCH 05/13] Learns MountainCar-v0 until -119 reward Signed-off-by: sven1977 --- .../connectors/classes/euclidian_distance_based_curiosity.py | 5 +++++ .../examples/curiosity/euclidian_distance_based_curiosity.py | 4 +++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py index 58a9e2746d2f..0babff5a33f0 100644 --- a/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py +++ b/rllib/examples/connectors/classes/euclidian_distance_based_curiosity.py @@ -67,6 +67,8 @@ def __init__( self.obs_buffer = deque(maxlen=max_buffer_size) self.intrinsic_reward_coeff = intrinsic_reward_coeff + self._test = 0 + def __call__( self, *, @@ -77,6 +79,9 @@ def __call__( shared_data: Optional[dict] = None, **kwargs, ) -> Any: + if self._test > 10: + return data + self._test += 1 # Loop through all episodes and change the reward to # [reward + intrinsic reward] for sa_episode in self.single_agent_episode_iterator( diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py index d8bab56f3e96..640bdee4d06f 100644 --- a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +++ b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py @@ -113,11 +113,13 @@ else lambda *ags, **kw: EuclidianDistanceBasedCuriosity() ), # train_batch_size_per_learner=512, + grad_clip=20.0, + entropy_coeff=0.003, gamma=0.99, lr=0.0002, lambda_=0.98, ) - # .rl_module(model_config_dict={"vf_share_layers": True}) + #.rl_module(model_config_dict={"fcnet_activation": "relu"}) ) run_rllib_example_script_experiment(base_config, args) From 4b64182fb43028b749f70a8659a4e298d2c8462c Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 23 Jul 2024 12:21:39 +0200 Subject: [PATCH 06/13] wip Signed-off-by: sven1977 --- .../euclidian_distance_based_curiosity.py | 14 ++++++++++++++ .../euclidian_distance_based_curiosity.py | 8 ++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 rllib/examples/connectors/euclidian_distance_based_curiosity.py diff --git a/rllib/examples/connectors/euclidian_distance_based_curiosity.py b/rllib/examples/connectors/euclidian_distance_based_curiosity.py new file mode 100644 index 000000000000..6e52de767913 --- /dev/null +++ b/rllib/examples/connectors/euclidian_distance_based_curiosity.py @@ -0,0 +1,14 @@ +"""Placeholder for training with euclidian distance-based curiosity. + +The actual script can be found at a different location (see code below). +""" + +if __name__ == "__main__": + import subprocess + import sys + + # Forward to "python ../curiosity/[same script name].py [same options]" + command = [sys.executable, "../curiosity/", sys.argv[0]] + sys.argv[1:] + + # Run the script. + subprocess.run(command, capture_output=True) diff --git a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py index 640bdee4d06f..c24b805ed062 100644 --- a/rllib/examples/curiosity/euclidian_distance_based_curiosity.py +++ b/rllib/examples/curiosity/euclidian_distance_based_curiosity.py @@ -67,8 +67,12 @@ ) from ray.tune.registry import get_trainable_cls +# TODO (sven): SB3's PPO does seem to learn MountainCar-v0 until a reward of ~-110. +# We might have to play around some more with different initializations, more +# randomized SGD minibatching (we don't shuffle batch rn), etc.. to get to these +# results as well. parser = add_rllib_example_script_args( - default_reward=-100.0, default_iters=2000, default_timesteps=1000000 + default_reward=-130.0, default_iters=2000, default_timesteps=1000000 ) parser.set_defaults( enable_new_api_stack=True, @@ -119,7 +123,7 @@ lr=0.0002, lambda_=0.98, ) - #.rl_module(model_config_dict={"fcnet_activation": "relu"}) + # .rl_module(model_config_dict={"fcnet_activation": "relu"}) ) run_rllib_example_script_experiment(base_config, args) From 471f726d598f5fcebed55b1af632634f29df8616 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 23 Jul 2024 15:50:25 +0200 Subject: [PATCH 07/13] wip Signed-off-by: sven1977 --- .../inverse_dynamics_model_based_curiosity.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py diff --git a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py new file mode 100644 index 000000000000..3cff7c9f20f1 --- /dev/null +++ b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py @@ -0,0 +1,86 @@ +class InverseDynamicsModelBasedCuriosity: + """Implementation of: + [1] Curiosity-driven Exploration by Self-supervised Prediction + Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. + https://arxiv.org/pdf/1705.05363.pdf + + Learns a simplified model of the environment based on three networks: + 1) Embedding observations into latent space ("feature" network). + 2) Predicting the action, given two consecutive embedded observations + ("inverse" network). + 3) Predicting the next embedded obs, given an obs and action + ("forward" network). + + The less the agent is able to predict the actually observed next feature + vector, given obs and action (through the forwards network), the larger the + "intrinsic reward", which will be added to the extrinsic reward. + Therefore, if a state transition was unexpected, the agent becomes + "curious" and will further explore this transition leading to better + exploration in sparse rewards environments. + """ + + def __init__(self): + + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # Push both observations through feature net to get both phis. + phis, _ = self.model._curiosity_feature_net( + { + SampleBatch.OBS: torch.cat( + [ + torch.from_numpy(sample_batch[SampleBatch.OBS]).to( + policy.device + ), + torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to( + policy.device + ), + ] + ) + } + ) + phi, next_phi = torch.chunk(phis, 2) + actions_tensor = ( + torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device) + ) + + # Predict next phi with forward model. + predicted_next_phi = self.model._curiosity_forward_fcnet( + torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1) + ) + + # Forward loss term (predicted phi', given phi and action vs actually + # observed phi'). + forward_l2_norm_sqared = 0.5 * torch.sum( + torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1 + ) + forward_loss = torch.mean(forward_l2_norm_sqared) + + # Scale intrinsic reward by eta hyper-parameter. + sample_batch[SampleBatch.REWARDS] = ( + sample_batch[SampleBatch.REWARDS] + + self.eta * forward_l2_norm_sqared.detach().cpu().numpy() + ) + + # Inverse loss term (prediced action that led from phi to phi' vs + # actual action taken). + phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) + dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) + action_dist = ( + TorchCategorical(dist_inputs, self.model) + if isinstance(self.action_space, Discrete) + else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec) + ) + # Neg log(p); p=probability of observed action given the inverse-NN + # predicted action distribution. + inverse_loss = -action_dist.logp(actions_tensor) + inverse_loss = torch.mean(inverse_loss) + + # Calculate the ICM loss. + loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss + # Perform an optimizer step. + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + + # Return the postprocessed sample batch (with the corrected rewards). + return sample_batch From f99182a0f771264a16c64c79b499c8befe1aa3f6 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Wed, 24 Jul 2024 00:11:13 +0200 Subject: [PATCH 08/13] wip Signed-off-by: sven1977 --- .../inverse_dynamics_model_based_curiosity.py | 12 +- .../classes/inverse_dynamics_model_rlm.py | 158 ++++++++++++++++++ 2 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py diff --git a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py index 3cff7c9f20f1..374e6838d342 100644 --- a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py +++ b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py @@ -1,4 +1,14 @@ -class InverseDynamicsModelBasedCuriosity: +from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner + + +class PPOTorchLearnerWithCuriosity(PPOTorchLearner): + def build(self): + super().build() + + # Add + + +class InverseDynamicsBasedCuriosity: """Implementation of: [1] Curiosity-driven Exploration by Self-supervised Prediction Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. diff --git a/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py new file mode 100644 index 000000000000..dda39aed0fe7 --- /dev/null +++ b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py @@ -0,0 +1,158 @@ +from typing import Any, Dict + +import numpy as np + +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.torch.torch_distributions import TorchCategorical +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +class InverseDynamicsModel(TorchRLModule): + """An inverse-dynamics model as TorchRLModule for curiosity-based exploration. + + For more details, see: + [1] Curiosity-driven Exploration by Self-supervised Prediction + Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. + https://arxiv.org/pdf/1705.05363.pdf + + Learns a simplified model of the environment based on three networks: + 1) Embedding observations into latent space ("feature" network). + 2) Predicting the action, given two consecutive embedded observations + ("inverse" network). + 3) Predicting the next embedded obs, given an obs and action + ("forward" network). + + The less the agent is able to predict the actually observed next feature + vector, given obs and action (through the forwards network), the larger the + "intrinsic reward", which will be added to the extrinsic reward. + Therefore, if a state transition was unexpected, the agent becomes + "curious" and will further explore this transition leading to better + exploration in sparse rewards environments. + + .. testcode:: + + import numpy as np + import gymnasium as gym + from ray.rllib.core.rl_module.rl_module import RLModuleConfig + + B = 10 # batch size + T = 5 # seq len + f = 25 # feature dim + + # Construct the RLModule. + rl_module_config = RLModuleConfig( + observation_space=gym.spaces.Box(-1.0, 1.0, (f,), np.float32), + action_space=gym.spaces.Discrete(4), + model_config_dict={"lstm_cell_size": CELL} + ) + my_net = LSTMContainingRLModule(rl_module_config) + + # Create some dummy input. + obs = torch.from_numpy( + np.random.random_sample(size=(B, T, f) + ).astype(np.float32)) + state_in = my_net.get_initial_state() + # Repeat state_in across batch. + state_in = tree.map_structure( + lambda s: torch.from_numpy(s).unsqueeze(0).repeat(B, 1), state_in + ) + input_dict = { + Columns.OBS: obs, + Columns.STATE_IN: state_in, + } + + # Run through all 3 forward passes. + print(my_net.forward_inference(input_dict)) + print(my_net.forward_exploration(input_dict)) + print(my_net.forward_train(input_dict)) + + # Print out the number of parameters. + num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) + print(f"num params = {num_all_params}") + """ + + @override(TorchRLModule) + def setup(self): + # Assume a simple Box(1D) tensor as input shape. + in_size = self.config.observation_space.shape[0] + + # Get the IDM achitecture settings from the our RLModuleConfig's (self.config) + # `model_config_dict` property: + cfg = self.config.model_config_dict + self._feature_dim = cfg.get("feature_dim", 288) + feature_net_config: Optional[ModelConfigDict] = None, + self._inverse_net_hiddens = cfg.get("inverse_net_hiddens", (256,)) + self._inverse_net_activation = cfg.get("inverse_net_activation", "relu") + self._forward_net_hiddens = cfg.get("forward_net_hiddens", (256,)) + self._forward_net_activation = cfg.get("forward_net_activation", "relu") + + + # Build the inverse model (predicting . + layers = [] + # Get the dense layer pre-stack configuration from the same config dict. + dense_layers = self.config.model_config_dict.get("dense_layers", [128, 128]) + for out_size in dense_layers: + # Dense layer. + layers.append(nn.Linear(in_size, out_size)) + # ReLU activation. + layers.append(nn.ReLU()) + in_size = out_size + + self._fc_net = nn.Sequential(*layers) + + # Logits layer (no bias, no activation). + self._logits = nn.Linear(in_size, self.config.action_space.n) + # Single-node value layer. + self._values = nn.Linear(in_size, 1) + + @override(TorchRLModule) + def _forward_train(self, batch, **kwargs): + # Compute the basic 1D feature tensor (inputs to policy- and value-heads). + features, state_out, logits = self._compute_features_state_out_and_logits(batch) + # Besides the action logits, we also have to return value predictions here + # (to be used inside the loss function). + values = self._values(features).squeeze(-1) + return { + Columns.STATE_OUT: state_out, + Columns.ACTION_DIST_INPUTS: logits, + Columns.VF_PREDS: values, + } + + @override(TorchRLModule) + def get_train_action_dist_cls(self): + return TorchCategorical + + def _compute_features_state_out_and_logits(self, batch): + obs = batch[Columns.OBS] + state_in = batch[Columns.STATE_IN] + h, c = state_in["h"], state_in["c"] + # Unsqueeze the layer dim (we only have 1 LSTM layer. + features, (h, c) = self._lstm( + obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major + (h.unsqueeze(0), c.unsqueeze(0)), + ) + # Make batch-major again. + features = features.permute(1, 0, 2) + # Push through our FC net. + features = self._fc_net(features) + logits = self._logits(features) + return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits + + # Inference and exploration not supported (this is a world-model that should only + # be used for training). + @override(TorchRLModule) + def _forward_inference(self, batch, **kwargs): + raise NotImplementedError( + "InverseDynamicsModel should only be used for training! " + "Use `forward_train()` instead." + ) + + @override(TorchRLModule) + def _forward_exploration(self, batch, **kwargs): + return self._forward_inference(batch) From a7d8338a075800e684c480d2f929e7399e9daec2 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Thu, 25 Jul 2024 13:25:33 +0200 Subject: [PATCH 09/13] learning 12x12 frozenlake with max-steps=22 w/ 1M env steps Signed-off-by: sven1977 --- rllib/algorithms/dqn/dqn_rainbow_learner.py | 4 +- .../common/batch_individual_items.py | 4 +- rllib/connectors/connector_v2.py | 24 +- rllib/core/columns.py | 3 + rllib/core/learner/learner.py | 9 +- rllib/core/models/specs/checker.py | 26 +- rllib/core/rl_module/torch/torch_rl_module.py | 12 +- rllib/core/testing/testing_learner.py | 2 +- .../inverse_dynamics_model_based_curiosity.py | 231 +++++++++++------- .../curriculum/curriculum_learning.py | 11 +- .../classes/curiosity_ppo_torch_learner.py | 143 +++++++++++ .../classes/inverse_dynamics_model_rlm.py | 182 ++++++++++---- 12 files changed, 473 insertions(+), 178 deletions(-) create mode 100644 rllib/examples/learners/classes/curiosity_ppo_torch_learner.py diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index 9728eeef84de..562e707d1427 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -55,8 +55,8 @@ def build(self) -> None: ) ) - # Prepend a NEXT_OBS from episodes to train batch connector piece (right - # after the observation default piece). + # Prepend a "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). if self.config.add_default_connectors_to_learner_pipeline: self._learner_connector.insert_after( AddObservationsFromEpisodesToBatch, diff --git a/rllib/connectors/common/batch_individual_items.py b/rllib/connectors/common/batch_individual_items.py index b095d4d77a7a..7708a0417d9f 100644 --- a/rllib/connectors/common/batch_individual_items.py +++ b/rllib/connectors/common/batch_individual_items.py @@ -58,8 +58,8 @@ def __call__( # connector piece is called. if not self._multi_agent: continue - # If MA Off-Policy and independent sampling we need to overcome - # this check. + # If MA Off-Policy and independent sampling we need to overcome this + # check. module_data = column_data for col, col_data in module_data.copy().items(): if isinstance(col_data, list) and col != Columns.INFOS: diff --git a/rllib/connectors/connector_v2.py b/rllib/connectors/connector_v2.py index 33b3e6da28eb..fb678be07c2d 100644 --- a/rllib/connectors/connector_v2.py +++ b/rllib/connectors/connector_v2.py @@ -415,17 +415,18 @@ def add_batch_item( `column`. """ sub_key = None - if ( - single_agent_episode is not None - and single_agent_episode.agent_id is not None - ): - sub_key = ( - single_agent_episode.multi_agent_episode_id, - single_agent_episode.agent_id, - single_agent_episode.module_id, - ) - elif single_agent_episode is not None: - sub_key = (single_agent_episode.id_,) + # SAEpisode is provided ... + if single_agent_episode is not None: + # ... and has `agent_id` -> Use agent ID and module ID from it. + if single_agent_episode.agent_id is not None: + sub_key = ( + single_agent_episode.multi_agent_episode_id, + single_agent_episode.agent_id, + single_agent_episode.module_id, + ) + # Otherwise, just use episode's ID. + else: + sub_key = (single_agent_episode.id_,) if column not in batch: batch[column] = [] if sub_key is None else {sub_key: []} @@ -443,6 +444,7 @@ def add_n_batch_items( items_to_add: Any, num_items: int, single_agent_episode: Optional[SingleAgentEpisode] = None, + ) -> None: """Adds a list of items (or batched item) under `column` to the given `batch`. diff --git a/rllib/core/columns.py b/rllib/core/columns.py index 7a46d1282f10..0944d521e2c1 100644 --- a/rllib/core/columns.py +++ b/rllib/core/columns.py @@ -59,6 +59,9 @@ class Columns: ADVANTAGES = "advantages" VALUE_TARGETS = "value_targets" + # Intrinsic rewards (learning with curiosity). + INTRINSIC_REWARDS = "intrinsic_rewards" + # Loss mask. If provided in a train batch, a Learner's compute_loss_for_module # method should respect the False-set value in here and mask out the respective # items form the loss. diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index a840665bb633..68bfeab16898 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -835,11 +835,8 @@ def should_module_be_updated(self, module_id, multi_agent_batch=None): @OverrideToImplementCustomLogic def compute_loss( - self, - *, - fwd_out: Dict[str, Any], - batch: Dict[str, Any], - ) -> Union[TensorType, Dict[str, Any]]: + self, *, fwd_out: Dict[str, Any], batch: Dict[str, Any] + ) -> Dict[str, Any]: """Computes the loss for the module being optimized. This method must be overridden by multiagent-specific algorithm learners to @@ -892,7 +889,7 @@ def compute_loss_for_module( self, *, module_id: ModuleID, - config: Optional["AlgorithmConfig"] = None, + config: "AlgorithmConfig", batch: Dict[str, Any], fwd_out: Dict[str, TensorType], ) -> TensorType: diff --git a/rllib/core/models/specs/checker.py b/rllib/core/models/specs/checker.py index 190820552cfd..1ff3ead16ecd 100644 --- a/rllib/core/models/specs/checker.py +++ b/rllib/core/models/specs/checker.py @@ -167,7 +167,7 @@ def check_input_specs( This is a stateful decorator (https://realpython.com/primer-on-python-decorators/#stateful-decorators) to enforce input specs for any instance method that has an argument named - `input_data` in its args. + `batch` in its args. See more examples in ../tests/test_specs_dict.py) @@ -183,7 +183,7 @@ def input_specs(self): return {"obs": TensorSpec("b, d", d=64)} @check_input_specs("input_specs", only_check_on_retry=False) - def forward(self, input_data, return_loss=False): + def forward(self, batch, return_loss=False): ... model = MyModel() @@ -194,11 +194,11 @@ def forward(self, input_data, return_loss=False): Args: func: The instance method to decorate. It should be a callable that takes - `self` as the first argument, `input_data` as the second argument and any + `self` as the first argument, `batch` as the second argument and any other keyword argument thereafter. input_specs: `self` should have an instance attribute whose name matches the string in input_specs and returns the `SpecDict`, `Spec`, or simply the - `Type` that the `input_data` should comply with. It can also be None or + `Type` that the `batch` should comply with. It can also be None or empty list / dict to enforce no input spec. only_check_on_retry: If True, the spec will not be checked. Only if the decorated method raises an Exception, we check the spec to provide a more @@ -220,7 +220,7 @@ def forward(self, input_data, return_loss=False): def decorator(func): @functools.wraps(func) - def wrapper(self, input_data, **kwargs): + def wrapper(self, batch, **kwargs): if cache and not hasattr(self, "__checked_input_specs_cache__"): self.__checked_input_specs_cache__ = {} @@ -228,7 +228,7 @@ def wrapper(self, input_data, **kwargs): if only_check_on_retry: # Attempt to run the function without spec checking try: - return func(self, input_data, **kwargs) + return func(self, batch, **kwargs) except SpecCheckingError as e: raise e except Exception as e: @@ -242,7 +242,7 @@ def wrapper(self, input_data, **kwargs): ) # If the function was not executed successfully yet, we check specs - checked_data = input_data + checked_data = batch if input_specs and ( initial_exception @@ -262,7 +262,7 @@ def wrapper(self, input_data, **kwargs): checked_data = _validate( cls_instance=self, method=func, - data=input_data, + data=batch, spec=spec, tag="input", ) @@ -312,17 +312,17 @@ def output_specs(self): return {"obs": TensorSpec("b, d", d=64)} @check_output_specs("output_specs") - def forward(self, input_data, return_loss=False): + def forward(self, batch, return_loss=False): return {"obs": torch.randn(32, 64)} Args: func: The instance method to decorate. It should be a callable that takes - `self` as the first argument, `input_data` as the second argument and any + `self` as the first argument, `batch` as the second argument and any other keyword argument thereafter. It should return a single dict-like object (i.e. not a tuple). output_specs: `self` should have an instance attribute whose name matches the string in output_specs and returns the `SpecDict`, `Spec`, or simply the - `Type` that the `input_data` should comply with. It can alos be None or + `Type` that the `batch` should comply with. It can alos be None or empty list / dict to enforce no input spec. cache: If True, only checks the data validation for the first time the instance method is called. @@ -338,11 +338,11 @@ def forward(self, input_data, return_loss=False): def decorator(func): @functools.wraps(func) - def wrapper(self, input_data, **kwargs): + def wrapper(self, batch, **kwargs): if cache and not hasattr(self, "__checked_output_specs_cache__"): self.__checked_output_specs_cache__ = {} - output_data = func(self, input_data, **kwargs) + output_data = func(self, batch, **kwargs) if output_specs and ( not cache or func.__name__ not in self.__checked_output_specs_cache__ diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 10a29a621232..92d428d18d91 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -163,12 +163,12 @@ def restore_from_path(self, *args, **kwargs): def get_metadata(self, *args, **kwargs): self.unwrapped().get_metadata(*args, **kwargs) - # TODO (sven): Figure out a better way to avoid having to method-spam this wrapper - # class, whenever we add a new API to any wrapped RLModule here. We could try - # auto generating the wrapper methods, but this will bring its own challenge - # (e.g. recursive calls due to __getattr__ checks, etc..). - def _compute_values(self, *args, **kwargs): - return self.unwrapped()._compute_values(*args, **kwargs) + ## TODO (sven): Figure out a better way to avoid having to method-spam this wrapper + ## class, whenever we add a new API to any wrapped RLModule here. We could try + ## auto generating the wrapper methods, but this will bring its own challenge + ## (e.g. recursive calls due to __getattr__ checks, etc..). + #def _compute_values(self, *args, **kwargs): + # return self.unwrapped()._compute_values(*args, **kwargs) @override(RLModule) def unwrapped(self) -> "RLModule": diff --git a/rllib/core/testing/testing_learner.py b/rllib/core/testing/testing_learner.py index 057784a300af..0ef5a1fd9f36 100644 --- a/rllib/core/testing/testing_learner.py +++ b/rllib/core/testing/testing_learner.py @@ -63,7 +63,7 @@ def get_default_rl_module_spec(self) -> "RLModuleSpec": class BaseTestingLearner(Learner): @override(Learner) - def compute_loss_for_module(self, *, module_id, config=None, batch, fwd_out): + def compute_loss_for_module(self, *, module_id, config, batch, fwd_out): # This is to check if in the multi-gpu case, the weights across workers are # the same. It is really only needed during testing. if config.report_mean_weights: diff --git a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py index 374e6838d342..ddeefae8d3dd 100644 --- a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py +++ b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py @@ -1,96 +1,153 @@ -from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner - - -class PPOTorchLearnerWithCuriosity(PPOTorchLearner): - def build(self): - super().build() - - # Add - - -class InverseDynamicsBasedCuriosity: - """Implementation of: - [1] Curiosity-driven Exploration by Self-supervised Prediction - Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. - https://arxiv.org/pdf/1705.05363.pdf - - Learns a simplified model of the environment based on three networks: - 1) Embedding observations into latent space ("feature" network). - 2) Predicting the action, given two consecutive embedded observations - ("inverse" network). - 3) Predicting the next embedded obs, given an obs and action - ("forward" network). - - The less the agent is able to predict the actually observed next feature - vector, given obs and action (through the forwards network), the larger the - "intrinsic reward", which will be added to the extrinsic reward. - Therefore, if a state transition was unexpected, the agent becomes - "curious" and will further explore this transition leading to better - exploration in sparse rewards environments. - """ - +"""Implementation of: +[1] Curiosity-driven Exploration by Self-supervised Prediction +Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. +https://arxiv.org/pdf/1705.05363.pdf + +Learns a simplified model of the environment based on three networks: +1) Embedding observations into latent space ("feature" network). +2) Predicting the action, given two consecutive embedded observations +("inverse" network). +3) Predicting the next embedded obs, given an obs and action +("forward" network). + +The less the agent is able to predict the actually observed next feature +vector, given obs and action (through the forwards network), the larger the +"intrinsic reward", which will be added to the extrinsic reward. +Therefore, if a state transition was unexpected, the agent becomes +"curious" and will further explore this transition leading to better +exploration in sparse rewards environments. +""" +from collections import defaultdict + +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.connectors.env_to_module import FlattenObservations +from ray.rllib.examples.learners.classes.curiosity_ppo_torch_learner import ( + PPOConfigWithCuriosity, PPOTorchLearnerWithCuriosity +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) + +parser = add_rllib_example_script_args( + default_iters=20000, + default_timesteps=100000000, + default_reward=1.0, +) +parser.set_defaults(enable_new_api_stack=True) + + + +class PrintMaxDistanceFrozenLakeCallback(DefaultCallbacks): def __init__(self): - - - def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: - # Push both observations through feature net to get both phis. - phis, _ = self.model._curiosity_feature_net( - { - SampleBatch.OBS: torch.cat( - [ - torch.from_numpy(sample_batch[SampleBatch.OBS]).to( - policy.device - ), - torch.from_numpy(sample_batch[SampleBatch.NEXT_OBS]).to( - policy.device - ), - ] - ) - } + super().__init__() + self.max_dists = defaultdict(float) + self.max_dists_lifetime = 0.0 + + def on_episode_step( + self, + *, + episode, + env_runner, + metrics_logger, + env, + env_index, + rl_module, + **kwargs, + ): + obs = episode.get_observations(-1) + num_rows = env.envs[0].unwrapped.nrow + num_cols = env.envs[0].unwrapped.ncol + row = obs // num_cols + col = obs % num_rows + curr_dist = (row ** 2 + col ** 2) ** 0.5 + if curr_dist > self.max_dists[episode.id_]: + self.max_dists[episode.id_] = curr_dist + + def on_episode_end( + self, + *, + episode, + env_runner, + metrics_logger, + env, + env_index, + rl_module, + **kwargs, + ): + # Compute current maximum distance across all running episodes + # (including the just ended one). + max_dist = max(self.max_dists.values()) + metrics_logger.log_value( + key="max_dist_travelled_across_running_episodes", + value=max_dist, + window=10, ) - phi, next_phi = torch.chunk(phis, 2) - actions_tensor = ( - torch.from_numpy(sample_batch[SampleBatch.ACTIONS]).long().to(policy.device) + if max_dist > self.max_dists_lifetime: + self.max_dists_lifetime = max_dist + del self.max_dists[episode.id_] + + def on_sample_end( + self, + *, + env_runner, + metrics_logger, + samples, + **kwargs, + ): + metrics_logger.log_value( + key="max_dist_travelled_lifetime", + value=self.max_dists_lifetime, + window=1, ) - # Predict next phi with forward model. - predicted_next_phi = self.model._curiosity_forward_fcnet( - torch.cat([phi, one_hot(actions_tensor, self.action_space).float()], dim=-1) - ) - # Forward loss term (predicted phi', given phi and action vs actually - # observed phi'). - forward_l2_norm_sqared = 0.5 * torch.sum( - torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1 +if __name__ == "__main__": + args = parser.parse_args() + + base_config = ( + PPOConfigWithCuriosity() + .environment( + "FrozenLake-v1", + env_config={ + "desc": [ + "SFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFF", + "FFFFFFFFFFFG", + ], + "is_slippery": False, + # Limit the number of steps the agent is allowed to make in the env to + # make it almost impossible to learn without the curriculum. + "max_episode_steps": 22, + }, ) - forward_loss = torch.mean(forward_l2_norm_sqared) - - # Scale intrinsic reward by eta hyper-parameter. - sample_batch[SampleBatch.REWARDS] = ( - sample_batch[SampleBatch.REWARDS] - + self.eta * forward_l2_norm_sqared.detach().cpu().numpy() + # Use our custom `curiosity` method to set up the ICM and our PPO/ICM-Learner. + .curiosity( + #curiosity_feature_net_hiddens=[256, 256], + #curiosity_inverse_net_activation="relu" ) - - # Inverse loss term (prediced action that led from phi to phi' vs - # actual action taken). - phi_cat_next_phi = torch.cat([phi, next_phi], dim=-1) - dist_inputs = self.model._curiosity_inverse_fcnet(phi_cat_next_phi) - action_dist = ( - TorchCategorical(dist_inputs, self.model) - if isinstance(self.action_space, Discrete) - else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec) + .callbacks(PrintMaxDistanceFrozenLakeCallback) + .env_runners( + num_envs_per_env_runner=5, + env_to_module_connector=lambda env: FlattenObservations(), ) - # Neg log(p); p=probability of observed action given the inverse-NN - # predicted action distribution. - inverse_loss = -action_dist.logp(actions_tensor) - inverse_loss = torch.mean(inverse_loss) - - # Calculate the ICM loss. - loss = (1.0 - self.beta) * inverse_loss + self.beta * forward_loss - # Perform an optimizer step. - self._optimizer.zero_grad() - loss.backward() - self._optimizer.step() + .training( + learner_class=PPOTorchLearnerWithCuriosity, + train_batch_size_per_learner=2000, + num_sgd_iter=6, + #vf_loss_coeff=0.01, + lr=0.0003, + ) + .rl_module(model_config_dict={"vf_share_layers": True}) + ) - # Return the postprocessed sample batch (with the corrected rewards). - return sample_batch + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/curriculum/curriculum_learning.py b/rllib/examples/curriculum/curriculum_learning.py index b906d0e017c9..a6f0e9fb2d26 100644 --- a/rllib/examples/curriculum/curriculum_learning.py +++ b/rllib/examples/curriculum/curriculum_learning.py @@ -72,6 +72,7 @@ from ray.tune.registry import get_trainable_cls parser = add_rllib_example_script_args(default_iters=100, default_timesteps=600000) +parser.set_defaults(enable_new_api_stack=True) parser.add_argument( "--upgrade-task-threshold", type=float, @@ -212,16 +213,16 @@ def on_train_result( **ENV_OPTIONS, }, ) + .env_runners( + num_envs_per_env_runner=5, + env_to_module_connector=lambda env: FlattenObservations(), + ) .training( num_sgd_iter=6, vf_loss_coeff=0.01, lr=0.0002, - model={"vf_share_layers": True}, - ) - .env_runners( - num_envs_per_env_runner=5, - env_to_module_connector=lambda env: FlattenObservations(), ) + .rl_module(model_config_dict={"vf_share_layers": True}) ) stop = { diff --git a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py new file mode 100644 index 000000000000..29c3dad70813 --- /dev/null +++ b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py @@ -0,0 +1,143 @@ +from typing import Any, Dict, Tuple + +from ray.rllib.algorithms.algorithm_config import NotProvided +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner +from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( + AddObservationsFromEpisodesToBatch, +) +from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa + AddNextObservationsFromEpisodesToTrainBatch, +) +from ray.rllib.core import Columns, DEFAULT_MODULE_ID +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.rl_modules.classes.inverse_dynamics_model_rlm import ( + InverseDynamicsModel +) +from ray.rllib.utils.metrics import ALL_MODULES + +ICM_MODULE_ID = "_inverse_dynamics_model" + + +class PPOConfigWithCuriosity(PPOConfig): + # Define defaults. + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.curiosity_feature_net_hiddens = (256, 256) + self.curiosity_feature_net_activation = "relu" + self.curiosity_inverse_net_hiddens = (256, 256) + self.curiosity_inverse_net_activation = "relu" + self.curiosity_forward_net_hiddens = (256, 256) + self.curiosity_forward_net_activation = "relu" + self.curiosity_beta = 0.2 + self.curiosity_eta = 1.0 + + # Allow users to change curiosity settings. + def curiosity( + self, + *, + curiosity_feature_net_hiddens: Tuple[int, ...] = NotProvided, + curiosity_feature_net_activation: str = NotProvided, + curiosity_inverse_net_hiddens: Tuple[int, ...] = NotProvided, + curiosity_inverse_net_activation: str = NotProvided, + curiosity_forward_net_hiddens: Tuple[int, ...] = NotProvided, + curiosity_forward_net_activation: str = NotProvided, + curiosity_beta: float = NotProvided, + curiosity_eta: float = NotProvided, + ): + if curiosity_feature_net_hiddens is not NotProvided: + self.curiosity_feature_net_hiddens = curiosity_feature_net_hiddens + if curiosity_feature_net_activation is not NotProvided: + self.curiosity_feature_net_activation = curiosity_feature_net_activation + if curiosity_inverse_net_hiddens is not NotProvided: + self.curiosity_inverse_net_hiddens = curiosity_inverse_net_hiddens + if curiosity_inverse_net_activation is not NotProvided: + self.curiosity_inverse_net_activation = curiosity_inverse_net_activation + if curiosity_forward_net_hiddens is not NotProvided: + self.curiosity_forward_net_hiddens = curiosity_forward_net_hiddens + if curiosity_forward_net_activation is not NotProvided: + self.curiosity_forward_net_activation = curiosity_forward_net_activation + if curiosity_beta is not NotProvided: + self.curiosity_beta = curiosity_beta + if curiosity_eta is not NotProvided: + self.curiosity_eta = curiosity_eta + return self + + +class PPOTorchLearnerWithCuriosity(PPOTorchLearner): + + def build(self): + super().build() + + # Assert, we are only training one policy (RLModule). + assert len(self.module) == 1 and DEFAULT_MODULE_ID in self.module + + # Add an InverseDynamicsModel to our MARLModule. + icm_spec = SingleAgentRLModuleSpec( + module_class=InverseDynamicsModel, + observation_space=self.module[DEFAULT_MODULE_ID].config.observation_space, + action_space=self.module[DEFAULT_MODULE_ID].config.action_space, + model_config_dict={ + "feature_net_hiddens": self.config.curiosity_feature_net_hiddens, + "feature_net_activation": self.config.curiosity_feature_net_activation, + "inverse_net_hiddens": self.config.curiosity_inverse_net_hiddens, + "inverse_net_activation": self.config.curiosity_inverse_net_activation, + "forward_net_hiddens": self.config.curiosity_forward_net_hiddens, + "forward_net_activation": self.config.curiosity_forward_net_activation, + }, + ) + self.add_module( + module_id=ICM_MODULE_ID, + module_spec=icm_spec, + ) + + # Prepend a "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # after the corresponding "add-OBS-..." default piece). + if self.config.add_default_connectors_to_learner_pipeline: + self._learner_connector.insert_after( + AddObservationsFromEpisodesToBatch, + AddNextObservationsFromEpisodesToTrainBatch(), + ) + + + def compute_loss( + self, + *, + fwd_out: Dict[str, Any], + batch: Dict[str, Any], + ) -> Dict[str, Any]: + # Compute the ICM loss first (so we'll have the chance to change the rewards + # in the batch for the "main" RLModule (before we compute its loss with the + # intrinsic rewards). + icm = self.module[ICM_MODULE_ID] + # Send the exact same batch to the ICM module that we used for the "main" + # RLModule's forward pass. + icm_fwd_out = icm.forward_train(batch=batch[DEFAULT_MODULE_ID]) + # Compute the loss of the ICM module. + icm_loss = icm.compute_loss_for_module( + learner=self, + module_id=ICM_MODULE_ID, + config=self.config.get_config_for_module(ICM_MODULE_ID), + batch=batch[DEFAULT_MODULE_ID], + fwd_out=icm_fwd_out, + ) + + # Add intrinsic rewards from ICM's `fwd_out` (multiplied by factor `eta`) + # to "main" module batch's extrinsic rewards. + batch[DEFAULT_MODULE_ID][Columns.REWARDS] += ( + self.config.curiosity_eta * icm_fwd_out[Columns.INTRINSIC_REWARDS] + ) + + # Compute the "main" RLModule's loss. + main_loss = self.compute_loss_for_module( + module_id=DEFAULT_MODULE_ID, + config=self.config.get_config_for_module(DEFAULT_MODULE_ID), + batch=batch[DEFAULT_MODULE_ID], + fwd_out=fwd_out[DEFAULT_MODULE_ID], + ) + + return { + DEFAULT_MODULE_ID: main_loss, + ICM_MODULE_ID: icm_loss, + ALL_MODULES: main_loss + icm_loss, + } diff --git a/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py index dda39aed0fe7..c632d09ba92a 100644 --- a/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py +++ b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, TYPE_CHECKING import numpy as np @@ -6,9 +6,15 @@ from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.torch.torch_distributions import TorchCategorical +from ray.rllib.models.utils import get_activation_fn from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.torch_utils import one_hot +from ray.rllib.utils.typing import ModuleID + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.core.learner.torch.torch_learner import TorchLearner torch, nn = try_import_torch() @@ -79,70 +85,156 @@ class InverseDynamicsModel(TorchRLModule): @override(TorchRLModule) def setup(self): - # Assume a simple Box(1D) tensor as input shape. - in_size = self.config.observation_space.shape[0] - # Get the IDM achitecture settings from the our RLModuleConfig's (self.config) # `model_config_dict` property: cfg = self.config.model_config_dict - self._feature_dim = cfg.get("feature_dim", 288) - feature_net_config: Optional[ModelConfigDict] = None, - self._inverse_net_hiddens = cfg.get("inverse_net_hiddens", (256,)) - self._inverse_net_activation = cfg.get("inverse_net_activation", "relu") - self._forward_net_hiddens = cfg.get("forward_net_hiddens", (256,)) - self._forward_net_activation = cfg.get("forward_net_activation", "relu") + feature_dim = cfg.get("feature_dim", 288) - # Build the inverse model (predicting . + # Build the feature model (encoder of observations to feature space). layers = [] - # Get the dense layer pre-stack configuration from the same config dict. - dense_layers = self.config.model_config_dict.get("dense_layers", [128, 128]) + dense_layers = cfg.get("feature_net_hiddens", (256, 256)) + # `in_size` is the observation space (assume a simple Box(1D)). + in_size = self.config.observation_space.shape[0] for out_size in dense_layers: - # Dense layer. layers.append(nn.Linear(in_size, out_size)) - # ReLU activation. - layers.append(nn.ReLU()) + if cfg.get("feature_net_activation"): + layers.append( + get_activation_fn(cfg["feature_net_activation"], "torch")() + ) in_size = out_size + # Last feature layer of n nodes (feature dimension). + layers.append(nn.Linear(in_size, feature_dim)) + self._feature_net = nn.Sequential(*layers) - self._fc_net = nn.Sequential(*layers) + # Build the inverse model (predicting the action between two observations). + layers = [] + dense_layers = cfg.get("inverse_net_hiddens", (256,)) + # `in_size` is 2x the feature dim. + in_size = feature_dim * 2 + for out_size in dense_layers: + layers.append(nn.Linear(in_size, out_size)) + if cfg.get("inverse_net_activation"): + layers.append( + get_activation_fn(cfg["inverse_net_activation"], "torch")() + ) + in_size = out_size + # Last feature layer of n nodes (action space). + layers.append(nn.Linear(in_size, self.config.action_space.n)) + self._inverse_net = nn.Sequential(*layers) - # Logits layer (no bias, no activation). - self._logits = nn.Linear(in_size, self.config.action_space.n) - # Single-node value layer. - self._values = nn.Linear(in_size, 1) + # Build the forward model (predicting the next observation from current one and + # action). + layers = [] + dense_layers = cfg.get("forward_net_hiddens", (256,)) + # `in_size` is the feature dim + action space (one-hot). + in_size = feature_dim + self.config.action_space.n + for out_size in dense_layers: + layers.append(nn.Linear(in_size, out_size)) + if cfg.get("forward_net_activation"): + layers.append( + get_activation_fn(cfg["forward_net_activation"], "torch")() + ) + in_size = out_size + # Last feature layer of n nodes (feature dimension). + layers.append(nn.Linear(in_size, feature_dim)) + self._forward_net = nn.Sequential(*layers) @override(TorchRLModule) def _forward_train(self, batch, **kwargs): - # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, state_out, logits = self._compute_features_state_out_and_logits(batch) - # Besides the action logits, we also have to return value predictions here - # (to be used inside the loss function). - values = self._values(features).squeeze(-1) - return { - Columns.STATE_OUT: state_out, - Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: values, + # Push both observations through feature net to get feature vectors (phis). + # We cat/batch them here for efficiency reasons (save one forward pass). + phis = self._feature_net( + torch.cat( + [ + batch[Columns.OBS], + batch[Columns.NEXT_OBS], + ], + dim=0, + ) + ) + # Split again to yield 2 individual phi tensors. + phi, next_phi = torch.chunk(phis, 2) + + # Predict next feature vector (next_phi) with forward model (using obs and + # actions). + predicted_next_phi = self._forward_net( + torch.cat( + [ + phi, + one_hot(batch[Columns.ACTIONS].long(), self.config.action_space).float(), + ], + dim=-1, + ) + ) + + # Forward loss term: Predicted phi - given phi and action - vs actually observed + # phi (square-root of L2 norm). Note that this is the intrinsic reward that + # will be used and the mean of this is the forward net loss. + forward_l2_norm_sqrt = 0.5 * torch.sum( + torch.pow(predicted_next_phi - next_phi, 2.0), dim=-1 + ) + + output = { + Columns.INTRINSIC_REWARDS: forward_l2_norm_sqrt, + # Computed feature vectors (used to compute the losses later). + "phi": phi, + "next_phi": next_phi, } + return output + @override(TorchRLModule) def get_train_action_dist_cls(self): + ## TorchCategorical + ## if isinstance(self.action_space, Discrete) + ## else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec) + ##) return TorchCategorical - def _compute_features_state_out_and_logits(self, batch): - obs = batch[Columns.OBS] - state_in = batch[Columns.STATE_IN] - h, c = state_in["h"], state_in["c"] - # Unsqueeze the layer dim (we only have 1 LSTM layer. - features, (h, c) = self._lstm( - obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major - (h.unsqueeze(0), c.unsqueeze(0)), + @staticmethod + def compute_loss_for_module( + *, + learner: "TorchLearner", + module_id: ModuleID, + config: "AlgorithmConfig", + batch: Dict[str, Any], + fwd_out: Dict[str, Any], + ) -> Dict[str, Any]: + module = learner.module[module_id] + + # Forward net loss. + forward_loss = torch.mean(fwd_out[Columns.INTRINSIC_REWARDS]) + + # Inverse loss term (predicted action that led from phi to phi' vs + # actual action taken). + dist_inputs = module._inverse_net( + torch.cat([fwd_out["phi"], fwd_out["next_phi"]], dim=-1) ) - # Make batch-major again. - features = features.permute(1, 0, 2) - # Push through our FC net. - features = self._fc_net(features) - logits = self._logits(features) - return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits + action_dist = module.get_train_action_dist_cls().from_logits(dist_inputs) + + # Neg log(p); p=probability of observed action given the inverse-NN + # predicted action distribution. + inverse_loss = -action_dist.logp(batch[Columns.ACTIONS]) + inverse_loss = torch.mean(inverse_loss) + + # Calculate the ICM loss. + total_loss = ( + (1.0 - config.curiosity_beta) * inverse_loss + + config.curiosity_beta * forward_loss + ) + + learner.metrics.log_dict( + { + "mean_intrinsic_rewards": forward_loss, + "forward_loss": forward_loss, + "inverse_loss": inverse_loss, + }, + key=module_id, + window=1, + ) + + return total_loss # Inference and exploration not supported (this is a world-model that should only # be used for training). From 82ff7030461b2881988f48a70fad01dba46ca7cc Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 29 Jul 2024 15:45:09 +0200 Subject: [PATCH 10/13] wip Signed-off-by: sven1977 --- rllib/BUILD | 10 ++++++++++ rllib/algorithms/dqn/dqn_rainbow_learner.py | 2 +- rllib/core/rl_module/torch/torch_rl_module.py | 7 ------- .../inverse_dynamics_model_based_curiosity.py | 12 +++++++++--- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 36ad7ab1902b..7c675265a1ec 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2513,6 +2513,16 @@ py_test( args = ["--enable-new-api-stack", "--as-test"] ) +py_test( + name = "examples/curiosity/inverse_dynamics_model_based_curiosity", + main = "examples/curiosity/inverse_dynamics_model_based_curiosity.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/curiosity/inverse_dynamics_model_based_curiosity.py"], + args = ["--enable-new-api-stack", "--as-test"] +) + + # subdirectory: curriculum/ # .................................... py_test( diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index 562e707d1427..082cb587c405 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -55,7 +55,7 @@ def build(self) -> None: ) ) - # Prepend a "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right + # Prepend the "add-NEXT_OBS-from-episodes-to-train-batch" connector piece (right # after the corresponding "add-OBS-..." default piece). if self.config.add_default_connectors_to_learner_pipeline: self._learner_connector.insert_after( diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index 92d428d18d91..da0ff4a37f5a 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -163,13 +163,6 @@ def restore_from_path(self, *args, **kwargs): def get_metadata(self, *args, **kwargs): self.unwrapped().get_metadata(*args, **kwargs) - ## TODO (sven): Figure out a better way to avoid having to method-spam this wrapper - ## class, whenever we add a new API to any wrapped RLModule here. We could try - ## auto generating the wrapper methods, but this will bring its own challenge - ## (e.g. recursive calls due to __getattr__ checks, etc..). - #def _compute_values(self, *args, **kwargs): - # return self.unwrapped()._compute_values(*args, **kwargs) - @override(RLModule) def unwrapped(self) -> "RLModule": return self.module diff --git a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py index ddeefae8d3dd..fc955d353c4f 100644 --- a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py +++ b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py @@ -38,7 +38,13 @@ -class PrintMaxDistanceFrozenLakeCallback(DefaultCallbacks): +class MeasureMaxDistanceToStart(DefaultCallbacks): + """Callback measuring the dist of the agent to its start position in FrozenLake-v1. + + Makes the naive assumption that the start position ("S") is in the upper left + corner of the used map. + Uses the MetricsLogger to record the (euclidian) distance value. + """ def __init__(self): super().__init__() self.max_dists = defaultdict(float) @@ -110,6 +116,7 @@ def on_sample_end( .environment( "FrozenLake-v1", env_config={ + # Use a 12x12 map. "desc": [ "SFFFFFFFFFFF", "FFFFFFFFFFFF", @@ -135,7 +142,7 @@ def on_sample_end( #curiosity_feature_net_hiddens=[256, 256], #curiosity_inverse_net_activation="relu" ) - .callbacks(PrintMaxDistanceFrozenLakeCallback) + .callbacks(MeasureMaxDistanceToStart) .env_runners( num_envs_per_env_runner=5, env_to_module_connector=lambda env: FlattenObservations(), @@ -144,7 +151,6 @@ def on_sample_end( learner_class=PPOTorchLearnerWithCuriosity, train_batch_size_per_learner=2000, num_sgd_iter=6, - #vf_loss_coeff=0.01, lr=0.0003, ) .rl_module(model_config_dict={"vf_share_layers": True}) From 09a1c95993d62206cf3132c0065ccd35978f0674 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Mon, 29 Jul 2024 18:00:24 +0200 Subject: [PATCH 11/13] wip Signed-off-by: sven1977 --- rllib/connectors/connector_v2.py | 1 - rllib/env/single_agent_env_runner.py | 4 +- .../inverse_dynamics_model_based_curiosity.py | 82 ++++++++++++++--- .../classes/curiosity_ppo_torch_learner.py | 4 +- .../classes/inverse_dynamics_model_rlm.py | 92 ++++++++++--------- 5 files changed, 119 insertions(+), 64 deletions(-) diff --git a/rllib/connectors/connector_v2.py b/rllib/connectors/connector_v2.py index fb678be07c2d..ef566d5b9438 100644 --- a/rllib/connectors/connector_v2.py +++ b/rllib/connectors/connector_v2.py @@ -444,7 +444,6 @@ def add_n_batch_items( items_to_add: Any, num_items: int, single_agent_episode: Optional[SingleAgentEpisode] = None, - ) -> None: """Adds a list of items (or batched item) under `column` to the given `batch`. diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 7dc8b33d0ce3..4c6f50cb7d43 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -207,8 +207,8 @@ def sample( # For complete episodes mode, sample a single episode and # leave coordination of sampling to `synchronous_parallel_sample`. # TODO (simon, sven): The coordination will eventually move - # to `EnvRunnerGroup` in the future. So from the algorithm one - # would do `EnvRunnerGroup.sample()`. + # to `EnvRunnerGroup` in the future. So from the algorithm one + # would do `EnvRunnerGroup.sample()`. else: samples = self._sample_episodes( num_episodes=1, diff --git a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py index fc955d353c4f..496c0b49c7bb 100644 --- a/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py +++ b/rllib/examples/curiosity/inverse_dynamics_model_based_curiosity.py @@ -1,28 +1,84 @@ -"""Implementation of: -[1] Curiosity-driven Exploration by Self-supervised Prediction -Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. -https://arxiv.org/pdf/1705.05363.pdf +"""Example of implementing and running inverse dynamics model (ICM) based curiosity. -Learns a simplified model of the environment based on three networks: +This type of curiosity-based learning trains a simplified model of the environment +dynamics based on three networks: 1) Embedding observations into latent space ("feature" network). 2) Predicting the action, given two consecutive embedded observations ("inverse" network). 3) Predicting the next embedded obs, given an obs and action ("forward" network). -The less the agent is able to predict the actually observed next feature -vector, given obs and action (through the forwards network), the larger the -"intrinsic reward", which will be added to the extrinsic reward. +The less the ICM is able to predict the actually observed next feature vector, +given obs and action (through the forwards network), the larger the +"intrinsic reward", which will be added to the extrinsic reward of the agent. + Therefore, if a state transition was unexpected, the agent becomes "curious" and will further explore this transition leading to better exploration in sparse rewards environments. + +For more details, see here: + +[1] Curiosity-driven Exploration by Self-supervised Prediction +Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017. +https://arxiv.org/pdf/1705.05363.pdf + +This example: + - demonstrates how to write a custom RLModule, representing the ICM from the paper + above. Note that this custom RLModule does not belong to any individual agent. + - demonstrates how to write a custom (PPO) TorchLearner that a) adds the ICM to its + MultiRLModule, b) trains the regular PPO Policy plus the ICM module, using the + PPO parent loss and the ICM's RLModule's own loss function. + +We use a FrozenLake (sparse reward) environment with a custom map size of 12x12 and a +hard time step limit of 22 to make it almost impossible for a non-curiosity based +learners to learn a good policy. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +Use the `--no-curiosity` flag to disable curiosity learning and force your policy +to be trained on the task w/o the use of intrinsic rewards. With this option, the +algorithm should NOT succeed. + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +In the console output, you can see that only a PPO policy that uses curiosity can +actually learn. + +Policy using ICM-based curiosity: ++-------------------------------+------------+-----------------+--------+ +| Trial name | status | loc | iter | +|-------------------------------+------------+-----------------+--------+ +| PPO_FrozenLake-v1_52ab2_00000 | TERMINATED | 127.0.0.1:73318 | 392 | ++-------------------------------+------------+-----------------+--------+ ++------------------+--------+----------+--------------------+ +| total time (s) | ts | reward | episode_len_mean | +|------------------+--------+----------+--------------------| +| 236.652 | 786000 | 1.0 | 22.0 | ++------------------+--------+----------+--------------------+ + +Policy NOT using curiosity: +[DOES NOT LEARN AT ALL] """ from collections import defaultdict from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.connectors.env_to_module import FlattenObservations from ray.rllib.examples.learners.classes.curiosity_ppo_torch_learner import ( - PPOConfigWithCuriosity, PPOTorchLearnerWithCuriosity + PPOConfigWithCuriosity, + PPOTorchLearnerWithCuriosity, ) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -37,7 +93,6 @@ parser.set_defaults(enable_new_api_stack=True) - class MeasureMaxDistanceToStart(DefaultCallbacks): """Callback measuring the dist of the agent to its start position in FrozenLake-v1. @@ -45,6 +100,7 @@ class MeasureMaxDistanceToStart(DefaultCallbacks): corner of the used map. Uses the MetricsLogger to record the (euclidian) distance value. """ + def __init__(self): super().__init__() self.max_dists = defaultdict(float) @@ -66,7 +122,7 @@ def on_episode_step( num_cols = env.envs[0].unwrapped.ncol row = obs // num_cols col = obs % num_rows - curr_dist = (row ** 2 + col ** 2) ** 0.5 + curr_dist = (row**2 + col**2) ** 0.5 if curr_dist > self.max_dists[episode.id_]: self.max_dists[episode.id_] = curr_dist @@ -139,8 +195,8 @@ def on_sample_end( ) # Use our custom `curiosity` method to set up the ICM and our PPO/ICM-Learner. .curiosity( - #curiosity_feature_net_hiddens=[256, 256], - #curiosity_inverse_net_activation="relu" + # curiosity_feature_net_hiddens=[256, 256], + # curiosity_inverse_net_activation="relu", ) .callbacks(MeasureMaxDistanceToStart) .env_runners( diff --git a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py index 29c3dad70813..4a0e30b36aa1 100644 --- a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py +++ b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py @@ -12,7 +12,7 @@ from ray.rllib.core import Columns, DEFAULT_MODULE_ID from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.examples.rl_modules.classes.inverse_dynamics_model_rlm import ( - InverseDynamicsModel + InverseDynamicsModel, ) from ray.rllib.utils.metrics import ALL_MODULES @@ -65,7 +65,6 @@ def curiosity( class PPOTorchLearnerWithCuriosity(PPOTorchLearner): - def build(self): super().build() @@ -99,7 +98,6 @@ def build(self): AddNextObservationsFromEpisodesToTrainBatch(), ) - def compute_loss( self, *, diff --git a/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py index c632d09ba92a..05f0a9a5e802 100644 --- a/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py +++ b/rllib/examples/rl_modules/classes/inverse_dynamics_model_rlm.py @@ -1,9 +1,6 @@ from typing import Any, Dict, TYPE_CHECKING -import numpy as np - from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.torch.torch_distributions import TorchCategorical from ray.rllib.models.utils import get_activation_fn @@ -43,44 +40,52 @@ class InverseDynamicsModel(TorchRLModule): .. testcode:: - import numpy as np - import gymnasium as gym - from ray.rllib.core.rl_module.rl_module import RLModuleConfig + import numpy as np + import gymnasium as gym + import torch - B = 10 # batch size - T = 5 # seq len - f = 25 # feature dim + from ray.rllib.core import Columns + from ray.rllib.core.rl_module.rl_module import RLModuleConfig + from ray.rllib.examples.rl_modules.classes.inverse_dynamics_model_rlm import ( # noqa + InverseDynamicsModel + ) - # Construct the RLModule. - rl_module_config = RLModuleConfig( - observation_space=gym.spaces.Box(-1.0, 1.0, (f,), np.float32), - action_space=gym.spaces.Discrete(4), - model_config_dict={"lstm_cell_size": CELL} - ) - my_net = LSTMContainingRLModule(rl_module_config) - - # Create some dummy input. - obs = torch.from_numpy( - np.random.random_sample(size=(B, T, f) - ).astype(np.float32)) - state_in = my_net.get_initial_state() - # Repeat state_in across batch. - state_in = tree.map_structure( - lambda s: torch.from_numpy(s).unsqueeze(0).repeat(B, 1), state_in - ) - input_dict = { - Columns.OBS: obs, - Columns.STATE_IN: state_in, - } + B = 10 # batch size + O = 4 # obs (1D) dim + A = 2 # num actions + f = 25 # feature dim - # Run through all 3 forward passes. - print(my_net.forward_inference(input_dict)) - print(my_net.forward_exploration(input_dict)) - print(my_net.forward_train(input_dict)) + # Construct the RLModule. + rl_module_config = RLModuleConfig( + observation_space=gym.spaces.Box(-1.0, 1.0, (O,), np.float32), + action_space=gym.spaces.Discrete(A), + ) + icm_net = InverseDynamicsModel(rl_module_config) - # Print out the number of parameters. - num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) - print(f"num params = {num_all_params}") + # Create some dummy input. + obs = torch.from_numpy( + np.random.random_sample(size=(B, O)).astype(np.float32) + ) + next_obs = torch.from_numpy( + np.random.random_sample(size=(B, O)).astype(np.float32) + ) + actions = torch.from_numpy( + np.random.random_integers(0, A - 1, size=(B,)) + ) + input_dict = { + Columns.OBS: obs, + Columns.NEXT_OBS: next_obs, + Columns.ACTIONS: actions, + } + + # Call `forward_train()` to get phi (feature vector from obs), next-phi + # (feature vector from next obs), and the intrinsic rewards (individual, per + # batch-item forward loss values). + print(icm_net.forward_train(input_dict)) + + # Print out the number of parameters. + num_all_params = sum(int(np.prod(p.size())) for p in icm_net.parameters()) + print(f"num params = {num_all_params}") """ @override(TorchRLModule) @@ -162,7 +167,9 @@ def _forward_train(self, batch, **kwargs): torch.cat( [ phi, - one_hot(batch[Columns.ACTIONS].long(), self.config.action_space).float(), + one_hot( + batch[Columns.ACTIONS].long(), self.config.action_space + ).float(), ], dim=-1, ) @@ -186,10 +193,6 @@ def _forward_train(self, batch, **kwargs): @override(TorchRLModule) def get_train_action_dist_cls(self): - ## TorchCategorical - ## if isinstance(self.action_space, Discrete) - ## else TorchMultiCategorical(dist_inputs, self.model, self.action_space.nvec) - ##) return TorchCategorical @staticmethod @@ -220,9 +223,8 @@ def compute_loss_for_module( # Calculate the ICM loss. total_loss = ( - (1.0 - config.curiosity_beta) * inverse_loss - + config.curiosity_beta * forward_loss - ) + 1.0 - config.curiosity_beta + ) * inverse_loss + config.curiosity_beta * forward_loss learner.metrics.log_dict( { From 6333405244f3e553b5038f864f13228085c48747 Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 30 Jul 2024 09:18:25 +0200 Subject: [PATCH 12/13] wip Signed-off-by: sven1977 --- rllib/BUILD | 4 ++-- rllib/core/models/tests/test_catalog.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 6251a4f09126..ce00d6a35eac 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2499,7 +2499,7 @@ py_test( name = "examples/curiosity/count_based_curiosity", main = "examples/curiosity/count_based_curiosity.py", tags = ["team:rllib", "exclusive", "examples"], - size = "medium", + size = "large", srcs = ["examples/curiosity/count_based_curiosity.py"], args = ["--enable-new-api-stack", "--as-test"] ) @@ -2508,7 +2508,7 @@ py_test( name = "examples/curiosity/euclidian_distance_based_curiosity", main = "examples/curiosity/euclidian_distance_based_curiosity.py", tags = ["team:rllib", "exclusive", "examples"], - size = "medium", + size = "large", srcs = ["examples/curiosity/euclidian_distance_based_curiosity.py"], args = ["--enable-new-api-stack", "--as-test"] ) diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index 17790278a0f6..86d561a3f752 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -465,7 +465,7 @@ def _determine_components(self): module = spec.build() module.forward_inference( - input_data={"obs": torch.ones((32, *env.observation_space.shape))} + batch={"obs": torch.ones((32, *env.observation_space.shape))} ) From 5afe5cdf7a4b9c9fcfd58e909233156616a08fef Mon Sep 17 00:00:00 2001 From: sven1977 Date: Tue, 30 Jul 2024 11:00:26 +0200 Subject: [PATCH 13/13] fix Signed-off-by: sven1977 --- .../examples/learners/classes/curiosity_ppo_torch_learner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py index 4a0e30b36aa1..f5a1390d8cea 100644 --- a/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py +++ b/rllib/examples/learners/classes/curiosity_ppo_torch_learner.py @@ -10,7 +10,7 @@ AddNextObservationsFromEpisodesToTrainBatch, ) from ray.rllib.core import Columns, DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.rl_modules.classes.inverse_dynamics_model_rlm import ( InverseDynamicsModel, ) @@ -72,7 +72,7 @@ def build(self): assert len(self.module) == 1 and DEFAULT_MODULE_ID in self.module # Add an InverseDynamicsModel to our MARLModule. - icm_spec = SingleAgentRLModuleSpec( + icm_spec = RLModuleSpec( module_class=InverseDynamicsModel, observation_space=self.module[DEFAULT_MODULE_ID].config.observation_space, action_space=self.module[DEFAULT_MODULE_ID].config.action_space,