From b73d6d4772de77eb48ce6a563554c9b619caa45c Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 27 Jun 2024 16:34:07 +0200 Subject: [PATCH] [RLlib] Cleanup examples folder 17: Add example for custom RLModule with an LSTM. (#46276) --- rllib/BUILD | 25 ++- .../rl_modules/classes/lstm_containing_rlm.py | 192 ++++++++++++++++++ ...iny_atari_cnn.py => tiny_atari_cnn_rlm.py} | 56 +++-- ...m_rl_module.py => custom_cnn_rl_module.py} | 12 +- .../rl_modules/custom_lstm_rl_module.py | 104 ++++++++++ rllib/tuned_examples/impala/pong_impala.py | 4 +- .../impala/pong_impala_pb2_hyperopt.py | 4 +- rllib/utils/typing.py | 2 +- 8 files changed, 350 insertions(+), 49 deletions(-) create mode 100644 rllib/examples/rl_modules/classes/lstm_containing_rlm.py rename rllib/examples/rl_modules/classes/{tiny_atari_cnn.py => tiny_atari_cnn_rlm.py} (97%) rename rllib/examples/rl_modules/{custom_rl_module.py => custom_cnn_rl_module.py} (92%) create mode 100644 rllib/examples/rl_modules/custom_lstm_rl_module.py diff --git a/rllib/BUILD b/rllib/BUILD index df43ccc7e9fa..048963b0911e 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -3104,15 +3104,6 @@ py_test( # subdirectory: rl_modules/ # .................................... -py_test( - name = "examples/rl_modules/custom_rl_module", - main = "examples/rl_modules/custom_rl_module.py", - tags = ["team:rllib", "examples"], - size = "medium", - srcs = ["examples/rl_modules/custom_rl_module.py"], - args = ["--enable-new-api-stack", "--stop-iters=3"], -) - py_test( name = "examples/rl_modules/action_masking_rlm", main = "examples/rl_modules/action_masking_rlm.py", @@ -3130,6 +3121,22 @@ py_test( srcs = ["examples/rl_modules/autoregressive_actions_rlm.py"], args = ["--enable-new-api-stack"], ) +py_test( + name = "examples/rl_modules/custom_cnn_rl_module", + main = "examples/rl_modules/custom_cnn_rl_module.py", + tags = ["team:rllib", "examples"], + size = "medium", + srcs = ["examples/rl_modules/custom_cnn_rl_module.py"], + args = ["--enable-new-api-stack", "--stop-iters=3"], +) +py_test( + name = "examples/rl_modules/custom_lstm_rl_module", + main = "examples/rl_modules/custom_lstm_rl_module.py", + tags = ["team:rllib", "examples"], + size = "medium", + srcs = ["examples/rl_modules/custom_lstm_rl_module.py"], + args = ["--as-test", "--enable-new-api-stack"], +) #@OldAPIStack @HybridAPIStack py_test( diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py new file mode 100644 index 000000000000..58946df56b1f --- /dev/null +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -0,0 +1,192 @@ +from typing import Any + +import numpy as np + +from ray.rllib.core.columns import Columns +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.torch.torch_distributions import TorchCategorical +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + + +class LSTMContainingRLModule(TorchRLModule): + """An example TorchRLModule that contains an LSTM layer. + + .. testcode:: + + import numpy as np + import gymnasium as gym + from ray.rllib.core.rl_module.rl_module import RLModuleConfig + + B = 10 # batch size + T = 5 # seq len + f = 25 # feature dim + CELL = 32 # LSTM cell size + + # Construct the RLModule. + rl_module_config = RLModuleConfig( + observation_space=gym.spaces.Box(-1.0, 1.0, (f,), np.float32), + action_space=gym.spaces.Discrete(4), + model_config_dict={"lstm_cell_size": CELL} + ) + my_net = LSTMContainingRLModule(rl_module_config) + + # Create some dummy input. + obs = torch.from_numpy( + np.random.random_sample(size=(B, T, f) + ).astype(np.float32)) + state_in = my_net.get_initial_state() + # Repeat state_in across batch. + state_in = tree.map_structure( + lambda s: torch.from_numpy(s).unsqueeze(0).repeat(B, 1), state_in + ) + input_dict = { + Columns.OBS: obs, + Columns.STATE_IN: state_in, + } + + # Run through all 3 forward passes. + print(my_net.forward_inference(input_dict)) + print(my_net.forward_exploration(input_dict)) + print(my_net.forward_train(input_dict)) + + # Print out the number of parameters. + num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) + print(f"num params = {num_all_params}") + """ + + @override(TorchRLModule) + def setup(self): + """Use this method to create all the model components that you require. + + Feel free to access the following useful properties in this class: + - `self.config.model_config_dict`: The config dict for this RLModule class, + which should contain flxeible settings, for example: {"hiddens": [256, 256]}. + - `self.config.observation|action_space`: The observation and action space that + this RLModule is subject to. Note that the observation space might not be the + exact space from your env, but that it might have already gone through + preprocessing through a connector pipeline (for example, flattening, + frame-stacking, mean/std-filtering, etc..). + """ + # Assume a simple Box(1D) tensor as input shape. + in_size = self.config.observation_space.shape[0] + + # Get the LSTM cell size from our RLModuleConfig's (self.config) + # `model_config_dict` property: + self._lstm_cell_size = self.config.model_config_dict.get("lstm_cell_size", 256) + self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=False) + in_size = self._lstm_cell_size + + # Build a sequential stack. + layers = [] + # Get the dense layer pre-stack configuration from the same config dict. + dense_layers = self.config.model_config_dict.get("dense_layers", [128, 128]) + for out_size in dense_layers: + # Dense layer. + layers.append(nn.Linear(in_size, out_size)) + # ReLU activation. + layers.append(nn.ReLU()) + in_size = out_size + + self._fc_net = nn.Sequential(*layers) + + # Logits layer (no bias, no activation). + self._logits = nn.Linear(in_size, self.config.action_space.n) + # Single-node value layer. + self._values = nn.Linear(in_size, 1) + + @override(TorchRLModule) + def get_initial_state(self) -> Any: + return { + "h": np.zeros(shape=(self._lstm_cell_size,), dtype=np.float32), + "c": np.zeros(shape=(self._lstm_cell_size,), dtype=np.float32), + } + + @override(TorchRLModule) + def _forward_inference(self, batch, **kwargs): + # Compute the basic 1D feature tensor (inputs to policy- and value-heads). + _, state_out, logits = self._compute_features_state_out_and_logits(batch) + + # Return logits as ACTION_DIST_INPUTS (categorical distribution). + # Note that the default `GetActions` connector piece (in the EnvRunner) will + # take care of argmax-"sampling" from the logits to yield the inference (greedy) + # action. + return { + Columns.STATE_OUT: state_out, + Columns.ACTION_DIST_INPUTS: logits, + } + + @override(TorchRLModule) + def _forward_exploration(self, batch, **kwargs): + # Exact same as `_forward_inference`. + # Note that the default `GetActions` connector piece (in the EnvRunner) will + # take care of stochastic sampling from the Categorical defined by the logits + # to yield the exploration action. + return self._forward_inference(batch, **kwargs) + + @override(TorchRLModule) + def _forward_train(self, batch, **kwargs): + # Compute the basic 1D feature tensor (inputs to policy- and value-heads). + features, state_out, logits = self._compute_features_state_out_and_logits(batch) + # Besides the action logits, we also have to return value predictions here + # (to be used inside the loss function). + values = self._values(features).squeeze(-1) + return { + Columns.STATE_OUT: state_out, + Columns.ACTION_DIST_INPUTS: logits, + Columns.VF_PREDS: values, + } + + # TODO (sven): We still need to define the distibution to use here, even though, + # we have a pretty standard action space (Discrete), which should simply always map + # to a categorical dist. by default. + @override(TorchRLModule) + def get_inference_action_dist_cls(self): + return TorchCategorical + + @override(TorchRLModule) + def get_exploration_action_dist_cls(self): + return TorchCategorical + + @override(TorchRLModule) + def get_train_action_dist_cls(self): + return TorchCategorical + + # TODO (sven): In order for this RLModule to work with PPO, we must define + # our own `_compute_values()` method. This would become more obvious, if we simply + # subclassed the `PPOTorchRLModule` directly here (which we didn't do for + # simplicity and to keep some generality). We might change even get rid of algo- + # specific RLModule subclasses altogether in the future and replace them + # by mere algo-specific APIs (w/o any actual implementations). + def _compute_values(self, batch): + obs = batch[Columns.OBS] + state_in = batch[Columns.STATE_IN] + h, c = state_in["h"], state_in["c"] + # Unsqueeze the layer dim (we only have 1 LSTM layer. + features, _ = self._lstm( + obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major + (h.unsqueeze(0), c.unsqueeze(0)), + ) + # Make batch-major again. + features = features.permute(1, 0, 2) + # Push through our FC net. + features = self._fc_net(features) + return self._values(features).squeeze(-1) + + def _compute_features_state_out_and_logits(self, batch): + obs = batch[Columns.OBS] + state_in = batch[Columns.STATE_IN] + h, c = state_in["h"], state_in["c"] + # Unsqueeze the layer dim (we only have 1 LSTM layer. + features, (h, c) = self._lstm( + obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major + (h.unsqueeze(0), c.unsqueeze(0)), + ) + # Make batch-major again. + features = features.permute(1, 0, 2) + # Push through our FC net. + features = self._fc_net(features) + logits = self._logits(features) + return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py similarity index 97% rename from rllib/examples/rl_modules/classes/tiny_atari_cnn.py rename to rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py index e19d175d28e9..9864a8cec17c 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py @@ -9,7 +9,6 @@ from ray.rllib.models.torch.torch_distributions import TorchCategorical from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, nn = try_import_torch() @@ -24,6 +23,31 @@ class TinyAtariCNN(TorchRLModule): and n 1x1 filters, where n is the number of actions in the (discrete) action space. Simple reshaping (no flattening or extra linear layers necessary) lead to the action logits, which can directly be used inside a distribution or loss. + + import numpy as np + import gymnasium as gym + from ray.rllib.core.rl_module.rl_module import RLModuleConfig + + rl_module_config = RLModuleConfig( + observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32), + action_space=gym.spaces.Discrete(4), + ) + my_net = TinyAtariCNN(rl_module_config) + + B = 10 + w = 42 + h = 42 + c = 4 + data = torch.from_numpy( + np.random.random_sample(size=(B, w, h, c)).astype(np.float32) + ) + print(my_net.forward_inference({"obs": data})) + print(my_net.forward_exploration({"obs": data})) + print(my_net.forward_train({"obs": data})) + + num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) + print(f"num params = {num_all_params}") + """ @override(TorchRLModule) @@ -124,8 +148,8 @@ def _forward_train(self, batch, **kwargs): # simplicity and to keep some generality). We might even get rid of algo- # specific RLModule subclasses altogether in the future and replace them # by mere algo-specific APIs (w/o any actual implementations). - def _compute_values(self, batch, device): - obs = convert_to_torch_tensor(batch[Columns.OBS], device=device) + def _compute_values(self, batch): + obs = batch[Columns.OBS] features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) features = torch.squeeze(features, dim=[-1, -2]) return self._values(features).squeeze(-1) @@ -156,29 +180,3 @@ def get_exploration_action_dist_cls(self): @override(RLModule) def get_inference_action_dist_cls(self): return TorchCategorical - - -if __name__ == "__main__": - import numpy as np - import gymnasium as gym - from ray.rllib.core.rl_module.rl_module import RLModuleConfig - - rl_module_config = RLModuleConfig( - observation_space=gym.spaces.Box(-1.0, 1.0, (42, 42, 4), np.float32), - action_space=gym.spaces.Discrete(4), - ) - my_net = TinyAtariCNN(rl_module_config) - - B = 10 - w = 42 - h = 42 - c = 4 - data = torch.from_numpy( - np.random.random_sample(size=(B, w, h, c)).astype(np.float32) - ) - print(my_net.forward_inference({"obs": data})) - print(my_net.forward_exploration({"obs": data})) - print(my_net.forward_train({"obs": data})) - - num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) - print(f"num params = {num_all_params}") diff --git a/rllib/examples/rl_modules/custom_rl_module.py b/rllib/examples/rl_modules/custom_cnn_rl_module.py similarity index 92% rename from rllib/examples/rl_modules/custom_rl_module.py rename to rllib/examples/rl_modules/custom_cnn_rl_module.py index a75d59960044..a171f83467d2 100644 --- a/rllib/examples/rl_modules/custom_rl_module.py +++ b/rllib/examples/rl_modules/custom_cnn_rl_module.py @@ -1,9 +1,9 @@ -"""Example of implementing and configuring a custom (torch) RLModule. +"""Example of implementing and configuring a custom (torch) CNN containing RLModule. This example: - - demonstrates how you can subclass the TorchRLModule base class and setup your - own neural network architecture by overriding `setup()`. - - how to override the 3 forward methods: `_forward_inference()`, + - demonstrates how you can subclass the TorchRLModule base class and set up your + own CNN-stack architecture by overriding the `setup()` method. + - shows how to override the 3 forward methods: `_forward_inference()`, `_forward_exploration()`, and `forward_train()` to implement your own custom forward logic(s). You will also learn, when each of these 3 methods is called by RLlib or the users of your RLModule. @@ -56,7 +56,7 @@ from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack -from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn import TinyAtariCNN +from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, run_rllib_example_script_experiment, @@ -114,4 +114,4 @@ ) ) - run_rllib_example_script_experiment(base_config, args, stop={}) + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/rl_modules/custom_lstm_rl_module.py b/rllib/examples/rl_modules/custom_lstm_rl_module.py new file mode 100644 index 000000000000..5612df47104d --- /dev/null +++ b/rllib/examples/rl_modules/custom_lstm_rl_module.py @@ -0,0 +1,104 @@ +"""Example of implementing and configuring a custom (torch) LSTM containing RLModule. + +This example: + - demonstrates how you can subclass the TorchRLModule base class and set up your + own LSTM-containing NN architecture by overriding the `setup()` method. + - shows how to override the 3 forward methods: `_forward_inference()`, + `_forward_exploration()`, and `forward_train()` to implement your own custom forward + logic(s), including how to handle STATE in- and outputs to and from these calls. + - explains when each of these 3 methods is called by RLlib or the users of your + RLModule. + - shows how you then configure an RLlib Algorithm such that it uses your custom + RLModule (instead of a default RLModule). + +We implement a simple LSTM layer here, followed by a series of Linear layers. +After the last Linear layer, we add fork of 2 Linear (non-activated) layers, one for the +action logits and one for the value function output. + +We test the LSTM containing RLModule on the StatelessCartPole environment, a variant +of CartPole that is non-Markovian (partially observable). Only an RNN-network can learn +a decent policy in this environment due to the lack of any velocity information. By +looking at one observation, one cannot know whether the cart is currently moving left or +right and whether the pole is currently moving up or down). + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +You should see the following output (during the experiment) in your console: + +""" +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentStatelessCartPole +from ray.rllib.examples.rl_modules.classes.lstm_containing_rlm import ( + LSTMContainingRLModule, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + run_rllib_example_script_experiment, +) +from ray.tune.registry import get_trainable_cls, register_env + +parser = add_rllib_example_script_args( + default_reward=300.0, + default_timesteps=2000000, +) + + +if __name__ == "__main__": + args = parser.parse_args() + + assert ( + args.enable_new_api_stack + ), "Must set --enable-new-api-stack when running this script!" + + if args.num_agents == 0: + register_env("env", lambda cfg: StatelessCartPole()) + else: + register_env("env", lambda cfg: MultiAgentStatelessCartPole(cfg)) + + base_config = ( + get_trainable_cls(args.algo) + .get_default_config() + .environment( + env="env", + env_config={"num_agents": args.num_agents}, + ) + .training( + train_batch_size_per_learner=1024, + num_sgd_iter=6, + lr=0.0009, + vf_loss_coeff=0.001, + entropy_coeff=0.0, + ) + .rl_module( + # Plug-in our custom RLModule class. + rl_module_spec=SingleAgentRLModuleSpec( + module_class=LSTMContainingRLModule, + # Feel free to specify your own `model_config_dict` settings below. + # The `model_config_dict` defined here will be available inside your + # custom RLModule class through the `self.config.model_config_dict` + # property. + model_config_dict={ + "lstm_cell_size": 256, + "dense_layers": [256, 256], + }, + ), + ) + ) + + run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/tuned_examples/impala/pong_impala.py b/rllib/tuned_examples/impala/pong_impala.py index 1e0b4ea651b6..d1a5dd5fe076 100644 --- a/rllib/tuned_examples/impala/pong_impala.py +++ b/rllib/tuned_examples/impala/pong_impala.py @@ -3,7 +3,7 @@ from ray.rllib.algorithms.impala import ImpalaConfig from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack -from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn import TinyAtariCNN +from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, @@ -21,7 +21,7 @@ "3 CNN layers ([32, 4, 2, same], [64, 4, 2, same], [256, 11, 1, valid]) for the " "base features and then a CNN pi-head with an output of [num-actions, 1, 1] and " "a Linear(1) layer for the values. The actual RLModule class used can be found " - "here: ray.rllib.examples.rl_modules.classes.tiny_atari_cnn", + "here: ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm", ) args = parser.parse_args() diff --git a/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py b/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py index 9d3cc86fb122..32dd58612270 100644 --- a/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py +++ b/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py @@ -3,7 +3,7 @@ from ray.rllib.algorithms.impala import ImpalaConfig from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack -from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn import TinyAtariCNN +from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, @@ -23,7 +23,7 @@ "3 CNN layers ([32, 4, 2, same], [64, 4, 2, same], [256, 11, 1, valid]) for the " "base features and then a CNN pi-head with an output of [num-actions, 1, 1] and " "a Linear(1) layer for the values. The actual RLModule class used can be found " - "here: ray.rllib.examples.rl_modules.classes.tiny_atari_cnn", + "here: ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm", ) args = parser.parse_args() diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 1df194aff04b..6869a75d49f4 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -194,7 +194,7 @@ ModelInputDict = Dict[str, TensorType] # Some kind of sample batch. -SampleBatchType = Union["SampleBatch", "MultiAgentBatch"] +SampleBatchType = Union["SampleBatch", "MultiAgentBatch", Dict[str, Any]] # A (possibly nested) space struct: Either a gym.spaces.Space or a # (possibly nested) dict|tuple of gym.space.Spaces.