From b1ddd7396280851c166822bbd04d8e45148ed65b Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 17:40:09 +0100 Subject: [PATCH 01/51] initial Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 130 ++++++++++++++++++ .../ppo/torch/ppo_torch_rl_module.py | 84 ++++++++--- 2 files changed, 195 insertions(+), 19 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 4f37ffd6476c..2b7dcc75cbed 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,12 +1,21 @@ import numpy as np +import tree import unittest import ray import ray.rllib.algorithms.ppo as ppo +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + get_expected_model_config, + PPOTorchRLModule, + get_ppo_loss, +) +from ray.rllib.utils.torch_utils import convert_to_torch_tensor from ray.rllib.utils.test_utils import ( check, check_compute_single_action, @@ -188,6 +197,127 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() + def test_torch_model_creation(self): + pass + + def test_torch_model_creation_lstm(self): + pass + + def test_rollouts(self): + for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [False]: # , True]" + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + import gym + + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + obs = env.reset() + if lstm: + states = [ + s.get_initial_state() + for s in ( + module.shared_encoder, + module.encoder_vf, + module.encoder_pi, + ) + ] + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + **{f"state_in_{i}": s for i, s in enumerate(states)}, + } + else: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + elif fwd_fn == "forward_inference": + module.forward_inference(batch) + + def test_forward_train(self): + for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [False]: # , True]" + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + import gym + + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batch = [] + obs = env.reset() + tstep = 0 + if lstm: + states = {} + for i, model in enumerate( + [ + module.shared_encoder, + module.encoder_pi, + module.encoder_vf, + ] + ): + states[i] = model.get_inital_state() + while tstep < 10: + fwd_out = module.forward_exploration( + {"obs": convert_to_torch_tensor(obs)[None]} + ) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + step = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), + } + if lstm: + assert "state_out" in fwd_out + for k, v in states.items(): + step[f"state_in_{k}"] = v + states[k] = fwd_out["state_out"][k] + batch.append(step) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batch) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + + # forward train + # before training make sure it's on the right device and it's on + # trianing mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + self.assertIsNotNone(param.grad) + if __name__ == "__main__": import pytest diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b3cc51a732f3..0065b0eb73ff 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -42,6 +42,50 @@ def get_ppo_loss(fwd_in, fwd_out): return loss +# TODO: Most of the neural network, and model specs in this file will eventually be +# retreived from the model catalog. That includes FCNet, Encoder, etc. +def get_expected_model_config(env, lstm, shared_encoder): + assert not env.observation_space.shape[-1] == 3, "Implement VisionNet first" + if lstm: + encoder_config_class = LSTMConfig + else: + encoder_config_class = FCConfig + + pi_config = encoder_config_class( + hidden_layers=[32], + activation="ReLU", + ) + vf_config = encoder_config_class( + input_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + if shared_encoder: + shared_encoder_config = ( + encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", + ), + ) + pi_config.input_dim = 32 + vf_config.input_dim = 32 + else: + shared_encoder_config = None + pi_config.input_dim = env.observation_space.shape[0] + vf_config.input_dim = env.observation_space.shape[0] + + return PPOModuleConfig( + observation_space=env.observation_space, + action_space=env.action_space, + shared_encoder_config=shared_encoder_config, + pi_config=pi_config, + vf_config=vf_config, + shared_encoder=shared_encoder, + ) + + @dataclass class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO module. @@ -58,8 +102,6 @@ class PPOModuleConfig(RLModuleConfig): pi_config: FCConfig = None vf_config: FCConfig = None - pi_encoder_config: FCConfig = None - vf_encoder_config: FCConfig = None shared_encoder_config: FCConfig = None free_log_std: bool = False shared_encoder: bool = True @@ -76,9 +118,14 @@ def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - self.shared_encoder = self.config.shared_encoder_config.build() - self.pi_encoder = self.config.pi_encoder_config.build() - self.vf_encoder = self.config.vf_encoder_config.build() + if self.config.shared_encoder: + self.shared_encoder = self.config.shared_encoder_config.build() + self.encoder_pi = IdentityEncoder(self.config.pi_config) + self.encoder_vf = IdentityEncoder(self.config.vf_config) + else: + self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) + self.encoder_pi = self.config.pi_config.build() + self.encoder_vf = self.config.vf_config.build() self.pi = FCNet( input_dim=self.config.pi_encoder_config.output_dim, @@ -138,16 +185,15 @@ def from_model_config( shared_encoder_config = IdentityConfig(output_dim=obs_dim) if use_lstm: - pi_encoder_config = LSTMConfig( - input_dim=shared_encoder_config.output_dim, + assert vf_share_layers, "LSTM not supported with vf_share_layers=False" + shared_encoder_config = LSTMConfig( hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], output_dim=model_config["lstm_cell_size"], num_layers=1, ) else: - pi_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, + shared_encoder_config = FCConfig( hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], @@ -176,16 +222,15 @@ def from_model_config( # build pi network shared_encoder_config.input_dim = observation_space.shape[0] - pi_encoder_config.input_dim = shared_encoder_config.output_dim - pi_config.input_dim = pi_encoder_config.output_dim + pi_config.input_dim = shared_encoder_config.output_dim + if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_encoder_config.input_dim = shared_encoder_config.output_dim - vf_config.input_dim = vf_encoder_config.output_dim + vf_config.input_dim = shared_encoder_config.output_dim vf_config.output_dim = 1 config_ = PPOModuleConfig( @@ -205,12 +250,13 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_initial_state() - elif isinstance(self.pi_encoder, LSTMEncoder): - return self.pi_encoder.get_initial_state() - else: - return NestedDict({}) + if isinstance(self.config.shared_encoder_config, LSTMConfig): + # TODO (Kourosh): How does this work in RLlib today? + if isinstance(self.shared_encoder, LSTMEncoder): + return self.shared_encoder.get_inital_state() + else: + return self.encoder_pi.get_inital_state() + return {} @override(RLModule) def input_specs_inference(self) -> SpecDict: From 5489c89bb755a2c20a91e7ec1fe1a44292c5a8bd Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 17:58:47 +0100 Subject: [PATCH 02/51] tests complete Signed-off-by: Artur Niederfahrenhorst --- .../ppo/torch/ppo_torch_rl_module.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 0065b0eb73ff..0b28566001ac 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -45,29 +45,27 @@ def get_ppo_loss(fwd_in, fwd_out): # TODO: Most of the neural network, and model specs in this file will eventually be # retreived from the model catalog. That includes FCNet, Encoder, etc. def get_expected_model_config(env, lstm, shared_encoder): - assert not env.observation_space.shape[-1] == 3, "Implement VisionNet first" if lstm: encoder_config_class = LSTMConfig else: encoder_config_class = FCConfig pi_config = encoder_config_class( + output_dim=32, hidden_layers=[32], activation="ReLU", ) vf_config = encoder_config_class( - input_dim=32, + output_dim=32, hidden_layers=[32], activation="ReLU", ) if shared_encoder: - shared_encoder_config = ( - encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ), + shared_encoder_config = encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", ) pi_config.input_dim = 32 vf_config.input_dim = 32 @@ -128,14 +126,14 @@ def setup(self) -> None: self.encoder_vf = self.config.vf_config.build() self.pi = FCNet( - input_dim=self.config.pi_encoder_config.output_dim, - output_dim=self.config.pi_config.output_dim, + input_dim=self.config.pi_config.output_dim, + output_dim=2, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) self.vf = FCNet( - input_dim=self.config.vf_encoder_config.output_dim, + input_dim=self.config.vf_config.output_dim, output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, activation=self.config.vf_config.activation, From 0dc91e2b4305918d7cda831983a6dce11f22367f Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 18:07:23 +0100 Subject: [PATCH 03/51] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 69 +++++++++++++++---- .../ppo/torch/ppo_torch_rl_module.py | 42 ----------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 2b7dcc75cbed..c261edef5edc 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,27 +1,74 @@ +import unittest + import numpy as np import tree -import unittest import ray import ray.rllib.algorithms.ppo as ppo - -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - get_expected_model_config, + PPOModuleConfig, PPOTorchRLModule, get_ppo_loss, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.core.rl_module.encoder import ( + FCConfig, + LSTMConfig, +) +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.test_utils import ( check, check_compute_single_action, check_train_results, framework_iterator, ) +from ray.rllib.utils.torch_utils import convert_to_torch_tensor + + +# TODO: Most of the neural network, and model specs in this file will eventually be +# retreived from the model catalog. That includes FCNet, Encoder, etc. +def get_expected_model_config(env, lstm, shared_encoder): + assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" + if lstm: + encoder_config_class = LSTMConfig + else: + encoder_config_class = FCConfig + + pi_config = encoder_config_class( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + vf_config = encoder_config_class( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + if shared_encoder: + shared_encoder_config = encoder_config_class( + input_dim=env.observation_space.shape[0], + hidden_layers=[32], + activation="ReLU", + ) + pi_config.input_dim = 32 + vf_config.input_dim = 32 + else: + shared_encoder_config = None + pi_config.input_dim = env.observation_space.shape[0] + vf_config.input_dim = env.observation_space.shape[0] + + return PPOModuleConfig( + observation_space=env.observation_space, + action_space=env.action_space, + shared_encoder_config=shared_encoder_config, + pi_config=pi_config, + vf_config=vf_config, + shared_encoder=shared_encoder, + ) def get_model_config(framework, lstm=False): @@ -197,12 +244,6 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() - def test_torch_model_creation(self): - pass - - def test_torch_model_creation_lstm(self): - pass - def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: for fwd_fn in ["forward_exploration", "forward_inference"]: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 0b28566001ac..4be461f1281a 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -42,48 +42,6 @@ def get_ppo_loss(fwd_in, fwd_out): return loss -# TODO: Most of the neural network, and model specs in this file will eventually be -# retreived from the model catalog. That includes FCNet, Encoder, etc. -def get_expected_model_config(env, lstm, shared_encoder): - if lstm: - encoder_config_class = LSTMConfig - else: - encoder_config_class = FCConfig - - pi_config = encoder_config_class( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - vf_config = encoder_config_class( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - if shared_encoder: - shared_encoder_config = encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - pi_config.input_dim = 32 - vf_config.input_dim = 32 - else: - shared_encoder_config = None - pi_config.input_dim = env.observation_space.shape[0] - vf_config.input_dim = env.observation_space.shape[0] - - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, - ) - - @dataclass class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO module. From a9a94983588425d09ace6381a5cc46b9189a9035 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 16 Dec 2022 19:06:05 +0100 Subject: [PATCH 04/51] wip Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 103 ++++++++++-------- .../ppo/torch/ppo_torch_rl_module.py | 37 +++---- rllib/core/rl_module/encoder.py | 2 +- 3 files changed, 73 insertions(+), 69 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index c261edef5edc..979d2eb22145 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -2,6 +2,7 @@ import numpy as np import tree +import gym import ray import ray.rllib.algorithms.ppo as ppo @@ -28,43 +29,55 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -# TODO: Most of the neural network, and model specs in this file will eventually be -# retreived from the model catalog. That includes FCNet, Encoder, etc. +# Get a model config that we expect from the catalog def get_expected_model_config(env, lstm, shared_encoder): assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" - if lstm: - encoder_config_class = LSTMConfig - else: - encoder_config_class = FCConfig - pi_config = encoder_config_class( - output_dim=32, + shared_encoder_config = FCConfig( + input_dim=env.observation_space.shape[0], hidden_layers=[32], activation="ReLU", ) - vf_config = encoder_config_class( + + if lstm: + pi_encoder_config = LSTMConfig( + hidden_dim=32, + batch_first=True, + output_dim=32, + num_layers=1, + ) + else: + pi_encoder_config = FCConfig( + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + + vf_encoder_config = FCConfig( output_dim=32, hidden_layers=[32], activation="ReLU", ) - if shared_encoder: - shared_encoder_config = encoder_config_class( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - pi_config.input_dim = 32 - vf_config.input_dim = 32 - else: + if not shared_encoder: shared_encoder_config = None - pi_config.input_dim = env.observation_space.shape[0] - vf_config.input_dim = env.observation_space.shape[0] + pi_encoder_config.input_dim = env.observation_space.shape[0] + vf_encoder_config.input_dim = env.observation_space.shape[0] + + pi_config = FCConfig() + vf_config = FCConfig() + + if isinstance(env.action_space, gym.spaces.Discrete): + pi_config.output_dim = env.action_space.n + else: + pi_config.output_dim = env.action_space.shape[0] * 2 return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, shared_encoder_config=shared_encoder_config, + pi_encoder_config=pi_encoder_config, + vf_encoder_config=vf_encoder_config, pi_config=pi_config, vf_config=vf_config, shared_encoder=shared_encoder, @@ -245,10 +258,11 @@ def test_ppo_exploration_setup(self): trainer.stop() def test_rollouts(self): - for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False]: # , True]" + for lstm in [False, True]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -261,18 +275,11 @@ def test_rollouts(self): module = PPOTorchRLModule(config) obs = env.reset() + if lstm: - states = [ - s.get_initial_state() - for s in ( - module.shared_encoder, - module.encoder_vf, - module.encoder_pi, - ) - ] batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - **{f"state_in_{i}": s for i, s in enumerate(states)}, + "state_in": module.pi_encoder.get_inital_state() } else: batch = { @@ -285,10 +292,11 @@ def test_rollouts(self): module.forward_inference(batch) def test_forward_train(self): - for env_name in ["CartPole-v1", "Pendulum-v1"]: # , "BreakoutNoFrameskip-v4"]: + # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space + for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False]: # , True]" + for lstm in [False, True]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -305,19 +313,19 @@ def test_forward_train(self): obs = env.reset() tstep = 0 if lstm: - states = {} - for i, model in enumerate( - [ - module.shared_encoder, - module.encoder_pi, - module.encoder_vf, - ] - ): - states[i] = model.get_inital_state() + # TODO (Artur): Multiple states + state_in = module.pi_encoder.get_inital_state() while tstep < 10: - fwd_out = module.forward_exploration( - {"obs": convert_to_torch_tensor(obs)[None]} - ) + if lstm: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + "state_in": state_in, + } + else: + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(batch) action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) @@ -331,9 +339,8 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - for k, v in states.items(): - step[f"state_in_{k}"] = v - states[k] = fwd_out["state_out"][k] + step[f"state_in"] = state_in + state_in = fwd_out["state_out"] batch.append(step) obs = new_obs tstep += 1 diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 4be461f1281a..640ca2a3e523 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -58,6 +58,8 @@ class PPOModuleConfig(RLModuleConfig): pi_config: FCConfig = None vf_config: FCConfig = None + pi_encoder_config: FCConfig = None + vf_encoder_config: FCConfig = None shared_encoder_config: FCConfig = None free_log_std: bool = False shared_encoder: bool = True @@ -76,22 +78,22 @@ def setup(self) -> None: if self.config.shared_encoder: self.shared_encoder = self.config.shared_encoder_config.build() - self.encoder_pi = IdentityEncoder(self.config.pi_config) - self.encoder_vf = IdentityEncoder(self.config.vf_config) + self.pi_encoder = IdentityEncoder(self.config.pi_encoder_config) + self.vf_encoder = IdentityEncoder(self.config.vf_encoder_config) else: self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) - self.encoder_pi = self.config.pi_config.build() - self.encoder_vf = self.config.vf_config.build() + self.pi_encoder = self.config.pi_encoder_config.build() + self.vf_encoder = self.config.vf_encoder_config.build() self.pi = FCNet( - input_dim=self.config.pi_config.output_dim, - output_dim=2, + input_dim=self.config.pi_encoder_config.output_dim, + output_dim=self.config.pi_config.output_dim, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) self.vf = FCNet( - input_dim=self.config.vf_config.output_dim, + input_dim=self.config.vf_encoder_config.output_dim, output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, activation=self.config.vf_config.activation, @@ -155,12 +157,8 @@ def from_model_config( output_dim=model_config["fcnet_hiddens"][-1], ) - vf_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - ) + pi_encoder_config = FCConfig() + vf_encoder_config = FCConfig() pi_config = FCConfig() vf_config = FCConfig() @@ -178,7 +176,7 @@ def from_model_config( # build pi network shared_encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = shared_encoder_config.output_dim + pi_encoder_config.input_dim = shared_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n @@ -211,7 +209,7 @@ def get_initial_state(self) -> NestedDict: if isinstance(self.shared_encoder, LSTMEncoder): return self.shared_encoder.get_inital_state() else: - return self.encoder_pi.get_inital_state() + return self.pi_encoder.get_inital_state() return {} @override(RLModule) @@ -224,10 +222,9 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder(batch) - pi_enc_out = self.pi_encoder(shared_enc_out) - - action_logits = self.pi(pi_enc_out[ENCODER_OUT]) + encoder_out = self.shared_encoder(batch) + encoder_out_pi = self.pi_encoder(encoder_out) + action_logits = self.pi(encoder_out_pi["embedding"]) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) @@ -269,7 +266,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) + action_logits = self.pi(encoder_out_pi["embedding"]) output = {} if self._is_discrete: diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index f3bb22b46900..c18a735f4864 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -75,7 +75,7 @@ def __init__(self, config: EncoderConfig) -> None: self._input_spec = self.input_spec() self._output_spec = self.output_spec() - def get_initial_state(self): + def get_inital_state(self): return [] def input_spec(self): From 9930d79482994bb27e16fef32984c50925efb52b Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Sun, 18 Dec 2022 22:34:21 +0100 Subject: [PATCH 05/51] mutually exclusive encoders, tests passing Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 82 ++++++++++--------- .../ppo/torch/ppo_torch_rl_module.py | 40 ++++----- rllib/core/rl_module/encoder.py | 15 ++-- 3 files changed, 69 insertions(+), 68 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 979d2eb22145..3653cf67328e 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -13,6 +13,7 @@ get_ppo_loss, ) from ray.rllib.core.rl_module.encoder import ( + IdentityConfig, FCConfig, LSTMConfig, ) @@ -32,38 +33,42 @@ # Get a model config that we expect from the catalog def get_expected_model_config(env, lstm, shared_encoder): assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" + obs_dim = env.observation_space.shape[0] - shared_encoder_config = FCConfig( - input_dim=env.observation_space.shape[0], - hidden_layers=[32], - activation="ReLU", - ) - - if lstm: - pi_encoder_config = LSTMConfig( - hidden_dim=32, - batch_first=True, + if shared_encoder: + assert not lstm, "LSTM can only be used in PI" + shared_encoder_config = FCConfig( + input_dim=obs_dim, + hidden_layers=[32], + activation="ReLU", output_dim=32, - num_layers=1, ) + pi_encoder_config = IdentityConfig(output_dim=32) + vf_encoder_config = IdentityConfig(output_dim=32) else: - pi_encoder_config = FCConfig( + shared_encoder_config = IdentityConfig(output_dim=obs_dim) + if lstm: + pi_encoder_config = LSTMConfig( + input_dim=obs_dim, + hidden_dim=32, + batch_first=True, + output_dim=32, + num_layers=1, + ) + else: + pi_encoder_config = FCConfig( + input_dim=obs_dim, + output_dim=32, + hidden_layers=[32], + activation="ReLU", + ) + vf_encoder_config = FCConfig( + input_dim=obs_dim, output_dim=32, hidden_layers=[32], activation="ReLU", ) - vf_encoder_config = FCConfig( - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - if not shared_encoder: - shared_encoder_config = None - pi_encoder_config.input_dim = env.observation_space.shape[0] - vf_encoder_config.input_dim = env.observation_space.shape[0] - pi_config = FCConfig() vf_config = FCConfig() @@ -147,7 +152,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init() + ray.init(local_mode=True) @classmethod def tearDownClass(cls): @@ -262,13 +267,12 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False, True]: + # TODO: LSTM = True + for lstm in [False]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" ) - import gym - env = gym.make(env_name) config = get_expected_model_config(env, lstm, shared_encoder) @@ -279,7 +283,7 @@ def test_rollouts(self): if lstm: batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": module.pi_encoder.get_inital_state() + "state_in": module.pi_encoder.get_inital_state(), } else: batch = { @@ -296,20 +300,19 @@ def test_forward_train(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [False, True]: + # TODO: LSTM = True + for lstm in [False]: print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" ) - import gym - env = gym.make(env_name) config = get_expected_model_config(env, lstm, shared_encoder) module = PPOTorchRLModule(config) # collect a batch of data - batch = [] + batches = [] obs = env.reset() tstep = 0 if lstm: @@ -317,20 +320,21 @@ def test_forward_train(self): state_in = module.pi_encoder.get_inital_state() while tstep < 10: if lstm: - batch = { + input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], "state_in": state_in, + SampleBatch.SEQ_LENS: np.array([1]), } else: - batch = { + input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None] } - fwd_out = module.forward_exploration(batch) + fwd_out = module.forward_exploration(input_batch) action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) new_obs, reward, done, _ = env.step(action) - step = { + output_batch = { SampleBatch.OBS: obs, SampleBatch.NEXT_OBS: new_obs, SampleBatch.ACTIONS: action, @@ -339,14 +343,14 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - step[f"state_in"] = state_in + output_batch["state_in"] = state_in state_in = fwd_out["state_out"] - batch.append(step) + batches.append(output_batch) obs = new_obs tstep += 1 # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batch) + batch = tree.map_structure(lambda *x: list(x), *batches) # convert dict of lists to dict of tensors fwd_in = { k: convert_to_torch_tensor(np.array(v)) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 640ca2a3e523..98c4ff646110 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -76,14 +76,9 @@ def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - if self.config.shared_encoder: - self.shared_encoder = self.config.shared_encoder_config.build() - self.pi_encoder = IdentityEncoder(self.config.pi_encoder_config) - self.vf_encoder = IdentityEncoder(self.config.vf_encoder_config) - else: - self.shared_encoder = IdentityEncoder(self.config.shared_encoder_config) - self.pi_encoder = self.config.pi_encoder_config.build() - self.vf_encoder = self.config.vf_encoder_config.build() + self.shared_encoder = self.config.shared_encoder_config.build() + self.pi_encoder = self.config.pi_encoder_config.build() + self.vf_encoder = self.config.vf_encoder_config.build() self.pi = FCNet( input_dim=self.config.pi_encoder_config.output_dim, @@ -143,22 +138,27 @@ def from_model_config( shared_encoder_config = IdentityConfig(output_dim=obs_dim) if use_lstm: - assert vf_share_layers, "LSTM not supported with vf_share_layers=False" - shared_encoder_config = LSTMConfig( + pi_encoder_config = LSTMConfig( + input_dim=shared_encoder_config.output_dim, hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], output_dim=model_config["lstm_cell_size"], num_layers=1, ) else: - shared_encoder_config = FCConfig( + pi_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) - pi_encoder_config = FCConfig() - vf_encoder_config = FCConfig() + vf_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=fcnet_hiddens, + activation=activation, + output_dim=model_config["fcnet_hiddens"][-1], + ) pi_config = FCConfig() vf_config = FCConfig() @@ -177,14 +177,15 @@ def from_model_config( # build pi network shared_encoder_config.input_dim = observation_space.shape[0] pi_encoder_config.input_dim = shared_encoder_config.output_dim - + pi_config.input_dim = pi_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_config.input_dim = shared_encoder_config.output_dim + vf_encoder_config.input_dim = shared_encoder_config.output_dim + vf_config.input_dim = vf_encoder_config.output_dim vf_config.output_dim = 1 config_ = PPOModuleConfig( @@ -222,9 +223,10 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi["embedding"]) + shared_enc_out = self.shared_encoder(batch) + pi_enc_out = self.pi_encoder(shared_enc_out) + + action_logits = self.pi(pi_enc_out[ENCODER_OUT]) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) @@ -266,7 +268,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi["embedding"]) + action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) output = {} if self._is_discrete: diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index c18a735f4864..bc7ddfac9e79 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -16,7 +16,6 @@ ENCODER_OUT = "encoder_out" STATE_IN = "state_in" -STATE_OUT = "state_out" @dataclass @@ -110,7 +109,7 @@ def input_spec(self): ) def output_spec(self): - return SpecDict( + return ModelSpec( {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} ) @@ -159,13 +158,9 @@ def output_spec(self): return SpecDict( { ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { - "h": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - "c": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), + "state_out": { + "h": TorchTensorSpec("b, h", h=config.hidden_dim), + "c": TorchTensorSpec("b, h", h=config.hidden_dim), }, } ) @@ -190,7 +185,7 @@ def _forward(self, input_dict: SampleBatch): return { ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + "state_out": tree.map_structure(lambda x: x.transpose(0, 1), states_o), } From 109e56d2d0d7f489444c8091a85d9ee4ad560282 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 19 Dec 2022 15:41:09 +0100 Subject: [PATCH 06/51] add lstm code Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 51 +++++++++++++------ 1 file changed, 36 insertions(+), 15 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 3653cf67328e..fd5dce3e3b9d 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -3,6 +3,7 @@ import numpy as np import tree import gym +import torch import ray import ray.rllib.algorithms.ppo as ppo @@ -267,8 +268,11 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - # TODO: LSTM = True - for lstm in [False]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -280,15 +284,17 @@ def test_rollouts(self): obs = env.reset() + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } + if lstm: - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": module.pi_encoder.get_inital_state(), - } - else: - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch["state_in"] = state_in + batch["seq_lens"] = torch.Tensor([1]) if fwd_fn == "forward_exploration": module.forward_exploration(batch) @@ -300,8 +306,11 @@ def test_forward_train(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - # TODO: LSTM = True - for lstm in [False]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue print( f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}" @@ -318,6 +327,10 @@ def test_forward_train(self): if lstm: # TODO (Artur): Multiple states state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + output_states = state_in while tstep < 10: if lstm: input_batch = { @@ -343,7 +356,12 @@ def test_forward_train(self): } if lstm: assert "state_out" in fwd_out - output_batch["state_in"] = state_in + if tstep > 0: # First states are already added + output_states = tree.map_structure( + lambda *s: torch.cat((s[0], s[1])), + output_states, + state_in, + ) state_in = fwd_out["state_out"] batches.append(output_batch) obs = new_obs @@ -356,10 +374,13 @@ def test_forward_train(self): k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items() } + if lstm: + fwd_in["state_in"] = output_states + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) # forward train - # before training make sure it's on the right device and it's on - # trianing mode + # before training make sure module is on the right device and in + # training mode module.to("cpu") module.train() fwd_out = module.forward_train(fwd_in) From 532c8c05044fbcac62e2bce4b19d16b1cb376417 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 10:41:19 +0100 Subject: [PATCH 07/51] better docs for get expected model config Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_with_rl_module.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index fd5dce3e3b9d..b09e98d77c32 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -31,8 +31,19 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -# Get a model config that we expect from the catalog -def get_expected_model_config(env, lstm, shared_encoder): +def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: + """Get a PPOModuleConfig that we would expect from the catalog otherwise. + + Args: + env: Environment for which we build the model later + lstm: If True, build recurrent pi encoder + shared_encoder: If True, build a shared encoder for pi and vf, where pi + encoder and vf encoder will be identity. If False, the shared encoder + will be identity. + + Returns: + A PPOModuleConfig containing the relevant configs to build PPORLModule + """ assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" obs_dim = env.observation_space.shape[0] From 31ae2a064d6aacec9984d9861aa346827ef13ddc Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 11:56:19 +0100 Subject: [PATCH 08/51] kourosh's comments Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 395 ++++++------------ .../ppo/tests/test_ppo_with_rl_module.py | 216 +--------- 2 files changed, 133 insertions(+), 478 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 6b0e0161a5a6..7adde54cfffd 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -1,39 +1,19 @@ -import itertools import ray import unittest import numpy as np -import gymnasium as gym +import gym import torch -import tensorflow as tf import tree from ray.rllib import SampleBatch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - PPOTorchRLModule, - PPOModuleConfig, -) -from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( - PPOTfRLModule, - PPOTfModuleConfig, -) -from ray.rllib.core.rl_module.encoder import ( - FCConfig, - IdentityConfig, - LSTMConfig, - STATE_IN, - STATE_OUT, -) -from ray.rllib.core.rl_module.encoder_tf import ( - FCTfConfig, - IdentityTfConfig, -) +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule, \ + get_ppo_loss, PPOModuleConfig +from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor -def get_expected_model_config_torch( - env: gym.Env, lstm: bool, shared_encoder: bool -) -> PPOModuleConfig: +def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: @@ -46,9 +26,8 @@ def get_expected_model_config_torch( Returns: A PPOModuleConfig containing the relevant configs to build PPORLModule """ - assert len(env.observation_space.shape) == 1, ( - "No multidimensional obs space " "supported." - ) + assert len(env.observation_space.shape) == 1, "No multidimensional obs space " \ + "supported." obs_dim = env.observation_space.shape[0] if shared_encoder: @@ -105,99 +84,6 @@ def get_expected_model_config_torch( ) -def get_expected_model_config_tf( - env: gym.Env, shared_encoder: bool -) -> PPOTfModuleConfig: - """Get a PPOTfModuleConfig that we would expect from the catalog otherwise. - - Args: - env: Environment for which we build the model later - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. - - Returns: - A PPOTfModuleConfig containing the relevant configs to build PPOTfRLModule. - """ - assert len(env.observation_space.shape) == 1, ( - "No multidimensional obs space " "supported." - ) - obs_dim = env.observation_space.shape[0] - - if shared_encoder: - shared_encoder_config = FCTfConfig( - input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", - output_dim=32, - ) - else: - shared_encoder_config = IdentityTfConfig(output_dim=obs_dim) - pi_config = FCConfig() - vf_config = FCConfig() - pi_config.input_dim = vf_config.input_dim = shared_encoder_config.output_dim - - if isinstance(env.action_space, gym.spaces.Discrete): - pi_config.output_dim = env.action_space.n - else: - pi_config.output_dim = env.action_space.shape[0] * 2 - - pi_config.hidden_layers = vf_config.hidden_layers = [32] - pi_config.activation = vf_config.activation = "ReLU" - - return PPOTfModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, - ) - - -def dummy_torch_ppo_loss(batch, fwd_out): - """Dummy PPO loss function for testing purposes. - - Will eventually use the actual PPO loss function implemented in the PPOTfTrainer. - - Args: - batch: SampleBatch used for training. - fwd_out: Forward output of the model. - - Returns: - Loss tensor - """ - # TODO: we should replace these components later with real ppo components when - # RLOptimizer and RLModule are integrated together. - # this is not exactly a ppo loss, just something to show that the - # forward train works - adv = batch[SampleBatch.REWARDS] - fwd_out[SampleBatch.VF_PREDS] - actor_loss = -(fwd_out[SampleBatch.ACTION_LOGP] * adv).mean() - critic_loss = (adv**2).mean() - loss = actor_loss + critic_loss - - return loss - - -def dummy_tf_ppo_loss(batch, fwd_out): - """Dummy PPO loss function for testing purposes. - - Will eventually use the actual PPO loss function implemented in the PPOTfTrainer. - - Args: - batch: SampleBatch used for training. - fwd_out: Forward output of the model. - - Returns: - Loss tensor - """ - adv = batch[SampleBatch.REWARDS] - fwd_out[SampleBatch.VF_PREDS] - action_probs = fwd_out[SampleBatch.ACTION_DIST].logp(batch[SampleBatch.ACTIONS]) - actor_loss = -tf.reduce_mean(action_probs * adv) - critic_loss = tf.reduce_mean(tf.square(adv)) - return actor_loss + critic_loss - - class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): @@ -207,152 +93,135 @@ def setUpClass(cls): def tearDownClass(cls): ray.shutdown() - def get_ppo_module(self, framwework, env, lstm, shared_encoder): - if framwework == "torch": - config = get_expected_model_config_torch(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - else: - config = get_expected_model_config_tf(env, shared_encoder) - module = PPOTfRLModule(config) - return module - - def get_input_batch_from_obs(self, framework, obs): - if framework == "torch": - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - } - else: - batch = {SampleBatch.OBS: np.array([obs])} - return batch - def test_rollouts(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - frameworks = ["torch", "tf2"] - env_names = ["CartPole-v1", "Pendulum-v1"] - fwd_fns = ["forward_exploration", "forward_inference"] - shared_encoders = [False, True] - ltsms = [False, True] - config_combinations = [frameworks, env_names, fwd_fns, shared_encoders, ltsms] - for config in itertools.product(*config_combinations): - fw, env_name, fwd_fn, shared_encoder, lstm = config - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - if lstm and fw == "tf2": - # LSTM not implemented in TF2 yet - continue - print(f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" f"={lstm}") - env = gym.make(env_name) - module = self.get_ppo_module(fw, env, lstm, shared_encoder) - - obs, _ = env.reset() - - batch = self.get_input_batch_from_obs(fw, obs) + for env_name in ["CartPole-v1", "Pendulum-v1"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + obs = env.reset() + + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } + + if lstm: + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch["state_in"] = state_in + batch["seq_lens"] = torch.Tensor([1]) + + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + elif fwd_fn == "forward_inference": + module.forward_inference(batch) - if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - batch[STATE_IN] = state_in - batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) - - if fwd_fn == "forward_exploration": - module.forward_exploration(batch) - else: - module.forward_inference(batch) def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - frameworks = ["torch", "tf2"] - env_names = ["CartPole-v1", "Pendulum-v1"] - shared_encoders = [False, True] - ltsms = [False, True] - config_combinations = [frameworks, env_names, shared_encoders, ltsms] - for config in itertools.product(*config_combinations): - fw, env_name, shared_encoder, lstm = config - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - if lstm and fw == "tf2": - # LSTM not implemented in TF2 yet - continue - print(f"[ENV={env_name}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}") - env = gym.make(env_name) - - module = self.get_ppo_module(fw, env, lstm, shared_encoder) - - # collect a batch of data - batches = [] - obs, _ = env.reset() - tstep = 0 - if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - initial_state = state_in - while tstep < 10: - if lstm: - input_batch = self.get_input_batch_from_obs(fw, obs) - input_batch[STATE_IN] = state_in - input_batch[SampleBatch.SEQ_LENS] = np.array([1]) - else: - input_batch = self.get_input_batch_from_obs(fw, obs) - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy(fwd_out["action_dist"].sample()[0]) - new_obs, reward, terminated, truncated, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.TERMINATEDS: np.array(terminated), - SampleBatch.TRUNCATEDS: np.array(truncated), - } - if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: np.array(x), *batches) - # convert dict of lists to dict of tensors - if fw == "torch": - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items() - } - if lstm: - fwd_in[STATE_IN] = initial_state - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) - - # forward train - # before training make sure module is on the right device - # and in training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = dummy_torch_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - self.assertIsNotNone(param.grad) - else: - with tf.GradientTape() as tape: - fwd_out = module.forward_train(batch) - loss = dummy_tf_ppo_loss(batch, fwd_out) - grads = tape.gradient(loss, module.trainable_variables) - for grad in grads: - self.assertIsNotNone(grad) - - -if __name__ == "__main__": - import pytest - import sys + for env_name in ["CartPole-v1", "Pendulum-v1"]: + for fwd_fn in ["forward_exploration", "forward_inference"]: + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batches = [] + obs = env.reset() + tstep = 0 + if lstm: + # TODO (Artur): Multiple states + state_in = module.pi_encoder.get_inital_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + output_states = state_in + while tstep < 10: + if lstm: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + "state_in": state_in, + SampleBatch.SEQ_LENS: np.array([1]), + } + else: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), + } + if lstm: + assert "state_out" in fwd_out + if tstep > 0: # First states are already added + + # Extend nested batches of states + output_states = tree.map_structure( + lambda *s: torch.cat((s[0], s[1])), + output_states, + state_in, + ) + state_in = fwd_out["state_out"] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batches) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + if lstm: + fwd_in["state_in"] = output_states + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) + + # forward train + # before training make sure module is on the right device and in + # training mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + pass + self.assertIsNotNone(param.grad) - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index b09e98d77c32..b063523e99b6 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,104 +1,18 @@ import unittest import numpy as np -import tree -import gym -import torch import ray import ray.rllib.algorithms.ppo as ppo from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - PPOModuleConfig, - PPOTorchRLModule, - get_ppo_loss, -) -from ray.rllib.core.rl_module.encoder import ( - IdentityConfig, - FCConfig, - LSTMConfig, -) from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY -from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.test_utils import ( check, check_compute_single_action, check_train_results, framework_iterator, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor - - -def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: - """Get a PPOModuleConfig that we would expect from the catalog otherwise. - - Args: - env: Environment for which we build the model later - lstm: If True, build recurrent pi encoder - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. - - Returns: - A PPOModuleConfig containing the relevant configs to build PPORLModule - """ - assert not len(env.observation_space.shape) == 3, "Implement VisionNet first!" - obs_dim = env.observation_space.shape[0] - - if shared_encoder: - assert not lstm, "LSTM can only be used in PI" - shared_encoder_config = FCConfig( - input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", - output_dim=32, - ) - pi_encoder_config = IdentityConfig(output_dim=32) - vf_encoder_config = IdentityConfig(output_dim=32) - else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - if lstm: - pi_encoder_config = LSTMConfig( - input_dim=obs_dim, - hidden_dim=32, - batch_first=True, - output_dim=32, - num_layers=1, - ) - else: - pi_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - vf_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - - pi_config = FCConfig() - vf_config = FCConfig() - - if isinstance(env.action_space, gym.spaces.Discrete): - pi_config.output_dim = env.action_space.n - else: - pi_config.output_dim = env.action_space.shape[0] * 2 - - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, - ) def get_model_config(framework, lstm=False): @@ -164,7 +78,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls): @@ -274,134 +188,6 @@ def test_ppo_exploration_setup(self): check(np.mean(actions), 1.5, atol=0.2) trainer.stop() - def test_rollouts(self): - # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" - ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - obs = env.reset() - - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - } - - if lstm: - state_in = module.pi_encoder.get_inital_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - batch["state_in"] = state_in - batch["seq_lens"] = torch.Tensor([1]) - - if fwd_fn == "forward_exploration": - module.forward_exploration(batch) - elif fwd_fn == "forward_inference": - module.forward_inference(batch) - - def test_forward_train(self): - # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" - ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs = env.reset() - tstep = 0 - if lstm: - # TODO (Artur): Multiple states - state_in = module.pi_encoder.get_inital_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - output_states = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, done, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), - } - if lstm: - assert "state_out" in fwd_out - if tstep > 0: # First states are already added - output_states = tree.map_structure( - lambda *s: torch.cat((s[0], s[1])), - output_states, - state_in, - ) - state_in = fwd_out["state_out"] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) - for k, v in batch.items() - } - if lstm: - fwd_in["state_in"] = output_states - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) - - # forward train - # before training make sure module is on the right device and in - # training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - self.assertIsNotNone(param.grad) - if __name__ == "__main__": import pytest From 018e223e88a97c914d1533a75f13f0b553e94772 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 16:51:27 +0100 Subject: [PATCH 09/51] lstm fixed, tests working Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 14 ++++++++------ rllib/core/rl_module/encoder.py | 8 ++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 7adde54cfffd..927c0c9c7665 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -6,8 +6,11 @@ import tree from ray.rllib import SampleBatch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule, \ - get_ppo_loss, PPOModuleConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + PPOTorchRLModule, + get_ppo_loss, + PPOModuleConfig, +) from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -26,8 +29,9 @@ def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: Returns: A PPOModuleConfig containing the relevant configs to build PPORLModule """ - assert len(env.observation_space.shape) == 1, "No multidimensional obs space " \ - "supported." + assert len(env.observation_space.shape) == 1, ( + "No multidimensional obs space " "supported." + ) obs_dim = env.observation_space.shape[0] if shared_encoder: @@ -131,7 +135,6 @@ def test_rollouts(self): elif fwd_fn == "forward_inference": module.forward_inference(batch) - def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: @@ -224,4 +227,3 @@ def test_forward_train(self): for param in module.parameters(): pass self.assertIsNotNone(param.grad) - diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index bc7ddfac9e79..6fb817c264bf 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -159,8 +159,12 @@ def output_spec(self): { ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), "state_out": { - "h": TorchTensorSpec("b, h", h=config.hidden_dim), - "c": TorchTensorSpec("b, h", h=config.hidden_dim), + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), }, } ) From b38d501882ca61a11e1e8150e2bee6974c02a617 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 16:54:48 +0100 Subject: [PATCH 10/51] add state out Signed-off-by: Artur Niederfahrenhorst --- .../algorithms/ppo/tests/test_ppo_rl_module.py | 18 ++++++++++++------ rllib/core/rl_module/encoder.py | 5 +++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 927c0c9c7665..77f2894fa434 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -11,7 +11,13 @@ get_ppo_loss, PPOModuleConfig, ) -from ray.rllib.core.rl_module.encoder import FCConfig, IdentityConfig, LSTMConfig +from ray.rllib.core.rl_module.encoder import ( + FCConfig, + IdentityConfig, + LSTMConfig, + STATE_IN, + STATE_OUT, +) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -127,7 +133,7 @@ def test_rollouts(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - batch["state_in"] = state_in + batch[STATE_IN] = state_in batch["seq_lens"] = torch.Tensor([1]) if fwd_fn == "forward_exploration": @@ -169,7 +175,7 @@ def test_forward_train(self): if lstm: input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - "state_in": state_in, + STATE_IN: state_in, SampleBatch.SEQ_LENS: np.array([1]), } else: @@ -189,7 +195,7 @@ def test_forward_train(self): SampleBatch.DONES: np.array(done), } if lstm: - assert "state_out" in fwd_out + assert STATE_OUT in fwd_out if tstep > 0: # First states are already added # Extend nested batches of states @@ -198,7 +204,7 @@ def test_forward_train(self): output_states, state_in, ) - state_in = fwd_out["state_out"] + state_in = fwd_out[STATE_OUT] batches.append(output_batch) obs = new_obs tstep += 1 @@ -211,7 +217,7 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in["state_in"] = output_states + fwd_in[STATE_IN] = output_states fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) # forward train diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index 6fb817c264bf..fef1bb322ec0 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -16,6 +16,7 @@ ENCODER_OUT = "encoder_out" STATE_IN = "state_in" +STATE_OUT = "state_out" @dataclass @@ -158,7 +159,7 @@ def output_spec(self): return SpecDict( { ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - "state_out": { + STATE_OUT: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -189,7 +190,7 @@ def _forward(self, input_dict: SampleBatch): return { ENCODER_OUT: x, - "state_out": tree.map_structure(lambda x: x.transpose(0, 1), states_o), + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), } From 462fc4d9149143187feb39f458c84d2b49512f9f Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 17:30:09 +0100 Subject: [PATCH 11/51] add __main__ to test Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 77f2894fa434..f13790046828 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -233,3 +233,10 @@ def test_forward_train(self): for param in module.parameters(): pass self.assertIsNotNone(param.grad) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) From f93795a6cdc94ffecb423451cafbd3d45591f7ad Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 21:42:57 +0100 Subject: [PATCH 12/51] change lstm testing according to kourosh's comment Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 26 +++++++------------ .../ppo/torch/ppo_torch_rl_module.py | 13 +++++----- 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index f13790046828..b8acee1fed98 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -22,7 +22,9 @@ from ray.rllib.utils.torch_utils import convert_to_torch_tensor -def get_expected_model_config(env, lstm, shared_encoder) -> PPOModuleConfig: +def get_expected_model_config( + env: gym.Env, lstm: bool, shared_encoder: bool +) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: @@ -114,8 +116,8 @@ def test_rollouts(self): # TODO (Artur): Implement continue print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" + f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" + f"={lstm}" ) env = gym.make(env_name) @@ -134,11 +136,11 @@ def test_rollouts(self): lambda x: x[None], convert_to_torch_tensor(state_in) ) batch[STATE_IN] = state_in - batch["seq_lens"] = torch.Tensor([1]) + batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) if fwd_fn == "forward_exploration": module.forward_exploration(batch) - elif fwd_fn == "forward_inference": + else: module.forward_inference(batch) def test_forward_train(self): @@ -170,7 +172,7 @@ def test_forward_train(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - output_states = state_in + initial_state = state_in while tstep < 10: if lstm: input_batch = { @@ -196,14 +198,6 @@ def test_forward_train(self): } if lstm: assert STATE_OUT in fwd_out - if tstep > 0: # First states are already added - - # Extend nested batches of states - output_states = tree.map_structure( - lambda *s: torch.cat((s[0], s[1])), - output_states, - state_in, - ) state_in = fwd_out[STATE_OUT] batches.append(output_batch) obs = new_obs @@ -217,8 +211,8 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in[STATE_IN] = output_states - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([1] * 10) + fwd_in[STATE_IN] = initial_state + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) # forward train # before training make sure module is on the right device and in diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 98c4ff646110..dadf59ef2925 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -205,13 +205,12 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if isinstance(self.config.shared_encoder_config, LSTMConfig): - # TODO (Kourosh): How does this work in RLlib today? - if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_inital_state() - else: - return self.pi_encoder.get_inital_state() - return {} + if isinstance(self.shared_encoder, LSTMEncoder): + return self.shared_encoder.get_inital_state() + elif isinstance(self.pi_encoder, LSTMEncoder): + return self.pi_encoder.get_inital_state() + else: + return NestedDict({}) @override(RLModule) def input_specs_inference(self) -> SpecDict: From 6ae344324321bcbe18d1f107d6fbf89157488022 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 21:46:34 +0100 Subject: [PATCH 13/51] fix get_initial_state Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 5 ++--- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 4 ++-- rllib/core/rl_module/encoder.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index b8acee1fed98..c6e9b4395bf5 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -131,7 +131,7 @@ def test_rollouts(self): } if lstm: - state_in = module.pi_encoder.get_inital_state() + state_in = module.get_initial_state() state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) @@ -167,8 +167,7 @@ def test_forward_train(self): obs = env.reset() tstep = 0 if lstm: - # TODO (Artur): Multiple states - state_in = module.pi_encoder.get_inital_state() + state_in = module.get_initial_state() state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index dadf59ef2925..b3cc51a732f3 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -206,9 +206,9 @@ def from_model_config( def get_initial_state(self) -> NestedDict: if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_inital_state() + return self.shared_encoder.get_initial_state() elif isinstance(self.pi_encoder, LSTMEncoder): - return self.pi_encoder.get_inital_state() + return self.pi_encoder.get_initial_state() else: return NestedDict({}) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index fef1bb322ec0..a4389c80ea27 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -75,7 +75,7 @@ def __init__(self, config: EncoderConfig) -> None: self._input_spec = self.input_spec() self._output_spec = self.output_spec() - def get_inital_state(self): + def get_initial_state(self): return [] def input_spec(self): From 548f42aa53e5c49b3fa9ef1b3210eb64293f9a2f Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 22:02:36 +0100 Subject: [PATCH 14/51] remove useless forward_exploration/forward_inference branch Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 149 +++++++++--------- 1 file changed, 74 insertions(+), 75 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index c6e9b4395bf5..0908a2ad7b80 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -146,86 +146,85 @@ def test_rollouts(self): def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [FWD={fwd_fn}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" + for shared_encoder in [False, True]: + for lstm in [True, False]: + if lstm and shared_encoder: + # Not yet implemented + # TODO (Artur): Implement + continue + print( + f"[ENV={env_name}] | [SHARED=" + f"{shared_encoder}] | LSTM={lstm}" + ) + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm, shared_encoder) + module = PPOTorchRLModule(config) + + # collect a batch of data + batches = [] + obs = env.reset() + tstep = 0 + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs = env.reset() - tstep = 0 + initial_state = state_in + while tstep < 10: if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - initial_state = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, done, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + STATE_IN: state_in, + SampleBatch.SEQ_LENS: np.array([1]), } - if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) - for k, v in batch.items() + else: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, done, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.DONES: np.array(done), } if lstm: - fwd_in[STATE_IN] = initial_state - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) - - # forward train - # before training make sure module is on the right device and in - # training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - pass - self.assertIsNotNone(param.grad) + assert STATE_OUT in fwd_out + state_in = fwd_out[STATE_OUT] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batches) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) + for k, v in batch.items() + } + if lstm: + fwd_in[STATE_IN] = initial_state + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) + + # forward train + # before training make sure module is on the right device and in + # training mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + pass + self.assertIsNotNone(param.grad) if __name__ == "__main__": From 92ae5100696b9fadf83d08a98c4ab7131294af39 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 21 Dec 2022 22:05:03 +0100 Subject: [PATCH 15/51] revert changes to test_ppo_with_rl_module.py Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index b063523e99b6..4f37ffd6476c 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -1,9 +1,9 @@ -import unittest - import numpy as np +import unittest import ray import ray.rllib.algorithms.ppo as ppo + from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY From 04228db6f62cbc8df244abf2318744dc383c018c Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 22 Dec 2022 17:15:37 +0100 Subject: [PATCH 16/51] remove pass Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 0908a2ad7b80..cbbd995250fa 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -223,7 +223,6 @@ def test_forward_train(self): # check that all neural net parameters have gradients for param in module.parameters(): - pass self.assertIsNotNone(param.grad) From 0ab7809517d892cc0917b445ecc4add3c4b1f793 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 22 Dec 2022 18:07:17 +0100 Subject: [PATCH 17/51] fix gym incompatability Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 9 +++++---- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index cbbd995250fa..8e3fa4201820 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -124,7 +124,7 @@ def test_rollouts(self): config = get_expected_model_config(env, lstm, shared_encoder) module = PPOTorchRLModule(config) - obs = env.reset() + obs, _ = env.reset() batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], @@ -163,7 +163,7 @@ def test_forward_train(self): # collect a batch of data batches = [] - obs = env.reset() + obs, _ = env.reset() tstep = 0 if lstm: state_in = module.get_initial_state() @@ -186,13 +186,14 @@ def test_forward_train(self): action = convert_to_numpy( fwd_out["action_dist"].sample().squeeze(0) ) - new_obs, reward, done, _ = env.step(action) + new_obs, reward, terminated, truncated, _ = env.step(action) output_batch = { SampleBatch.OBS: obs, SampleBatch.NEXT_OBS: new_obs, SampleBatch.ACTIONS: action, SampleBatch.REWARDS: np.array(reward), - SampleBatch.DONES: np.array(done), + SampleBatch.TERMINATEDS: np.array(terminated), + SampleBatch.TRUNCATEDS: np.array(truncated), } if lstm: assert STATE_OUT in fwd_out diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b3cc51a732f3..800c676b9731 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -23,7 +23,7 @@ LSTMEncoder, ENCODER_OUT, ) -from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from rllib.utils.gym import convert_old_gym_space_to_gymnasium_space torch, nn = try_import_torch() From 0b7536ceae35075bcc96f00410940ea5c2cceebd Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 27 Dec 2022 14:34:52 +0100 Subject: [PATCH 18/51] add missing ray. Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 800c676b9731..b3cc51a732f3 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -23,7 +23,7 @@ LSTMEncoder, ENCODER_OUT, ) -from rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space torch, nn = try_import_torch() From 77f991adb63d656370431c298b940184fe7dfae6 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 5 Jan 2023 17:32:19 +0100 Subject: [PATCH 19/51] test_ppo_rl_module working Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 38 ++++++++++- .../ppo/tests/test_ppo_with_rl_module.py | 2 +- .../ppo/torch/ppo_torch_rl_module.py | 64 +++++++++--------- rllib/core/rl_module/encoder.py | 66 +++++++++++-------- rllib/models/base_model.py | 59 ++++++++++++----- 5 files changed, 151 insertions(+), 78 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 8e3fa4201820..ccdb61f6d375 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -18,6 +18,7 @@ STATE_IN, STATE_OUT, ) +from ray.rllib.models.base_model import BaseModelIOKeys, ModelIOKeyHelper from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -43,17 +44,24 @@ def get_expected_model_config( obs_dim = env.observation_space.shape[0] if shared_encoder: + shared_encoder_kh = ModelIOKeyHelper("shared_encoder") assert not lstm, "LSTM can only be used in PI" shared_encoder_config = FCConfig( input_dim=obs_dim, hidden_layers=[32], activation="ReLU", output_dim=32, + input_key=SampleBatch.OBS, + output_key=shared_encoder_kh.create(BaseModelIOKeys.OUT), ) pi_encoder_config = IdentityConfig(output_dim=32) vf_encoder_config = IdentityConfig(output_dim=32) + pi_input_key = shared_encoder_config.output_key + vf_input_key = shared_encoder_config.output_key else: shared_encoder_config = IdentityConfig(output_dim=obs_dim) + pi_encoder_kh = ModelIOKeyHelper("pi_encoder") + vf_encoder_kh = ModelIOKeyHelper("vf_encoder") if lstm: pi_encoder_config = LSTMConfig( input_dim=obs_dim, @@ -61,6 +69,10 @@ def get_expected_model_config( batch_first=True, output_dim=32, num_layers=1, + input_key=SampleBatch.OBS, + state_in_key="state_in_0", + output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), + state_out_key="state_out_0", ) else: pi_encoder_config = FCConfig( @@ -68,16 +80,36 @@ def get_expected_model_config( output_dim=32, hidden_layers=[32], activation="ReLU", + input_key=SampleBatch.OBS, + output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), ) vf_encoder_config = FCConfig( input_dim=obs_dim, output_dim=32, hidden_layers=[32], activation="ReLU", + input_key=SampleBatch.OBS, + output_key=vf_encoder_kh.create(BaseModelIOKeys.OUT), ) - - pi_config = FCConfig() - vf_config = FCConfig() + pi_input_key = pi_encoder_config.output_key + vf_input_key = vf_encoder_config.output_key + pi_kh = ModelIOKeyHelper("pi_encoder") + pi_config = FCConfig( + input_dim=pi_encoder_config.output_dim, + hidden_layers=[32], + activation="ReLU", + input_key=pi_input_key, + output_key=pi_kh.create(BaseModelIOKeys.OUT), + ) + vf_kh = ModelIOKeyHelper("pi_encoder") + vf_config = FCConfig( + input_dim=vf_encoder_config.output_dim, + output_dim=1, + hidden_layers=[32], + activation="ReLU", + input_key=vf_input_key, + output_key=vf_kh.create(BaseModelIOKeys.OUT), + ) if isinstance(env.action_space, gym.spaces.Discrete): pi_config.output_dim = env.action_space.n diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 4f37ffd6476c..84100e07ae0e 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -78,7 +78,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init() + ray.init(local_mode=True) @classmethod def tearDownClass(cls): diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b3cc51a732f3..74f6ef247baf 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -16,12 +16,11 @@ TorchDiagGaussian, ) from ray.rllib.core.rl_module.encoder import ( - FCNet, FCConfig, LSTMConfig, IdentityConfig, LSTMEncoder, - ENCODER_OUT, + STATE_OUT, ) from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space @@ -79,20 +78,8 @@ def setup(self) -> None: self.shared_encoder = self.config.shared_encoder_config.build() self.pi_encoder = self.config.pi_encoder_config.build() self.vf_encoder = self.config.vf_encoder_config.build() - - self.pi = FCNet( - input_dim=self.config.pi_encoder_config.output_dim, - output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, - ) - - self.vf = FCNet( - input_dim=self.config.vf_encoder_config.output_dim, - output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, - ) + self.pi = self.config.pi_config.build() + self.vf = self.config.vf_config.build() self._is_discrete = isinstance( convert_old_gym_space_to_gymnasium_space(self.config.action_space), @@ -222,10 +209,10 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder(batch) - pi_enc_out = self.pi_encoder(shared_enc_out) - - action_logits = self.pi(pi_enc_out[ENCODER_OUT]) + encoder_out = self.shared_encoder(batch) + encoder_out_pi = self.pi_encoder(encoder_out) + pi_out = self.pi(encoder_out_pi) + action_logits = pi_out[self.pi.config.output_key] if self._is_discrete: action = torch.argmax(action_logits, dim=-1) @@ -234,12 +221,17 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = pi_enc_out.get("state_out", {}) + + if hasattr(self.shared_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] + if hasattr(self.pi_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] + return output @override(RLModule) def input_specs_exploration(self): - return self.shared_encoder.input_spec() + return self.shared_encoder.input_spec @override(RLModule) def output_specs_exploration(self) -> SpecDict: @@ -267,7 +259,9 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) + pi_out = self.pi(encoder_out_pi) + action_logits = pi_out[self.pi.config.output_key] + vf_out = self.vf(encoder_out_vf) output = {} if self._is_discrete: @@ -280,9 +274,13 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST_INPUTS] = {"loc": loc, "scale": scale} output[SampleBatch.ACTION_DIST] = action_dist + if hasattr(self.shared_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] + if hasattr(self.pi_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] + # compute the value function - output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1) - output["state_out"] = encoder_out_pi.get("state_out", {}) + output[SampleBatch.VF_PREDS] = vf_out[self.vf.config.output_key].squeeze(-1) return output @override(RLModule) @@ -293,7 +291,7 @@ def input_specs_train(self) -> SpecDict: action_dim = self.config.action_space.shape[0] action_spec = TorchTensorSpec("b, h", h=action_dim) - spec_dict = self.shared_encoder.input_spec() + spec_dict = self.shared_encoder.input_spec spec_dict.update({SampleBatch.ACTIONS: action_spec}) if SampleBatch.OBS in spec_dict: spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] @@ -317,9 +315,9 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: encoder_out = self.shared_encoder(batch) encoder_out_pi = self.pi_encoder(encoder_out) encoder_out_vf = self.vf_encoder(encoder_out) - - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) - vf = self.vf(encoder_out_vf[ENCODER_OUT]) + pi_out = self.pi(encoder_out_pi) + action_logits = pi_out[self.pi.config.output_key] + vf_out = self.vf(encoder_out_vf) if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) @@ -333,11 +331,15 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: output = { SampleBatch.ACTION_DIST: action_dist, SampleBatch.ACTION_LOGP: logp, - SampleBatch.VF_PREDS: vf.squeeze(-1), + SampleBatch.VF_PREDS: vf_out[self.vf.config.output_key].squeeze(-1), "entropy": entropy, } - output["state_out"] = encoder_out_pi.get("state_out", {}) + if hasattr(self.shared_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] + if hasattr(self.pi_encoder.config, "state_out_key"): + output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] + return output def __get_action_dist_type(self): diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index a4389c80ea27..bbe1d5084994 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field +from ray.rllib.models.base_model import Model from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict @@ -15,8 +16,8 @@ # TODO (Kourosh): Find a better / more straight fwd approach for sub-components ENCODER_OUT = "encoder_out" -STATE_IN = "state_in" -STATE_OUT = "state_out" +STATE_IN = "state_in_0" +STATE_OUT = "state_out_0" @dataclass @@ -29,6 +30,8 @@ class EncoderConfig: """ output_dim: int = None + input_key: str = None + output_key: str = None @dataclass @@ -63,29 +66,36 @@ class LSTMConfig(EncoderConfig): hidden_dim: int = None num_layers: int = None batch_first: bool = True + state_in_key: str = None + state_out_key: str = None def build(self): return LSTMEncoder(self) -class Encoder(nn.Module): +class Encoder(Model, nn.Module): def __init__(self, config: EncoderConfig) -> None: - super().__init__() + nn.Module.__init__(self) + Model.__init__(self) self.config = config - self._input_spec = self.input_spec() - self._output_spec = self.output_spec() def get_initial_state(self): return [] + @property def input_spec(self): - return SpecDict() + return SpecDict( + {self.config.input_key: TorchTensorSpec("b, h", h=self.config.input_dim)} + ) + @property def output_spec(self): - return SpecDict() + return SpecDict( + {self.config.output_key: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) - @check_input_specs("_input_spec") - @check_output_specs("_output_spec") + @check_input_specs("input_spec") + @check_output_specs("output_spec") def forward(self, input_dict): return self._forward(input_dict) @@ -104,18 +114,8 @@ def __init__(self, config: FCConfig) -> None: activation=config.activation, ) - def input_spec(self): - return SpecDict( - {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} - ) - - def output_spec(self): - return ModelSpec( - {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} - ) - def _forward(self, input_dict): - return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} + return {self.config.output_key: self.net(input_dict[self.config.input_key])} class LSTMEncoder(Encoder): @@ -137,13 +137,14 @@ def get_initial_state(self): "c": torch.zeros(config.num_layers, config.hidden_dim), } + @property def input_spec(self): config = self.config return SpecDict( { # bxt is just a name for better readability to indicated padded batch - SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - STATE_IN: { + self.config.input_key: TorchTensorSpec("bxt, h", h=config.input_dim), + self.config.state_in_key: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -154,12 +155,13 @@ def input_spec(self): } ) + @property def output_spec(self): config = self.config return SpecDict( { - ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { + self.config.output_key: TorchTensorSpec("bxt, h", h=config.output_dim), + self.config.state_out_key: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -189,8 +191,10 @@ def _forward(self, input_dict: SampleBatch): x = x.view(-1, x.shape[-1]) return { - ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + self.config.output_key: x, + self.config.state_out_key: tree.map_structure( + lambda x: x.transpose(0, 1), states_o + ), } @@ -198,5 +202,13 @@ class IdentityEncoder(Encoder): def __init__(self, config: EncoderConfig) -> None: super().__init__(config) + @property + def input_spec(self): + return SpecDict() + + @property + def output_spec(self): + return SpecDict() + def _forward(self, input_dict): return input_dict diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index c006af27f6f1..769ce08f3d96 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -1,22 +1,11 @@ -# Copyright 2021 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import abc +from enum import Enum from typing import Optional, Tuple +from collections import defaultdict +from typing import Mapping +from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.temp_spec_classes import TensorDict, SpecDict, ModelConfig +from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig from ray.rllib.utils.annotations import ( DeveloperAPI, OverrideToImplementCustomLogic, @@ -30,6 +19,34 @@ UnrollOutputType = Tuple[TensorDict, TensorDict] +@ExperimentalAPI +class BaseModelIOKeys(Enum): + IN: str = "in" + OUT: str = "out" + STATE_IN: str = "state_in" + STATE_OUT: str = "state_out" + + +class ModelIOKeyHelper: + """Creates unique IO keys for models. + + In order to distinguish keys in input- and outputs-specs of multiple instances of + a give model, each instance is supposed to have distinct keys. + This helper provides a way to generate distinct keys per instance of + ModelIOMapping. + """ + + __init_counters__ = defaultdict(lambda: 0) + + def __init__(self, model_name: str): + self._name: str = model_name + self._init_idx: str = str(self.__init_counters__[model_name]) + self.__init_counters__[model_name] += 1 + + def create(self, key): + return self._name + "_" + str(key) + "_" + self._init_idx + + @ExperimentalAPI class RecurrentModel(abc.ABC): """The base model all other models are based on. @@ -272,6 +289,16 @@ def _unroll( outputs = self._forward(inputs, **kwargs) return outputs, TensorDict() + def forward( + self, input_dict, input_mapping: Mapping = None, **kwargs + ) -> ForwardOutputType: + if input_mapping: + for forward_key, input_dict_key in input_mapping.items(): + if input_dict_key in input_dict: + input_dict[forward_key] = input_dict[input_dict_key] + input_dict.update(self._forward(input_dict, **kwargs)) + return input_dict + @abc.abstractmethod def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: """Computes the output of this module for each timestep. From 096a6126de90cf34af2a4b08861611bfaefe7c39 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 6 Jan 2023 01:44:26 +0100 Subject: [PATCH 20/51] ppo_torch_rl_module tests working Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 4 +- .../ppo/torch/ppo_torch_rl_module.py | 79 +++++++++++++------ 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index ccdb61f6d375..96e6b4e2a5e6 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -93,7 +93,7 @@ def get_expected_model_config( ) pi_input_key = pi_encoder_config.output_key vf_input_key = vf_encoder_config.output_key - pi_kh = ModelIOKeyHelper("pi_encoder") + pi_kh = ModelIOKeyHelper("pi") pi_config = FCConfig( input_dim=pi_encoder_config.output_dim, hidden_layers=[32], @@ -101,7 +101,7 @@ def get_expected_model_config( input_key=pi_input_key, output_key=pi_kh.create(BaseModelIOKeys.OUT), ) - vf_kh = ModelIOKeyHelper("pi_encoder") + vf_kh = ModelIOKeyHelper("vf") vf_config = FCConfig( input_dim=vf_encoder_config.output_dim, output_dim=1, diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 74f6ef247baf..a7531b2ac8d3 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -22,6 +22,7 @@ LSTMEncoder, STATE_OUT, ) +from ray.rllib.models.base_model import BaseModelIOKeys, ModelIOKeyHelper from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space @@ -115,40 +116,72 @@ def from_model_config( use_lstm = model_config["use_lstm"] if vf_share_layers: + shared_encoder_kh = ModelIOKeyHelper("shared_encoder") shared_encoder_config = FCConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], + input_key=SampleBatch.OBS, + output_key=shared_encoder_kh.create(BaseModelIOKeys.OUT), ) + pi_encoder_config = IdentityConfig( + output_dim=model_config["fcnet_hiddens"][-1] + ) + vf_encoder_config = IdentityConfig( + output_dim=model_config["fcnet_hiddens"][-1] + ) + pi_input_key = shared_encoder_config.output_key + vf_input_key = shared_encoder_config.output_key else: shared_encoder_config = IdentityConfig(output_dim=obs_dim) - - if use_lstm: - pi_encoder_config = LSTMConfig( - input_dim=shared_encoder_config.output_dim, - hidden_dim=model_config["lstm_cell_size"], - batch_first=not model_config["_time_major"], - output_dim=model_config["lstm_cell_size"], - num_layers=1, + pi_encoder_kh = ModelIOKeyHelper("pi_encoder") + + if use_lstm: + pi_encoder_config = LSTMConfig( + input_dim=shared_encoder_config.output_dim, + hidden_dim=model_config["lstm_cell_size"], + batch_first=not model_config["_time_major"], + output_dim=model_config["lstm_cell_size"], + num_layers=1, + input_key=SampleBatch.OBS, + state_in_key="state_in_0", + output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), + state_out_key="state_out_0", + ) + else: + pi_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=fcnet_hiddens, + activation=activation, + output_dim=model_config["fcnet_hiddens"][-1], + input_key=SampleBatch.OBS, + output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), + ) + + vf_encoder_kh = ModelIOKeyHelper("vf_encoder") + vf_encoder_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=fcnet_hiddens, + activation=activation, + output_dim=model_config["fcnet_hiddens"][-1], + input_key=SampleBatch.OBS, + output_key=vf_encoder_kh.create(BaseModelIOKeys.OUT), + ) + pi_input_key = pi_encoder_config.output_key + vf_input_key = vf_encoder_config.output_key + + pi_kh = ModelIOKeyHelper("pi") + vf_kh = ModelIOKeyHelper("vf") + pi_config = FCConfig( + input_key=pi_input_key, + output_key=pi_kh.create(BaseModelIOKeys.OUT), ) - else: - pi_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], + vf_config = FCConfig( + input_key=vf_input_key, + output_key=vf_kh.create(BaseModelIOKeys.OUT), ) - vf_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - ) - pi_config = FCConfig() - vf_config = FCConfig() - assert isinstance( observation_space, gym.spaces.Box ), "This simple PPOModule only supports Box observation space." From a373538e472686bf5e1435cef317c08369448b0d Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 19 Jan 2023 00:00:11 -0800 Subject: [PATCH 21/51] feedback from kourosh from last week Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 263 +++++++----------- .../ppo/torch/ppo_torch_rl_module.py | 229 ++++++--------- rllib/core/rl_module/encoder.py | 158 ++++------- rllib/core/rl_module/fc.py | 58 ++++ rllib/models/base_model.py | 72 ++--- 5 files changed, 325 insertions(+), 455 deletions(-) create mode 100644 rllib/core/rl_module/fc.py diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 96e6b4e2a5e6..644c93599c9b 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -11,10 +11,10 @@ get_ppo_loss, PPOModuleConfig, ) +from ray.rllib.core.rl_module.fc import FCConfig from ray.rllib.core.rl_module.encoder import ( - FCConfig, - IdentityConfig, - LSTMConfig, + FCEncoderConfig, + LSTMEncoderConfig, STATE_IN, STATE_OUT, ) @@ -24,16 +24,14 @@ def get_expected_model_config( - env: gym.Env, lstm: bool, shared_encoder: bool + env: gym.Env, + lstm: bool, ) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: env: Environment for which we build the model later lstm: If True, build recurrent pi encoder - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. Returns: A PPOModuleConfig containing the relevant configs to build PPORLModule @@ -43,72 +41,33 @@ def get_expected_model_config( ) obs_dim = env.observation_space.shape[0] - if shared_encoder: - shared_encoder_kh = ModelIOKeyHelper("shared_encoder") - assert not lstm, "LSTM can only be used in PI" - shared_encoder_config = FCConfig( + if lstm: + shared_encoder_config = LSTMEncoderConfig( input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", + hidden_dim=32, + batch_first=True, + num_layers=1, output_dim=32, input_key=SampleBatch.OBS, output_key=shared_encoder_kh.create(BaseModelIOKeys.OUT), ) - pi_encoder_config = IdentityConfig(output_dim=32) - vf_encoder_config = IdentityConfig(output_dim=32) - pi_input_key = shared_encoder_config.output_key - vf_input_key = shared_encoder_config.output_key else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - pi_encoder_kh = ModelIOKeyHelper("pi_encoder") - vf_encoder_kh = ModelIOKeyHelper("vf_encoder") - if lstm: - pi_encoder_config = LSTMConfig( - input_dim=obs_dim, - hidden_dim=32, - batch_first=True, - output_dim=32, - num_layers=1, - input_key=SampleBatch.OBS, - state_in_key="state_in_0", - output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), - state_out_key="state_out_0", - ) - else: - pi_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - input_key=SampleBatch.OBS, - output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), - ) - vf_encoder_config = FCConfig( + shared_encoder_config = FCEncoderConfig( input_dim=obs_dim, - output_dim=32, hidden_layers=[32], activation="ReLU", - input_key=SampleBatch.OBS, - output_key=vf_encoder_kh.create(BaseModelIOKeys.OUT), + output_dim=32, ) - pi_input_key = pi_encoder_config.output_key - vf_input_key = vf_encoder_config.output_key - pi_kh = ModelIOKeyHelper("pi") + pi_config = FCConfig( - input_dim=pi_encoder_config.output_dim, + input_dim=32, hidden_layers=[32], activation="ReLU", - input_key=pi_input_key, - output_key=pi_kh.create(BaseModelIOKeys.OUT), ) - vf_kh = ModelIOKeyHelper("vf") vf_config = FCConfig( - input_dim=vf_encoder_config.output_dim, - output_dim=1, - hidden_layers=[32], + input_dim=32, + hidden_layers=[32, 1], activation="ReLU", - input_key=vf_input_key, - output_key=vf_kh.create(BaseModelIOKeys.OUT), ) if isinstance(env.action_space, gym.spaces.Discrete): @@ -120,11 +79,8 @@ def get_expected_model_config( observation_space=env.observation_space, action_space=env.action_space, shared_encoder_config=shared_encoder_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, pi_config=pi_config, vf_config=vf_config, - shared_encoder=shared_encoder, ) @@ -141,122 +97,105 @@ def test_rollouts(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" - f"={lstm}" - ) - env = gym.make(env_name) + for lstm in [False, True]: + print(f"[ENV={env_name}] | LSTM={lstm}") + env = gym.make(env_name) - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) + config = get_expected_model_config(env, lstm) + module = PPOTorchRLModule(config) - obs, _ = env.reset() + obs, _ = env.reset() - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - } + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } - if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - batch[STATE_IN] = state_in - batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch[STATE_IN] = state_in + batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) - if fwd_fn == "forward_exploration": - module.forward_exploration(batch) - else: - module.forward_inference(batch) + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + else: + module.forward_inference(batch) def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue - print( - f"[ENV={env_name}] | [SHARED=" - f"{shared_encoder}] | LSTM={lstm}" + for lstm in [False, True]: + print(f"[ENV={env_name}] | LSTM={lstm}") + env = gym.make(env_name) + + config = get_expected_model_config(env, lstm) + module = PPOTorchRLModule(config) + + # collect a batch of data + batches = [] + obs, _ = env.reset() + tstep = 0 + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) ) - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm, shared_encoder) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs, _ = env.reset() - tstep = 0 + initial_state = state_in + while tstep < 10: if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - initial_state = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, terminated, truncated, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.TERMINATEDS: np.array(terminated), - SampleBatch.TRUNCATEDS: np.array(truncated), + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + STATE_IN: state_in, + SampleBatch.SEQ_LENS: np.array([1]), } - if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors - fwd_in = { - k: convert_to_torch_tensor(np.array(v)) - for k, v in batch.items() + else: + input_batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None] + } + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy( + fwd_out["action_dist"].sample().squeeze(0) + ) + new_obs, reward, terminated, truncated, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.TERMINATEDS: np.array(terminated), + SampleBatch.TRUNCATEDS: np.array(truncated), } if lstm: - fwd_in[STATE_IN] = initial_state - fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) - - # forward train - # before training make sure module is on the right device and in - # training mode - module.to("cpu") - module.train() - fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) - loss.backward() - - # check that all neural net parameters have gradients - for param in module.parameters(): - self.assertIsNotNone(param.grad) + assert STATE_OUT in fwd_out + state_in = fwd_out[STATE_OUT] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: list(x), *batches) + # convert dict of lists to dict of tensors + fwd_in = { + k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items() + } + if lstm: + fwd_in[STATE_IN] = initial_state + fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) + + # forward train + # before training make sure module is on the right device and in + # training mode + module.to("cpu") + module.train() + fwd_out = module.forward_train(fwd_in) + loss = get_ppo_loss(fwd_in, fwd_out) + loss.backward() + + # check that all neural net parameters have gradients + for param in module.parameters(): + self.assertIsNotNone(param.grad) if __name__ == "__main__": diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index a7531b2ac8d3..b9ed311c077f 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,13 +1,17 @@ from dataclasses import dataclass -import gymnasium as gym from typing import Mapping, Any, Union -from ray.rllib.core.rl_module.torch import TorchRLModule +import gymnasium as gym + +from ray.rllib.core.rl_module.encoder import ( + FCConfig, + FCEncoderConfig, + LSTMEncoderConfig, + LSTMEncoder, +) from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import override -from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.utils.framework import try_import_torch +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.base_model import STATE_OUT from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.torch_distributions import ( @@ -15,16 +19,12 @@ TorchDeterministic, TorchDiagGaussian, ) -from ray.rllib.core.rl_module.encoder import ( - FCConfig, - LSTMConfig, - IdentityConfig, - LSTMEncoder, - STATE_OUT, -) -from ray.rllib.models.base_model import BaseModelIOKeys, ModelIOKeyHelper +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space - +from ray.rllib.utils.nested_dict import NestedDict +from rllib.core.rl_module.encoder import ENCODER_OUT torch, nn = try_import_torch() @@ -47,22 +47,22 @@ class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO module. Attributes: + observation_space: The observation space of the environment. + action_space: The action space of the environment. + shared_encoder_config: The configuration for the encoder network. pi_config: The configuration for the policy network. vf_config: The configuration for the value network. - shared_encoder_config: The configuration for the encoder network. free_log_std: For DiagGaussian action distributions, make the second half of the model outputs floating bias variables instead of state-dependent. This only has an effect is using the default fully connected net. - shared_encoder: Whether to share the encoder between the pi and value """ + observation_space: gym.Space = None + action_space: gym.Space = None + shared_encoder_config: FCConfig = None pi_config: FCConfig = None vf_config: FCConfig = None - pi_encoder_config: FCConfig = None - vf_encoder_config: FCConfig = None - shared_encoder_config: FCConfig = None free_log_std: bool = False - shared_encoder: bool = True class PPOTorchRLModule(TorchRLModule): @@ -72,13 +72,13 @@ def __init__(self, config: PPOModuleConfig) -> None: self.setup() def setup(self) -> None: - assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." + assert self.config.shared_encoder_config, ( + "shared encoder config must be " "provided." + ) self.shared_encoder = self.config.shared_encoder_config.build() - self.pi_encoder = self.config.pi_encoder_config.build() - self.vf_encoder = self.config.vf_encoder_config.build() self.pi = self.config.pi_config.build() self.vf = self.config.vf_config.build() @@ -111,77 +111,39 @@ def from_model_config( obs_dim = observation_space.shape[0] fcnet_hiddens = model_config["fcnet_hiddens"] - vf_share_layers = model_config["vf_share_layers"] free_log_std = model_config["free_log_std"] - use_lstm = model_config["use_lstm"] + assert ( + model_config.get("vf_share_layers") is False + ), "`vf_share_layers=False` is no longer supported." - if vf_share_layers: - shared_encoder_kh = ModelIOKeyHelper("shared_encoder") - shared_encoder_config = FCConfig( + if model_config["use_lstm"]: + shared_encoder_config = LSTMEncoderConfig( input_dim=obs_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - input_key=SampleBatch.OBS, - output_key=shared_encoder_kh.create(BaseModelIOKeys.OUT), - ) - pi_encoder_config = IdentityConfig( - output_dim=model_config["fcnet_hiddens"][-1] - ) - vf_encoder_config = IdentityConfig( - output_dim=model_config["fcnet_hiddens"][-1] + hidden_dim=model_config["lstm_cell_size"], + batch_first=not model_config["_time_major"], + num_layers=1, + output_dim=model_config["lstm_cell_size"], ) - pi_input_key = shared_encoder_config.output_key - vf_input_key = shared_encoder_config.output_key else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - pi_encoder_kh = ModelIOKeyHelper("pi_encoder") - - if use_lstm: - pi_encoder_config = LSTMConfig( - input_dim=shared_encoder_config.output_dim, - hidden_dim=model_config["lstm_cell_size"], - batch_first=not model_config["_time_major"], - output_dim=model_config["lstm_cell_size"], - num_layers=1, - input_key=SampleBatch.OBS, - state_in_key="state_in_0", - output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), - state_out_key="state_out_0", - ) - else: - pi_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - input_key=SampleBatch.OBS, - output_key=pi_encoder_kh.create(BaseModelIOKeys.OUT), - ) - - vf_encoder_kh = ModelIOKeyHelper("vf_encoder") - vf_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - input_key=SampleBatch.OBS, - output_key=vf_encoder_kh.create(BaseModelIOKeys.OUT), - ) - pi_input_key = pi_encoder_config.output_key - vf_input_key = vf_encoder_config.output_key - - pi_kh = ModelIOKeyHelper("pi") - vf_kh = ModelIOKeyHelper("vf") - pi_config = FCConfig( - input_key=pi_input_key, - output_key=pi_kh.create(BaseModelIOKeys.OUT), - ) - vf_config = FCConfig( - input_key=vf_input_key, - output_key=vf_kh.create(BaseModelIOKeys.OUT), + shared_encoder_config = FCEncoderConfig( + input_dim=obs_dim, + hidden_layers=fcnet_hiddens[:-1], + activation=activation, + output_dim=fcnet_hiddens[-1], ) + pi_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=[32], + activation="ReLU", + ) + vf_config = FCConfig( + input_dim=shared_encoder_config.output_dim, + hidden_layers=[32, 1], + activation="ReLU", + output_dim=1, + ) + assert isinstance( observation_space, gym.spaces.Box ), "This simple PPOModule only supports Box observation space." @@ -194,31 +156,21 @@ def from_model_config( "This simple PPOModule only supports Discrete and Box action space.", ) - # build pi network + # build policy network head shared_encoder_config.input_dim = observation_space.shape[0] - pi_encoder_config.input_dim = shared_encoder_config.output_dim - pi_config.input_dim = pi_encoder_config.output_dim + pi_config.input_dim = shared_encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 - # build vf network - vf_encoder_config.input_dim = shared_encoder_config.output_dim - vf_config.input_dim = vf_encoder_config.output_dim - vf_config.output_dim = 1 - config_ = PPOModuleConfig( observation_space=observation_space, action_space=action_space, - max_seq_len=model_config["max_seq_len"], shared_encoder_config=shared_encoder_config, pi_config=pi_config, vf_config=vf_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, free_log_std=free_log_std, - shared_encoder=vf_share_layers, ) module = PPOTorchRLModule(config_) @@ -227,8 +179,6 @@ def from_model_config( def get_initial_state(self) -> NestedDict: if isinstance(self.shared_encoder, LSTMEncoder): return self.shared_encoder.get_initial_state() - elif isinstance(self.pi_encoder, LSTMEncoder): - return self.pi_encoder.get_initial_state() else: return NestedDict({}) @@ -242,23 +192,20 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: + output = {} + encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - pi_out = self.pi(encoder_out_pi) - action_logits = pi_out[self.pi.config.output_key] + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + # Actions + action_logits = self.pi(encoder_out[ENCODER_OUT]) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) else: action, _ = action_logits.chunk(2, dim=-1) - action_dist = TorchDeterministic(action) - output = {SampleBatch.ACTION_DIST: action_dist} - - if hasattr(self.shared_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] - if hasattr(self.pi_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] + output[SampleBatch.ACTION_DIST] = action_dist return output @@ -289,14 +236,20 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ + output = {} + + # Shared encoder encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) - pi_out = self.pi(encoder_out_pi) - action_logits = pi_out[self.pi.config.output_key] - vf_out = self.vf(encoder_out_vf) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] - output = {} + # Value head + vf_out = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) + + # Policy head + pi_out = self.pi(encoder_out[ENCODER_OUT]) + action_logits = pi_out if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) output[SampleBatch.ACTION_DIST_INPUTS] = {"logits": action_logits} @@ -307,13 +260,6 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST_INPUTS] = {"loc": loc, "scale": scale} output[SampleBatch.ACTION_DIST] = action_dist - if hasattr(self.shared_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] - if hasattr(self.pi_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] - - # compute the value function - output[SampleBatch.VF_PREDS] = vf_out[self.vf.config.output_key].squeeze(-1) return output @override(RLModule) @@ -345,33 +291,30 @@ def output_specs_train(self) -> SpecDict: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: + output = {} + + # Shared encoder encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) - pi_out = self.pi(encoder_out_pi) - action_logits = pi_out[self.pi.config.output_key] - vf_out = self.vf(encoder_out_vf) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + + # Value head + vf_out = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) + # Policy head + pi_out = self.pi(encoder_out[ENCODER_OUT]) + action_logits = pi_out if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) else: mu, scale = action_logits.chunk(2, dim=-1) action_dist = TorchDiagGaussian(mu, scale.exp()) - logp = action_dist.logp(batch[SampleBatch.ACTIONS]) entropy = action_dist.entropy() - - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.ACTION_LOGP: logp, - SampleBatch.VF_PREDS: vf_out[self.vf.config.output_key].squeeze(-1), - "entropy": entropy, - } - - if hasattr(self.shared_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out[self.shared_encoder.config.state_out_key] - if hasattr(self.pi_encoder.config, "state_out_key"): - output[STATE_OUT] = encoder_out_pi[self.pi_encoder.config.state_out_key] + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.ACTION_LOGP] = logp + output["entropy"] = entropy return output diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index bbe1d5084994..ca2e85542e60 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -1,126 +1,68 @@ import torch import torch.nn as nn import tree -from typing import List -from dataclasses import dataclass, field +from dataclasses import dataclass -from ray.rllib.models.base_model import Model +from ray.rllib.models.base_model import ( + ModelConfig, + Model, + STATE_IN, + STATE_OUT, + ForwardOutputType, +) +from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.torch.primitives import FCNet +from ray.rllib.core.rl_module.fc import FC, FCConfig -# TODO (Kourosh): Find a better / more straight fwd approach for sub-components -ENCODER_OUT = "encoder_out" -STATE_IN = "state_in_0" -STATE_OUT = "state_out_0" +ENCODER_OUT: str = "encoder_out" @dataclass -class EncoderConfig: - """Configuration for an encoder network. - - Attributes: - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. - """ - - output_dim: int = None - input_key: str = None - output_key: str = None - - -@dataclass -class IdentityConfig(EncoderConfig): - """Configuration for an identity encoder.""" - +class FCEncoderConfig(FCConfig): def build(self): - return IdentityEncoder(self) - - -@dataclass -class FCConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FullyConnectedEncoder(self) - - -@dataclass -class LSTMConfig(EncoderConfig): - input_dim: int = None - hidden_dim: int = None - num_layers: int = None - batch_first: bool = True - state_in_key: str = None - state_out_key: str = None - - def build(self): - return LSTMEncoder(self) - - -class Encoder(Model, nn.Module): - def __init__(self, config: EncoderConfig) -> None: - nn.Module.__init__(self) - Model.__init__(self) - self.config = config + return FCEncoder(self) - def get_initial_state(self): - return [] +class FCEncoder(FC): @property def input_spec(self): return SpecDict( - {self.config.input_key: TorchTensorSpec("b, h", h=self.config.input_dim)} + {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} ) @property def output_spec(self): return SpecDict( - {self.config.output_key: TorchTensorSpec("b, h", h=self.config.output_dim)} + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} ) - @check_input_specs("input_spec") - @check_output_specs("output_spec") - def forward(self, input_dict): - return self._forward(input_dict) + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} - def _forward(self, input_dict): - raise NotImplementedError +@dataclass +class LSTMEncoderConfig(ModelConfig): + input_dim: int = None + hidden_dim: int = None + num_layers: int = None + batch_first: bool = True -class FullyConnectedEncoder(Encoder): - def __init__(self, config: FCConfig) -> None: - super().__init__(config) - - self.net = FCNet( - input_dim=config.input_dim, - hidden_layers=config.hidden_layers, - output_dim=config.output_dim, - activation=config.activation, - ) - - def _forward(self, input_dict): - return {self.config.output_key: self.net(input_dict[self.config.input_key])} + def build(self): + return LSTMEncoder(self) -class LSTMEncoder(Encoder): - def __init__(self, config: LSTMConfig) -> None: - super().__init__(config) +class LSTMEncoder(Model, nn.Module): + def __init__(self, config: LSTMEncoderConfig) -> None: + nn.Module.__init__(self) + Model.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -152,6 +94,7 @@ def input_spec(self): "b, l, h", h=config.hidden_dim, l=config.num_layers ), }, + SampleBatch.SEQ_LENS: None, } ) @@ -172,15 +115,17 @@ def output_spec(self): } ) - def _forward(self, input_dict: SampleBatch): - x = input_dict[SampleBatch.OBS] - states = input_dict[STATE_IN] + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + x = inputs[SampleBatch.OBS] + states = inputs[STATE_IN] # states are batch-first when coming in states = tree.map_structure(lambda x: x.transpose(0, 1), states) x = add_time_dimension( x, - seq_lens=input_dict[SampleBatch.SEQ_LENS], + seq_lens=inputs[SampleBatch.SEQ_LENS], framework="torch", time_major=not self.config.batch_first, ) @@ -198,17 +143,20 @@ def _forward(self, input_dict: SampleBatch): } -class IdentityEncoder(Encoder): - def __init__(self, config: EncoderConfig) -> None: - super().__init__(config) +@dataclass +class IdentityConfig(ModelConfig): + """Configuration for an identity encoder.""" - @property - def input_spec(self): - return SpecDict() + def build(self): + return IdentityEncoder(self) - @property - def output_spec(self): - return SpecDict() - def _forward(self, input_dict): - return input_dict +class IdentityEncoder(Model): + def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + pass + + def __init__(self, config: IdentityConfig) -> None: + super().__init__(config) + + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return inputs diff --git a/rllib/core/rl_module/fc.py b/rllib/core/rl_module/fc.py new file mode 100644 index 000000000000..2e037fa03d40 --- /dev/null +++ b/rllib/core/rl_module/fc.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import List + +import torch.nn as nn + +from ray.rllib.models.base_model import Model, ModelConfig, ForwardOutputType +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.torch.primitives import FCNet + + +@dataclass +class FCConfig(ModelConfig): + """Configuration for a fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer (except for the + output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) + activation: str = "ReLU" + output_activation: str = "ReLU" + + def build(self) -> Model: + return FC(self) + + +class FC(Model, nn.Module): + def __init__(self, config: FCConfig) -> None: + nn.Module.__init__(self) + Model.__init__(self, config) + + self.net = FCNet( + input_dim=config.input_dim, + hidden_layers=config.hidden_layers, + output_dim=config.output_dim, + activation=config.activation, + ) + + @property + def input_spec(self): + return TorchTensorSpec("b, h", h=self.config.input_dim) + + @property + def output_spec(self): + return TorchTensorSpec("b, h", h=self.config.output_dim) + + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return self.net(inputs) diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index 769ce08f3d96..cadf65539a31 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -1,9 +1,7 @@ import abc -from enum import Enum from typing import Optional, Tuple -from collections import defaultdict -from typing import Mapping from ray.rllib.models.specs.specs_dict import SpecDict +from dataclasses import dataclass from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig from ray.rllib.utils.annotations import ( @@ -18,33 +16,8 @@ # [Output, Recurrent State(s)] UnrollOutputType = Tuple[TensorDict, TensorDict] - -@ExperimentalAPI -class BaseModelIOKeys(Enum): - IN: str = "in" - OUT: str = "out" - STATE_IN: str = "state_in" - STATE_OUT: str = "state_out" - - -class ModelIOKeyHelper: - """Creates unique IO keys for models. - - In order to distinguish keys in input- and outputs-specs of multiple instances of - a give model, each instance is supposed to have distinct keys. - This helper provides a way to generate distinct keys per instance of - ModelIOMapping. - """ - - __init_counters__ = defaultdict(lambda: 0) - - def __init__(self, model_name: str): - self._name: str = model_name - self._init_idx: str = str(self.__init_counters__[model_name]) - self.__init_counters__[model_name] += 1 - - def create(self, key): - return self._name + "_" + str(key) + "_" + self._init_idx +STATE_IN: str = "state_in" +STATE_OUT: str = "state_out" @ExperimentalAPI @@ -72,8 +45,9 @@ class RecurrentModel(abc.ABC): name: An optional name for the module """ - def __init__(self, name: Optional[str] = None): + def __init__(self, config: ModelConfig, name: Optional[str] = None): self._name = name or self.__class__.__name__ + self.config = config @property def name(self) -> str: @@ -289,18 +263,8 @@ def _unroll( outputs = self._forward(inputs, **kwargs) return outputs, TensorDict() - def forward( - self, input_dict, input_mapping: Mapping = None, **kwargs - ) -> ForwardOutputType: - if input_mapping: - for forward_key, input_dict_key in input_mapping.items(): - if input_dict_key in input_dict: - input_dict[forward_key] = input_dict[input_dict_key] - input_dict.update(self._forward(input_dict, **kwargs)) - return input_dict - @abc.abstractmethod - def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: """Computes the output of this module for each timestep. Args: @@ -319,7 +283,7 @@ def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: @ExperimentalAPI class ModelIO(abc.ABC): - """Abstract class defining how to save and load model weights + """Abstract class defining how to save and load model weights. Args: config: The ModelConfig passed to the underlying model @@ -336,7 +300,7 @@ def config(self) -> ModelConfig: @DeveloperAPI @abc.abstractmethod def save(self, path: str) -> None: - """Save model weights to a path + """Save model weights to a path. Args: path: The path on disk where weights are to be saved @@ -349,7 +313,7 @@ def save(self, path: str) -> None: @DeveloperAPI @abc.abstractmethod def load(self, path: str) -> RecurrentModel: - """Load model weights from a path + """Load model weights from a path. Args: path: The path on disk where to load weights from @@ -358,3 +322,21 @@ def load(self, path: str) -> RecurrentModel: model.load("/tmp/model_path.cpt") """ raise NotImplementedError + + +@ExperimentalAPI +@dataclass +class ModelConfig(abc.ABC): + """Configuration for an encoder network. + + Attributes: + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + """ + + output_dim: int = None + + @abc.abstractmethod + def build(self) -> RecurrentModel: + """Builds the model.""" + raise NotImplementedError From c2ba97cda521fc9f410253c1c276da25558088a2 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 19 Jan 2023 15:20:38 -0800 Subject: [PATCH 22/51] solution 3 Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 29 +++ .../ppo/tests/test_ppo_rl_module.py | 169 +++++++++++++----- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 57 ++---- .../ppo/torch/ppo_torch_rl_module.py | 46 ++--- rllib/core/rl_module/encoder_tf.py | 37 ---- rllib/core/rl_module/model_configs.py | 77 ++++++++ rllib/core/rl_module/tf/encoder.py | 129 +++++++++++++ rllib/core/rl_module/tf/fcmodel.py | 34 ++++ rllib/core/rl_module/{ => torch}/encoder.py | 40 +---- .../rl_module/{fc.py => torch/fcmodel.py} | 32 +--- rllib/models/base_model.py | 2 +- rllib/models/configs/encoder.py | 83 --------- rllib/models/tf/primitives.py | 11 +- rllib/models/torch/primitives.py | 14 +- 14 files changed, 455 insertions(+), 305 deletions(-) delete mode 100644 rllib/core/rl_module/encoder_tf.py create mode 100644 rllib/core/rl_module/model_configs.py create mode 100644 rllib/core/rl_module/tf/encoder.py create mode 100644 rllib/core/rl_module/tf/fcmodel.py rename rllib/core/rl_module/{ => torch}/encoder.py (85%) rename rllib/core/rl_module/{fc.py => torch/fcmodel.py} (52%) delete mode 100644 rllib/models/configs/encoder.py diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 89b7955c149b..fee7cdc0429b 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -10,8 +10,14 @@ """ import logging +from dataclasses import dataclass from typing import List, Optional, Type, Union, TYPE_CHECKING +import gymnasium as gym + +from ray.rllib.core.rl_module.model_configs import FCConfig +from ray.rllib.core.rl_module.rl_module import RLModuleConfig + from ray.util.debug import log_once from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided @@ -465,3 +471,26 @@ def __getitem__(self, item): DEFAULT_CONFIG = _deprecated_default_config() + + +@dataclass +class PPOModuleConfig(RLModuleConfig): + """Configuration for the PPO module. + + Attributes: + observation_space: The observation space of the environment. + action_space: The action space of the environment. + shared_encoder_config: The configuration for the encoder network. + pi_config: The configuration for the policy network. + vf_config: The configuration for the value network. + free_log_std: For DiagGaussian action distributions, make the second half of + the model outputs floating bias variables instead of state-dependent. This + only has an effect is using the default fully connected net. + """ + + observation_space: gym.Space = None + action_space: gym.Space = None + shared_encoder_config: FCConfig = None + pi_config: FCConfig = None + vf_config: FCConfig = None + free_log_std: bool = False diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 644c93599c9b..95546b3c6ea7 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -8,17 +8,20 @@ from ray.rllib import SampleBatch from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule, - get_ppo_loss, - PPOModuleConfig, ) -from ray.rllib.core.rl_module.fc import FCConfig -from ray.rllib.core.rl_module.encoder import ( +from rllib.algorithms.ppo.ppo import PPOModuleConfig +from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( + PPOTfModule, +) +from ray.rllib.core.rl_module.model_configs import ( + FCConfig, FCEncoderConfig, LSTMEncoderConfig, +) +from ray.rllib.core.rl_module.torch.encoder import ( STATE_IN, STATE_OUT, ) -from ray.rllib.models.base_model import BaseModelIOKeys, ModelIOKeyHelper from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -84,6 +87,49 @@ def get_expected_model_config( ) +def dummy_torch_ppo_loss(batch, fwd_out): + """Dummy PPO loss function for testing purposes. + + Will eventually use the actual PPO loss function implemented in the PPOTfTrainer. + + Args: + batch: SampleBatch used for training. + fwd_out: Forward output of the model. + + Returns: + Loss tensor + """ + # TODO: we should replace these components later with real ppo components when + # RLOptimizer and RLModule are integrated together. + # this is not exactly a ppo loss, just something to show that the + # forward train works + adv = batch[SampleBatch.REWARDS] - fwd_out[SampleBatch.VF_PREDS] + actor_loss = -(fwd_out[SampleBatch.ACTION_LOGP] * adv).mean() + critic_loss = (adv**2).mean() + loss = actor_loss + critic_loss + + return loss + + +def dummy_tf_ppo_loss(batch, fwd_out): + """Dummy PPO loss function for testing purposes. + + Will eventually use the actual PPO loss function implemented in the PPOTfTrainer. + + Args: + batch: SampleBatch used for training. + fwd_out: Forward output of the model. + + Returns: + Loss tensor + """ + adv = batch[SampleBatch.REWARDS] - fwd_out[SampleBatch.VF_PREDS] + action_probs = fwd_out[SampleBatch.ACTION_DIST].logp(batch[SampleBatch.ACTIONS]) + actor_loss = -tf.reduce_mean(action_probs * adv) + critic_loss = tf.reduce_mean(tf.square(adv)) + return actor_loss + critic_loss + + class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): @@ -93,50 +139,83 @@ def setUpClass(cls): def tearDownClass(cls): ray.shutdown() + def get_ppo_module(self, framework, env, lstm): + config = get_expected_model_config(env, lstm) + if framework == "torch": + module = PPOTorchRLModule(config) + else: + module = PPOTfModule(config) + return module + + def get_input_batch_from_obs(self, framework, obs): + if framework == "torch": + batch = { + SampleBatch.OBS: convert_to_torch_tensor(obs)[None], + } + else: + batch = {SampleBatch.OBS: tf.convert_to_tensor([obs])} + return batch + def test_rollouts(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for fwd_fn in ["forward_exploration", "forward_inference"]: - for lstm in [False, True]: - print(f"[ENV={env_name}] | LSTM={lstm}") - env = gym.make(env_name) + frameworks = ["torch", "tf2"] + env_names = ["CartPole-v1", "Pendulum-v1"] + fwd_fns = ["forward_exploration", "forward_inference"] + lstm = [False, True] + config_combinations = [frameworks, env_names, fwd_fns, lstm] + for config in itertools.product(*config_combinations): + fw, env_name, fwd_fn, lstm = config + if lstm and fw == "tf2": + # LSTM not implemented in TF2 yet + continue + print(f"[FW={fw} | [ENV={env_name}] | [FWD={fwd_fn}] | LSTM" f"={lstm}") + env = gym.make(env_name) + module = self.get_ppo_module(framework=fw, env=env, lstm=lstm) - config = get_expected_model_config(env, lstm) - module = PPOTorchRLModule(config) + obs, _ = env.reset() - obs, _ = env.reset() + batch = self.get_input_batch_from_obs(fw, obs) - batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - } + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + batch[STATE_IN] = state_in + batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) - if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - batch[STATE_IN] = state_in - batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) - - if fwd_fn == "forward_exploration": - module.forward_exploration(batch) - else: - module.forward_inference(batch) + if fwd_fn == "forward_exploration": + module.forward_exploration(batch) + else: + module.forward_inference(batch) def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space - for env_name in ["CartPole-v1", "Pendulum-v1"]: - for lstm in [False, True]: - print(f"[ENV={env_name}] | LSTM={lstm}") - env = gym.make(env_name) - - config = get_expected_model_config(env, lstm) - module = PPOTorchRLModule(config) - - # collect a batch of data - batches = [] - obs, _ = env.reset() - tstep = 0 + frameworks = ["torch", "tf2"] + env_names = ["CartPole-v1", "Pendulum-v1"] + lstm = [False, True] + config_combinations = [frameworks, env_names, lstm] + for config in itertools.product(*config_combinations): + fw, env_name, lstm = config + if lstm and fw == "tf2": + # LSTM not implemented in TF2 yet + continue + print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}") + env = gym.make(env_name) + + module = self.get_ppo_module(fw, env, lstm) + + # collect a batch of data + batches = [] + obs, _ = env.reset() + tstep = 0 + if lstm: + state_in = module.get_initial_state() + state_in = tree.map_structure( + lambda x: x[None], convert_to_torch_tensor(state_in) + ) + initial_state = state_in + while tstep < 10: if lstm: state_in = module.get_initial_state() state_in = tree.map_structure( @@ -196,6 +275,16 @@ def test_forward_train(self): # check that all neural net parameters have gradients for param in module.parameters(): self.assertIsNotNone(param.grad) + else: + batch = tree.map_structure( + lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch + ) + with tf.GradientTape() as tape: + fwd_out = module.forward_train(batch) + loss = dummy_tf_ppo_loss(batch, fwd_out) + grads = tape.gradient(loss, module.trainable_variables) + for grad in grads: + self.assertIsNotNone(grad) if __name__ == "__main__": diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 5568d06a4224..28230091a5cb 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -1,41 +1,25 @@ -from dataclasses import dataclass import gymnasium as gym from typing import Mapping, Any, List from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.core.rl_module.encoder_tf import FCTfConfig, IdentityTfConfig +from ray.rllib.core.rl_module.model_configs import FCConfig, IdentityConfig from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian from ray.rllib.models.tf.primitives import FCNet +from ray.rllib.core.rl_module.tf.encoder import ENCODER_OUT +from ray.rllib.algorithms.ppo.ppo import PPOModuleConfig tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() -@dataclass -class PPOTfModuleConfig(RLModuleConfig): - """Configuration for the PPO module. - - Attributes: - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. - """ - - observation_space: gym.Space = None - action_space: gym.Space = None - pi_config: FCTfConfig = None - vf_config: FCTfConfig = None - shared_encoder_config: FCTfConfig = None - shared_encoder: bool = True - - -class PPOTfRLModule(TfRLModule): - def __init__(self, config: PPOTfModuleConfig): +class PPOTfModule(TfRLModule): + def __init__(self, config: RLModuleConfig): super().__init__() self.config = config self.setup() @@ -43,7 +27,7 @@ def __init__(self, config: PPOTfModuleConfig): def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - self.shared_encoder = self.config.shared_encoder_config.build() + self.shared_encoder = self.config.shared_encoder_config.build(framework="tf") self.pi = FCNet( input_dim=self.config.shared_encoder_config.output_dim, @@ -77,10 +61,9 @@ def output_specs_train(self) -> List[str]: @override(TfRLModule) def _forward_train(self, batch: NestedDict): - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) - action_logits = self.pi(encoder_out) - vf = self.vf(encoder_out) + encoder_out = self.shared_encoder(batch) + action_logits = self.pi(encoder_out[ENCODER_OUT]) + vf = self.vf(encoder_out[ENCODER_OUT]) if self._is_discrete: action_dist = Categorical(action_logits) @@ -105,10 +88,9 @@ def output_specs_inference(self) -> List[str]: @override(TfRLModule) def _forward_inference(self, batch) -> Mapping[str, Any]: - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) + encoder_out = self.shared_encoder(batch) - action_logits = self.pi(encoder_out) + action_logits = self.pi(encoder_out[ENCODER_OUT]) if self._is_discrete: action = tf.math.argmax(action_logits, axis=-1) @@ -135,11 +117,10 @@ def output_specs_exploration(self) -> List[str]: @override(TfRLModule) def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) + encoder_out = self.shared_encoder(batch) - action_logits = self.pi(encoder_out) - vf = self.vf(encoder_out) + action_logits = self.pi(encoder_out[ENCODER_OUT]) + vf = self.vf(encoder_out[ENCODER_OUT]) if self._is_discrete: action_dist = Categorical(action_logits) @@ -180,14 +161,14 @@ def from_model_config( if use_lstm: raise ValueError("LSTM not supported by PPOTfRLModule yet.") if vf_share_layers: - shared_encoder_config = FCTfConfig( + shared_encoder_config = FCConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) else: - shared_encoder_config = IdentityTfConfig(output_dim=obs_dim) + shared_encoder_config = IdentityConfig(output_dim=obs_dim) assert isinstance( observation_space, gym.spaces.Box ), "This simple PPOModule only supports Box observation space." @@ -199,8 +180,8 @@ def from_model_config( assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.Box)), ( "This simple PPOModule only supports Discrete and Box action space.", ) - pi_config = FCTfConfig() - vf_config = FCTfConfig() + pi_config = FCConfig() + vf_config = FCConfig() shared_encoder_config.input_dim = observation_space.shape[0] pi_config.input_dim = shared_encoder_config.output_dim pi_config.hidden_layers = fcnet_hiddens @@ -212,7 +193,7 @@ def from_model_config( vf_config.input_dim = shared_encoder_config.output_dim vf_config.hidden_layers = fcnet_hiddens vf_config.output_dim = 1 - config_ = PPOTfModuleConfig( + config_ = PPOModuleConfig( pi_config=pi_config, vf_config=vf_config, shared_encoder_config=shared_encoder_config, diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b9ed311c077f..253db0629c67 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,15 +1,15 @@ -from dataclasses import dataclass from typing import Mapping, Any, Union import gymnasium as gym -from ray.rllib.core.rl_module.encoder import ( - FCConfig, - FCEncoderConfig, +from ray.rllib.core.rl_module.torch.encoder import ( + ENCODER_OUT, +) +from ray.rllib.core.rl_module.model_configs import ( LSTMEncoderConfig, - LSTMEncoder, ) -from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig +from ray.rllib.core.rl_module.model_configs import FCConfig, FCEncoderConfig +from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.base_model import STATE_OUT from ray.rllib.models.specs.specs_dict import SpecDict @@ -24,7 +24,7 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict -from rllib.core.rl_module.encoder import ENCODER_OUT +from rllib.algorithms.ppo.ppo import PPOModuleConfig torch, nn = try_import_torch() @@ -42,29 +42,6 @@ def get_ppo_loss(fwd_in, fwd_out): return loss -@dataclass -class PPOModuleConfig(RLModuleConfig): - """Configuration for the PPO module. - - Attributes: - observation_space: The observation space of the environment. - action_space: The action space of the environment. - shared_encoder_config: The configuration for the encoder network. - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. - free_log_std: For DiagGaussian action distributions, make the second half of - the model outputs floating bias variables instead of state-dependent. This - only has an effect is using the default fully connected net. - """ - - observation_space: gym.Space = None - action_space: gym.Space = None - shared_encoder_config: FCConfig = None - pi_config: FCConfig = None - vf_config: FCConfig = None - free_log_std: bool = False - - class PPOTorchRLModule(TorchRLModule): def __init__(self, config: PPOModuleConfig) -> None: super().__init__() @@ -78,9 +55,10 @@ def setup(self) -> None: "shared encoder config must be " "provided." ) - self.shared_encoder = self.config.shared_encoder_config.build() - self.pi = self.config.pi_config.build() - self.vf = self.config.vf_config.build() + # TODO(Artur): Unify to tf and torch setup(framework) + self.shared_encoder = self.config.shared_encoder_config.build(framework="torch") + self.pi = self.config.pi_config.build(framework="torch") + self.vf = self.config.vf_config.build(framework="torch") self._is_discrete = isinstance( convert_old_gym_space_to_gymnasium_space(self.config.action_space), @@ -177,7 +155,7 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if isinstance(self.shared_encoder, LSTMEncoder): + if hasattr(self.shared_encoder, "get_initial_state"): return self.shared_encoder.get_initial_state() else: return NestedDict({}) diff --git a/rllib/core/rl_module/encoder_tf.py b/rllib/core/rl_module/encoder_tf.py deleted file mode 100644 index 5c517f6c745d..000000000000 --- a/rllib/core/rl_module/encoder_tf.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - -from ray.rllib.core.rl_module.encoder import EncoderConfig -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.models.tf.primitives import FCNet, IdentityNetwork - -tf1, tf, tfv = try_import_tf() - - -@dataclass -class FCTfConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - output_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FCNet( - self.input_dim, self.hidden_layers, self.output_dim, self.activation - ) - - -@dataclass -class IdentityTfConfig(EncoderConfig): - """A network that returns the input as the output.""" - - def build(self): - return IdentityNetwork() diff --git a/rllib/core/rl_module/model_configs.py b/rllib/core/rl_module/model_configs.py new file mode 100644 index 000000000000..74398eefc5e6 --- /dev/null +++ b/rllib/core/rl_module/model_configs.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass, field +from typing import List +import functools + +from ray.rllib.models.torch.primitives import Identity +from ray.rllib.models.base_model import ModelConfig, Model + + +def check_framework(fn): + @functools.wraps(fn) + def checked_build(self, framework, **kwargs): + if framework not in ("torch", "tf", "tf2"): + raise ValueError(f"Framework {framework} not supported.") + return fn(self, framework, **kwargs) + + return checked_build + + +@dataclass +class FCConfig(ModelConfig): + """Configuration for a fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer (except for the + output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) + activation: str = "ReLU" + output_activation: str = "ReLU" + + @check_framework + def build(self, framework: str = "torch") -> Model: + if framework == "torch": + from ray.rllib.core.rl_module.torch.fcmodel import FCModel + else: + from ray.rllib.core.rl_module.tf.fcmodel import FCModel + return FCModel(self) + + +@dataclass +class FCEncoderConfig(FCConfig): + def build(self, framework: str = "torch"): + if framework == "torch": + from ray.rllib.core.rl_module.torch.encoder import FCEncoder + else: + from ray.rllib.core.rl_module.tf.encoder import FCEncoder + return FCEncoder(self) + + +@dataclass +class LSTMEncoderConfig(ModelConfig): + input_dim: int = None + hidden_dim: int = None + num_layers: int = None + batch_first: bool = True + + @check_framework + def build(self, framework: str = "torch"): + if not framework == "torch": + raise ValueError("Only torch framework supported.") + from rllib.core.rl_module.torch.encoder import LSTMEncoder + + return LSTMEncoder(self) + + +@dataclass +class IdentityConfig(ModelConfig): + """Configuration for an identity encoder.""" + + @check_framework + def build(self, framework: str = "torch"): + return Identity(self) diff --git a/rllib/core/rl_module/tf/encoder.py b/rllib/core/rl_module/tf/encoder.py new file mode 100644 index 000000000000..fd3cc94cd3cd --- /dev/null +++ b/rllib/core/rl_module/tf/encoder.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import tree + +from ray.rllib.models.base_model import ( + Model, + STATE_IN, + STATE_OUT, + ForwardOutputType, + ModelConfig, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.rnn_sequencing import add_time_dimension +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_tf import TFTensorSpecs +from ray.rllib.core.rl_module.tf.fcmodel import FCModel +from ray.rllib.core.rl_module.torch.encoder import ENCODER_OUT + + +class FCEncoder(FCModel): + @property + def input_spec(self): + return SpecDict( + {SampleBatch.OBS: TFTensorSpecs("b, h", h=self.config.input_dim)} + ) + + @property + def output_spec(self): + return SpecDict({ENCODER_OUT: TFTensorSpecs("b, h", h=self.config.output_dim)}) + + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} + + +class LSTMEncoder(Model, nn.Module): + def __init__(self, config: ModelConfig) -> None: + nn.Module.__init__(self) + Model.__init__(self, config) + + self.lstm = nn.LSTM( + config.input_dim, + config.hidden_dim, + config.num_layers, + batch_first=config.batch_first, + ) + self.linear = nn.Linear(config.hidden_dim, config.output_dim) + + def get_initial_state(self): + config = self.config + return { + "h": torch.zeros(config.num_layers, config.hidden_dim), + "c": torch.zeros(config.num_layers, config.hidden_dim), + } + + @property + def input_spec(self): + config = self.config + return SpecDict( + { + # bxt is just a name for better readability to indicated padded batch + SampleBatch.OBS: TFTensorSpecs("bxt, h", h=config.input_dim), + STATE_IN: { + "h": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + SampleBatch.SEQ_LENS: None, + } + ) + + @property + def output_spec(self): + config = self.config + return SpecDict( + { + ENCODER_OUT: TFTensorSpecs("bxt, h", h=config.output_dim), + STATE_OUT: { + "h": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + } + ) + + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + x = inputs[SampleBatch.OBS] + states = inputs[STATE_IN] + # states are batch-first when coming in + states = tree.map_structure(lambda x: x.transpose(0, 1), states) + + x = add_time_dimension( + x, + seq_lens=inputs[SampleBatch.SEQ_LENS], + framework="torch", + time_major=not self.config.batch_first, + ) + states_o = {} + x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) + + x = self.linear(x) + x = x.view(-1, x.shape[-1]) + + return { + ENCODER_OUT: x, + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + } + + +class Identity(Model): + def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + pass + + def __init__(self, config: ModelConfig) -> None: + super().__init__(config) + + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return inputs diff --git a/rllib/core/rl_module/tf/fcmodel.py b/rllib/core/rl_module/tf/fcmodel.py new file mode 100644 index 000000000000..5a0c6eca4939 --- /dev/null +++ b/rllib/core/rl_module/tf/fcmodel.py @@ -0,0 +1,34 @@ +import torch.nn as nn + +from ray.rllib.models.base_model import Model, ForwardOutputType +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_tf import TFTensorSpecs +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.tf.primitives import FCNet +from rllib.models.base_model import ModelConfig + + +class FCModel(Model, nn.Module): + def __init__(self, config: ModelConfig) -> None: + nn.Module.__init__(self) + Model.__init__(self, config) + + self.net = FCNet( + input_dim=config.input_dim, + hidden_layers=config.hidden_layers, + output_dim=config.output_dim, + activation=config.activation, + ) + + @property + def input_spec(self): + return TFTensorSpecs("b, h", h=self.config.input_dim) + + @property + def output_spec(self): + return TFTensorSpecs("b, h", h=self.config.output_dim) + + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return self.net(inputs) diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/torch/encoder.py similarity index 85% rename from rllib/core/rl_module/encoder.py rename to rllib/core/rl_module/torch/encoder.py index ca2e85542e60..ff6d2fa7e24d 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/torch/encoder.py @@ -2,14 +2,12 @@ import torch.nn as nn import tree -from dataclasses import dataclass - from ray.rllib.models.base_model import ( - ModelConfig, Model, STATE_IN, STATE_OUT, ForwardOutputType, + ModelConfig, ) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch @@ -17,19 +15,12 @@ from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.core.rl_module.fc import FC, FCConfig - +from ray.rllib.core.rl_module.torch.fcmodel import FCModel ENCODER_OUT: str = "encoder_out" -@dataclass -class FCEncoderConfig(FCConfig): - def build(self): - return FCEncoder(self) - - -class FCEncoder(FC): +class FCEncoder(FCModel): @property def input_spec(self): return SpecDict( @@ -48,19 +39,8 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -@dataclass -class LSTMEncoderConfig(ModelConfig): - input_dim: int = None - hidden_dim: int = None - num_layers: int = None - batch_first: bool = True - - def build(self): - return LSTMEncoder(self) - - class LSTMEncoder(Model, nn.Module): - def __init__(self, config: LSTMEncoderConfig) -> None: + def __init__(self, config: ModelConfig) -> None: nn.Module.__init__(self) Model.__init__(self, config) @@ -143,19 +123,11 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -@dataclass -class IdentityConfig(ModelConfig): - """Configuration for an identity encoder.""" - - def build(self): - return IdentityEncoder(self) - - -class IdentityEncoder(Model): +class Identity(Model): def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: pass - def __init__(self, config: IdentityConfig) -> None: + def __init__(self, config: ModelConfig) -> None: super().__init__(config) def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: diff --git a/rllib/core/rl_module/fc.py b/rllib/core/rl_module/torch/fcmodel.py similarity index 52% rename from rllib/core/rl_module/fc.py rename to rllib/core/rl_module/torch/fcmodel.py index 2e037fa03d40..4b3c1ba00f4f 100644 --- a/rllib/core/rl_module/fc.py +++ b/rllib/core/rl_module/torch/fcmodel.py @@ -1,39 +1,15 @@ -from dataclasses import dataclass -from dataclasses import field -from typing import List - import torch.nn as nn -from ray.rllib.models.base_model import Model, ModelConfig, ForwardOutputType +from ray.rllib.models.base_model import Model, ForwardOutputType from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.models.torch.primitives import FCNet +from rllib.models.base_model import ModelConfig -@dataclass -class FCConfig(ModelConfig): - """Configuration for a fully connected network. - - Attributes: - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - output_activation: str = "ReLU" - - def build(self) -> Model: - return FC(self) - - -class FC(Model, nn.Module): - def __init__(self, config: FCConfig) -> None: +class FCModel(Model, nn.Module): + def __init__(self, config: ModelConfig) -> None: nn.Module.__init__(self) Model.__init__(self, config) diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index cadf65539a31..4024dd59fa80 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -337,6 +337,6 @@ class ModelConfig(abc.ABC): output_dim: int = None @abc.abstractmethod - def build(self) -> RecurrentModel: + def build(self, framework: str = "torch") -> RecurrentModel: """Builds the model.""" raise NotImplementedError diff --git a/rllib/models/configs/encoder.py b/rllib/models/configs/encoder.py deleted file mode 100644 index 38a7f305123a..000000000000 --- a/rllib/models/configs/encoder.py +++ /dev/null @@ -1,83 +0,0 @@ -import abc -from dataclasses import dataclass -from typing import TYPE_CHECKING, Tuple - -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.torch.encoders.vector import TorchVectorEncoder - -if TYPE_CHECKING: - from ray.rllib.models.torch.encoders.vector import Encoder - - -@dataclass -class EncoderConfig: - """The base config for encoder models. - - Each config should define a `build` method that builds a model from the config. - - All user-configurable parameters known before runtime - (e.g. framework, activation, num layers, etc.) should be defined as attributes. - - Parameters unknown before runtime (e.g. the output size of the module providing - input for this module) should be passed as arguments to `build`. This should be - as few params as possible. - - `build` should return an instance of the encoder associated with the config. - - Attributes: - framework_str: The tensor framework to construct a model for. - This can be 'torch', 'tf2', or 'jax'. - """ - - framework_str: str = "torch" - - @abc.abstractmethod - def build(self, input_spec: SpecDict, **kwargs) -> "Encoder": - """Builds the EncoderConfig into an Encoder instance""" - - -@dataclass -class VectorEncoderConfig(EncoderConfig): - """An MLP encoder mappings tensors with shape [..., feature] to [..., output]. - - Attributes: - activation: The type of activation function to use between hidden layers. - Options are 'relu', 'swish', 'tanh', or 'linear' - final_activation: The activation function to use after the final linear layer. - Options are the same as for activation. - hidden_layer_sizes: A list, where each element represents the number of neurons - in that layer. For example, [128, 64] would produce a two-layer MLP with - 128 hidden neurons and 64 hidden neurons. - output_key: Write the output of the encoder to this key in the NestedDict. - """ - - activation: str = "relu" - final_activation: str = "linear" - hidden_layer_sizes: Tuple[int, ...] = (128, 128) - output_key: str = "encoding" - - def build(self, input_spec: SpecDict) -> TorchVectorEncoder: - """Build the config into a VectorEncoder model instance. - - Args: - input_spec: The output spec of the previous module(s) that will feed - inputs to this encoder. - - Returns: - A VectorEncoder of the specified framework. - """ - assert ( - len(self.hidden_layer_sizes) > 1 - ), "Must have at least a single hidden layer" - for k in input_spec.shallow_keys(): - assert isinstance( - input_spec[k].shape[-1], int - ), "Input spec {k} does not define the size of the feature (last) dimension" - - if self.framework_str == "torch": - return TorchVectorEncoder(input_spec, self) - else: - raise NotImplementedError( - "{self.__class__.__name__} not implemented" - " for framework {self.framework}" - ) diff --git a/rllib/models/tf/primitives.py b/rllib/models/tf/primitives.py index 395ce9863135..48b96ff1237e 100644 --- a/rllib/models/tf/primitives.py +++ b/rllib/models/tf/primitives.py @@ -1,11 +1,10 @@ from typing import List from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf +from rllib.models.torch.primitives import Identity _, tf, _ = try_import_tf() -# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. - class FCNet(tf.keras.Model): """A simple fully connected network. @@ -47,9 +46,5 @@ def call(self, inputs, training=None, mask=None): return self.network(inputs) -class IdentityNetwork(tf.keras.Model): - """A network that returns the input as the output.""" - - @override(tf.keras.Model) - def call(self, inputs, training=None, mask=None): - return inputs +# Reuse our torch IdentityEncoder here because it is not framework-specific. +Identity = Identity diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py index 191a0ff35e5a..46d0ae8a7fc6 100644 --- a/rllib/models/torch/primitives.py +++ b/rllib/models/torch/primitives.py @@ -1,10 +1,12 @@ from typing import List, Optional + +from ray.rllib.models.base_model import Model, ForwardOutputType +from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.utils.framework import try_import_torch +from rllib.models.base_model import ModelConfig torch, nn = try_import_torch() -# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. - class FCNet(nn.Module): """A simple fully connected network. @@ -52,3 +54,11 @@ def __init__( def forward(self, x): return self.layers(x) + + +class Identity(Model): + def __init__(self, config: ModelConfig) -> None: + super().__init__(config) + + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return inputs From 6ef042ed462a1fb003b3fc616c1bb208aeb9678a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 16:01:12 -0800 Subject: [PATCH 23/51] some larger refactors Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 2 +- .../ppo/tests/test_ppo_rl_module.py | 8 +- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 8 +- .../ppo/torch/ppo_torch_rl_module.py | 16 +- rllib/models/base_model.py | 342 ------------------ rllib/models/experimental/README.rst | 2 + rllib/models/experimental/__init__.py | 0 rllib/models/experimental/base.py | 152 ++++++++ .../experimental}/model_configs.py | 28 +- rllib/models/experimental/tf/__init__.py | 0 .../experimental}/tf/encoder.py | 56 ++- .../experimental}/tf/fcmodel.py | 18 +- rllib/models/experimental/tf/primitives.py | 86 +++++ rllib/models/experimental/torch/__init__.py | 0 .../experimental}/torch/encoder.py | 52 ++- .../experimental}/torch/fcmodel.py | 14 +- .../{ => experimental}/torch/primitives.py | 53 ++- rllib/models/specs/checker.py | 12 + rllib/models/tf/primitives.py | 50 --- 19 files changed, 432 insertions(+), 467 deletions(-) delete mode 100644 rllib/models/base_model.py create mode 100644 rllib/models/experimental/README.rst create mode 100644 rllib/models/experimental/__init__.py create mode 100644 rllib/models/experimental/base.py rename rllib/{core/rl_module => models/experimental}/model_configs.py (69%) create mode 100644 rllib/models/experimental/tf/__init__.py rename rllib/{core/rl_module => models/experimental}/tf/encoder.py (68%) rename rllib/{core/rl_module => models/experimental}/tf/fcmodel.py (67%) create mode 100644 rllib/models/experimental/tf/primitives.py create mode 100644 rllib/models/experimental/torch/__init__.py rename rllib/{core/rl_module => models/experimental}/torch/encoder.py (73%) rename rllib/{core/rl_module => models/experimental}/torch/fcmodel.py (69%) rename rllib/models/{ => experimental}/torch/primitives.py (53%) delete mode 100644 rllib/models/tf/primitives.py diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index fee7cdc0429b..ff370dd00995 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -15,7 +15,7 @@ import gymnasium as gym -from ray.rllib.core.rl_module.model_configs import FCConfig +from ray.rllib.models.experimental.model_configs import FCConfig from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.util.debug import log_once diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 95546b3c6ea7..80c8d2f72513 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -11,14 +11,14 @@ ) from rllib.algorithms.ppo.ppo import PPOModuleConfig from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( - PPOTfModule, + PPOTfRLModule, ) -from ray.rllib.core.rl_module.model_configs import ( +from ray.rllib.models.experimental.model_configs import ( FCConfig, FCEncoderConfig, LSTMEncoderConfig, ) -from ray.rllib.core.rl_module.torch.encoder import ( +from ray.rllib.models.experimental.torch.encoder import ( STATE_IN, STATE_OUT, ) @@ -144,7 +144,7 @@ def get_ppo_module(self, framework, env, lstm): if framework == "torch": module = PPOTorchRLModule(config) else: - module = PPOTfModule(config) + module = PPOTfRLModule(config) return module def get_input_batch_from_obs(self, framework, obs): diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 28230091a5cb..760017c38825 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -3,14 +3,14 @@ from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.core.rl_module.model_configs import FCConfig, IdentityConfig +from ray.rllib.models.experimental.model_configs import FCConfig, IdentityConfig from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian -from ray.rllib.models.tf.primitives import FCNet -from ray.rllib.core.rl_module.tf.encoder import ENCODER_OUT +from ray.rllib.models.experimental.tf.primitives import FCNet +from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT from ray.rllib.algorithms.ppo.ppo import PPOModuleConfig @@ -18,7 +18,7 @@ tf1.enable_eager_execution() -class PPOTfModule(TfRLModule): +class PPOTfRLModule(TfRLModule): def __init__(self, config: RLModuleConfig): super().__init__() self.config = config diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 253db0629c67..4c4839dfb1e1 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -2,16 +2,16 @@ import gymnasium as gym -from ray.rllib.core.rl_module.torch.encoder import ( - ENCODER_OUT, -) -from ray.rllib.core.rl_module.model_configs import ( - LSTMEncoderConfig, -) -from ray.rllib.core.rl_module.model_configs import FCConfig, FCEncoderConfig from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule -from ray.rllib.models.base_model import STATE_OUT +from ray.rllib.models.experimental.base import STATE_OUT +from ray.rllib.models.experimental.model_configs import FCConfig, FCEncoderConfig +from ray.rllib.models.experimental.model_configs import ( + LSTMEncoderConfig, +) +from ray.rllib.models.experimental.torch.encoder import ( + ENCODER_OUT, +) from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.torch_distributions import ( diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py deleted file mode 100644 index 4024dd59fa80..000000000000 --- a/rllib/models/base_model.py +++ /dev/null @@ -1,342 +0,0 @@ -import abc -from typing import Optional, Tuple -from ray.rllib.models.specs.specs_dict import SpecDict -from dataclasses import dataclass - -from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig -from ray.rllib.utils.annotations import ( - DeveloperAPI, - OverrideToImplementCustomLogic, - override, - ExperimentalAPI, -) - - -ForwardOutputType = TensorDict -# [Output, Recurrent State(s)] -UnrollOutputType = Tuple[TensorDict, TensorDict] - -STATE_IN: str = "state_in" -STATE_OUT: str = "state_out" - - -@ExperimentalAPI -class RecurrentModel(abc.ABC): - """The base model all other models are based on. - - RLlib models all inherit from the recurrent base class, which can be chained - together with other models. - - The models input and output TensorDicts. Which keys the models read/write to - and the desired tensor shapes must be defined in input_spec, output_spec, - prev_state_spec, and next_state_spec. - - The `unroll` function gets the model inputs and previous recurrent state, and - outputs the model outputs and next recurrent state. Note all ins/outs must match - the specs. Users should override `_unroll` rather than `unroll`. - - `initial_state` returns the "next" state for the first recurrent iteration. Again, - users should override `_initial_state` instead. - - For non-recurrent models, users may use Model instead, and override - `_forward` which does not make use of recurrent states. - - Args: - name: An optional name for the module - """ - - def __init__(self, config: ModelConfig, name: Optional[str] = None): - self._name = name or self.__class__.__name__ - self.config = config - - @property - def name(self) -> str: - """Returns the name of this module.""" - return self._name - - @property - @abc.abstractmethod - def input_spec(self) -> SpecDict: - """Returns the spec of the input of this module.""" - - @property - @abc.abstractmethod - def prev_state_spec(self) -> SpecDict: - """Returns the spec of the prev_state of this module.""" - - @property - @abc.abstractmethod - def output_spec(self) -> SpecDict: - """Returns the spec of the output of this module.""" - - @property - @abc.abstractmethod - def next_state_spec(self) -> SpecDict: - """Returns the spec of the next_state of this module.""" - - @abc.abstractmethod - def _initial_state(self) -> TensorDict: - """Initial state of the component. - If this component returns a next_state in its unroll function, then - this function provides the initial state. - Subclasses should override this function instead of `initial_state`, which - adds additional checks. - - Returns: - A TensorDict containing the state before the first step. - """ - - @DeveloperAPI - def initial_state(self) -> TensorDict: - """Initial state of the component. - If this component returns a next_state in its unroll function, then - this function provides the initial state. - - Returns: - A TensorDict containing the state before the first step. - - Examples: - >>> state = model.initial_state() - >>> state # TensorDict(...) - """ - initial_state = self._initial_state() - self.next_state_spec.validate(initial_state) - return initial_state - - @abc.abstractmethod - def _unroll( - self, inputs: TensorDict, prev_state: TensorDict, **kwargs - ) -> UnrollOutputType: - """Computes the output of the module over the timesteps within the batch. - Subclasses should override this function instead of `unroll`, which - adds additional checks. - - Args: - inputs: A TensorDict of inputs - prev_state: A TensorDict containing the next_state of the last - timestep of the previous unroll. - kwargs: For forwards compatibility - - Returns: - outputs: A TensorDict of outputs - next_state: A dict containing the state to be passed - as the first state of the next rollout. - """ - - @DeveloperAPI - def unroll( - self, inputs: TensorDict, prev_state: TensorDict, **kwargs - ) -> UnrollOutputType: - """Computes the output of the module over the timesteps within the batch. - - Args: - inputs: A TensorDict containing inputs to the model - prev_state: A TensorDict containing containing the - next_state of the last timestep of the previous unroll. - kwargs: For forwards compatibility - - Returns: - outputs: A TensorDict containing model outputs - next_state: A TensorDict containing the - state to be passed as the first state of the next rollout. - - Examples: - >>> output, state = model.unroll(TensorDict(...), TensorDict(...)) - >>> output # TensorDict(...) - >>> state # TensorDict(...) - - """ - self.input_spec.validate(inputs) - self.prev_state_spec.validate(prev_state) - # We hide inputs not specified in input_spec to prevent accidental use. - inputs = inputs.filter(self.input_spec) - prev_state = prev_state.filter(self.prev_state_spec) - inputs, prev_state = self._update_inputs_and_prev_state(inputs, prev_state) - outputs, next_state = self._unroll(inputs, prev_state, **kwargs) - self.output_spec.validate(outputs) - self.next_state_spec.validate(next_state) - outputs, next_state = self._update_outputs_and_next_state(outputs, next_state) - return outputs, next_state - - @OverrideToImplementCustomLogic - def _update_inputs_and_prev_state( - self, inputs: TensorDict, prev_state: TensorDict - ) -> Tuple[TensorDict, TensorDict]: - """Override this function to add additional checks and optionally update inputs. - - Args: - inputs: TensorDict containing inputs to the model - prev_state: The previous recurrent state - - Returns: - inputs: Potentially modified inputs - prev_state: Potentially modified recurrent state - """ - return inputs, prev_state - - @OverrideToImplementCustomLogic - def _update_outputs_and_next_state( - self, outputs: TensorDict, next_state: TensorDict - ) -> Tuple[TensorDict, TensorDict]: - """Override this function to add additional checks and optionally update - outputs. - - Args: - outputs: TensorDict output by the model - next_state: Recurrent state output by the model - - Returns: - outputs: Potentially modified TensorDict output by the model - next_state: Potentially modified recurrent state output by the model - """ - return outputs, next_state - - -class Model(RecurrentModel): - """A RecurrentModel made non-recurrent by ignoring - the input/output states. - - As a convienience, users may override _forward instead of _unroll, - which hides model states. - - Args: - name: An optional name for the module - """ - - @property - @override(RecurrentModel) - def prev_state_spec(self) -> SpecDict: - return SpecDict() - - @property - @override(RecurrentModel) - def next_state_spec(self) -> SpecDict: - return SpecDict() - - @override(RecurrentModel) - def _initial_state(self) -> TensorDict: - return TensorDict() - - @override(RecurrentModel) - def _update_inputs_and_prev_state( - self, inputs: TensorDict, prev_state: TensorDict - ) -> Tuple[TensorDict, TensorDict]: - inputs = self._update_inputs(inputs) - return inputs, prev_state - - @OverrideToImplementCustomLogic - def _update_inputs(self, inputs: TensorDict) -> TensorDict: - """Override this function to add additional checks and optionally update inputs. - - Args: - inputs: TensorDict containing inputs to the model - - Returns: - inputs: Potentially modified inputs - """ - return inputs - - @override(RecurrentModel) - def _update_outputs_and_next_state( - self, outputs: TensorDict, next_state: TensorDict - ) -> Tuple[TensorDict, TensorDict]: - outputs = self._update_outputs(outputs) - return outputs, next_state - - @OverrideToImplementCustomLogic - def _update_outputs(self, outputs: TensorDict) -> TensorDict: - """Override this function to add additional checks and optionally update - outputs. - - Args: - outputs: TensorDict output by the model - - Returns: - outputs: Potentially modified TensorDict output by the model - """ - return outputs - - @override(RecurrentModel) - def _unroll( - self, inputs: TensorDict, prev_state: TensorDict, **kwargs - ) -> UnrollOutputType: - outputs = self._forward(inputs, **kwargs) - return outputs, TensorDict() - - @abc.abstractmethod - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - """Computes the output of this module for each timestep. - - Args: - inputs: A TensorDict containing model inputs - kwargs: For forwards compatibility - - Returns: - outputs: A TensorDict containing model outputs - - Examples: - # This is abstract, see the torch/tf/jax implementations - >>> out = model._forward(TensorDict({"in": np.arange(10)})) - >>> out # TensorDict(...) - """ - - -@ExperimentalAPI -class ModelIO(abc.ABC): - """Abstract class defining how to save and load model weights. - - Args: - config: The ModelConfig passed to the underlying model - """ - - def __init__(self, config: ModelConfig) -> None: - self._config = config - - @DeveloperAPI - @property - def config(self) -> ModelConfig: - return self._config - - @DeveloperAPI - @abc.abstractmethod - def save(self, path: str) -> None: - """Save model weights to a path. - - Args: - path: The path on disk where weights are to be saved - - Examples: - model.save("/tmp/model_path.cpt") - """ - raise NotImplementedError - - @DeveloperAPI - @abc.abstractmethod - def load(self, path: str) -> RecurrentModel: - """Load model weights from a path. - - Args: - path: The path on disk where to load weights from - - Examples: - model.load("/tmp/model_path.cpt") - """ - raise NotImplementedError - - -@ExperimentalAPI -@dataclass -class ModelConfig(abc.ABC): - """Configuration for an encoder network. - - Attributes: - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. - """ - - output_dim: int = None - - @abc.abstractmethod - def build(self, framework: str = "torch") -> RecurrentModel: - """Builds the model.""" - raise NotImplementedError diff --git a/rllib/models/experimental/README.rst b/rllib/models/experimental/README.rst new file mode 100644 index 000000000000..a45919dadc77 --- /dev/null +++ b/rllib/models/experimental/README.rst @@ -0,0 +1,2 @@ +This folder holds models that are under development and to be used with RLModules in upcoming versions of RLLib. +They are not yet ready for use in the current version of RLLib. \ No newline at end of file diff --git a/rllib/models/experimental/__init__.py b/rllib/models/experimental/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py new file mode 100644 index 000000000000..2786f596ba71 --- /dev/null +++ b/rllib/models/experimental/base.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +import abc + +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, + is_input_decorated, + is_output_decorated, +) +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.utils.typing import TensorType +from ray.rllib.utils.annotations import ExperimentalAPI + +ForwardOutputType = TensorDict +STATE_IN: str = "state_in" +STATE_OUT: str = "state_out" + + +def _not_decorated_message(input_or_output): + return ( + f"__call__ not decorated with {input_or_output} specification. Decorate " + f"with @check_{input_or_output}_specs() to define a specification. See " + f"BaseModel for examples." + ) + + +@ExperimentalAPI +@dataclass +class ModelConfig(abc.ABC): + """Configuration for a model. + + Attributes: + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + """ + + output_dim: int = None + + @abc.abstractmethod + def build(self, framework: str = "torch"): + """Builds the model. + + Args: + framework: The framework to use for building the model. + """ + raise NotImplementedError + + +class Model: + """Base class for RLlib models.""" + + def __init__(self, config: ModelConfig): + self.config = config + + def get_initial_state(self): + """Returns the initial state of the model.""" + return {} + + @property + @abc.abstractmethod + def output_spec(self) -> SpecDict: + """Returns the outputs spec of this model. + + This can include the state specs as well. + + Examples: + >>> ... + """ + # If no checking is needed, we can simply return an empty spec. + return SpecDict() + + @property + @abc.abstractmethod + def input_spec(self) -> SpecDict: + """Returns the input spec of this model. + + This can include the state specs as well. + + Examples: + >>> ... + """ + # If no checking is needed, we can simply return an empty spec. + return SpecDict() + + @check_input_specs("input_spec", filter=True, cache=True) + @check_output_specs("output_spec", cache=True) + @abc.abstractmethod + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + """Computes the output of this module for each timestep. + + Outputs and inputs should be subjected to spec checking. + + Args: + inputs: A TensorDict containing model inputs + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + + Examples: + # This is abstract, see the torch/tf/jax implementations + >>> out = model(TensorDict({"in": np.arange(10)})) + >>> out # TensorDict(...) + """ + raise NotImplementedError + + +class Encoder(Model): + """The base class for all encoders Rllib produces. + + Encoders are used to encode observations into a latent space in RLModules. + Therefore, their input_spec usually contains the observation space dimensions. + Their output_spec usually contains the latent space dimensions. + Encoders can be recurrent, in which case they should also have state_specs. + """ + + def __init__(self, config: dict): + super().__init__(config) + + def get_initial_state(self) -> TensorType: + """Returns the initial state of the encoder. + + This is the initial state of the encoder. + It can be left empty if this encoder is not stateful. + + Examples: + >>> ... + """ + return {} + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) + @abc.abstractmethod + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + """Computes the output of this module for each timestep. + + Outputs and inputs are subjected to spec checking. + + Args: + inputs: A TensorDict containing model inputs + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + + Examples: + # This is abstract, see the torch/tf/jax implementations + >>> out = model(TensorDict({"in": np.arange(10)})) + >>> out # TensorDict(...) + """ + raise NotImplementedError diff --git a/rllib/core/rl_module/model_configs.py b/rllib/models/experimental/model_configs.py similarity index 69% rename from rllib/core/rl_module/model_configs.py rename to rllib/models/experimental/model_configs.py index 74398eefc5e6..bb9175048a8a 100644 --- a/rllib/core/rl_module/model_configs.py +++ b/rllib/models/experimental/model_configs.py @@ -2,11 +2,10 @@ from typing import List import functools -from ray.rllib.models.torch.primitives import Identity -from ray.rllib.models.base_model import ModelConfig, Model +from ray.rllib.models.experimental.base import ModelConfig, Model -def check_framework(fn): +def _check_framework(fn): @functools.wraps(fn) def checked_build(self, framework, **kwargs): if framework not in ("torch", "tf", "tf2"): @@ -33,12 +32,12 @@ class FCConfig(ModelConfig): activation: str = "ReLU" output_activation: str = "ReLU" - @check_framework + @_check_framework def build(self, framework: str = "torch") -> Model: if framework == "torch": - from ray.rllib.core.rl_module.torch.fcmodel import FCModel + from ray.rllib.models.experimental.torch.fcmodel import FCModel else: - from ray.rllib.core.rl_module.tf.fcmodel import FCModel + from ray.rllib.models.experimental.tf.fcmodel import FCModel return FCModel(self) @@ -46,9 +45,9 @@ def build(self, framework: str = "torch") -> Model: class FCEncoderConfig(FCConfig): def build(self, framework: str = "torch"): if framework == "torch": - from ray.rllib.core.rl_module.torch.encoder import FCEncoder + from ray.rllib.models.experimental.torch.encoder import FCEncoder else: - from ray.rllib.core.rl_module.tf.encoder import FCEncoder + from ray.rllib.models.experimental.tf.encoder import FCEncoder return FCEncoder(self) @@ -59,11 +58,11 @@ class LSTMEncoderConfig(ModelConfig): num_layers: int = None batch_first: bool = True - @check_framework + @_check_framework def build(self, framework: str = "torch"): if not framework == "torch": raise ValueError("Only torch framework supported.") - from rllib.core.rl_module.torch.encoder import LSTMEncoder + from rllib.models.experimental.torch.encoder import LSTMEncoder return LSTMEncoder(self) @@ -72,6 +71,11 @@ def build(self, framework: str = "torch"): class IdentityConfig(ModelConfig): """Configuration for an identity encoder.""" - @check_framework + @_check_framework def build(self, framework: str = "torch"): - return Identity(self) + if framework == "torch": + from rllib.models.experimental.torch.encoder import IdentityEncoder + else: + from rllib.models.experimental.tf.encoder import IdentityEncoder + + return IdentityEncoder(self) diff --git a/rllib/models/experimental/tf/__init__.py b/rllib/models/experimental/tf/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/core/rl_module/tf/encoder.py b/rllib/models/experimental/tf/encoder.py similarity index 68% rename from rllib/core/rl_module/tf/encoder.py rename to rllib/models/experimental/tf/encoder.py index fd3cc94cd3cd..d70369c25b37 100644 --- a/rllib/core/rl_module/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -2,8 +2,9 @@ import torch.nn as nn import tree -from ray.rllib.models.base_model import ( +from ray.rllib.models.experimental.base import ( Model, + Encoder, STATE_IN, STATE_OUT, ForwardOutputType, @@ -11,15 +12,29 @@ ) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.experimental.tf.primitives import FCNet from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_tf import TFTensorSpecs -from ray.rllib.core.rl_module.tf.fcmodel import FCModel -from ray.rllib.core.rl_module.torch.encoder import ENCODER_OUT +from ray.rllib.models.experimental.torch.encoder import ENCODER_OUT +from ray.rllib.models.experimental.tf.primitives import TFModel -class FCEncoder(FCModel): +class FCEncoder(Encoder, TFModel): + """A fully connected encoder.""" + + def __init__(self, config: ModelConfig) -> None: + Encoder.__init__(self, config) + TFModel.__init__(self, config) + + self.net = FCNet( + input_dim=config.input_dim, + hidden_layers=config.hidden_layers, + output_dim=config.output_dim, + activation=config.activation, + ) + @property def input_spec(self): return SpecDict( @@ -30,16 +45,18 @@ def input_spec(self): def output_spec(self): return SpecDict({ENCODER_OUT: TFTensorSpecs("b, h", h=self.config.output_dim)}) - @check_input_specs("input_spec", filter=True, cache=False) + @check_input_specs("input_spec", cache=False) @check_output_specs("output_spec", cache=False) - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(Model, nn.Module): +class LSTMEncoder(Encoder, TFModel): + """An encoder that uses an LSTM cell and a linear layer.""" + def __init__(self, config: ModelConfig) -> None: - nn.Module.__init__(self) - Model.__init__(self, config) + Encoder.__init__(self, config) + TFModel.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -94,7 +111,7 @@ def output_spec(self): @check_input_specs("input_spec", filter=True, cache=False) @check_output_specs("output_spec", cache=False) - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: x = inputs[SampleBatch.OBS] states = inputs[STATE_IN] # states are batch-first when coming in @@ -118,12 +135,25 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class Identity(Model): +class IdentityEncoder(TFModel): def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: pass def __init__(self, config: ModelConfig) -> None: super().__init__(config) - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - return inputs + @property + def input_spec(self): + return SpecDict( + # Use the output dim as input dim because identity. + {SampleBatch.OBS: TFTensorSpecs("b, h", h=self.config.output_dim)} + ) + + @property + def output_spec(self): + return SpecDict({ENCODER_OUT: TFTensorSpecs("b, h", h=self.config.output_dim)}) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: inputs[SampleBatch.OBS]} diff --git a/rllib/core/rl_module/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py similarity index 67% rename from rllib/core/rl_module/tf/fcmodel.py rename to rllib/models/experimental/tf/fcmodel.py index 5a0c6eca4939..ff9faf2780bf 100644 --- a/rllib/core/rl_module/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -1,17 +1,17 @@ -import torch.nn as nn - -from ray.rllib.models.base_model import Model, ForwardOutputType from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_tf import TFTensorSpecs +from ray.rllib.utils import try_import_tf from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.tf.primitives import FCNet -from rllib.models.base_model import ModelConfig +from ray.rllib.models.tf.primitives import FCNet, TFModel +from rllib.models.experimental.base import ModelConfig, ForwardOutputType + +tf1, tf, tfv = try_import_tf() -class FCModel(Model, nn.Module): +class FCModel(tf.Module, TFModel): def __init__(self, config: ModelConfig) -> None: - nn.Module.__init__(self) - Model.__init__(self, config) + tf.Module.__init__(self) + TFModel.__init__(self, config) self.net = FCNet( input_dim=config.input_dim, @@ -30,5 +30,5 @@ def output_spec(self): @check_input_specs("input_spec", filter=True, cache=False) @check_output_specs("output_spec", cache=False) - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return self.net(inputs) diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py new file mode 100644 index 000000000000..71ddd8393f6a --- /dev/null +++ b/rllib/models/experimental/tf/primitives.py @@ -0,0 +1,86 @@ +from typing import List +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import ( + is_input_decorated, + is_output_decorated, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.experimental.base import ModelConfig, Model, ForwardOutputType +from ray.rllib.utils.typing import TensorType +from typing import Tuple + +_, tf, _ = try_import_tf() + + +def _call_not_decorated(input_or_output): + return ( + f"forward not decorated with {input_or_output} specification. Decorate " + f"with @check_{input_or_output}_specs() to define a specification. See " + f"BaseModel for examples." + ) + + +class TFModel(Model): + """Base class for RLlib models. + + This class is used to define the general interface for RLlib models and checks + whether inputs and outputs are checked with `check_input_specs()` and + `check_output_specs()` respectively. + """ + + def __init__(self, config): + super().__init__(config) + assert is_input_decorated(self.__call__), _call_not_decorated("input") + assert is_output_decorated(self.__call__), _call_not_decorated("output") + + def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: + """Returns the output of this model for the given input. + + Args: + input_dict: The input tensors. + + Returns: + Tuple[TensorDict, List[TensorType]]: The output tensors. + """ + raise NotImplementedError + + +class FCNet(tf.Module): + """A simple fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + output_dim: The output dimension of the network. + activation: The activation function to use after each layer. + Currently "Linear" (no activation) and "ReLU" are supported. + """ + + def __init__( + self, + input_dim: int, + hidden_layers: List[int], + output_dim: int, + activation: str = "linear", + ): + super().__init__() + + assert activation in ("linear", "ReLU", "Tanh"), ( + "Activation function not " "supported" + ) + assert input_dim is not None, "Input dimension must not be None" + assert output_dim is not None, "Output dimension must not be None" + layers = [] + activation = activation.lower() + # input = tf.keras.layers.Dense(input_dim, activation=activation) + layers.append(tf.keras.Input(shape=(input_dim,))) + for i in range(len(hidden_layers)): + layers.append( + tf.keras.layers.Dense(hidden_layers[i], activation=activation) + ) + layers.append(tf.keras.layers.Dense(output_dim)) + self.network = tf.keras.Sequential(layers) + + def __call__(self, inputs): + return self.network(inputs) diff --git a/rllib/models/experimental/torch/__init__.py b/rllib/models/experimental/torch/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/core/rl_module/torch/encoder.py b/rllib/models/experimental/torch/encoder.py similarity index 73% rename from rllib/core/rl_module/torch/encoder.py rename to rllib/models/experimental/torch/encoder.py index ff6d2fa7e24d..7578b3c1ec4f 100644 --- a/rllib/core/rl_module/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -2,8 +2,7 @@ import torch.nn as nn import tree -from ray.rllib.models.base_model import ( - Model, +from ray.rllib.models.experimental.base import ( STATE_IN, STATE_OUT, ForwardOutputType, @@ -11,23 +10,40 @@ ) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.core.rl_module.torch.fcmodel import FCModel +from ray.rllib.models.experimental.torch.primitives import FCNet, TorchModel +from ray.rllib.models.experimental.base import Encoder ENCODER_OUT: str = "encoder_out" -class FCEncoder(FCModel): +class FCEncoder(TorchModel, Encoder): + """A fully connected encoder.""" + + def __init__(self, config: ModelConfig) -> None: + TorchModel.__init__(self, config) + Encoder.__init__(self, config) + + self.net = FCNet( + input_dim=config.input_dim, + hidden_layers=config.hidden_layers, + output_dim=config.output_dim, + activation=config.activation, + ) + @property + @override(TorchModel) def input_spec(self): return SpecDict( {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} ) @property + @override(TorchModel) def output_spec(self): return SpecDict( {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} @@ -39,10 +55,11 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(Model, nn.Module): +class LSTMEncoder(TorchModel, Encoder): + """An encoder that uses an LSTM cell and a linear layer.""" + def __init__(self, config: ModelConfig) -> None: - nn.Module.__init__(self) - Model.__init__(self, config) + TorchModel.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -60,6 +77,7 @@ def get_initial_state(self): } @property + @override(TorchModel) def input_spec(self): config = self.config return SpecDict( @@ -79,6 +97,7 @@ def input_spec(self): ) @property + @override(TorchModel) def output_spec(self): config = self.config return SpecDict( @@ -123,12 +142,27 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class Identity(Model): +class IdentityEncoder(TorchModel): def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: pass def __init__(self, config: ModelConfig) -> None: super().__init__(config) + @property + def input_spec(self): + return SpecDict( + # Use the output dim as input dim because identity. + {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + @property + def output_spec(self): + return SpecDict( + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - return inputs + return {ENCODER_OUT: inputs[SampleBatch.OBS]} diff --git a/rllib/core/rl_module/torch/fcmodel.py b/rllib/models/experimental/torch/fcmodel.py similarity index 69% rename from rllib/core/rl_module/torch/fcmodel.py rename to rllib/models/experimental/torch/fcmodel.py index 4b3c1ba00f4f..71e0b8ea0e96 100644 --- a/rllib/core/rl_module/torch/fcmodel.py +++ b/rllib/models/experimental/torch/fcmodel.py @@ -1,17 +1,18 @@ import torch.nn as nn -from ray.rllib.models.base_model import Model, ForwardOutputType +from ray.rllib.models.experimental.base import ForwardOutputType, Model, ModelConfig from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.torch.primitives import FCNet -from rllib.models.base_model import ModelConfig +from ray.rllib.models.experimental.torch.primitives import FCNet +from ray.rllib.models.experimental.torch.primitives import TorchModel +from ray.rllib.utils.annotations import override -class FCModel(Model, nn.Module): +class FCModel(TorchModel, nn.Module): def __init__(self, config: ModelConfig) -> None: nn.Module.__init__(self) - Model.__init__(self, config) + TorchModel.__init__(self, config) self.net = FCNet( input_dim=config.input_dim, @@ -21,14 +22,17 @@ def __init__(self, config: ModelConfig) -> None: ) @property + @override(Model) def input_spec(self): return TorchTensorSpec("b, h", h=self.config.input_dim) @property + @override(Model) def output_spec(self): return TorchTensorSpec("b, h", h=self.config.output_dim) @check_input_specs("input_spec", filter=True, cache=False) @check_output_specs("output_spec", cache=False) + @override(TorchModel) def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return self.net(inputs) diff --git a/rllib/models/torch/primitives.py b/rllib/models/experimental/torch/primitives.py similarity index 53% rename from rllib/models/torch/primitives.py rename to rllib/models/experimental/torch/primitives.py index 46d0ae8a7fc6..b3ba9cee0ac6 100644 --- a/rllib/models/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -1,13 +1,54 @@ from typing import List, Optional +from typing import Tuple -from ray.rllib.models.base_model import Model, ForwardOutputType +from ray.rllib.models.experimental.base import Model, ForwardOutputType +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import ( + is_input_decorated, + is_output_decorated, +) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.utils.framework import try_import_torch -from rllib.models.base_model import ModelConfig +from ray.rllib.utils.typing import TensorType +from rllib.models.experimental.base import ModelConfig torch, nn = try_import_torch() +def _forward_not_decorated(input_or_output): + return ( + f"forward not decorated with {input_or_output} specification. Decorate " + f"with @check_{input_or_output}_specs() to define a specification. See " + f"BaseModel for examples." + ) + + +class TorchModel(nn.Module, Model): + """Base class for torch models. + + This class is used to define the general interface for torch models and checks + whether inputs and outputs are checked with `check_input_specs()` and + `check_output_specs()` respectively. + """ + + def __init__(self, config: ModelConfig): + nn.Module.__init__(self) + Model.__init__(self, config) + assert is_input_decorated(self.forward), _forward_not_decorated("input") + assert is_output_decorated(self.forward), _forward_not_decorated("output") + + def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: + """Returns the output of this model for the given input. + + Args: + input_dict: The input tensors. + + Returns: + Tuple[TensorDict, List[TensorType]]: The output tensors. + """ + raise NotImplementedError + + class FCNet(nn.Module): """A simple fully connected network. @@ -54,11 +95,3 @@ def __init__( def forward(self, x): return self.layers(x) - - -class Identity(Model): - def __init__(self, config: ModelConfig) -> None: - super().__init__(config) - - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - return inputs diff --git a/rllib/models/specs/checker.py b/rllib/models/specs/checker.py index b7ca04c74325..cfb5d5eeb27e 100644 --- a/rllib/models/specs/checker.py +++ b/rllib/models/specs/checker.py @@ -336,3 +336,15 @@ def wrapper(self, input_data, **kwargs): return wrapper return decorator + + +@DeveloperAPI(stability="alpha") +def is_input_decorated(obj: object) -> bool: + """Returns True if the object is decorated with `check_input_specs`.""" + return hasattr(obj, "__checked_input_specs__") + + +@DeveloperAPI(stability="alpha") +def is_output_decorated(obj: object) -> bool: + """Returns True if the object is decorated with `check_output_specs`.""" + return hasattr(obj, "__checked_output_specs__") diff --git a/rllib/models/tf/primitives.py b/rllib/models/tf/primitives.py deleted file mode 100644 index 48b96ff1237e..000000000000 --- a/rllib/models/tf/primitives.py +++ /dev/null @@ -1,50 +0,0 @@ -from typing import List -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf -from rllib.models.torch.primitives import Identity - -_, tf, _ = try_import_tf() - - -class FCNet(tf.keras.Model): - """A simple fully connected network. - - Attributes: - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - output_dim: The output dimension of the network. - activation: The activation function to use after each layer. - Currently "Linear" (no activation) and "ReLU" are supported. - """ - - def __init__( - self, - input_dim: int, - hidden_layers: List[int], - output_dim: int, - activation: str = "linear", - ): - super().__init__() - - if activation not in ("linear", "ReLU", "Tanh"): - raise ValueError("Activation function not supported") - assert input_dim is not None, "Input dimension must not be None" - assert output_dim is not None, "Output dimension must not be None" - layers = [] - activation = activation.lower() - # input = tf.keras.layers.Dense(input_dim, activation=activation) - layers.append(tf.keras.Input(shape=(input_dim,))) - for i in range(len(hidden_layers)): - layers.append( - tf.keras.layers.Dense(hidden_layers[i], activation=activation) - ) - layers.append(tf.keras.layers.Dense(output_dim)) - self.network = tf.keras.Sequential(layers) - - @override(tf.keras.Model) - def call(self, inputs, training=None, mask=None): - return self.network(inputs) - - -# Reuse our torch IdentityEncoder here because it is not framework-specific. -Identity = Identity From d8d8c724fa375a6fab0cd93b1d9a06f1307a244a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 16:06:06 -0800 Subject: [PATCH 24/51] rename PPOModuleConfig for upcoming release Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 4 ++-- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 6 +++--- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index ff370dd00995..cf3f1be7a850 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -474,8 +474,8 @@ def __getitem__(self, item): @dataclass -class PPOModuleConfig(RLModuleConfig): - """Configuration for the PPO module. +class __PPOModuleConfig(RLModuleConfig): + """Configuration for the PPO RLModule. Attributes: observation_space: The observation space of the environment. diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 80c8d2f72513..fbae9c105a84 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -9,7 +9,7 @@ from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule, ) -from rllib.algorithms.ppo.ppo import PPOModuleConfig +from rllib.algorithms.ppo.ppo import __PPOModuleConfig from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( PPOTfRLModule, ) @@ -29,7 +29,7 @@ def get_expected_model_config( env: gym.Env, lstm: bool, -) -> PPOModuleConfig: +) -> __PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: @@ -78,7 +78,7 @@ def get_expected_model_config( else: pi_config.output_dim = env.action_space.shape[0] * 2 - return PPOModuleConfig( + return __PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, shared_encoder_config=shared_encoder_config, diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 4c4839dfb1e1..4f7a29de113d 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -24,7 +24,7 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict -from rllib.algorithms.ppo.ppo import PPOModuleConfig +from rllib.algorithms.ppo.ppo import __PPOModuleConfig torch, nn = try_import_torch() @@ -43,7 +43,7 @@ def get_ppo_loss(fwd_in, fwd_out): class PPOTorchRLModule(TorchRLModule): - def __init__(self, config: PPOModuleConfig) -> None: + def __init__(self, config: __PPOModuleConfig) -> None: super().__init__() self.config = config self.setup() @@ -142,7 +142,7 @@ def from_model_config( else: pi_config.output_dim = action_space.shape[0] * 2 - config_ = PPOModuleConfig( + config_ = __PPOModuleConfig( observation_space=observation_space, action_space=action_space, shared_encoder_config=shared_encoder_config, From 503b7dd0c52dfadfca30308953a384782f8a6928 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 16:37:04 -0800 Subject: [PATCH 25/51] cleanup Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 2 +- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 7 ++++--- rllib/models/experimental/base.py | 2 -- rllib/models/experimental/tf/encoder.py | 1 - rllib/models/experimental/tf/primitives.py | 7 +++---- rllib/models/experimental/torch/primitives.py | 7 +++---- rllib/models/specs/checker.py | 3 +-- 7 files changed, 12 insertions(+), 17 deletions(-) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index cf3f1be7a850..1e8eae9b345c 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -474,7 +474,7 @@ def __getitem__(self, item): @dataclass -class __PPOModuleConfig(RLModuleConfig): +class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO RLModule. Attributes: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 4f7a29de113d..15160bd21173 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -23,8 +23,9 @@ from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from ray.rllib.algorithms.ppo.ppo import PPOModuleConfig from ray.rllib.utils.nested_dict import NestedDict -from rllib.algorithms.ppo.ppo import __PPOModuleConfig + torch, nn = try_import_torch() @@ -43,7 +44,7 @@ def get_ppo_loss(fwd_in, fwd_out): class PPOTorchRLModule(TorchRLModule): - def __init__(self, config: __PPOModuleConfig) -> None: + def __init__(self, config: PPOModuleConfig) -> None: super().__init__() self.config = config self.setup() @@ -142,7 +143,7 @@ def from_model_config( else: pi_config.output_dim = action_space.shape[0] * 2 - config_ = __PPOModuleConfig( + config_ = PPOModuleConfig( observation_space=observation_space, action_space=action_space, shared_encoder_config=shared_encoder_config, diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 2786f596ba71..dcbd4c943e9c 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -4,8 +4,6 @@ from ray.rllib.models.specs.checker import ( check_input_specs, check_output_specs, - is_input_decorated, - is_output_decorated, ) from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.temp_spec_classes import TensorDict diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index d70369c25b37..fb3939421800 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -3,7 +3,6 @@ import tree from ray.rllib.models.experimental.base import ( - Model, Encoder, STATE_IN, STATE_OUT, diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 71ddd8393f6a..868a85ce01a8 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -1,12 +1,11 @@ from typing import List from ray.rllib.utils.framework import try_import_tf -from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import ( - is_input_decorated, + input_is_decorated, is_output_decorated, ) from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.experimental.base import ModelConfig, Model, ForwardOutputType +from ray.rllib.models.experimental.base import Model from ray.rllib.utils.typing import TensorType from typing import Tuple @@ -31,7 +30,7 @@ class TFModel(Model): def __init__(self, config): super().__init__(config) - assert is_input_decorated(self.__call__), _call_not_decorated("input") + assert input_is_decorated(self.__call__), _call_not_decorated("input") assert is_output_decorated(self.__call__), _call_not_decorated("output") def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index b3ba9cee0ac6..3a4b5ee92c30 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -1,10 +1,9 @@ from typing import List, Optional from typing import Tuple -from ray.rllib.models.experimental.base import Model, ForwardOutputType -from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.experimental.base import Model from ray.rllib.models.specs.checker import ( - is_input_decorated, + input_is_decorated, is_output_decorated, ) from ray.rllib.models.temp_spec_classes import TensorDict @@ -34,7 +33,7 @@ class TorchModel(nn.Module, Model): def __init__(self, config: ModelConfig): nn.Module.__init__(self) Model.__init__(self, config) - assert is_input_decorated(self.forward), _forward_not_decorated("input") + assert input_is_decorated(self.forward), _forward_not_decorated("input") assert is_output_decorated(self.forward), _forward_not_decorated("output") def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: diff --git a/rllib/models/specs/checker.py b/rllib/models/specs/checker.py index cfb5d5eeb27e..2661d1750086 100644 --- a/rllib/models/specs/checker.py +++ b/rllib/models/specs/checker.py @@ -338,8 +338,7 @@ def wrapper(self, input_data, **kwargs): return decorator -@DeveloperAPI(stability="alpha") -def is_input_decorated(obj: object) -> bool: +def input_is_decorated(obj: object) -> bool: """Returns True if the object is decorated with `check_input_specs`.""" return hasattr(obj, "__checked_input_specs__") From 3c30b4201df61ad5fb9e41f9a76f5a1c9410014b Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 16:43:53 -0800 Subject: [PATCH 26/51] rename configs Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/{model_configs.py => configs.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename rllib/models/experimental/{model_configs.py => configs.py} (100%) diff --git a/rllib/models/experimental/model_configs.py b/rllib/models/experimental/configs.py similarity index 100% rename from rllib/models/experimental/model_configs.py rename to rllib/models/experimental/configs.py From 5687454e9fff24d3d0a3b0ea1062beb81b078aad Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 17:01:20 -0800 Subject: [PATCH 27/51] remove rebase artifacts Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 3 +- .../ppo/tests/test_ppo_rl_module.py | 86 ++--- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 2 +- .../ppo/torch/ppo_torch_rl_module.py | 4 +- rllib/core/rl_module/encoder.py | 202 +++++++++++ rllib/core/rl_module/encoder_tf.py | 37 ++ rllib/models/base_model.py | 333 ++++++++++++++++++ rllib/models/configs/encoder.py | 83 +++++ rllib/models/experimental/model_configs.py | 81 +++++ rllib/models/experimental/torch/encoder.py | 14 +- rllib/models/tf/primitives.py | 55 +++ rllib/models/torch/primitives.py | 54 +++ 12 files changed, 893 insertions(+), 61 deletions(-) create mode 100644 rllib/core/rl_module/encoder.py create mode 100644 rllib/core/rl_module/encoder_tf.py create mode 100644 rllib/models/base_model.py create mode 100644 rllib/models/configs/encoder.py create mode 100644 rllib/models/experimental/model_configs.py create mode 100644 rllib/models/tf/primitives.py create mode 100644 rllib/models/torch/primitives.py diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 1e8eae9b345c..f19c742b2c04 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -15,7 +15,7 @@ import gymnasium as gym -from ray.rllib.models.experimental.model_configs import FCConfig +from ray.rllib.models.experimental.configs import FCConfig from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.util.debug import log_once @@ -473,6 +473,7 @@ def __getitem__(self, item): DEFAULT_CONFIG = _deprecated_default_config() +@ExperimentalAPI @dataclass class PPOModuleConfig(RLModuleConfig): """Configuration for the PPO RLModule. diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index fbae9c105a84..d1a10d916a82 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -1,19 +1,21 @@ +import itertools import ray import unittest import numpy as np -import gym +import gymnasium as gym import torch +import tensorflow as tf import tree from ray.rllib import SampleBatch from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule, ) -from rllib.algorithms.ppo.ppo import __PPOModuleConfig +from rllib.algorithms.ppo.ppo import PPOModuleConfig from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( PPOTfRLModule, ) -from ray.rllib.models.experimental.model_configs import ( +from ray.rllib.models.experimental.configs import ( FCConfig, FCEncoderConfig, LSTMEncoderConfig, @@ -29,7 +31,7 @@ def get_expected_model_config( env: gym.Env, lstm: bool, -) -> __PPOModuleConfig: +) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: @@ -51,8 +53,6 @@ def get_expected_model_config( batch_first=True, num_layers=1, output_dim=32, - input_key=SampleBatch.OBS, - output_key=shared_encoder_kh.create(BaseModelIOKeys.OUT), ) else: shared_encoder_config = FCEncoderConfig( @@ -78,7 +78,7 @@ def get_expected_model_config( else: pi_config.output_dim = env.action_space.shape[0] * 2 - return __PPOModuleConfig( + return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, shared_encoder_config=shared_encoder_config, @@ -217,45 +217,33 @@ def test_forward_train(self): initial_state = state_in while tstep < 10: if lstm: - state_in = module.get_initial_state() - state_in = tree.map_structure( - lambda x: x[None], convert_to_torch_tensor(state_in) - ) - initial_state = state_in - while tstep < 10: - if lstm: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, - SampleBatch.SEQ_LENS: np.array([1]), - } - else: - input_batch = { - SampleBatch.OBS: convert_to_torch_tensor(obs)[None] - } - fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy( - fwd_out["action_dist"].sample().squeeze(0) - ) - new_obs, reward, terminated, truncated, _ = env.step(action) - output_batch = { - SampleBatch.OBS: obs, - SampleBatch.NEXT_OBS: new_obs, - SampleBatch.ACTIONS: action, - SampleBatch.REWARDS: np.array(reward), - SampleBatch.TERMINATEDS: np.array(terminated), - SampleBatch.TRUNCATEDS: np.array(truncated), - } - if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] - batches.append(output_batch) - obs = new_obs - tstep += 1 - - # convert the list of dicts to dict of lists - batch = tree.map_structure(lambda *x: list(x), *batches) - # convert dict of lists to dict of tensors + input_batch = self.get_input_batch_from_obs(fw, obs) + input_batch[STATE_IN] = state_in + input_batch[SampleBatch.SEQ_LENS] = np.array([1]) + else: + input_batch = self.get_input_batch_from_obs(fw, obs) + fwd_out = module.forward_exploration(input_batch) + action = convert_to_numpy(fwd_out["action_dist"].sample()[0]) + new_obs, reward, terminated, truncated, _ = env.step(action) + output_batch = { + SampleBatch.OBS: obs, + SampleBatch.NEXT_OBS: new_obs, + SampleBatch.ACTIONS: action, + SampleBatch.REWARDS: np.array(reward), + SampleBatch.TERMINATEDS: np.array(terminated), + SampleBatch.TRUNCATEDS: np.array(truncated), + } + if lstm: + assert STATE_OUT in fwd_out + state_in = fwd_out[STATE_OUT] + batches.append(output_batch) + obs = new_obs + tstep += 1 + + # convert the list of dicts to dict of lists + batch = tree.map_structure(lambda *x: np.array(x), *batches) + # convert dict of lists to dict of tensors + if fw == "torch": fwd_in = { k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items() } @@ -264,12 +252,12 @@ def test_forward_train(self): fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) # forward train - # before training make sure module is on the right device and in - # training mode + # before training make sure module is on the right device + # and in training mode module.to("cpu") module.train() fwd_out = module.forward_train(fwd_in) - loss = get_ppo_loss(fwd_in, fwd_out) + loss = dummy_torch_ppo_loss(fwd_in, fwd_out) loss.backward() # check that all neural net parameters have gradients diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 760017c38825..adcba4289621 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -3,7 +3,7 @@ from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.experimental.model_configs import FCConfig, IdentityConfig +from ray.rllib.models.experimental.configs import FCConfig, IdentityConfig from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 15160bd21173..43c9bf4d414e 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -5,8 +5,8 @@ from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.experimental.base import STATE_OUT -from ray.rllib.models.experimental.model_configs import FCConfig, FCEncoderConfig -from ray.rllib.models.experimental.model_configs import ( +from ray.rllib.models.experimental.configs import FCConfig, FCEncoderConfig +from ray.rllib.models.experimental.configs import ( LSTMEncoderConfig, ) from ray.rllib.models.experimental.torch.encoder import ( diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py new file mode 100644 index 000000000000..f3bb22b46900 --- /dev/null +++ b/rllib/core/rl_module/encoder.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import tree +from typing import List + +from dataclasses import dataclass, field + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.rnn_sequencing import add_time_dimension +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.torch.primitives import FCNet + +# TODO (Kourosh): Find a better / more straight fwd approach for sub-components + +ENCODER_OUT = "encoder_out" +STATE_IN = "state_in" +STATE_OUT = "state_out" + + +@dataclass +class EncoderConfig: + """Configuration for an encoder network. + + Attributes: + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + """ + + output_dim: int = None + + +@dataclass +class IdentityConfig(EncoderConfig): + """Configuration for an identity encoder.""" + + def build(self): + return IdentityEncoder(self) + + +@dataclass +class FCConfig(EncoderConfig): + """Configuration for a fully connected network. + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer (except for the + output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) + activation: str = "ReLU" + + def build(self): + return FullyConnectedEncoder(self) + + +@dataclass +class LSTMConfig(EncoderConfig): + input_dim: int = None + hidden_dim: int = None + num_layers: int = None + batch_first: bool = True + + def build(self): + return LSTMEncoder(self) + + +class Encoder(nn.Module): + def __init__(self, config: EncoderConfig) -> None: + super().__init__() + self.config = config + self._input_spec = self.input_spec() + self._output_spec = self.output_spec() + + def get_initial_state(self): + return [] + + def input_spec(self): + return SpecDict() + + def output_spec(self): + return SpecDict() + + @check_input_specs("_input_spec") + @check_output_specs("_output_spec") + def forward(self, input_dict): + return self._forward(input_dict) + + def _forward(self, input_dict): + raise NotImplementedError + + +class FullyConnectedEncoder(Encoder): + def __init__(self, config: FCConfig) -> None: + super().__init__(config) + + self.net = FCNet( + input_dim=config.input_dim, + hidden_layers=config.hidden_layers, + output_dim=config.output_dim, + activation=config.activation, + ) + + def input_spec(self): + return SpecDict( + {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} + ) + + def output_spec(self): + return SpecDict( + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + def _forward(self, input_dict): + return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} + + +class LSTMEncoder(Encoder): + def __init__(self, config: LSTMConfig) -> None: + super().__init__(config) + + self.lstm = nn.LSTM( + config.input_dim, + config.hidden_dim, + config.num_layers, + batch_first=config.batch_first, + ) + self.linear = nn.Linear(config.hidden_dim, config.output_dim) + + def get_initial_state(self): + config = self.config + return { + "h": torch.zeros(config.num_layers, config.hidden_dim), + "c": torch.zeros(config.num_layers, config.hidden_dim), + } + + def input_spec(self): + config = self.config + return SpecDict( + { + # bxt is just a name for better readability to indicated padded batch + SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), + STATE_IN: { + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + } + ) + + def output_spec(self): + config = self.config + return SpecDict( + { + ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), + STATE_OUT: { + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + } + ) + + def _forward(self, input_dict: SampleBatch): + x = input_dict[SampleBatch.OBS] + states = input_dict[STATE_IN] + # states are batch-first when coming in + states = tree.map_structure(lambda x: x.transpose(0, 1), states) + + x = add_time_dimension( + x, + seq_lens=input_dict[SampleBatch.SEQ_LENS], + framework="torch", + time_major=not self.config.batch_first, + ) + states_o = {} + x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) + + x = self.linear(x) + x = x.view(-1, x.shape[-1]) + + return { + ENCODER_OUT: x, + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + } + + +class IdentityEncoder(Encoder): + def __init__(self, config: EncoderConfig) -> None: + super().__init__(config) + + def _forward(self, input_dict): + return input_dict diff --git a/rllib/core/rl_module/encoder_tf.py b/rllib/core/rl_module/encoder_tf.py new file mode 100644 index 000000000000..5c517f6c745d --- /dev/null +++ b/rllib/core/rl_module/encoder_tf.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from typing import List + +from ray.rllib.core.rl_module.encoder import EncoderConfig +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.models.tf.primitives import FCNet, IdentityNetwork + +tf1, tf, tfv = try_import_tf() + + +@dataclass +class FCTfConfig(EncoderConfig): + """Configuration for a fully connected network. + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer (except for the + output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + output_dim: int = None + hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) + activation: str = "ReLU" + + def build(self): + return FCNet( + self.input_dim, self.hidden_layers, self.output_dim, self.activation + ) + + +@dataclass +class IdentityTfConfig(EncoderConfig): + """A network that returns the input as the output.""" + + def build(self): + return IdentityNetwork() diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py new file mode 100644 index 000000000000..c006af27f6f1 --- /dev/null +++ b/rllib/models/base_model.py @@ -0,0 +1,333 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import abc +from typing import Optional, Tuple + +from ray.rllib.models.temp_spec_classes import TensorDict, SpecDict, ModelConfig +from ray.rllib.utils.annotations import ( + DeveloperAPI, + OverrideToImplementCustomLogic, + override, + ExperimentalAPI, +) + + +ForwardOutputType = TensorDict +# [Output, Recurrent State(s)] +UnrollOutputType = Tuple[TensorDict, TensorDict] + + +@ExperimentalAPI +class RecurrentModel(abc.ABC): + """The base model all other models are based on. + + RLlib models all inherit from the recurrent base class, which can be chained + together with other models. + + The models input and output TensorDicts. Which keys the models read/write to + and the desired tensor shapes must be defined in input_spec, output_spec, + prev_state_spec, and next_state_spec. + + The `unroll` function gets the model inputs and previous recurrent state, and + outputs the model outputs and next recurrent state. Note all ins/outs must match + the specs. Users should override `_unroll` rather than `unroll`. + + `initial_state` returns the "next" state for the first recurrent iteration. Again, + users should override `_initial_state` instead. + + For non-recurrent models, users may use Model instead, and override + `_forward` which does not make use of recurrent states. + + Args: + name: An optional name for the module + """ + + def __init__(self, name: Optional[str] = None): + self._name = name or self.__class__.__name__ + + @property + def name(self) -> str: + """Returns the name of this module.""" + return self._name + + @property + @abc.abstractmethod + def input_spec(self) -> SpecDict: + """Returns the spec of the input of this module.""" + + @property + @abc.abstractmethod + def prev_state_spec(self) -> SpecDict: + """Returns the spec of the prev_state of this module.""" + + @property + @abc.abstractmethod + def output_spec(self) -> SpecDict: + """Returns the spec of the output of this module.""" + + @property + @abc.abstractmethod + def next_state_spec(self) -> SpecDict: + """Returns the spec of the next_state of this module.""" + + @abc.abstractmethod + def _initial_state(self) -> TensorDict: + """Initial state of the component. + If this component returns a next_state in its unroll function, then + this function provides the initial state. + Subclasses should override this function instead of `initial_state`, which + adds additional checks. + + Returns: + A TensorDict containing the state before the first step. + """ + + @DeveloperAPI + def initial_state(self) -> TensorDict: + """Initial state of the component. + If this component returns a next_state in its unroll function, then + this function provides the initial state. + + Returns: + A TensorDict containing the state before the first step. + + Examples: + >>> state = model.initial_state() + >>> state # TensorDict(...) + """ + initial_state = self._initial_state() + self.next_state_spec.validate(initial_state) + return initial_state + + @abc.abstractmethod + def _unroll( + self, inputs: TensorDict, prev_state: TensorDict, **kwargs + ) -> UnrollOutputType: + """Computes the output of the module over the timesteps within the batch. + Subclasses should override this function instead of `unroll`, which + adds additional checks. + + Args: + inputs: A TensorDict of inputs + prev_state: A TensorDict containing the next_state of the last + timestep of the previous unroll. + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict of outputs + next_state: A dict containing the state to be passed + as the first state of the next rollout. + """ + + @DeveloperAPI + def unroll( + self, inputs: TensorDict, prev_state: TensorDict, **kwargs + ) -> UnrollOutputType: + """Computes the output of the module over the timesteps within the batch. + + Args: + inputs: A TensorDict containing inputs to the model + prev_state: A TensorDict containing containing the + next_state of the last timestep of the previous unroll. + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + next_state: A TensorDict containing the + state to be passed as the first state of the next rollout. + + Examples: + >>> output, state = model.unroll(TensorDict(...), TensorDict(...)) + >>> output # TensorDict(...) + >>> state # TensorDict(...) + + """ + self.input_spec.validate(inputs) + self.prev_state_spec.validate(prev_state) + # We hide inputs not specified in input_spec to prevent accidental use. + inputs = inputs.filter(self.input_spec) + prev_state = prev_state.filter(self.prev_state_spec) + inputs, prev_state = self._update_inputs_and_prev_state(inputs, prev_state) + outputs, next_state = self._unroll(inputs, prev_state, **kwargs) + self.output_spec.validate(outputs) + self.next_state_spec.validate(next_state) + outputs, next_state = self._update_outputs_and_next_state(outputs, next_state) + return outputs, next_state + + @OverrideToImplementCustomLogic + def _update_inputs_and_prev_state( + self, inputs: TensorDict, prev_state: TensorDict + ) -> Tuple[TensorDict, TensorDict]: + """Override this function to add additional checks and optionally update inputs. + + Args: + inputs: TensorDict containing inputs to the model + prev_state: The previous recurrent state + + Returns: + inputs: Potentially modified inputs + prev_state: Potentially modified recurrent state + """ + return inputs, prev_state + + @OverrideToImplementCustomLogic + def _update_outputs_and_next_state( + self, outputs: TensorDict, next_state: TensorDict + ) -> Tuple[TensorDict, TensorDict]: + """Override this function to add additional checks and optionally update + outputs. + + Args: + outputs: TensorDict output by the model + next_state: Recurrent state output by the model + + Returns: + outputs: Potentially modified TensorDict output by the model + next_state: Potentially modified recurrent state output by the model + """ + return outputs, next_state + + +class Model(RecurrentModel): + """A RecurrentModel made non-recurrent by ignoring + the input/output states. + + As a convienience, users may override _forward instead of _unroll, + which hides model states. + + Args: + name: An optional name for the module + """ + + @property + @override(RecurrentModel) + def prev_state_spec(self) -> SpecDict: + return SpecDict() + + @property + @override(RecurrentModel) + def next_state_spec(self) -> SpecDict: + return SpecDict() + + @override(RecurrentModel) + def _initial_state(self) -> TensorDict: + return TensorDict() + + @override(RecurrentModel) + def _update_inputs_and_prev_state( + self, inputs: TensorDict, prev_state: TensorDict + ) -> Tuple[TensorDict, TensorDict]: + inputs = self._update_inputs(inputs) + return inputs, prev_state + + @OverrideToImplementCustomLogic + def _update_inputs(self, inputs: TensorDict) -> TensorDict: + """Override this function to add additional checks and optionally update inputs. + + Args: + inputs: TensorDict containing inputs to the model + + Returns: + inputs: Potentially modified inputs + """ + return inputs + + @override(RecurrentModel) + def _update_outputs_and_next_state( + self, outputs: TensorDict, next_state: TensorDict + ) -> Tuple[TensorDict, TensorDict]: + outputs = self._update_outputs(outputs) + return outputs, next_state + + @OverrideToImplementCustomLogic + def _update_outputs(self, outputs: TensorDict) -> TensorDict: + """Override this function to add additional checks and optionally update + outputs. + + Args: + outputs: TensorDict output by the model + + Returns: + outputs: Potentially modified TensorDict output by the model + """ + return outputs + + @override(RecurrentModel) + def _unroll( + self, inputs: TensorDict, prev_state: TensorDict, **kwargs + ) -> UnrollOutputType: + outputs = self._forward(inputs, **kwargs) + return outputs, TensorDict() + + @abc.abstractmethod + def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + """Computes the output of this module for each timestep. + + Args: + inputs: A TensorDict containing model inputs + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + + Examples: + # This is abstract, see the torch/tf/jax implementations + >>> out = model._forward(TensorDict({"in": np.arange(10)})) + >>> out # TensorDict(...) + """ + + +@ExperimentalAPI +class ModelIO(abc.ABC): + """Abstract class defining how to save and load model weights + + Args: + config: The ModelConfig passed to the underlying model + """ + + def __init__(self, config: ModelConfig) -> None: + self._config = config + + @DeveloperAPI + @property + def config(self) -> ModelConfig: + return self._config + + @DeveloperAPI + @abc.abstractmethod + def save(self, path: str) -> None: + """Save model weights to a path + + Args: + path: The path on disk where weights are to be saved + + Examples: + model.save("/tmp/model_path.cpt") + """ + raise NotImplementedError + + @DeveloperAPI + @abc.abstractmethod + def load(self, path: str) -> RecurrentModel: + """Load model weights from a path + + Args: + path: The path on disk where to load weights from + + Examples: + model.load("/tmp/model_path.cpt") + """ + raise NotImplementedError diff --git a/rllib/models/configs/encoder.py b/rllib/models/configs/encoder.py new file mode 100644 index 000000000000..38a7f305123a --- /dev/null +++ b/rllib/models/configs/encoder.py @@ -0,0 +1,83 @@ +import abc +from dataclasses import dataclass +from typing import TYPE_CHECKING, Tuple + +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.torch.encoders.vector import TorchVectorEncoder + +if TYPE_CHECKING: + from ray.rllib.models.torch.encoders.vector import Encoder + + +@dataclass +class EncoderConfig: + """The base config for encoder models. + + Each config should define a `build` method that builds a model from the config. + + All user-configurable parameters known before runtime + (e.g. framework, activation, num layers, etc.) should be defined as attributes. + + Parameters unknown before runtime (e.g. the output size of the module providing + input for this module) should be passed as arguments to `build`. This should be + as few params as possible. + + `build` should return an instance of the encoder associated with the config. + + Attributes: + framework_str: The tensor framework to construct a model for. + This can be 'torch', 'tf2', or 'jax'. + """ + + framework_str: str = "torch" + + @abc.abstractmethod + def build(self, input_spec: SpecDict, **kwargs) -> "Encoder": + """Builds the EncoderConfig into an Encoder instance""" + + +@dataclass +class VectorEncoderConfig(EncoderConfig): + """An MLP encoder mappings tensors with shape [..., feature] to [..., output]. + + Attributes: + activation: The type of activation function to use between hidden layers. + Options are 'relu', 'swish', 'tanh', or 'linear' + final_activation: The activation function to use after the final linear layer. + Options are the same as for activation. + hidden_layer_sizes: A list, where each element represents the number of neurons + in that layer. For example, [128, 64] would produce a two-layer MLP with + 128 hidden neurons and 64 hidden neurons. + output_key: Write the output of the encoder to this key in the NestedDict. + """ + + activation: str = "relu" + final_activation: str = "linear" + hidden_layer_sizes: Tuple[int, ...] = (128, 128) + output_key: str = "encoding" + + def build(self, input_spec: SpecDict) -> TorchVectorEncoder: + """Build the config into a VectorEncoder model instance. + + Args: + input_spec: The output spec of the previous module(s) that will feed + inputs to this encoder. + + Returns: + A VectorEncoder of the specified framework. + """ + assert ( + len(self.hidden_layer_sizes) > 1 + ), "Must have at least a single hidden layer" + for k in input_spec.shallow_keys(): + assert isinstance( + input_spec[k].shape[-1], int + ), "Input spec {k} does not define the size of the feature (last) dimension" + + if self.framework_str == "torch": + return TorchVectorEncoder(input_spec, self) + else: + raise NotImplementedError( + "{self.__class__.__name__} not implemented" + " for framework {self.framework}" + ) diff --git a/rllib/models/experimental/model_configs.py b/rllib/models/experimental/model_configs.py new file mode 100644 index 000000000000..bb9175048a8a --- /dev/null +++ b/rllib/models/experimental/model_configs.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass, field +from typing import List +import functools + +from ray.rllib.models.experimental.base import ModelConfig, Model + + +def _check_framework(fn): + @functools.wraps(fn) + def checked_build(self, framework, **kwargs): + if framework not in ("torch", "tf", "tf2"): + raise ValueError(f"Framework {framework} not supported.") + return fn(self, framework, **kwargs) + + return checked_build + + +@dataclass +class FCConfig(ModelConfig): + """Configuration for a fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer (except for the + output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) + activation: str = "ReLU" + output_activation: str = "ReLU" + + @_check_framework + def build(self, framework: str = "torch") -> Model: + if framework == "torch": + from ray.rllib.models.experimental.torch.fcmodel import FCModel + else: + from ray.rllib.models.experimental.tf.fcmodel import FCModel + return FCModel(self) + + +@dataclass +class FCEncoderConfig(FCConfig): + def build(self, framework: str = "torch"): + if framework == "torch": + from ray.rllib.models.experimental.torch.encoder import FCEncoder + else: + from ray.rllib.models.experimental.tf.encoder import FCEncoder + return FCEncoder(self) + + +@dataclass +class LSTMEncoderConfig(ModelConfig): + input_dim: int = None + hidden_dim: int = None + num_layers: int = None + batch_first: bool = True + + @_check_framework + def build(self, framework: str = "torch"): + if not framework == "torch": + raise ValueError("Only torch framework supported.") + from rllib.models.experimental.torch.encoder import LSTMEncoder + + return LSTMEncoder(self) + + +@dataclass +class IdentityConfig(ModelConfig): + """Configuration for an identity encoder.""" + + @_check_framework + def build(self, framework: str = "torch"): + if framework == "torch": + from rllib.models.experimental.torch.encoder import IdentityEncoder + else: + from rllib.models.experimental.tf.encoder import IdentityEncoder + + return IdentityEncoder(self) diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index 7578b3c1ec4f..bb7c46214e2f 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -83,8 +83,8 @@ def input_spec(self): return SpecDict( { # bxt is just a name for better readability to indicated padded batch - self.config.input_key: TorchTensorSpec("bxt, h", h=config.input_dim), - self.config.state_in_key: { + SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), + STATE_IN: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -102,8 +102,8 @@ def output_spec(self): config = self.config return SpecDict( { - self.config.output_key: TorchTensorSpec("bxt, h", h=config.output_dim), - self.config.state_out_key: { + ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), + STATE_OUT: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -135,10 +135,8 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: x = x.view(-1, x.shape[-1]) return { - self.config.output_key: x, - self.config.state_out_key: tree.map_structure( - lambda x: x.transpose(0, 1), states_o - ), + ENCODER_OUT: x, + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), } diff --git a/rllib/models/tf/primitives.py b/rllib/models/tf/primitives.py new file mode 100644 index 000000000000..395ce9863135 --- /dev/null +++ b/rllib/models/tf/primitives.py @@ -0,0 +1,55 @@ +from typing import List +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf + +_, tf, _ = try_import_tf() + +# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. + + +class FCNet(tf.keras.Model): + """A simple fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layers: The sizes of the hidden layers. + output_dim: The output dimension of the network. + activation: The activation function to use after each layer. + Currently "Linear" (no activation) and "ReLU" are supported. + """ + + def __init__( + self, + input_dim: int, + hidden_layers: List[int], + output_dim: int, + activation: str = "linear", + ): + super().__init__() + + if activation not in ("linear", "ReLU", "Tanh"): + raise ValueError("Activation function not supported") + assert input_dim is not None, "Input dimension must not be None" + assert output_dim is not None, "Output dimension must not be None" + layers = [] + activation = activation.lower() + # input = tf.keras.layers.Dense(input_dim, activation=activation) + layers.append(tf.keras.Input(shape=(input_dim,))) + for i in range(len(hidden_layers)): + layers.append( + tf.keras.layers.Dense(hidden_layers[i], activation=activation) + ) + layers.append(tf.keras.layers.Dense(output_dim)) + self.network = tf.keras.Sequential(layers) + + @override(tf.keras.Model) + def call(self, inputs, training=None, mask=None): + return self.network(inputs) + + +class IdentityNetwork(tf.keras.Model): + """A network that returns the input as the output.""" + + @override(tf.keras.Model) + def call(self, inputs, training=None, mask=None): + return inputs diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py new file mode 100644 index 000000000000..191a0ff35e5a --- /dev/null +++ b/rllib/models/torch/primitives.py @@ -0,0 +1,54 @@ +from typing import List, Optional +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + +# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. + + +class FCNet(nn.Module): + """A simple fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer. + """ + + def __init__( + self, + input_dim: int, + hidden_layers: List[int], + output_dim: Optional[int] = None, + activation: str = "linear", + ): + super().__init__() + self.input_dim = input_dim + self.hidden_layers = hidden_layers + + activation_class = getattr(nn, activation, lambda: None)() + self.layers = [] + self.layers.append(nn.Linear(self.input_dim, self.hidden_layers[0])) + for i in range(len(self.hidden_layers) - 1): + if activation != "linear": + self.layers.append(activation_class) + self.layers.append( + nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]) + ) + + if output_dim is not None: + if activation != "linear": + self.layers.append(activation_class) + self.layers.append(nn.Linear(self.hidden_layers[-1], output_dim)) + + if output_dim is None: + self.output_dim = hidden_layers[-1] + else: + self.output_dim = output_dim + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x): + return self.layers(x) From bec7f470c4e72450cb8773e15c05851cf46312ad Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 17:31:13 -0800 Subject: [PATCH 28/51] minor fixes Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/ppo.py | 30 ----------- .../ppo/tests/test_ppo_with_rl_module.py | 2 +- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 3 +- .../ppo/torch/ppo_torch_rl_module.py | 30 +++++++++-- rllib/models/experimental/base.py | 8 --- rllib/models/experimental/configs.py | 52 ++++++++++++++----- 6 files changed, 67 insertions(+), 58 deletions(-) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index f19c742b2c04..89b7955c149b 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -10,14 +10,8 @@ """ import logging -from dataclasses import dataclass from typing import List, Optional, Type, Union, TYPE_CHECKING -import gymnasium as gym - -from ray.rllib.models.experimental.configs import FCConfig -from ray.rllib.core.rl_module.rl_module import RLModuleConfig - from ray.util.debug import log_once from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided @@ -471,27 +465,3 @@ def __getitem__(self, item): DEFAULT_CONFIG = _deprecated_default_config() - - -@ExperimentalAPI -@dataclass -class PPOModuleConfig(RLModuleConfig): - """Configuration for the PPO RLModule. - - Attributes: - observation_space: The observation space of the environment. - action_space: The action space of the environment. - shared_encoder_config: The configuration for the encoder network. - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. - free_log_std: For DiagGaussian action distributions, make the second half of - the model outputs floating bias variables instead of state-dependent. This - only has an effect is using the default fully connected net. - """ - - observation_space: gym.Space = None - action_space: gym.Space = None - shared_encoder_config: FCConfig = None - pi_config: FCConfig = None - vf_config: FCConfig = None - free_log_std: bool = False diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 84100e07ae0e..4f37ffd6476c 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -78,7 +78,7 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs): class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): - ray.init(local_mode=True) + ray.init() @classmethod def tearDownClass(cls): diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index adcba4289621..48538316bdfb 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -11,8 +11,7 @@ from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian from ray.rllib.models.experimental.tf.primitives import FCNet from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT -from ray.rllib.algorithms.ppo.ppo import PPOModuleConfig - +from rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 43c9bf4d414e..2ced64f5a3a1 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,8 +1,9 @@ +from dataclasses import dataclass from typing import Mapping, Any, Union import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.experimental.base import STATE_OUT from ray.rllib.models.experimental.configs import FCConfig, FCEncoderConfig @@ -20,10 +21,9 @@ TorchDiagGaussian, ) from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import override, ExperimentalAPI from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space -from ray.rllib.algorithms.ppo.ppo import PPOModuleConfig from ray.rllib.utils.nested_dict import NestedDict @@ -43,6 +43,30 @@ def get_ppo_loss(fwd_in, fwd_out): return loss +@ExperimentalAPI +@dataclass +class PPOModuleConfig(RLModuleConfig): + """Configuration for the PPO RLModule. + + Attributes: + observation_space: The observation space of the environment. + action_space: The action space of the environment. + shared_encoder_config: The configuration for the encoder network. + pi_config: The configuration for the policy network. + vf_config: The configuration for the value network. + free_log_std: For DiagGaussian action distributions, make the second half of + the model outputs floating bias variables instead of state-dependent. This + only has an effect is using the default fully connected net. + """ + + observation_space: gym.Space = None + action_space: gym.Space = None + shared_encoder_config: FCConfig = None + pi_config: FCConfig = None + vf_config: FCConfig = None + free_log_std: bool = False + + class PPOTorchRLModule(TorchRLModule): def __init__(self, config: PPOModuleConfig) -> None: super().__init__() diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index dcbd4c943e9c..67cf3c4d4add 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -15,14 +15,6 @@ STATE_OUT: str = "state_out" -def _not_decorated_message(input_or_output): - return ( - f"__call__ not decorated with {input_or_output} specification. Decorate " - f"with @check_{input_or_output}_specs() to define a specification. See " - f"BaseModel for examples." - ) - - @ExperimentalAPI @dataclass class ModelConfig(abc.ABC): diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index bb9175048a8a..985ec145cac8 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -1,18 +1,42 @@ from dataclasses import dataclass, field -from typing import List +from typing import List, Callable import functools from ray.rllib.models.experimental.base import ModelConfig, Model +from ray.rllib.utils.annotations import DeveloperAPI -def _check_framework(fn): - @functools.wraps(fn) - def checked_build(self, framework, **kwargs): - if framework not in ("torch", "tf", "tf2"): - raise ValueError(f"Framework {framework} not supported.") - return fn(self, framework, **kwargs) +@DeveloperAPI +def _framework_implemented(torch: bool = True, tf: bool = True): + """Decorator to check if a model was implemented in a framework. - return checked_build + Args: + torch: Whether we can build this model with torch. + tf: Whether we can build this model with tf. + + Returns: + The decorated function. + + Raises: + ValueError: If the framework is not available to build. + """ + accepted = [] + if torch: + accepted.append("torch") + if tf: + accepted.append("tf") + accepted.append("tf2") + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + def checked_build(self, framework, **kwargs): + if framework not in accepted: + raise ValueError(f"Framework {framework} not supported.") + return fn(self, framework, **kwargs) + + return checked_build + + return decorator @dataclass @@ -32,7 +56,7 @@ class FCConfig(ModelConfig): activation: str = "ReLU" output_activation: str = "ReLU" - @_check_framework + @_framework_implemented() def build(self, framework: str = "torch") -> Model: if framework == "torch": from ray.rllib.models.experimental.torch.fcmodel import FCModel @@ -43,6 +67,7 @@ def build(self, framework: str = "torch") -> Model: @dataclass class FCEncoderConfig(FCConfig): + @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": from ray.rllib.models.experimental.torch.encoder import FCEncoder @@ -58,11 +83,10 @@ class LSTMEncoderConfig(ModelConfig): num_layers: int = None batch_first: bool = True - @_check_framework + @_framework_implemented(tf=False) def build(self, framework: str = "torch"): - if not framework == "torch": - raise ValueError("Only torch framework supported.") - from rllib.models.experimental.torch.encoder import LSTMEncoder + if framework == "torch": + from rllib.models.experimental.torch.encoder import LSTMEncoder return LSTMEncoder(self) @@ -71,7 +95,7 @@ def build(self, framework: str = "torch"): class IdentityConfig(ModelConfig): """Configuration for an identity encoder.""" - @_check_framework + @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": from rllib.models.experimental.torch.encoder import IdentityEncoder From da887b6f4c95a790e820f5a2207cdaba40238156 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 23:03:10 -0800 Subject: [PATCH 29/51] fix import Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/torch/primitives.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index 3a4b5ee92c30..97b2a9034017 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -9,7 +9,7 @@ from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType -from rllib.models.experimental.base import ModelConfig +from ray.rllib.models.experimental.base import ModelConfig torch, nn = try_import_torch() From 42e0d49bd2c31bf9a12c8d41199f406e3d02a08a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 23 Jan 2023 23:04:47 -0800 Subject: [PATCH 30/51] fix misspelling Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rllib/models/experimental/README.rst b/rllib/models/experimental/README.rst index a45919dadc77..2ef2007403e2 100644 --- a/rllib/models/experimental/README.rst +++ b/rllib/models/experimental/README.rst @@ -1,2 +1,2 @@ -This folder holds models that are under development and to be used with RLModules in upcoming versions of RLLib. -They are not yet ready for use in the current version of RLLib. \ No newline at end of file +This folder holds models that are under development and to be used with RLModules in upcoming versions of RLlib. +They are not yet ready for use in the current version of RLlib. \ No newline at end of file From aef9875ccab5dc00a88111f4b846b863b49cda7b Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 24 Jan 2023 00:05:23 -0800 Subject: [PATCH 31/51] fix import Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index d1a10d916a82..ad12eaf4d1a2 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -1,20 +1,21 @@ import itertools -import ray import unittest -import numpy as np + import gymnasium as gym -import torch +import numpy as np import tensorflow as tf +import torch import tree +import ray from ray.rllib import SampleBatch -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( - PPOTorchRLModule, -) -from rllib.algorithms.ppo.ppo import PPOModuleConfig from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( PPOTfRLModule, ) +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( + PPOTorchRLModule, +) from ray.rllib.models.experimental.configs import ( FCConfig, FCEncoderConfig, From d0d5277398e833bac513a1a76edacbf1f10d838e Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 24 Jan 2023 00:06:23 -0800 Subject: [PATCH 32/51] also fix other imports Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 2 +- rllib/models/experimental/configs.py | 6 +++--- rllib/models/experimental/tf/fcmodel.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 48538316bdfb..4b1f3c153e87 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -11,7 +11,7 @@ from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian from ray.rllib.models.experimental.tf.primitives import FCNet from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT -from rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index 985ec145cac8..2142f6ef0cfc 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -86,7 +86,7 @@ class LSTMEncoderConfig(ModelConfig): @_framework_implemented(tf=False) def build(self, framework: str = "torch"): if framework == "torch": - from rllib.models.experimental.torch.encoder import LSTMEncoder + from ray.rllib.models.experimental.torch.encoder import LSTMEncoder return LSTMEncoder(self) @@ -98,8 +98,8 @@ class IdentityConfig(ModelConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": - from rllib.models.experimental.torch.encoder import IdentityEncoder + from ray.rllib.models.experimental.torch.encoder import IdentityEncoder else: - from rllib.models.experimental.tf.encoder import IdentityEncoder + from ray.rllib.models.experimental.tf.encoder import IdentityEncoder return IdentityEncoder(self) diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py index ff9faf2780bf..c40727512514 100644 --- a/rllib/models/experimental/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -3,7 +3,7 @@ from ray.rllib.utils import try_import_tf from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.models.tf.primitives import FCNet, TFModel -from rllib.models.experimental.base import ModelConfig, ForwardOutputType +from ray.rllib.models.experimental.base import ModelConfig, ForwardOutputType tf1, tf, tfv = try_import_tf() From 3011262ccef4d320e3899a2f3295934dcc262334 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 24 Jan 2023 12:49:42 -0800 Subject: [PATCH 33/51] typo Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 67cf3c4d4add..06b37f1adb9c 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -97,7 +97,7 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: class Encoder(Model): - """The base class for all encoders Rllib produces. + """The base class for all encoders RLlib produces. Encoders are used to encode observations into a latent space in RLModules. Therefore, their input_spec usually contains the observation space dimensions. From 3591fabf24a8b7fa61c20a8e26a6a55aa45c9d2a Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 24 Jan 2023 15:54:14 -0800 Subject: [PATCH 34/51] delete unneeded configs Signed-off-by: Artur Niederfahrenhorst --- rllib/models/configs/encoder.py | 83 --------------------------------- 1 file changed, 83 deletions(-) delete mode 100644 rllib/models/configs/encoder.py diff --git a/rllib/models/configs/encoder.py b/rllib/models/configs/encoder.py deleted file mode 100644 index 38a7f305123a..000000000000 --- a/rllib/models/configs/encoder.py +++ /dev/null @@ -1,83 +0,0 @@ -import abc -from dataclasses import dataclass -from typing import TYPE_CHECKING, Tuple - -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.torch.encoders.vector import TorchVectorEncoder - -if TYPE_CHECKING: - from ray.rllib.models.torch.encoders.vector import Encoder - - -@dataclass -class EncoderConfig: - """The base config for encoder models. - - Each config should define a `build` method that builds a model from the config. - - All user-configurable parameters known before runtime - (e.g. framework, activation, num layers, etc.) should be defined as attributes. - - Parameters unknown before runtime (e.g. the output size of the module providing - input for this module) should be passed as arguments to `build`. This should be - as few params as possible. - - `build` should return an instance of the encoder associated with the config. - - Attributes: - framework_str: The tensor framework to construct a model for. - This can be 'torch', 'tf2', or 'jax'. - """ - - framework_str: str = "torch" - - @abc.abstractmethod - def build(self, input_spec: SpecDict, **kwargs) -> "Encoder": - """Builds the EncoderConfig into an Encoder instance""" - - -@dataclass -class VectorEncoderConfig(EncoderConfig): - """An MLP encoder mappings tensors with shape [..., feature] to [..., output]. - - Attributes: - activation: The type of activation function to use between hidden layers. - Options are 'relu', 'swish', 'tanh', or 'linear' - final_activation: The activation function to use after the final linear layer. - Options are the same as for activation. - hidden_layer_sizes: A list, where each element represents the number of neurons - in that layer. For example, [128, 64] would produce a two-layer MLP with - 128 hidden neurons and 64 hidden neurons. - output_key: Write the output of the encoder to this key in the NestedDict. - """ - - activation: str = "relu" - final_activation: str = "linear" - hidden_layer_sizes: Tuple[int, ...] = (128, 128) - output_key: str = "encoding" - - def build(self, input_spec: SpecDict) -> TorchVectorEncoder: - """Build the config into a VectorEncoder model instance. - - Args: - input_spec: The output spec of the previous module(s) that will feed - inputs to this encoder. - - Returns: - A VectorEncoder of the specified framework. - """ - assert ( - len(self.hidden_layer_sizes) > 1 - ), "Must have at least a single hidden layer" - for k in input_spec.shallow_keys(): - assert isinstance( - input_spec[k].shape[-1], int - ), "Input spec {k} does not define the size of the feature (last) dimension" - - if self.framework_str == "torch": - return TorchVectorEncoder(input_spec, self) - else: - raise NotImplementedError( - "{self.__class__.__name__} not implemented" - " for framework {self.framework}" - ) From ab2c301f52d3e31f7374215edd0ad1b3fdb2f11e Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 24 Jan 2023 15:57:07 -0800 Subject: [PATCH 35/51] remove unneeded model configs Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/model_configs.py | 81 ---------------------- 1 file changed, 81 deletions(-) delete mode 100644 rllib/models/experimental/model_configs.py diff --git a/rllib/models/experimental/model_configs.py b/rllib/models/experimental/model_configs.py deleted file mode 100644 index bb9175048a8a..000000000000 --- a/rllib/models/experimental/model_configs.py +++ /dev/null @@ -1,81 +0,0 @@ -from dataclasses import dataclass, field -from typing import List -import functools - -from ray.rllib.models.experimental.base import ModelConfig, Model - - -def _check_framework(fn): - @functools.wraps(fn) - def checked_build(self, framework, **kwargs): - if framework not in ("torch", "tf", "tf2"): - raise ValueError(f"Framework {framework} not supported.") - return fn(self, framework, **kwargs) - - return checked_build - - -@dataclass -class FCConfig(ModelConfig): - """Configuration for a fully connected network. - - Attributes: - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - output_activation: str = "ReLU" - - @_check_framework - def build(self, framework: str = "torch") -> Model: - if framework == "torch": - from ray.rllib.models.experimental.torch.fcmodel import FCModel - else: - from ray.rllib.models.experimental.tf.fcmodel import FCModel - return FCModel(self) - - -@dataclass -class FCEncoderConfig(FCConfig): - def build(self, framework: str = "torch"): - if framework == "torch": - from ray.rllib.models.experimental.torch.encoder import FCEncoder - else: - from ray.rllib.models.experimental.tf.encoder import FCEncoder - return FCEncoder(self) - - -@dataclass -class LSTMEncoderConfig(ModelConfig): - input_dim: int = None - hidden_dim: int = None - num_layers: int = None - batch_first: bool = True - - @_check_framework - def build(self, framework: str = "torch"): - if not framework == "torch": - raise ValueError("Only torch framework supported.") - from rllib.models.experimental.torch.encoder import LSTMEncoder - - return LSTMEncoder(self) - - -@dataclass -class IdentityConfig(ModelConfig): - """Configuration for an identity encoder.""" - - @_check_framework - def build(self, framework: str = "torch"): - if framework == "torch": - from rllib.models.experimental.torch.encoder import IdentityEncoder - else: - from rllib.models.experimental.tf.encoder import IdentityEncoder - - return IdentityEncoder(self) From e7aa5282036c20893582cfac507147f6b5a4cea0 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 25 Jan 2023 21:57:11 -0800 Subject: [PATCH 36/51] sven's nits Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 6 +- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 30 ++++---- .../ppo/torch/ppo_torch_rl_module.py | 28 ++++---- rllib/models/experimental/base.py | 68 +++---------------- rllib/models/experimental/configs.py | 38 +++++++---- rllib/models/experimental/encoder.py | 60 ++++++++++++++++ rllib/models/experimental/tf/encoder.py | 28 ++++---- rllib/models/experimental/tf/fcmodel.py | 2 +- rllib/models/experimental/tf/primitives.py | 8 +-- rllib/models/experimental/torch/encoder.py | 21 +++--- rllib/models/experimental/torch/fcmodel.py | 6 +- rllib/models/experimental/torch/primitives.py | 6 +- rllib/models/specs/checker.py | 5 +- 13 files changed, 164 insertions(+), 142 deletions(-) create mode 100644 rllib/models/experimental/encoder.py diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index ad12eaf4d1a2..64bbf84e6d82 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -48,7 +48,7 @@ def get_expected_model_config( obs_dim = env.observation_space.shape[0] if lstm: - shared_encoder_config = LSTMEncoderConfig( + encoder_config = LSTMEncoderConfig( input_dim=obs_dim, hidden_dim=32, batch_first=True, @@ -56,7 +56,7 @@ def get_expected_model_config( output_dim=32, ) else: - shared_encoder_config = FCEncoderConfig( + encoder_config = FCEncoderConfig( input_dim=obs_dim, hidden_layers=[32], activation="ReLU", @@ -82,7 +82,7 @@ def get_expected_model_config( return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, pi_config=pi_config, vf_config=vf_config, ) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 4b1f3c153e87..2b96b1d1c4a6 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -9,7 +9,7 @@ from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian -from ray.rllib.models.experimental.tf.primitives import FCNet +from ray.rllib.models.experimental.tf.primitives import TfFCNet from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig @@ -26,17 +26,17 @@ def __init__(self, config: RLModuleConfig): def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - self.shared_encoder = self.config.shared_encoder_config.build(framework="tf") + self.encoder = self.config.encoder_config.build(framework="tf") - self.pi = FCNet( - input_dim=self.config.shared_encoder_config.output_dim, + self.pi = TfFCNet( + input_dim=self.config.encoder_config.output_dim, output_dim=self.config.pi_config.output_dim, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) - self.vf = FCNet( - input_dim=self.config.shared_encoder_config.output_dim, + self.vf = TfFCNet( + input_dim=self.config.encoder_config.output_dim, output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, activation=self.config.vf_config.activation, @@ -60,7 +60,7 @@ def output_specs_train(self) -> List[str]: @override(TfRLModule) def _forward_train(self, batch: NestedDict): - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) action_logits = self.pi(encoder_out[ENCODER_OUT]) vf = self.vf(encoder_out[ENCODER_OUT]) @@ -87,7 +87,7 @@ def output_specs_inference(self) -> List[str]: @override(TfRLModule) def _forward_inference(self, batch) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) action_logits = self.pi(encoder_out[ENCODER_OUT]) @@ -116,7 +116,7 @@ def output_specs_exploration(self) -> List[str]: @override(TfRLModule) def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) action_logits = self.pi(encoder_out[ENCODER_OUT]) vf = self.vf(encoder_out[ENCODER_OUT]) @@ -160,14 +160,14 @@ def from_model_config( if use_lstm: raise ValueError("LSTM not supported by PPOTfRLModule yet.") if vf_share_layers: - shared_encoder_config = FCConfig( + encoder_config = FCConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens, activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) + encoder_config = IdentityConfig(output_dim=obs_dim) assert isinstance( observation_space, gym.spaces.Box ), "This simple PPOModule only supports Box observation space." @@ -181,21 +181,21 @@ def from_model_config( ) pi_config = FCConfig() vf_config = FCConfig() - shared_encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = shared_encoder_config.output_dim + encoder_config.input_dim = observation_space.shape[0] + pi_config.input_dim = encoder_config.output_dim pi_config.hidden_layers = fcnet_hiddens if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_config.input_dim = shared_encoder_config.output_dim + vf_config.input_dim = encoder_config.output_dim vf_config.hidden_layers = fcnet_hiddens vf_config.output_dim = 1 config_ = PPOModuleConfig( pi_config=pi_config, vf_config=vf_config, - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, observation_space=observation_space, action_space=action_space, ) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 2ced64f5a3a1..b589c34c56b2 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -5,7 +5,7 @@ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig from ray.rllib.core.rl_module.torch import TorchRLModule -from ray.rllib.models.experimental.base import STATE_OUT +from ray.rllib.models.experimental.encoder import STATE_OUT from ray.rllib.models.experimental.configs import FCConfig, FCEncoderConfig from ray.rllib.models.experimental.configs import ( LSTMEncoderConfig, @@ -51,7 +51,7 @@ class PPOModuleConfig(RLModuleConfig): Attributes: observation_space: The observation space of the environment. action_space: The action space of the environment. - shared_encoder_config: The configuration for the encoder network. + encoder_config: The configuration for the encoder network. pi_config: The configuration for the policy network. vf_config: The configuration for the value network. free_log_std: For DiagGaussian action distributions, make the second half of @@ -59,9 +59,7 @@ class PPOModuleConfig(RLModuleConfig): only has an effect is using the default fully connected net. """ - observation_space: gym.Space = None - action_space: gym.Space = None - shared_encoder_config: FCConfig = None + encoder_config: FCConfig = None pi_config: FCConfig = None vf_config: FCConfig = None free_log_std: bool = False @@ -76,12 +74,10 @@ def __init__(self, config: PPOModuleConfig) -> None: def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - assert self.config.shared_encoder_config, ( - "shared encoder config must be " "provided." - ) + assert self.config.encoder_config, "shared encoder config must be " "provided." # TODO(Artur): Unify to tf and torch setup(framework) - self.shared_encoder = self.config.shared_encoder_config.build(framework="torch") + self.shared_encoder = self.config.encoder_config.build(framework="torch") self.pi = self.config.pi_config.build(framework="torch") self.vf = self.config.vf_config.build(framework="torch") @@ -120,7 +116,7 @@ def from_model_config( ), "`vf_share_layers=False` is no longer supported." if model_config["use_lstm"]: - shared_encoder_config = LSTMEncoderConfig( + encoder_config = LSTMEncoderConfig( input_dim=obs_dim, hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], @@ -128,7 +124,7 @@ def from_model_config( output_dim=model_config["lstm_cell_size"], ) else: - shared_encoder_config = FCEncoderConfig( + encoder_config = FCEncoderConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens[:-1], activation=activation, @@ -136,12 +132,12 @@ def from_model_config( ) pi_config = FCConfig( - input_dim=shared_encoder_config.output_dim, + input_dim=encoder_config.output_dim, hidden_layers=[32], activation="ReLU", ) vf_config = FCConfig( - input_dim=shared_encoder_config.output_dim, + input_dim=encoder_config.output_dim, hidden_layers=[32, 1], activation="ReLU", output_dim=1, @@ -160,8 +156,8 @@ def from_model_config( ) # build policy network head - shared_encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = shared_encoder_config.output_dim + encoder_config.input_dim = observation_space.shape[0] + pi_config.input_dim = encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: @@ -170,7 +166,7 @@ def from_model_config( config_ = PPOModuleConfig( observation_space=observation_space, action_space=action_space, - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, pi_config=pi_config, vf_config=vf_config, free_log_std=free_log_std, diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 06b37f1adb9c..79e1eac67824 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -7,12 +7,9 @@ ) from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.utils.typing import TensorType from ray.rllib.utils.annotations import ExperimentalAPI ForwardOutputType = TensorDict -STATE_IN: str = "state_in" -STATE_OUT: str = "state_out" @ExperimentalAPI @@ -21,8 +18,8 @@ class ModelConfig(abc.ABC): """Configuration for a model. Attributes: - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. + output_dim: The output dimension of the network. If None, the output_dim will + be the number of nodes in the last hidden layer. """ output_dim: int = None @@ -38,7 +35,13 @@ def build(self, framework: str = "torch"): class Model: - """Base class for RLlib models.""" + """Framework-agnostic base class for RLlib models. + + Models are low-level neural network components that offer input- and + output-specification, a forward method, and a get_initial_state method. They are + therefore not algorithm-specific. Models are composed in RLModules, where tensors + are passed through them. + """ def __init__(self, config: ModelConfig): self.config = config @@ -79,63 +82,14 @@ def input_spec(self) -> SpecDict: def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: """Computes the output of this module for each timestep. - Outputs and inputs should be subjected to spec checking. - - Args: - inputs: A TensorDict containing model inputs - kwargs: For forwards compatibility - - Returns: - outputs: A TensorDict containing model outputs - - Examples: - # This is abstract, see the torch/tf/jax implementations - >>> out = model(TensorDict({"in": np.arange(10)})) - >>> out # TensorDict(...) - """ - raise NotImplementedError - - -class Encoder(Model): - """The base class for all encoders RLlib produces. - - Encoders are used to encode observations into a latent space in RLModules. - Therefore, their input_spec usually contains the observation space dimensions. - Their output_spec usually contains the latent space dimensions. - Encoders can be recurrent, in which case they should also have state_specs. - """ - - def __init__(self, config: dict): - super().__init__(config) - - def get_initial_state(self) -> TensorType: - """Returns the initial state of the encoder. - - This is the initial state of the encoder. - It can be left empty if this encoder is not stateful. - - Examples: - >>> ... - """ - return {} - - @check_input_specs("input_spec", cache=True) - @check_output_specs("output_spec", cache=True) - @abc.abstractmethod - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - """Computes the output of this module for each timestep. - - Outputs and inputs are subjected to spec checking. + Outputs and inputs should be subject to spec checking. Args: inputs: A TensorDict containing model inputs kwargs: For forwards compatibility - Returns: - outputs: A TensorDict containing model outputs - Examples: - # This is abstract, see the torch/tf/jax implementations + # This is abstract, see the torch/tf2/jax implementations >>> out = model(TensorDict({"in": np.arange(10)})) >>> out # TensorDict(...) """ diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index 2142f6ef0cfc..4f59ed964351 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -7,12 +7,12 @@ @DeveloperAPI -def _framework_implemented(torch: bool = True, tf: bool = True): +def _framework_implemented(torch: bool = True, tf2: bool = True): """Decorator to check if a model was implemented in a framework. Args: torch: Whether we can build this model with torch. - tf: Whether we can build this model with tf. + tf2: Whether we can build this model with tf. Returns: The decorated function. @@ -23,7 +23,7 @@ def _framework_implemented(torch: bool = True, tf: bool = True): accepted = [] if torch: accepted.append("torch") - if tf: + if tf2: accepted.append("tf") accepted.append("tf2") @@ -59,10 +59,13 @@ class FCConfig(ModelConfig): @_framework_implemented() def build(self, framework: str = "torch") -> Model: if framework == "torch": - from ray.rllib.models.experimental.torch.fcmodel import FCModel + from ray.rllib.models.experimental.torch.fcmodel import TorchFCModel + + return TorchFCModel(self) else: - from ray.rllib.models.experimental.tf.fcmodel import FCModel - return FCModel(self) + from ray.rllib.models.experimental.tf.fcmodel import TfFCModel + + return TfFCModel(self) @dataclass @@ -70,10 +73,13 @@ class FCEncoderConfig(FCConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": - from ray.rllib.models.experimental.torch.encoder import FCEncoder + from ray.rllib.models.experimental.torch.encoder import TorchFCEncoder + + return TorchFCEncoder(self) else: - from ray.rllib.models.experimental.tf.encoder import FCEncoder - return FCEncoder(self) + from ray.rllib.models.experimental.tf.encoder import TfFCEncoder + + return TfFCEncoder(self) @dataclass @@ -83,12 +89,12 @@ class LSTMEncoderConfig(ModelConfig): num_layers: int = None batch_first: bool = True - @_framework_implemented(tf=False) + @_framework_implemented(tf2=False) def build(self, framework: str = "torch"): if framework == "torch": - from ray.rllib.models.experimental.torch.encoder import LSTMEncoder + from ray.rllib.models.experimental.torch.encoder import TorchLSTMEncoder - return LSTMEncoder(self) + return TorchLSTMEncoder(self) @dataclass @@ -98,8 +104,10 @@ class IdentityConfig(ModelConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": - from ray.rllib.models.experimental.torch.encoder import IdentityEncoder + from ray.rllib.models.experimental.torch.encoder import TorchIdentityEncoder + + return TorchIdentityEncoder(self) else: - from ray.rllib.models.experimental.tf.encoder import IdentityEncoder + from ray.rllib.models.experimental.tf.encoder import TfIdentityEncoder - return IdentityEncoder(self) + return TfIdentityEncoder(self) diff --git a/rllib/models/experimental/encoder.py b/rllib/models/experimental/encoder.py new file mode 100644 index 000000000000..d9506bb2423b --- /dev/null +++ b/rllib/models/experimental/encoder.py @@ -0,0 +1,60 @@ +import abc + +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.utils.typing import TensorType +from rllib.models.experimental.base import Model, ForwardOutputType + +STATE_IN: str = "state_in" +STATE_OUT: str = "state_out" + + +class Encoder(Model): + """The framework-agnostic base class for all encoders RLlib produces. + + Encoders are used to encode observations into a latent space in RLModules. + Therefore, their input_spec usually contains the observation space dimensions. + Their output_spec usually contains the latent space dimensions. + Encoders can be recurrent, in which case they should also have state_specs. + + Encoders encode observations into a latent space that serve as input to heads. + Outputs of encoders are generally of shape (B, latent_dim) or (B, T, latent_dim). + That is, for time-series data, we encode into the latent space for each time step. + This should be reflected in the output_spec. + """ + + def __init__(self, config: dict): + super().__init__(config) + + def get_initial_state(self) -> TensorType: + """Returns the initial state of the encoder. + + This is the initial state of the encoder. + It can be left empty if this encoder is not stateful. + + Examples: + >>> ... + """ + return {} + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) + @abc.abstractmethod + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + """Computes the output of this module for each timestep. + + Outputs and inputs are subjected to spec checking. + + Args: + inputs: A TensorDict containing model inputs + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + + Examples: + # This is abstract, see the torch/tf2/jax implementations + >>> out = model(TensorDict({"in": np.arange(10)})) + >>> out # TensorDict(...) + """ + raise NotImplementedError diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index fb3939421800..dd8dbbe7c563 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -3,31 +3,33 @@ import tree from ray.rllib.models.experimental.base import ( + ForwardOutputType, + ModelConfig, +) +from ray.rllib.models.experimental.encoder import ( Encoder, STATE_IN, STATE_OUT, - ForwardOutputType, - ModelConfig, ) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.experimental.tf.primitives import FCNet +from ray.rllib.models.experimental.tf.primitives import TfFCNet from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.models.experimental.torch.encoder import ENCODER_OUT -from ray.rllib.models.experimental.tf.primitives import TFModel +from ray.rllib.models.experimental.tf.primitives import TfFCModel -class FCEncoder(Encoder, TFModel): +class TfFCEncoder(Encoder, TfFCModel): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TFModel.__init__(self, config) + TfFCModel.__init__(self, config) - self.net = FCNet( + self.net = TfFCNet( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, @@ -50,12 +52,12 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(Encoder, TFModel): +class LSTMEncoder(Encoder, TfFCModel): """An encoder that uses an LSTM cell and a linear layer.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TFModel.__init__(self, config) + TfFCModel.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -134,9 +136,11 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class IdentityEncoder(TFModel): - def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - pass +class TfIdentityEncoder(TfFCModel): + """An encoder that does nothing but passing on inputs. + + We use this so that we avoid having many if/else statements in the RLModule. + """ def __init__(self, config: ModelConfig) -> None: super().__init__(config) diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py index c40727512514..18577ec2eb7f 100644 --- a/rllib/models/experimental/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -8,7 +8,7 @@ tf1, tf, tfv = try_import_tf() -class FCModel(tf.Module, TFModel): +class TfFCModel(tf.Module, TFModel): def __init__(self, config: ModelConfig) -> None: tf.Module.__init__(self) TFModel.__init__(self, config) diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 868a85ce01a8..361d92f4dd86 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -1,7 +1,7 @@ from typing import List from ray.rllib.utils.framework import try_import_tf from ray.rllib.models.specs.checker import ( - input_is_decorated, + is_input_decorated, is_output_decorated, ) from ray.rllib.models.temp_spec_classes import TensorDict @@ -20,7 +20,7 @@ def _call_not_decorated(input_or_output): ) -class TFModel(Model): +class TfFCModel(Model, tf.Module): """Base class for RLlib models. This class is used to define the general interface for RLlib models and checks @@ -30,7 +30,7 @@ class TFModel(Model): def __init__(self, config): super().__init__(config) - assert input_is_decorated(self.__call__), _call_not_decorated("input") + assert is_input_decorated(self.__call__), _call_not_decorated("input") assert is_output_decorated(self.__call__), _call_not_decorated("output") def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: @@ -45,7 +45,7 @@ def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType] raise NotImplementedError -class FCNet(tf.Module): +class TfFCNet(tf.Module): """A simple fully connected network. Attributes: diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index bb7c46214e2f..801b53ad72a5 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -3,11 +3,14 @@ import tree from ray.rllib.models.experimental.base import ( - STATE_IN, - STATE_OUT, ForwardOutputType, ModelConfig, ) +from ray.rllib.models.experimental.encoder import ( + Encoder, + STATE_IN, + STATE_OUT, +) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override @@ -15,20 +18,19 @@ from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.experimental.torch.primitives import FCNet, TorchModel -from ray.rllib.models.experimental.base import Encoder +from ray.rllib.models.experimental.torch.primitives import TorchFCNet, TorchModel ENCODER_OUT: str = "encoder_out" -class FCEncoder(TorchModel, Encoder): +class TorchFCEncoder(TorchModel, Encoder): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: TorchModel.__init__(self, config) Encoder.__init__(self, config) - self.net = FCNet( + self.net = TorchFCNet( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, @@ -55,7 +57,7 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(TorchModel, Encoder): +class TorchLSTMEncoder(TorchModel, Encoder): """An encoder that uses an LSTM cell and a linear layer.""" def __init__(self, config: ModelConfig) -> None: @@ -140,10 +142,7 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class IdentityEncoder(TorchModel): - def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - pass - +class TorchIdentityEncoder(TorchModel): def __init__(self, config: ModelConfig) -> None: super().__init__(config) diff --git a/rllib/models/experimental/torch/fcmodel.py b/rllib/models/experimental/torch/fcmodel.py index 71e0b8ea0e96..dc2117bac462 100644 --- a/rllib/models/experimental/torch/fcmodel.py +++ b/rllib/models/experimental/torch/fcmodel.py @@ -4,17 +4,17 @@ from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.experimental.torch.primitives import FCNet +from ray.rllib.models.experimental.torch.primitives import TorchFCNet from ray.rllib.models.experimental.torch.primitives import TorchModel from ray.rllib.utils.annotations import override -class FCModel(TorchModel, nn.Module): +class TorchFCModel(TorchModel, nn.Module): def __init__(self, config: ModelConfig) -> None: nn.Module.__init__(self) TorchModel.__init__(self, config) - self.net = FCNet( + self.net = TorchFCNet( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index 97b2a9034017..395e602c2555 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -3,7 +3,7 @@ from ray.rllib.models.experimental.base import Model from ray.rllib.models.specs.checker import ( - input_is_decorated, + is_input_decorated, is_output_decorated, ) from ray.rllib.models.temp_spec_classes import TensorDict @@ -33,7 +33,7 @@ class TorchModel(nn.Module, Model): def __init__(self, config: ModelConfig): nn.Module.__init__(self) Model.__init__(self, config) - assert input_is_decorated(self.forward), _forward_not_decorated("input") + assert is_input_decorated(self.forward), _forward_not_decorated("input") assert is_output_decorated(self.forward), _forward_not_decorated("output") def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: @@ -48,7 +48,7 @@ def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]] raise NotImplementedError -class FCNet(nn.Module): +class TorchFCNet(nn.Module): """A simple fully connected network. Attributes: diff --git a/rllib/models/specs/checker.py b/rllib/models/specs/checker.py index 2661d1750086..1662c58aa07d 100644 --- a/rllib/models/specs/checker.py +++ b/rllib/models/specs/checker.py @@ -338,12 +338,13 @@ def wrapper(self, input_data, **kwargs): return decorator -def input_is_decorated(obj: object) -> bool: +@DeveloperAPI +def is_input_decorated(obj: object) -> bool: """Returns True if the object is decorated with `check_input_specs`.""" return hasattr(obj, "__checked_input_specs__") -@DeveloperAPI(stability="alpha") +@DeveloperAPI def is_output_decorated(obj: object) -> bool: """Returns True if the object is decorated with `check_output_specs`.""" return hasattr(obj, "__checked_output_specs__") From 2d709d50174e9fa0dfc179879cb714a6f43211a6 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 25 Jan 2023 22:06:55 -0800 Subject: [PATCH 37/51] sven's nits Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 6 +++--- rllib/models/experimental/configs.py | 12 ++++++------ rllib/models/experimental/tf/encoder.py | 16 ++++++++-------- rllib/models/experimental/tf/fcmodel.py | 2 +- rllib/models/experimental/tf/primitives.py | 6 +++--- rllib/models/experimental/torch/encoder.py | 6 +++--- rllib/models/experimental/torch/fcmodel.py | 6 +++--- rllib/models/experimental/torch/primitives.py | 4 ++-- 8 files changed, 29 insertions(+), 29 deletions(-) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 2b96b1d1c4a6..e358384d4a5f 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -9,7 +9,7 @@ from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian -from ray.rllib.models.experimental.tf.primitives import TfFCNet +from ray.rllib.models.experimental.tf.primitives import TfMLP from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig @@ -28,14 +28,14 @@ def setup(self) -> None: assert self.config.vf_config, "vf_config must be provided." self.encoder = self.config.encoder_config.build(framework="tf") - self.pi = TfFCNet( + self.pi = TfMLP( input_dim=self.config.encoder_config.output_dim, output_dim=self.config.pi_config.output_dim, hidden_layers=self.config.pi_config.hidden_layers, activation=self.config.pi_config.activation, ) - self.vf = TfFCNet( + self.vf = TfMLP( input_dim=self.config.encoder_config.output_dim, output_dim=1, hidden_layers=self.config.vf_config.hidden_layers, diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index 4f59ed964351..b6b17a644b57 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -59,13 +59,13 @@ class FCConfig(ModelConfig): @_framework_implemented() def build(self, framework: str = "torch") -> Model: if framework == "torch": - from ray.rllib.models.experimental.torch.fcmodel import TorchFCModel + from ray.rllib.models.experimental.torch.fcmodel import TorchMLPModel - return TorchFCModel(self) + return TorchMLPModel(self) else: - from ray.rllib.models.experimental.tf.fcmodel import TfFCModel + from ray.rllib.models.experimental.tf.fcmodel import TfMLPModel - return TfFCModel(self) + return TfMLPModel(self) @dataclass @@ -73,9 +73,9 @@ class FCEncoderConfig(FCConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": - from ray.rllib.models.experimental.torch.encoder import TorchFCEncoder + from ray.rllib.models.experimental.torch.encoder import TorchMLPEncoder - return TorchFCEncoder(self) + return TorchMLPEncoder(self) else: from ray.rllib.models.experimental.tf.encoder import TfFCEncoder diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index dd8dbbe7c563..a63014e24eba 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -13,23 +13,23 @@ ) from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.experimental.tf.primitives import TfFCNet +from ray.rllib.models.experimental.tf.primitives import TfMLP from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.models.experimental.torch.encoder import ENCODER_OUT -from ray.rllib.models.experimental.tf.primitives import TfFCModel +from ray.rllib.models.experimental.tf.primitives import TfMLPModel -class TfFCEncoder(Encoder, TfFCModel): +class TfFCEncoder(Encoder, TfMLPModel): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TfFCModel.__init__(self, config) + TfMLPModel.__init__(self, config) - self.net = TfFCNet( + self.net = TfMLP( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, @@ -52,12 +52,12 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(Encoder, TfFCModel): +class LSTMEncoder(Encoder, TfMLPModel): """An encoder that uses an LSTM cell and a linear layer.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TfFCModel.__init__(self, config) + TfMLPModel.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -136,7 +136,7 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class TfIdentityEncoder(TfFCModel): +class TfIdentityEncoder(TfMLPModel): """An encoder that does nothing but passing on inputs. We use this so that we avoid having many if/else statements in the RLModule. diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py index 18577ec2eb7f..a85ffeb1ea50 100644 --- a/rllib/models/experimental/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -8,7 +8,7 @@ tf1, tf, tfv = try_import_tf() -class TfFCModel(tf.Module, TFModel): +class TfMLPModel(tf.Module, TFModel): def __init__(self, config: ModelConfig) -> None: tf.Module.__init__(self) TFModel.__init__(self, config) diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 361d92f4dd86..2199cdf2ebe8 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -20,7 +20,7 @@ def _call_not_decorated(input_or_output): ) -class TfFCModel(Model, tf.Module): +class TfMLPModel(Model, tf.Module): """Base class for RLlib models. This class is used to define the general interface for RLlib models and checks @@ -45,8 +45,8 @@ def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType] raise NotImplementedError -class TfFCNet(tf.Module): - """A simple fully connected network. +class TfMLP(tf.Module): + """A multi-layer perceptron. Attributes: input_dim: The input dimension of the network. It cannot be None. diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index 801b53ad72a5..aa3bd071fe83 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -18,19 +18,19 @@ from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.experimental.torch.primitives import TorchFCNet, TorchModel +from ray.rllib.models.experimental.torch.primitives import TorchMLP, TorchModel ENCODER_OUT: str = "encoder_out" -class TorchFCEncoder(TorchModel, Encoder): +class TorchMLPEncoder(TorchModel, Encoder): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: TorchModel.__init__(self, config) Encoder.__init__(self, config) - self.net = TorchFCNet( + self.net = TorchMLP( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, diff --git a/rllib/models/experimental/torch/fcmodel.py b/rllib/models/experimental/torch/fcmodel.py index dc2117bac462..17d7b3b32822 100644 --- a/rllib/models/experimental/torch/fcmodel.py +++ b/rllib/models/experimental/torch/fcmodel.py @@ -4,17 +4,17 @@ from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.experimental.torch.primitives import TorchFCNet +from ray.rllib.models.experimental.torch.primitives import TorchMLP from ray.rllib.models.experimental.torch.primitives import TorchModel from ray.rllib.utils.annotations import override -class TorchFCModel(TorchModel, nn.Module): +class TorchMLPModel(TorchModel, nn.Module): def __init__(self, config: ModelConfig) -> None: nn.Module.__init__(self) TorchModel.__init__(self, config) - self.net = TorchFCNet( + self.net = TorchMLP( input_dim=config.input_dim, hidden_layers=config.hidden_layers, output_dim=config.output_dim, diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index 395e602c2555..b15dfea7512f 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -48,8 +48,8 @@ def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]] raise NotImplementedError -class TorchFCNet(nn.Module): - """A simple fully connected network. +class TorchMLP(nn.Module): + """A multi-layer perceptron. Attributes: input_dim: The input dimension of the network. It cannot be None. From 871fe75ed878f84325d3606a2971f83aa4a7e494 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 25 Jan 2023 22:20:25 -0800 Subject: [PATCH 38/51] more refactors Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 6 +++--- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 8 ++++---- .../algorithms/ppo/torch/ppo_torch_rl_module.py | 12 ++++++------ rllib/core/rl_module/encoder.py | 4 ++-- rllib/models/experimental/configs.py | 6 +++--- rllib/models/experimental/encoder.py | 16 ++++++++-------- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 64bbf84e6d82..dc600e0fae9e 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -17,7 +17,7 @@ PPOTorchRLModule, ) from ray.rllib.models.experimental.configs import ( - FCConfig, + MLPConfig, FCEncoderConfig, LSTMEncoderConfig, ) @@ -63,12 +63,12 @@ def get_expected_model_config( output_dim=32, ) - pi_config = FCConfig( + pi_config = MLPConfig( input_dim=32, hidden_layers=[32], activation="ReLU", ) - vf_config = FCConfig( + vf_config = MLPConfig( input_dim=32, hidden_layers=[32, 1], activation="ReLU", diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index e358384d4a5f..1c4c8cba8a5a 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -3,7 +3,7 @@ from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.models.experimental.configs import FCConfig, IdentityConfig +from ray.rllib.models.experimental.configs import MLPConfig, IdentityConfig from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space @@ -160,7 +160,7 @@ def from_model_config( if use_lstm: raise ValueError("LSTM not supported by PPOTfRLModule yet.") if vf_share_layers: - encoder_config = FCConfig( + encoder_config = MLPConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens, activation=activation, @@ -179,8 +179,8 @@ def from_model_config( assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.Box)), ( "This simple PPOModule only supports Discrete and Box action space.", ) - pi_config = FCConfig() - vf_config = FCConfig() + pi_config = MLPConfig() + vf_config = MLPConfig() encoder_config.input_dim = observation_space.shape[0] pi_config.input_dim = encoder_config.output_dim pi_config.hidden_layers = fcnet_hiddens diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b589c34c56b2..9750036e4519 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -6,7 +6,7 @@ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.experimental.encoder import STATE_OUT -from ray.rllib.models.experimental.configs import FCConfig, FCEncoderConfig +from ray.rllib.models.experimental.configs import MLPConfig, FCEncoderConfig from ray.rllib.models.experimental.configs import ( LSTMEncoderConfig, ) @@ -59,9 +59,9 @@ class PPOModuleConfig(RLModuleConfig): only has an effect is using the default fully connected net. """ - encoder_config: FCConfig = None - pi_config: FCConfig = None - vf_config: FCConfig = None + encoder_config: MLPConfig = None + pi_config: MLPConfig = None + vf_config: MLPConfig = None free_log_std: bool = False @@ -131,12 +131,12 @@ def from_model_config( output_dim=fcnet_hiddens[-1], ) - pi_config = FCConfig( + pi_config = MLPConfig( input_dim=encoder_config.output_dim, hidden_layers=[32], activation="ReLU", ) - vf_config = FCConfig( + vf_config = MLPConfig( input_dim=encoder_config.output_dim, hidden_layers=[32, 1], activation="ReLU", diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index f3bb22b46900..e88bcfdce1e3 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -40,7 +40,7 @@ def build(self): @dataclass -class FCConfig(EncoderConfig): +class MLPConfig(EncoderConfig): """Configuration for a fully connected network. input_dim: The input dimension of the network. It cannot be None. hidden_layers: The sizes of the hidden layers. @@ -94,7 +94,7 @@ def _forward(self, input_dict): class FullyConnectedEncoder(Encoder): - def __init__(self, config: FCConfig) -> None: + def __init__(self, config: MLPConfig) -> None: super().__init__(config) self.net = FCNet( diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index b6b17a644b57..bb2b975ef77b 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -12,7 +12,7 @@ def _framework_implemented(torch: bool = True, tf2: bool = True): Args: torch: Whether we can build this model with torch. - tf2: Whether we can build this model with tf. + tf2: Whether we can build this model with tf2. Returns: The decorated function. @@ -40,7 +40,7 @@ def checked_build(self, framework, **kwargs): @dataclass -class FCConfig(ModelConfig): +class MLPConfig(ModelConfig): """Configuration for a fully connected network. Attributes: @@ -69,7 +69,7 @@ def build(self, framework: str = "torch") -> Model: @dataclass -class FCEncoderConfig(FCConfig): +class FCEncoderConfig(MLPConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": diff --git a/rllib/models/experimental/encoder.py b/rllib/models/experimental/encoder.py index d9506bb2423b..30b4852a9a6a 100644 --- a/rllib/models/experimental/encoder.py +++ b/rllib/models/experimental/encoder.py @@ -3,7 +3,7 @@ from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.utils.typing import TensorType -from rllib.models.experimental.base import Model, ForwardOutputType +from ray.rllib.models.experimental.base import Model, ForwardOutputType STATE_IN: str = "state_in" STATE_OUT: str = "state_out" @@ -13,8 +13,8 @@ class Encoder(Model): """The framework-agnostic base class for all encoders RLlib produces. Encoders are used to encode observations into a latent space in RLModules. - Therefore, their input_spec usually contains the observation space dimensions. - Their output_spec usually contains the latent space dimensions. + Therefore, their input_spec contains the observation space dimensions. + Similarly, their output_spec usually the latent space dimensions. Encoders can be recurrent, in which case they should also have state_specs. Encoders encode observations into a latent space that serve as input to heads. @@ -29,11 +29,12 @@ def __init__(self, config: dict): def get_initial_state(self) -> TensorType: """Returns the initial state of the encoder. - This is the initial state of the encoder. It can be left empty if this encoder is not stateful. Examples: - >>> ... + >>> encoder = Encoder(...) + >>> state = encoder.get_initial_state() + >>> out = encoder.forward({"obs": ..., STATE_IN: state}) """ return {} @@ -53,8 +54,7 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: outputs: A TensorDict containing model outputs Examples: - # This is abstract, see the torch/tf2/jax implementations - >>> out = model(TensorDict({"in": np.arange(10)})) - >>> out # TensorDict(...) + # This is abstract, see the framework implementations + >>> out = encoder.forward({"obs": np.arange(10)})) """ raise NotImplementedError From 0bafdd7edfc7bd5ed4f8a043d6efdec0f3b95f84 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Wed, 25 Jan 2023 22:25:24 -0800 Subject: [PATCH 39/51] another nit Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tests/test_ppo_rl_module.py | 4 ++-- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 4 ++-- rllib/models/experimental/configs.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index dc600e0fae9e..405b414adeef 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -18,7 +18,7 @@ ) from ray.rllib.models.experimental.configs import ( MLPConfig, - FCEncoderConfig, + MLPEncoderConfig, LSTMEncoderConfig, ) from ray.rllib.models.experimental.torch.encoder import ( @@ -56,7 +56,7 @@ def get_expected_model_config( output_dim=32, ) else: - encoder_config = FCEncoderConfig( + encoder_config = MLPEncoderConfig( input_dim=obs_dim, hidden_layers=[32], activation="ReLU", diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 9750036e4519..8936c4c51808 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -6,7 +6,7 @@ from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.models.experimental.encoder import STATE_OUT -from ray.rllib.models.experimental.configs import MLPConfig, FCEncoderConfig +from ray.rllib.models.experimental.configs import MLPConfig, MLPEncoderConfig from ray.rllib.models.experimental.configs import ( LSTMEncoderConfig, ) @@ -124,7 +124,7 @@ def from_model_config( output_dim=model_config["lstm_cell_size"], ) else: - encoder_config = FCEncoderConfig( + encoder_config = MLPEncoderConfig( input_dim=obs_dim, hidden_layers=fcnet_hiddens[:-1], activation=activation, diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index bb2b975ef77b..2b4a3479f49a 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -69,7 +69,7 @@ def build(self, framework: str = "torch") -> Model: @dataclass -class FCEncoderConfig(MLPConfig): +class MLPEncoderConfig(MLPConfig): @_framework_implemented() def build(self, framework: str = "torch"): if framework == "torch": From 69d96559e6e3e0afd048c8be23e6f2e087f6e443 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 26 Jan 2023 00:49:17 -0800 Subject: [PATCH 40/51] some more renaming Signed-off-by: Artur Niederfahrenhorst --- .../ppo/tests/test_ppo_rl_module.py | 12 ++--- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 16 +++--- .../ppo/torch/ppo_torch_rl_module.py | 12 ++--- rllib/models/experimental/configs.py | 13 ++--- rllib/models/experimental/tf/encoder.py | 4 +- rllib/models/experimental/tf/fcmodel.py | 4 +- rllib/models/experimental/tf/primitives.py | 31 ++++++++---- rllib/models/experimental/torch/encoder.py | 4 +- rllib/models/experimental/torch/fcmodel.py | 4 +- rllib/models/experimental/torch/primitives.py | 50 ++++++++++--------- 10 files changed, 83 insertions(+), 67 deletions(-) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 405b414adeef..18b917dd86e4 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -58,20 +58,20 @@ def get_expected_model_config( else: encoder_config = MLPEncoderConfig( input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", output_dim=32, ) pi_config = MLPConfig( input_dim=32, - hidden_layers=[32], - activation="ReLU", + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", ) vf_config = MLPConfig( input_dim=32, - hidden_layers=[32, 1], - activation="ReLU", + hidden_layer_dims=[32, 1], + hidden_layer_activation="ReLU", ) if isinstance(env.action_space, gym.spaces.Discrete): diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 1c4c8cba8a5a..b9f550ecd243 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -31,15 +31,15 @@ def setup(self) -> None: self.pi = TfMLP( input_dim=self.config.encoder_config.output_dim, output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, + hidden_layer_dims=self.config.pi_config.hidden_layer_dims, + hidden_layer_activation=self.config.pi_config.hidden_layer_activation, ) self.vf = TfMLP( input_dim=self.config.encoder_config.output_dim, output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, + hidden_layer_dims=self.config.vf_config.hidden_layer_dims, + hidden_layer_activation=self.config.vf_config.hidden_layer_activation, ) self._is_discrete = isinstance( @@ -162,8 +162,8 @@ def from_model_config( if vf_share_layers: encoder_config = MLPConfig( input_dim=obs_dim, - hidden_layers=fcnet_hiddens, - activation=activation, + hidden_layer_dims=fcnet_hiddens, + hidden_layer_activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) else: @@ -183,14 +183,14 @@ def from_model_config( vf_config = MLPConfig() encoder_config.input_dim = observation_space.shape[0] pi_config.input_dim = encoder_config.output_dim - pi_config.hidden_layers = fcnet_hiddens + pi_config.hidden_layer_dims = fcnet_hiddens if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network vf_config.input_dim = encoder_config.output_dim - vf_config.hidden_layers = fcnet_hiddens + vf_config.hidden_layer_dims = fcnet_hiddens vf_config.output_dim = 1 config_ = PPOModuleConfig( pi_config=pi_config, diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 8936c4c51808..8d07c14a7f73 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -126,20 +126,20 @@ def from_model_config( else: encoder_config = MLPEncoderConfig( input_dim=obs_dim, - hidden_layers=fcnet_hiddens[:-1], - activation=activation, + hidden_layer_dims=fcnet_hiddens[:-1], + hidden_layer_activation=activation, output_dim=fcnet_hiddens[-1], ) pi_config = MLPConfig( input_dim=encoder_config.output_dim, - hidden_layers=[32], - activation="ReLU", + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", ) vf_config = MLPConfig( input_dim=encoder_config.output_dim, - hidden_layers=[32, 1], - activation="ReLU", + hidden_layer_dims=[32, 1], + hidden_layer_activation="ReLU", output_dim=1, ) diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index 2b4a3479f49a..41f7a74edc3f 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -45,16 +45,16 @@ class MLPConfig(ModelConfig): Attributes: input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). + hidden_layer_dims: The sizes of the hidden layers. + hidden_layer_activation: The activation function to use after each layer ( + except for the output). output_activation: The activation function to use for the output layer. """ input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - output_activation: str = "ReLU" + hidden_layer_dims: List[int] = field(default_factory=lambda: [256, 256]) + hidden_layer_activation: str = "ReLU" + output_activation: str = "linear" @_framework_implemented() def build(self, framework: str = "torch") -> Model: @@ -88,6 +88,7 @@ class LSTMEncoderConfig(ModelConfig): hidden_dim: int = None num_layers: int = None batch_first: bool = True + output_activation: str = "linear" @_framework_implemented(tf2=False) def build(self, framework: str = "torch"): diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index a63014e24eba..9bbd3168d6f6 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -31,9 +31,9 @@ def __init__(self, config: ModelConfig) -> None: self.net = TfMLP( input_dim=config.input_dim, - hidden_layers=config.hidden_layers, + hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, - activation=config.activation, + hidden_layer_activation=config.hidden_layer_activation, ) @property diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py index a85ffeb1ea50..86b2423a84ec 100644 --- a/rllib/models/experimental/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -15,9 +15,9 @@ def __init__(self, config: ModelConfig) -> None: self.net = FCNet( input_dim=config.input_dim, - hidden_layers=config.hidden_layers, + hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, - activation=config.activation, + hidden_layer_activation=config.hidden_layer_activation, ) @property diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 2199cdf2ebe8..347d103a890f 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -7,6 +7,7 @@ from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.models.experimental.base import Model from ray.rllib.utils.typing import TensorType +from ray.rllib.models.utils import get_activation_fn from typing import Tuple _, tf, _ = try_import_tf() @@ -50,35 +51,47 @@ class TfMLP(tf.Module): Attributes: input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. + hidden_layer_dims: The sizes of the hidden layers. output_dim: The output dimension of the network. - activation: The activation function to use after each layer. + hidden_layer_activation: The activation function to use after each layer. Currently "Linear" (no activation) and "ReLU" are supported. + output_activation: The activation function to use for the output layer. """ def __init__( self, input_dim: int, - hidden_layers: List[int], + hidden_layer_dims: List[int], output_dim: int, - activation: str = "linear", + hidden_layer_activation: str = "linear", + output_activation: str = "linear", ): super().__init__() - assert activation in ("linear", "ReLU", "Tanh"), ( + assert hidden_layer_activation in ("linear", "ReLU", "Tanh"), ( "Activation function not " "supported" ) assert input_dim is not None, "Input dimension must not be None" assert output_dim is not None, "Output dimension must not be None" layers = [] - activation = activation.lower() + hidden_layer_activation = hidden_layer_activation.lower() # input = tf.keras.layers.Dense(input_dim, activation=activation) layers.append(tf.keras.Input(shape=(input_dim,))) - for i in range(len(hidden_layers)): + for i in range(len(hidden_layer_dims)): layers.append( - tf.keras.layers.Dense(hidden_layers[i], activation=activation) + tf.keras.layers.Dense( + hidden_layer_dims[i], activation=hidden_layer_activation + ) ) - layers.append(tf.keras.layers.Dense(output_dim)) + if output_activation != "linear": + output_activation = get_activation_fn(output_activation, framework="torch") + final_layer = tf.keras.layers.Dense( + output_dim, activation=output_activation + ) + else: + final_layer = tf.keras.layers.Dense(output_dim) + + layers.append(final_layer) self.network = tf.keras.Sequential(layers) def __call__(self, inputs): diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index aa3bd071fe83..c438308106ed 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -32,9 +32,9 @@ def __init__(self, config: ModelConfig) -> None: self.net = TorchMLP( input_dim=config.input_dim, - hidden_layers=config.hidden_layers, + hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, - activation=config.activation, + hidden_layer_activation=config.hidden_layer_activation, ) @property diff --git a/rllib/models/experimental/torch/fcmodel.py b/rllib/models/experimental/torch/fcmodel.py index 17d7b3b32822..6b9cb674b84d 100644 --- a/rllib/models/experimental/torch/fcmodel.py +++ b/rllib/models/experimental/torch/fcmodel.py @@ -16,9 +16,9 @@ def __init__(self, config: ModelConfig) -> None: self.net = TorchMLP( input_dim=config.input_dim, - hidden_layers=config.hidden_layers, + hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, - activation=config.activation, + hidden_layer_activation=config.hidden_layer_activation, ) @property diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index b15dfea7512f..445d97be0e5b 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -10,6 +10,7 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.typing import TensorType from ray.rllib.models.experimental.base import ModelConfig +from ray.rllib.models.utils import get_activation_fn torch, nn = try_import_torch() @@ -53,44 +54,45 @@ class TorchMLP(nn.Module): Attributes: input_dim: The input dimension of the network. It cannot be None. + hidden_layer_dims: The sizes of the hidden layers. output_dim: The output dimension of the network. if None, the last layer would be the last hidden layer. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer. + hidden_layer_activation: The activation function to use after each layer. + output_activation: The activation function to use for the output layer. """ def __init__( self, input_dim: int, - hidden_layers: List[int], + hidden_layer_dims: List[int], output_dim: Optional[int] = None, - activation: str = "linear", + hidden_layer_activation: str = "linear", + output_activation: str = "linear", ): super().__init__() self.input_dim = input_dim - self.hidden_layers = hidden_layers - - activation_class = getattr(nn, activation, lambda: None)() - self.layers = [] - self.layers.append(nn.Linear(self.input_dim, self.hidden_layers[0])) - for i in range(len(self.hidden_layers) - 1): - if activation != "linear": - self.layers.append(activation_class) - self.layers.append( - nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]) - ) + hidden_layer_dims = hidden_layer_dims - if output_dim is not None: - if activation != "linear": - self.layers.append(activation_class) - self.layers.append(nn.Linear(self.hidden_layers[-1], output_dim)) + activation_class = getattr(nn, hidden_layer_activation, lambda: None)() + layers = [] + layers.append(nn.Linear(input_dim, hidden_layer_dims[0])) + for i in range(len(hidden_layer_dims) - 1): + if hidden_layer_activation != "linear": + layers.append(activation_class) + layers.append(nn.Linear(hidden_layer_dims[i], hidden_layer_dims[i + 1])) - if output_dim is None: - self.output_dim = hidden_layers[-1] - else: + if output_dim is not None: + if hidden_layer_activation != "linear": + layers.append(activation_class) + layers.append(nn.Linear(hidden_layer_dims[-1], output_dim)) self.output_dim = output_dim + else: + self.output_dim = hidden_layer_dims[-1] + + if output_activation != "linear": + layers.append(get_activation_fn(output_activation, framework="torch")) - self.layers = nn.Sequential(*self.layers) + self.mlp = nn.Sequential(*layers) def forward(self, x): - return self.layers(x) + return self.mlp(x) From bd687abc5b82975ec4a34dca757a1b6226e091b7 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 26 Jan 2023 13:00:12 -0800 Subject: [PATCH 41/51] delete vectorencoder Signed-off-by: Artur Niederfahrenhorst --- rllib/BUILD | 8 - .../ppo/torch/ppo_torch_rl_module.py | 18 +- rllib/models/torch/__init__.py | 12 - rllib/models/torch/attention_net.py | 452 ------------ rllib/models/torch/complex_input_net.py | 238 ------- .../tests/test_torch_vector_encoder.py | 70 -- rllib/models/torch/encoders/vector.py | 97 --- rllib/models/torch/fcnet.py | 161 ----- rllib/models/torch/mingpt.py | 299 -------- rllib/models/torch/misc.py | 195 ------ rllib/models/torch/model.py | 220 ------ rllib/models/torch/modules/__init__.py | 13 - .../torch/modules/convtranspose2d_stack.py | 82 --- rllib/models/torch/modules/gru_gate.py | 65 -- .../torch/modules/multi_head_attention.py | 68 -- rllib/models/torch/modules/noisy_layer.py | 99 --- .../modules/relative_multi_head_attention.py | 175 ----- rllib/models/torch/modules/skip_connection.py | 41 -- rllib/models/torch/noop.py | 13 - rllib/models/torch/primitives.py | 54 -- rllib/models/torch/recurrent_net.py | 285 -------- rllib/models/torch/torch_action_dist.py | 648 ------------------ rllib/models/torch/torch_distributions.py | 257 ------- rllib/models/torch/torch_modelv2.py | 81 --- rllib/models/torch/visionnet.py | 293 -------- 25 files changed, 9 insertions(+), 3935 deletions(-) delete mode 100644 rllib/models/torch/__init__.py delete mode 100644 rllib/models/torch/attention_net.py delete mode 100644 rllib/models/torch/complex_input_net.py delete mode 100644 rllib/models/torch/encoders/tests/test_torch_vector_encoder.py delete mode 100644 rllib/models/torch/encoders/vector.py delete mode 100644 rllib/models/torch/fcnet.py delete mode 100644 rllib/models/torch/mingpt.py delete mode 100644 rllib/models/torch/misc.py delete mode 100644 rllib/models/torch/model.py delete mode 100644 rllib/models/torch/modules/__init__.py delete mode 100644 rllib/models/torch/modules/convtranspose2d_stack.py delete mode 100644 rllib/models/torch/modules/gru_gate.py delete mode 100644 rllib/models/torch/modules/multi_head_attention.py delete mode 100644 rllib/models/torch/modules/noisy_layer.py delete mode 100644 rllib/models/torch/modules/relative_multi_head_attention.py delete mode 100644 rllib/models/torch/modules/skip_connection.py delete mode 100644 rllib/models/torch/noop.py delete mode 100644 rllib/models/torch/primitives.py delete mode 100644 rllib/models/torch/recurrent_net.py delete mode 100644 rllib/models/torch/torch_action_dist.py delete mode 100644 rllib/models/torch/torch_distributions.py delete mode 100644 rllib/models/torch/torch_modelv2.py delete mode 100644 rllib/models/torch/visionnet.py diff --git a/rllib/BUILD b/rllib/BUILD index 46642d737365..2c70e8e3d14b 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1974,14 +1974,6 @@ py_test( srcs = ["models/specs/tests/test_spec_dict.py"] ) -# test TorchVectorEncoder -py_test( - name = "test_torch_vector_encoder", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["models/torch/encoders/tests/test_torch_vector_encoder.py"] -) - # -------------------------------------------------------------------- # Offline diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 8d07c14a7f73..9d60245c3176 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -76,8 +76,8 @@ def setup(self) -> None: assert self.config.vf_config, "vf_config must be provided." assert self.config.encoder_config, "shared encoder config must be " "provided." - # TODO(Artur): Unify to tf and torch setup(framework) - self.shared_encoder = self.config.encoder_config.build(framework="torch") + # TODO(Artur): Unify to tf and torch setup with ModelBuilder + self.encoder = self.config.encoder_config.build(framework="torch") self.pi = self.config.pi_config.build(framework="torch") self.vf = self.config.vf_config.build(framework="torch") @@ -176,8 +176,8 @@ def from_model_config( return module def get_initial_state(self) -> NestedDict: - if hasattr(self.shared_encoder, "get_initial_state"): - return self.shared_encoder.get_initial_state() + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() else: return NestedDict({}) @@ -193,7 +193,7 @@ def output_specs_inference(self) -> SpecDict: def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: output = {} - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) if STATE_OUT in encoder_out: output[STATE_OUT] = encoder_out[STATE_OUT] @@ -210,7 +210,7 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: @override(RLModule) def input_specs_exploration(self): - return self.shared_encoder.input_spec + return self.encoder.input_spec @override(RLModule) def output_specs_exploration(self) -> SpecDict: @@ -238,7 +238,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output = {} # Shared encoder - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) if STATE_OUT in encoder_out: output[STATE_OUT] = encoder_out[STATE_OUT] @@ -269,7 +269,7 @@ def input_specs_train(self) -> SpecDict: action_dim = self.config.action_space.shape[0] action_spec = TorchTensorSpec("b, h", h=action_dim) - spec_dict = self.shared_encoder.input_spec + spec_dict = self.encoder.input_spec spec_dict.update({SampleBatch.ACTIONS: action_spec}) if SampleBatch.OBS in spec_dict: spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] @@ -293,7 +293,7 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: output = {} # Shared encoder - encoder_out = self.shared_encoder(batch) + encoder_out = self.encoder(batch) if STATE_OUT in encoder_out: output[STATE_OUT] = encoder_out[STATE_OUT] diff --git a/rllib/models/torch/__init__.py b/rllib/models/torch/__init__.py deleted file mode 100644 index abbe5ef60464..000000000000 --- a/rllib/models/torch/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -# from ray.rllib.models.torch.fcnet import FullyConnectedNetwork -# from ray.rllib.models.torch.recurrent_net import \ -# RecurrentNetwork -# from ray.rllib.models.torch.visionnet import VisionNetwork - -# __all__ = [ -# "FullyConnectedNetwork", -# "RecurrentNetwork", -# "TorchModelV2", -# "VisionNetwork", -# ] diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py deleted file mode 100644 index 454c0a555c97..000000000000 --- a/rllib/models/torch/attention_net.py +++ /dev/null @@ -1,452 +0,0 @@ -""" -[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, - Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. - https://arxiv.org/pdf/1706.03762.pdf -[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto - et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf -[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context. - Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. - https://www.aclweb.org/anthology/P19-1285.pdf -""" -import gymnasium as gym -from gymnasium.spaces import Box, Discrete, MultiDiscrete -import numpy as np -import tree # pip install dm_tree -from typing import Dict, Optional, Union - -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.misc import SlimFC -from ray.rllib.models.torch.modules import ( - GRUGate, - RelativeMultiHeadAttention, - SkipConnection, -) -from ray.rllib.models.torch.recurrent_net import RecurrentNetwork -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot -from ray.rllib.utils.typing import ModelConfigDict, TensorType, List - -torch, nn = try_import_torch() - - -class GTrXLNet(RecurrentNetwork, nn.Module): - """A GTrXL net Model described in [2]. - - This is still in an experimental phase. - Can be used as a drop-in replacement for LSTMs in PPO and IMPALA. - For an example script, see: `ray/rllib/examples/attention_net.py`. - - To use this network as a replacement for an RNN, configure your Trainer - as follows: - - Examples: - >> config["model"]["custom_model"] = GTrXLNet - >> config["model"]["max_seq_len"] = 10 - >> config["model"]["custom_model_config"] = { - >> num_transformer_units=1, - >> attention_dim=32, - >> num_heads=2, - >> memory_tau=50, - >> etc.. - >> } - """ - - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: Optional[int], - model_config: ModelConfigDict, - name: str, - *, - num_transformer_units: int = 1, - attention_dim: int = 64, - num_heads: int = 2, - memory_inference: int = 50, - memory_training: int = 50, - head_dim: int = 32, - position_wise_mlp_dim: int = 32, - init_gru_gate_bias: float = 2.0 - ): - """Initializes a GTrXLNet. - - Args: - num_transformer_units: The number of Transformer repeats to - use (denoted L in [2]). - attention_dim: The input and output dimensions of one - Transformer unit. - num_heads: The number of attention heads to use in parallel. - Denoted as `H` in [3]. - memory_inference: The number of timesteps to concat (time - axis) and feed into the next transformer unit as inference - input. The first transformer unit will receive this number of - past observations (plus the current one), instead. - memory_training: The number of timesteps to concat (time - axis) and feed into the next transformer unit as training - input (plus the actual input sequence of len=max_seq_len). - The first transformer unit will receive this number of - past observations (plus the input sequence), instead. - head_dim: The dimension of a single(!) attention head within - a multi-head attention unit. Denoted as `d` in [3]. - position_wise_mlp_dim: The dimension of the hidden layer - within the position-wise MLP (after the multi-head attention - block within one Transformer unit). This is the size of the - first of the two layers within the PositionwiseFeedforward. The - second layer always has size=`attention_dim`. - init_gru_gate_bias: Initial bias values for the GRU gates - (two GRUs per Transformer unit, one after the MHA, one after - the position-wise MLP). - """ - - super().__init__( - observation_space, action_space, num_outputs, model_config, name - ) - - nn.Module.__init__(self) - - self.num_transformer_units = num_transformer_units - self.attention_dim = attention_dim - self.num_heads = num_heads - self.memory_inference = memory_inference - self.memory_training = memory_training - self.head_dim = head_dim - self.max_seq_len = model_config["max_seq_len"] - self.obs_dim = observation_space.shape[0] - - self.linear_layer = SlimFC(in_size=self.obs_dim, out_size=self.attention_dim) - - self.layers = [self.linear_layer] - - attention_layers = [] - # 2) Create L Transformer blocks according to [2]. - for i in range(self.num_transformer_units): - # RelativeMultiHeadAttention part. - MHA_layer = SkipConnection( - RelativeMultiHeadAttention( - in_dim=self.attention_dim, - out_dim=self.attention_dim, - num_heads=num_heads, - head_dim=head_dim, - input_layernorm=True, - output_activation=nn.ReLU, - ), - fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), - ) - - # Position-wise MultiLayerPerceptron part. - E_layer = SkipConnection( - nn.Sequential( - torch.nn.LayerNorm(self.attention_dim), - SlimFC( - in_size=self.attention_dim, - out_size=position_wise_mlp_dim, - use_bias=False, - activation_fn=nn.ReLU, - ), - SlimFC( - in_size=position_wise_mlp_dim, - out_size=self.attention_dim, - use_bias=False, - activation_fn=nn.ReLU, - ), - ), - fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), - ) - - # Build a list of all attanlayers in order. - attention_layers.extend([MHA_layer, E_layer]) - - # Create a Sequential such that all parameters inside the attention - # layers are automatically registered with this top-level model. - self.attention_layers = nn.Sequential(*attention_layers) - self.layers.extend(attention_layers) - - # Final layers if num_outputs not None. - self.logits = None - self.values_out = None - # Last value output. - self._value_out = None - # Postprocess GTrXL output with another hidden layer. - if self.num_outputs is not None: - self.logits = SlimFC( - in_size=self.attention_dim, - out_size=self.num_outputs, - activation_fn=nn.ReLU, - ) - - # Value function used by all RLlib Torch RL implementations. - self.values_out = SlimFC( - in_size=self.attention_dim, out_size=1, activation_fn=None - ) - else: - self.num_outputs = self.attention_dim - - # Setup trajectory views (`memory-inference` x past memory outs). - for i in range(self.num_transformer_units): - space = Box(-1.0, 1.0, shape=(self.attention_dim,)) - self.view_requirements["state_in_{}".format(i)] = ViewRequirement( - "state_out_{}".format(i), - shift="-{}:-1".format(self.memory_inference), - # Repeat the incoming state every max-seq-len times. - batch_repeat_value=self.max_seq_len, - space=space, - ) - self.view_requirements["state_out_{}".format(i)] = ViewRequirement( - space=space, used_for_training=False - ) - - @override(ModelV2) - def forward( - self, input_dict, state: List[TensorType], seq_lens: TensorType - ) -> (TensorType, List[TensorType]): - assert seq_lens is not None - - # Add the needed batch rank (tf Models' Input requires this). - observations = input_dict[SampleBatch.OBS] - # Add the time dim to observations. - B = len(seq_lens) - T = observations.shape[0] // B - observations = torch.reshape( - observations, [-1, T] + list(observations.shape[1:]) - ) - - all_out = observations - memory_outs = [] - for i in range(len(self.layers)): - # MHA layers which need memory passed in. - if i % 2 == 1: - all_out = self.layers[i](all_out, memory=state[i // 2]) - # Either self.linear_layer (initial obs -> attn. dim layer) or - # MultiLayerPerceptrons. The output of these layers is always the - # memory for the next forward pass. - else: - all_out = self.layers[i](all_out) - memory_outs.append(all_out) - - # Discard last output (not needed as a memory since it's the last - # layer). - memory_outs = memory_outs[:-1] - - if self.logits is not None: - out = self.logits(all_out) - self._value_out = self.values_out(all_out) - out_dim = self.num_outputs - else: - out = all_out - out_dim = self.attention_dim - - return torch.reshape(out, [-1, out_dim]), [ - torch.reshape(m, [-1, self.attention_dim]) for m in memory_outs - ] - - # TODO: (sven) Deprecate this once trajectory view API has fully matured. - @override(RecurrentNetwork) - def get_initial_state(self) -> List[np.ndarray]: - return [] - - @override(ModelV2) - def value_function(self) -> TensorType: - assert ( - self._value_out is not None - ), "Must call forward first AND must have value branch!" - return torch.reshape(self._value_out, [-1]) - - -class AttentionWrapper(TorchModelV2, nn.Module): - """GTrXL wrapper serving as interface for ModelV2s that set use_attention.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - - nn.Module.__init__(self) - super().__init__(obs_space, action_space, None, model_config, name) - - self.use_n_prev_actions = model_config["attention_use_n_prev_actions"] - self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"] - - self.action_space_struct = get_base_struct_from_space(self.action_space) - self.action_dim = 0 - - for space in tree.flatten(self.action_space_struct): - if isinstance(space, Discrete): - self.action_dim += space.n - elif isinstance(space, MultiDiscrete): - self.action_dim += np.sum(space.nvec) - elif space.shape is not None: - self.action_dim += int(np.product(space.shape)) - else: - self.action_dim += int(len(space)) - - # Add prev-action/reward nodes to input to LSTM. - if self.use_n_prev_actions: - self.num_outputs += self.use_n_prev_actions * self.action_dim - if self.use_n_prev_rewards: - self.num_outputs += self.use_n_prev_rewards - - cfg = model_config - - self.attention_dim = cfg["attention_dim"] - - if self.num_outputs is not None: - in_space = gym.spaces.Box( - float("-inf"), float("inf"), shape=(self.num_outputs,), dtype=np.float32 - ) - else: - in_space = obs_space - - # Construct GTrXL sub-module w/ num_outputs=None (so it does not - # create a logits/value output; we'll do this ourselves in this wrapper - # here). - self.gtrxl = GTrXLNet( - in_space, - action_space, - None, - model_config, - "gtrxl", - num_transformer_units=cfg["attention_num_transformer_units"], - attention_dim=self.attention_dim, - num_heads=cfg["attention_num_heads"], - head_dim=cfg["attention_head_dim"], - memory_inference=cfg["attention_memory_inference"], - memory_training=cfg["attention_memory_training"], - position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"], - init_gru_gate_bias=cfg["attention_init_gru_gate_bias"], - ) - - # Set final num_outputs to correct value (depending on action space). - self.num_outputs = num_outputs - - # Postprocess GTrXL output with another hidden layer and compute - # values. - self._logits_branch = SlimFC( - in_size=self.attention_dim, - out_size=self.num_outputs, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - self._value_branch = SlimFC( - in_size=self.attention_dim, - out_size=1, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - - self.view_requirements = self.gtrxl.view_requirements - self.view_requirements["obs"].space = self.obs_space - - # Add prev-a/r to this model's view, if required. - if self.use_n_prev_actions: - self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement( - SampleBatch.ACTIONS, - space=self.action_space, - shift="-{}:-1".format(self.use_n_prev_actions), - ) - if self.use_n_prev_rewards: - self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement( - SampleBatch.REWARDS, shift="-{}:-1".format(self.use_n_prev_rewards) - ) - - @override(RecurrentNetwork) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> (TensorType, List[TensorType]): - assert seq_lens is not None - # Push obs through "unwrapped" net's `forward()` first. - wrapped_out, _ = self._wrapped_forward(input_dict, [], None) - - # Concat. prev-action/reward if required. - prev_a_r = [] - - # Prev actions. - if self.use_n_prev_actions: - prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS] - # If actions are not processed yet (in their original form as - # have been sent to environment): - # Flatten/one-hot into 1D array. - if self.model_config["_disable_action_flattening"]: - # Merge prev n actions into flat tensor. - flat = flatten_inputs_to_1d_tensor( - prev_n_actions, - spaces_struct=self.action_space_struct, - time_axis=True, - ) - # Fold time-axis into flattened data. - flat = torch.reshape(flat, [flat.shape[0], -1]) - prev_a_r.append(flat) - # If actions are already flattened (but not one-hot'd yet!), - # one-hot discrete/multi-discrete actions here and concatenate the - # n most recent actions together. - else: - if isinstance(self.action_space, Discrete): - for i in range(self.use_n_prev_actions): - prev_a_r.append( - one_hot( - prev_n_actions[:, i].float(), space=self.action_space - ) - ) - elif isinstance(self.action_space, MultiDiscrete): - for i in range( - 0, self.use_n_prev_actions, self.action_space.shape[0] - ): - prev_a_r.append( - one_hot( - prev_n_actions[ - :, i : i + self.action_space.shape[0] - ].float(), - space=self.action_space, - ) - ) - else: - prev_a_r.append( - torch.reshape( - prev_n_actions.float(), - [-1, self.use_n_prev_actions * self.action_dim], - ) - ) - # Prev rewards. - if self.use_n_prev_rewards: - prev_a_r.append( - torch.reshape( - input_dict[SampleBatch.PREV_REWARDS].float(), - [-1, self.use_n_prev_rewards], - ) - ) - - # Concat prev. actions + rewards to the "main" input. - if prev_a_r: - wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) - - # Then through our GTrXL. - input_dict["obs_flat"] = input_dict["obs"] = wrapped_out - - self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) - model_out = self._logits_branch(self._features) - return model_out, memory_outs - - @override(ModelV2) - def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: - return [ - torch.zeros( - self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape - ) - for i in range(self.gtrxl.num_transformer_units) - ] - - @override(ModelV2) - def value_function(self) -> TensorType: - assert self._features is not None, "Must call forward() first!" - return torch.reshape(self._value_branch(self._features), [-1]) diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py deleted file mode 100644 index f3cb4311521d..000000000000 --- a/rllib/models/torch/complex_input_net.py +++ /dev/null @@ -1,238 +0,0 @@ -from gymnasium.spaces import Box, Discrete, MultiDiscrete -import numpy as np -import tree # pip install dm_tree - -# TODO (sven): add IMPALA-style option. -# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet -from ray.rllib.models.torch.misc import ( - normc_initializer as torch_normc_initializer, - SlimFC, -) -from ray.rllib.models.catalog import ModelCatalog -from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.utils import get_filter_config -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.spaces.space_utils import flatten_space -from ray.rllib.utils.torch_utils import one_hot - -torch, nn = try_import_torch() - - -class ComplexInputNetwork(TorchModelV2, nn.Module): - """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). - - Note: This model should be used for complex (Dict or Tuple) observation - spaces that have one or more image components. - - The data flow is as follows: - - `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT` - `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out` - `out` -> (optional) FC-stack -> `out2` - `out2` -> action (logits) and value heads. - """ - - def __init__(self, obs_space, action_space, num_outputs, model_config, name): - self.original_space = ( - obs_space.original_space - if hasattr(obs_space, "original_space") - else obs_space - ) - - self.processed_obs_space = ( - self.original_space - if model_config.get("_disable_preprocessor_api") - else obs_space - ) - - nn.Module.__init__(self) - TorchModelV2.__init__( - self, self.original_space, action_space, num_outputs, model_config, name - ) - - self.flattened_input_space = flatten_space(self.original_space) - - # Atari type CNNs or IMPALA type CNNs (with residual layers)? - # self.cnn_type = self.model_config["custom_model_config"].get( - # "conv_type", "atari") - - # Build the CNN(s) given obs_space's image components. - self.cnns = nn.ModuleDict() - self.one_hot = nn.ModuleDict() - self.flatten_dims = {} - self.flatten = nn.ModuleDict() - concat_size = 0 - for i, component in enumerate(self.flattened_input_space): - i = str(i) - # Image space. - if len(component.shape) == 3 and isinstance(component, Box): - config = { - "conv_filters": model_config["conv_filters"] - if "conv_filters" in model_config - else get_filter_config(component.shape), - "conv_activation": model_config.get("conv_activation"), - "post_fcnet_hiddens": [], - } - # if self.cnn_type == "atari": - self.cnns[i] = ModelCatalog.get_model_v2( - component, - action_space, - num_outputs=None, - model_config=config, - framework="torch", - name="cnn_{}".format(i), - ) - # TODO (sven): add IMPALA-style option. - # else: - # cnn = TorchImpalaVisionNet( - # component, - # action_space, - # num_outputs=None, - # model_config=config, - # name="cnn_{}".format(i)) - - concat_size += self.cnns[i].num_outputs - self.add_module("cnn_{}".format(i), self.cnns[i]) - # Discrete|MultiDiscrete inputs -> One-hot encode. - elif isinstance(component, (Discrete, MultiDiscrete)): - if isinstance(component, Discrete): - size = component.n - else: - size = np.sum(component.nvec) - config = { - "fcnet_hiddens": model_config["fcnet_hiddens"], - "fcnet_activation": model_config.get("fcnet_activation"), - "post_fcnet_hiddens": [], - } - self.one_hot[i] = ModelCatalog.get_model_v2( - Box(-1.0, 1.0, (size,), np.float32), - action_space, - num_outputs=None, - model_config=config, - framework="torch", - name="one_hot_{}".format(i), - ) - concat_size += self.one_hot[i].num_outputs - self.add_module("one_hot_{}".format(i), self.one_hot[i]) - # Everything else (1D Box). - else: - size = int(np.product(component.shape)) - config = { - "fcnet_hiddens": model_config["fcnet_hiddens"], - "fcnet_activation": model_config.get("fcnet_activation"), - "post_fcnet_hiddens": [], - } - self.flatten[i] = ModelCatalog.get_model_v2( - Box(-1.0, 1.0, (size,), np.float32), - action_space, - num_outputs=None, - model_config=config, - framework="torch", - name="flatten_{}".format(i), - ) - self.flatten_dims[i] = size - concat_size += self.flatten[i].num_outputs - self.add_module("flatten_{}".format(i), self.flatten[i]) - - # Optional post-concat FC-stack. - post_fc_stack_config = { - "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), - "fcnet_activation": model_config.get("post_fcnet_activation", "relu"), - } - self.post_fc_stack = ModelCatalog.get_model_v2( - Box(float("-inf"), float("inf"), shape=(concat_size,), dtype=np.float32), - self.action_space, - None, - post_fc_stack_config, - framework="torch", - name="post_fc_stack", - ) - - # Actions and value heads. - self.logits_layer = None - self.value_layer = None - self._value_out = None - - if num_outputs: - # Action-distribution head. - self.logits_layer = SlimFC( - in_size=self.post_fc_stack.num_outputs, - out_size=num_outputs, - activation_fn=None, - initializer=torch_normc_initializer(0.01), - ) - # Create the value branch model. - self.value_layer = SlimFC( - in_size=self.post_fc_stack.num_outputs, - out_size=1, - activation_fn=None, - initializer=torch_normc_initializer(0.01), - ) - else: - self.num_outputs = concat_size - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: - orig_obs = input_dict[SampleBatch.OBS] - else: - orig_obs = restore_original_dimensions( - input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch" - ) - # Push observations through the different components - # (CNNs, one-hot + FC, etc..). - outs = [] - for i, component in enumerate(tree.flatten(orig_obs)): - i = str(i) - if i in self.cnns: - cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component})) - outs.append(cnn_out) - elif i in self.one_hot: - if component.dtype in [ - torch.int8, - torch.int16, - torch.int32, - torch.int64, - torch.uint8, - ]: - one_hot_in = { - SampleBatch.OBS: one_hot( - component, self.flattened_input_space[int(i)] - ) - } - else: - one_hot_in = {SampleBatch.OBS: component} - one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in)) - outs.append(one_hot_out) - else: - nn_out, _ = self.flatten[i]( - SampleBatch( - { - SampleBatch.OBS: torch.reshape( - component, [-1, self.flatten_dims[i]] - ) - } - ) - ) - outs.append(nn_out) - - # Concat all outputs and the non-image inputs. - out = torch.cat(outs, dim=1) - # Push through (optional) FC-stack (this may be an empty stack). - out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out})) - - # No logits/value branches. - if self.logits_layer is None: - return out, [] - - # Logits- and value branches. - logits, values = self.logits_layer(out), self.value_layer(out) - self._value_out = torch.reshape(values, [-1]) - return logits, [] - - @override(ModelV2) - def value_function(self): - return self._value_out diff --git a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py b/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py deleted file mode 100644 index 0f3743be2f43..000000000000 --- a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py +++ /dev/null @@ -1,70 +0,0 @@ -import unittest - -import torch - -from ray.rllib.models.configs.encoder import VectorEncoderConfig -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.utils.nested_dict import NestedDict - - -class TestConfig(unittest.TestCase): - def test_error_no_feature_dim(self): - """Ensure we error out if we don't know the input dim""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c")}) - c = VectorEncoderConfig() - with self.assertRaises(AssertionError): - c.build(input_spec) - - def test_default_build(self): - """Test building with the default config""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - c.build(input_spec) - - def test_nonlinear_final_build(self): - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig(final_activation="relu") - c.build(input_spec) - - def test_default_forward(self): - """Test the default config/model _forward implementation""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - m = c.build(input_spec) - inputs = NestedDict({"bork": torch.rand((2, 4, 3))}) - outputs, _ = m.unroll(inputs, NestedDict()) - self.assertEqual(outputs[c.output_key].shape[-1], c.hidden_layer_sizes[-1]) - self.assertEqual(outputs[c.output_key].shape[:-1], (2, 4)) - - def test_two_inputs_forward(self): - """Test the default model when we have two items in the input_spec. - These two items will be concatenated and fed thru the mlp.""" - """Test the default config/model _forward implementation""" - input_spec = SpecDict( - { - "bork": TorchTensorSpec("a, b, c", c=3), - "dork": TorchTensorSpec("x, y, z", z=5), - } - ) - c = VectorEncoderConfig() - m = c.build(input_spec) - self.assertEqual(m.net[0].in_features, 8) - inputs = NestedDict( - {"bork": torch.rand((2, 4, 3)), "dork": torch.rand((2, 4, 5))} - ) - outputs, _ = m.unroll(inputs, NestedDict()) - self.assertEqual(outputs[c.output_key].shape[-1], c.hidden_layer_sizes[-1]) - self.assertEqual(outputs[c.output_key].shape[:-1], (2, 4)) - - def test_deep_build(self): - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - c.build(input_spec) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/torch/encoders/vector.py b/rllib/models/torch/encoders/vector.py deleted file mode 100644 index 91ef65d71f44..000000000000 --- a/rllib/models/torch/encoders/vector.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import TYPE_CHECKING -from ray.rllib.models.specs.specs_torch import TorchTensorSpec - -import torch -from torch import nn - -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.torch.model import TorchModel -from ray.rllib.models.utils import get_activation_fn -from ray.rllib.utils.nested_dict import NestedDict - -from ray.rllib.models.utils import input_to_output_spec - -if TYPE_CHECKING: - from ray.rllib.models.configs.encoder import VectorEncoderConfig - - -class TorchVectorEncoder(TorchModel): - """A torch implementation of an MLP encoder. - - This encoder concatenates inputs along the last dimension, - then pushes them through a series of linear layers and nonlinear activations. - """ - - @property - def input_spec(self) -> SpecDict: - return self._input_spec - - @property - def output_spec(self) -> SpecDict: - return self._output_spec - - def __init__( - self, - input_spec: SpecDict, - config: "VectorEncoderConfig", - ): - super().__init__(config=config) - # Setup input and output specs - self._input_spec = input_spec - self._output_spec = input_to_output_spec( - input_spec=input_spec, - num_input_feature_dims=1, - output_key=config.output_key, - output_feature_spec=TorchTensorSpec("f", f=config.hidden_layer_sizes[-1]), - ) - # Returns the size of the feature dimension for the input tensors - prev_size = sum(v.shape[-1] for v in input_spec.values()) - - # Construct layers - layers = [] - activation = ( - None - if config.activation == "linear" - else get_activation_fn(config.activation, framework=config.framework_str)() - ) - for size in config.hidden_layer_sizes[:-1]: - layers += [nn.Linear(prev_size, size)] - layers += [activation] if activation is not None else [] - prev_size = size - - # Final layer - layers += [ - nn.Linear(config.hidden_layer_sizes[-2], config.hidden_layer_sizes[-1]) - ] - if config.final_activation != "linear": - layers += [ - get_activation_fn( - config.final_activation, framework=config.framework_str - )() - ] - - self.net = nn.Sequential(*layers) - - def _forward(self, inputs: NestedDict) -> NestedDict: - """Runs the forward pass of the MLP. Call this via unroll(). - - Args: - inputs: The nested dictionary of inputs - - Returns: - The nested dictionary of outputs - """ - # Ensure all inputs have matching dims before concat - # so we can emit an informative error message - first_key, first_tensor = list(inputs.items())[0] - for k, tensor in inputs.items(): - assert tensor.shape[:-1] == first_tensor.shape[:-1], ( - "Inputs have mismatching dimensions, all dims but the last should " - f"be equal: {first_key}: {first_tensor.shape} != {k}: {tensor.shape}" - ) - - # Concatenate all input along the feature dim - x = torch.cat(list(inputs.values()), dim=-1) - [out_key] = self.output_spec.keys() - inputs[out_key] = self.net(x) - return inputs diff --git a/rllib/models/torch/fcnet.py b/rllib/models/torch/fcnet.py deleted file mode 100644 index 97bb9096bb64..000000000000 --- a/rllib/models/torch/fcnet.py +++ /dev/null @@ -1,161 +0,0 @@ -import logging -import numpy as np -import gymnasium as gym - -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, normc_initializer -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict - -torch, nn = try_import_torch() - -logger = logging.getLogger(__name__) - - -class FullyConnectedNetwork(TorchModelV2, nn.Module): - """Generic fully connected network.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - TorchModelV2.__init__( - self, obs_space, action_space, num_outputs, model_config, name - ) - nn.Module.__init__(self) - - hiddens = list(model_config.get("fcnet_hiddens", [])) + list( - model_config.get("post_fcnet_hiddens", []) - ) - activation = model_config.get("fcnet_activation") - if not model_config.get("fcnet_hiddens", []): - activation = model_config.get("post_fcnet_activation") - no_final_linear = model_config.get("no_final_linear") - self.vf_share_layers = model_config.get("vf_share_layers") - self.free_log_std = model_config.get("free_log_std") - # Generate free-floating bias variables for the second half of - # the outputs. - if self.free_log_std: - assert num_outputs % 2 == 0, ( - "num_outputs must be divisible by two", - num_outputs, - ) - num_outputs = num_outputs // 2 - - layers = [] - prev_layer_size = int(np.product(obs_space.shape)) - self._logits = None - - # Create layers 0 to second-last. - for size in hiddens[:-1]: - layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=size, - initializer=normc_initializer(1.0), - activation_fn=activation, - ) - ) - prev_layer_size = size - - # The last layer is adjusted to be of size num_outputs, but it's a - # layer with activation. - if no_final_linear and num_outputs: - layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=num_outputs, - initializer=normc_initializer(1.0), - activation_fn=activation, - ) - ) - prev_layer_size = num_outputs - # Finish the layers with the provided sizes (`hiddens`), plus - - # iff num_outputs > 0 - a last linear layer of size num_outputs. - else: - if len(hiddens) > 0: - layers.append( - SlimFC( - in_size=prev_layer_size, - out_size=hiddens[-1], - initializer=normc_initializer(1.0), - activation_fn=activation, - ) - ) - prev_layer_size = hiddens[-1] - if num_outputs: - self._logits = SlimFC( - in_size=prev_layer_size, - out_size=num_outputs, - initializer=normc_initializer(0.01), - activation_fn=None, - ) - else: - self.num_outputs = ([int(np.product(obs_space.shape))] + hiddens[-1:])[ - -1 - ] - - # Layer to add the log std vars to the state-dependent means. - if self.free_log_std and self._logits: - self._append_free_log_std = AppendBiasLayer(num_outputs) - - self._hidden_layers = nn.Sequential(*layers) - - self._value_branch_separate = None - if not self.vf_share_layers: - # Build a parallel set of hidden layers for the value net. - prev_vf_layer_size = int(np.product(obs_space.shape)) - vf_layers = [] - for size in hiddens: - vf_layers.append( - SlimFC( - in_size=prev_vf_layer_size, - out_size=size, - activation_fn=activation, - initializer=normc_initializer(1.0), - ) - ) - prev_vf_layer_size = size - self._value_branch_separate = nn.Sequential(*vf_layers) - - self._value_branch = SlimFC( - in_size=prev_layer_size, - out_size=1, - initializer=normc_initializer(0.01), - activation_fn=None, - ) - # Holds the current "base" output (before logits layer). - self._features = None - # Holds the last input, in case value branch is separate. - self._last_flat_in = None - - @override(TorchModelV2) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> (TensorType, List[TensorType]): - obs = input_dict["obs_flat"].float() - self._last_flat_in = obs.reshape(obs.shape[0], -1) - self._features = self._hidden_layers(self._last_flat_in) - logits = self._logits(self._features) if self._logits else self._features - if self.free_log_std: - logits = self._append_free_log_std(logits) - return logits, state - - @override(TorchModelV2) - def value_function(self) -> TensorType: - assert self._features is not None, "must call forward() first" - if self._value_branch_separate: - out = self._value_branch( - self._value_branch_separate(self._last_flat_in) - ).squeeze(1) - else: - out = self._value_branch(self._features).squeeze(1) - return out diff --git a/rllib/models/torch/mingpt.py b/rllib/models/torch/mingpt.py deleted file mode 100644 index 00a192e9ec91..000000000000 --- a/rllib/models/torch/mingpt.py +++ /dev/null @@ -1,299 +0,0 @@ -# LICENSE: MIT -""" -Adapted from https://github.com/karpathy/minGPT - -Full definition of a GPT Language Model, all of it in this single file. -References: -1) the official GPT-2 TensorFlow implementation released by OpenAI: -https://github.com/openai/gpt-2/blob/master/src/model.py -2) huggingface/transformers PyTorch implementation: -https://github.com/huggingface/transformers/blob/main/src/transformers - /models/gpt2/modeling_gpt2.py -""" - -import math -from dataclasses import dataclass -from typing import Tuple - -import torch -import torch.nn as nn -from torch.nn import functional as F - -from ray.rllib.utils.annotations import DeveloperAPI - - -@DeveloperAPI -@dataclass -class GPTConfig: - # block size must be provided - block_size: int - - # transformer config - n_layer: int = 12 - n_head: int = 12 - n_embed: int = 768 - - # dropout config - embed_pdrop: float = 0.1 - resid_pdrop: float = 0.1 - attn_pdrop: float = 0.1 - - -class NewGELU(nn.Module): - """ - Implementation of the GELU activation function currently in Google BERT - repo (identical to OpenAI GPT). - Reference: Gaussian Error Linear Units (GELU) paper: - https://arxiv.org/abs/1606.08415 - """ - - def forward(self, x): - return ( - 0.5 - * x - * ( - 1.0 - + torch.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) - ) - ) - ) - - -class CausalSelfAttention(nn.Module): - """ - Vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. - """ - - def __init__(self, config: GPTConfig): - super().__init__() - assert config.n_embed % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) - # output projection - self.c_proj = nn.Linear(config.n_embed, config.n_embed) - # regularization - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) - # causal mask to ensure that attention is only applied to the left - # in the input sequence - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size - ), - ) - self.n_head = config.n_head - self.n_embed = config.n_embed - - def forward(self, x, attention_masks=None): - # batch size, sequence length, embedding dimensionality (n_embed) - B, T, C = x.size() - - # calculate query, key, values for all heads in batch and move head - # forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embed, dim=2) - # (B, nh, T, hs) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) - # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) - # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) - - # causal self-attention; Self-attend: - # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - if attention_masks is not None: - att = att + attention_masks - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - # re-assemble all head outputs side by side - y = y.transpose(1, 2).contiguous().view(B, T, C) - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y, att - - -class Block(nn.Module): - """an unassuming Transformer block""" - - def __init__(self, config: GPTConfig): - super().__init__() - self.ln_1 = nn.LayerNorm(config.n_embed) - self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embed) - self.mlp = nn.ModuleDict( - dict( - c_fc=nn.Linear(config.n_embed, 4 * config.n_embed), - c_proj=nn.Linear(4 * config.n_embed, config.n_embed), - act=NewGELU(), - dropout=nn.Dropout(config.resid_pdrop), - ) - ) - - def forward(self, x, attention_masks=None): - # Multi-head attention sub-layer. - x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks) - # Residual of multi-head attention sub-layer. - x = x + x_att - - # Position-wise FFN sub-layer: fc + activation + fc + dropout - x_ffn = self.mlp.dropout(self.mlp.c_proj(self.mlp.act(self.mlp.c_fc(x)))) - # Residual of position-wise FFN sub-layer. - x = x + x_ffn - return x, att - - -@DeveloperAPI -def configure_gpt_optimizer( - model: nn.Module, - learning_rate: float, - weight_decay: float, - betas: Tuple[float, float] = (0.9, 0.95), - **kwargs, -) -> torch.optim.Optimizer: - """ - This long function is unfortunately doing something very simple and is - being very defensive: We are separating out all parameters of the model - into two buckets: those that will experience weight decay for regularization - and those that won't (biases, and layernorm/embedding weights). We are then - returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience - # regularizing weight decay - decay = set() - no_decay = set() - whitelist_w_modules = (torch.nn.Linear,) - blacklist_w_modules = (torch.nn.LayerNorm, torch.nn.Embedding) - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(): - fpn = "%s.%s" % (mn, pn) if mn else pn # full param name - # random note: because named_modules and named_parameters are - # recursive we will see the same tensors p many many times. but - # doing it this way allows us to know which parent module any - # tensor p belongs to... - if pn.endswith("bias"): - # all biases will not be decayed - no_decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, whitelist_w_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif pn.endswith("weight") and isinstance(m, blacklist_w_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - # validate that we considered every parameter - param_dict = {pn: p for pn, p in model.named_parameters()} - inter_params = decay & no_decay - union_params = decay | no_decay - assert ( - len(inter_params) == 0 - ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert len(param_dict.keys() - union_params) == 0, ( - f"parameters {str(param_dict.keys() - union_params)} were not " - f"separated into either decay/no_decay set!" - ) - - # create the pytorch optimizer object - optim_groups = [ - { - "params": [param_dict[pn] for pn in sorted(decay)], - "weight_decay": weight_decay, - }, - { - "params": [param_dict[pn] for pn in sorted(no_decay)], - "weight_decay": 0.0, - }, - ] - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **kwargs) - return optimizer - - -@DeveloperAPI -class GPT(nn.Module): - """GPT Transformer Model""" - - def __init__(self, config: GPTConfig): - super().__init__() - assert config.block_size is not None - self.block_size = config.block_size - - self.transformer = nn.ModuleDict( - dict( - drop=nn.Dropout(config.embed_pdrop), - h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f=nn.LayerNorm(config.n_embed), - ) - ) - - # init all weights, and apply a special scaled init to the residual - # projections, per GPT-2 paper - self.apply(self._init_weights) - for pn, p in self.named_parameters(): - if pn.endswith("c_proj.weight"): - torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) - ) - - def _init_weights(self, module): - if isinstance(module, nn.Linear): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - if module.bias is not None: - torch.nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - elif isinstance(module, nn.LayerNorm): - torch.nn.init.zeros_(module.bias) - torch.nn.init.ones_(module.weight) - - def forward(self, input_embeds, attention_masks=None, return_attentions=False): - """ - input_embeds: [batch_size x seq_len x n_embed] - attention_masks: [batch_size x seq_len], 0 don't attend, 1 attend - """ - B, T, C = input_embeds.size() - assert T <= self.block_size, ( - f"Cannot forward sequence of length {T}, " - f"block size is only {self.block_size}" - ) - - if attention_masks is not None: - _B, _T = attention_masks.size() - assert _B == B and _T == T - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_len] - # So we can broadcast to - # [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular - # masking of causal attention used in OpenAI GPT, we just need - # to prepare the broadcast dimension here. - attention_masks = attention_masks[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend - # and 0.0 for masked positions, this operation will create a - # tensor which is 0.0 for positions we want to attend and -inf - # for masked positions. Since we are adding it to the raw scores - # before the softmax, this is effectively the same as removing - # these entirely. - attention_masks = attention_masks.to(dtype=input_embeds.dtype) - attention_masks = (1.0 - attention_masks) * -1e9 - - # forward the GPT model itself - x = self.transformer.drop(input_embeds) - - atts = [] - for block in self.transformer.h: - x, att = block(x, attention_masks=attention_masks) - atts.append(att) - x = self.transformer.ln_f(x) - - if return_attentions: - return x, atts - else: - return x diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py deleted file mode 100644 index 29a02e365b31..000000000000 --- a/rllib/models/torch/misc.py +++ /dev/null @@ -1,195 +0,0 @@ -""" Code adapted from https://github.com/ikostrikov/pytorch-a3c""" -import numpy as np -from typing import Union, Tuple, Any, List - -from ray.rllib.models.utils import get_activation_fn -from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType - -torch, nn = try_import_torch() - - -@DeveloperAPI -def normc_initializer(std: float = 1.0) -> Any: - def initializer(tensor): - tensor.data.normal_(0, 1) - tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True)) - - return initializer - - -@DeveloperAPI -def same_padding( - in_size: Tuple[int, int], - filter_size: Tuple[int, int], - stride_size: Union[int, Tuple[int, int]], -) -> (Union[int, Tuple[int, int]], Tuple[int, int]): - """Note: Padding is added to match TF conv2d `same` padding. See - www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution - - Args: - in_size: Rows (Height), Column (Width) for input - stride_size (Union[int,Tuple[int, int]]): Rows (Height), column (Width) - for stride. If int, height == width. - filter_size: Rows (Height), column (Width) for filter - - Returns: - padding: For input into torch.nn.ZeroPad2d. - output: Output shape after padding and convolution. - """ - in_height, in_width = in_size - if isinstance(filter_size, int): - filter_height, filter_width = filter_size, filter_size - else: - filter_height, filter_width = filter_size - if isinstance(stride_size, (int, float)): - stride_height, stride_width = int(stride_size), int(stride_size) - else: - stride_height, stride_width = int(stride_size[0]), int(stride_size[1]) - - out_height = np.ceil(float(in_height) / float(stride_height)) - out_width = np.ceil(float(in_width) / float(stride_width)) - - pad_along_height = int( - ((out_height - 1) * stride_height + filter_height - in_height) - ) - pad_along_width = int(((out_width - 1) * stride_width + filter_width - in_width)) - pad_top = pad_along_height // 2 - pad_bottom = pad_along_height - pad_top - pad_left = pad_along_width // 2 - pad_right = pad_along_width - pad_left - padding = (pad_left, pad_right, pad_top, pad_bottom) - output = (out_height, out_width) - return padding, output - - -@DeveloperAPI -class SlimConv2d(nn.Module): - """Simple mock of tf.slim Conv2d""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]], - padding: Union[int, Tuple[int, int]], - # Defaulting these to nn.[..] will break soft torch import. - initializer: Any = "default", - activation_fn: Any = "default", - bias_init: float = 0, - ): - """Creates a standard Conv2d layer, similar to torch.nn.Conv2d - - Args: - in_channels: Number of input channels - out_channels: Number of output channels - kernel: If int, the kernel is - a tuple(x,x). Elsewise, the tuple can be specified - stride: Controls the stride - for the cross-correlation. If int, the stride is a - tuple(x,x). Elsewise, the tuple can be specified - padding: Controls the amount - of implicit zero-paddings during the conv operation - initializer: Initializer function for kernel weights - activation_fn: Activation function at the end of layer - bias_init: Initalize bias weights to bias_init const - """ - super(SlimConv2d, self).__init__() - layers = [] - # Padding layer. - if padding: - layers.append(nn.ZeroPad2d(padding)) - # Actual Conv2D layer (including correct initialization logic). - conv = nn.Conv2d(in_channels, out_channels, kernel, stride) - if initializer: - if initializer == "default": - initializer = nn.init.xavier_uniform_ - initializer(conv.weight) - nn.init.constant_(conv.bias, bias_init) - layers.append(conv) - # Activation function (if any; default=ReLu). - if isinstance(activation_fn, str): - if activation_fn == "default": - activation_fn = nn.ReLU - else: - activation_fn = get_activation_fn(activation_fn, "torch") - if activation_fn is not None: - layers.append(activation_fn()) - # Put everything in sequence. - self._model = nn.Sequential(*layers) - - def forward(self, x: TensorType) -> TensorType: - return self._model(x) - - -@DeveloperAPI -class SlimFC(nn.Module): - """Simple PyTorch version of `linear` function""" - - def __init__( - self, - in_size: int, - out_size: int, - initializer: Any = None, - activation_fn: Any = None, - use_bias: bool = True, - bias_init: float = 0.0, - ): - """Creates a standard FC layer, similar to torch.nn.Linear - - Args: - in_size: Input size for FC Layer - out_size: Output size for FC Layer - initializer: Initializer function for FC layer weights - activation_fn: Activation function at the end of layer - use_bias: Whether to add bias weights or not - bias_init: Initalize bias weights to bias_init const - """ - super(SlimFC, self).__init__() - layers = [] - # Actual nn.Linear layer (including correct initialization logic). - linear = nn.Linear(in_size, out_size, bias=use_bias) - if initializer is None: - initializer = nn.init.xavier_uniform_ - initializer(linear.weight) - if use_bias is True: - nn.init.constant_(linear.bias, bias_init) - layers.append(linear) - # Activation function (if any; default=None (linear)). - if isinstance(activation_fn, str): - activation_fn = get_activation_fn(activation_fn, "torch") - if activation_fn is not None: - layers.append(activation_fn()) - # Put everything in sequence. - self._model = nn.Sequential(*layers) - - def forward(self, x: TensorType) -> TensorType: - return self._model(x) - - -@DeveloperAPI -class AppendBiasLayer(nn.Module): - """Simple bias appending layer for free_log_std.""" - - def __init__(self, num_bias_vars: int): - super().__init__() - self.log_std = torch.nn.Parameter(torch.as_tensor([0.0] * num_bias_vars)) - self.register_parameter("log_std", self.log_std) - - def forward(self, x: TensorType) -> TensorType: - out = torch.cat([x, self.log_std.unsqueeze(0).repeat([len(x), 1])], axis=1) - return out - - -@DeveloperAPI -class Reshape(nn.Module): - """Standard module that reshapes/views a tensor""" - - def __init__(self, shape: List): - super().__init__() - self.shape = shape - - def forward(self, x): - return x.view(*self.shape) diff --git a/rllib/models/torch/model.py b/rllib/models/torch/model.py deleted file mode 100644 index 30f341d63238..000000000000 --- a/rllib/models/torch/model.py +++ /dev/null @@ -1,220 +0,0 @@ -import torch -from torch import nn -import tree - -from ray.rllib.utils.annotations import ( - DeveloperAPI, - override, -) -from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig -from ray.rllib.models.base_model import RecurrentModel, Model, ModelIO - - -class TorchModelIO(ModelIO): - """Save/Load mixin for torch models - - Examples: - >>> model.save("/tmp/model_weights.cpt") - >>> model.load("/tmp/model_weights.cpt") - """ - - @DeveloperAPI - @override(ModelIO) - def save(self, path: str) -> None: - """Saves the state dict to the specified path - - Args: - path: Path on disk the checkpoint is saved to - - """ - torch.save(self.state_dict(), path) - - @DeveloperAPI - @override(ModelIO) - def load(self, path: str) -> RecurrentModel: - """Loads the state dict from the specified path - - Args: - path: Path on disk to load the checkpoint from - """ - self.load_state_dict(torch.load(path)) - - -class TorchRecurrentModel(RecurrentModel, nn.Module, TorchModelIO): - """The base class for recurrent pytorch models. - - If implementing a custom recurrent model, you likely want to inherit - from this model. You should make sure to call super().__init__(config) - in your __init__. - - Args: - config: The config used to construct the model - - Required Attributes: - input_spec: SpecDict: Denotes the input keys and shapes passed to `unroll` - output_spec: SpecDict: Denotes the output keys and shapes returned from - `unroll` - prev_state_spec: SpecDict: Denotes the keys and shapes for the input - recurrent states to the model - next_state_spec: SpecDict: Denotes the keys and shapes for the - recurrent states output by the model - - Required Overrides: - # Define unrolling (forward pass) over a sequence of inputs - _unroll(self, inputs: TensorDict, prev_state: TensorDict, **kwargs) - -> Tuple[TensorDict, TensorDict] - - Optional Overrides: - # Define the initial state, if a zero tensor is insufficient - # the returned TensorDict must match the prev_state_spec - _initial_state(self) -> TensorDict - - # Additional checks on the input and recurrent state before `_unroll` - _update_inputs_and_prev_state(inputs: TensorDict, prev_state: TensorDict) - -> Tuple[TensorDict, TensorDict] - - # Additional checks on the output and the output recurrent state - # after `_unroll` - _update_outputs_and_next_state(outputs: TensorDict, next_state: TensorDict) - -> Tuple[TensorDict, TensorDict] - - # Save model weights to path - save(self, path: str) -> None - - # Load model weights from path - load(self, path: str) -> None - - Examples: - >>> class MyCustomModel(TorchRecurrentModel): - ... def __init__(self, config): - ... super().__init__(config) - ... - ... self.lstm = nn.LSTM( - ... input_size, recurrent_size, batch_first=True - ... ) - ... self.project = nn.Linear(recurrent_size, output_size) - ... - ... @property - ... def input_spec(self): - ... return SpecDict( - ... {"obs": "batch time hidden"}, hidden=self.config.input_size - ... ) - ... - ... @property - ... def output_spec(self): - ... return SpecDict( - ... {"logits": "batch time logits"}, logits=self.config.output_size - ... ) - ... - ... @property - ... def prev_state_spec(self): - ... return SpecDict( - ... {"input_state": "batch recur"}, recur=self.config.recurrent_size - ... ) - ... - ... @property - ... def next_state_spec(self): - ... return SpecDict( - ... {"output_state": "batch recur"}, - ... recur=self.config.recurrent_size - ... ) - ... - ... def _unroll(self, inputs, prev_state, **kwargs): - ... output, state = self.lstm(inputs["obs"], prev_state["input_state"]) - ... output = self.project(output) - ... return TensorDict( - ... {"logits": output}), TensorDict({"output_state": state} - ... ) - - """ - - def __init__(self, config: ModelConfig) -> None: - RecurrentModel.__init__(self) - nn.Module.__init__(self) - TorchModelIO.__init__(self, config) - - @override(RecurrentModel) - def _initial_state(self) -> TensorDict: - """Returns the initial recurrent state - - This defaults to all zeros and can be overidden to return - nonzero tensors. - - Returns: - A TensorDict that matches the initial_state_spec - """ - return TensorDict( - tree.map_structure( - lambda spec: torch.zeros(spec.shape, dtype=spec.dtype), - self.initial_state_spec, - ) - ) - - -class TorchModel(Model, nn.Module, TorchModelIO): - """The base class for non-recurrent pytorch models. - - If implementing a custom pytorch model, you likely want to - inherit from this class. You should make sure to call super().__init__(config) - in your __init__. - - Args: - config: The config used to construct the model - - Required Attributes: - input_spec: SpecDict: Denotes the input keys and shapes passed to `_forward` - output_spec: SpecDict: Denotes the output keys and shapes returned from - `_forward` - - Required Overrides: - # Define unrolling (forward pass) over a sequence of inputs - _forward(self, inputs: TensorDict, **kwargs) - -> TensorDict - - Optional Overrides: - # Additional checks on the input before `_forward` - _update_inputs(inputs: TensorDict) -> TensorDict - - # Additional checks on the output after `_forward` - _update_outputs(outputs: TensorDict) -> TensorDict - - # Save model weights to path - save(self, path: str) -> None - - # Load model weights from path - load(self, path: str) -> None - - Examples: - >>> class MyCustomModel(TorchModel): - ... def __init__(self, config): - ... super().__init__(config) - ... self.mlp = nn.Sequential( - ... nn.Linear(input_size, hidden_size), - ... nn.ReLU(), - ... nn.Linear(hidden_size, hidden_size), - ... nn.ReLU(), - ... nn.Linear(hidden_size, output_size) - ... ) - ... - ... @property - ... def input_spec(self): - ... return SpecDict( - ... {"obs": "batch time hidden"}, hidden=self.config.input_size - ... ) - ... - ... @property - ... def output_spec(self): - ... return SpecDict( - ... {"logits": "batch time logits"}, logits=self.config.output_size - ... ) - ... - ... def _forward(self, inputs, **kwargs): - ... output = self.mlp(inputs["obs"]) - ... return TensorDict({"logits": output}) - - """ - - def __init__(self, config: ModelConfig) -> None: - Model.__init__(self) - nn.Module.__init__(self) - TorchModelIO.__init__(self, config) diff --git a/rllib/models/torch/modules/__init__.py b/rllib/models/torch/modules/__init__.py deleted file mode 100644 index 2585dcc77abe..000000000000 --- a/rllib/models/torch/modules/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from ray.rllib.models.torch.modules.gru_gate import GRUGate -from ray.rllib.models.torch.modules.multi_head_attention import MultiHeadAttention -from ray.rllib.models.torch.modules.relative_multi_head_attention import ( - RelativeMultiHeadAttention, -) -from ray.rllib.models.torch.modules.skip_connection import SkipConnection - -__all__ = [ - "GRUGate", - "RelativeMultiHeadAttention", - "SkipConnection", - "MultiHeadAttention", -] diff --git a/rllib/models/torch/modules/convtranspose2d_stack.py b/rllib/models/torch/modules/convtranspose2d_stack.py deleted file mode 100644 index f991400d3df0..000000000000 --- a/rllib/models/torch/modules/convtranspose2d_stack.py +++ /dev/null @@ -1,82 +0,0 @@ -from typing import Tuple - -from ray.rllib.models.torch.misc import Reshape -from ray.rllib.models.utils import get_activation_fn, get_initializer -from ray.rllib.utils.framework import try_import_torch - -torch, nn = try_import_torch() -if torch: - import torch.distributions as td - - -class ConvTranspose2DStack(nn.Module): - """ConvTranspose2D decoder generating an image distribution from a vector.""" - - def __init__( - self, - *, - input_size: int, - filters: Tuple[Tuple[int]] = ( - (1024, 5, 2), - (128, 5, 2), - (64, 6, 2), - (32, 6, 2), - ), - initializer="default", - bias_init=0, - activation_fn: str = "relu", - output_shape: Tuple[int] = (3, 64, 64) - ): - """Initializes a TransposedConv2DStack instance. - - Args: - input_size: The size of the 1D input vector, from which to - generate the image distribution. - filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each - ConvTranspose2D layer): [in_channels, kernel, stride]. - initializer (Union[str]): - bias_init: The initial bias values to use. - activation_fn: Activation function descriptor (str). - output_shape (Tuple[int]): Shape of the final output image. - """ - super().__init__() - self.activation = get_activation_fn(activation_fn, framework="torch") - self.output_shape = output_shape - initializer = get_initializer(initializer, framework="torch") - - in_channels = filters[0][0] - self.layers = [ - # Map from 1D-input vector to correct initial size for the - # Conv2DTransposed stack. - nn.Linear(input_size, in_channels), - # Reshape from the incoming 1D vector (input_size) to 1x1 image - # format (channels first). - Reshape([-1, in_channels, 1, 1]), - ] - for i, (_, kernel, stride) in enumerate(filters): - out_channels = ( - filters[i + 1][0] if i < len(filters) - 1 else output_shape[0] - ) - conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride) - # Apply initializer. - initializer(conv_transp.weight) - nn.init.constant_(conv_transp.bias, bias_init) - self.layers.append(conv_transp) - # Apply activation function, if provided and if not last layer. - if self.activation is not None and i < len(filters) - 1: - self.layers.append(self.activation()) - - # num-outputs == num-inputs for next layer. - in_channels = out_channels - - self._model = nn.Sequential(*self.layers) - - def forward(self, x): - # x is [batch, hor_length, input_size] - batch_dims = x.shape[:-1] - model_out = self._model(x) - - # Equivalent to making a multivariate diag. - reshape_size = batch_dims + self.output_shape - mean = model_out.view(*reshape_size) - return td.Independent(td.Normal(mean, 1.0), len(self.output_shape)) diff --git a/rllib/models/torch/modules/gru_gate.py b/rllib/models/torch/modules/gru_gate.py deleted file mode 100644 index 7eee53534d6d..000000000000 --- a/rllib/models/torch/modules/gru_gate.py +++ /dev/null @@ -1,65 +0,0 @@ -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.framework import TensorType - -torch, nn = try_import_torch() - - -class GRUGate(nn.Module): - """Implements a gated recurrent unit for use in AttentionNet""" - - def __init__(self, dim: int, init_bias: int = 0.0, **kwargs): - """ - input_shape (torch.Tensor): dimension of the input - init_bias: Bias added to every input to stabilize training - """ - super().__init__(**kwargs) - # Xavier initialization of torch tensors - self._w_r = nn.Parameter(torch.zeros(dim, dim)) - self._w_z = nn.Parameter(torch.zeros(dim, dim)) - self._w_h = nn.Parameter(torch.zeros(dim, dim)) - nn.init.xavier_uniform_(self._w_r) - nn.init.xavier_uniform_(self._w_z) - nn.init.xavier_uniform_(self._w_h) - self.register_parameter("_w_r", self._w_r) - self.register_parameter("_w_z", self._w_z) - self.register_parameter("_w_h", self._w_h) - - self._u_r = nn.Parameter(torch.zeros(dim, dim)) - self._u_z = nn.Parameter(torch.zeros(dim, dim)) - self._u_h = nn.Parameter(torch.zeros(dim, dim)) - nn.init.xavier_uniform_(self._u_r) - nn.init.xavier_uniform_(self._u_z) - nn.init.xavier_uniform_(self._u_h) - self.register_parameter("_u_r", self._u_r) - self.register_parameter("_u_z", self._u_z) - self.register_parameter("_u_h", self._u_h) - - self._bias_z = nn.Parameter( - torch.zeros( - dim, - ).fill_(init_bias) - ) - self.register_parameter("_bias_z", self._bias_z) - - def forward(self, inputs: TensorType, **kwargs) -> TensorType: - # Pass in internal state first. - h, X = inputs - - r = torch.tensordot(X, self._w_r, dims=1) + torch.tensordot( - h, self._u_r, dims=1 - ) - r = torch.sigmoid(r) - - z = ( - torch.tensordot(X, self._w_z, dims=1) - + torch.tensordot(h, self._u_z, dims=1) - - self._bias_z - ) - z = torch.sigmoid(z) - - h_next = torch.tensordot(X, self._w_h, dims=1) + torch.tensordot( - (h * r), self._u_h, dims=1 - ) - h_next = torch.tanh(h_next) - - return (1 - z) * h + z * h_next diff --git a/rllib/models/torch/modules/multi_head_attention.py b/rllib/models/torch/modules/multi_head_attention.py deleted file mode 100644 index 68413bde025b..000000000000 --- a/rllib/models/torch/modules/multi_head_attention.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, - Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. - https://arxiv.org/pdf/1706.03762.pdf -""" -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.models.torch.misc import SlimFC -from ray.rllib.utils.torch_utils import sequence_mask -from ray.rllib.utils.framework import TensorType - -torch, nn = try_import_torch() - - -class MultiHeadAttention(nn.Module): - """A multi-head attention layer described in [1].""" - - def __init__( - self, in_dim: int, out_dim: int, num_heads: int, head_dim: int, **kwargs - ): - """ - in_dim: Dimension of input - out_dim: Dimension of output - num_heads: Number of attention heads - head_dim: Output dimension of each attention head - """ - super().__init__(**kwargs) - - # No bias or non-linearity. - self._num_heads = num_heads - self._head_dim = head_dim - self._qkv_layer = SlimFC( - in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False - ) - - self._linear_layer = SlimFC( - in_size=num_heads * head_dim, out_size=out_dim, use_bias=False - ) - - def forward(self, inputs: TensorType) -> TensorType: - L = list(inputs.size())[1] # length of segment - H = self._num_heads # number of attention heads - D = self._head_dim # attention head dimension - - qkv = self._qkv_layer(inputs) - - queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) - queries = queries[:, -L:] # only query based on the segment - - queries = torch.reshape(queries, [-1, L, H, D]) - keys = torch.reshape(keys, [-1, L, H, D]) - values = torch.reshape(values, [-1, L, H, D]) - - score = torch.einsum("bihd,bjhd->bijh", queries, keys) - score = score / D**0.5 - - # causal mask of the same length as the sequence - mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype) - mask = mask[None, :, :, None] - mask = mask.float() - - masked_score = score * mask + 1e30 * (mask - 1.0) - wmat = nn.functional.softmax(masked_score, dim=2) - - out = torch.einsum("bijh,bjhd->bihd", wmat, values) - shape = list(out.size())[:2] + [H * D] - # temp = torch.cat(temp2, [H * D], dim=0) - out = torch.reshape(out, shape) - return self._linear_layer(out) diff --git a/rllib/models/torch/modules/noisy_layer.py b/rllib/models/torch/modules/noisy_layer.py deleted file mode 100644 index 8a9fe999cf79..000000000000 --- a/rllib/models/torch/modules/noisy_layer.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np - -from ray.rllib.models.utils import get_activation_fn -from ray.rllib.utils.framework import try_import_torch, TensorType - -torch, nn = try_import_torch() - - -class NoisyLayer(nn.Module): - r"""A Layer that adds learnable Noise to some previous layer's outputs. - - Consists of: - - a common dense layer: y = w^{T}x + b - - a noisy layer: y = (w + \epsilon_w*\sigma_w)^{T}x + - (b+\epsilon_b*\sigma_b) - , where \epsilon are random variables sampled from factorized normal - distributions and \sigma are trainable variables which are expected to - vanish along the training procedure. - """ - - def __init__( - self, in_size: int, out_size: int, sigma0: float, activation: str = "relu" - ): - """Initializes a NoisyLayer object. - - Args: - in_size: Input size for Noisy Layer - out_size: Output size for Noisy Layer - sigma0: Initialization value for sigma_b (bias noise) - activation: Non-linear activation for Noisy Layer - """ - super().__init__() - - self.in_size = in_size - self.out_size = out_size - self.sigma0 = sigma0 - self.activation = get_activation_fn(activation, framework="torch") - if self.activation is not None: - self.activation = self.activation() - - sigma_w = nn.Parameter( - torch.from_numpy( - np.random.uniform( - low=-1.0 / np.sqrt(float(self.in_size)), - high=1.0 / np.sqrt(float(self.in_size)), - size=[self.in_size, out_size], - ) - ).float() - ) - self.register_parameter("sigma_w", sigma_w) - sigma_b = nn.Parameter( - torch.from_numpy( - np.full( - shape=[out_size], fill_value=sigma0 / np.sqrt(float(self.in_size)) - ) - ).float() - ) - self.register_parameter("sigma_b", sigma_b) - - w = nn.Parameter( - torch.from_numpy( - np.full( - shape=[self.in_size, self.out_size], - fill_value=6 / np.sqrt(float(in_size) + float(out_size)), - ) - ).float() - ) - self.register_parameter("w", w) - b = nn.Parameter(torch.from_numpy(np.zeros([out_size])).float()) - self.register_parameter("b", b) - - def forward(self, inputs: TensorType) -> TensorType: - epsilon_in = self._f_epsilon( - torch.normal( - mean=torch.zeros([self.in_size]), std=torch.ones([self.in_size]) - ).to(inputs.device) - ) - epsilon_out = self._f_epsilon( - torch.normal( - mean=torch.zeros([self.out_size]), std=torch.ones([self.out_size]) - ).to(inputs.device) - ) - epsilon_w = torch.matmul( - torch.unsqueeze(epsilon_in, -1), other=torch.unsqueeze(epsilon_out, 0) - ) - epsilon_b = epsilon_out - - action_activation = ( - torch.matmul(inputs, self.w + self.sigma_w * epsilon_w) - + self.b - + self.sigma_b * epsilon_b - ) - - if self.activation is not None: - action_activation = self.activation(action_activation) - return action_activation - - def _f_epsilon(self, x: TensorType) -> TensorType: - return torch.sign(x) * torch.pow(torch.abs(x), 0.5) diff --git a/rllib/models/torch/modules/relative_multi_head_attention.py b/rllib/models/torch/modules/relative_multi_head_attention.py deleted file mode 100644 index d3ff9cf59eee..000000000000 --- a/rllib/models/torch/modules/relative_multi_head_attention.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import Union - -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.models.torch.misc import SlimFC -from ray.rllib.utils.torch_utils import sequence_mask -from ray.rllib.utils.typing import TensorType - -torch, nn = try_import_torch() - - -class RelativePositionEmbedding(nn.Module): - """Creates a [seq_length x seq_length] matrix for rel. pos encoding. - - Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding - matrix. - - Args: - seq_length: The max. sequence length (time axis). - out_dim: The number of nodes to go into the first Tranformer - layer with. - - Returns: - torch.Tensor: The encoding matrix Phi. - """ - - def __init__(self, out_dim, **kwargs): - super().__init__() - self.out_dim = out_dim - - out_range = torch.arange(0, self.out_dim, 2.0) - inverse_freq = 1 / (10000 ** (out_range / self.out_dim)) - self.register_buffer("inverse_freq", inverse_freq) - - def forward(self, seq_length): - pos_input = torch.arange(seq_length - 1, -1, -1.0, dtype=torch.float).to( - self.inverse_freq.device - ) - sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq) - pos_embeddings = torch.cat( - [torch.sin(sinusoid_input), torch.cos(sinusoid_input)], dim=-1 - ) - return pos_embeddings[:, None, :] - - -class RelativeMultiHeadAttention(nn.Module): - """A RelativeMultiHeadAttention layer as described in [3]. - - Uses segment level recurrence with state reuse. - """ - - def __init__( - self, - in_dim: int, - out_dim: int, - num_heads: int, - head_dim: int, - input_layernorm: bool = False, - output_activation: Union[str, callable] = None, - **kwargs - ): - """Initializes a RelativeMultiHeadAttention nn.Module object. - - Args: - in_dim (int): - out_dim: The output dimension of this module. Also known as - "attention dim". - num_heads: The number of attention heads to use. - Denoted `H` in [2]. - head_dim: The dimension of a single(!) attention head - Denoted `D` in [2]. - input_layernorm: Whether to prepend a LayerNorm before - everything else. Should be True for building a GTrXL. - output_activation (Union[str, callable]): Optional activation - function or activation function specifier (str). - Should be "relu" for GTrXL. - **kwargs: - """ - super().__init__(**kwargs) - - # No bias or non-linearity. - self._num_heads = num_heads - self._head_dim = head_dim - - # 3=Query, key, and value inputs. - self._qkv_layer = SlimFC( - in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False - ) - - self._linear_layer = SlimFC( - in_size=num_heads * head_dim, - out_size=out_dim, - use_bias=False, - activation_fn=output_activation, - ) - - self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim)) - self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim)) - nn.init.xavier_uniform_(self._uvar) - nn.init.xavier_uniform_(self._vvar) - self.register_parameter("_uvar", self._uvar) - self.register_parameter("_vvar", self._vvar) - - self._pos_proj = SlimFC( - in_size=in_dim, out_size=num_heads * head_dim, use_bias=False - ) - self._rel_pos_embedding = RelativePositionEmbedding(out_dim) - - self._input_layernorm = None - if input_layernorm: - self._input_layernorm = torch.nn.LayerNorm(in_dim) - - def forward(self, inputs: TensorType, memory: TensorType = None) -> TensorType: - T = list(inputs.size())[1] # length of segment (time) - H = self._num_heads # number of attention heads - d = self._head_dim # attention head dimension - - # Add previous memory chunk (as const, w/o gradient) to input. - # Tau (number of (prev) time slices in each memory chunk). - Tau = list(memory.shape)[1] - inputs = torch.cat((memory.detach(), inputs), dim=1) - - # Apply the Layer-Norm. - if self._input_layernorm is not None: - inputs = self._input_layernorm(inputs) - - qkv = self._qkv_layer(inputs) - - queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) - # Cut out Tau memory timesteps from query. - queries = queries[:, -T:] - - queries = torch.reshape(queries, [-1, T, H, d]) - keys = torch.reshape(keys, [-1, Tau + T, H, d]) - values = torch.reshape(values, [-1, Tau + T, H, d]) - - R = self._pos_proj(self._rel_pos_embedding(Tau + T)) - R = torch.reshape(R, [Tau + T, H, d]) - - # b=batch - # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) - # h=head - # d=head-dim (over which we will reduce-sum) - score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) - pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) - score = score + self.rel_shift(pos_score) - score = score / d**0.5 - - # causal mask of the same length as the sequence - mask = sequence_mask(torch.arange(Tau + 1, Tau + T + 1), dtype=score.dtype).to( - score.device - ) - mask = mask[None, :, :, None] - - masked_score = score * mask + 1e30 * (mask.float() - 1.0) - wmat = nn.functional.softmax(masked_score, dim=2) - - out = torch.einsum("bijh,bjhd->bihd", wmat, values) - shape = list(out.shape)[:2] + [H * d] - out = torch.reshape(out, shape) - - return self._linear_layer(out) - - @staticmethod - def rel_shift(x: TensorType) -> TensorType: - # Transposed version of the shift approach described in [3]. - # https://github.com/kimiyoung/transformer-xl/blob/ - # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 - x_size = list(x.shape) - - x = torch.nn.functional.pad(x, (0, 0, 1, 0, 0, 0, 0, 0)) - x = torch.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) - x = x[:, 1:, :, :] - x = torch.reshape(x, x_size) - - return x diff --git a/rllib/models/torch/modules/skip_connection.py b/rllib/models/torch/modules/skip_connection.py deleted file mode 100644 index 8bc155eda9ca..000000000000 --- a/rllib/models/torch/modules/skip_connection.py +++ /dev/null @@ -1,41 +0,0 @@ -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType -from typing import Optional - -torch, nn = try_import_torch() - - -class SkipConnection(nn.Module): - """Skip connection layer. - - Adds the original input to the output (regular residual layer) OR uses - input as hidden state input to a given fan_in_layer. - """ - - def __init__( - self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, **kwargs - ): - """Initializes a SkipConnection nn Module object. - - Args: - layer (nn.Module): Any layer processing inputs. - fan_in_layer (Optional[nn.Module]): An optional - layer taking two inputs: The original input and the output - of `layer`. - """ - super().__init__(**kwargs) - self._layer = layer - self._fan_in_layer = fan_in_layer - - def forward(self, inputs: TensorType, **kwargs) -> TensorType: - # del kwargs - outputs = self._layer(inputs, **kwargs) - # Residual case, just add inputs to outputs. - if self._fan_in_layer is None: - outputs = outputs + inputs - # Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`. - else: - # NOTE: In the GRU case, `inputs` is the state input. - outputs = self._fan_in_layer((inputs, outputs)) - - return outputs diff --git a/rllib/models/torch/noop.py b/rllib/models/torch/noop.py deleted file mode 100644 index 8b0705b11874..000000000000 --- a/rllib/models/torch/noop.py +++ /dev/null @@ -1,13 +0,0 @@ -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.annotations import override - - -class TorchNoopModel(TorchModelV2): - """Trivial model that just returns the obs flattened. - - This is the model used if use_state_preprocessor=False.""" - - @override(ModelV2) - def forward(self, input_dict, state, seq_lens): - return input_dict["obs_flat"].float(), state diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py deleted file mode 100644 index 191a0ff35e5a..000000000000 --- a/rllib/models/torch/primitives.py +++ /dev/null @@ -1,54 +0,0 @@ -from typing import List, Optional -from ray.rllib.utils.framework import try_import_torch - -torch, nn = try_import_torch() - -# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. - - -class FCNet(nn.Module): - """A simple fully connected network. - - Attributes: - input_dim: The input dimension of the network. It cannot be None. - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer. - """ - - def __init__( - self, - input_dim: int, - hidden_layers: List[int], - output_dim: Optional[int] = None, - activation: str = "linear", - ): - super().__init__() - self.input_dim = input_dim - self.hidden_layers = hidden_layers - - activation_class = getattr(nn, activation, lambda: None)() - self.layers = [] - self.layers.append(nn.Linear(self.input_dim, self.hidden_layers[0])) - for i in range(len(self.hidden_layers) - 1): - if activation != "linear": - self.layers.append(activation_class) - self.layers.append( - nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]) - ) - - if output_dim is not None: - if activation != "linear": - self.layers.append(activation_class) - self.layers.append(nn.Linear(self.hidden_layers[-1], output_dim)) - - if output_dim is None: - self.output_dim = hidden_layers[-1] - else: - self.output_dim = output_dim - - self.layers = nn.Sequential(*self.layers) - - def forward(self, x): - return self.layers(x) diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py deleted file mode 100644 index ec3f7b3b797c..000000000000 --- a/rllib/models/torch/recurrent_net.py +++ /dev/null @@ -1,285 +0,0 @@ -import numpy as np -import gymnasium as gym -from gymnasium.spaces import Discrete, MultiDiscrete -import tree # pip install dm_tree -from typing import Dict, List, Union, Tuple - -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.torch.misc import SlimFC -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.policy.rnn_sequencing import add_time_dimension -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.view_requirement import ViewRequirement -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot -from ray.rllib.utils.typing import ModelConfigDict, TensorType - -torch, nn = try_import_torch() - - -@DeveloperAPI -class RecurrentNetwork(TorchModelV2): - """Helper class to simplify implementing RNN models with TorchModelV2. - - Instead of implementing forward(), you can implement forward_rnn() which - takes batches with the time dimension added already. - - Here is an example implementation for a subclass - ``MyRNNClass(RecurrentNetwork, nn.Module)``:: - - def __init__(self, obs_space, num_outputs): - nn.Module.__init__(self) - super().__init__(obs_space, action_space, num_outputs, - model_config, name) - self.obs_size = _get_size(obs_space) - self.rnn_hidden_dim = model_config["lstm_cell_size"] - self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim) - self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) - self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) - - self.value_branch = nn.Linear(self.rnn_hidden_dim, 1) - self._cur_value = None - - @override(ModelV2) - def get_initial_state(self): - # Place hidden states on same device as model. - h = [self.fc1.weight.new( - 1, self.rnn_hidden_dim).zero_().squeeze(0)] - return h - - @override(ModelV2) - def value_function(self): - assert self._cur_value is not None, "must call forward() first" - return self._cur_value - - @override(RecurrentNetwork) - def forward_rnn(self, input_dict, state, seq_lens): - x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float())) - h_in = state[0].reshape(-1, self.rnn_hidden_dim) - h = self.rnn(x, h_in) - q = self.fc2(h) - self._cur_value = self.value_branch(h).squeeze(1) - return q, [h] - """ - - @override(ModelV2) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> Tuple[TensorType, List[TensorType]]: - """Adds time dimension to batch before sending inputs to forward_rnn(). - - You should implement forward_rnn() in your subclass.""" - flat_inputs = input_dict["obs_flat"].float() - # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() - # as input_dict may have extra zero-padding beyond seq_lens.max(). - # Use add_time_dimension to handle this - self.time_major = self.model_config.get("_time_major", False) - inputs = add_time_dimension( - flat_inputs, - seq_lens=seq_lens, - framework="torch", - time_major=self.time_major, - ) - output, new_state = self.forward_rnn(inputs, state, seq_lens) - output = torch.reshape(output, [-1, self.num_outputs]) - return output, new_state - - def forward_rnn( - self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType - ) -> Tuple[TensorType, List[TensorType]]: - """Call the model with the given input tensors and state. - - Args: - inputs: Observation tensor with shape [B, T, obs_size]. - state: List of state tensors, each with shape [B, size]. - seq_lens: 1D tensor holding input sequence lengths. - Note: len(seq_lens) == B. - - Returns: - (outputs, new_state): The model output tensor of shape - [B, T, num_outputs] and the list of new state tensors each with - shape [B, size]. - - Examples: - def forward_rnn(self, inputs, state, seq_lens): - model_out, h, c = self.rnn_model([inputs, seq_lens] + state) - return model_out, [h, c] - """ - raise NotImplementedError("You must implement this for an RNN model") - - -class LSTMWrapper(RecurrentNetwork, nn.Module): - """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - - nn.Module.__init__(self) - super(LSTMWrapper, self).__init__( - obs_space, action_space, None, model_config, name - ) - - # At this point, self.num_outputs is the number of nodes coming - # from the wrapped (underlying) model. In other words, self.num_outputs - # is the input size for the LSTM layer. - # If None, set it to the observation space. - if self.num_outputs is None: - self.num_outputs = int(np.product(self.obs_space.shape)) - - self.cell_size = model_config["lstm_cell_size"] - self.time_major = model_config.get("_time_major", False) - self.use_prev_action = model_config["lstm_use_prev_action"] - self.use_prev_reward = model_config["lstm_use_prev_reward"] - - self.action_space_struct = get_base_struct_from_space(self.action_space) - self.action_dim = 0 - - for space in tree.flatten(self.action_space_struct): - if isinstance(space, Discrete): - self.action_dim += space.n - elif isinstance(space, MultiDiscrete): - self.action_dim += np.sum(space.nvec) - elif space.shape is not None: - self.action_dim += int(np.product(space.shape)) - else: - self.action_dim += int(len(space)) - - # Add prev-action/reward nodes to input to LSTM. - if self.use_prev_action: - self.num_outputs += self.action_dim - if self.use_prev_reward: - self.num_outputs += 1 - - # Define actual LSTM layer (with num_outputs being the nodes coming - # from the wrapped (underlying) layer). - self.lstm = nn.LSTM( - self.num_outputs, self.cell_size, batch_first=not self.time_major - ) - - # Set self.num_outputs to the number of output nodes desired by the - # caller of this constructor. - self.num_outputs = num_outputs - - # Postprocess LSTM output with another hidden layer and compute values. - self._logits_branch = SlimFC( - in_size=self.cell_size, - out_size=self.num_outputs, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - self._value_branch = SlimFC( - in_size=self.cell_size, - out_size=1, - activation_fn=None, - initializer=torch.nn.init.xavier_uniform_, - ) - - # __sphinx_doc_begin__ - # Add prev-a/r to this model's view, if required. - if model_config["lstm_use_prev_action"]: - self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement( - SampleBatch.ACTIONS, space=self.action_space, shift=-1 - ) - if model_config["lstm_use_prev_reward"]: - self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement( - SampleBatch.REWARDS, shift=-1 - ) - # __sphinx_doc_end__ - - @override(RecurrentNetwork) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> Tuple[TensorType, List[TensorType]]: - assert seq_lens is not None - # Push obs through "unwrapped" net's `forward()` first. - wrapped_out, _ = self._wrapped_forward(input_dict, [], None) - - # Concat. prev-action/reward if required. - prev_a_r = [] - - # Prev actions. - if self.model_config["lstm_use_prev_action"]: - prev_a = input_dict[SampleBatch.PREV_ACTIONS] - # If actions are not processed yet (in their original form as - # have been sent to environment): - # Flatten/one-hot into 1D array. - if self.model_config["_disable_action_flattening"]: - prev_a_r.append( - flatten_inputs_to_1d_tensor( - prev_a, spaces_struct=self.action_space_struct, time_axis=False - ) - ) - # If actions are already flattened (but not one-hot'd yet!), - # one-hot discrete/multi-discrete actions here. - else: - if isinstance(self.action_space, (Discrete, MultiDiscrete)): - prev_a = one_hot(prev_a.float(), self.action_space) - else: - prev_a = prev_a.float() - prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) - # Prev rewards. - if self.model_config["lstm_use_prev_reward"]: - prev_a_r.append( - torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1]) - ) - - # Concat prev. actions + rewards to the "main" input. - if prev_a_r: - wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) - - # Push everything through our LSTM. - input_dict["obs_flat"] = wrapped_out - return super().forward(input_dict, state, seq_lens) - - @override(RecurrentNetwork) - def forward_rnn( - self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType - ) -> Tuple[TensorType, List[TensorType]]: - # Don't show paddings to RNN(?) - # TODO: (sven) For now, only allow, iff time_major=True to not break - # anything retrospectively (time_major not supported previously). - # max_seq_len = inputs.shape[0] - # time_major = self.model_config["_time_major"] - # if time_major and max_seq_len > 1: - # inputs = torch.nn.utils.rnn.pack_padded_sequence( - # inputs, seq_lens, - # batch_first=not time_major, enforce_sorted=False) - self._features, [h, c] = self.lstm( - inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)] - ) - # Re-apply paddings. - # if time_major and max_seq_len > 1: - # self._features, _ = torch.nn.utils.rnn.pad_packed_sequence( - # self._features, - # batch_first=not time_major) - model_out = self._logits_branch(self._features) - return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] - - @override(ModelV2) - def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: - # Place hidden states on same device as model. - linear = next(self._logits_branch._model.children()) - h = [ - linear.weight.new(1, self.cell_size).zero_().squeeze(0), - linear.weight.new(1, self.cell_size).zero_().squeeze(0), - ] - return h - - @override(ModelV2) - def value_function(self) -> TensorType: - assert self._features is not None, "must call forward() first" - return torch.reshape(self._value_branch(self._features), [-1]) diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py deleted file mode 100644 index dadbec72f2f1..000000000000 --- a/rllib/models/torch/torch_action_dist.py +++ /dev/null @@ -1,648 +0,0 @@ -import functools -import gymnasium as gym -from math import log -import numpy as np -import tree # pip install dm_tree -from typing import Optional - -from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT -from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space -from ray.rllib.utils.typing import TensorType, List, Union, Tuple, ModelConfigDict - -torch, nn = try_import_torch() - - -@DeveloperAPI -class TorchDistributionWrapper(ActionDistribution): - """Wrapper class for torch.distributions.""" - - @override(ActionDistribution) - def __init__(self, inputs: List[TensorType], model: TorchModelV2): - # If inputs are not a torch Tensor, make them one and make sure they - # are on the correct device. - if not isinstance(inputs, torch.Tensor): - inputs = torch.from_numpy(inputs) - if isinstance(model, TorchModelV2): - inputs = inputs.to(next(model.parameters()).device) - super().__init__(inputs, model) - # Store the last sample here. - self.last_sample = None - - @override(ActionDistribution) - def logp(self, actions: TensorType) -> TensorType: - return self.dist.log_prob(actions) - - @override(ActionDistribution) - def entropy(self) -> TensorType: - return self.dist.entropy() - - @override(ActionDistribution) - def kl(self, other: ActionDistribution) -> TensorType: - return torch.distributions.kl.kl_divergence(self.dist, other.dist) - - @override(ActionDistribution) - def sample(self) -> TensorType: - self.last_sample = self.dist.sample() - return self.last_sample - - @override(ActionDistribution) - def sampled_action_logp(self) -> TensorType: - assert self.last_sample is not None - return self.logp(self.last_sample) - - -@DeveloperAPI -class TorchCategorical(TorchDistributionWrapper): - """Wrapper class for PyTorch Categorical distribution.""" - - @override(ActionDistribution) - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2 = None, - temperature: float = 1.0, - ): - if temperature != 1.0: - assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" - inputs /= temperature - super().__init__(inputs, model) - self.dist = torch.distributions.categorical.Categorical(logits=self.inputs) - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - self.last_sample = self.dist.probs.argmax(dim=1) - return self.last_sample - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - return action_space.n - - -@DeveloperAPI -def get_torch_categorical_class_with_temperature(t: float): - """TorchCategorical distribution class that has customized default temperature.""" - - class TorchCategoricalWithTemperature(TorchCategorical): - def __init__(self, inputs, model=None, temperature=t): - super().__init__(inputs, model, temperature) - - return TorchCategoricalWithTemperature - - -@DeveloperAPI -class TorchMultiCategorical(TorchDistributionWrapper): - """MultiCategorical distribution for MultiDiscrete action spaces.""" - - @override(TorchDistributionWrapper) - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2, - input_lens: Union[List[int], np.ndarray, Tuple[int, ...]], - action_space=None, - ): - super().__init__(inputs, model) - # If input_lens is np.ndarray or list, force-make it a tuple. - inputs_split = self.inputs.split(tuple(input_lens), dim=1) - self.cats = [ - torch.distributions.categorical.Categorical(logits=input_) - for input_ in inputs_split - ] - # Used in case we are dealing with an Int Box. - self.action_space = action_space - - @override(TorchDistributionWrapper) - def sample(self) -> TensorType: - arr = [cat.sample() for cat in self.cats] - sample_ = torch.stack(arr, dim=1) - if isinstance(self.action_space, gym.spaces.Box): - sample_ = torch.reshape(sample_, [-1] + list(self.action_space.shape)) - self.last_sample = sample_ - return sample_ - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - arr = [torch.argmax(cat.probs, -1) for cat in self.cats] - sample_ = torch.stack(arr, dim=1) - if isinstance(self.action_space, gym.spaces.Box): - sample_ = torch.reshape(sample_, [-1] + list(self.action_space.shape)) - self.last_sample = sample_ - return sample_ - - @override(TorchDistributionWrapper) - def logp(self, actions: TensorType) -> TensorType: - # # If tensor is provided, unstack it into list. - if isinstance(actions, torch.Tensor): - if isinstance(self.action_space, gym.spaces.Box): - actions = torch.reshape( - actions, [-1, int(np.prod(self.action_space.shape))] - ) - actions = torch.unbind(actions, dim=1) - logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)]) - return torch.sum(logps, dim=0) - - @override(ActionDistribution) - def multi_entropy(self) -> TensorType: - return torch.stack([cat.entropy() for cat in self.cats], dim=1) - - @override(TorchDistributionWrapper) - def entropy(self) -> TensorType: - return torch.sum(self.multi_entropy(), dim=1) - - @override(ActionDistribution) - def multi_kl(self, other: ActionDistribution) -> TensorType: - return torch.stack( - [ - torch.distributions.kl.kl_divergence(cat, oth_cat) - for cat, oth_cat in zip(self.cats, other.cats) - ], - dim=1, - ) - - @override(TorchDistributionWrapper) - def kl(self, other: ActionDistribution) -> TensorType: - return torch.sum(self.multi_kl(other), dim=1) - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - # Int Box. - if isinstance(action_space, gym.spaces.Box): - assert action_space.dtype.name.startswith("int") - low_ = np.min(action_space.low) - high_ = np.max(action_space.high) - assert np.all(action_space.low == low_) - assert np.all(action_space.high == high_) - np.prod(action_space.shape, dtype=np.int32) * (high_ - low_ + 1) - # MultiDiscrete space. - else: - # `nvec` is already integer. No need to cast. - return np.sum(action_space.nvec) - - -@ExperimentalAPI -class TorchSlateMultiCategorical(TorchCategorical): - """MultiCategorical distribution for MultiDiscrete action spaces. - - The action space must be uniform, meaning all nvec items have the same size, e.g. - MultiDiscrete([10, 10, 10]), where 10 is the number of candidates to pick from - and 3 is the slate size (pick 3 out of 10). When picking candidates, no candidate - must be picked more than once. - """ - - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2 = None, - temperature: float = 1.0, - action_space: Optional[gym.spaces.MultiDiscrete] = None, - all_slates=None, - ): - assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" - # Allow softmax formula w/ temperature != 1.0: - # Divide inputs by temperature. - super().__init__(inputs / temperature, model) - self.action_space = action_space - # Assert uniformness of the action space (all discrete buckets have the same - # size). - assert isinstance(self.action_space, gym.spaces.MultiDiscrete) and all( - n == self.action_space.nvec[0] for n in self.action_space.nvec - ) - self.all_slates = all_slates - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - # Get a sample from the underlying Categorical (batch of ints). - sample = super().deterministic_sample() - # Use the sampled ints to pick the actual slates. - return torch.take_along_dim(self.all_slates, sample.long(), dim=-1) - - @override(ActionDistribution) - def logp(self, x: TensorType) -> TensorType: - # TODO: Implement. - return torch.ones_like(self.inputs[:, 0]) - - -@DeveloperAPI -class TorchDiagGaussian(TorchDistributionWrapper): - """Wrapper class for PyTorch Normal distribution.""" - - @override(ActionDistribution) - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2, - *, - action_space: Optional[gym.spaces.Space] = None - ): - super().__init__(inputs, model) - mean, log_std = torch.chunk(self.inputs, 2, dim=1) - self.log_std = log_std - self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std)) - # Remember to squeeze action samples in case action space is Box(shape) - self.zero_action_dim = action_space and action_space.shape == () - - @override(TorchDistributionWrapper) - def sample(self) -> TensorType: - sample = super().sample() - if self.zero_action_dim: - return torch.squeeze(sample, dim=-1) - return sample - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - self.last_sample = self.dist.mean - return self.last_sample - - @override(TorchDistributionWrapper) - def logp(self, actions: TensorType) -> TensorType: - return super().logp(actions).sum(-1) - - @override(TorchDistributionWrapper) - def entropy(self) -> TensorType: - return super().entropy().sum(-1) - - @override(TorchDistributionWrapper) - def kl(self, other: ActionDistribution) -> TensorType: - return super().kl(other).sum(-1) - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - return np.prod(action_space.shape, dtype=np.int32) * 2 - - -@DeveloperAPI -class TorchSquashedGaussian(TorchDistributionWrapper): - """A tanh-squashed Gaussian distribution defined by: mean, std, low, high. - - The distribution will never return low or high exactly, but - `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively. - """ - - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2, - low: float = -1.0, - high: float = 1.0, - ): - """Parameterizes the distribution via `inputs`. - - Args: - low: The lowest possible sampling value - (excluding this value). - high: The highest possible sampling value - (excluding this value). - """ - super().__init__(inputs, model) - # Split inputs into mean and log(std). - mean, log_std = torch.chunk(self.inputs, 2, dim=-1) - # Clip `scale` values (coming from NN) to reasonable values. - log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT) - std = torch.exp(log_std) - self.dist = torch.distributions.normal.Normal(mean, std) - assert np.all(np.less(low, high)) - self.low = low - self.high = high - self.mean = mean - self.std = std - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - self.last_sample = self._squash(self.dist.mean) - return self.last_sample - - @override(TorchDistributionWrapper) - def sample(self) -> TensorType: - # Use the reparameterization version of `dist.sample` to allow for - # the results to be backprop'able e.g. in a loss term. - - normal_sample = self.dist.rsample() - self.last_sample = self._squash(normal_sample) - return self.last_sample - - @override(ActionDistribution) - def logp(self, x: TensorType) -> TensorType: - # Unsquash values (from [low,high] to ]-inf,inf[) - unsquashed_values = self._unsquash(x) - # Get log prob of unsquashed values from our Normal. - log_prob_gaussian = self.dist.log_prob(unsquashed_values) - # For safety reasons, clamp somehow, only then sum up. - log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100) - log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1) - # Get log-prob for squashed Gaussian. - unsquashed_values_tanhd = torch.tanh(unsquashed_values) - log_prob = log_prob_gaussian - torch.sum( - torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1 - ) - return log_prob - - def sample_logp(self): - z = self.dist.rsample() - actions = self._squash(z) - return actions, torch.sum( - self.dist.log_prob(z) - torch.log(1 - actions * actions + SMALL_NUMBER), - dim=-1, - ) - - @override(TorchDistributionWrapper) - def entropy(self) -> TensorType: - raise ValueError("Entropy not defined for SquashedGaussian!") - - @override(TorchDistributionWrapper) - def kl(self, other: ActionDistribution) -> TensorType: - raise ValueError("KL not defined for SquashedGaussian!") - - def _squash(self, raw_values: TensorType) -> TensorType: - # Returned values are within [low, high] (including `low` and `high`). - squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * ( - self.high - self.low - ) + self.low - return torch.clamp(squashed, self.low, self.high) - - def _unsquash(self, values: TensorType) -> TensorType: - normed_values = (values - self.low) / (self.high - self.low) * 2.0 - 1.0 - # Stabilize input to atanh. - save_normed_values = torch.clamp( - normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER - ) - unsquashed = torch.atanh(save_normed_values) - return unsquashed - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - return np.prod(action_space.shape, dtype=np.int32) * 2 - - -@DeveloperAPI -class TorchBeta(TorchDistributionWrapper): - """ - A Beta distribution is defined on the interval [0, 1] and parameterized by - shape parameters alpha and beta (also called concentration parameters). - - PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z - with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) - and Gamma(n) = (n - 1)! - """ - - def __init__( - self, - inputs: List[TensorType], - model: TorchModelV2, - low: float = 0.0, - high: float = 1.0, - ): - super().__init__(inputs, model) - # Stabilize input parameters (possibly coming from a linear layer). - self.inputs = torch.clamp(self.inputs, log(SMALL_NUMBER), -log(SMALL_NUMBER)) - self.inputs = torch.log(torch.exp(self.inputs) + 1.0) + 1.0 - self.low = low - self.high = high - alpha, beta = torch.chunk(self.inputs, 2, dim=-1) - # Note: concentration0==beta, concentration1=alpha (!) - self.dist = torch.distributions.Beta(concentration1=alpha, concentration0=beta) - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - self.last_sample = self._squash(self.dist.mean) - return self.last_sample - - @override(TorchDistributionWrapper) - def sample(self) -> TensorType: - # Use the reparameterization version of `dist.sample` to allow for - # the results to be backprop'able e.g. in a loss term. - normal_sample = self.dist.rsample() - self.last_sample = self._squash(normal_sample) - return self.last_sample - - @override(ActionDistribution) - def logp(self, x: TensorType) -> TensorType: - unsquashed_values = self._unsquash(x) - return torch.sum(self.dist.log_prob(unsquashed_values), dim=-1) - - def _squash(self, raw_values: TensorType) -> TensorType: - return raw_values * (self.high - self.low) + self.low - - def _unsquash(self, values: TensorType) -> TensorType: - return (values - self.low) / (self.high - self.low) - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - return np.prod(action_space.shape, dtype=np.int32) * 2 - - -@DeveloperAPI -class TorchDeterministic(TorchDistributionWrapper): - """Action distribution that returns the input values directly. - - This is similar to DiagGaussian with standard deviation zero (thus only - requiring the "mean" values as NN output). - """ - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - return self.inputs - - @override(TorchDistributionWrapper) - def sampled_action_logp(self) -> TensorType: - return torch.zeros((self.inputs.size()[0],), dtype=torch.float32) - - @override(TorchDistributionWrapper) - def sample(self) -> TensorType: - return self.deterministic_sample() - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape( - action_space: gym.Space, model_config: ModelConfigDict - ) -> Union[int, np.ndarray]: - return np.prod(action_space.shape, dtype=np.int32) - - -@DeveloperAPI -class TorchMultiActionDistribution(TorchDistributionWrapper): - """Action distribution that operates on multiple, possibly nested actions.""" - - def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): - """Initializes a TorchMultiActionDistribution object. - - Args: - inputs (torch.Tensor): A single tensor of shape [BATCH, size]. - model (TorchModelV2): The TorchModelV2 object used to produce - inputs for this distribution. - child_distributions (any[torch.Tensor]): Any struct - that contains the child distribution classes to use to - instantiate the child distributions from `inputs`. This could - be an already flattened list or a struct according to - `action_space`. - input_lens (any[int]): A flat list or a nested struct of input - split lengths used to split `inputs`. - action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex - and possibly nested action space. - """ - if not isinstance(inputs, torch.Tensor): - inputs = torch.from_numpy(inputs) - if isinstance(model, TorchModelV2): - inputs = inputs.to(next(model.parameters()).device) - super().__init__(inputs, model) - - self.action_space_struct = get_base_struct_from_space(action_space) - - self.input_lens = tree.flatten(input_lens) - flat_child_distributions = tree.flatten(child_distributions) - split_inputs = torch.split(inputs, self.input_lens, dim=1) - self.flat_child_distributions = tree.map_structure( - lambda dist, input_: dist(input_, model), - flat_child_distributions, - list(split_inputs), - ) - - @override(ActionDistribution) - def logp(self, x): - if isinstance(x, np.ndarray): - x = torch.Tensor(x) - # Single tensor input (all merged). - if isinstance(x, torch.Tensor): - split_indices = [] - for dist in self.flat_child_distributions: - if isinstance(dist, TorchCategorical): - split_indices.append(1) - elif ( - isinstance(dist, TorchMultiCategorical) - and dist.action_space is not None - ): - split_indices.append(int(np.prod(dist.action_space.shape))) - else: - sample = dist.sample() - # Cover Box(shape=()) case. - if len(sample.shape) == 1: - split_indices.append(1) - else: - split_indices.append(sample.size()[1]) - split_x = list(torch.split(x, split_indices, dim=1)) - # Structured or flattened (by single action component) input. - else: - split_x = tree.flatten(x) - - def map_(val, dist): - # Remove extra categorical dimension. - if isinstance(dist, TorchCategorical): - val = (torch.squeeze(val, dim=-1) if len(val.shape) > 1 else val).int() - return dist.logp(val) - - # Remove extra categorical dimension and take the logp of each - # component. - flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) - - return functools.reduce(lambda a, b: a + b, flat_logps) - - @override(ActionDistribution) - def kl(self, other): - kl_list = [ - d.kl(o) - for d, o in zip( - self.flat_child_distributions, other.flat_child_distributions - ) - ] - return functools.reduce(lambda a, b: a + b, kl_list) - - @override(ActionDistribution) - def entropy(self): - entropy_list = [d.entropy() for d in self.flat_child_distributions] - return functools.reduce(lambda a, b: a + b, entropy_list) - - @override(ActionDistribution) - def sample(self): - child_distributions = tree.unflatten_as( - self.action_space_struct, self.flat_child_distributions - ) - return tree.map_structure(lambda s: s.sample(), child_distributions) - - @override(ActionDistribution) - def deterministic_sample(self): - child_distributions = tree.unflatten_as( - self.action_space_struct, self.flat_child_distributions - ) - return tree.map_structure( - lambda s: s.deterministic_sample(), child_distributions - ) - - @override(TorchDistributionWrapper) - def sampled_action_logp(self): - p = self.flat_child_distributions[0].sampled_action_logp() - for c in self.flat_child_distributions[1:]: - p += c.sampled_action_logp() - return p - - @override(ActionDistribution) - def required_model_output_shape(self, action_space, model_config): - return np.sum(self.input_lens, dtype=np.int32) - - -@DeveloperAPI -class TorchDirichlet(TorchDistributionWrapper): - """Dirichlet distribution for continuous actions that are between - [0,1] and sum to 1. - - e.g. actions that represent resource allocation.""" - - def __init__(self, inputs, model): - """Input is a tensor of logits. The exponential of logits is used to - parametrize the Dirichlet distribution as all parameters need to be - positive. An arbitrary small epsilon is added to the concentration - parameters to be zero due to numerical error. - - See issue #4440 for more details. - """ - self.epsilon = torch.tensor(1e-7).to(inputs.device) - concentration = torch.exp(inputs) + self.epsilon - self.dist = torch.distributions.dirichlet.Dirichlet( - concentration=concentration, - validate_args=True, - ) - super().__init__(concentration, model) - - @override(ActionDistribution) - def deterministic_sample(self) -> TensorType: - self.last_sample = nn.functional.softmax(self.dist.concentration) - return self.last_sample - - @override(ActionDistribution) - def logp(self, x): - # Support of Dirichlet are positive real numbers. x is already - # an array of positive numbers, but we clip to avoid zeros due to - # numerical errors. - x = torch.max(x, self.epsilon) - x = x / torch.sum(x, dim=-1, keepdim=True) - return self.dist.log_prob(x) - - @override(ActionDistribution) - def entropy(self): - return self.dist.entropy() - - @override(ActionDistribution) - def kl(self, other): - return self.dist.kl_divergence(other.dist) - - @staticmethod - @override(ActionDistribution) - def required_model_output_shape(action_space, model_config): - return np.prod(action_space.shape, dtype=np.int32) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py deleted file mode 100644 index 809c516897ef..000000000000 --- a/rllib/models/torch/torch_distributions.py +++ /dev/null @@ -1,257 +0,0 @@ -"""The main difference between this and the old ActionDistribution is that this one -has more explicit input args. So that the input format does not have to be guessed from -the code. This matches the design pattern of torch distribution which developers may -already be familiar with. -""" -import gymnasium as gym -import numpy as np -from typing import Optional -import abc - - -from ray.rllib.models.distributions import Distribution -from ray.rllib.utils.annotations import override, DeveloperAPI -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType, Union, Tuple, ModelConfigDict - -torch, nn = try_import_torch() - - -@DeveloperAPI -class TorchDistribution(Distribution, abc.ABC): - """Wrapper class for torch.distributions.""" - - def __init__(self, *args, **kwargs): - super().__init__() - self._dist = self._get_torch_distribution(*args, **kwargs) - - @abc.abstractmethod - def _get_torch_distribution( - self, *args, **kwargs - ) -> torch.distributions.Distribution: - """Returns the torch.distributions.Distribution object to use.""" - - @override(Distribution) - def logp(self, value: TensorType, **kwargs) -> TensorType: - return self._dist.log_prob(value, **kwargs) - - @override(Distribution) - def entropy(self) -> TensorType: - return self._dist.entropy() - - @override(Distribution) - def kl(self, other: "Distribution") -> TensorType: - return torch.distributions.kl.kl_divergence(self._dist, other._dist) - - @override(Distribution) - def sample( - self, *, sample_shape=torch.Size(), return_logp: bool = False - ) -> Union[TensorType, Tuple[TensorType, TensorType]]: - sample = self._dist.sample(sample_shape) - if return_logp: - return sample, self.logp(sample) - return sample - - @override(Distribution) - def rsample( - self, *, sample_shape=torch.Size(), return_logp: bool = False - ) -> Union[TensorType, Tuple[TensorType, TensorType]]: - rsample = self._dist.rsample(sample_shape) - if return_logp: - return rsample, self.logp(rsample) - return rsample - - -@DeveloperAPI -class TorchCategorical(TorchDistribution): - """Wrapper class for PyTorch Categorical distribution. - - Creates a categorical distribution parameterized by either :attr:`probs` or - :attr:`logits` (but not both). - - Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is - ``probs.size(-1)``. - - If `probs` is 1-dimensional with length-`K`, each element is the relative - probability of sampling the class at that index. - - If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of - relative probability vectors. - - Example:: - >>> m = TorchCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) - >>> m.sample(sample_shape=(2,)) # equal probability of 0, 1, 2, 3 - tensor([3, 4]) - - Args: - probs: The probablities of each event. - logits: Event log probabilities (unnormalized) - temperature: In case of using logits, this parameter can be used to determine - the sharpness of the distribution. i.e. - ``probs = softmax(logits / temperature)``. The temperature must be strictly - positive. A low value (e.g. 1e-10) will result in argmax sampling while a - larger value will result in uniform sampling. - """ - - @override(TorchDistribution) - def __init__( - self, - probs: torch.Tensor = None, - logits: torch.Tensor = None, - temperature: float = 1.0, - ) -> None: - super().__init__(probs=probs, logits=logits, temperature=temperature) - - @override(TorchDistribution) - def _get_torch_distribution( - self, - probs: torch.Tensor = None, - logits: torch.Tensor = None, - temperature: float = 1.0, - ) -> torch.distributions.Distribution: - if logits is not None: - assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" - logits /= temperature - return torch.distributions.categorical.Categorical(probs, logits) - - @staticmethod - @override(Distribution) - def required_model_output_shape( - space: gym.Space, model_config: ModelConfigDict - ) -> Tuple[int, ...]: - return (space.n,) - - -@DeveloperAPI -class TorchDiagGaussian(TorchDistribution): - """Wrapper class for PyTorch Normal distribution. - - Creates a normal distribution parameterized by :attr:`loc` and :attr:`scale`. In - case of multi-dimensional distribution, the variance is assumed to be diagonal. - - Example:: - - >>> m = Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0, 1.0])) - >>> m.sample(sample_shape=(2,)) # 2d normal dist with loc=0 and scale=1 - tensor([[ 0.1046, -0.6120], [ 0.234, 0.556]]) - - >>> # scale is None - >>> m = Normal(loc=torch.tensor([0.0, 1.0])) - >>> m.sample(sample_shape=(2,)) # normally distributed with loc=0 and scale=1 - tensor([0.1046, 0.6120]) - - - Args: - loc: mean of the distribution (often referred to as mu). If scale is None, the - second half of the `loc` will be used as the log of scale. - scale: standard deviation of the distribution (often referred to as sigma). - Has to be positive. - """ - - @override(TorchDistribution) - def __init__( - self, - loc: Union[float, torch.Tensor], - scale: Optional[Union[float, torch.Tensor]] = None, - ): - super().__init__(loc=loc, scale=scale) - - def _get_torch_distribution( - self, loc, scale=None - ) -> torch.distributions.Distribution: - if scale is None: - loc, log_std = torch.chunk(self.inputs, 2, dim=1) - scale = torch.exp(log_std) - return torch.distributions.normal.Normal(loc, scale) - - @override(TorchDistribution) - def logp(self, value: TensorType) -> TensorType: - return super().logp(value).sum(-1) - - @override(TorchDistribution) - def entropy(self) -> TensorType: - return super().entropy().sum(-1) - - @override(TorchDistribution) - def kl(self, other: "TorchDistribution") -> TensorType: - return super().kl(other).sum(-1) - - @staticmethod - @override(Distribution) - def required_model_output_shape( - space: gym.Space, model_config: ModelConfigDict - ) -> Tuple[int, ...]: - return tuple(np.prod(space.shape, dtype=np.int32) * 2) - - -@DeveloperAPI -class TorchDeterministic(Distribution): - """The distribution that returns the input values directly. - - This is similar to DiagGaussian with standard deviation zero (thus only - requiring the "mean" values as NN output). - - Note: entropy is always zero, ang logp and kl are not implemented. - - Example:: - - >>> m = TorchDeterministic(loc=torch.tensor([0.0, 0.0])) - >>> m.sample(sample_shape=(2,)) - tensor([[ 0.0, 0.0], [ 0.0, 0.0]]) - - Args: - loc: the determinsitic value to return - """ - - @override(Distribution) - def __init__(self, loc: torch.Tensor) -> None: - super().__init__() - self.loc = loc - - @override(Distribution) - def sample( - self, - *, - sample_shape: Tuple[int, ...] = None, - return_logp: bool = False, - **kwargs, - ) -> Union[TensorType, Tuple[TensorType, TensorType]]: - if return_logp: - raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") - - if sample_shape is None: - sample_shape = torch.Size() - - device = self.loc.device - dtype = self.loc.dtype - shape = sample_shape + self.loc.shape - return torch.ones(shape, device=device, dtype=dtype) * self.loc - - def rsample( - self, - *, - sample_shape: Tuple[int, ...] = None, - return_logp: bool = False, - **kwargs, - ) -> Union[TensorType, Tuple[TensorType, TensorType]]: - raise NotImplementedError - - @override(Distribution) - def logp(self, value: TensorType, **kwargs) -> TensorType: - raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") - - @override(Distribution) - def entropy(self, **kwargs) -> TensorType: - raise torch.zeros_like(self.loc) - - @override(Distribution) - def kl(self, other: "Distribution", **kwargs) -> TensorType: - raise ValueError(f"Cannot return kl for {self.__class__.__name__}.") - - @staticmethod - @override(Distribution) - def required_model_output_shape( - space: gym.Space, model_config: ModelConfigDict - ) -> Tuple[int, ...]: - # TODO: This was copied from previous code. Is this correct? add unit test. - return tuple(np.prod(space.shape, dtype=np.int32)) diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py deleted file mode 100644 index b56bf425fb6f..000000000000 --- a/rllib/models/torch/torch_modelv2.py +++ /dev/null @@ -1,81 +0,0 @@ -import gymnasium as gym -from typing import Dict, List, Union - -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import ModelConfigDict, TensorType - -_, nn = try_import_torch() - - -@PublicAPI -class TorchModelV2(ModelV2): - """Torch version of ModelV2. - - Note that this class by itself is not a valid model unless you - inherit from nn.Module and implement forward() in a subclass.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - """Initialize a TorchModelV2. - - Here is an example implementation for a subclass - ``MyModelClass(TorchModelV2, nn.Module)``:: - - def __init__(self, *args, **kwargs): - TorchModelV2.__init__(self, *args, **kwargs) - nn.Module.__init__(self) - self._hidden_layers = nn.Sequential(...) - self._logits = ... - self._value_branch = ... - """ - - if not isinstance(self, nn.Module): - raise ValueError( - "Subclasses of TorchModelV2 must also inherit from " - "nn.Module, e.g., MyModel(TorchModelV2, nn.Module)" - ) - - ModelV2.__init__( - self, - obs_space, - action_space, - num_outputs, - model_config, - name, - framework="torch", - ) - - # Dict to store per multi-gpu tower stats into. - # In PyTorch multi-GPU, we use a single TorchPolicy and copy - # it's Model(s) n times (1 copy for each GPU). When computing the loss - # on each tower, we cannot store the stats (e.g. `entropy`) inside the - # policy object as this would lead to race conditions between the - # different towers all accessing the same property at the same time. - self.tower_stats = {} - - @override(ModelV2) - def variables( - self, as_dict: bool = False - ) -> Union[List[TensorType], Dict[str, TensorType]]: - p = list(self.parameters()) - if as_dict: - return {k: p[i] for i, k in enumerate(self.state_dict().keys())} - return p - - @override(ModelV2) - def trainable_variables( - self, as_dict: bool = False - ) -> Union[List[TensorType], Dict[str, TensorType]]: - if as_dict: - return { - k: v for k, v in self.variables(as_dict=True).items() if v.requires_grad - } - return [v for v in self.variables() if v.requires_grad] diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py deleted file mode 100644 index 32153b1e2e80..000000000000 --- a/rllib/models/torch/visionnet.py +++ /dev/null @@ -1,293 +0,0 @@ -import numpy as np -from typing import Dict, List -import gymnasium as gym - -from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 -from ray.rllib.models.torch.misc import ( - normc_initializer, - same_padding, - SlimConv2d, - SlimFC, -) -from ray.rllib.models.utils import get_activation_fn, get_filter_config -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import ModelConfigDict, TensorType - -torch, nn = try_import_torch() - - -class VisionNetwork(TorchModelV2, nn.Module): - """Generic vision network.""" - - def __init__( - self, - obs_space: gym.spaces.Space, - action_space: gym.spaces.Space, - num_outputs: int, - model_config: ModelConfigDict, - name: str, - ): - - if not model_config.get("conv_filters"): - model_config["conv_filters"] = get_filter_config(obs_space.shape) - - TorchModelV2.__init__( - self, obs_space, action_space, num_outputs, model_config, name - ) - nn.Module.__init__(self) - - activation = self.model_config.get("conv_activation") - filters = self.model_config["conv_filters"] - assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!" - - # Post FC net config. - post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", []) - post_fcnet_activation = get_activation_fn( - model_config.get("post_fcnet_activation"), framework="torch" - ) - - no_final_linear = self.model_config.get("no_final_linear") - vf_share_layers = self.model_config.get("vf_share_layers") - - # Whether the last layer is the output of a Flattened (rather than - # a n x (1,1) Conv2D). - self.last_layer_is_flattened = False - self._logits = None - - layers = [] - (w, h, in_channels) = obs_space.shape - - in_size = [w, h] - for out_channels, kernel, stride in filters[:-1]: - padding, out_size = same_padding(in_size, kernel, stride) - layers.append( - SlimConv2d( - in_channels, - out_channels, - kernel, - stride, - padding, - activation_fn=activation, - ) - ) - in_channels = out_channels - in_size = out_size - - out_channels, kernel, stride = filters[-1] - - # No final linear: Last layer has activation function and exits with - # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending - # on `post_fcnet_...` settings). - if no_final_linear and num_outputs: - out_channels = out_channels if post_fcnet_hiddens else num_outputs - layers.append( - SlimConv2d( - in_channels, - out_channels, - kernel, - stride, - None, # padding=valid - activation_fn=activation, - ) - ) - - # Add (optional) post-fc-stack after last Conv2D layer. - layer_sizes = post_fcnet_hiddens[:-1] + ( - [num_outputs] if post_fcnet_hiddens else [] - ) - for i, out_size in enumerate(layer_sizes): - layers.append( - SlimFC( - in_size=out_channels, - out_size=out_size, - activation_fn=post_fcnet_activation, - initializer=normc_initializer(1.0), - ) - ) - out_channels = out_size - - # Finish network normally (w/o overriding last layer size with - # `num_outputs`), then add another linear one of size `num_outputs`. - else: - layers.append( - SlimConv2d( - in_channels, - out_channels, - kernel, - stride, - None, # padding=valid - activation_fn=activation, - ) - ) - - # num_outputs defined. Use that to create an exact - # `num_output`-sized (1,1)-Conv2D. - if num_outputs: - in_size = [ - np.ceil((in_size[0] - kernel[0]) / stride), - np.ceil((in_size[1] - kernel[1]) / stride), - ] - padding, _ = same_padding(in_size, [1, 1], [1, 1]) - if post_fcnet_hiddens: - layers.append(nn.Flatten()) - in_size = out_channels - # Add (optional) post-fc-stack after last Conv2D layer. - for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]): - layers.append( - SlimFC( - in_size=in_size, - out_size=out_size, - activation_fn=post_fcnet_activation - if i < len(post_fcnet_hiddens) - 1 - else None, - initializer=normc_initializer(1.0), - ) - ) - in_size = out_size - # Last layer is logits layer. - self._logits = layers.pop() - - else: - self._logits = SlimConv2d( - out_channels, - num_outputs, - [1, 1], - 1, - padding, - activation_fn=None, - ) - - # num_outputs not known -> Flatten, then set self.num_outputs - # to the resulting number of nodes. - else: - self.last_layer_is_flattened = True - layers.append(nn.Flatten()) - - self._convs = nn.Sequential(*layers) - - # If our num_outputs still unknown, we need to do a test pass to - # figure out the output dimensions. This could be the case, if we have - # the Flatten layer at the end. - if self.num_outputs is None: - # Create a B=1 dummy sample and push it through out conv-net. - dummy_in = ( - torch.from_numpy(self.obs_space.sample()) - .permute(2, 0, 1) - .unsqueeze(0) - .float() - ) - dummy_out = self._convs(dummy_in) - self.num_outputs = dummy_out.shape[1] - - # Build the value layers - self._value_branch_separate = self._value_branch = None - if vf_share_layers: - self._value_branch = SlimFC( - out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None - ) - else: - vf_layers = [] - (w, h, in_channels) = obs_space.shape - in_size = [w, h] - for out_channels, kernel, stride in filters[:-1]: - padding, out_size = same_padding(in_size, kernel, stride) - vf_layers.append( - SlimConv2d( - in_channels, - out_channels, - kernel, - stride, - padding, - activation_fn=activation, - ) - ) - in_channels = out_channels - in_size = out_size - - out_channels, kernel, stride = filters[-1] - vf_layers.append( - SlimConv2d( - in_channels, - out_channels, - kernel, - stride, - None, - activation_fn=activation, - ) - ) - - vf_layers.append( - SlimConv2d( - in_channels=out_channels, - out_channels=1, - kernel=1, - stride=1, - padding=None, - activation_fn=None, - ) - ) - self._value_branch_separate = nn.Sequential(*vf_layers) - - # Holds the current "base" output (before logits layer). - self._features = None - - @override(TorchModelV2) - def forward( - self, - input_dict: Dict[str, TensorType], - state: List[TensorType], - seq_lens: TensorType, - ) -> (TensorType, List[TensorType]): - self._features = input_dict["obs"].float() - # Permuate b/c data comes in as [B, dim, dim, channels]: - self._features = self._features.permute(0, 3, 1, 2) - conv_out = self._convs(self._features) - # Store features to save forward pass when getting value_function out. - if not self._value_branch_separate: - self._features = conv_out - - if not self.last_layer_is_flattened: - if self._logits: - conv_out = self._logits(conv_out) - if len(conv_out.shape) == 4: - if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: - raise ValueError( - "Given `conv_filters` ({}) do not result in a [B, {} " - "(`num_outputs`), 1, 1] shape (but in {})! Please " - "adjust your Conv2D stack such that the last 2 dims " - "are both 1.".format( - self.model_config["conv_filters"], - self.num_outputs, - list(conv_out.shape), - ) - ) - logits = conv_out.squeeze(3) - logits = logits.squeeze(2) - else: - logits = conv_out - return logits, state - else: - return conv_out, state - - @override(TorchModelV2) - def value_function(self) -> TensorType: - assert self._features is not None, "must call forward() first" - if self._value_branch_separate: - value = self._value_branch_separate(self._features) - value = value.squeeze(3) - value = value.squeeze(2) - return value.squeeze(1) - else: - if not self.last_layer_is_flattened: - features = self._features.squeeze(3) - features = features.squeeze(2) - else: - features = self._features - return self._value_branch(features).squeeze(1) - - def _hidden_layers(self, obs: TensorType) -> TensorType: - res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major - res = res.squeeze(3) - res = res.squeeze(2) - return res From 6085e05f9304f6b8547bf8152ab74bd65fbaadb8 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Thu, 26 Jan 2023 13:33:03 -0800 Subject: [PATCH 42/51] add back torch folder Signed-off-by: Artur Niederfahrenhorst --- rllib/models/torch/__init__.py | 12 + rllib/models/torch/attention_net.py | 452 ++++++++++++ rllib/models/torch/complex_input_net.py | 238 +++++++ rllib/models/torch/fcnet.py | 161 +++++ rllib/models/torch/mingpt.py | 299 ++++++++ rllib/models/torch/misc.py | 195 ++++++ rllib/models/torch/model.py | 220 ++++++ rllib/models/torch/modules/__init__.py | 13 + .../torch/modules/convtranspose2d_stack.py | 82 +++ rllib/models/torch/modules/gru_gate.py | 65 ++ .../torch/modules/multi_head_attention.py | 68 ++ rllib/models/torch/modules/noisy_layer.py | 99 +++ .../modules/relative_multi_head_attention.py | 175 +++++ rllib/models/torch/modules/skip_connection.py | 41 ++ rllib/models/torch/noop.py | 13 + rllib/models/torch/primitives.py | 54 ++ rllib/models/torch/recurrent_net.py | 285 ++++++++ rllib/models/torch/torch_action_dist.py | 648 ++++++++++++++++++ rllib/models/torch/torch_distributions.py | 257 +++++++ rllib/models/torch/torch_modelv2.py | 81 +++ rllib/models/torch/visionnet.py | 293 ++++++++ 21 files changed, 3751 insertions(+) create mode 100644 rllib/models/torch/__init__.py create mode 100644 rllib/models/torch/attention_net.py create mode 100644 rllib/models/torch/complex_input_net.py create mode 100644 rllib/models/torch/fcnet.py create mode 100644 rllib/models/torch/mingpt.py create mode 100644 rllib/models/torch/misc.py create mode 100644 rllib/models/torch/model.py create mode 100644 rllib/models/torch/modules/__init__.py create mode 100644 rllib/models/torch/modules/convtranspose2d_stack.py create mode 100644 rllib/models/torch/modules/gru_gate.py create mode 100644 rllib/models/torch/modules/multi_head_attention.py create mode 100644 rllib/models/torch/modules/noisy_layer.py create mode 100644 rllib/models/torch/modules/relative_multi_head_attention.py create mode 100644 rllib/models/torch/modules/skip_connection.py create mode 100644 rllib/models/torch/noop.py create mode 100644 rllib/models/torch/primitives.py create mode 100644 rllib/models/torch/recurrent_net.py create mode 100644 rllib/models/torch/torch_action_dist.py create mode 100644 rllib/models/torch/torch_distributions.py create mode 100644 rllib/models/torch/torch_modelv2.py create mode 100644 rllib/models/torch/visionnet.py diff --git a/rllib/models/torch/__init__.py b/rllib/models/torch/__init__.py new file mode 100644 index 000000000000..abbe5ef60464 --- /dev/null +++ b/rllib/models/torch/__init__.py @@ -0,0 +1,12 @@ +# from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +# from ray.rllib.models.torch.fcnet import FullyConnectedNetwork +# from ray.rllib.models.torch.recurrent_net import \ +# RecurrentNetwork +# from ray.rllib.models.torch.visionnet import VisionNetwork + +# __all__ = [ +# "FullyConnectedNetwork", +# "RecurrentNetwork", +# "TorchModelV2", +# "VisionNetwork", +# ] diff --git a/rllib/models/torch/attention_net.py b/rllib/models/torch/attention_net.py new file mode 100644 index 000000000000..454c0a555c97 --- /dev/null +++ b/rllib/models/torch/attention_net.py @@ -0,0 +1,452 @@ +""" +[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, + Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. + https://arxiv.org/pdf/1706.03762.pdf +[2] - Stabilizing Transformers for Reinforcement Learning - E. Parisotto + et al. - DeepMind - 2019. https://arxiv.org/pdf/1910.06764.pdf +[3] - Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context. + Z. Dai, Z. Yang, et al. - Carnegie Mellon U - 2019. + https://www.aclweb.org/anthology/P19-1285.pdf +""" +import gymnasium as gym +from gymnasium.spaces import Box, Discrete, MultiDiscrete +import numpy as np +import tree # pip install dm_tree +from typing import Dict, Optional, Union + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.models.torch.modules import ( + GRUGate, + RelativeMultiHeadAttention, + SkipConnection, +) +from ray.rllib.models.torch.recurrent_net import RecurrentNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot +from ray.rllib.utils.typing import ModelConfigDict, TensorType, List + +torch, nn = try_import_torch() + + +class GTrXLNet(RecurrentNetwork, nn.Module): + """A GTrXL net Model described in [2]. + + This is still in an experimental phase. + Can be used as a drop-in replacement for LSTMs in PPO and IMPALA. + For an example script, see: `ray/rllib/examples/attention_net.py`. + + To use this network as a replacement for an RNN, configure your Trainer + as follows: + + Examples: + >> config["model"]["custom_model"] = GTrXLNet + >> config["model"]["max_seq_len"] = 10 + >> config["model"]["custom_model_config"] = { + >> num_transformer_units=1, + >> attention_dim=32, + >> num_heads=2, + >> memory_tau=50, + >> etc.. + >> } + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: Optional[int], + model_config: ModelConfigDict, + name: str, + *, + num_transformer_units: int = 1, + attention_dim: int = 64, + num_heads: int = 2, + memory_inference: int = 50, + memory_training: int = 50, + head_dim: int = 32, + position_wise_mlp_dim: int = 32, + init_gru_gate_bias: float = 2.0 + ): + """Initializes a GTrXLNet. + + Args: + num_transformer_units: The number of Transformer repeats to + use (denoted L in [2]). + attention_dim: The input and output dimensions of one + Transformer unit. + num_heads: The number of attention heads to use in parallel. + Denoted as `H` in [3]. + memory_inference: The number of timesteps to concat (time + axis) and feed into the next transformer unit as inference + input. The first transformer unit will receive this number of + past observations (plus the current one), instead. + memory_training: The number of timesteps to concat (time + axis) and feed into the next transformer unit as training + input (plus the actual input sequence of len=max_seq_len). + The first transformer unit will receive this number of + past observations (plus the input sequence), instead. + head_dim: The dimension of a single(!) attention head within + a multi-head attention unit. Denoted as `d` in [3]. + position_wise_mlp_dim: The dimension of the hidden layer + within the position-wise MLP (after the multi-head attention + block within one Transformer unit). This is the size of the + first of the two layers within the PositionwiseFeedforward. The + second layer always has size=`attention_dim`. + init_gru_gate_bias: Initial bias values for the GRU gates + (two GRUs per Transformer unit, one after the MHA, one after + the position-wise MLP). + """ + + super().__init__( + observation_space, action_space, num_outputs, model_config, name + ) + + nn.Module.__init__(self) + + self.num_transformer_units = num_transformer_units + self.attention_dim = attention_dim + self.num_heads = num_heads + self.memory_inference = memory_inference + self.memory_training = memory_training + self.head_dim = head_dim + self.max_seq_len = model_config["max_seq_len"] + self.obs_dim = observation_space.shape[0] + + self.linear_layer = SlimFC(in_size=self.obs_dim, out_size=self.attention_dim) + + self.layers = [self.linear_layer] + + attention_layers = [] + # 2) Create L Transformer blocks according to [2]. + for i in range(self.num_transformer_units): + # RelativeMultiHeadAttention part. + MHA_layer = SkipConnection( + RelativeMultiHeadAttention( + in_dim=self.attention_dim, + out_dim=self.attention_dim, + num_heads=num_heads, + head_dim=head_dim, + input_layernorm=True, + output_activation=nn.ReLU, + ), + fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), + ) + + # Position-wise MultiLayerPerceptron part. + E_layer = SkipConnection( + nn.Sequential( + torch.nn.LayerNorm(self.attention_dim), + SlimFC( + in_size=self.attention_dim, + out_size=position_wise_mlp_dim, + use_bias=False, + activation_fn=nn.ReLU, + ), + SlimFC( + in_size=position_wise_mlp_dim, + out_size=self.attention_dim, + use_bias=False, + activation_fn=nn.ReLU, + ), + ), + fan_in_layer=GRUGate(self.attention_dim, init_gru_gate_bias), + ) + + # Build a list of all attanlayers in order. + attention_layers.extend([MHA_layer, E_layer]) + + # Create a Sequential such that all parameters inside the attention + # layers are automatically registered with this top-level model. + self.attention_layers = nn.Sequential(*attention_layers) + self.layers.extend(attention_layers) + + # Final layers if num_outputs not None. + self.logits = None + self.values_out = None + # Last value output. + self._value_out = None + # Postprocess GTrXL output with another hidden layer. + if self.num_outputs is not None: + self.logits = SlimFC( + in_size=self.attention_dim, + out_size=self.num_outputs, + activation_fn=nn.ReLU, + ) + + # Value function used by all RLlib Torch RL implementations. + self.values_out = SlimFC( + in_size=self.attention_dim, out_size=1, activation_fn=None + ) + else: + self.num_outputs = self.attention_dim + + # Setup trajectory views (`memory-inference` x past memory outs). + for i in range(self.num_transformer_units): + space = Box(-1.0, 1.0, shape=(self.attention_dim,)) + self.view_requirements["state_in_{}".format(i)] = ViewRequirement( + "state_out_{}".format(i), + shift="-{}:-1".format(self.memory_inference), + # Repeat the incoming state every max-seq-len times. + batch_repeat_value=self.max_seq_len, + space=space, + ) + self.view_requirements["state_out_{}".format(i)] = ViewRequirement( + space=space, used_for_training=False + ) + + @override(ModelV2) + def forward( + self, input_dict, state: List[TensorType], seq_lens: TensorType + ) -> (TensorType, List[TensorType]): + assert seq_lens is not None + + # Add the needed batch rank (tf Models' Input requires this). + observations = input_dict[SampleBatch.OBS] + # Add the time dim to observations. + B = len(seq_lens) + T = observations.shape[0] // B + observations = torch.reshape( + observations, [-1, T] + list(observations.shape[1:]) + ) + + all_out = observations + memory_outs = [] + for i in range(len(self.layers)): + # MHA layers which need memory passed in. + if i % 2 == 1: + all_out = self.layers[i](all_out, memory=state[i // 2]) + # Either self.linear_layer (initial obs -> attn. dim layer) or + # MultiLayerPerceptrons. The output of these layers is always the + # memory for the next forward pass. + else: + all_out = self.layers[i](all_out) + memory_outs.append(all_out) + + # Discard last output (not needed as a memory since it's the last + # layer). + memory_outs = memory_outs[:-1] + + if self.logits is not None: + out = self.logits(all_out) + self._value_out = self.values_out(all_out) + out_dim = self.num_outputs + else: + out = all_out + out_dim = self.attention_dim + + return torch.reshape(out, [-1, out_dim]), [ + torch.reshape(m, [-1, self.attention_dim]) for m in memory_outs + ] + + # TODO: (sven) Deprecate this once trajectory view API has fully matured. + @override(RecurrentNetwork) + def get_initial_state(self) -> List[np.ndarray]: + return [] + + @override(ModelV2) + def value_function(self) -> TensorType: + assert ( + self._value_out is not None + ), "Must call forward first AND must have value branch!" + return torch.reshape(self._value_out, [-1]) + + +class AttentionWrapper(TorchModelV2, nn.Module): + """GTrXL wrapper serving as interface for ModelV2s that set use_attention.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + + nn.Module.__init__(self) + super().__init__(obs_space, action_space, None, model_config, name) + + self.use_n_prev_actions = model_config["attention_use_n_prev_actions"] + self.use_n_prev_rewards = model_config["attention_use_n_prev_rewards"] + + self.action_space_struct = get_base_struct_from_space(self.action_space) + self.action_dim = 0 + + for space in tree.flatten(self.action_space_struct): + if isinstance(space, Discrete): + self.action_dim += space.n + elif isinstance(space, MultiDiscrete): + self.action_dim += np.sum(space.nvec) + elif space.shape is not None: + self.action_dim += int(np.product(space.shape)) + else: + self.action_dim += int(len(space)) + + # Add prev-action/reward nodes to input to LSTM. + if self.use_n_prev_actions: + self.num_outputs += self.use_n_prev_actions * self.action_dim + if self.use_n_prev_rewards: + self.num_outputs += self.use_n_prev_rewards + + cfg = model_config + + self.attention_dim = cfg["attention_dim"] + + if self.num_outputs is not None: + in_space = gym.spaces.Box( + float("-inf"), float("inf"), shape=(self.num_outputs,), dtype=np.float32 + ) + else: + in_space = obs_space + + # Construct GTrXL sub-module w/ num_outputs=None (so it does not + # create a logits/value output; we'll do this ourselves in this wrapper + # here). + self.gtrxl = GTrXLNet( + in_space, + action_space, + None, + model_config, + "gtrxl", + num_transformer_units=cfg["attention_num_transformer_units"], + attention_dim=self.attention_dim, + num_heads=cfg["attention_num_heads"], + head_dim=cfg["attention_head_dim"], + memory_inference=cfg["attention_memory_inference"], + memory_training=cfg["attention_memory_training"], + position_wise_mlp_dim=cfg["attention_position_wise_mlp_dim"], + init_gru_gate_bias=cfg["attention_init_gru_gate_bias"], + ) + + # Set final num_outputs to correct value (depending on action space). + self.num_outputs = num_outputs + + # Postprocess GTrXL output with another hidden layer and compute + # values. + self._logits_branch = SlimFC( + in_size=self.attention_dim, + out_size=self.num_outputs, + activation_fn=None, + initializer=torch.nn.init.xavier_uniform_, + ) + self._value_branch = SlimFC( + in_size=self.attention_dim, + out_size=1, + activation_fn=None, + initializer=torch.nn.init.xavier_uniform_, + ) + + self.view_requirements = self.gtrxl.view_requirements + self.view_requirements["obs"].space = self.obs_space + + # Add prev-a/r to this model's view, if required. + if self.use_n_prev_actions: + self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement( + SampleBatch.ACTIONS, + space=self.action_space, + shift="-{}:-1".format(self.use_n_prev_actions), + ) + if self.use_n_prev_rewards: + self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement( + SampleBatch.REWARDS, shift="-{}:-1".format(self.use_n_prev_rewards) + ) + + @override(RecurrentNetwork) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> (TensorType, List[TensorType]): + assert seq_lens is not None + # Push obs through "unwrapped" net's `forward()` first. + wrapped_out, _ = self._wrapped_forward(input_dict, [], None) + + # Concat. prev-action/reward if required. + prev_a_r = [] + + # Prev actions. + if self.use_n_prev_actions: + prev_n_actions = input_dict[SampleBatch.PREV_ACTIONS] + # If actions are not processed yet (in their original form as + # have been sent to environment): + # Flatten/one-hot into 1D array. + if self.model_config["_disable_action_flattening"]: + # Merge prev n actions into flat tensor. + flat = flatten_inputs_to_1d_tensor( + prev_n_actions, + spaces_struct=self.action_space_struct, + time_axis=True, + ) + # Fold time-axis into flattened data. + flat = torch.reshape(flat, [flat.shape[0], -1]) + prev_a_r.append(flat) + # If actions are already flattened (but not one-hot'd yet!), + # one-hot discrete/multi-discrete actions here and concatenate the + # n most recent actions together. + else: + if isinstance(self.action_space, Discrete): + for i in range(self.use_n_prev_actions): + prev_a_r.append( + one_hot( + prev_n_actions[:, i].float(), space=self.action_space + ) + ) + elif isinstance(self.action_space, MultiDiscrete): + for i in range( + 0, self.use_n_prev_actions, self.action_space.shape[0] + ): + prev_a_r.append( + one_hot( + prev_n_actions[ + :, i : i + self.action_space.shape[0] + ].float(), + space=self.action_space, + ) + ) + else: + prev_a_r.append( + torch.reshape( + prev_n_actions.float(), + [-1, self.use_n_prev_actions * self.action_dim], + ) + ) + # Prev rewards. + if self.use_n_prev_rewards: + prev_a_r.append( + torch.reshape( + input_dict[SampleBatch.PREV_REWARDS].float(), + [-1, self.use_n_prev_rewards], + ) + ) + + # Concat prev. actions + rewards to the "main" input. + if prev_a_r: + wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) + + # Then through our GTrXL. + input_dict["obs_flat"] = input_dict["obs"] = wrapped_out + + self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens) + model_out = self._logits_branch(self._features) + return model_out, memory_outs + + @override(ModelV2) + def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: + return [ + torch.zeros( + self.gtrxl.view_requirements["state_in_{}".format(i)].space.shape + ) + for i in range(self.gtrxl.num_transformer_units) + ] + + @override(ModelV2) + def value_function(self) -> TensorType: + assert self._features is not None, "Must call forward() first!" + return torch.reshape(self._value_branch(self._features), [-1]) diff --git a/rllib/models/torch/complex_input_net.py b/rllib/models/torch/complex_input_net.py new file mode 100644 index 000000000000..f3cb4311521d --- /dev/null +++ b/rllib/models/torch/complex_input_net.py @@ -0,0 +1,238 @@ +from gymnasium.spaces import Box, Discrete, MultiDiscrete +import numpy as np +import tree # pip install dm_tree + +# TODO (sven): add IMPALA-style option. +# from ray.rllib.examples.models.impala_vision_nets import TorchImpalaVisionNet +from ray.rllib.models.torch.misc import ( + normc_initializer as torch_normc_initializer, + SlimFC, +) +from ray.rllib.models.catalog import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2, restore_original_dimensions +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.utils import get_filter_config +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.space_utils import flatten_space +from ray.rllib.utils.torch_utils import one_hot + +torch, nn = try_import_torch() + + +class ComplexInputNetwork(TorchModelV2, nn.Module): + """TorchModelV2 concat'ing CNN outputs to flat input(s), followed by FC(s). + + Note: This model should be used for complex (Dict or Tuple) observation + spaces that have one or more image components. + + The data flow is as follows: + + `obs` (e.g. Tuple[img0, img1, discrete0]) -> `CNN0 + CNN1 + ONE-HOT` + `CNN0 + CNN1 + ONE-HOT` -> concat all flat outputs -> `out` + `out` -> (optional) FC-stack -> `out2` + `out2` -> action (logits) and value heads. + """ + + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + self.original_space = ( + obs_space.original_space + if hasattr(obs_space, "original_space") + else obs_space + ) + + self.processed_obs_space = ( + self.original_space + if model_config.get("_disable_preprocessor_api") + else obs_space + ) + + nn.Module.__init__(self) + TorchModelV2.__init__( + self, self.original_space, action_space, num_outputs, model_config, name + ) + + self.flattened_input_space = flatten_space(self.original_space) + + # Atari type CNNs or IMPALA type CNNs (with residual layers)? + # self.cnn_type = self.model_config["custom_model_config"].get( + # "conv_type", "atari") + + # Build the CNN(s) given obs_space's image components. + self.cnns = nn.ModuleDict() + self.one_hot = nn.ModuleDict() + self.flatten_dims = {} + self.flatten = nn.ModuleDict() + concat_size = 0 + for i, component in enumerate(self.flattened_input_space): + i = str(i) + # Image space. + if len(component.shape) == 3 and isinstance(component, Box): + config = { + "conv_filters": model_config["conv_filters"] + if "conv_filters" in model_config + else get_filter_config(component.shape), + "conv_activation": model_config.get("conv_activation"), + "post_fcnet_hiddens": [], + } + # if self.cnn_type == "atari": + self.cnns[i] = ModelCatalog.get_model_v2( + component, + action_space, + num_outputs=None, + model_config=config, + framework="torch", + name="cnn_{}".format(i), + ) + # TODO (sven): add IMPALA-style option. + # else: + # cnn = TorchImpalaVisionNet( + # component, + # action_space, + # num_outputs=None, + # model_config=config, + # name="cnn_{}".format(i)) + + concat_size += self.cnns[i].num_outputs + self.add_module("cnn_{}".format(i), self.cnns[i]) + # Discrete|MultiDiscrete inputs -> One-hot encode. + elif isinstance(component, (Discrete, MultiDiscrete)): + if isinstance(component, Discrete): + size = component.n + else: + size = np.sum(component.nvec) + config = { + "fcnet_hiddens": model_config["fcnet_hiddens"], + "fcnet_activation": model_config.get("fcnet_activation"), + "post_fcnet_hiddens": [], + } + self.one_hot[i] = ModelCatalog.get_model_v2( + Box(-1.0, 1.0, (size,), np.float32), + action_space, + num_outputs=None, + model_config=config, + framework="torch", + name="one_hot_{}".format(i), + ) + concat_size += self.one_hot[i].num_outputs + self.add_module("one_hot_{}".format(i), self.one_hot[i]) + # Everything else (1D Box). + else: + size = int(np.product(component.shape)) + config = { + "fcnet_hiddens": model_config["fcnet_hiddens"], + "fcnet_activation": model_config.get("fcnet_activation"), + "post_fcnet_hiddens": [], + } + self.flatten[i] = ModelCatalog.get_model_v2( + Box(-1.0, 1.0, (size,), np.float32), + action_space, + num_outputs=None, + model_config=config, + framework="torch", + name="flatten_{}".format(i), + ) + self.flatten_dims[i] = size + concat_size += self.flatten[i].num_outputs + self.add_module("flatten_{}".format(i), self.flatten[i]) + + # Optional post-concat FC-stack. + post_fc_stack_config = { + "fcnet_hiddens": model_config.get("post_fcnet_hiddens", []), + "fcnet_activation": model_config.get("post_fcnet_activation", "relu"), + } + self.post_fc_stack = ModelCatalog.get_model_v2( + Box(float("-inf"), float("inf"), shape=(concat_size,), dtype=np.float32), + self.action_space, + None, + post_fc_stack_config, + framework="torch", + name="post_fc_stack", + ) + + # Actions and value heads. + self.logits_layer = None + self.value_layer = None + self._value_out = None + + if num_outputs: + # Action-distribution head. + self.logits_layer = SlimFC( + in_size=self.post_fc_stack.num_outputs, + out_size=num_outputs, + activation_fn=None, + initializer=torch_normc_initializer(0.01), + ) + # Create the value branch model. + self.value_layer = SlimFC( + in_size=self.post_fc_stack.num_outputs, + out_size=1, + activation_fn=None, + initializer=torch_normc_initializer(0.01), + ) + else: + self.num_outputs = concat_size + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + if SampleBatch.OBS in input_dict and "obs_flat" in input_dict: + orig_obs = input_dict[SampleBatch.OBS] + else: + orig_obs = restore_original_dimensions( + input_dict[SampleBatch.OBS], self.processed_obs_space, tensorlib="torch" + ) + # Push observations through the different components + # (CNNs, one-hot + FC, etc..). + outs = [] + for i, component in enumerate(tree.flatten(orig_obs)): + i = str(i) + if i in self.cnns: + cnn_out, _ = self.cnns[i](SampleBatch({SampleBatch.OBS: component})) + outs.append(cnn_out) + elif i in self.one_hot: + if component.dtype in [ + torch.int8, + torch.int16, + torch.int32, + torch.int64, + torch.uint8, + ]: + one_hot_in = { + SampleBatch.OBS: one_hot( + component, self.flattened_input_space[int(i)] + ) + } + else: + one_hot_in = {SampleBatch.OBS: component} + one_hot_out, _ = self.one_hot[i](SampleBatch(one_hot_in)) + outs.append(one_hot_out) + else: + nn_out, _ = self.flatten[i]( + SampleBatch( + { + SampleBatch.OBS: torch.reshape( + component, [-1, self.flatten_dims[i]] + ) + } + ) + ) + outs.append(nn_out) + + # Concat all outputs and the non-image inputs. + out = torch.cat(outs, dim=1) + # Push through (optional) FC-stack (this may be an empty stack). + out, _ = self.post_fc_stack(SampleBatch({SampleBatch.OBS: out})) + + # No logits/value branches. + if self.logits_layer is None: + return out, [] + + # Logits- and value branches. + logits, values = self.logits_layer(out), self.value_layer(out) + self._value_out = torch.reshape(values, [-1]) + return logits, [] + + @override(ModelV2) + def value_function(self): + return self._value_out diff --git a/rllib/models/torch/fcnet.py b/rllib/models/torch/fcnet.py new file mode 100644 index 000000000000..97bb9096bb64 --- /dev/null +++ b/rllib/models/torch/fcnet.py @@ -0,0 +1,161 @@ +import logging +import numpy as np +import gymnasium as gym + +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.torch.misc import SlimFC, AppendBiasLayer, normc_initializer +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict + +torch, nn = try_import_torch() + +logger = logging.getLogger(__name__) + + +class FullyConnectedNetwork(TorchModelV2, nn.Module): + """Generic fully connected network.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + nn.Module.__init__(self) + + hiddens = list(model_config.get("fcnet_hiddens", [])) + list( + model_config.get("post_fcnet_hiddens", []) + ) + activation = model_config.get("fcnet_activation") + if not model_config.get("fcnet_hiddens", []): + activation = model_config.get("post_fcnet_activation") + no_final_linear = model_config.get("no_final_linear") + self.vf_share_layers = model_config.get("vf_share_layers") + self.free_log_std = model_config.get("free_log_std") + # Generate free-floating bias variables for the second half of + # the outputs. + if self.free_log_std: + assert num_outputs % 2 == 0, ( + "num_outputs must be divisible by two", + num_outputs, + ) + num_outputs = num_outputs // 2 + + layers = [] + prev_layer_size = int(np.product(obs_space.shape)) + self._logits = None + + # Create layers 0 to second-last. + for size in hiddens[:-1]: + layers.append( + SlimFC( + in_size=prev_layer_size, + out_size=size, + initializer=normc_initializer(1.0), + activation_fn=activation, + ) + ) + prev_layer_size = size + + # The last layer is adjusted to be of size num_outputs, but it's a + # layer with activation. + if no_final_linear and num_outputs: + layers.append( + SlimFC( + in_size=prev_layer_size, + out_size=num_outputs, + initializer=normc_initializer(1.0), + activation_fn=activation, + ) + ) + prev_layer_size = num_outputs + # Finish the layers with the provided sizes (`hiddens`), plus - + # iff num_outputs > 0 - a last linear layer of size num_outputs. + else: + if len(hiddens) > 0: + layers.append( + SlimFC( + in_size=prev_layer_size, + out_size=hiddens[-1], + initializer=normc_initializer(1.0), + activation_fn=activation, + ) + ) + prev_layer_size = hiddens[-1] + if num_outputs: + self._logits = SlimFC( + in_size=prev_layer_size, + out_size=num_outputs, + initializer=normc_initializer(0.01), + activation_fn=None, + ) + else: + self.num_outputs = ([int(np.product(obs_space.shape))] + hiddens[-1:])[ + -1 + ] + + # Layer to add the log std vars to the state-dependent means. + if self.free_log_std and self._logits: + self._append_free_log_std = AppendBiasLayer(num_outputs) + + self._hidden_layers = nn.Sequential(*layers) + + self._value_branch_separate = None + if not self.vf_share_layers: + # Build a parallel set of hidden layers for the value net. + prev_vf_layer_size = int(np.product(obs_space.shape)) + vf_layers = [] + for size in hiddens: + vf_layers.append( + SlimFC( + in_size=prev_vf_layer_size, + out_size=size, + activation_fn=activation, + initializer=normc_initializer(1.0), + ) + ) + prev_vf_layer_size = size + self._value_branch_separate = nn.Sequential(*vf_layers) + + self._value_branch = SlimFC( + in_size=prev_layer_size, + out_size=1, + initializer=normc_initializer(0.01), + activation_fn=None, + ) + # Holds the current "base" output (before logits layer). + self._features = None + # Holds the last input, in case value branch is separate. + self._last_flat_in = None + + @override(TorchModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> (TensorType, List[TensorType]): + obs = input_dict["obs_flat"].float() + self._last_flat_in = obs.reshape(obs.shape[0], -1) + self._features = self._hidden_layers(self._last_flat_in) + logits = self._logits(self._features) if self._logits else self._features + if self.free_log_std: + logits = self._append_free_log_std(logits) + return logits, state + + @override(TorchModelV2) + def value_function(self) -> TensorType: + assert self._features is not None, "must call forward() first" + if self._value_branch_separate: + out = self._value_branch( + self._value_branch_separate(self._last_flat_in) + ).squeeze(1) + else: + out = self._value_branch(self._features).squeeze(1) + return out diff --git a/rllib/models/torch/mingpt.py b/rllib/models/torch/mingpt.py new file mode 100644 index 000000000000..00a192e9ec91 --- /dev/null +++ b/rllib/models/torch/mingpt.py @@ -0,0 +1,299 @@ +# LICENSE: MIT +""" +Adapted from https://github.com/karpathy/minGPT + +Full definition of a GPT Language Model, all of it in this single file. +References: +1) the official GPT-2 TensorFlow implementation released by OpenAI: +https://github.com/openai/gpt-2/blob/master/src/model.py +2) huggingface/transformers PyTorch implementation: +https://github.com/huggingface/transformers/blob/main/src/transformers + /models/gpt2/modeling_gpt2.py +""" + +import math +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from ray.rllib.utils.annotations import DeveloperAPI + + +@DeveloperAPI +@dataclass +class GPTConfig: + # block size must be provided + block_size: int + + # transformer config + n_layer: int = 12 + n_head: int = 12 + n_embed: int = 768 + + # dropout config + embed_pdrop: float = 0.1 + resid_pdrop: float = 0.1 + attn_pdrop: float = 0.1 + + +class NewGELU(nn.Module): + """ + Implementation of the GELU activation function currently in Google BERT + repo (identical to OpenAI GPT). + Reference: Gaussian Error Linear Units (GELU) paper: + https://arxiv.org/abs/1606.08415 + """ + + def forward(self, x): + return ( + 0.5 + * x + * ( + 1.0 + + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)) + ) + ) + ) + + +class CausalSelfAttention(nn.Module): + """ + Vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config: GPTConfig): + super().__init__() + assert config.n_embed % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed) + # output projection + self.c_proj = nn.Linear(config.n_embed, config.n_embed) + # regularization + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + # causal mask to ensure that attention is only applied to the left + # in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + self.n_head = config.n_head + self.n_embed = config.n_embed + + def forward(self, x, attention_masks=None): + # batch size, sequence length, embedding dimensionality (n_embed) + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head + # forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embed, dim=2) + # (B, nh, T, hs) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) + + # causal self-attention; Self-attend: + # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + if attention_masks is not None: + att = att + attention_masks + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, C) + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y, att + + +class Block(nn.Module): + """an unassuming Transformer block""" + + def __init__(self, config: GPTConfig): + super().__init__() + self.ln_1 = nn.LayerNorm(config.n_embed) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.n_embed) + self.mlp = nn.ModuleDict( + dict( + c_fc=nn.Linear(config.n_embed, 4 * config.n_embed), + c_proj=nn.Linear(4 * config.n_embed, config.n_embed), + act=NewGELU(), + dropout=nn.Dropout(config.resid_pdrop), + ) + ) + + def forward(self, x, attention_masks=None): + # Multi-head attention sub-layer. + x_att, att = self.attn(self.ln_1(x), attention_masks=attention_masks) + # Residual of multi-head attention sub-layer. + x = x + x_att + + # Position-wise FFN sub-layer: fc + activation + fc + dropout + x_ffn = self.mlp.dropout(self.mlp.c_proj(self.mlp.act(self.mlp.c_fc(x)))) + # Residual of position-wise FFN sub-layer. + x = x + x_ffn + return x, att + + +@DeveloperAPI +def configure_gpt_optimizer( + model: nn.Module, + learning_rate: float, + weight_decay: float, + betas: Tuple[float, float] = (0.9, 0.95), + **kwargs, +) -> torch.optim.Optimizer: + """ + This long function is unfortunately doing something very simple and is + being very defensive: We are separating out all parameters of the model + into two buckets: those that will experience weight decay for regularization + and those that won't (biases, and layernorm/embedding weights). We are then + returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience + # regularizing weight decay + decay = set() + no_decay = set() + whitelist_w_modules = (torch.nn.Linear,) + blacklist_w_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = "%s.%s" % (mn, pn) if mn else pn # full param name + # random note: because named_modules and named_parameters are + # recursive we will see the same tensors p many many times. but + # doing it this way allows us to know which parent module any + # tensor p belongs to... + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_w_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_w_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in model.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert ( + len(inter_params) == 0 + ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert len(param_dict.keys() - union_params) == 0, ( + f"parameters {str(param_dict.keys() - union_params)} were not " + f"separated into either decay/no_decay set!" + ) + + # create the pytorch optimizer object + optim_groups = [ + { + "params": [param_dict[pn] for pn in sorted(decay)], + "weight_decay": weight_decay, + }, + { + "params": [param_dict[pn] for pn in sorted(no_decay)], + "weight_decay": 0.0, + }, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **kwargs) + return optimizer + + +@DeveloperAPI +class GPT(nn.Module): + """GPT Transformer Model""" + + def __init__(self, config: GPTConfig): + super().__init__() + assert config.block_size is not None + self.block_size = config.block_size + + self.transformer = nn.ModuleDict( + dict( + drop=nn.Dropout(config.embed_pdrop), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=nn.LayerNorm(config.n_embed), + ) + ) + + # init all weights, and apply a special scaled init to the residual + # projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + ) + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + def forward(self, input_embeds, attention_masks=None, return_attentions=False): + """ + input_embeds: [batch_size x seq_len x n_embed] + attention_masks: [batch_size x seq_len], 0 don't attend, 1 attend + """ + B, T, C = input_embeds.size() + assert T <= self.block_size, ( + f"Cannot forward sequence of length {T}, " + f"block size is only {self.block_size}" + ) + + if attention_masks is not None: + _B, _T = attention_masks.size() + assert _B == B and _T == T + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_len] + # So we can broadcast to + # [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular + # masking of causal attention used in OpenAI GPT, we just need + # to prepare the broadcast dimension here. + attention_masks = attention_masks[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend + # and 0.0 for masked positions, this operation will create a + # tensor which is 0.0 for positions we want to attend and -inf + # for masked positions. Since we are adding it to the raw scores + # before the softmax, this is effectively the same as removing + # these entirely. + attention_masks = attention_masks.to(dtype=input_embeds.dtype) + attention_masks = (1.0 - attention_masks) * -1e9 + + # forward the GPT model itself + x = self.transformer.drop(input_embeds) + + atts = [] + for block in self.transformer.h: + x, att = block(x, attention_masks=attention_masks) + atts.append(att) + x = self.transformer.ln_f(x) + + if return_attentions: + return x, atts + else: + return x diff --git a/rllib/models/torch/misc.py b/rllib/models/torch/misc.py new file mode 100644 index 000000000000..29a02e365b31 --- /dev/null +++ b/rllib/models/torch/misc.py @@ -0,0 +1,195 @@ +""" Code adapted from https://github.com/ikostrikov/pytorch-a3c""" +import numpy as np +from typing import Union, Tuple, Any, List + +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +@DeveloperAPI +def normc_initializer(std: float = 1.0) -> Any: + def initializer(tensor): + tensor.data.normal_(0, 1) + tensor.data *= std / torch.sqrt(tensor.data.pow(2).sum(1, keepdim=True)) + + return initializer + + +@DeveloperAPI +def same_padding( + in_size: Tuple[int, int], + filter_size: Tuple[int, int], + stride_size: Union[int, Tuple[int, int]], +) -> (Union[int, Tuple[int, int]], Tuple[int, int]): + """Note: Padding is added to match TF conv2d `same` padding. See + www.tensorflow.org/versions/r0.12/api_docs/python/nn/convolution + + Args: + in_size: Rows (Height), Column (Width) for input + stride_size (Union[int,Tuple[int, int]]): Rows (Height), column (Width) + for stride. If int, height == width. + filter_size: Rows (Height), column (Width) for filter + + Returns: + padding: For input into torch.nn.ZeroPad2d. + output: Output shape after padding and convolution. + """ + in_height, in_width = in_size + if isinstance(filter_size, int): + filter_height, filter_width = filter_size, filter_size + else: + filter_height, filter_width = filter_size + if isinstance(stride_size, (int, float)): + stride_height, stride_width = int(stride_size), int(stride_size) + else: + stride_height, stride_width = int(stride_size[0]), int(stride_size[1]) + + out_height = np.ceil(float(in_height) / float(stride_height)) + out_width = np.ceil(float(in_width) / float(stride_width)) + + pad_along_height = int( + ((out_height - 1) * stride_height + filter_height - in_height) + ) + pad_along_width = int(((out_width - 1) * stride_width + filter_width - in_width)) + pad_top = pad_along_height // 2 + pad_bottom = pad_along_height - pad_top + pad_left = pad_along_width // 2 + pad_right = pad_along_width - pad_left + padding = (pad_left, pad_right, pad_top, pad_bottom) + output = (out_height, out_width) + return padding, output + + +@DeveloperAPI +class SlimConv2d(nn.Module): + """Simple mock of tf.slim Conv2d""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + padding: Union[int, Tuple[int, int]], + # Defaulting these to nn.[..] will break soft torch import. + initializer: Any = "default", + activation_fn: Any = "default", + bias_init: float = 0, + ): + """Creates a standard Conv2d layer, similar to torch.nn.Conv2d + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel: If int, the kernel is + a tuple(x,x). Elsewise, the tuple can be specified + stride: Controls the stride + for the cross-correlation. If int, the stride is a + tuple(x,x). Elsewise, the tuple can be specified + padding: Controls the amount + of implicit zero-paddings during the conv operation + initializer: Initializer function for kernel weights + activation_fn: Activation function at the end of layer + bias_init: Initalize bias weights to bias_init const + """ + super(SlimConv2d, self).__init__() + layers = [] + # Padding layer. + if padding: + layers.append(nn.ZeroPad2d(padding)) + # Actual Conv2D layer (including correct initialization logic). + conv = nn.Conv2d(in_channels, out_channels, kernel, stride) + if initializer: + if initializer == "default": + initializer = nn.init.xavier_uniform_ + initializer(conv.weight) + nn.init.constant_(conv.bias, bias_init) + layers.append(conv) + # Activation function (if any; default=ReLu). + if isinstance(activation_fn, str): + if activation_fn == "default": + activation_fn = nn.ReLU + else: + activation_fn = get_activation_fn(activation_fn, "torch") + if activation_fn is not None: + layers.append(activation_fn()) + # Put everything in sequence. + self._model = nn.Sequential(*layers) + + def forward(self, x: TensorType) -> TensorType: + return self._model(x) + + +@DeveloperAPI +class SlimFC(nn.Module): + """Simple PyTorch version of `linear` function""" + + def __init__( + self, + in_size: int, + out_size: int, + initializer: Any = None, + activation_fn: Any = None, + use_bias: bool = True, + bias_init: float = 0.0, + ): + """Creates a standard FC layer, similar to torch.nn.Linear + + Args: + in_size: Input size for FC Layer + out_size: Output size for FC Layer + initializer: Initializer function for FC layer weights + activation_fn: Activation function at the end of layer + use_bias: Whether to add bias weights or not + bias_init: Initalize bias weights to bias_init const + """ + super(SlimFC, self).__init__() + layers = [] + # Actual nn.Linear layer (including correct initialization logic). + linear = nn.Linear(in_size, out_size, bias=use_bias) + if initializer is None: + initializer = nn.init.xavier_uniform_ + initializer(linear.weight) + if use_bias is True: + nn.init.constant_(linear.bias, bias_init) + layers.append(linear) + # Activation function (if any; default=None (linear)). + if isinstance(activation_fn, str): + activation_fn = get_activation_fn(activation_fn, "torch") + if activation_fn is not None: + layers.append(activation_fn()) + # Put everything in sequence. + self._model = nn.Sequential(*layers) + + def forward(self, x: TensorType) -> TensorType: + return self._model(x) + + +@DeveloperAPI +class AppendBiasLayer(nn.Module): + """Simple bias appending layer for free_log_std.""" + + def __init__(self, num_bias_vars: int): + super().__init__() + self.log_std = torch.nn.Parameter(torch.as_tensor([0.0] * num_bias_vars)) + self.register_parameter("log_std", self.log_std) + + def forward(self, x: TensorType) -> TensorType: + out = torch.cat([x, self.log_std.unsqueeze(0).repeat([len(x), 1])], axis=1) + return out + + +@DeveloperAPI +class Reshape(nn.Module): + """Standard module that reshapes/views a tensor""" + + def __init__(self, shape: List): + super().__init__() + self.shape = shape + + def forward(self, x): + return x.view(*self.shape) diff --git a/rllib/models/torch/model.py b/rllib/models/torch/model.py new file mode 100644 index 000000000000..30f341d63238 --- /dev/null +++ b/rllib/models/torch/model.py @@ -0,0 +1,220 @@ +import torch +from torch import nn +import tree + +from ray.rllib.utils.annotations import ( + DeveloperAPI, + override, +) +from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig +from ray.rllib.models.base_model import RecurrentModel, Model, ModelIO + + +class TorchModelIO(ModelIO): + """Save/Load mixin for torch models + + Examples: + >>> model.save("/tmp/model_weights.cpt") + >>> model.load("/tmp/model_weights.cpt") + """ + + @DeveloperAPI + @override(ModelIO) + def save(self, path: str) -> None: + """Saves the state dict to the specified path + + Args: + path: Path on disk the checkpoint is saved to + + """ + torch.save(self.state_dict(), path) + + @DeveloperAPI + @override(ModelIO) + def load(self, path: str) -> RecurrentModel: + """Loads the state dict from the specified path + + Args: + path: Path on disk to load the checkpoint from + """ + self.load_state_dict(torch.load(path)) + + +class TorchRecurrentModel(RecurrentModel, nn.Module, TorchModelIO): + """The base class for recurrent pytorch models. + + If implementing a custom recurrent model, you likely want to inherit + from this model. You should make sure to call super().__init__(config) + in your __init__. + + Args: + config: The config used to construct the model + + Required Attributes: + input_spec: SpecDict: Denotes the input keys and shapes passed to `unroll` + output_spec: SpecDict: Denotes the output keys and shapes returned from + `unroll` + prev_state_spec: SpecDict: Denotes the keys and shapes for the input + recurrent states to the model + next_state_spec: SpecDict: Denotes the keys and shapes for the + recurrent states output by the model + + Required Overrides: + # Define unrolling (forward pass) over a sequence of inputs + _unroll(self, inputs: TensorDict, prev_state: TensorDict, **kwargs) + -> Tuple[TensorDict, TensorDict] + + Optional Overrides: + # Define the initial state, if a zero tensor is insufficient + # the returned TensorDict must match the prev_state_spec + _initial_state(self) -> TensorDict + + # Additional checks on the input and recurrent state before `_unroll` + _update_inputs_and_prev_state(inputs: TensorDict, prev_state: TensorDict) + -> Tuple[TensorDict, TensorDict] + + # Additional checks on the output and the output recurrent state + # after `_unroll` + _update_outputs_and_next_state(outputs: TensorDict, next_state: TensorDict) + -> Tuple[TensorDict, TensorDict] + + # Save model weights to path + save(self, path: str) -> None + + # Load model weights from path + load(self, path: str) -> None + + Examples: + >>> class MyCustomModel(TorchRecurrentModel): + ... def __init__(self, config): + ... super().__init__(config) + ... + ... self.lstm = nn.LSTM( + ... input_size, recurrent_size, batch_first=True + ... ) + ... self.project = nn.Linear(recurrent_size, output_size) + ... + ... @property + ... def input_spec(self): + ... return SpecDict( + ... {"obs": "batch time hidden"}, hidden=self.config.input_size + ... ) + ... + ... @property + ... def output_spec(self): + ... return SpecDict( + ... {"logits": "batch time logits"}, logits=self.config.output_size + ... ) + ... + ... @property + ... def prev_state_spec(self): + ... return SpecDict( + ... {"input_state": "batch recur"}, recur=self.config.recurrent_size + ... ) + ... + ... @property + ... def next_state_spec(self): + ... return SpecDict( + ... {"output_state": "batch recur"}, + ... recur=self.config.recurrent_size + ... ) + ... + ... def _unroll(self, inputs, prev_state, **kwargs): + ... output, state = self.lstm(inputs["obs"], prev_state["input_state"]) + ... output = self.project(output) + ... return TensorDict( + ... {"logits": output}), TensorDict({"output_state": state} + ... ) + + """ + + def __init__(self, config: ModelConfig) -> None: + RecurrentModel.__init__(self) + nn.Module.__init__(self) + TorchModelIO.__init__(self, config) + + @override(RecurrentModel) + def _initial_state(self) -> TensorDict: + """Returns the initial recurrent state + + This defaults to all zeros and can be overidden to return + nonzero tensors. + + Returns: + A TensorDict that matches the initial_state_spec + """ + return TensorDict( + tree.map_structure( + lambda spec: torch.zeros(spec.shape, dtype=spec.dtype), + self.initial_state_spec, + ) + ) + + +class TorchModel(Model, nn.Module, TorchModelIO): + """The base class for non-recurrent pytorch models. + + If implementing a custom pytorch model, you likely want to + inherit from this class. You should make sure to call super().__init__(config) + in your __init__. + + Args: + config: The config used to construct the model + + Required Attributes: + input_spec: SpecDict: Denotes the input keys and shapes passed to `_forward` + output_spec: SpecDict: Denotes the output keys and shapes returned from + `_forward` + + Required Overrides: + # Define unrolling (forward pass) over a sequence of inputs + _forward(self, inputs: TensorDict, **kwargs) + -> TensorDict + + Optional Overrides: + # Additional checks on the input before `_forward` + _update_inputs(inputs: TensorDict) -> TensorDict + + # Additional checks on the output after `_forward` + _update_outputs(outputs: TensorDict) -> TensorDict + + # Save model weights to path + save(self, path: str) -> None + + # Load model weights from path + load(self, path: str) -> None + + Examples: + >>> class MyCustomModel(TorchModel): + ... def __init__(self, config): + ... super().__init__(config) + ... self.mlp = nn.Sequential( + ... nn.Linear(input_size, hidden_size), + ... nn.ReLU(), + ... nn.Linear(hidden_size, hidden_size), + ... nn.ReLU(), + ... nn.Linear(hidden_size, output_size) + ... ) + ... + ... @property + ... def input_spec(self): + ... return SpecDict( + ... {"obs": "batch time hidden"}, hidden=self.config.input_size + ... ) + ... + ... @property + ... def output_spec(self): + ... return SpecDict( + ... {"logits": "batch time logits"}, logits=self.config.output_size + ... ) + ... + ... def _forward(self, inputs, **kwargs): + ... output = self.mlp(inputs["obs"]) + ... return TensorDict({"logits": output}) + + """ + + def __init__(self, config: ModelConfig) -> None: + Model.__init__(self) + nn.Module.__init__(self) + TorchModelIO.__init__(self, config) diff --git a/rllib/models/torch/modules/__init__.py b/rllib/models/torch/modules/__init__.py new file mode 100644 index 000000000000..2585dcc77abe --- /dev/null +++ b/rllib/models/torch/modules/__init__.py @@ -0,0 +1,13 @@ +from ray.rllib.models.torch.modules.gru_gate import GRUGate +from ray.rllib.models.torch.modules.multi_head_attention import MultiHeadAttention +from ray.rllib.models.torch.modules.relative_multi_head_attention import ( + RelativeMultiHeadAttention, +) +from ray.rllib.models.torch.modules.skip_connection import SkipConnection + +__all__ = [ + "GRUGate", + "RelativeMultiHeadAttention", + "SkipConnection", + "MultiHeadAttention", +] diff --git a/rllib/models/torch/modules/convtranspose2d_stack.py b/rllib/models/torch/modules/convtranspose2d_stack.py new file mode 100644 index 000000000000..f991400d3df0 --- /dev/null +++ b/rllib/models/torch/modules/convtranspose2d_stack.py @@ -0,0 +1,82 @@ +from typing import Tuple + +from ray.rllib.models.torch.misc import Reshape +from ray.rllib.models.utils import get_activation_fn, get_initializer +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() +if torch: + import torch.distributions as td + + +class ConvTranspose2DStack(nn.Module): + """ConvTranspose2D decoder generating an image distribution from a vector.""" + + def __init__( + self, + *, + input_size: int, + filters: Tuple[Tuple[int]] = ( + (1024, 5, 2), + (128, 5, 2), + (64, 6, 2), + (32, 6, 2), + ), + initializer="default", + bias_init=0, + activation_fn: str = "relu", + output_shape: Tuple[int] = (3, 64, 64) + ): + """Initializes a TransposedConv2DStack instance. + + Args: + input_size: The size of the 1D input vector, from which to + generate the image distribution. + filters (Tuple[Tuple[int]]): Tuple of filter setups (1 for each + ConvTranspose2D layer): [in_channels, kernel, stride]. + initializer (Union[str]): + bias_init: The initial bias values to use. + activation_fn: Activation function descriptor (str). + output_shape (Tuple[int]): Shape of the final output image. + """ + super().__init__() + self.activation = get_activation_fn(activation_fn, framework="torch") + self.output_shape = output_shape + initializer = get_initializer(initializer, framework="torch") + + in_channels = filters[0][0] + self.layers = [ + # Map from 1D-input vector to correct initial size for the + # Conv2DTransposed stack. + nn.Linear(input_size, in_channels), + # Reshape from the incoming 1D vector (input_size) to 1x1 image + # format (channels first). + Reshape([-1, in_channels, 1, 1]), + ] + for i, (_, kernel, stride) in enumerate(filters): + out_channels = ( + filters[i + 1][0] if i < len(filters) - 1 else output_shape[0] + ) + conv_transp = nn.ConvTranspose2d(in_channels, out_channels, kernel, stride) + # Apply initializer. + initializer(conv_transp.weight) + nn.init.constant_(conv_transp.bias, bias_init) + self.layers.append(conv_transp) + # Apply activation function, if provided and if not last layer. + if self.activation is not None and i < len(filters) - 1: + self.layers.append(self.activation()) + + # num-outputs == num-inputs for next layer. + in_channels = out_channels + + self._model = nn.Sequential(*self.layers) + + def forward(self, x): + # x is [batch, hor_length, input_size] + batch_dims = x.shape[:-1] + model_out = self._model(x) + + # Equivalent to making a multivariate diag. + reshape_size = batch_dims + self.output_shape + mean = model_out.view(*reshape_size) + return td.Independent(td.Normal(mean, 1.0), len(self.output_shape)) diff --git a/rllib/models/torch/modules/gru_gate.py b/rllib/models/torch/modules/gru_gate.py new file mode 100644 index 000000000000..7eee53534d6d --- /dev/null +++ b/rllib/models/torch/modules/gru_gate.py @@ -0,0 +1,65 @@ +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.framework import TensorType + +torch, nn = try_import_torch() + + +class GRUGate(nn.Module): + """Implements a gated recurrent unit for use in AttentionNet""" + + def __init__(self, dim: int, init_bias: int = 0.0, **kwargs): + """ + input_shape (torch.Tensor): dimension of the input + init_bias: Bias added to every input to stabilize training + """ + super().__init__(**kwargs) + # Xavier initialization of torch tensors + self._w_r = nn.Parameter(torch.zeros(dim, dim)) + self._w_z = nn.Parameter(torch.zeros(dim, dim)) + self._w_h = nn.Parameter(torch.zeros(dim, dim)) + nn.init.xavier_uniform_(self._w_r) + nn.init.xavier_uniform_(self._w_z) + nn.init.xavier_uniform_(self._w_h) + self.register_parameter("_w_r", self._w_r) + self.register_parameter("_w_z", self._w_z) + self.register_parameter("_w_h", self._w_h) + + self._u_r = nn.Parameter(torch.zeros(dim, dim)) + self._u_z = nn.Parameter(torch.zeros(dim, dim)) + self._u_h = nn.Parameter(torch.zeros(dim, dim)) + nn.init.xavier_uniform_(self._u_r) + nn.init.xavier_uniform_(self._u_z) + nn.init.xavier_uniform_(self._u_h) + self.register_parameter("_u_r", self._u_r) + self.register_parameter("_u_z", self._u_z) + self.register_parameter("_u_h", self._u_h) + + self._bias_z = nn.Parameter( + torch.zeros( + dim, + ).fill_(init_bias) + ) + self.register_parameter("_bias_z", self._bias_z) + + def forward(self, inputs: TensorType, **kwargs) -> TensorType: + # Pass in internal state first. + h, X = inputs + + r = torch.tensordot(X, self._w_r, dims=1) + torch.tensordot( + h, self._u_r, dims=1 + ) + r = torch.sigmoid(r) + + z = ( + torch.tensordot(X, self._w_z, dims=1) + + torch.tensordot(h, self._u_z, dims=1) + - self._bias_z + ) + z = torch.sigmoid(z) + + h_next = torch.tensordot(X, self._w_h, dims=1) + torch.tensordot( + (h * r), self._u_h, dims=1 + ) + h_next = torch.tanh(h_next) + + return (1 - z) * h + z * h_next diff --git a/rllib/models/torch/modules/multi_head_attention.py b/rllib/models/torch/modules/multi_head_attention.py new file mode 100644 index 000000000000..68413bde025b --- /dev/null +++ b/rllib/models/torch/modules/multi_head_attention.py @@ -0,0 +1,68 @@ +""" +[1] - Attention Is All You Need - Vaswani, Jones, Shazeer, Parmar, + Uszkoreit, Gomez, Kaiser - Google Brain/Research, U Toronto - 2017. + https://arxiv.org/pdf/1706.03762.pdf +""" +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.utils.torch_utils import sequence_mask +from ray.rllib.utils.framework import TensorType + +torch, nn = try_import_torch() + + +class MultiHeadAttention(nn.Module): + """A multi-head attention layer described in [1].""" + + def __init__( + self, in_dim: int, out_dim: int, num_heads: int, head_dim: int, **kwargs + ): + """ + in_dim: Dimension of input + out_dim: Dimension of output + num_heads: Number of attention heads + head_dim: Output dimension of each attention head + """ + super().__init__(**kwargs) + + # No bias or non-linearity. + self._num_heads = num_heads + self._head_dim = head_dim + self._qkv_layer = SlimFC( + in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False + ) + + self._linear_layer = SlimFC( + in_size=num_heads * head_dim, out_size=out_dim, use_bias=False + ) + + def forward(self, inputs: TensorType) -> TensorType: + L = list(inputs.size())[1] # length of segment + H = self._num_heads # number of attention heads + D = self._head_dim # attention head dimension + + qkv = self._qkv_layer(inputs) + + queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) + queries = queries[:, -L:] # only query based on the segment + + queries = torch.reshape(queries, [-1, L, H, D]) + keys = torch.reshape(keys, [-1, L, H, D]) + values = torch.reshape(values, [-1, L, H, D]) + + score = torch.einsum("bihd,bjhd->bijh", queries, keys) + score = score / D**0.5 + + # causal mask of the same length as the sequence + mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype) + mask = mask[None, :, :, None] + mask = mask.float() + + masked_score = score * mask + 1e30 * (mask - 1.0) + wmat = nn.functional.softmax(masked_score, dim=2) + + out = torch.einsum("bijh,bjhd->bihd", wmat, values) + shape = list(out.size())[:2] + [H * D] + # temp = torch.cat(temp2, [H * D], dim=0) + out = torch.reshape(out, shape) + return self._linear_layer(out) diff --git a/rllib/models/torch/modules/noisy_layer.py b/rllib/models/torch/modules/noisy_layer.py new file mode 100644 index 000000000000..8a9fe999cf79 --- /dev/null +++ b/rllib/models/torch/modules/noisy_layer.py @@ -0,0 +1,99 @@ +import numpy as np + +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.utils.framework import try_import_torch, TensorType + +torch, nn = try_import_torch() + + +class NoisyLayer(nn.Module): + r"""A Layer that adds learnable Noise to some previous layer's outputs. + + Consists of: + - a common dense layer: y = w^{T}x + b + - a noisy layer: y = (w + \epsilon_w*\sigma_w)^{T}x + + (b+\epsilon_b*\sigma_b) + , where \epsilon are random variables sampled from factorized normal + distributions and \sigma are trainable variables which are expected to + vanish along the training procedure. + """ + + def __init__( + self, in_size: int, out_size: int, sigma0: float, activation: str = "relu" + ): + """Initializes a NoisyLayer object. + + Args: + in_size: Input size for Noisy Layer + out_size: Output size for Noisy Layer + sigma0: Initialization value for sigma_b (bias noise) + activation: Non-linear activation for Noisy Layer + """ + super().__init__() + + self.in_size = in_size + self.out_size = out_size + self.sigma0 = sigma0 + self.activation = get_activation_fn(activation, framework="torch") + if self.activation is not None: + self.activation = self.activation() + + sigma_w = nn.Parameter( + torch.from_numpy( + np.random.uniform( + low=-1.0 / np.sqrt(float(self.in_size)), + high=1.0 / np.sqrt(float(self.in_size)), + size=[self.in_size, out_size], + ) + ).float() + ) + self.register_parameter("sigma_w", sigma_w) + sigma_b = nn.Parameter( + torch.from_numpy( + np.full( + shape=[out_size], fill_value=sigma0 / np.sqrt(float(self.in_size)) + ) + ).float() + ) + self.register_parameter("sigma_b", sigma_b) + + w = nn.Parameter( + torch.from_numpy( + np.full( + shape=[self.in_size, self.out_size], + fill_value=6 / np.sqrt(float(in_size) + float(out_size)), + ) + ).float() + ) + self.register_parameter("w", w) + b = nn.Parameter(torch.from_numpy(np.zeros([out_size])).float()) + self.register_parameter("b", b) + + def forward(self, inputs: TensorType) -> TensorType: + epsilon_in = self._f_epsilon( + torch.normal( + mean=torch.zeros([self.in_size]), std=torch.ones([self.in_size]) + ).to(inputs.device) + ) + epsilon_out = self._f_epsilon( + torch.normal( + mean=torch.zeros([self.out_size]), std=torch.ones([self.out_size]) + ).to(inputs.device) + ) + epsilon_w = torch.matmul( + torch.unsqueeze(epsilon_in, -1), other=torch.unsqueeze(epsilon_out, 0) + ) + epsilon_b = epsilon_out + + action_activation = ( + torch.matmul(inputs, self.w + self.sigma_w * epsilon_w) + + self.b + + self.sigma_b * epsilon_b + ) + + if self.activation is not None: + action_activation = self.activation(action_activation) + return action_activation + + def _f_epsilon(self, x: TensorType) -> TensorType: + return torch.sign(x) * torch.pow(torch.abs(x), 0.5) diff --git a/rllib/models/torch/modules/relative_multi_head_attention.py b/rllib/models/torch/modules/relative_multi_head_attention.py new file mode 100644 index 000000000000..d3ff9cf59eee --- /dev/null +++ b/rllib/models/torch/modules/relative_multi_head_attention.py @@ -0,0 +1,175 @@ +from typing import Union + +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.utils.torch_utils import sequence_mask +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +class RelativePositionEmbedding(nn.Module): + """Creates a [seq_length x seq_length] matrix for rel. pos encoding. + + Denoted as Phi in [2] and [3]. Phi is the standard sinusoid encoding + matrix. + + Args: + seq_length: The max. sequence length (time axis). + out_dim: The number of nodes to go into the first Tranformer + layer with. + + Returns: + torch.Tensor: The encoding matrix Phi. + """ + + def __init__(self, out_dim, **kwargs): + super().__init__() + self.out_dim = out_dim + + out_range = torch.arange(0, self.out_dim, 2.0) + inverse_freq = 1 / (10000 ** (out_range / self.out_dim)) + self.register_buffer("inverse_freq", inverse_freq) + + def forward(self, seq_length): + pos_input = torch.arange(seq_length - 1, -1, -1.0, dtype=torch.float).to( + self.inverse_freq.device + ) + sinusoid_input = torch.einsum("i,j->ij", pos_input, self.inverse_freq) + pos_embeddings = torch.cat( + [torch.sin(sinusoid_input), torch.cos(sinusoid_input)], dim=-1 + ) + return pos_embeddings[:, None, :] + + +class RelativeMultiHeadAttention(nn.Module): + """A RelativeMultiHeadAttention layer as described in [3]. + + Uses segment level recurrence with state reuse. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_heads: int, + head_dim: int, + input_layernorm: bool = False, + output_activation: Union[str, callable] = None, + **kwargs + ): + """Initializes a RelativeMultiHeadAttention nn.Module object. + + Args: + in_dim (int): + out_dim: The output dimension of this module. Also known as + "attention dim". + num_heads: The number of attention heads to use. + Denoted `H` in [2]. + head_dim: The dimension of a single(!) attention head + Denoted `D` in [2]. + input_layernorm: Whether to prepend a LayerNorm before + everything else. Should be True for building a GTrXL. + output_activation (Union[str, callable]): Optional activation + function or activation function specifier (str). + Should be "relu" for GTrXL. + **kwargs: + """ + super().__init__(**kwargs) + + # No bias or non-linearity. + self._num_heads = num_heads + self._head_dim = head_dim + + # 3=Query, key, and value inputs. + self._qkv_layer = SlimFC( + in_size=in_dim, out_size=3 * num_heads * head_dim, use_bias=False + ) + + self._linear_layer = SlimFC( + in_size=num_heads * head_dim, + out_size=out_dim, + use_bias=False, + activation_fn=output_activation, + ) + + self._uvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + self._vvar = nn.Parameter(torch.zeros(num_heads, head_dim)) + nn.init.xavier_uniform_(self._uvar) + nn.init.xavier_uniform_(self._vvar) + self.register_parameter("_uvar", self._uvar) + self.register_parameter("_vvar", self._vvar) + + self._pos_proj = SlimFC( + in_size=in_dim, out_size=num_heads * head_dim, use_bias=False + ) + self._rel_pos_embedding = RelativePositionEmbedding(out_dim) + + self._input_layernorm = None + if input_layernorm: + self._input_layernorm = torch.nn.LayerNorm(in_dim) + + def forward(self, inputs: TensorType, memory: TensorType = None) -> TensorType: + T = list(inputs.size())[1] # length of segment (time) + H = self._num_heads # number of attention heads + d = self._head_dim # attention head dimension + + # Add previous memory chunk (as const, w/o gradient) to input. + # Tau (number of (prev) time slices in each memory chunk). + Tau = list(memory.shape)[1] + inputs = torch.cat((memory.detach(), inputs), dim=1) + + # Apply the Layer-Norm. + if self._input_layernorm is not None: + inputs = self._input_layernorm(inputs) + + qkv = self._qkv_layer(inputs) + + queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) + # Cut out Tau memory timesteps from query. + queries = queries[:, -T:] + + queries = torch.reshape(queries, [-1, T, H, d]) + keys = torch.reshape(keys, [-1, Tau + T, H, d]) + values = torch.reshape(values, [-1, Tau + T, H, d]) + + R = self._pos_proj(self._rel_pos_embedding(Tau + T)) + R = torch.reshape(R, [Tau + T, H, d]) + + # b=batch + # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space) + # h=head + # d=head-dim (over which we will reduce-sum) + score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys) + pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R) + score = score + self.rel_shift(pos_score) + score = score / d**0.5 + + # causal mask of the same length as the sequence + mask = sequence_mask(torch.arange(Tau + 1, Tau + T + 1), dtype=score.dtype).to( + score.device + ) + mask = mask[None, :, :, None] + + masked_score = score * mask + 1e30 * (mask.float() - 1.0) + wmat = nn.functional.softmax(masked_score, dim=2) + + out = torch.einsum("bijh,bjhd->bihd", wmat, values) + shape = list(out.shape)[:2] + [H * d] + out = torch.reshape(out, shape) + + return self._linear_layer(out) + + @staticmethod + def rel_shift(x: TensorType) -> TensorType: + # Transposed version of the shift approach described in [3]. + # https://github.com/kimiyoung/transformer-xl/blob/ + # 44781ed21dbaec88b280f74d9ae2877f52b492a5/tf/model.py#L31 + x_size = list(x.shape) + + x = torch.nn.functional.pad(x, (0, 0, 1, 0, 0, 0, 0, 0)) + x = torch.reshape(x, [x_size[0], x_size[2] + 1, x_size[1], x_size[3]]) + x = x[:, 1:, :, :] + x = torch.reshape(x, x_size) + + return x diff --git a/rllib/models/torch/modules/skip_connection.py b/rllib/models/torch/modules/skip_connection.py new file mode 100644 index 000000000000..8bc155eda9ca --- /dev/null +++ b/rllib/models/torch/modules/skip_connection.py @@ -0,0 +1,41 @@ +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from typing import Optional + +torch, nn = try_import_torch() + + +class SkipConnection(nn.Module): + """Skip connection layer. + + Adds the original input to the output (regular residual layer) OR uses + input as hidden state input to a given fan_in_layer. + """ + + def __init__( + self, layer: nn.Module, fan_in_layer: Optional[nn.Module] = None, **kwargs + ): + """Initializes a SkipConnection nn Module object. + + Args: + layer (nn.Module): Any layer processing inputs. + fan_in_layer (Optional[nn.Module]): An optional + layer taking two inputs: The original input and the output + of `layer`. + """ + super().__init__(**kwargs) + self._layer = layer + self._fan_in_layer = fan_in_layer + + def forward(self, inputs: TensorType, **kwargs) -> TensorType: + # del kwargs + outputs = self._layer(inputs, **kwargs) + # Residual case, just add inputs to outputs. + if self._fan_in_layer is None: + outputs = outputs + inputs + # Fan-in e.g. RNN: Call fan-in with `inputs` and `outputs`. + else: + # NOTE: In the GRU case, `inputs` is the state input. + outputs = self._fan_in_layer((inputs, outputs)) + + return outputs diff --git a/rllib/models/torch/noop.py b/rllib/models/torch/noop.py new file mode 100644 index 000000000000..8b0705b11874 --- /dev/null +++ b/rllib/models/torch/noop.py @@ -0,0 +1,13 @@ +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.annotations import override + + +class TorchNoopModel(TorchModelV2): + """Trivial model that just returns the obs flattened. + + This is the model used if use_state_preprocessor=False.""" + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + return input_dict["obs_flat"].float(), state diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py new file mode 100644 index 000000000000..191a0ff35e5a --- /dev/null +++ b/rllib/models/torch/primitives.py @@ -0,0 +1,54 @@ +from typing import List, Optional +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + +# TODO (Kourosh): Find a better hierarchy for the primitives after the POC is done. + + +class FCNet(nn.Module): + """A simple fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + hidden_layers: The sizes of the hidden layers. + activation: The activation function to use after each layer. + """ + + def __init__( + self, + input_dim: int, + hidden_layers: List[int], + output_dim: Optional[int] = None, + activation: str = "linear", + ): + super().__init__() + self.input_dim = input_dim + self.hidden_layers = hidden_layers + + activation_class = getattr(nn, activation, lambda: None)() + self.layers = [] + self.layers.append(nn.Linear(self.input_dim, self.hidden_layers[0])) + for i in range(len(self.hidden_layers) - 1): + if activation != "linear": + self.layers.append(activation_class) + self.layers.append( + nn.Linear(self.hidden_layers[i], self.hidden_layers[i + 1]) + ) + + if output_dim is not None: + if activation != "linear": + self.layers.append(activation_class) + self.layers.append(nn.Linear(self.hidden_layers[-1], output_dim)) + + if output_dim is None: + self.output_dim = hidden_layers[-1] + else: + self.output_dim = output_dim + + self.layers = nn.Sequential(*self.layers) + + def forward(self, x): + return self.layers(x) diff --git a/rllib/models/torch/recurrent_net.py b/rllib/models/torch/recurrent_net.py new file mode 100644 index 000000000000..ec3f7b3b797c --- /dev/null +++ b/rllib/models/torch/recurrent_net.py @@ -0,0 +1,285 @@ +import numpy as np +import gymnasium as gym +from gymnasium.spaces import Discrete, MultiDiscrete +import tree # pip install dm_tree +from typing import Dict, List, Union, Tuple + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.misc import SlimFC +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.policy.rnn_sequencing import add_time_dimension +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.view_requirement import ViewRequirement +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor, one_hot +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +torch, nn = try_import_torch() + + +@DeveloperAPI +class RecurrentNetwork(TorchModelV2): + """Helper class to simplify implementing RNN models with TorchModelV2. + + Instead of implementing forward(), you can implement forward_rnn() which + takes batches with the time dimension added already. + + Here is an example implementation for a subclass + ``MyRNNClass(RecurrentNetwork, nn.Module)``:: + + def __init__(self, obs_space, num_outputs): + nn.Module.__init__(self) + super().__init__(obs_space, action_space, num_outputs, + model_config, name) + self.obs_size = _get_size(obs_space) + self.rnn_hidden_dim = model_config["lstm_cell_size"] + self.fc1 = nn.Linear(self.obs_size, self.rnn_hidden_dim) + self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) + self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) + + self.value_branch = nn.Linear(self.rnn_hidden_dim, 1) + self._cur_value = None + + @override(ModelV2) + def get_initial_state(self): + # Place hidden states on same device as model. + h = [self.fc1.weight.new( + 1, self.rnn_hidden_dim).zero_().squeeze(0)] + return h + + @override(ModelV2) + def value_function(self): + assert self._cur_value is not None, "must call forward() first" + return self._cur_value + + @override(RecurrentNetwork) + def forward_rnn(self, input_dict, state, seq_lens): + x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float())) + h_in = state[0].reshape(-1, self.rnn_hidden_dim) + h = self.rnn(x, h_in) + q = self.fc2(h) + self._cur_value = self.value_branch(h).squeeze(1) + return q, [h] + """ + + @override(ModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> Tuple[TensorType, List[TensorType]]: + """Adds time dimension to batch before sending inputs to forward_rnn(). + + You should implement forward_rnn() in your subclass.""" + flat_inputs = input_dict["obs_flat"].float() + # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max() + # as input_dict may have extra zero-padding beyond seq_lens.max(). + # Use add_time_dimension to handle this + self.time_major = self.model_config.get("_time_major", False) + inputs = add_time_dimension( + flat_inputs, + seq_lens=seq_lens, + framework="torch", + time_major=self.time_major, + ) + output, new_state = self.forward_rnn(inputs, state, seq_lens) + output = torch.reshape(output, [-1, self.num_outputs]) + return output, new_state + + def forward_rnn( + self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType + ) -> Tuple[TensorType, List[TensorType]]: + """Call the model with the given input tensors and state. + + Args: + inputs: Observation tensor with shape [B, T, obs_size]. + state: List of state tensors, each with shape [B, size]. + seq_lens: 1D tensor holding input sequence lengths. + Note: len(seq_lens) == B. + + Returns: + (outputs, new_state): The model output tensor of shape + [B, T, num_outputs] and the list of new state tensors each with + shape [B, size]. + + Examples: + def forward_rnn(self, inputs, state, seq_lens): + model_out, h, c = self.rnn_model([inputs, seq_lens] + state) + return model_out, [h, c] + """ + raise NotImplementedError("You must implement this for an RNN model") + + +class LSTMWrapper(RecurrentNetwork, nn.Module): + """An LSTM wrapper serving as an interface for ModelV2s that set use_lstm.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + + nn.Module.__init__(self) + super(LSTMWrapper, self).__init__( + obs_space, action_space, None, model_config, name + ) + + # At this point, self.num_outputs is the number of nodes coming + # from the wrapped (underlying) model. In other words, self.num_outputs + # is the input size for the LSTM layer. + # If None, set it to the observation space. + if self.num_outputs is None: + self.num_outputs = int(np.product(self.obs_space.shape)) + + self.cell_size = model_config["lstm_cell_size"] + self.time_major = model_config.get("_time_major", False) + self.use_prev_action = model_config["lstm_use_prev_action"] + self.use_prev_reward = model_config["lstm_use_prev_reward"] + + self.action_space_struct = get_base_struct_from_space(self.action_space) + self.action_dim = 0 + + for space in tree.flatten(self.action_space_struct): + if isinstance(space, Discrete): + self.action_dim += space.n + elif isinstance(space, MultiDiscrete): + self.action_dim += np.sum(space.nvec) + elif space.shape is not None: + self.action_dim += int(np.product(space.shape)) + else: + self.action_dim += int(len(space)) + + # Add prev-action/reward nodes to input to LSTM. + if self.use_prev_action: + self.num_outputs += self.action_dim + if self.use_prev_reward: + self.num_outputs += 1 + + # Define actual LSTM layer (with num_outputs being the nodes coming + # from the wrapped (underlying) layer). + self.lstm = nn.LSTM( + self.num_outputs, self.cell_size, batch_first=not self.time_major + ) + + # Set self.num_outputs to the number of output nodes desired by the + # caller of this constructor. + self.num_outputs = num_outputs + + # Postprocess LSTM output with another hidden layer and compute values. + self._logits_branch = SlimFC( + in_size=self.cell_size, + out_size=self.num_outputs, + activation_fn=None, + initializer=torch.nn.init.xavier_uniform_, + ) + self._value_branch = SlimFC( + in_size=self.cell_size, + out_size=1, + activation_fn=None, + initializer=torch.nn.init.xavier_uniform_, + ) + + # __sphinx_doc_begin__ + # Add prev-a/r to this model's view, if required. + if model_config["lstm_use_prev_action"]: + self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement( + SampleBatch.ACTIONS, space=self.action_space, shift=-1 + ) + if model_config["lstm_use_prev_reward"]: + self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement( + SampleBatch.REWARDS, shift=-1 + ) + # __sphinx_doc_end__ + + @override(RecurrentNetwork) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> Tuple[TensorType, List[TensorType]]: + assert seq_lens is not None + # Push obs through "unwrapped" net's `forward()` first. + wrapped_out, _ = self._wrapped_forward(input_dict, [], None) + + # Concat. prev-action/reward if required. + prev_a_r = [] + + # Prev actions. + if self.model_config["lstm_use_prev_action"]: + prev_a = input_dict[SampleBatch.PREV_ACTIONS] + # If actions are not processed yet (in their original form as + # have been sent to environment): + # Flatten/one-hot into 1D array. + if self.model_config["_disable_action_flattening"]: + prev_a_r.append( + flatten_inputs_to_1d_tensor( + prev_a, spaces_struct=self.action_space_struct, time_axis=False + ) + ) + # If actions are already flattened (but not one-hot'd yet!), + # one-hot discrete/multi-discrete actions here. + else: + if isinstance(self.action_space, (Discrete, MultiDiscrete)): + prev_a = one_hot(prev_a.float(), self.action_space) + else: + prev_a = prev_a.float() + prev_a_r.append(torch.reshape(prev_a, [-1, self.action_dim])) + # Prev rewards. + if self.model_config["lstm_use_prev_reward"]: + prev_a_r.append( + torch.reshape(input_dict[SampleBatch.PREV_REWARDS].float(), [-1, 1]) + ) + + # Concat prev. actions + rewards to the "main" input. + if prev_a_r: + wrapped_out = torch.cat([wrapped_out] + prev_a_r, dim=1) + + # Push everything through our LSTM. + input_dict["obs_flat"] = wrapped_out + return super().forward(input_dict, state, seq_lens) + + @override(RecurrentNetwork) + def forward_rnn( + self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType + ) -> Tuple[TensorType, List[TensorType]]: + # Don't show paddings to RNN(?) + # TODO: (sven) For now, only allow, iff time_major=True to not break + # anything retrospectively (time_major not supported previously). + # max_seq_len = inputs.shape[0] + # time_major = self.model_config["_time_major"] + # if time_major and max_seq_len > 1: + # inputs = torch.nn.utils.rnn.pack_padded_sequence( + # inputs, seq_lens, + # batch_first=not time_major, enforce_sorted=False) + self._features, [h, c] = self.lstm( + inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)] + ) + # Re-apply paddings. + # if time_major and max_seq_len > 1: + # self._features, _ = torch.nn.utils.rnn.pad_packed_sequence( + # self._features, + # batch_first=not time_major) + model_out = self._logits_branch(self._features) + return model_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] + + @override(ModelV2) + def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]: + # Place hidden states on same device as model. + linear = next(self._logits_branch._model.children()) + h = [ + linear.weight.new(1, self.cell_size).zero_().squeeze(0), + linear.weight.new(1, self.cell_size).zero_().squeeze(0), + ] + return h + + @override(ModelV2) + def value_function(self) -> TensorType: + assert self._features is not None, "must call forward() first" + return torch.reshape(self._value_branch(self._features), [-1]) diff --git a/rllib/models/torch/torch_action_dist.py b/rllib/models/torch/torch_action_dist.py new file mode 100644 index 000000000000..dadbec72f2f1 --- /dev/null +++ b/rllib/models/torch/torch_action_dist.py @@ -0,0 +1,648 @@ +import functools +import gymnasium as gym +from math import log +import numpy as np +import tree # pip install dm_tree +from typing import Optional + +from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import SMALL_NUMBER, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT +from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space +from ray.rllib.utils.typing import TensorType, List, Union, Tuple, ModelConfigDict + +torch, nn = try_import_torch() + + +@DeveloperAPI +class TorchDistributionWrapper(ActionDistribution): + """Wrapper class for torch.distributions.""" + + @override(ActionDistribution) + def __init__(self, inputs: List[TensorType], model: TorchModelV2): + # If inputs are not a torch Tensor, make them one and make sure they + # are on the correct device. + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + if isinstance(model, TorchModelV2): + inputs = inputs.to(next(model.parameters()).device) + super().__init__(inputs, model) + # Store the last sample here. + self.last_sample = None + + @override(ActionDistribution) + def logp(self, actions: TensorType) -> TensorType: + return self.dist.log_prob(actions) + + @override(ActionDistribution) + def entropy(self) -> TensorType: + return self.dist.entropy() + + @override(ActionDistribution) + def kl(self, other: ActionDistribution) -> TensorType: + return torch.distributions.kl.kl_divergence(self.dist, other.dist) + + @override(ActionDistribution) + def sample(self) -> TensorType: + self.last_sample = self.dist.sample() + return self.last_sample + + @override(ActionDistribution) + def sampled_action_logp(self) -> TensorType: + assert self.last_sample is not None + return self.logp(self.last_sample) + + +@DeveloperAPI +class TorchCategorical(TorchDistributionWrapper): + """Wrapper class for PyTorch Categorical distribution.""" + + @override(ActionDistribution) + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2 = None, + temperature: float = 1.0, + ): + if temperature != 1.0: + assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" + inputs /= temperature + super().__init__(inputs, model) + self.dist = torch.distributions.categorical.Categorical(logits=self.inputs) + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + self.last_sample = self.dist.probs.argmax(dim=1) + return self.last_sample + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + return action_space.n + + +@DeveloperAPI +def get_torch_categorical_class_with_temperature(t: float): + """TorchCategorical distribution class that has customized default temperature.""" + + class TorchCategoricalWithTemperature(TorchCategorical): + def __init__(self, inputs, model=None, temperature=t): + super().__init__(inputs, model, temperature) + + return TorchCategoricalWithTemperature + + +@DeveloperAPI +class TorchMultiCategorical(TorchDistributionWrapper): + """MultiCategorical distribution for MultiDiscrete action spaces.""" + + @override(TorchDistributionWrapper) + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2, + input_lens: Union[List[int], np.ndarray, Tuple[int, ...]], + action_space=None, + ): + super().__init__(inputs, model) + # If input_lens is np.ndarray or list, force-make it a tuple. + inputs_split = self.inputs.split(tuple(input_lens), dim=1) + self.cats = [ + torch.distributions.categorical.Categorical(logits=input_) + for input_ in inputs_split + ] + # Used in case we are dealing with an Int Box. + self.action_space = action_space + + @override(TorchDistributionWrapper) + def sample(self) -> TensorType: + arr = [cat.sample() for cat in self.cats] + sample_ = torch.stack(arr, dim=1) + if isinstance(self.action_space, gym.spaces.Box): + sample_ = torch.reshape(sample_, [-1] + list(self.action_space.shape)) + self.last_sample = sample_ + return sample_ + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + arr = [torch.argmax(cat.probs, -1) for cat in self.cats] + sample_ = torch.stack(arr, dim=1) + if isinstance(self.action_space, gym.spaces.Box): + sample_ = torch.reshape(sample_, [-1] + list(self.action_space.shape)) + self.last_sample = sample_ + return sample_ + + @override(TorchDistributionWrapper) + def logp(self, actions: TensorType) -> TensorType: + # # If tensor is provided, unstack it into list. + if isinstance(actions, torch.Tensor): + if isinstance(self.action_space, gym.spaces.Box): + actions = torch.reshape( + actions, [-1, int(np.prod(self.action_space.shape))] + ) + actions = torch.unbind(actions, dim=1) + logps = torch.stack([cat.log_prob(act) for cat, act in zip(self.cats, actions)]) + return torch.sum(logps, dim=0) + + @override(ActionDistribution) + def multi_entropy(self) -> TensorType: + return torch.stack([cat.entropy() for cat in self.cats], dim=1) + + @override(TorchDistributionWrapper) + def entropy(self) -> TensorType: + return torch.sum(self.multi_entropy(), dim=1) + + @override(ActionDistribution) + def multi_kl(self, other: ActionDistribution) -> TensorType: + return torch.stack( + [ + torch.distributions.kl.kl_divergence(cat, oth_cat) + for cat, oth_cat in zip(self.cats, other.cats) + ], + dim=1, + ) + + @override(TorchDistributionWrapper) + def kl(self, other: ActionDistribution) -> TensorType: + return torch.sum(self.multi_kl(other), dim=1) + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + # Int Box. + if isinstance(action_space, gym.spaces.Box): + assert action_space.dtype.name.startswith("int") + low_ = np.min(action_space.low) + high_ = np.max(action_space.high) + assert np.all(action_space.low == low_) + assert np.all(action_space.high == high_) + np.prod(action_space.shape, dtype=np.int32) * (high_ - low_ + 1) + # MultiDiscrete space. + else: + # `nvec` is already integer. No need to cast. + return np.sum(action_space.nvec) + + +@ExperimentalAPI +class TorchSlateMultiCategorical(TorchCategorical): + """MultiCategorical distribution for MultiDiscrete action spaces. + + The action space must be uniform, meaning all nvec items have the same size, e.g. + MultiDiscrete([10, 10, 10]), where 10 is the number of candidates to pick from + and 3 is the slate size (pick 3 out of 10). When picking candidates, no candidate + must be picked more than once. + """ + + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2 = None, + temperature: float = 1.0, + action_space: Optional[gym.spaces.MultiDiscrete] = None, + all_slates=None, + ): + assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" + # Allow softmax formula w/ temperature != 1.0: + # Divide inputs by temperature. + super().__init__(inputs / temperature, model) + self.action_space = action_space + # Assert uniformness of the action space (all discrete buckets have the same + # size). + assert isinstance(self.action_space, gym.spaces.MultiDiscrete) and all( + n == self.action_space.nvec[0] for n in self.action_space.nvec + ) + self.all_slates = all_slates + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + # Get a sample from the underlying Categorical (batch of ints). + sample = super().deterministic_sample() + # Use the sampled ints to pick the actual slates. + return torch.take_along_dim(self.all_slates, sample.long(), dim=-1) + + @override(ActionDistribution) + def logp(self, x: TensorType) -> TensorType: + # TODO: Implement. + return torch.ones_like(self.inputs[:, 0]) + + +@DeveloperAPI +class TorchDiagGaussian(TorchDistributionWrapper): + """Wrapper class for PyTorch Normal distribution.""" + + @override(ActionDistribution) + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2, + *, + action_space: Optional[gym.spaces.Space] = None + ): + super().__init__(inputs, model) + mean, log_std = torch.chunk(self.inputs, 2, dim=1) + self.log_std = log_std + self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std)) + # Remember to squeeze action samples in case action space is Box(shape) + self.zero_action_dim = action_space and action_space.shape == () + + @override(TorchDistributionWrapper) + def sample(self) -> TensorType: + sample = super().sample() + if self.zero_action_dim: + return torch.squeeze(sample, dim=-1) + return sample + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + self.last_sample = self.dist.mean + return self.last_sample + + @override(TorchDistributionWrapper) + def logp(self, actions: TensorType) -> TensorType: + return super().logp(actions).sum(-1) + + @override(TorchDistributionWrapper) + def entropy(self) -> TensorType: + return super().entropy().sum(-1) + + @override(TorchDistributionWrapper) + def kl(self, other: ActionDistribution) -> TensorType: + return super().kl(other).sum(-1) + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + return np.prod(action_space.shape, dtype=np.int32) * 2 + + +@DeveloperAPI +class TorchSquashedGaussian(TorchDistributionWrapper): + """A tanh-squashed Gaussian distribution defined by: mean, std, low, high. + + The distribution will never return low or high exactly, but + `low`+SMALL_NUMBER or `high`-SMALL_NUMBER respectively. + """ + + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2, + low: float = -1.0, + high: float = 1.0, + ): + """Parameterizes the distribution via `inputs`. + + Args: + low: The lowest possible sampling value + (excluding this value). + high: The highest possible sampling value + (excluding this value). + """ + super().__init__(inputs, model) + # Split inputs into mean and log(std). + mean, log_std = torch.chunk(self.inputs, 2, dim=-1) + # Clip `scale` values (coming from NN) to reasonable values. + log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT) + std = torch.exp(log_std) + self.dist = torch.distributions.normal.Normal(mean, std) + assert np.all(np.less(low, high)) + self.low = low + self.high = high + self.mean = mean + self.std = std + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + self.last_sample = self._squash(self.dist.mean) + return self.last_sample + + @override(TorchDistributionWrapper) + def sample(self) -> TensorType: + # Use the reparameterization version of `dist.sample` to allow for + # the results to be backprop'able e.g. in a loss term. + + normal_sample = self.dist.rsample() + self.last_sample = self._squash(normal_sample) + return self.last_sample + + @override(ActionDistribution) + def logp(self, x: TensorType) -> TensorType: + # Unsquash values (from [low,high] to ]-inf,inf[) + unsquashed_values = self._unsquash(x) + # Get log prob of unsquashed values from our Normal. + log_prob_gaussian = self.dist.log_prob(unsquashed_values) + # For safety reasons, clamp somehow, only then sum up. + log_prob_gaussian = torch.clamp(log_prob_gaussian, -100, 100) + log_prob_gaussian = torch.sum(log_prob_gaussian, dim=-1) + # Get log-prob for squashed Gaussian. + unsquashed_values_tanhd = torch.tanh(unsquashed_values) + log_prob = log_prob_gaussian - torch.sum( + torch.log(1 - unsquashed_values_tanhd**2 + SMALL_NUMBER), dim=-1 + ) + return log_prob + + def sample_logp(self): + z = self.dist.rsample() + actions = self._squash(z) + return actions, torch.sum( + self.dist.log_prob(z) - torch.log(1 - actions * actions + SMALL_NUMBER), + dim=-1, + ) + + @override(TorchDistributionWrapper) + def entropy(self) -> TensorType: + raise ValueError("Entropy not defined for SquashedGaussian!") + + @override(TorchDistributionWrapper) + def kl(self, other: ActionDistribution) -> TensorType: + raise ValueError("KL not defined for SquashedGaussian!") + + def _squash(self, raw_values: TensorType) -> TensorType: + # Returned values are within [low, high] (including `low` and `high`). + squashed = ((torch.tanh(raw_values) + 1.0) / 2.0) * ( + self.high - self.low + ) + self.low + return torch.clamp(squashed, self.low, self.high) + + def _unsquash(self, values: TensorType) -> TensorType: + normed_values = (values - self.low) / (self.high - self.low) * 2.0 - 1.0 + # Stabilize input to atanh. + save_normed_values = torch.clamp( + normed_values, -1.0 + SMALL_NUMBER, 1.0 - SMALL_NUMBER + ) + unsquashed = torch.atanh(save_normed_values) + return unsquashed + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + return np.prod(action_space.shape, dtype=np.int32) * 2 + + +@DeveloperAPI +class TorchBeta(TorchDistributionWrapper): + """ + A Beta distribution is defined on the interval [0, 1] and parameterized by + shape parameters alpha and beta (also called concentration parameters). + + PDF(x; alpha, beta) = x**(alpha - 1) (1 - x)**(beta - 1) / Z + with Z = Gamma(alpha) Gamma(beta) / Gamma(alpha + beta) + and Gamma(n) = (n - 1)! + """ + + def __init__( + self, + inputs: List[TensorType], + model: TorchModelV2, + low: float = 0.0, + high: float = 1.0, + ): + super().__init__(inputs, model) + # Stabilize input parameters (possibly coming from a linear layer). + self.inputs = torch.clamp(self.inputs, log(SMALL_NUMBER), -log(SMALL_NUMBER)) + self.inputs = torch.log(torch.exp(self.inputs) + 1.0) + 1.0 + self.low = low + self.high = high + alpha, beta = torch.chunk(self.inputs, 2, dim=-1) + # Note: concentration0==beta, concentration1=alpha (!) + self.dist = torch.distributions.Beta(concentration1=alpha, concentration0=beta) + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + self.last_sample = self._squash(self.dist.mean) + return self.last_sample + + @override(TorchDistributionWrapper) + def sample(self) -> TensorType: + # Use the reparameterization version of `dist.sample` to allow for + # the results to be backprop'able e.g. in a loss term. + normal_sample = self.dist.rsample() + self.last_sample = self._squash(normal_sample) + return self.last_sample + + @override(ActionDistribution) + def logp(self, x: TensorType) -> TensorType: + unsquashed_values = self._unsquash(x) + return torch.sum(self.dist.log_prob(unsquashed_values), dim=-1) + + def _squash(self, raw_values: TensorType) -> TensorType: + return raw_values * (self.high - self.low) + self.low + + def _unsquash(self, values: TensorType) -> TensorType: + return (values - self.low) / (self.high - self.low) + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + return np.prod(action_space.shape, dtype=np.int32) * 2 + + +@DeveloperAPI +class TorchDeterministic(TorchDistributionWrapper): + """Action distribution that returns the input values directly. + + This is similar to DiagGaussian with standard deviation zero (thus only + requiring the "mean" values as NN output). + """ + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + return self.inputs + + @override(TorchDistributionWrapper) + def sampled_action_logp(self) -> TensorType: + return torch.zeros((self.inputs.size()[0],), dtype=torch.float32) + + @override(TorchDistributionWrapper) + def sample(self) -> TensorType: + return self.deterministic_sample() + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape( + action_space: gym.Space, model_config: ModelConfigDict + ) -> Union[int, np.ndarray]: + return np.prod(action_space.shape, dtype=np.int32) + + +@DeveloperAPI +class TorchMultiActionDistribution(TorchDistributionWrapper): + """Action distribution that operates on multiple, possibly nested actions.""" + + def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): + """Initializes a TorchMultiActionDistribution object. + + Args: + inputs (torch.Tensor): A single tensor of shape [BATCH, size]. + model (TorchModelV2): The TorchModelV2 object used to produce + inputs for this distribution. + child_distributions (any[torch.Tensor]): Any struct + that contains the child distribution classes to use to + instantiate the child distributions from `inputs`. This could + be an already flattened list or a struct according to + `action_space`. + input_lens (any[int]): A flat list or a nested struct of input + split lengths used to split `inputs`. + action_space (Union[gym.spaces.Dict,gym.spaces.Tuple]): The complex + and possibly nested action space. + """ + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + if isinstance(model, TorchModelV2): + inputs = inputs.to(next(model.parameters()).device) + super().__init__(inputs, model) + + self.action_space_struct = get_base_struct_from_space(action_space) + + self.input_lens = tree.flatten(input_lens) + flat_child_distributions = tree.flatten(child_distributions) + split_inputs = torch.split(inputs, self.input_lens, dim=1) + self.flat_child_distributions = tree.map_structure( + lambda dist, input_: dist(input_, model), + flat_child_distributions, + list(split_inputs), + ) + + @override(ActionDistribution) + def logp(self, x): + if isinstance(x, np.ndarray): + x = torch.Tensor(x) + # Single tensor input (all merged). + if isinstance(x, torch.Tensor): + split_indices = [] + for dist in self.flat_child_distributions: + if isinstance(dist, TorchCategorical): + split_indices.append(1) + elif ( + isinstance(dist, TorchMultiCategorical) + and dist.action_space is not None + ): + split_indices.append(int(np.prod(dist.action_space.shape))) + else: + sample = dist.sample() + # Cover Box(shape=()) case. + if len(sample.shape) == 1: + split_indices.append(1) + else: + split_indices.append(sample.size()[1]) + split_x = list(torch.split(x, split_indices, dim=1)) + # Structured or flattened (by single action component) input. + else: + split_x = tree.flatten(x) + + def map_(val, dist): + # Remove extra categorical dimension. + if isinstance(dist, TorchCategorical): + val = (torch.squeeze(val, dim=-1) if len(val.shape) > 1 else val).int() + return dist.logp(val) + + # Remove extra categorical dimension and take the logp of each + # component. + flat_logps = tree.map_structure(map_, split_x, self.flat_child_distributions) + + return functools.reduce(lambda a, b: a + b, flat_logps) + + @override(ActionDistribution) + def kl(self, other): + kl_list = [ + d.kl(o) + for d, o in zip( + self.flat_child_distributions, other.flat_child_distributions + ) + ] + return functools.reduce(lambda a, b: a + b, kl_list) + + @override(ActionDistribution) + def entropy(self): + entropy_list = [d.entropy() for d in self.flat_child_distributions] + return functools.reduce(lambda a, b: a + b, entropy_list) + + @override(ActionDistribution) + def sample(self): + child_distributions = tree.unflatten_as( + self.action_space_struct, self.flat_child_distributions + ) + return tree.map_structure(lambda s: s.sample(), child_distributions) + + @override(ActionDistribution) + def deterministic_sample(self): + child_distributions = tree.unflatten_as( + self.action_space_struct, self.flat_child_distributions + ) + return tree.map_structure( + lambda s: s.deterministic_sample(), child_distributions + ) + + @override(TorchDistributionWrapper) + def sampled_action_logp(self): + p = self.flat_child_distributions[0].sampled_action_logp() + for c in self.flat_child_distributions[1:]: + p += c.sampled_action_logp() + return p + + @override(ActionDistribution) + def required_model_output_shape(self, action_space, model_config): + return np.sum(self.input_lens, dtype=np.int32) + + +@DeveloperAPI +class TorchDirichlet(TorchDistributionWrapper): + """Dirichlet distribution for continuous actions that are between + [0,1] and sum to 1. + + e.g. actions that represent resource allocation.""" + + def __init__(self, inputs, model): + """Input is a tensor of logits. The exponential of logits is used to + parametrize the Dirichlet distribution as all parameters need to be + positive. An arbitrary small epsilon is added to the concentration + parameters to be zero due to numerical error. + + See issue #4440 for more details. + """ + self.epsilon = torch.tensor(1e-7).to(inputs.device) + concentration = torch.exp(inputs) + self.epsilon + self.dist = torch.distributions.dirichlet.Dirichlet( + concentration=concentration, + validate_args=True, + ) + super().__init__(concentration, model) + + @override(ActionDistribution) + def deterministic_sample(self) -> TensorType: + self.last_sample = nn.functional.softmax(self.dist.concentration) + return self.last_sample + + @override(ActionDistribution) + def logp(self, x): + # Support of Dirichlet are positive real numbers. x is already + # an array of positive numbers, but we clip to avoid zeros due to + # numerical errors. + x = torch.max(x, self.epsilon) + x = x / torch.sum(x, dim=-1, keepdim=True) + return self.dist.log_prob(x) + + @override(ActionDistribution) + def entropy(self): + return self.dist.entropy() + + @override(ActionDistribution) + def kl(self, other): + return self.dist.kl_divergence(other.dist) + + @staticmethod + @override(ActionDistribution) + def required_model_output_shape(action_space, model_config): + return np.prod(action_space.shape, dtype=np.int32) diff --git a/rllib/models/torch/torch_distributions.py b/rllib/models/torch/torch_distributions.py new file mode 100644 index 000000000000..809c516897ef --- /dev/null +++ b/rllib/models/torch/torch_distributions.py @@ -0,0 +1,257 @@ +"""The main difference between this and the old ActionDistribution is that this one +has more explicit input args. So that the input format does not have to be guessed from +the code. This matches the design pattern of torch distribution which developers may +already be familiar with. +""" +import gymnasium as gym +import numpy as np +from typing import Optional +import abc + + +from ray.rllib.models.distributions import Distribution +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType, Union, Tuple, ModelConfigDict + +torch, nn = try_import_torch() + + +@DeveloperAPI +class TorchDistribution(Distribution, abc.ABC): + """Wrapper class for torch.distributions.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self._dist = self._get_torch_distribution(*args, **kwargs) + + @abc.abstractmethod + def _get_torch_distribution( + self, *args, **kwargs + ) -> torch.distributions.Distribution: + """Returns the torch.distributions.Distribution object to use.""" + + @override(Distribution) + def logp(self, value: TensorType, **kwargs) -> TensorType: + return self._dist.log_prob(value, **kwargs) + + @override(Distribution) + def entropy(self) -> TensorType: + return self._dist.entropy() + + @override(Distribution) + def kl(self, other: "Distribution") -> TensorType: + return torch.distributions.kl.kl_divergence(self._dist, other._dist) + + @override(Distribution) + def sample( + self, *, sample_shape=torch.Size(), return_logp: bool = False + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + sample = self._dist.sample(sample_shape) + if return_logp: + return sample, self.logp(sample) + return sample + + @override(Distribution) + def rsample( + self, *, sample_shape=torch.Size(), return_logp: bool = False + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + rsample = self._dist.rsample(sample_shape) + if return_logp: + return rsample, self.logp(rsample) + return rsample + + +@DeveloperAPI +class TorchCategorical(TorchDistribution): + """Wrapper class for PyTorch Categorical distribution. + + Creates a categorical distribution parameterized by either :attr:`probs` or + :attr:`logits` (but not both). + + Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is + ``probs.size(-1)``. + + If `probs` is 1-dimensional with length-`K`, each element is the relative + probability of sampling the class at that index. + + If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. + + Example:: + >>> m = TorchCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample(sample_shape=(2,)) # equal probability of 0, 1, 2, 3 + tensor([3, 4]) + + Args: + probs: The probablities of each event. + logits: Event log probabilities (unnormalized) + temperature: In case of using logits, this parameter can be used to determine + the sharpness of the distribution. i.e. + ``probs = softmax(logits / temperature)``. The temperature must be strictly + positive. A low value (e.g. 1e-10) will result in argmax sampling while a + larger value will result in uniform sampling. + """ + + @override(TorchDistribution) + def __init__( + self, + probs: torch.Tensor = None, + logits: torch.Tensor = None, + temperature: float = 1.0, + ) -> None: + super().__init__(probs=probs, logits=logits, temperature=temperature) + + @override(TorchDistribution) + def _get_torch_distribution( + self, + probs: torch.Tensor = None, + logits: torch.Tensor = None, + temperature: float = 1.0, + ) -> torch.distributions.Distribution: + if logits is not None: + assert temperature > 0.0, "Categorical `temperature` must be > 0.0!" + logits /= temperature + return torch.distributions.categorical.Categorical(probs, logits) + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + return (space.n,) + + +@DeveloperAPI +class TorchDiagGaussian(TorchDistribution): + """Wrapper class for PyTorch Normal distribution. + + Creates a normal distribution parameterized by :attr:`loc` and :attr:`scale`. In + case of multi-dimensional distribution, the variance is assumed to be diagonal. + + Example:: + + >>> m = Normal(loc=torch.tensor([0.0, 0.0]), scale=torch.tensor([1.0, 1.0])) + >>> m.sample(sample_shape=(2,)) # 2d normal dist with loc=0 and scale=1 + tensor([[ 0.1046, -0.6120], [ 0.234, 0.556]]) + + >>> # scale is None + >>> m = Normal(loc=torch.tensor([0.0, 1.0])) + >>> m.sample(sample_shape=(2,)) # normally distributed with loc=0 and scale=1 + tensor([0.1046, 0.6120]) + + + Args: + loc: mean of the distribution (often referred to as mu). If scale is None, the + second half of the `loc` will be used as the log of scale. + scale: standard deviation of the distribution (often referred to as sigma). + Has to be positive. + """ + + @override(TorchDistribution) + def __init__( + self, + loc: Union[float, torch.Tensor], + scale: Optional[Union[float, torch.Tensor]] = None, + ): + super().__init__(loc=loc, scale=scale) + + def _get_torch_distribution( + self, loc, scale=None + ) -> torch.distributions.Distribution: + if scale is None: + loc, log_std = torch.chunk(self.inputs, 2, dim=1) + scale = torch.exp(log_std) + return torch.distributions.normal.Normal(loc, scale) + + @override(TorchDistribution) + def logp(self, value: TensorType) -> TensorType: + return super().logp(value).sum(-1) + + @override(TorchDistribution) + def entropy(self) -> TensorType: + return super().entropy().sum(-1) + + @override(TorchDistribution) + def kl(self, other: "TorchDistribution") -> TensorType: + return super().kl(other).sum(-1) + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + return tuple(np.prod(space.shape, dtype=np.int32) * 2) + + +@DeveloperAPI +class TorchDeterministic(Distribution): + """The distribution that returns the input values directly. + + This is similar to DiagGaussian with standard deviation zero (thus only + requiring the "mean" values as NN output). + + Note: entropy is always zero, ang logp and kl are not implemented. + + Example:: + + >>> m = TorchDeterministic(loc=torch.tensor([0.0, 0.0])) + >>> m.sample(sample_shape=(2,)) + tensor([[ 0.0, 0.0], [ 0.0, 0.0]]) + + Args: + loc: the determinsitic value to return + """ + + @override(Distribution) + def __init__(self, loc: torch.Tensor) -> None: + super().__init__() + self.loc = loc + + @override(Distribution) + def sample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs, + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + if return_logp: + raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") + + if sample_shape is None: + sample_shape = torch.Size() + + device = self.loc.device + dtype = self.loc.dtype + shape = sample_shape + self.loc.shape + return torch.ones(shape, device=device, dtype=dtype) * self.loc + + def rsample( + self, + *, + sample_shape: Tuple[int, ...] = None, + return_logp: bool = False, + **kwargs, + ) -> Union[TensorType, Tuple[TensorType, TensorType]]: + raise NotImplementedError + + @override(Distribution) + def logp(self, value: TensorType, **kwargs) -> TensorType: + raise ValueError(f"Cannot return logp for {self.__class__.__name__}.") + + @override(Distribution) + def entropy(self, **kwargs) -> TensorType: + raise torch.zeros_like(self.loc) + + @override(Distribution) + def kl(self, other: "Distribution", **kwargs) -> TensorType: + raise ValueError(f"Cannot return kl for {self.__class__.__name__}.") + + @staticmethod + @override(Distribution) + def required_model_output_shape( + space: gym.Space, model_config: ModelConfigDict + ) -> Tuple[int, ...]: + # TODO: This was copied from previous code. Is this correct? add unit test. + return tuple(np.prod(space.shape, dtype=np.int32)) diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py new file mode 100644 index 000000000000..b56bf425fb6f --- /dev/null +++ b/rllib/models/torch/torch_modelv2.py @@ -0,0 +1,81 @@ +import gymnasium as gym +from typing import Dict, List, Union + +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +_, nn = try_import_torch() + + +@PublicAPI +class TorchModelV2(ModelV2): + """Torch version of ModelV2. + + Note that this class by itself is not a valid model unless you + inherit from nn.Module and implement forward() in a subclass.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + """Initialize a TorchModelV2. + + Here is an example implementation for a subclass + ``MyModelClass(TorchModelV2, nn.Module)``:: + + def __init__(self, *args, **kwargs): + TorchModelV2.__init__(self, *args, **kwargs) + nn.Module.__init__(self) + self._hidden_layers = nn.Sequential(...) + self._logits = ... + self._value_branch = ... + """ + + if not isinstance(self, nn.Module): + raise ValueError( + "Subclasses of TorchModelV2 must also inherit from " + "nn.Module, e.g., MyModel(TorchModelV2, nn.Module)" + ) + + ModelV2.__init__( + self, + obs_space, + action_space, + num_outputs, + model_config, + name, + framework="torch", + ) + + # Dict to store per multi-gpu tower stats into. + # In PyTorch multi-GPU, we use a single TorchPolicy and copy + # it's Model(s) n times (1 copy for each GPU). When computing the loss + # on each tower, we cannot store the stats (e.g. `entropy`) inside the + # policy object as this would lead to race conditions between the + # different towers all accessing the same property at the same time. + self.tower_stats = {} + + @override(ModelV2) + def variables( + self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: + p = list(self.parameters()) + if as_dict: + return {k: p[i] for i, k in enumerate(self.state_dict().keys())} + return p + + @override(ModelV2) + def trainable_variables( + self, as_dict: bool = False + ) -> Union[List[TensorType], Dict[str, TensorType]]: + if as_dict: + return { + k: v for k, v in self.variables(as_dict=True).items() if v.requires_grad + } + return [v for v in self.variables() if v.requires_grad] diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py new file mode 100644 index 000000000000..32153b1e2e80 --- /dev/null +++ b/rllib/models/torch/visionnet.py @@ -0,0 +1,293 @@ +import numpy as np +from typing import Dict, List +import gymnasium as gym + +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.models.torch.misc import ( + normc_initializer, + same_padding, + SlimConv2d, + SlimFC, +) +from ray.rllib.models.utils import get_activation_fn, get_filter_config +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import ModelConfigDict, TensorType + +torch, nn = try_import_torch() + + +class VisionNetwork(TorchModelV2, nn.Module): + """Generic vision network.""" + + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: int, + model_config: ModelConfigDict, + name: str, + ): + + if not model_config.get("conv_filters"): + model_config["conv_filters"] = get_filter_config(obs_space.shape) + + TorchModelV2.__init__( + self, obs_space, action_space, num_outputs, model_config, name + ) + nn.Module.__init__(self) + + activation = self.model_config.get("conv_activation") + filters = self.model_config["conv_filters"] + assert len(filters) > 0, "Must provide at least 1 entry in `conv_filters`!" + + # Post FC net config. + post_fcnet_hiddens = model_config.get("post_fcnet_hiddens", []) + post_fcnet_activation = get_activation_fn( + model_config.get("post_fcnet_activation"), framework="torch" + ) + + no_final_linear = self.model_config.get("no_final_linear") + vf_share_layers = self.model_config.get("vf_share_layers") + + # Whether the last layer is the output of a Flattened (rather than + # a n x (1,1) Conv2D). + self.last_layer_is_flattened = False + self._logits = None + + layers = [] + (w, h, in_channels) = obs_space.shape + + in_size = [w, h] + for out_channels, kernel, stride in filters[:-1]: + padding, out_size = same_padding(in_size, kernel, stride) + layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + padding, + activation_fn=activation, + ) + ) + in_channels = out_channels + in_size = out_size + + out_channels, kernel, stride = filters[-1] + + # No final linear: Last layer has activation function and exits with + # num_outputs nodes (this could be a 1x1 conv or a FC layer, depending + # on `post_fcnet_...` settings). + if no_final_linear and num_outputs: + out_channels = out_channels if post_fcnet_hiddens else num_outputs + layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + None, # padding=valid + activation_fn=activation, + ) + ) + + # Add (optional) post-fc-stack after last Conv2D layer. + layer_sizes = post_fcnet_hiddens[:-1] + ( + [num_outputs] if post_fcnet_hiddens else [] + ) + for i, out_size in enumerate(layer_sizes): + layers.append( + SlimFC( + in_size=out_channels, + out_size=out_size, + activation_fn=post_fcnet_activation, + initializer=normc_initializer(1.0), + ) + ) + out_channels = out_size + + # Finish network normally (w/o overriding last layer size with + # `num_outputs`), then add another linear one of size `num_outputs`. + else: + layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + None, # padding=valid + activation_fn=activation, + ) + ) + + # num_outputs defined. Use that to create an exact + # `num_output`-sized (1,1)-Conv2D. + if num_outputs: + in_size = [ + np.ceil((in_size[0] - kernel[0]) / stride), + np.ceil((in_size[1] - kernel[1]) / stride), + ] + padding, _ = same_padding(in_size, [1, 1], [1, 1]) + if post_fcnet_hiddens: + layers.append(nn.Flatten()) + in_size = out_channels + # Add (optional) post-fc-stack after last Conv2D layer. + for i, out_size in enumerate(post_fcnet_hiddens + [num_outputs]): + layers.append( + SlimFC( + in_size=in_size, + out_size=out_size, + activation_fn=post_fcnet_activation + if i < len(post_fcnet_hiddens) - 1 + else None, + initializer=normc_initializer(1.0), + ) + ) + in_size = out_size + # Last layer is logits layer. + self._logits = layers.pop() + + else: + self._logits = SlimConv2d( + out_channels, + num_outputs, + [1, 1], + 1, + padding, + activation_fn=None, + ) + + # num_outputs not known -> Flatten, then set self.num_outputs + # to the resulting number of nodes. + else: + self.last_layer_is_flattened = True + layers.append(nn.Flatten()) + + self._convs = nn.Sequential(*layers) + + # If our num_outputs still unknown, we need to do a test pass to + # figure out the output dimensions. This could be the case, if we have + # the Flatten layer at the end. + if self.num_outputs is None: + # Create a B=1 dummy sample and push it through out conv-net. + dummy_in = ( + torch.from_numpy(self.obs_space.sample()) + .permute(2, 0, 1) + .unsqueeze(0) + .float() + ) + dummy_out = self._convs(dummy_in) + self.num_outputs = dummy_out.shape[1] + + # Build the value layers + self._value_branch_separate = self._value_branch = None + if vf_share_layers: + self._value_branch = SlimFC( + out_channels, 1, initializer=normc_initializer(0.01), activation_fn=None + ) + else: + vf_layers = [] + (w, h, in_channels) = obs_space.shape + in_size = [w, h] + for out_channels, kernel, stride in filters[:-1]: + padding, out_size = same_padding(in_size, kernel, stride) + vf_layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + padding, + activation_fn=activation, + ) + ) + in_channels = out_channels + in_size = out_size + + out_channels, kernel, stride = filters[-1] + vf_layers.append( + SlimConv2d( + in_channels, + out_channels, + kernel, + stride, + None, + activation_fn=activation, + ) + ) + + vf_layers.append( + SlimConv2d( + in_channels=out_channels, + out_channels=1, + kernel=1, + stride=1, + padding=None, + activation_fn=None, + ) + ) + self._value_branch_separate = nn.Sequential(*vf_layers) + + # Holds the current "base" output (before logits layer). + self._features = None + + @override(TorchModelV2) + def forward( + self, + input_dict: Dict[str, TensorType], + state: List[TensorType], + seq_lens: TensorType, + ) -> (TensorType, List[TensorType]): + self._features = input_dict["obs"].float() + # Permuate b/c data comes in as [B, dim, dim, channels]: + self._features = self._features.permute(0, 3, 1, 2) + conv_out = self._convs(self._features) + # Store features to save forward pass when getting value_function out. + if not self._value_branch_separate: + self._features = conv_out + + if not self.last_layer_is_flattened: + if self._logits: + conv_out = self._logits(conv_out) + if len(conv_out.shape) == 4: + if conv_out.shape[2] != 1 or conv_out.shape[3] != 1: + raise ValueError( + "Given `conv_filters` ({}) do not result in a [B, {} " + "(`num_outputs`), 1, 1] shape (but in {})! Please " + "adjust your Conv2D stack such that the last 2 dims " + "are both 1.".format( + self.model_config["conv_filters"], + self.num_outputs, + list(conv_out.shape), + ) + ) + logits = conv_out.squeeze(3) + logits = logits.squeeze(2) + else: + logits = conv_out + return logits, state + else: + return conv_out, state + + @override(TorchModelV2) + def value_function(self) -> TensorType: + assert self._features is not None, "must call forward() first" + if self._value_branch_separate: + value = self._value_branch_separate(self._features) + value = value.squeeze(3) + value = value.squeeze(2) + return value.squeeze(1) + else: + if not self.last_layer_is_flattened: + features = self._features.squeeze(3) + features = features.squeeze(2) + else: + features = self._features + return self._value_branch(features).squeeze(1) + + def _hidden_layers(self, obs: TensorType) -> TensorType: + res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major + res = res.squeeze(3) + res = res.squeeze(2) + return res From f46b569f0669a3983cd4a609ef3fa65af5ca70d7 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 11:53:21 -0800 Subject: [PATCH 43/51] fix model names and some nits Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/base.py | 25 +++---------------- rllib/models/experimental/tf/encoder.py | 12 ++++----- rllib/models/experimental/tf/fcmodel.py | 9 +++---- rllib/models/experimental/tf/primitives.py | 17 ++++++++++--- rllib/models/experimental/torch/primitives.py | 25 +++++++++++-------- 5 files changed, 41 insertions(+), 47 deletions(-) diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 79e1eac67824..305fe82c2e17 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -38,14 +38,14 @@ class Model: """Framework-agnostic base class for RLlib models. Models are low-level neural network components that offer input- and - output-specification, a forward method, and a get_initial_state method. They are - therefore not algorithm-specific. Models are composed in RLModules, where tensors - are passed through them. + output-specification, a forward method, and a get_initial_state method. Models + are composed in RLModules. """ def __init__(self, config: ModelConfig): self.config = config + @abc.abstractmethod def get_initial_state(self): """Returns the initial state of the model.""" return {} @@ -75,22 +75,3 @@ def input_spec(self) -> SpecDict: """ # If no checking is needed, we can simply return an empty spec. return SpecDict() - - @check_input_specs("input_spec", filter=True, cache=True) - @check_output_specs("output_spec", cache=True) - @abc.abstractmethod - def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: - """Computes the output of this module for each timestep. - - Outputs and inputs should be subject to spec checking. - - Args: - inputs: A TensorDict containing model inputs - kwargs: For forwards compatibility - - Examples: - # This is abstract, see the torch/tf2/jax implementations - >>> out = model(TensorDict({"in": np.arange(10)})) - >>> out # TensorDict(...) - """ - raise NotImplementedError diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index 9bbd3168d6f6..a7f7effe4698 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -19,15 +19,15 @@ from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.models.experimental.torch.encoder import ENCODER_OUT -from ray.rllib.models.experimental.tf.primitives import TfMLPModel +from ray.rllib.models.experimental.tf.primitives import TfModel -class TfFCEncoder(Encoder, TfMLPModel): +class TfFCEncoder(Encoder, TfModel): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TfMLPModel.__init__(self, config) + TfModel.__init__(self, config) self.net = TfMLP( input_dim=config.input_dim, @@ -52,12 +52,12 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} -class LSTMEncoder(Encoder, TfMLPModel): +class LSTMEncoder(Encoder, TfModel): """An encoder that uses an LSTM cell and a linear layer.""" def __init__(self, config: ModelConfig) -> None: Encoder.__init__(self, config) - TfMLPModel.__init__(self, config) + TfModel.__init__(self, config) self.lstm = nn.LSTM( config.input_dim, @@ -136,7 +136,7 @@ def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: } -class TfIdentityEncoder(TfMLPModel): +class TfIdentityEncoder(TfModel): """An encoder that does nothing but passing on inputs. We use this so that we avoid having many if/else statements in the RLModule. diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/fcmodel.py index 86b2423a84ec..0c824cd21779 100644 --- a/rllib/models/experimental/tf/fcmodel.py +++ b/rllib/models/experimental/tf/fcmodel.py @@ -2,18 +2,17 @@ from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.utils import try_import_tf from ray.rllib.models.temp_spec_classes import TensorDict -from ray.rllib.models.tf.primitives import FCNet, TFModel +from ray.rllib.models.experimental.tf.primitives import TfMLP, TfModel from ray.rllib.models.experimental.base import ModelConfig, ForwardOutputType tf1, tf, tfv = try_import_tf() -class TfMLPModel(tf.Module, TFModel): +class TfMLPModel(TfModel): def __init__(self, config: ModelConfig) -> None: - tf.Module.__init__(self) - TFModel.__init__(self, config) + TfModel.__init__(self, config) - self.net = FCNet( + self.net = TfMLP( input_dim=config.input_dim, hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 347d103a890f..10a3c964f8b1 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -9,6 +9,10 @@ from ray.rllib.utils.typing import TensorType from ray.rllib.models.utils import get_activation_fn from typing import Tuple +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, +) _, tf, _ = try_import_tf() @@ -21,7 +25,7 @@ def _call_not_decorated(input_or_output): ) -class TfMLPModel(Model, tf.Module): +class TfModel(Model, tf.Module): """Base class for RLlib models. This class is used to define the general interface for RLlib models and checks @@ -31,9 +35,16 @@ class TfMLPModel(Model, tf.Module): def __init__(self, config): super().__init__(config) - assert is_input_decorated(self.__call__), _call_not_decorated("input") - assert is_output_decorated(self.__call__), _call_not_decorated("output") + # automatically apply spec checking + if not is_input_decorated(self.__call__): + self.__call__ = check_input_specs("input_spec", filter=True, cache=True)( + self.__call__ + ) + if not is_output_decorated(self.__call__): + self.__call__ = check_output_specs("output_spec", cache=True)(self.__call__) + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: """Returns the output of this model for the given input. diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index 445d97be0e5b..a7b4cba176ca 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -11,18 +11,14 @@ from ray.rllib.utils.typing import TensorType from ray.rllib.models.experimental.base import ModelConfig from ray.rllib.models.utils import get_activation_fn +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, +) torch, nn = try_import_torch() -def _forward_not_decorated(input_or_output): - return ( - f"forward not decorated with {input_or_output} specification. Decorate " - f"with @check_{input_or_output}_specs() to define a specification. See " - f"BaseModel for examples." - ) - - class TorchModel(nn.Module, Model): """Base class for torch models. @@ -34,9 +30,16 @@ class TorchModel(nn.Module, Model): def __init__(self, config: ModelConfig): nn.Module.__init__(self) Model.__init__(self, config) - assert is_input_decorated(self.forward), _forward_not_decorated("input") - assert is_output_decorated(self.forward), _forward_not_decorated("output") - + # automatically apply spec checking + if not is_input_decorated(self.forward): + self.forward = check_input_specs("input_spec", filter=True, cache=True)( + self.forward + ) + if not is_output_decorated(self.forward): + self.forward = check_output_specs("output_spec", cache=True)(self.forward) + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: """Returns the output of this model for the given input. From 868d442b10d2608048190322d5fb4c5a1f3b5ea3 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 12:47:35 -0800 Subject: [PATCH 44/51] renaming and lint Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/base.py | 4 ---- rllib/models/experimental/configs.py | 4 ++-- rllib/models/experimental/tf/{fcmodel.py => mlp.py} | 0 rllib/models/experimental/torch/{fcmodel.py => mlp.py} | 0 4 files changed, 2 insertions(+), 6 deletions(-) rename rllib/models/experimental/tf/{fcmodel.py => mlp.py} (100%) rename rllib/models/experimental/torch/{fcmodel.py => mlp.py} (100%) diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 305fe82c2e17..9f9a13726317 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -1,10 +1,6 @@ from dataclasses import dataclass import abc -from ray.rllib.models.specs.checker import ( - check_input_specs, - check_output_specs, -) from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.temp_spec_classes import TensorDict from ray.rllib.utils.annotations import ExperimentalAPI diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index 41f7a74edc3f..de2de16f6cf2 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -59,11 +59,11 @@ class MLPConfig(ModelConfig): @_framework_implemented() def build(self, framework: str = "torch") -> Model: if framework == "torch": - from ray.rllib.models.experimental.torch.fcmodel import TorchMLPModel + from ray.rllib.models.experimental.torch.mlp import TorchMLPModel return TorchMLPModel(self) else: - from ray.rllib.models.experimental.tf.fcmodel import TfMLPModel + from ray.rllib.models.experimental.tf.mlp import TfMLPModel return TfMLPModel(self) diff --git a/rllib/models/experimental/tf/fcmodel.py b/rllib/models/experimental/tf/mlp.py similarity index 100% rename from rllib/models/experimental/tf/fcmodel.py rename to rllib/models/experimental/tf/mlp.py diff --git a/rllib/models/experimental/torch/fcmodel.py b/rllib/models/experimental/torch/mlp.py similarity index 100% rename from rllib/models/experimental/torch/fcmodel.py rename to rllib/models/experimental/torch/mlp.py From 8326719c8cc28db1acf262979d056df82256b3b8 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 14:23:39 -0800 Subject: [PATCH 45/51] self-review Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/torch/ppo_torch_rl_module.py | 10 +++++----- rllib/models/experimental/base.py | 5 ++--- rllib/models/experimental/configs.py | 11 ++++++----- rllib/models/experimental/tf/encoder.py | 4 ++-- rllib/models/experimental/tf/mlp.py | 2 +- rllib/models/experimental/tf/primitives.py | 4 +--- rllib/models/experimental/torch/encoder.py | 14 +++++++------- rllib/models/experimental/torch/mlp.py | 6 +++--- rllib/models/experimental/torch/primitives.py | 4 +--- rllib/models/torch/primitives.py | 4 ++-- 10 files changed, 30 insertions(+), 34 deletions(-) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 9d60245c3176..4b02f9823fa7 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -45,15 +45,15 @@ def get_ppo_loss(fwd_in, fwd_out): @ExperimentalAPI @dataclass -class PPOModuleConfig(RLModuleConfig): - """Configuration for the PPO RLModule. +class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to Torch-unspecific file + """Configuration for the PPORLModule. Attributes: observation_space: The observation space of the environment. action_space: The action space of the environment. encoder_config: The configuration for the encoder network. - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. + pi_config: The configuration for the policy head. + vf_config: The configuration for the value function head. free_log_std: For DiagGaussian action distributions, make the second half of the model outputs floating bias variables instead of state-dependent. This only has an effect is using the default fully connected net. @@ -76,7 +76,7 @@ def setup(self) -> None: assert self.config.vf_config, "vf_config must be provided." assert self.config.encoder_config, "shared encoder config must be " "provided." - # TODO(Artur): Unify to tf and torch setup with ModelBuilder + # TODO(Artur): Unify to tf and torch setup with Catalog self.encoder = self.config.encoder_config.build(framework="torch") self.pi = self.config.pi_config.build(framework="torch") self.vf = self.config.vf_config.build(framework="torch") diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py index 9f9a13726317..bcf822770e97 100644 --- a/rllib/models/experimental/base.py +++ b/rllib/models/experimental/base.py @@ -11,11 +11,10 @@ @ExperimentalAPI @dataclass class ModelConfig(abc.ABC): - """Configuration for a model. + """Base class for model configurations. Attributes: - output_dim: The output dimension of the network. If None, the output_dim will - be the number of nodes in the last hidden layer. + output_dim: The output dimension of the network. """ output_dim: int = None diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py index de2de16f6cf2..6a6c84772063 100644 --- a/rllib/models/experimental/configs.py +++ b/rllib/models/experimental/configs.py @@ -3,6 +3,7 @@ import functools from ray.rllib.models.experimental.base import ModelConfig, Model +from ray.rllib.models.experimental.encoder import Encoder from ray.rllib.utils.annotations import DeveloperAPI @@ -71,15 +72,15 @@ def build(self, framework: str = "torch") -> Model: @dataclass class MLPEncoderConfig(MLPConfig): @_framework_implemented() - def build(self, framework: str = "torch"): + def build(self, framework: str = "torch") -> Encoder: if framework == "torch": from ray.rllib.models.experimental.torch.encoder import TorchMLPEncoder return TorchMLPEncoder(self) else: - from ray.rllib.models.experimental.tf.encoder import TfFCEncoder + from ray.rllib.models.experimental.tf.encoder import TfMLPEncoder - return TfFCEncoder(self) + return TfMLPEncoder(self) @dataclass @@ -91,7 +92,7 @@ class LSTMEncoderConfig(ModelConfig): output_activation: str = "linear" @_framework_implemented(tf2=False) - def build(self, framework: str = "torch"): + def build(self, framework: str = "torch") -> Encoder: if framework == "torch": from ray.rllib.models.experimental.torch.encoder import TorchLSTMEncoder @@ -103,7 +104,7 @@ class IdentityConfig(ModelConfig): """Configuration for an identity encoder.""" @_framework_implemented() - def build(self, framework: str = "torch"): + def build(self, framework: str = "torch") -> Model: if framework == "torch": from ray.rllib.models.experimental.torch.encoder import TorchIdentityEncoder diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index a7f7effe4698..b991797bf7f3 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -22,7 +22,7 @@ from ray.rllib.models.experimental.tf.primitives import TfModel -class TfFCEncoder(Encoder, TfModel): +class TfMLPEncoder(Encoder, TfModel): """A fully connected encoder.""" def __init__(self, config: ModelConfig) -> None: @@ -110,7 +110,7 @@ def output_spec(self): } ) - @check_input_specs("input_spec", filter=True, cache=False) + @check_input_specs("input_spec", cache=False) @check_output_specs("output_spec", cache=False) def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: x = inputs[SampleBatch.OBS] diff --git a/rllib/models/experimental/tf/mlp.py b/rllib/models/experimental/tf/mlp.py index 0c824cd21779..00471991f512 100644 --- a/rllib/models/experimental/tf/mlp.py +++ b/rllib/models/experimental/tf/mlp.py @@ -27,7 +27,7 @@ def input_spec(self): def output_spec(self): return TFTensorSpecs("b, h", h=self.config.output_dim) - @check_input_specs("input_spec", filter=True, cache=False) + @check_input_specs("input_spec", cache=False) @check_output_specs("output_spec", cache=False) def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return self.net(inputs) diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py index 10a3c964f8b1..a7c45c11afe0 100644 --- a/rllib/models/experimental/tf/primitives.py +++ b/rllib/models/experimental/tf/primitives.py @@ -37,9 +37,7 @@ def __init__(self, config): super().__init__(config) # automatically apply spec checking if not is_input_decorated(self.__call__): - self.__call__ = check_input_specs("input_spec", filter=True, cache=True)( - self.__call__ - ) + self.__call__ = check_input_specs("input_spec", cache=True)(self.__call__) if not is_output_decorated(self.__call__): self.__call__ = check_output_specs("output_spec", cache=True)(self.__call__) diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index c438308106ed..53b997763acc 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -39,19 +39,19 @@ def __init__(self, config: ModelConfig) -> None: @property @override(TorchModel) - def input_spec(self): + def input_spec(self) -> SpecDict: return SpecDict( {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} ) @property @override(TorchModel) - def output_spec(self): + def output_spec(self) -> SpecDict: return SpecDict( {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} ) - @check_input_specs("input_spec", filter=True, cache=False) + @check_input_specs("input_spec", cache=False) @check_output_specs("output_spec", cache=False) def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} @@ -80,7 +80,7 @@ def get_initial_state(self): @property @override(TorchModel) - def input_spec(self): + def input_spec(self) -> SpecDict: config = self.config return SpecDict( { @@ -100,7 +100,7 @@ def input_spec(self): @property @override(TorchModel) - def output_spec(self): + def output_spec(self) -> SpecDict: config = self.config return SpecDict( { @@ -147,14 +147,14 @@ def __init__(self, config: ModelConfig) -> None: super().__init__(config) @property - def input_spec(self): + def input_spec(self) -> SpecDict: return SpecDict( # Use the output dim as input dim because identity. {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.output_dim)} ) @property - def output_spec(self): + def output_spec(self) -> SpecDict: return SpecDict( {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} ) diff --git a/rllib/models/experimental/torch/mlp.py b/rllib/models/experimental/torch/mlp.py index 6b9cb674b84d..31b45285f933 100644 --- a/rllib/models/experimental/torch/mlp.py +++ b/rllib/models/experimental/torch/mlp.py @@ -23,15 +23,15 @@ def __init__(self, config: ModelConfig) -> None: @property @override(Model) - def input_spec(self): + def input_spec(self) -> TorchTensorSpec: return TorchTensorSpec("b, h", h=self.config.input_dim) @property @override(Model) - def output_spec(self): + def output_spec(self) -> TorchTensorSpec: return TorchTensorSpec("b, h", h=self.config.output_dim) - @check_input_specs("input_spec", filter=True, cache=False) + @check_input_specs("input_spec", cache=False) @check_output_specs("output_spec", cache=False) @override(TorchModel) def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py index a7b4cba176ca..11156380133d 100644 --- a/rllib/models/experimental/torch/primitives.py +++ b/rllib/models/experimental/torch/primitives.py @@ -32,9 +32,7 @@ def __init__(self, config: ModelConfig): Model.__init__(self, config) # automatically apply spec checking if not is_input_decorated(self.forward): - self.forward = check_input_specs("input_spec", filter=True, cache=True)( - self.forward - ) + self.forward = check_input_specs("input_spec", cache=True)(self.forward) if not is_output_decorated(self.forward): self.forward = check_output_specs("output_spec", cache=True)(self.forward) diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py index 191a0ff35e5a..eaa43a6db3d4 100644 --- a/rllib/models/torch/primitives.py +++ b/rllib/models/torch/primitives.py @@ -11,8 +11,8 @@ class FCNet(nn.Module): Attributes: input_dim: The input dimension of the network. It cannot be None. - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. + output_dim: The output dimension of the network. If None, the output_dim will + be the number of nodes in the last hidden layer. hidden_layers: The sizes of the hidden layers. activation: The activation function to use after each layer. """ From cede51fe3f878212395f89dcd8dc9ecbcb5ecb74 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 15:18:41 -0800 Subject: [PATCH 46/51] output activations Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/tf/mlp.py | 1 + rllib/models/experimental/torch/mlp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/rllib/models/experimental/tf/mlp.py b/rllib/models/experimental/tf/mlp.py index 00471991f512..beebabd02c68 100644 --- a/rllib/models/experimental/tf/mlp.py +++ b/rllib/models/experimental/tf/mlp.py @@ -17,6 +17,7 @@ def __init__(self, config: ModelConfig) -> None: hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, hidden_layer_activation=config.hidden_layer_activation, + output_activation=config.output_activation, ) @property diff --git a/rllib/models/experimental/torch/mlp.py b/rllib/models/experimental/torch/mlp.py index 31b45285f933..4d5afa11d502 100644 --- a/rllib/models/experimental/torch/mlp.py +++ b/rllib/models/experimental/torch/mlp.py @@ -19,6 +19,7 @@ def __init__(self, config: ModelConfig) -> None: hidden_layer_dims=config.hidden_layer_dims, output_dim=config.output_dim, hidden_layer_activation=config.hidden_layer_activation, + output_activation=config.output_activation, ) @property From 8bc5c087ec33c2ba751857f0303f0c33be4e409d Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 15:23:24 -0800 Subject: [PATCH 47/51] remove useless init and add comment to torch encoder Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/tf/encoder.py | 3 --- rllib/models/experimental/torch/encoder.py | 6 ++++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py index b991797bf7f3..4ee2ea4433ea 100644 --- a/rllib/models/experimental/tf/encoder.py +++ b/rllib/models/experimental/tf/encoder.py @@ -142,9 +142,6 @@ class TfIdentityEncoder(TfModel): We use this so that we avoid having many if/else statements in the RLModule. """ - def __init__(self, config: ModelConfig) -> None: - super().__init__(config) - @property def input_spec(self): return SpecDict( diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py index 53b997763acc..cb87f9b8dc78 100644 --- a/rllib/models/experimental/torch/encoder.py +++ b/rllib/models/experimental/torch/encoder.py @@ -143,8 +143,10 @@ def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: class TorchIdentityEncoder(TorchModel): - def __init__(self, config: ModelConfig) -> None: - super().__init__(config) + """An encoder that does nothing but passing on inputs. + + We use this so that we avoid having many if/else statements in the RLModule. + """ @property def input_spec(self) -> SpecDict: From 7846d1b9e43d689f8a5246b3be8a4b4cc9ee1d41 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Fri, 27 Jan 2023 15:57:02 -0800 Subject: [PATCH 48/51] remove useless constructor Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/encoder.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rllib/models/experimental/encoder.py b/rllib/models/experimental/encoder.py index 30b4852a9a6a..da2da4e7916b 100644 --- a/rllib/models/experimental/encoder.py +++ b/rllib/models/experimental/encoder.py @@ -22,10 +22,6 @@ class Encoder(Model): That is, for time-series data, we encode into the latent space for each time step. This should be reflected in the output_spec. """ - - def __init__(self, config: dict): - super().__init__(config) - def get_initial_state(self) -> TensorType: """Returns the initial state of the encoder. From d2b4aee492c4c1b5ce161a02620057d656ca8cc9 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Mon, 30 Jan 2023 11:07:40 -0800 Subject: [PATCH 49/51] lint Signed-off-by: Artur Niederfahrenhorst --- rllib/models/experimental/encoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rllib/models/experimental/encoder.py b/rllib/models/experimental/encoder.py index da2da4e7916b..bf1e85d4ff8e 100644 --- a/rllib/models/experimental/encoder.py +++ b/rllib/models/experimental/encoder.py @@ -22,6 +22,7 @@ class Encoder(Model): That is, for time-series data, we encode into the latent space for each time step. This should be reflected in the output_spec. """ + def get_initial_state(self) -> TensorType: """Returns the initial state of the encoder. From 166e4ebce6a477b910ac84257717f72208411ad0 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 31 Jan 2023 10:27:34 -0800 Subject: [PATCH 50/51] kourohs's nits Signed-off-by: Artur Niederfahrenhorst --- .../ppo/torch/ppo_torch_rl_module.py | 2 +- rllib/core/rl_module/encoder.py | 202 ------------------ rllib/core/rl_module/encoder_tf.py | 37 ---- 3 files changed, 1 insertion(+), 240 deletions(-) delete mode 100644 rllib/core/rl_module/encoder.py delete mode 100644 rllib/core/rl_module/encoder_tf.py diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 4b02f9823fa7..2c26c1a02f3c 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -45,7 +45,7 @@ def get_ppo_loss(fwd_in, fwd_out): @ExperimentalAPI @dataclass -class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to Torch-unspecific file +class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to non-torch-specific file """Configuration for the PPORLModule. Attributes: diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py deleted file mode 100644 index e88bcfdce1e3..000000000000 --- a/rllib/core/rl_module/encoder.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch -import torch.nn as nn -import tree -from typing import List - -from dataclasses import dataclass, field - -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.rnn_sequencing import add_time_dimension -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.checker import check_input_specs, check_output_specs -from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.torch.primitives import FCNet - -# TODO (Kourosh): Find a better / more straight fwd approach for sub-components - -ENCODER_OUT = "encoder_out" -STATE_IN = "state_in" -STATE_OUT = "state_out" - - -@dataclass -class EncoderConfig: - """Configuration for an encoder network. - - Attributes: - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. - """ - - output_dim: int = None - - -@dataclass -class IdentityConfig(EncoderConfig): - """Configuration for an identity encoder.""" - - def build(self): - return IdentityEncoder(self) - - -@dataclass -class MLPConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FullyConnectedEncoder(self) - - -@dataclass -class LSTMConfig(EncoderConfig): - input_dim: int = None - hidden_dim: int = None - num_layers: int = None - batch_first: bool = True - - def build(self): - return LSTMEncoder(self) - - -class Encoder(nn.Module): - def __init__(self, config: EncoderConfig) -> None: - super().__init__() - self.config = config - self._input_spec = self.input_spec() - self._output_spec = self.output_spec() - - def get_initial_state(self): - return [] - - def input_spec(self): - return SpecDict() - - def output_spec(self): - return SpecDict() - - @check_input_specs("_input_spec") - @check_output_specs("_output_spec") - def forward(self, input_dict): - return self._forward(input_dict) - - def _forward(self, input_dict): - raise NotImplementedError - - -class FullyConnectedEncoder(Encoder): - def __init__(self, config: MLPConfig) -> None: - super().__init__(config) - - self.net = FCNet( - input_dim=config.input_dim, - hidden_layers=config.hidden_layers, - output_dim=config.output_dim, - activation=config.activation, - ) - - def input_spec(self): - return SpecDict( - {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} - ) - - def output_spec(self): - return SpecDict( - {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} - ) - - def _forward(self, input_dict): - return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} - - -class LSTMEncoder(Encoder): - def __init__(self, config: LSTMConfig) -> None: - super().__init__(config) - - self.lstm = nn.LSTM( - config.input_dim, - config.hidden_dim, - config.num_layers, - batch_first=config.batch_first, - ) - self.linear = nn.Linear(config.hidden_dim, config.output_dim) - - def get_initial_state(self): - config = self.config - return { - "h": torch.zeros(config.num_layers, config.hidden_dim), - "c": torch.zeros(config.num_layers, config.hidden_dim), - } - - def input_spec(self): - config = self.config - return SpecDict( - { - # bxt is just a name for better readability to indicated padded batch - SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - STATE_IN: { - "h": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - "c": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - }, - } - ) - - def output_spec(self): - config = self.config - return SpecDict( - { - ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { - "h": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - "c": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - }, - } - ) - - def _forward(self, input_dict: SampleBatch): - x = input_dict[SampleBatch.OBS] - states = input_dict[STATE_IN] - # states are batch-first when coming in - states = tree.map_structure(lambda x: x.transpose(0, 1), states) - - x = add_time_dimension( - x, - seq_lens=input_dict[SampleBatch.SEQ_LENS], - framework="torch", - time_major=not self.config.batch_first, - ) - states_o = {} - x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) - - x = self.linear(x) - x = x.view(-1, x.shape[-1]) - - return { - ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), - } - - -class IdentityEncoder(Encoder): - def __init__(self, config: EncoderConfig) -> None: - super().__init__(config) - - def _forward(self, input_dict): - return input_dict diff --git a/rllib/core/rl_module/encoder_tf.py b/rllib/core/rl_module/encoder_tf.py deleted file mode 100644 index 5c517f6c745d..000000000000 --- a/rllib/core/rl_module/encoder_tf.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - -from ray.rllib.core.rl_module.encoder import EncoderConfig -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.models.tf.primitives import FCNet, IdentityNetwork - -tf1, tf, tfv = try_import_tf() - - -@dataclass -class FCTfConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - output_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FCNet( - self.input_dim, self.hidden_layers, self.output_dim, self.activation - ) - - -@dataclass -class IdentityTfConfig(EncoderConfig): - """A network that returns the input as the output.""" - - def build(self): - return IdentityNetwork() From c70224fa67a716c4a67b8b10bd2db556a3240701 Mon Sep 17 00:00:00 2001 From: Artur Niederfahrenhorst Date: Tue, 31 Jan 2023 10:34:14 -0800 Subject: [PATCH 51/51] unify torch + tf Signed-off-by: Artur Niederfahrenhorst --- rllib/algorithms/ppo/tf/ppo_tf_rl_module.py | 52 +++++++++++++-------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index b9f550ecd243..d24ba43e7a01 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -1,17 +1,20 @@ -import gymnasium as gym from typing import Mapping, Any, List + +import gymnasium as gym + +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule -from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.models.experimental.configs import MLPConfig, IdentityConfig +from ray.rllib.models.experimental.encoder import STATE_OUT +from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT +from ray.rllib.models.experimental.tf.primitives import TfMLP +from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian -from ray.rllib.models.experimental.tf.primitives import TfMLP -from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() @@ -60,9 +63,14 @@ def output_specs_train(self) -> List[str]: @override(TfRLModule) def _forward_train(self, batch: NestedDict): + output = {} + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + + # Actions action_logits = self.pi(encoder_out[ENCODER_OUT]) - vf = self.vf(encoder_out[ENCODER_OUT]) if self._is_discrete: action_dist = Categorical(action_logits) @@ -71,10 +79,10 @@ def _forward_train(self, batch: NestedDict): action_logits, None, action_space=self.config.action_space ) - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.VF_PREDS: tf.squeeze(vf, axis=-1), - } + vf = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.VF_PREDS] = tf.squeeze(vf, axis=-1) + return output @override(TfRLModule) @@ -87,7 +95,11 @@ def output_specs_inference(self) -> List[str]: @override(TfRLModule) def _forward_inference(self, batch) -> Mapping[str, Any]: + output = {} + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] action_logits = self.pi(encoder_out[ENCODER_OUT]) @@ -97,9 +109,8 @@ def _forward_inference(self, batch) -> Mapping[str, Any]: action, _ = tf.split(action_logits, num_or_size_splits=2, axis=1) action_dist = Deterministic(action, model=None) - output = { - SampleBatch.ACTION_DIST: action_dist, - } + output[SampleBatch.ACTION_DIST] = action_dist + return output @override(TfRLModule) @@ -116,7 +127,10 @@ def output_specs_exploration(self) -> List[str]: @override(TfRLModule) def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: + output = {} encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] action_logits = self.pi(encoder_out[ENCODER_OUT]) vf = self.vf(encoder_out[ENCODER_OUT]) @@ -127,11 +141,11 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = DiagGaussian( action_logits, None, action_space=self.config.action_space ) - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.ACTION_DIST_INPUTS: action_logits, - SampleBatch.VF_PREDS: tf.squeeze(vf, axis=-1), - } + + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.ACTION_DIST_INPUTS] = action_logits + output[SampleBatch.VF_PREDS] = tf.squeeze(vf, axis=-1) + return output @classmethod