From a41c245a1abeb36bc9c16352825446973c8e740a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 17:40:09 +0100 Subject: [PATCH 01/24] initial Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 130 ++++++++++++++++++ .../ppo/torch/ppo_torch_rl_module.py | 88 ++++++------ 2 files changed, 178 insertions(+), 40 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 8feb601b9f00..69cdd58a5342 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,12 +1,21 @@ import numpy as np +import tree import unittest import ray import ray.rllib.algorithms.ppo as ppo +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + get_expected_model_config, + PPOTorchRLModule, + get_ppo_loss, +) +from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.test_utils import ( check, check_compute_single_action, @@ -176,6 +185,127 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() + def test_torch_model_creation(self): + pass + + def test_torch_model_creation_lstm(self): + pass + + def test_rollouts(self): + for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [False]: # , True]" + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + import gym + + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + obs = env.reset() + if lstm: + states = [ + s.get_initial_state() + for s in ( + module.shared_encoder, + module.encoder_vf, + module.encoder_pi, + ) + ] + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + **{f"state_in_{i}": s for i, s in enumerate(states)}, + } + else: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + elif fwd_fn == "forward_inference": + module.forward_inference(batch) + + def test_forward_train(self): + for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [False]: # , True]" + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + import gym + + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batch = [] + obs = env.reset() + tstep = 0 + if lstm: + states = {} + for i, model in enumerate( + [ + module.shared_encoder, + module.encoder_pi, + module.encoder_vf, + ] + ): + states[i] = model.get_inital_state() + while tstep < 10: + fwd_out = module.forward_exploration( + {"obs": convert_to_torch_tensor(obs)[None]} + ) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + step = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), + } + if lstm: + assert "state_out" in fwd_out + for k, v in states.items(): + step[f"state_in_{k}"] = v + states[k] = fwd_out["state_out"][k] + batch.append(step) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batch) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + + # forward train + # before training make sure it's on the right device and it's on + # trianing mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + self.assertIsNotNone(param.grad) + if __name__ == "__main__": import pytest diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 83758c983b9f..43ea05fe1285 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -42,37 +42,45 @@ def get_ppo_loss(fwd_in, fwd_out): # TODO: Most of the neural network, and model specs in this file will eventually be # retreived from the model catalog. That includes FCNet, Encoder, etc. -def get_shared_encoder_config(env): - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - encoder_config=FCConfig( - hidden_layers=[32], - activation="ReLU", - ), - pi_config=FCConfig( - hidden_layers=[32], - activation="ReLU", - ), - vf_config=FCConfig( - hidden_layers=[32], - activation="ReLU", - ), +def get_expected_model_config(env, lstm, shared_encoder): + assert not env.observation_space.shape[-1] == 3, "Implement VisionNet first" + if lstm: + encoder_config_class = LSTMConfig + else: + encoder_config_class = FCConfig + + pi_config = encoder_config_class( + hidden_layers=[32], + activation="ReLU", + ) + vf_config = encoder_config_class( + input_dim=32, + hidden_layers=[32], + activation="ReLU", ) + if shared_encoder: + shared_encoder_config = ( + encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", + ), + ) + pi_config.input_dim = 32 + vf_config.input_dim = 32 + else: + shared_encoder_config = None + pi_config.input_dim = env.observation_space.shape[0] + vf_config.input_dim = env.observation_space.shape[0] -def get_separate_encoder_config(env): return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, - pi_config=FCConfig( - hidden_layers=[32], - activation="ReLU", - ), - vf_config=FCConfig( - hidden_layers=[32], - activation="ReLU", - ), + shared_encoder_config=shared_encoder_config, + pi_config=pi_config, + vf_config=vf_config, + shared_encoder=shared_encoder, ) @@ -83,7 +91,7 @@ class PPOModuleConfig(RLModuleConfig): Attributes: pi_config: The configuration for the policy network. vf_config: The configuration for the value network. - encoder_config: The configuration for the encoder network. + shared_encoder_config: The configuration for the encoder network. free_log_std: For DiagGaussian action distributions, make the second half of the model outputs floating bias variables instead of state-dependent. This only has an effect is using the default fully connected net. @@ -92,7 +100,7 @@ class PPOModuleConfig(RLModuleConfig): pi_config: FCConfig = None vf_config: FCConfig = None - encoder_config: FCConfig = None + shared_encoder_config: FCConfig = None free_log_std: bool = False shared_encoder: bool = True @@ -109,13 +117,13 @@ def setup(self) -> None: assert self.config.vf_config, "vf_config must be provided." if self.config.shared_encoder: - self.shared_encoder = self.config.encoder_config.build() - self.encoder_pi = IdentityEncoder(self.config.encoder_config) - self.encoder_vf = IdentityEncoder(self.config.encoder_config) + self.shared_encoder = self.config.shared_encoder_config.build() + self.encoder_pi = IdentityEncoder(self.config.pi_config) + self.encoder_vf = IdentityEncoder(self.config.vf_config) else: - self.shared_encoder = IdentityEncoder(self.config.encoder_config) - self.encoder_pi = self.config.encoder_config.build() - self.encoder_vf = self.config.encoder_config.build() + self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) + self.encoder_pi = self.config.pi_config.build() + self.encoder_vf = self.config.vf_config.build() self.pi = FCNet( input_dim=self.config.pi_config.input_dim, @@ -162,14 +170,14 @@ def from_model_config( if use_lstm: assert vf_share_layers, "LSTM not supported with vf_share_layers=False" - encoder_config = LSTMConfig( + shared_encoder_config = LSTMConfig( hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], output_dim=model_config["lstm_cell_size"], num_layers=1, ) else: - encoder_config = FCConfig( + shared_encoder_config = FCConfig( hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], @@ -191,8 +199,8 @@ def from_model_config( ) # build pi network - encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = encoder_config.output_dim + shared_encoder_config.input_dim = observation_space.shape[0] + pi_config.input_dim = shared_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n @@ -200,14 +208,14 @@ def from_model_config( pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_config.input_dim = encoder_config.output_dim + vf_config.input_dim = shared_encoder_config.output_dim vf_config.output_dim = 1 config_ = PPOModuleConfig( observation_space=observation_space, action_space=action_space, max_seq_len=model_config["max_seq_len"], - encoder_config=encoder_config, + shared_encoder_config=shared_encoder_config, pi_config=pi_config, vf_config=vf_config, free_log_std=free_log_std, @@ -218,7 +226,7 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if isinstance(self.config.encoder_config, LSTMConfig): + if isinstance(self.config.shared_encoder_config, LSTMConfig): # TODO (Kourosh): How does this work in RLlib today? if isinstance(self.shared_encoder, LSTMEncoder): return self.shared_encoder.get_inital_state() From a15a82d46b4ed087f8c7675dabbada289ff9bc9a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 17:58:47 +0100 Subject: [PATCH 02/24] tests complete Signed-off-by: Artur Niederfahrenhorst --- .../ppo/torch/ppo_torch_rl_module.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 43ea05fe1285..d5bfe93bdfb2 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -43,29 +43,27 @@ def get_ppo_loss(fwd_in, fwd_out): # TODO: Most of the neural network, and model specs in this file will eventually be # retreived from the model catalog. That includes FCNet, Encoder, etc. def get_expected_model_config(env, lstm, shared_encoder): - assert not env.observation_space.shape[-1] == 3, "Implement VisionNet first" if lstm: encoder_config_class = LSTMConfig else: encoder_config_class = FCConfig pi_config = encoder_config_class( + output_dim=32, hidden_layers=[32], activation="ReLU", ) vf_config = encoder_config_class( - input_dim=32, + output_dim=32, hidden_layers=[32], activation="ReLU", ) if shared_encoder: - shared_encoder_config = ( - encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ), + shared_encoder_config = encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", ) pi_config.input_dim = 32 vf_config.input_dim = 32 @@ -126,15 +124,15 @@ def setup(self) -> None: self.encoder_vf = self.config.vf_config.build() self.pi = FCNet( - input_dim=self.config.pi_config.input_dim, - output_dim=self.config.pi_config.output_dim, + input_dim=self.config.pi_config.output_dim, + output_dim=2, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) self.vf = FCNet( - input_dim=self.config.vf_config.input_dim, - output_dim=self.config.vf_config.output_dim, + input_dim=self.config.vf_config.output_dim, + output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, activation=self.config.vf_config.activation, ) From 37d57089543947a371e0c1116ccd85a6dfb18d6a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 18:07:23 +0100 Subject: [PATCH 03/24] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 69 +++++++++++++++---- .../ppo/torch/ppo_torch_rl_module.py | 42 ----------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 69cdd58a5342..80f2e0c50b33 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,27 +1,74 @@ +import unittest + import numpy as np import tree -import unittest import ray import ray.rllib.algorithms.ppo as ppo - -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - get_expected_model_config, + PPOModuleConfig, PPOTorchRLModule, get_ppo_loss, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.core.rl_module.encoder import ( + FCConfig, + LSTMConfig, +) +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.test_utils import ( check, check_compute_single_action, check_train_results, framework_iterator, ) +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + + +# TODO: Most of the neural network, and model specs in this file will eventually be +# retreived from the model catalog. That includes FCNet, Encoder, etc. +def get_expected_model_config(env, lstm, shared_encoder): + assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" + if lstm: + encoder_config_class = LSTMConfig + else: + encoder_config_class = FCConfig + + pi_config = encoder_config_class( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + vf_config = encoder_config_class( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + if shared_encoder: + shared_encoder_config = encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", + ) + pi_config.input_dim = 32 + vf_config.input_dim = 32 + else: + shared_encoder_config = None + pi_config.input_dim = env.observation_space.shape[0] + vf_config.input_dim = env.observation_space.shape[0] + + return PPOModuleConfig( + observation_space=env.observation_space, + action_space=env.action_space, + shared_encoder_config=shared_encoder_config, + pi_config=pi_config, + vf_config=vf_config, + shared_encoder=shared_encoder, + ) class MyCallbacks(DefaultCallbacks): @@ -185,12 +232,6 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() - def test_torch_model_creation(self): - pass - - def test_torch_model_creation_lstm(self): - pass - def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: for fwd_fn in ["forward_exploration", "forward_inference"]: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index d5bfe93bdfb2..54228e66ce60 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -40,48 +40,6 @@ def get_ppo_loss(fwd_in, fwd_out): return loss -# TODO: Most of the neural network, and model specs in this file will eventually be -# retreived from the model catalog. That includes FCNet, Encoder, etc. -def get_expected_model_config(env, lstm, shared_encoder): - if lstm: - encoder_config_class = LSTMConfig - else: - encoder_config_class = FCConfig - - pi_config = encoder_config_class( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - vf_config = encoder_config_class( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - if shared_encoder: - shared_encoder_config = encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - pi_config.input_dim = 32 - vf_config.input_dim = 32 - else: - shared_encoder_config = None - pi_config.input_dim = env.observation_space.shape[0] - vf_config.input_dim = env.observation_space.shape[0] - - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, - ) - - @dataclass class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO module. From 2d405693cc3c129561cd1588676c7081890c77ce Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 19:06:05 +0100 Subject: [PATCH 04/24] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 103 ++++++++++-------- .../ppo/torch/ppo_torch_rl_module.py | 34 +++--- rllib/core/rl_module/encoder.py | 2 +- 3 files changed, 76 insertions(+), 63 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 80f2e0c50b33..acc80398fd68 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -2,6 +2,7 @@ import numpy as np import tree +import gym import ray import ray.rllib.algorithms.ppo as ppo @@ -28,43 +29,55 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -# TODO: Most of the neural network, and model specs in this file will eventually be -# retreived from the model catalog. That includes FCNet, Encoder, etc. +# Get a model config that we expect from the catalog def get_expected_model_config(env, lstm, shared_encoder): assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" - if lstm: - encoder_config_class = LSTMConfig - else: - encoder_config_class = FCConfig - pi_config = encoder_config_class( - output_dim=32, + shared_encoder_config = FCConfig( + input_dim=env.observation_space.shape[0], hidden_layers=[32], activation="ReLU", ) - vf_config = encoder_config_class( + + if lstm: + pi_encoder_config = LSTMConfig( + hidden_dim=32, + batch_first=True, + output_dim=32, + num_layers=1, + ) + else: + pi_encoder_config = FCConfig( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + vf_encoder_config = FCConfig( output_dim=32, hidden_layers=[32], activation="ReLU", ) - if shared_encoder: - shared_encoder_config = encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - pi_config.input_dim = 32 - vf_config.input_dim = 32 - else: + if not shared_encoder: shared_encoder_config = None - pi_config.input_dim = env.observation_space.shape[0] - vf_config.input_dim = env.observation_space.shape[0] + pi_encoder_config.input_dim = env.observation_space.shape[0] + vf_encoder_config.input_dim = env.observation_space.shape[0] + + pi_config = FCConfig() + vf_config = FCConfig() + + if isinstance(env.action_space, gym.spaces.Discrete): + pi_config.output_dim = env.action_space.n + else: + pi_config.output_dim = env.action_space.shape[0] * 2 return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, shared_encoder_config=shared_encoder_config, + pi_encoder_config=pi_encoder_config, + vf_encoder_config=vf_encoder_config, pi_config=pi_config, vf_config=vf_config, shared_encoder=shared_encoder, @@ -233,10 +246,11 @@ def test_ppo_exploration_setup(self): trainer.stop() def test_rollouts(self): - for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False]: # , True]" + for lstm in [False, True]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -249,18 +263,11 @@ def test_rollouts(self): module = PPOTorchRLModule(config) obs = env.reset() + if lstm: - states = [ - s.get_initial_state() - for s in ( - module.shared_encoder, - module.encoder_vf, - module.encoder_pi, - ) - ] batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - **{f"state_in_{i}": s for i, s in enumerate(states)}, + "state_in": module.pi_encoder.get_inital_state() } else: batch = { @@ -273,10 +280,11 @@ def test_rollouts(self): module.forward_inference(batch) def test_forward_train(self): - for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False]: # , True]" + for lstm in [False, True]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -293,19 +301,19 @@ def test_forward_train(self): obs = env.reset() tstep = 0 if lstm: - states = {} - for i, model in enumerate( - [ - module.shared_encoder, - module.encoder_pi, - module.encoder_vf, - ] - ): - states[i] = model.get_inital_state() + # TODO (Artur): Multiple states + state_in = module.pi_encoder.get_inital_state() while tstep < 10: - fwd_out = module.forward_exploration( - {"obs": convert_to_torch_tensor(obs)[None]} - ) + if lstm: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + "state_in": state_in, + } + else: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(batch) action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) @@ -319,9 +327,8 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - for k, v in states.items(): - step[f"state_in_{k}"] = v - states[k] = fwd_out["state_out"][k] + step[f"state_in"] = state_in + state_in = fwd_out["state_out"] batch.append(step) obs = new_obs tstep += 1 diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 54228e66ce60..bb1bf8701fa1 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -56,6 +56,8 @@ class PPOModuleConfig(RLModuleConfig): pi_config: FCConfig = None vf_config: FCConfig = None + pi_encoder_config: FCConfig = None + vf_encoder_config: FCConfig = None shared_encoder_config: FCConfig = None free_log_std: bool = False shared_encoder: bool = True @@ -74,22 +76,22 @@ def setup(self) -> None: if self.config.shared_encoder: self.shared_encoder = self.config.shared_encoder_config.build() - self.encoder_pi = IdentityEncoder(self.config.pi_config) - self.encoder_vf = IdentityEncoder(self.config.vf_config) + self.pi_encoder = IdentityEncoder(self.config.pi_encoder_config) + self.vf_encoder = IdentityEncoder(self.config.vf_encoder_config) else: self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) - self.encoder_pi = self.config.pi_config.build() - self.encoder_vf = self.config.vf_config.build() + self.pi_encoder = self.config.pi_encoder_config.build() + self.vf_encoder = self.config.vf_encoder_config.build() self.pi = FCNet( - input_dim=self.config.pi_config.output_dim, - output_dim=2, + input_dim=self.config.pi_encoder_config.output_dim, + output_dim=self.config.pi_config.output_dim, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) self.vf = FCNet( - input_dim=self.config.vf_config.output_dim, + input_dim=self.config.vf_encoder_config.output_dim, output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, activation=self.config.vf_config.activation, @@ -139,6 +141,8 @@ def from_model_config( output_dim=model_config["fcnet_hiddens"][-1], ) + pi_encoder_config = FCConfig() + vf_encoder_config = FCConfig() pi_config = FCConfig() vf_config = FCConfig() @@ -156,7 +160,7 @@ def from_model_config( # build pi network shared_encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = shared_encoder_config.output_dim + pi_encoder_config.input_dim = shared_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n @@ -174,6 +178,8 @@ def from_model_config( shared_encoder_config=shared_encoder_config, pi_config=pi_config, vf_config=vf_config, + pi_encoder_config=pi_encoder_config, + vf_encoder_config=vf_encoder_config, free_log_std=free_log_std, shared_encoder=vf_share_layers, ) @@ -187,7 +193,7 @@ def get_initial_state(self) -> NestedDict: if isinstance(self.shared_encoder, LSTMEncoder): return self.shared_encoder.get_inital_state() else: - return self.encoder_pi.get_inital_state() + return self.pi_encoder.get_inital_state() return {} @override(RLModule) @@ -201,7 +207,7 @@ def output_specs_inference(self) -> ModelSpec: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.encoder_pi(encoder_out) + encoder_out_pi = self.pi_encoder(encoder_out) action_logits = self.pi(encoder_out_pi["embedding"]) if self._is_discrete: @@ -242,8 +248,8 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy and the new policy during training. """ encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.encoder_pi(encoder_out) - encoder_out_vf = self.encoder_vf(encoder_out) + encoder_out_pi = self.pi_encoder(encoder_out) + encoder_out_vf = self.vf_encoder(encoder_out) action_logits = self.pi(encoder_out_pi["embedding"]) output = {} @@ -292,8 +298,8 @@ def output_specs_train(self) -> ModelSpec: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.encoder_pi(encoder_out) - encoder_out_vf = self.encoder_vf(encoder_out) + encoder_out_pi = self.pi_encoder(encoder_out) + encoder_out_vf = self.vf_encoder(encoder_out) action_logits = self.pi(encoder_out_pi["embedding"]) vf = self.vf(encoder_out_vf["embedding"]) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 2b5e02bed9ae..d99a1a9e7c76 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -63,7 +63,7 @@ def __init__(self, config: EncoderConfig) -> None: self._output_spec = self.output_spec() def get_inital_state(self): - raise [] + return [] def input_spec(self): return ModelSpec() From 74d213b5a791b44989f8a85581a19bd9d277a696 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Sun, 18 Dec 2022 22:34:21 +0100 Subject: [PATCH 05/24] mutually exclusive encoders, tests passing Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 82 ++++++++++--------- .../ppo/torch/ppo_torch_rl_module.py | 62 ++++++++------ rllib/core/rl_module/encoder.py | 29 ++++--- 3 files changed, 98 insertions(+), 75 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index acc80398fd68..c19b34dde06b 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -13,6 +13,7 @@ get_ppo_loss, ) from ray.rllib.core.rl_module.encoder import ( + IdentityConfig, FCConfig, LSTMConfig, ) @@ -32,38 +33,42 @@ # Get a model config that we expect from the catalog def get_expected_model_config(env, lstm, shared_encoder): assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" + obs_dim = env.observation_space.shape[0] - shared_encoder_config = FCConfig( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - - if lstm: - pi_encoder_config = LSTMConfig( - hidden_dim=32, - batch_first=True, + if shared_encoder: + assert not lstm, "LSTM can only be used in PI" + shared_encoder_config = FCConfig( + input_dim=obs_dim, + hidden_layers=[32], + activation="ReLU", output_dim=32, - num_layers=1, ) + pi_encoder_config = IdentityConfig(output_dim=32) + vf_encoder_config = IdentityConfig(output_dim=32) else: - pi_encoder_config = FCConfig( + shared_encoder_config = IdentityConfig(output_dim=obs_dim) + if lstm: + pi_encoder_config = LSTMConfig( + input_dim=obs_dim, + hidden_dim=32, + batch_first=True, + output_dim=32, + num_layers=1, + ) + else: + pi_encoder_config = FCConfig( + input_dim=obs_dim, + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + vf_encoder_config = FCConfig( + input_dim=obs_dim, output_dim=32, hidden_layers=[32], activation="ReLU", ) - vf_encoder_config = FCConfig( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - if not shared_encoder: - shared_encoder_config = None - pi_encoder_config.input_dim = env.observation_space.shape[0] - vf_encoder_config.input_dim = env.observation_space.shape[0] - pi_config = FCConfig() vf_config = FCConfig() @@ -120,7 +125,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init() + ray.init(local_mode=True) @classmethod def tearDownClass(cls): @@ -250,13 +255,12 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False, True]: + # TODO: LSTM = True + for lstm in [False]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" ) - import gym - env = gym.make(env_name) config = get_expected_model_config(env, lstm, shared_encoder) @@ -267,7 +271,7 @@ def test_rollouts(self): if lstm: batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": module.pi_encoder.get_inital_state() + "state_in": module.pi_encoder.get_inital_state(), } else: batch = { @@ -284,20 +288,19 @@ def test_forward_train(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False, True]: + # TODO: LSTM = True + for lstm in [False]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" ) - import gym - env = gym.make(env_name) config = get_expected_model_config(env, lstm, shared_encoder) module = PPOTorchRLModule(config) # collect a batch of data - batch = [] + batches = [] obs = env.reset() tstep = 0 if lstm: @@ -305,20 +308,21 @@ def test_forward_train(self): state_in = module.pi_encoder.get_inital_state() while tstep < 10: if lstm: - batch = { + input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], "state_in": state_in, + SampleBatch.SEQ_LENS: np.array([1]), } else: - batch = { + input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None] } - fwd_out = module.forward_exploration(batch) + fwd_out = module.forward_exploration(input_batch) action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) new_obs, reward, done, _ = env.step(action) - step = { + output_batch = { SampleBatch.OBS: obs, SampleBatch.NEXT_OBS: new_obs, SampleBatch.ACTIONS: action, @@ -327,14 +331,14 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - step[f"state_in"] = state_in + output_batch["state_in"] = state_in state_in = fwd_out["state_out"] - batch.append(step) + batches.append(output_batch) obs = new_obs tstep += 1 # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batch) + batch = tree.map_structure(lambda *x: list(x), *batches) # convert dict of lists to dict of tensors fwd_in = { k: convert_to_torch_tensor(np.array(v)) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index bb1bf8701fa1..b6e4e27517b8 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -19,8 +19,9 @@ FCNet, FCConfig, LSTMConfig, - IdentityEncoder, + IdentityConfig, LSTMEncoder, + ENCODER_OUT, ) @@ -74,14 +75,9 @@ def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - if self.config.shared_encoder: - self.shared_encoder = self.config.shared_encoder_config.build() - self.pi_encoder = IdentityEncoder(self.config.pi_encoder_config) - self.vf_encoder = IdentityEncoder(self.config.vf_encoder_config) - else: - self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) - self.pi_encoder = self.config.pi_encoder_config.build() - self.vf_encoder = self.config.vf_encoder_config.build() + self.shared_encoder = self.config.shared_encoder_config.build() + self.pi_encoder = self.config.pi_encoder_config.build() + self.vf_encoder = self.config.vf_encoder_config.build() self.pi = FCNet( input_dim=self.config.pi_encoder_config.output_dim, @@ -121,28 +117,44 @@ def from_model_config( else: raise ValueError(f"Unsupported activation: {activation}") + obs_dim = observation_space.shape[0] fcnet_hiddens = model_config["fcnet_hiddens"] vf_share_layers = model_config["vf_share_layers"] free_log_std = model_config["free_log_std"] use_lstm = model_config["use_lstm"] + if vf_share_layers: + shared_encoder_config = FCConfig( + input_dim=obs_dim, + hidden_layers=fcnet_hiddens, + activation=activation, + output_dim=model_config["fcnet_hiddens"][-1], + ) + else: + shared_encoder_config = IdentityConfig(output_dim=obs_dim) + if use_lstm: - assert vf_share_layers, "LSTM not supported with vf_share_layers=False" - shared_encoder_config = LSTMConfig( + pi_encoder_config = LSTMConfig( + input_dim=shared_encoder_config.output_dim, hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], output_dim=model_config["lstm_cell_size"], num_layers=1, ) else: - shared_encoder_config = FCConfig( + pi_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) - pi_encoder_config = FCConfig() - vf_encoder_config = FCConfig() + vf_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=fcnet_hiddens, + activation=activation, + output_dim=model_config["fcnet_hiddens"][-1], + ) pi_config = FCConfig() vf_config = FCConfig() @@ -161,14 +173,15 @@ def from_model_config( # build pi network shared_encoder_config.input_dim = observation_space.shape[0] pi_encoder_config.input_dim = shared_encoder_config.output_dim - + pi_config.input_dim = pi_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_config.input_dim = shared_encoder_config.output_dim + vf_encoder_config.input_dim = shared_encoder_config.output_dim + vf_config.input_dim = vf_encoder_config.output_dim vf_config.output_dim = 1 config_ = PPOModuleConfig( @@ -206,9 +219,10 @@ def output_specs_inference(self) -> ModelSpec: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi["embedding"]) + shared_enc_out = self.shared_encoder(batch) + pi_enc_out = self.pi_encoder(shared_enc_out) + + action_logits = self.pi(pi_enc_out[ENCODER_OUT]) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) @@ -217,7 +231,7 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = encoder_out_pi.get("state_out", {}) + output["state_out"] = pi_enc_out.get("state_out", {}) return output @override(RLModule) @@ -250,7 +264,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi["embedding"]) + action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) output = {} if self._is_discrete: @@ -264,7 +278,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST] = action_dist # compute the value function - output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf["embedding"]).squeeze(-1) + output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1) output["state_out"] = encoder_out_pi.get("state_out", {}) return output @@ -301,8 +315,8 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi["embedding"]) - vf = self.vf(encoder_out_vf["embedding"]) + action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) + vf = self.vf(encoder_out_vf[ENCODER_OUT]) if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index d99a1a9e7c76..1c4d6232d55c 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -13,6 +13,9 @@ # TODO (Kourosh): Find a better / more straight fwd approach for sub-components +ENCODER_OUT = "encoder_out" +STATE_IN = "state_in" + @dataclass class EncoderConfig: @@ -26,6 +29,14 @@ class EncoderConfig: output_dim: int = None +@dataclass +class IdentityConfig(EncoderConfig): + """Configuration for an identity encoder.""" + + def build(self): + return IdentityEncoder(self) + + @dataclass class FCConfig(EncoderConfig): """Configuration for a fully connected network. @@ -97,11 +108,11 @@ def input_spec(self): def output_spec(self): return ModelSpec( - {"embedding": TorchTensorSpec("b, h", h=self.config.output_dim)} + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} ) def _forward(self, input_dict): - return {"embedding": self.net(input_dict[SampleBatch.OBS])} + return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} class LSTMEncoder(Encoder): @@ -129,7 +140,7 @@ def input_spec(self): { # bxt is just a name for better readability to indicated padded batch SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - "state_in": { + STATE_IN: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -144,7 +155,7 @@ def output_spec(self): config = self.config return ModelSpec( { - "embedding": TorchTensorSpec("bxt, h", h=config.output_dim), + ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), "state_out": { "h": TorchTensorSpec("b, h", h=config.hidden_dim), "c": TorchTensorSpec("b, h", h=config.hidden_dim), @@ -154,7 +165,7 @@ def output_spec(self): def forward(self, input_dict: SampleBatch): x = input_dict[SampleBatch.OBS] - states = input_dict["state_in"] + states = input_dict[STATE_IN] # states are batch-first when coming in states = tree.map_structure(lambda x: x.transpose(0, 1), states) @@ -171,7 +182,7 @@ def forward(self, input_dict: SampleBatch): x = x.view(-1, x.shape[-1]) return { - "embedding": x, + ENCODER_OUT: x, "state_out": tree.map_structure(lambda x: x.transpose(0, 1), states_o), } @@ -180,11 +191,5 @@ class IdentityEncoder(Encoder): def __init__(self, config: EncoderConfig) -> None: super().__init__(config) - def input_spec(self): - return ModelSpec() - - def output_spec(self): - return ModelSpec() - def _forward(self, input_dict): return input_dict From ddf8596d83e052d3569006327dd4840f385fdf6f Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 19 Dec 2022 15:41:09 +0100 Subject: [PATCH 06/24] add lstm code Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index c19b34dde06b..d56701184e20 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -3,6 +3,7 @@ import numpy as np import tree import gym +import torch import ray import ray.rllib.algorithms.ppo as ppo @@ -255,8 +256,11 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - # TODO: LSTM = True - for lstm in [False]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -268,15 +272,17 @@ def test_rollouts(self): obs = env.reset() + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } + if lstm: - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": module.pi_encoder.get_inital_state(), - } - else: - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch["state_in"] = state_in + batch["seq_lens"] = torch.Tensor([1]) if fwd_fn == "forward_exploration": module.forward_exploration(batch) @@ -288,8 +294,11 @@ def test_forward_train(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - # TODO: LSTM = True - for lstm in [False]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -306,6 +315,10 @@ def test_forward_train(self): if lstm: # TODO (Artur): Multiple states state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + output_states = state_in while tstep < 10: if lstm: input_batch = { @@ -331,7 +344,12 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - output_batch["state_in"] = state_in + if tstep > 0: # First states are already added + output_states = tree.map_structure( + lambda *s: torch.cat((s[0], s[1])), + output_states, + state_in, + ) state_in = fwd_out["state_out"] batches.append(output_batch) obs = new_obs @@ -344,10 +362,13 @@ def test_forward_train(self): k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items() } + if lstm: + fwd_in["state_in"] = output_states + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) # forward train - # before training make sure it's on the right device and it's on - # trianing mode + # before training make sure module is on the right device and in + # training mode module.to("cpu") module.train() fwd_out = module.forward_train(fwd_in) From 269a5dd0626182f8ac65aca2f5f311aea6da9364 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 19 Dec 2022 16:05:10 +0100 Subject: [PATCH 07/24] add underscore to forward method Signed-off-by: Artur Niederfahrenhorst --- rllib/core/rl_module/encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 1c4d6232d55c..7793b3d2c7d2 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -163,7 +163,7 @@ def output_spec(self): } ) - def forward(self, input_dict: SampleBatch): + def _forward(self, input_dict: SampleBatch): x = input_dict[SampleBatch.OBS] states = input_dict[STATE_IN] # states are batch-first when coming in From b00b9eed137fd067e77c0eac160f5a68c0d25254 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 10:41:19 +0100 Subject: [PATCH 08/24] better docs for get expected model config Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index d56701184e20..20ee0e072405 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -31,8 +31,19 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -# Get a model config that we expect from the catalog -def get_expected_model_config(env, lstm, shared_encoder): +def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: + """Get a PPOModuleConfig that we would expect from the catalog otherwise. + + Args: + env: Environment for which we build the model later + lstm: If True, build recurrent pi encoder + shared_encoder: If True, build a shared encoder for pi and vf, where pi + encoder and vf encoder will be identity. If False, the shared encoder + will be identity. + + Returns: + A PPOModuleConfig containing the relevant configs to build PPORLModule + """ assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" obs_dim = env.observation_space.shape[0] From 8157f63388acfe2f221d07170db60d1d910a5fa5 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 11:56:19 +0100 Subject: [PATCH 09/24] kourosh's comments Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 227 ++++++++++++++++++ .../ppo/tests/test_ppo_with_rl_module.py | 216 +---------------- 2 files changed, 228 insertions(+), 215 deletions(-) create mode 100644 rllib/algorithms/ppo/tests/test_ppo_rl_module.py diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py new file mode 100644 index 000000000000..7adde54cfffd --- /dev/null +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -0,0 +1,227 @@ +import ray +import unittest +import numpy as np +import gym +import torch +import tree + +from ray.rllib import SampleBatch +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule, \ + get_ppo_loss, PPOModuleConfig +from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + + +def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: + """Get a PPOModuleConfig that we would expect from the catalog otherwise. + + Args: + env: Environment for which we build the model later + lstm: If True, build recurrent pi encoder + shared_encoder: If True, build a shared encoder for pi and vf, where pi + encoder and vf encoder will be identity. If False, the shared encoder + will be identity. + + Returns: + A PPOModuleConfig containing the relevant configs to build PPORLModule + """ + assert len(env.observation_space.shape) == 1, "No multidimensional obs space " \ + "supported." + obs_dim = env.observation_space.shape[0] + + if shared_encoder: + assert not lstm, "LSTM can only be used in PI" + shared_encoder_config = FCConfig( + input_dim=obs_dim, + hidden_layers=[32], + activation="ReLU", + output_dim=32, + ) + pi_encoder_config = IdentityConfig(output_dim=32) + vf_encoder_config = IdentityConfig(output_dim=32) + else: + shared_encoder_config = IdentityConfig(output_dim=obs_dim) + if lstm: + pi_encoder_config = LSTMConfig( + input_dim=obs_dim, + hidden_dim=32, + batch_first=True, + output_dim=32, + num_layers=1, + ) + else: + pi_encoder_config = FCConfig( + input_dim=obs_dim, + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + vf_encoder_config = FCConfig( + input_dim=obs_dim, + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + pi_config = FCConfig() + vf_config = FCConfig() + + if isinstance(env.action_space, gym.spaces.Discrete): + pi_config.output_dim = env.action_space.n + else: + pi_config.output_dim = env.action_space.shape[0] * 2 + + return PPOModuleConfig( + observation_space=env.observation_space, + action_space=env.action_space, + shared_encoder_config=shared_encoder_config, + pi_encoder_config=pi_encoder_config, + vf_encoder_config=vf_encoder_config, + pi_config=pi_config, + vf_config=vf_config, + shared_encoder=shared_encoder, + ) + + +class TestPPO(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_rollouts(self): + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + obs = env.reset() + + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } + + if lstm: + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch["state_in"] = state_in + batch["seq_lens"] = torch.Tensor([1]) + + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + elif fwd_fn == "forward_inference": + module.forward_inference(batch) + + + def test_forward_train(self): + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batches = [] + obs = env.reset() + tstep = 0 + if lstm: + # TODO (Artur): Multiple states + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + output_states = state_in + while tstep < 10: + if lstm: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + "state_in": state_in, + SampleBatch.SEQ_LENS: np.array([1]), + } + else: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), + } + if lstm: + assert "state_out" in fwd_out + if tstep > 0: # First states are already added + + # Extend nested batches of states + output_states = tree.map_structure( + lambda *s: torch.cat((s[0], s[1])), + output_states, + state_in, + ) + state_in = fwd_out["state_out"] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batches) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + if lstm: + fwd_in["state_in"] = output_states + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) + + # forward train + # before training make sure module is on the right device and in + # training mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + pass + self.assertIsNotNone(param.grad) + diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 20ee0e072405..9f2f42c15238 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,104 +1,18 @@ import unittest import numpy as np -import tree -import gym -import torch import ray import ray.rllib.algorithms.ppo as ppo from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - PPOModuleConfig, - PPOTorchRLModule, - get_ppo_loss, -) -from ray.rllib.core.rl_module.encoder import ( - IdentityConfig, - FCConfig, - LSTMConfig, -) from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.test_utils import ( check, check_compute_single_action, check_train_results, framework_iterator, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor - - -def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: - """Get a PPOModuleConfig that we would expect from the catalog otherwise. - - Args: - env: Environment for which we build the model later - lstm: If True, build recurrent pi encoder - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. - - Returns: - A PPOModuleConfig containing the relevant configs to build PPORLModule - """ - assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" - obs_dim = env.observation_space.shape[0] - - if shared_encoder: - assert not lstm, "LSTM can only be used in PI" - shared_encoder_config = FCConfig( - input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", - output_dim=32, - ) - pi_encoder_config = IdentityConfig(output_dim=32) - vf_encoder_config = IdentityConfig(output_dim=32) - else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - if lstm: - pi_encoder_config = LSTMConfig( - input_dim=obs_dim, - hidden_dim=32, - batch_first=True, - output_dim=32, - num_layers=1, - ) - else: - pi_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - vf_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - pi_config = FCConfig() - vf_config = FCConfig() - - if isinstance(env.action_space, gym.spaces.Discrete): - pi_config.output_dim = env.action_space.n - else: - pi_config.output_dim = env.action_space.shape[0] * 2 - - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, - ) class MyCallbacks(DefaultCallbacks): @@ -137,7 +51,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls): @@ -262,134 +176,6 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() - def test_rollouts(self): - # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" - ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - obs = env.reset() - - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - } - - if lstm: - state_in = module.pi_encoder.get_inital_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - batch["state_in"] = state_in - batch["seq_lens"] = torch.Tensor([1]) - - if fwd_fn == "forward_exploration": - module.forward_exploration(batch) - elif fwd_fn == "forward_inference": - module.forward_inference(batch) - - def test_forward_train(self): - # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" - ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs = env.reset() - tstep = 0 - if lstm: - # TODO (Artur): Multiple states - state_in = module.pi_encoder.get_inital_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - output_states = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, done, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), - } - if lstm: - assert "state_out" in fwd_out - if tstep > 0: # First states are already added - output_states = tree.map_structure( - lambda *s: torch.cat((s[0], s[1])), - output_states, - state_in, - ) - state_in = fwd_out["state_out"] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) - for k, v in batch.items() - } - if lstm: - fwd_in["state_in"] = output_states - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) - - # forward train - # before training make sure module is on the right device and in - # training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - self.assertIsNotNone(param.grad) - if __name__ == "__main__": import pytest From a174623a324187626a37551bc1834b1585b0c861 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 16:51:27 +0100 Subject: [PATCH 10/24] lstm fixed, tests working Signed-off-by: Artur Niederfahrenhorst --- rllib/BUILD | 7 +++++++ rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 14 ++++++++------ rllib/core/rl_module/encoder.py | 8 ++++++-- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 92b8884efea4..cbf59391c52c 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1095,6 +1095,13 @@ py_test( srcs = ["algorithms/ppo/tests/test_ppo_with_rl_module.py"] ) +py_test( + name = "test_ppo_rl_module", + tags = ["team:rllib", "algorithms_dir"], + size = "large", + srcs = ["algorithms/ppo/tests/test_ppo_rl_module.py"] +) + # PPO Reproducibility py_test( name = "test_repro_ppo", diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 7adde54cfffd..927c0c9c7665 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -6,8 +6,11 @@ import tree from ray.rllib import SampleBatch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule, \ - get_ppo_loss, PPOModuleConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + PPOTorchRLModule, + get_ppo_loss, + PPOModuleConfig, +) from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -26,8 +29,9 @@ def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: Returns: A PPOModuleConfig containing the relevant configs to build PPORLModule """ - assert len(env.observation_space.shape) == 1, "No multidimensional obs space " \ - "supported." + assert len(env.observation_space.shape) == 1, ( + "No multidimensional obs space " "supported." + ) obs_dim = env.observation_space.shape[0] if shared_encoder: @@ -131,7 +135,6 @@ def test_rollouts(self): elif fwd_fn == "forward_inference": module.forward_inference(batch) - def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: @@ -224,4 +227,3 @@ def test_forward_train(self): for param in module.parameters(): pass self.assertIsNotNone(param.grad) - diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 7793b3d2c7d2..729a973fb76e 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -157,8 +157,12 @@ def output_spec(self): { ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), "state_out": { - "h": TorchTensorSpec("b, h", h=config.hidden_dim), - "c": TorchTensorSpec("b, h", h=config.hidden_dim), + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), }, } ) From 3a4ea0122879738332eb8b174c5d1ab86c6ef13a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 16:54:48 +0100 Subject: [PATCH 11/24] add state out Signed-off-by: Artur Niederfahrenhorst --- .../algorithms/ppo/tests/test_ppo_rl_module.py | 18 ++++++++++++------ rllib/core/rl_module/encoder.py | 5 +++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 927c0c9c7665..77f2894fa434 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -11,7 +11,13 @@ get_ppo_loss, PPOModuleConfig, ) -from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig +from ray.rllib.core.rl_module.encoder import ( + FCConfig, + IdentityConfig, + LSTMConfig, + STATE_IN, + STATE_OUT, +) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -127,7 +133,7 @@ def test_rollouts(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - batch["state_in"] = state_in + batch[STATE_IN] = state_in batch["seq_lens"] = torch.Tensor([1]) if fwd_fn == "forward_exploration": @@ -169,7 +175,7 @@ def test_forward_train(self): if lstm: input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": state_in, + STATE_IN: state_in, SampleBatch.SEQ_LENS: np.array([1]), } else: @@ -189,7 +195,7 @@ def test_forward_train(self): SampleBatch.DONES: np.array(done), } if lstm: - assert "state_out" in fwd_out + assert STATE_OUT in fwd_out if tstep > 0: # First states are already added # Extend nested batches of states @@ -198,7 +204,7 @@ def test_forward_train(self): output_states, state_in, ) - state_in = fwd_out["state_out"] + state_in = fwd_out[STATE_OUT] batches.append(output_batch) obs = new_obs tstep += 1 @@ -211,7 +217,7 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in["state_in"] = output_states + fwd_in[STATE_IN] = output_states fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) # forward train diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 729a973fb76e..ceb9cd82f85a 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -15,6 +15,7 @@ ENCODER_OUT = "encoder_out" STATE_IN = "state_in" +STATE_OUT = "state_out" @dataclass @@ -156,7 +157,7 @@ def output_spec(self): return ModelSpec( { ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - "state_out": { + STATE_OUT: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -187,7 +188,7 @@ def _forward(self, input_dict: SampleBatch): return { ENCODER_OUT: x, - "state_out": tree.map_structure(lambda x: x.transpose(0, 1), states_o), + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), } From 5ee9a4d6042fafb3754a3b40fa3e69cd69f57786 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 17:30:09 +0100 Subject: [PATCH 12/24] add __main__ to test Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 77f2894fa434..f13790046828 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -233,3 +233,10 @@ def test_forward_train(self): for param in module.parameters(): pass self.assertIsNotNone(param.grad) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) From eef53af3840a583dc8d5b720dd7f81cf85cc3e2c Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 21:42:57 +0100 Subject: [PATCH 13/24] change lstm testing according to kourosh's comment Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 26 +++++++------------ .../ppo/torch/ppo_torch_rl_module.py | 13 +++++----- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index f13790046828..b8acee1fed98 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -22,7 +22,9 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: +def get_expected_model_config( + env: gym.Env, lstm: bool, shared_encoder: bool +) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: @@ -114,8 +116,8 @@ def test_rollouts(self): # TODO (Artur): Implement continue print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" + f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" + f"={lstm}" ) env = gym.make(env_name) @@ -134,11 +136,11 @@ def test_rollouts(self): lambda x: x[None], convert_to_torch_tensor(state_in) ) batch[STATE_IN] = state_in - batch["seq_lens"] = torch.Tensor([1]) + batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) if fwd_fn == "forward_exploration": module.forward_exploration(batch) - elif fwd_fn == "forward_inference": + else: module.forward_inference(batch) def test_forward_train(self): @@ -170,7 +172,7 @@ def test_forward_train(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - output_states = state_in + initial_state = state_in while tstep < 10: if lstm: input_batch = { @@ -196,14 +198,6 @@ def test_forward_train(self): } if lstm: assert STATE_OUT in fwd_out - if tstep > 0: # First states are already added - - # Extend nested batches of states - output_states = tree.map_structure( - lambda *s: torch.cat((s[0], s[1])), - output_states, - state_in, - ) state_in = fwd_out[STATE_OUT] batches.append(output_batch) obs = new_obs @@ -217,8 +211,8 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in[STATE_IN] = output_states - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) + fwd_in[STATE_IN] = initial_state + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) # forward train # before training make sure module is on the right device and in diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b6e4e27517b8..2502466d496f 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -201,13 +201,12 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if isinstance(self.config.shared_encoder_config, LSTMConfig): - # TODO (Kourosh): How does this work in RLlib today? - if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_inital_state() - else: - return self.pi_encoder.get_inital_state() - return {} + if isinstance(self.shared_encoder, LSTMEncoder): + return self.shared_encoder.get_inital_state() + elif isinstance(self.pi_encoder, LSTMEncoder): + return self.pi_encoder.get_inital_state() + else: + return NestedDict({}) @override(RLModule) def input_specs_inference(self) -> ModelSpec: From 3d1ebde0de7ff764a44b962503f8c0158d6849a6 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 21:46:34 +0100 Subject: [PATCH 14/24] fix get_initial_state Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 5 ++--- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 4 ++-- rllib/core/rl_module/encoder.py | 4 ++-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index b8acee1fed98..c6e9b4395bf5 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -131,7 +131,7 @@ def test_rollouts(self): } if lstm: - state_in = module.pi_encoder.get_inital_state() + state_in = module.get_initial_state() state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) @@ -167,8 +167,7 @@ def test_forward_train(self): obs = env.reset() tstep = 0 if lstm: - # TODO (Artur): Multiple states - state_in = module.pi_encoder.get_inital_state() + state_in = module.get_initial_state() state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 2502466d496f..a96e0b096493 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -202,9 +202,9 @@ def from_model_config( def get_initial_state(self) -> NestedDict: if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_inital_state() + return self.shared_encoder.get_initial_state() elif isinstance(self.pi_encoder, LSTMEncoder): - return self.pi_encoder.get_inital_state() + return self.pi_encoder.get_initial_state() else: return NestedDict({}) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index ceb9cd82f85a..67b197802394 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -74,7 +74,7 @@ def __init__(self, config: EncoderConfig) -> None: self._input_spec = self.input_spec() self._output_spec = self.output_spec() - def get_inital_state(self): + def get_initial_state(self): return [] def input_spec(self): @@ -128,7 +128,7 @@ def __init__(self, config: LSTMConfig) -> None: ) self.linear = nn.Linear(config.hidden_dim, config.output_dim) - def get_inital_state(self): + def get_initial_state(self): config = self.config return { "h": torch.zeros(config.num_layers, config.hidden_dim), From fdab59e1f5cd8fa9261f457bb792db592384d8f1 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 22:02:36 +0100 Subject: [PATCH 15/24] remove useless forward_exploration/forward_inference branch Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 149 +++++++++--------- 1 file changed, 74 insertions(+), 75 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index c6e9b4395bf5..0908a2ad7b80 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -146,86 +146,85 @@ def test_rollouts(self): def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batches = [] + obs = env.reset() + tstep = 0 + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs = env.reset() - tstep = 0 + initial_state = state_in + while tstep < 10: if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - initial_state = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, done, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + STATE_IN: state_in, + SampleBatch.SEQ_LENS: np.array([1]), } - if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) - for k, v in batch.items() + else: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), } if lstm: - fwd_in[STATE_IN] = initial_state - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) - - # forward train - # before training make sure module is on the right device and in - # training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - pass - self.assertIsNotNone(param.grad) + assert STATE_OUT in fwd_out + state_in = fwd_out[STATE_OUT] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batches) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + if lstm: + fwd_in[STATE_IN] = initial_state + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) + + # forward train + # before training make sure module is on the right device and in + # training mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + pass + self.assertIsNotNone(param.grad) if __name__ == "__main__": From d7a9f17eb878239456280ebbabb3aecc1d1aae18 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 22:05:03 +0100 Subject: [PATCH 16/24] revert changes to test_ppo_with_rl_module.py Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 9f2f42c15238..8feb601b9f00 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,9 +1,9 @@ -import unittest - import numpy as np +import unittest import ray import ray.rllib.algorithms.ppo as ppo + from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY From b1bb02f4cab523462f5979af5393eaa0c3f1a4b9 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 22 Dec 2022 17:15:37 +0100 Subject: [PATCH 17/24] remove pass Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 0908a2ad7b80..cbbd995250fa 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -223,7 +223,6 @@ def test_forward_train(self): # check that all neural net parameters have gradients for param in module.parameters(): - pass self.assertIsNotNone(param.grad) From fe9226ea3415016b2a81a0223af27936eb8e44e2 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 22 Dec 2022 17:36:43 +0100 Subject: [PATCH 18/24] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 11 +- .../ppo/torch/ppo_torch_rl_module.py | 90 +++++++++++--- rllib/core/rl_module/encoder.py | 71 ++++++----- rllib/models/base_model.py | 114 +++++++++++++++--- rllib/models/torch/torch_modelv2.py | 2 +- 5 files changed, 215 insertions(+), 73 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index cbbd995250fa..9633daa68b2d 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -15,8 +15,6 @@ FCConfig, IdentityConfig, LSTMConfig, - STATE_IN, - STATE_OUT, ) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -135,7 +133,7 @@ def test_rollouts(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - batch[STATE_IN] = state_in + batch["state_in_0"] = state_in batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) if fwd_fn == "forward_exploration": @@ -175,7 +173,7 @@ def test_forward_train(self): if lstm: input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, + "state_in_0": state_in, SampleBatch.SEQ_LENS: np.array([1]), } else: @@ -195,8 +193,7 @@ def test_forward_train(self): SampleBatch.DONES: np.array(done), } if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] + state_in = fwd_out["state_out_0"] batches.append(output_batch) obs = new_obs tstep += 1 @@ -209,7 +206,7 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in[STATE_IN] = initial_state + fwd_in["state_in_0"] = initial_state fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) # forward train diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 5e65ae236c17..f6687012af8e 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -21,8 +21,8 @@ LSTMConfig, IdentityConfig, LSTMEncoder, - ENCODER_OUT, ) +from ray.rllib.models.base_model import BaseModelIOKeys torch, nn = try_import_torch() @@ -218,10 +218,27 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder(batch) - pi_enc_out = self.pi_encoder(shared_enc_out) + shared_enc_out = self.shared_encoder( + batch, + input_mapping={ + BaseModelIOKeys.IN: SampleBatch.OBS, + BaseModelIOKeys.STATE_I: "state_in_0", + }, + ) + encoder_out_pi = self.pi_encoder( + shared_enc_out, + input_mapping={ + BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], + BaseModelIOKeys.STATE_I: "state_in_0", + }, + ) - action_logits = self.pi(pi_enc_out[ENCODER_OUT]) + action_logits = self.pi( + encoder_out_pi, + input_mapping={ + BaseModelIOKeys.IN: self.pi_encoder.io_map[BaseModelIOKeys.OUT], + }, + ) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) @@ -230,12 +247,12 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = pi_enc_out.get("state_out", {}) + output["state_out"] = encoder_out_pi.get("state_out", {}) return output @override(RLModule) def input_specs_exploration(self): - return self.shared_encoder.input_spec() + return self.shared_encoder.input_spec @override(RLModule) def output_specs_exploration(self) -> SpecDict: @@ -260,10 +277,33 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) + shared_enc_out = self.shared_encoder( + batch, + input_mapping={ + BaseModelIOKeys.IN: SampleBatch.OBS, + BaseModelIOKeys.STATE_IN: "state_in_0", + }, + ) + encoder_out_pi = self.pi_encoder( + shared_enc_out, + input_mapping={ + BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], + BaseModelIOKeys.STATE_IN: "state_in_0", + }, + ) + encoder_out_vf = self.vf_encoder( + shared_enc_out, + input_mapping={ + BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], + }, + ) + + action_logits = self.pi( + encoder_out_pi, + input_mapping={ + BaseModelIOKeys.IN: self.pi_encoder.io_map[BaseModelIOKeys.OUT], + }, + ) output = {} if self._is_discrete: @@ -277,7 +317,13 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST] = action_dist # compute the value function - output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1) + vf_out = self.vf( + encoder_out_vf, + input_mapping={ + BaseModelIOKeys.IN: self.vf_encoder.io_map[BaseModelIOKeys.OUT], + }, + ) + output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) output["state_out"] = encoder_out_pi.get("state_out", {}) return output @@ -289,7 +335,7 @@ def input_specs_train(self) -> SpecDict: action_dim = self.config.action_space.shape[0] action_spec = TorchTensorSpec("b, h", h=action_dim) - spec_dict = self.shared_encoder.input_spec() + spec_dict = self.shared_encoder.input_spec spec_dict.update({SampleBatch.ACTIONS: action_spec}) if SampleBatch.OBS in spec_dict: spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] @@ -310,12 +356,22 @@ def output_specs_train(self) -> SpecDict: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) + encoder_out = self.shared_encoder.forward( + batch, input_map={self.shared_encoder.io_map.IN: SampleBatch.OBS} + ) + encoder_out_pi = self.pi_encoder.forward( + encoder_out, + input_map={self.pi_encoder.io_map.IN: self.shared_encoder.io_map.OUT}, + ) + encoder_out_vf = self.vf_encoder.forward( + encoder_out, + input_map={self.vf_encoder.io_map.IN: self.shared_encoder.io_map.OUT}, + ) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) - vf = self.vf(encoder_out_vf[ENCODER_OUT]) + action_logits = self.pi.forward( + encoder_out_pi, io_map=self.pi_encoder.input_map + ) + vf = self.vf.forward(encoder_out_vf, io_map=self.pi_encoder.input_map) if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index f3bb22b46900..aa2f519747b2 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -5,18 +5,13 @@ from dataclasses import dataclass, field +from ray.rllib.utils.annotations import override from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.primitives import FCNet - -# TODO (Kourosh): Find a better / more straight fwd approach for sub-components - -ENCODER_OUT = "encoder_out" -STATE_IN = "state_in" -STATE_OUT = "state_out" +from ray.rllib.models.base_model import Model, BaseModelIOKeys @dataclass @@ -68,12 +63,11 @@ def build(self): return LSTMEncoder(self) -class Encoder(nn.Module): +class Encoder(nn.Module, Model): def __init__(self, config: EncoderConfig) -> None: - super().__init__() + nn.Module.__init__(self) + Model.__init__(self) self.config = config - self._input_spec = self.input_spec() - self._output_spec = self.output_spec() def get_initial_state(self): return [] @@ -84,14 +78,6 @@ def input_spec(self): def output_spec(self): return SpecDict() - @check_input_specs("_input_spec") - @check_output_specs("_output_spec") - def forward(self, input_dict): - return self._forward(input_dict) - - def _forward(self, input_dict): - raise NotImplementedError - class FullyConnectedEncoder(Encoder): def __init__(self, config: FCConfig) -> None: @@ -104,21 +90,38 @@ def __init__(self, config: FCConfig) -> None: activation=config.activation, ) + @property + @override(Model) def input_spec(self): return SpecDict( - {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} + { + self.io_mapping[BaseModelIOKeys.IN]: TorchTensorSpec( + "b, h", h=self.config.input_dim + ) + } ) + @property + @override(Model) def output_spec(self): return SpecDict( - {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + { + self.io_mapping[BaseModelIOKeys.IN]: TorchTensorSpec( + "b, h", h=self.config.output_dim + ) + } ) def _forward(self, input_dict): - return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} + return {self.ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} -class LSTMEncoder(Encoder): +class RecurrentEncoder(Encoder): + def __init__(self, config: EncoderConfig): + super().__init__(config=config) + + +class LSTMEncoder(RecurrentEncoder): def __init__(self, config: LSTMConfig) -> None: super().__init__(config) @@ -137,13 +140,15 @@ def get_initial_state(self): "c": torch.zeros(config.num_layers, config.hidden_dim), } + @property + @override(Model) def input_spec(self): config = self.config return SpecDict( { # bxt is just a name for better readability to indicated padded batch - SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - STATE_IN: { + self.ENCODER_IN: TorchTensorSpec("bxt, h", h=config.input_dim), + self.STATE_IN: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -154,12 +159,14 @@ def input_spec(self): } ) + @property + @override(Model) def output_spec(self): config = self.config return SpecDict( { - ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { + self.ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), + self.STATE_OUT: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -171,8 +178,8 @@ def output_spec(self): ) def _forward(self, input_dict: SampleBatch): - x = input_dict[SampleBatch.OBS] - states = input_dict[STATE_IN] + x = input_dict[self.STATE_IN] + states = input_dict[self.STATE_IN] # states are batch-first when coming in states = tree.map_structure(lambda x: x.transpose(0, 1), states) @@ -189,14 +196,14 @@ def _forward(self, input_dict: SampleBatch): x = x.view(-1, x.shape[-1]) return { - ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + self.ENCODER_OUT: x, + self.STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), } class IdentityEncoder(Encoder): def __init__(self, config: EncoderConfig) -> None: - super().__init__(config) + super().__init__(config=config) def _forward(self, input_dict): return input_dict diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index c006af27f6f1..35aeb7da4943 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -1,22 +1,15 @@ -# Copyright 2021 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import abc +from enum import Enum from typing import Optional, Tuple +from collections import defaultdict +from typing import Mapping +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, +) -from ray.rllib.models.temp_spec_classes import TensorDict, SpecDict, ModelConfig +from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig from ray.rllib.utils.annotations import ( DeveloperAPI, OverrideToImplementCustomLogic, @@ -30,6 +23,82 @@ UnrollOutputType = Tuple[TensorDict, TensorDict] +@ExperimentalAPI +class BaseModelIOKeys(Enum): + IN = "in" + OUT = "out" + STATE_IN = "state_in" + STATE_OUT = "state_out" + + +class ModelIOMapping(Mapping): + """A mapping from general ModelIOKeys to their instance-based counterparts. + + In order to distinguish keys in input- and outputs-specs of multiple instances of + a give model, each instance is supposed to have distinct keys. + This Mapping provides a way to generate distinct keys per instance of + ModelIOMapping. + + """ + + __init_counters__ = defaultdict(lambda: 0) + + def __init__(self, model_name: str): + self._name: str = model_name + self._init_idx: str = str(self.__init_counters__[model_name]) + self.__init_counters__[model_name] += 1 + self._valid_keys = set() + + def __getitem__(self, item: str): + if item in self._valid_keys: + return self._name + "_" + item + "_" + self._init_idx + else: + raise KeyError( + "`{}` is not a key of ModelIOKeyGenerator with name `{}` " + "and index `{}`. Valid keys are `{}`".format( + item, self._name, self._init_idx, self._valid_keys + ) + ) + + def add(self, key: str): + self._valid_keys.add(key) + + def __repr__(self): + return ( + "ModelIOKeyGenerator for model {} with index {} and valid keys {" + "}".format(self._name, self._init_idx, self._valid_keys) + ) + + def __iter__(self): + return self._valid_keys.__iter__() + + def __len__(self): + return self._valid_keys.__len__() + + def __contains__(self): + return self._valid_keys.__contains__() + + def keys(self): + return iter(self._valid_keys) + + def items(self): + return iter([(k, self[k]) for k in self._valid_keys]) + + def values(self): + return iter([self[k] for k in self._valid_keys]) + + def get(self, name): + raise NotImplementedError + + def __eq__(self, other: "ModelIOMapping") -> bool: + assert isinstance(other, ModelIOMapping) + return self._valid_keys.__eq__(other._valid_keys) + + def __ne__(self, other: "ModelIOMapping") -> bool: + assert isinstance(other, ModelIOMapping) + return self._valid_keys.__ne__(other._valid_keys) + + @ExperimentalAPI class RecurrentModel(abc.ABC): """The base model all other models are based on. @@ -57,6 +126,11 @@ class RecurrentModel(abc.ABC): def __init__(self, name: Optional[str] = None): self._name = name or self.__class__.__name__ + self.io_mapping = ModelIOMapping(self._name) + self.io_mapping.add(BaseModelIOKeys.IN) + self.io_mapping.add(BaseModelIOKeys.OUT) + self.io_mapping.add(BaseModelIOKeys.STATE_IN) + self.io_mapping.add(BaseModelIOKeys.STATE_OUT) @property def name(self) -> str: @@ -272,6 +346,14 @@ def _unroll( outputs = self._forward(inputs, **kwargs) return outputs, TensorDict() + def forward(self, input_dict, input_mapping: Mapping = None) -> ForwardOutputType: + if input_mapping: + for forward_key, input_dict_key in input_mapping.items(): + input_dict[self.io_mapping[forward_key]] = input_dict[input_dict_key] + return check_input_specs("input_spec")( + (check_output_specs("outputs_spec")(self._forward(input_dict))) + ) + @abc.abstractmethod def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: """Computes the output of this module for each timestep. diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index b56bf425fb6f..728d6b2fa431 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -67,7 +67,7 @@ def variables( ) -> Union[List[TensorType], Dict[str, TensorType]]: p = list(self.parameters()) if as_dict: - return {k: p[i] for i, k in enumerate(self.state_dict().keys())} + return {k: p[i] for i, k in enumerate(self.state_dict().io_map())} return p @override(ModelV2) From 8f36e450c7548866347f2b3796b05afc00be3b4e Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 22 Dec 2022 18:07:17 +0100 Subject: [PATCH 19/24] fix gym incompatability Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 9 +++++---- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index cbbd995250fa..8e3fa4201820 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -124,7 +124,7 @@ def test_rollouts(self): config = get_expected_model_config(env, lstm, shared_encoder) module = PPOTorchRLModule(config) - obs = env.reset() + obs, _ = env.reset() batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], @@ -163,7 +163,7 @@ def test_forward_train(self): # collect a batch of data batches = [] - obs = env.reset() + obs, _ = env.reset() tstep = 0 if lstm: state_in = module.get_initial_state() @@ -186,13 +186,14 @@ def test_forward_train(self): action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) - new_obs, reward, done, _ = env.step(action) + new_obs, reward, terminated, truncated, _ = env.step(action) output_batch = { SampleBatch.OBS: obs, SampleBatch.NEXT_OBS: new_obs, SampleBatch.ACTIONS: action, SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), + SampleBatch.TERMINATEDS: np.array(terminated), + SampleBatch.TRUNCATEDS: np.array(truncated), } if lstm: assert STATE_OUT in fwd_out diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 5e65ae236c17..7db520559e75 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -23,6 +23,7 @@ LSTMEncoder, ENCODER_OUT, ) +from rllib.utils.gym import convert_old_gym_space_to_gymnasium_space torch, nn = try_import_torch() @@ -93,7 +94,10 @@ def setup(self) -> None: activation=self.config.vf_config.activation, ) - self._is_discrete = isinstance(self.config.action_space, gym.spaces.Discrete) + self._is_discrete = isinstance( + convert_old_gym_space_to_gymnasium_space(self.config.action_space), + gym.spaces.Discrete, + ) @classmethod @override(RLModule) From 30be028be19c970e79d9e18c7587d8b6db1ced08 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Sun, 25 Dec 2022 12:36:30 +0100 Subject: [PATCH 20/24] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 4 +- .../ppo/tests/test_ppo_with_rl_module.py | 2 +- .../ppo/torch/ppo_torch_rl_module.py | 57 +++++++++++-------- rllib/core/rl_module/encoder.py | 48 ++++++++++------ rllib/models/base_model.py | 39 ++++++------- 5 files changed, 85 insertions(+), 65 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 0e5668a5313d..63ea3d0608e4 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -108,7 +108,7 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [True, False]: + for lstm in [False, True]: if lstm and shared_encoder: # Not yet implemented # TODO (Artur): Implement @@ -145,7 +145,7 @@ def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: for shared_encoder in [False, True]: - for lstm in [True, False]: + for lstm in [False, True]: if lstm and shared_encoder: # Not yet implemented # TODO (Artur): Implement diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 8feb601b9f00..eb6d392f9a1c 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -98,7 +98,7 @@ def test_ppo_compilation_and_schedule_mixins(self): for env in ["CartPole-v1", "Pendulum-v1"]: print("Env={}".format(env)) # TODO (Kourosh): for now just do lstm=False - for lstm in [False]: + for lstm in [False, True]: print("LSTM={}".format(lstm)) config.training( model=dict( diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 12c72f42e1e8..36b1cd534076 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -222,28 +222,28 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder( + x = self.shared_encoder( batch, input_mapping={ BaseModelIOKeys.IN: SampleBatch.OBS, BaseModelIOKeys.STATE_I: "state_in_0", }, ) - encoder_out_pi = self.pi_encoder( - shared_enc_out, + x = self.pi_encoder( + x, input_mapping={ - BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], + BaseModelIOKeys.IN: self.shared_encoder.io[BaseModelIOKeys.OUT], BaseModelIOKeys.STATE_I: "state_in_0", }, ) - action_logits = self.pi( - encoder_out_pi, + x = self.pi( + x, input_mapping={ - BaseModelIOKeys.IN: self.pi_encoder.io_map[BaseModelIOKeys.OUT], + BaseModelIOKeys.IN: self.pi_encoder.io[BaseModelIOKeys.OUT], }, ) - + action_logits = x["action_logits"] if self._is_discrete: action = torch.argmax(action_logits, dim=-1) else: @@ -251,7 +251,7 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = encoder_out_pi.get("state_out", {}) + output["state_out"] = x.get("state_out", {}) return output @override(RLModule) @@ -281,34 +281,41 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ - shared_enc_out = self.shared_encoder( + x = self.shared_encoder( batch, input_mapping={ - BaseModelIOKeys.IN: SampleBatch.OBS, - BaseModelIOKeys.STATE_IN: "state_in_0", + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", }, ) - encoder_out_pi = self.pi_encoder( - shared_enc_out, + x = self.pi_encoder( + x, input_mapping={ - BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], - BaseModelIOKeys.STATE_IN: "state_in_0", + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", }, ) - encoder_out_vf = self.vf_encoder( - shared_enc_out, + x = self.vf_encoder( + x, input_mapping={ - BaseModelIOKeys.IN: self.shared_encoder.io_map[BaseModelIOKeys.OUT], + self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], }, ) - action_logits = self.pi( - encoder_out_pi, + x = self.pi( + x, input_mapping={ - BaseModelIOKeys.IN: self.pi_encoder.io_map[BaseModelIOKeys.OUT], + self.pi.io[BaseModelIOKeys.OUT]: self.pi_encoder.io[ + BaseModelIOKeys.OUT + ], }, ) + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] output = {} if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) @@ -322,13 +329,13 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: # compute the value function vf_out = self.vf( - encoder_out_vf, + x, input_mapping={ - BaseModelIOKeys.IN: self.vf_encoder.io_map[BaseModelIOKeys.OUT], + self.vf.io.IN: self.vf_encoder.io.OUT, }, ) output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) - output["state_out"] = encoder_out_pi.get("state_out", {}) + output["state_out"] = x.get("state_out", {}) return output @override(RLModule) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index aa2f519747b2..a83303a4d99f 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -3,6 +3,7 @@ import tree from typing import List +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from dataclasses import dataclass, field from ray.rllib.utils.annotations import override @@ -63,7 +64,7 @@ def build(self): return LSTMEncoder(self) -class Encoder(nn.Module, Model): +class Encoder(Model, nn.Module): def __init__(self, config: EncoderConfig) -> None: nn.Module.__init__(self) Model.__init__(self) @@ -72,9 +73,11 @@ def __init__(self, config: EncoderConfig) -> None: def get_initial_state(self): return [] + @property def input_spec(self): return SpecDict() + @property def output_spec(self): return SpecDict() @@ -95,7 +98,7 @@ def __init__(self, config: FCConfig) -> None: def input_spec(self): return SpecDict( { - self.io_mapping[BaseModelIOKeys.IN]: TorchTensorSpec( + self.io[BaseModelIOKeys.IN]: TorchTensorSpec( "b, h", h=self.config.input_dim ) } @@ -106,14 +109,17 @@ def input_spec(self): def output_spec(self): return SpecDict( { - self.io_mapping[BaseModelIOKeys.IN]: TorchTensorSpec( + self.io[BaseModelIOKeys.out]: TorchTensorSpec( "b, h", h=self.config.output_dim ) } ) - def _forward(self, input_dict): - return {self.ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict, **kwargs): + inputs = input_dict[self.io[BaseModelIOKeys.IN]] + return {self.io[BaseModelIOKeys.OUT]: self.net(inputs)} class RecurrentEncoder(Encoder): @@ -147,8 +153,10 @@ def input_spec(self): return SpecDict( { # bxt is just a name for better readability to indicated padded batch - self.ENCODER_IN: TorchTensorSpec("bxt, h", h=config.input_dim), - self.STATE_IN: { + self.io[BaseModelIOKeys.IN]: TorchTensorSpec( + "bxt, h", h=config.input_dim + ), + self.io[BaseModelIOKeys.STATE_IN]: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -165,8 +173,10 @@ def output_spec(self): config = self.config return SpecDict( { - self.ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - self.STATE_OUT: { + self.io[BaseModelIOKeys.OUT]: TorchTensorSpec( + "bxt, h", h=config.output_dim + ), + self.io[BaseModelIOKeys.STATE_OUT]: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -177,9 +187,11 @@ def output_spec(self): } ) - def _forward(self, input_dict: SampleBatch): - x = input_dict[self.STATE_IN] - states = input_dict[self.STATE_IN] + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict: SampleBatch, **kwargs): + x = input_dict[self.io[BaseModelIOKeys.IN]] + states = input_dict[self.io[BaseModelIOKeys.STATE_IN]] # states are batch-first when coming in states = tree.map_structure(lambda x: x.transpose(0, 1), states) @@ -196,8 +208,10 @@ def _forward(self, input_dict: SampleBatch): x = x.view(-1, x.shape[-1]) return { - self.ENCODER_OUT: x, - self.STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + self.io[BaseModelIOKeys.OUT]: x, + self.io[BaseModelIOKeys.STATE_OUT]: tree.map_structure( + lambda x: x.transpose(0, 1), states_o + ), } @@ -205,5 +219,7 @@ class IdentityEncoder(Encoder): def __init__(self, config: EncoderConfig) -> None: super().__init__(config=config) - def _forward(self, input_dict): - return input_dict + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict, **kwargs): + return {self.io[BaseModelIOKeys.OUT]: input_dict[self.io[BaseModelIOKeys.IN]]} diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index 35aeb7da4943..7c62202a3ccc 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -4,10 +4,6 @@ from collections import defaultdict from typing import Mapping from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.checker import ( - check_input_specs, - check_output_specs, -) from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig from ray.rllib.utils.annotations import ( @@ -25,10 +21,10 @@ @ExperimentalAPI class BaseModelIOKeys(Enum): - IN = "in" - OUT = "out" - STATE_IN = "state_in" - STATE_OUT = "state_out" + IN: str = "in" + OUT: str = "out" + STATE_IN: str = "state_in" + STATE_OUT: str = "state_out" class ModelIOMapping(Mapping): @@ -49,9 +45,9 @@ def __init__(self, model_name: str): self.__init_counters__[model_name] += 1 self._valid_keys = set() - def __getitem__(self, item: str): + def __getitem__(self, item): if item in self._valid_keys: - return self._name + "_" + item + "_" + self._init_idx + return self._name + "_" + str(item) + "_" + self._init_idx else: raise KeyError( "`{}` is not a key of ModelIOKeyGenerator with name `{}` " @@ -60,7 +56,7 @@ def __getitem__(self, item: str): ) ) - def add(self, key: str): + def add(self, key): self._valid_keys.add(key) def __repr__(self): @@ -126,11 +122,11 @@ class RecurrentModel(abc.ABC): def __init__(self, name: Optional[str] = None): self._name = name or self.__class__.__name__ - self.io_mapping = ModelIOMapping(self._name) - self.io_mapping.add(BaseModelIOKeys.IN) - self.io_mapping.add(BaseModelIOKeys.OUT) - self.io_mapping.add(BaseModelIOKeys.STATE_IN) - self.io_mapping.add(BaseModelIOKeys.STATE_OUT) + self.io = ModelIOMapping(self._name) + self.io.add(BaseModelIOKeys.IN) + self.io.add(BaseModelIOKeys.OUT) + self.io.add(BaseModelIOKeys.STATE_IN) + self.io.add(BaseModelIOKeys.STATE_OUT) @property def name(self) -> str: @@ -346,13 +342,14 @@ def _unroll( outputs = self._forward(inputs, **kwargs) return outputs, TensorDict() - def forward(self, input_dict, input_mapping: Mapping = None) -> ForwardOutputType: + def forward( + self, input_dict, input_mapping: Mapping = None, **kwargs + ) -> ForwardOutputType: if input_mapping: for forward_key, input_dict_key in input_mapping.items(): - input_dict[self.io_mapping[forward_key]] = input_dict[input_dict_key] - return check_input_specs("input_spec")( - (check_output_specs("outputs_spec")(self._forward(input_dict))) - ) + input_dict[forward_key] = input_dict[input_dict_key] + + return self._forward(input_dict, **kwargs) @abc.abstractmethod def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: From b1f8064d5f642418934abe103bf3cc856eafe6fb Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 26 Dec 2022 21:17:53 +0100 Subject: [PATCH 21/24] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 20 ++- .../ppo/torch/ppo_torch_rl_module.py | 123 +++++++++--------- rllib/core/rl_module/encoder.py | 2 +- rllib/models/base_model.py | 7 +- rllib/utils/nested_dict.py | 7 +- 5 files changed, 81 insertions(+), 78 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 63ea3d0608e4..95b1283d3843 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -74,8 +74,18 @@ def get_expected_model_config( activation="ReLU", ) - pi_config = FCConfig() - vf_config = FCConfig() + pi_config = FCConfig( + input_dim=pi_encoder_config.output_dim, + hidden_layers=[16], + ) + if isinstance(env.action_space, gym.spaces.Discrete): + pi_config.output_dim = env.action_space.n + else: + pi_config.output_dim = env.action_space.shape[0] * 2 + + vf_config = FCConfig( + input_dim=vf_encoder_config.output_dim, hidden_layers=[16], output_dim=1 + ) if isinstance(env.action_space, gym.spaces.Discrete): pi_config.output_dim = env.action_space.n @@ -115,7 +125,7 @@ def test_rollouts(self): continue print( f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" - f"={lstm}" + f"={lstm} | [FWD={fwd_fn}" ) env = gym.make(env_name) @@ -144,8 +154,8 @@ def test_rollouts(self): def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: - for shared_encoder in [False, True]: - for lstm in [False, True]: + for shared_encoder in [False]: + for lstm in [True]: if lstm and shared_encoder: # Not yet implemented # TODO (Artur): Implement diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 36b1cd534076..f3f041c754d6 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -16,7 +16,6 @@ TorchDiagGaussian, ) from ray.rllib.core.rl_module.encoder import ( - FCNet, FCConfig, LSTMConfig, IdentityConfig, @@ -79,20 +78,8 @@ def setup(self) -> None: self.shared_encoder = self.config.shared_encoder_config.build() self.pi_encoder = self.config.pi_encoder_config.build() self.vf_encoder = self.config.vf_encoder_config.build() - - self.pi = FCNet( - input_dim=self.config.pi_encoder_config.output_dim, - output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, - ) - - self.vf = FCNet( - input_dim=self.config.vf_encoder_config.output_dim, - output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, - ) + self.pi = self.config.pi_config.build() + self.vf = self.config.vf_config.build() self._is_discrete = isinstance( convert_old_gym_space_to_gymnasium_space(self.config.action_space), @@ -212,10 +199,6 @@ def get_initial_state(self) -> NestedDict: else: return NestedDict({}) - @override(RLModule) - def input_specs_inference(self) -> SpecDict: - return self.input_specs_exploration() - @override(RLModule) def output_specs_inference(self) -> SpecDict: return SpecDict({SampleBatch.ACTION_DIST: TorchDeterministic}) @@ -225,25 +208,27 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: x = self.shared_encoder( batch, input_mapping={ - BaseModelIOKeys.IN: SampleBatch.OBS, - BaseModelIOKeys.STATE_I: "state_in_0", + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", }, ) x = self.pi_encoder( x, input_mapping={ - BaseModelIOKeys.IN: self.shared_encoder.io[BaseModelIOKeys.OUT], - BaseModelIOKeys.STATE_I: "state_in_0", + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", }, ) x = self.pi( x, input_mapping={ - BaseModelIOKeys.IN: self.pi_encoder.io[BaseModelIOKeys.OUT], + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], }, ) - action_logits = x["action_logits"] + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] if self._is_discrete: action = torch.argmax(action_logits, dim=-1) else: @@ -251,13 +236,9 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = x.get("state_out", {}) + output["state_out_0"] = x.get("state_out", {}) return output - @override(RLModule) - def input_specs_exploration(self): - return self.shared_encoder.input_spec - @override(RLModule) def output_specs_exploration(self) -> SpecDict: specs = {SampleBatch.ACTION_DIST: self.__get_action_dist_type()} @@ -309,14 +290,12 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: x = self.pi( x, input_mapping={ - self.pi.io[BaseModelIOKeys.OUT]: self.pi_encoder.io[ - BaseModelIOKeys.OUT - ], + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], }, ) - action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] output = {} + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) output[SampleBatch.ACTION_DIST_INPUTS] = {"logits": action_logits} @@ -331,27 +310,17 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: vf_out = self.vf( x, input_mapping={ - self.vf.io.IN: self.vf_encoder.io.OUT, + self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT], }, ) - output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) - output["state_out"] = x.get("state_out", {}) - return output + output[SampleBatch.VF_PREDS] = vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze( + -1 + ) - @override(RLModule) - def input_specs_train(self) -> SpecDict: - if self._is_discrete: - action_spec = TorchTensorSpec("b") - else: - action_dim = self.config.action_space.shape[0] - action_spec = TorchTensorSpec("b, h", h=action_dim) - - spec_dict = self.shared_encoder.input_spec - spec_dict.update({SampleBatch.ACTIONS: action_spec}) - if SampleBatch.OBS in spec_dict: - spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] - spec = SpecDict(spec_dict) - return spec + shared_encoder_state = x.get(self.shared_encoder.io[BaseModelIOKeys.STATE_OUT]) + pi_encoder_state = x.get(self.pi_encoder.io[BaseModelIOKeys.STATE_OUT]) + output["state_out_0"] = shared_encoder_state or pi_encoder_state + return output @override(RLModule) def output_specs_train(self) -> SpecDict: @@ -367,22 +336,46 @@ def output_specs_train(self) -> SpecDict: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder.forward( - batch, input_map={self.shared_encoder.io_map.IN: SampleBatch.OBS} + x = self.shared_encoder( + batch, + input_mapping={ + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, ) - encoder_out_pi = self.pi_encoder.forward( - encoder_out, - input_map={self.pi_encoder.io_map.IN: self.shared_encoder.io_map.OUT}, + x = self.pi_encoder( + x, + input_mapping={ + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, ) - encoder_out_vf = self.vf_encoder.forward( - encoder_out, - input_map={self.vf_encoder.io_map.IN: self.shared_encoder.io_map.OUT}, + x = self.vf_encoder( + x, + input_mapping={ + self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + }, + ) + + x = self.pi( + x, + input_mapping={ + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], + }, ) - action_logits = self.pi.forward( - encoder_out_pi, io_map=self.pi_encoder.input_map + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] + + vf_out = self.vf( + x, + input_mapping={ + self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT], + }, ) - vf = self.vf.forward(encoder_out_vf, io_map=self.pi_encoder.input_map) if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) @@ -396,11 +389,11 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: output = { SampleBatch.ACTION_DIST: action_dist, SampleBatch.ACTION_LOGP: logp, - SampleBatch.VF_PREDS: vf.squeeze(-1), + SampleBatch.VF_PREDS: vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze(-1), "entropy": entropy, } - output["state_out"] = encoder_out_pi.get("state_out", {}) + output["state_out_0"] = x.get("state_out", {}) return output def __get_action_dist_type(self): diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index a83303a4d99f..46656200d014 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -109,7 +109,7 @@ def input_spec(self): def output_spec(self): return SpecDict( { - self.io[BaseModelIOKeys.out]: TorchTensorSpec( + self.io[BaseModelIOKeys.OUT]: TorchTensorSpec( "b, h", h=self.config.output_dim ) } diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index 7c62202a3ccc..fb80c2efdacc 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -347,9 +347,10 @@ def forward( ) -> ForwardOutputType: if input_mapping: for forward_key, input_dict_key in input_mapping.items(): - input_dict[forward_key] = input_dict[input_dict_key] - - return self._forward(input_dict, **kwargs) + if input_dict_key in input_dict: + input_dict[forward_key] = input_dict[input_dict_key] + input_dict.update(self._forward(input_dict, **kwargs)) + return input_dict @abc.abstractmethod def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: diff --git a/rllib/utils/nested_dict.py b/rllib/utils/nested_dict.py index 1e4d308d1ef5..05108a0acc07 100644 --- a/rllib/utils/nested_dict.py +++ b/rllib/utils/nested_dict.py @@ -166,10 +166,7 @@ def get( k = _flatten_index(k) if k not in self: - if default is not None: - return default - else: - raise KeyError(k) + return default data_ptr = self._data for key in k: @@ -180,6 +177,8 @@ def get( return data_ptr def __getitem__(self, k: SeqStrType) -> T: + if k not in self: + raise KeyError(k) output = self.get(k) return output From 05cc03a03a7105d209956d6cf00ab33dce8daed7 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 27 Dec 2022 14:25:15 +0100 Subject: [PATCH 22/24] fix lstm test Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py | 2 +- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index eb6d392f9a1c..8feb601b9f00 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -98,7 +98,7 @@ def test_ppo_compilation_and_schedule_mixins(self): for env in ["CartPole-v1", "Pendulum-v1"]: print("Env={}".format(env)) # TODO (Kourosh): for now just do lstm=False - for lstm in [False, True]: + for lstm in [False]: print("LSTM={}".format(lstm)) config.training( model=dict( diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index f3f041c754d6..074889d034cf 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -319,7 +319,10 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: shared_encoder_state = x.get(self.shared_encoder.io[BaseModelIOKeys.STATE_OUT]) pi_encoder_state = x.get(self.pi_encoder.io[BaseModelIOKeys.STATE_OUT]) - output["state_out_0"] = shared_encoder_state or pi_encoder_state + + state_out = shared_encoder_state or pi_encoder_state + if state_out: + output["state_out_0"] = state_out return output @override(RLModule) From 66ac43c0ea9f8e28e7356c9d1cbaceaa3ac7ad10 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 27 Dec 2022 14:34:52 +0100 Subject: [PATCH 23/24] add missing ray. Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 7db520559e75..40c0d1fd302c 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -23,7 +23,7 @@ LSTMEncoder, ENCODER_OUT, ) -from rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space torch, nn = try_import_torch() From 1d09f0336aa83a4b953d5162e39dd9f2be8195b1 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 3 Jan 2023 11:16:16 +0100 Subject: [PATCH 24/24] typo Signed-off-by: Artur Niederfahrenhorst --- rllib/models/base_model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index fb80c2efdacc..516bfa3b387f 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -50,7 +50,7 @@ def __getitem__(self, item): return self._name + "_" + str(item) + "_" + self._init_idx else: raise KeyError( - "`{}` is not a key of ModelIOKeyGenerator with name `{}` " + "`{}` is not a key of ModelIOMapping for model_name `{}` " "and index `{}`. Valid keys are `{}`".format( item, self._name, self._init_idx, self._valid_keys ) @@ -60,9 +60,8 @@ def add(self, key): self._valid_keys.add(key) def __repr__(self): - return ( - "ModelIOKeyGenerator for model {} with index {} and valid keys {" - "}".format(self._name, self._init_idx, self._valid_keys) + return "ModelIOMapping for model {} with index {} and valid keys {" "}".format( + self._name, self._init_idx, self._valid_keys ) def __iter__(self):