diff --git a/rllib/BUILD b/rllib/BUILD index 46642d737365..2c70e8e3d14b 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1974,14 +1974,6 @@ py_test( srcs = ["models/specs/tests/test_spec_dict.py"] ) -# test TorchVectorEncoder -py_test( - name = "test_torch_vector_encoder", - tags = ["team:rllib", "models"], - size = "small", - srcs = ["models/torch/encoders/tests/test_torch_vector_encoder.py"] -) - # -------------------------------------------------------------------- # Offline diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 6b0e0161a5a6..18b917dd86e4 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -1,47 +1,43 @@ import itertools -import ray import unittest -import numpy as np + import gymnasium as gym -import torch +import numpy as np import tensorflow as tf +import torch import tree +import ray from ray.rllib import SampleBatch +from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( + PPOTfRLModule, +) +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule, - PPOModuleConfig, ) -from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import ( - PPOTfRLModule, - PPOTfModuleConfig, +from ray.rllib.models.experimental.configs import ( + MLPConfig, + MLPEncoderConfig, + LSTMEncoderConfig, ) -from ray.rllib.core.rl_module.encoder import ( - FCConfig, - IdentityConfig, - LSTMConfig, +from ray.rllib.models.experimental.torch.encoder import ( STATE_IN, STATE_OUT, ) -from ray.rllib.core.rl_module.encoder_tf import ( - FCTfConfig, - IdentityTfConfig, -) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor -def get_expected_model_config_torch( - env: gym.Env, lstm: bool, shared_encoder: bool +def get_expected_model_config( + env: gym.Env, + lstm: bool, ) -> PPOModuleConfig: """Get a PPOModuleConfig that we would expect from the catalog otherwise. Args: env: Environment for which we build the model later lstm: If True, build recurrent pi encoder - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. Returns: A PPOModuleConfig containing the relevant configs to build PPORLModule @@ -51,107 +47,44 @@ def get_expected_model_config_torch( ) obs_dim = env.observation_space.shape[0] - if shared_encoder: - assert not lstm, "LSTM can only be used in PI" - shared_encoder_config = FCConfig( + if lstm: + encoder_config = LSTMEncoderConfig( input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", + hidden_dim=32, + batch_first=True, + num_layers=1, output_dim=32, ) - pi_encoder_config = IdentityConfig(output_dim=32) - vf_encoder_config = IdentityConfig(output_dim=32) else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - if lstm: - pi_encoder_config = LSTMConfig( - input_dim=obs_dim, - hidden_dim=32, - batch_first=True, - output_dim=32, - num_layers=1, - ) - else: - pi_encoder_config = FCConfig( - input_dim=obs_dim, - output_dim=32, - hidden_layers=[32], - activation="ReLU", - ) - vf_encoder_config = FCConfig( + encoder_config = MLPEncoderConfig( input_dim=obs_dim, + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", output_dim=32, - hidden_layers=[32], - activation="ReLU", ) - pi_config = FCConfig() - vf_config = FCConfig() - - if isinstance(env.action_space, gym.spaces.Discrete): - pi_config.output_dim = env.action_space.n - else: - pi_config.output_dim = env.action_space.shape[0] * 2 - - return PPOModuleConfig( - observation_space=env.observation_space, - action_space=env.action_space, - shared_encoder_config=shared_encoder_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, - pi_config=pi_config, - vf_config=vf_config, - shared_encoder=shared_encoder, + pi_config = MLPConfig( + input_dim=32, + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", ) - - -def get_expected_model_config_tf( - env: gym.Env, shared_encoder: bool -) -> PPOTfModuleConfig: - """Get a PPOTfModuleConfig that we would expect from the catalog otherwise. - - Args: - env: Environment for which we build the model later - shared_encoder: If True, build a shared encoder for pi and vf, where pi - encoder and vf encoder will be identity. If False, the shared encoder - will be identity. - - Returns: - A PPOTfModuleConfig containing the relevant configs to build PPOTfRLModule. - """ - assert len(env.observation_space.shape) == 1, ( - "No multidimensional obs space " "supported." + vf_config = MLPConfig( + input_dim=32, + hidden_layer_dims=[32, 1], + hidden_layer_activation="ReLU", ) - obs_dim = env.observation_space.shape[0] - - if shared_encoder: - shared_encoder_config = FCTfConfig( - input_dim=obs_dim, - hidden_layers=[32], - activation="ReLU", - output_dim=32, - ) - else: - shared_encoder_config = IdentityTfConfig(output_dim=obs_dim) - pi_config = FCConfig() - vf_config = FCConfig() - pi_config.input_dim = vf_config.input_dim = shared_encoder_config.output_dim if isinstance(env.action_space, gym.spaces.Discrete): pi_config.output_dim = env.action_space.n else: pi_config.output_dim = env.action_space.shape[0] * 2 - pi_config.hidden_layers = vf_config.hidden_layers = [32] - pi_config.activation = vf_config.activation = "ReLU" - - return PPOTfModuleConfig( + return PPOModuleConfig( observation_space=env.observation_space, action_space=env.action_space, - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, pi_config=pi_config, vf_config=vf_config, - shared_encoder=shared_encoder, ) @@ -207,12 +140,11 @@ def setUpClass(cls): def tearDownClass(cls): ray.shutdown() - def get_ppo_module(self, framwework, env, lstm, shared_encoder): - if framwework == "torch": - config = get_expected_model_config_torch(env, lstm, shared_encoder) + def get_ppo_module(self, framework, env, lstm): + config = get_expected_model_config(env, lstm) + if framework == "torch": module = PPOTorchRLModule(config) else: - config = get_expected_model_config_tf(env, shared_encoder) module = PPOTfRLModule(config) return module @@ -222,7 +154,7 @@ def get_input_batch_from_obs(self, framework, obs): SampleBatch.OBS: convert_to_torch_tensor(obs)[None], } else: - batch = {SampleBatch.OBS: np.array([obs])} + batch = {SampleBatch.OBS: tf.convert_to_tensor([obs])} return batch def test_rollouts(self): @@ -230,21 +162,16 @@ def test_rollouts(self): frameworks = ["torch", "tf2"] env_names = ["CartPole-v1", "Pendulum-v1"] fwd_fns = ["forward_exploration", "forward_inference"] - shared_encoders = [False, True] - ltsms = [False, True] - config_combinations = [frameworks, env_names, fwd_fns, shared_encoders, ltsms] + lstm = [False, True] + config_combinations = [frameworks, env_names, fwd_fns, lstm] for config in itertools.product(*config_combinations): - fw, env_name, fwd_fn, shared_encoder, lstm = config - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue + fw, env_name, fwd_fn, lstm = config if lstm and fw == "tf2": # LSTM not implemented in TF2 yet continue - print(f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" f"={lstm}") + print(f"[FW={fw} | [ENV={env_name}] | [FWD={fwd_fn}] | LSTM" f"={lstm}") env = gym.make(env_name) - module = self.get_ppo_module(fw, env, lstm, shared_encoder) + module = self.get_ppo_module(framework=fw, env=env, lstm=lstm) obs, _ = env.reset() @@ -267,22 +194,17 @@ def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space frameworks = ["torch", "tf2"] env_names = ["CartPole-v1", "Pendulum-v1"] - shared_encoders = [False, True] - ltsms = [False, True] - config_combinations = [frameworks, env_names, shared_encoders, ltsms] + lstm = [False, True] + config_combinations = [frameworks, env_names, lstm] for config in itertools.product(*config_combinations): - fw, env_name, shared_encoder, lstm = config - if lstm and shared_encoder: - # Not yet implemented - # TODO (Artur): Implement - continue + fw, env_name, lstm = config if lstm and fw == "tf2": # LSTM not implemented in TF2 yet continue - print(f"[ENV={env_name}] | [SHARED=" f"{shared_encoder}] | LSTM={lstm}") + print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}") env = gym.make(env_name) - module = self.get_ppo_module(fw, env, lstm, shared_encoder) + module = self.get_ppo_module(fw, env, lstm) # collect a batch of data batches = [] @@ -343,6 +265,9 @@ def test_forward_train(self): for param in module.parameters(): self.assertIsNotNone(param.grad) else: + batch = tree.map_structure( + lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch + ) with tf.GradientTape() as tape: fwd_out = module.forward_train(batch) loss = dummy_tf_ppo_loss(batch, fwd_out) diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index 5568d06a4224..d24ba43e7a01 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -1,41 +1,27 @@ -from dataclasses import dataclass -import gymnasium as gym from typing import Mapping, Any, List + +import gymnasium as gym + +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOModuleConfig from ray.rllib.core.rl_module.rl_module import RLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule +from ray.rllib.models.experimental.configs import MLPConfig, IdentityConfig +from ray.rllib.models.experimental.encoder import STATE_OUT +from ray.rllib.models.experimental.tf.encoder import ENCODER_OUT +from ray.rllib.models.experimental.tf.primitives import TfMLP +from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.core.rl_module.encoder_tf import FCTfConfig, IdentityTfConfig from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.models.tf.tf_action_dist import Categorical, Deterministic, DiagGaussian -from ray.rllib.models.tf.primitives import FCNet - tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() -@dataclass -class PPOTfModuleConfig(RLModuleConfig): - """Configuration for the PPO module. - - Attributes: - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. - """ - - observation_space: gym.Space = None - action_space: gym.Space = None - pi_config: FCTfConfig = None - vf_config: FCTfConfig = None - shared_encoder_config: FCTfConfig = None - shared_encoder: bool = True - - class PPOTfRLModule(TfRLModule): - def __init__(self, config: PPOTfModuleConfig): + def __init__(self, config: RLModuleConfig): super().__init__() self.config = config self.setup() @@ -43,20 +29,20 @@ def __init__(self, config: PPOTfModuleConfig): def setup(self) -> None: assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." - self.shared_encoder = self.config.shared_encoder_config.build() + self.encoder = self.config.encoder_config.build(framework="tf") - self.pi = FCNet( - input_dim=self.config.shared_encoder_config.output_dim, + self.pi = TfMLP( + input_dim=self.config.encoder_config.output_dim, output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, + hidden_layer_dims=self.config.pi_config.hidden_layer_dims, + hidden_layer_activation=self.config.pi_config.hidden_layer_activation, ) - self.vf = FCNet( - input_dim=self.config.shared_encoder_config.output_dim, + self.vf = TfMLP( + input_dim=self.config.encoder_config.output_dim, output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, + hidden_layer_dims=self.config.vf_config.hidden_layer_dims, + hidden_layer_activation=self.config.vf_config.hidden_layer_activation, ) self._is_discrete = isinstance( @@ -77,10 +63,14 @@ def output_specs_train(self) -> List[str]: @override(TfRLModule) def _forward_train(self, batch: NestedDict): - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) - action_logits = self.pi(encoder_out) - vf = self.vf(encoder_out) + output = {} + + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + + # Actions + action_logits = self.pi(encoder_out[ENCODER_OUT]) if self._is_discrete: action_dist = Categorical(action_logits) @@ -89,10 +79,10 @@ def _forward_train(self, batch: NestedDict): action_logits, None, action_space=self.config.action_space ) - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.VF_PREDS: tf.squeeze(vf, axis=-1), - } + vf = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.VF_PREDS] = tf.squeeze(vf, axis=-1) + return output @override(TfRLModule) @@ -105,10 +95,13 @@ def output_specs_inference(self) -> List[str]: @override(TfRLModule) def _forward_inference(self, batch) -> Mapping[str, Any]: - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) + output = {} - action_logits = self.pi(encoder_out) + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + + action_logits = self.pi(encoder_out[ENCODER_OUT]) if self._is_discrete: action = tf.math.argmax(action_logits, axis=-1) @@ -116,9 +109,8 @@ def _forward_inference(self, batch) -> Mapping[str, Any]: action, _ = tf.split(action_logits, num_or_size_splits=2, axis=1) action_dist = Deterministic(action, model=None) - output = { - SampleBatch.ACTION_DIST: action_dist, - } + output[SampleBatch.ACTION_DIST] = action_dist + return output @override(TfRLModule) @@ -135,11 +127,13 @@ def output_specs_exploration(self) -> List[str]: @override(TfRLModule) def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: - obs = batch[SampleBatch.OBS] - encoder_out = self.shared_encoder(obs) + output = {} + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] - action_logits = self.pi(encoder_out) - vf = self.vf(encoder_out) + action_logits = self.pi(encoder_out[ENCODER_OUT]) + vf = self.vf(encoder_out[ENCODER_OUT]) if self._is_discrete: action_dist = Categorical(action_logits) @@ -147,11 +141,11 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = DiagGaussian( action_logits, None, action_space=self.config.action_space ) - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.ACTION_DIST_INPUTS: action_logits, - SampleBatch.VF_PREDS: tf.squeeze(vf, axis=-1), - } + + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.ACTION_DIST_INPUTS] = action_logits + output[SampleBatch.VF_PREDS] = tf.squeeze(vf, axis=-1) + return output @classmethod @@ -180,14 +174,14 @@ def from_model_config( if use_lstm: raise ValueError("LSTM not supported by PPOTfRLModule yet.") if vf_share_layers: - shared_encoder_config = FCTfConfig( + encoder_config = MLPConfig( input_dim=obs_dim, - hidden_layers=fcnet_hiddens, - activation=activation, + hidden_layer_dims=fcnet_hiddens, + hidden_layer_activation=activation, output_dim=model_config["fcnet_hiddens"][-1], ) else: - shared_encoder_config = IdentityTfConfig(output_dim=obs_dim) + encoder_config = IdentityConfig(output_dim=obs_dim) assert isinstance( observation_space, gym.spaces.Box ), "This simple PPOModule only supports Box observation space." @@ -199,23 +193,23 @@ def from_model_config( assert isinstance(action_space, (gym.spaces.Discrete, gym.spaces.Box)), ( "This simple PPOModule only supports Discrete and Box action space.", ) - pi_config = FCTfConfig() - vf_config = FCTfConfig() - shared_encoder_config.input_dim = observation_space.shape[0] - pi_config.input_dim = shared_encoder_config.output_dim - pi_config.hidden_layers = fcnet_hiddens + pi_config = MLPConfig() + vf_config = MLPConfig() + encoder_config.input_dim = observation_space.shape[0] + pi_config.input_dim = encoder_config.output_dim + pi_config.hidden_layer_dims = fcnet_hiddens if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 # build vf network - vf_config.input_dim = shared_encoder_config.output_dim - vf_config.hidden_layers = fcnet_hiddens + vf_config.input_dim = encoder_config.output_dim + vf_config.hidden_layer_dims = fcnet_hiddens vf_config.output_dim = 1 - config_ = PPOTfModuleConfig( + config_ = PPOModuleConfig( pi_config=pi_config, vf_config=vf_config, - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, observation_space=observation_space, action_space=action_space, ) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index b3cc51a732f3..2c26c1a02f3c 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,13 +1,18 @@ from dataclasses import dataclass -import gymnasium as gym from typing import Mapping, Any, Union -from ray.rllib.core.rl_module.torch import TorchRLModule +import gymnasium as gym + from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.annotations import override -from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.utils.framework import try_import_torch +from ray.rllib.core.rl_module.torch import TorchRLModule +from ray.rllib.models.experimental.encoder import STATE_OUT +from ray.rllib.models.experimental.configs import MLPConfig, MLPEncoderConfig +from ray.rllib.models.experimental.configs import ( + LSTMEncoderConfig, +) +from ray.rllib.models.experimental.torch.encoder import ( + ENCODER_OUT, +) from ray.rllib.models.specs.specs_dict import SpecDict from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.torch_distributions import ( @@ -15,15 +20,11 @@ TorchDeterministic, TorchDiagGaussian, ) -from ray.rllib.core.rl_module.encoder import ( - FCNet, - FCConfig, - LSTMConfig, - IdentityConfig, - LSTMEncoder, - ENCODER_OUT, -) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override, ExperimentalAPI +from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from ray.rllib.utils.nested_dict import NestedDict torch, nn = try_import_torch() @@ -42,27 +43,26 @@ def get_ppo_loss(fwd_in, fwd_out): return loss +@ExperimentalAPI @dataclass -class PPOModuleConfig(RLModuleConfig): - """Configuration for the PPO module. +class PPOModuleConfig(RLModuleConfig): # TODO (Artur): Move to non-torch-specific file + """Configuration for the PPORLModule. Attributes: - pi_config: The configuration for the policy network. - vf_config: The configuration for the value network. - shared_encoder_config: The configuration for the encoder network. + observation_space: The observation space of the environment. + action_space: The action space of the environment. + encoder_config: The configuration for the encoder network. + pi_config: The configuration for the policy head. + vf_config: The configuration for the value function head. free_log_std: For DiagGaussian action distributions, make the second half of the model outputs floating bias variables instead of state-dependent. This only has an effect is using the default fully connected net. - shared_encoder: Whether to share the encoder between the pi and value """ - pi_config: FCConfig = None - vf_config: FCConfig = None - pi_encoder_config: FCConfig = None - vf_encoder_config: FCConfig = None - shared_encoder_config: FCConfig = None + encoder_config: MLPConfig = None + pi_config: MLPConfig = None + vf_config: MLPConfig = None free_log_std: bool = False - shared_encoder: bool = True class PPOTorchRLModule(TorchRLModule): @@ -72,27 +72,14 @@ def __init__(self, config: PPOModuleConfig) -> None: self.setup() def setup(self) -> None: - assert self.config.pi_config, "pi_config must be provided." assert self.config.vf_config, "vf_config must be provided." + assert self.config.encoder_config, "shared encoder config must be " "provided." - self.shared_encoder = self.config.shared_encoder_config.build() - self.pi_encoder = self.config.pi_encoder_config.build() - self.vf_encoder = self.config.vf_encoder_config.build() - - self.pi = FCNet( - input_dim=self.config.pi_encoder_config.output_dim, - output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, - ) - - self.vf = FCNet( - input_dim=self.config.vf_encoder_config.output_dim, - output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, - ) + # TODO(Artur): Unify to tf and torch setup with Catalog + self.encoder = self.config.encoder_config.build(framework="torch") + self.pi = self.config.pi_config.build(framework="torch") + self.vf = self.config.vf_config.build(framework="torch") self._is_discrete = isinstance( convert_old_gym_space_to_gymnasium_space(self.config.action_space), @@ -123,44 +110,38 @@ def from_model_config( obs_dim = observation_space.shape[0] fcnet_hiddens = model_config["fcnet_hiddens"] - vf_share_layers = model_config["vf_share_layers"] free_log_std = model_config["free_log_std"] - use_lstm = model_config["use_lstm"] + assert ( + model_config.get("vf_share_layers") is False + ), "`vf_share_layers=False` is no longer supported." - if vf_share_layers: - shared_encoder_config = FCConfig( + if model_config["use_lstm"]: + encoder_config = LSTMEncoderConfig( input_dim=obs_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], - ) - else: - shared_encoder_config = IdentityConfig(output_dim=obs_dim) - - if use_lstm: - pi_encoder_config = LSTMConfig( - input_dim=shared_encoder_config.output_dim, hidden_dim=model_config["lstm_cell_size"], batch_first=not model_config["_time_major"], - output_dim=model_config["lstm_cell_size"], num_layers=1, + output_dim=model_config["lstm_cell_size"], ) else: - pi_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], + encoder_config = MLPEncoderConfig( + input_dim=obs_dim, + hidden_layer_dims=fcnet_hiddens[:-1], + hidden_layer_activation=activation, + output_dim=fcnet_hiddens[-1], ) - vf_encoder_config = FCConfig( - input_dim=shared_encoder_config.output_dim, - hidden_layers=fcnet_hiddens, - activation=activation, - output_dim=model_config["fcnet_hiddens"][-1], + pi_config = MLPConfig( + input_dim=encoder_config.output_dim, + hidden_layer_dims=[32], + hidden_layer_activation="ReLU", + ) + vf_config = MLPConfig( + input_dim=encoder_config.output_dim, + hidden_layer_dims=[32, 1], + hidden_layer_activation="ReLU", + output_dim=1, ) - pi_config = FCConfig() - vf_config = FCConfig() assert isinstance( observation_space, gym.spaces.Box @@ -174,41 +155,29 @@ def from_model_config( "This simple PPOModule only supports Discrete and Box action space.", ) - # build pi network - shared_encoder_config.input_dim = observation_space.shape[0] - pi_encoder_config.input_dim = shared_encoder_config.output_dim - pi_config.input_dim = pi_encoder_config.output_dim + # build policy network head + encoder_config.input_dim = observation_space.shape[0] + pi_config.input_dim = encoder_config.output_dim if isinstance(action_space, gym.spaces.Discrete): pi_config.output_dim = action_space.n else: pi_config.output_dim = action_space.shape[0] * 2 - # build vf network - vf_encoder_config.input_dim = shared_encoder_config.output_dim - vf_config.input_dim = vf_encoder_config.output_dim - vf_config.output_dim = 1 - config_ = PPOModuleConfig( observation_space=observation_space, action_space=action_space, - max_seq_len=model_config["max_seq_len"], - shared_encoder_config=shared_encoder_config, + encoder_config=encoder_config, pi_config=pi_config, vf_config=vf_config, - pi_encoder_config=pi_encoder_config, - vf_encoder_config=vf_encoder_config, free_log_std=free_log_std, - shared_encoder=vf_share_layers, ) module = PPOTorchRLModule(config_) return module def get_initial_state(self) -> NestedDict: - if isinstance(self.shared_encoder, LSTMEncoder): - return self.shared_encoder.get_initial_state() - elif isinstance(self.pi_encoder, LSTMEncoder): - return self.pi_encoder.get_initial_state() + if hasattr(self.encoder, "get_initial_state"): + return self.encoder.get_initial_state() else: return NestedDict({}) @@ -222,24 +191,26 @@ def output_specs_inference(self) -> SpecDict: @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder(batch) - pi_enc_out = self.pi_encoder(shared_enc_out) + output = {} - action_logits = self.pi(pi_enc_out[ENCODER_OUT]) + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + # Actions + action_logits = self.pi(encoder_out[ENCODER_OUT]) if self._is_discrete: action = torch.argmax(action_logits, dim=-1) else: action, _ = action_logits.chunk(2, dim=-1) - action_dist = TorchDeterministic(action) - output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = pi_enc_out.get("state_out", {}) + output[SampleBatch.ACTION_DIST] = action_dist + return output @override(RLModule) def input_specs_exploration(self): - return self.shared_encoder.input_spec() + return self.encoder.input_spec @override(RLModule) def output_specs_exploration(self) -> SpecDict: @@ -264,12 +235,20 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) - output = {} + + # Shared encoder + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] + + # Value head + vf_out = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) + + # Policy head + pi_out = self.pi(encoder_out[ENCODER_OUT]) + action_logits = pi_out if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) output[SampleBatch.ACTION_DIST_INPUTS] = {"logits": action_logits} @@ -280,9 +259,6 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST_INPUTS] = {"loc": loc, "scale": scale} output[SampleBatch.ACTION_DIST] = action_dist - # compute the value function - output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1) - output["state_out"] = encoder_out_pi.get("state_out", {}) return output @override(RLModule) @@ -293,7 +269,7 @@ def input_specs_train(self) -> SpecDict: action_dim = self.config.action_space.shape[0] action_spec = TorchTensorSpec("b, h", h=action_dim) - spec_dict = self.shared_encoder.input_spec() + spec_dict = self.encoder.input_spec spec_dict.update({SampleBatch.ACTIONS: action_spec}) if SampleBatch.OBS in spec_dict: spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] @@ -314,30 +290,31 @@ def output_specs_train(self) -> SpecDict: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) + output = {} + + # Shared encoder + encoder_out = self.encoder(batch) + if STATE_OUT in encoder_out: + output[STATE_OUT] = encoder_out[STATE_OUT] - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) - vf = self.vf(encoder_out_vf[ENCODER_OUT]) + # Value head + vf_out = self.vf(encoder_out[ENCODER_OUT]) + output[SampleBatch.VF_PREDS] = vf_out.squeeze(-1) + # Policy head + pi_out = self.pi(encoder_out[ENCODER_OUT]) + action_logits = pi_out if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) else: mu, scale = action_logits.chunk(2, dim=-1) action_dist = TorchDiagGaussian(mu, scale.exp()) - logp = action_dist.logp(batch[SampleBatch.ACTIONS]) entropy = action_dist.entropy() + output[SampleBatch.ACTION_DIST] = action_dist + output[SampleBatch.ACTION_LOGP] = logp + output["entropy"] = entropy - output = { - SampleBatch.ACTION_DIST: action_dist, - SampleBatch.ACTION_LOGP: logp, - SampleBatch.VF_PREDS: vf.squeeze(-1), - "entropy": entropy, - } - - output["state_out"] = encoder_out_pi.get("state_out", {}) return output def __get_action_dist_type(self): diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py deleted file mode 100644 index f3bb22b46900..000000000000 --- a/rllib/core/rl_module/encoder.py +++ /dev/null @@ -1,202 +0,0 @@ -import torch -import torch.nn as nn -import tree -from typing import List - -from dataclasses import dataclass, field - -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.rnn_sequencing import add_time_dimension -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.checker import check_input_specs, check_output_specs -from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.models.torch.primitives import FCNet - -# TODO (Kourosh): Find a better / more straight fwd approach for sub-components - -ENCODER_OUT = "encoder_out" -STATE_IN = "state_in" -STATE_OUT = "state_out" - - -@dataclass -class EncoderConfig: - """Configuration for an encoder network. - - Attributes: - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. - """ - - output_dim: int = None - - -@dataclass -class IdentityConfig(EncoderConfig): - """Configuration for an identity encoder.""" - - def build(self): - return IdentityEncoder(self) - - -@dataclass -class FCConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FullyConnectedEncoder(self) - - -@dataclass -class LSTMConfig(EncoderConfig): - input_dim: int = None - hidden_dim: int = None - num_layers: int = None - batch_first: bool = True - - def build(self): - return LSTMEncoder(self) - - -class Encoder(nn.Module): - def __init__(self, config: EncoderConfig) -> None: - super().__init__() - self.config = config - self._input_spec = self.input_spec() - self._output_spec = self.output_spec() - - def get_initial_state(self): - return [] - - def input_spec(self): - return SpecDict() - - def output_spec(self): - return SpecDict() - - @check_input_specs("_input_spec") - @check_output_specs("_output_spec") - def forward(self, input_dict): - return self._forward(input_dict) - - def _forward(self, input_dict): - raise NotImplementedError - - -class FullyConnectedEncoder(Encoder): - def __init__(self, config: FCConfig) -> None: - super().__init__(config) - - self.net = FCNet( - input_dim=config.input_dim, - hidden_layers=config.hidden_layers, - output_dim=config.output_dim, - activation=config.activation, - ) - - def input_spec(self): - return SpecDict( - {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} - ) - - def output_spec(self): - return SpecDict( - {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} - ) - - def _forward(self, input_dict): - return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} - - -class LSTMEncoder(Encoder): - def __init__(self, config: LSTMConfig) -> None: - super().__init__(config) - - self.lstm = nn.LSTM( - config.input_dim, - config.hidden_dim, - config.num_layers, - batch_first=config.batch_first, - ) - self.linear = nn.Linear(config.hidden_dim, config.output_dim) - - def get_initial_state(self): - config = self.config - return { - "h": torch.zeros(config.num_layers, config.hidden_dim), - "c": torch.zeros(config.num_layers, config.hidden_dim), - } - - def input_spec(self): - config = self.config - return SpecDict( - { - # bxt is just a name for better readability to indicated padded batch - SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - STATE_IN: { - "h": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - "c": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - }, - } - ) - - def output_spec(self): - config = self.config - return SpecDict( - { - ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { - "h": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - "c": TorchTensorSpec( - "b, l, h", h=config.hidden_dim, l=config.num_layers - ), - }, - } - ) - - def _forward(self, input_dict: SampleBatch): - x = input_dict[SampleBatch.OBS] - states = input_dict[STATE_IN] - # states are batch-first when coming in - states = tree.map_structure(lambda x: x.transpose(0, 1), states) - - x = add_time_dimension( - x, - seq_lens=input_dict[SampleBatch.SEQ_LENS], - framework="torch", - time_major=not self.config.batch_first, - ) - states_o = {} - x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) - - x = self.linear(x) - x = x.view(-1, x.shape[-1]) - - return { - ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), - } - - -class IdentityEncoder(Encoder): - def __init__(self, config: EncoderConfig) -> None: - super().__init__(config) - - def _forward(self, input_dict): - return input_dict diff --git a/rllib/core/rl_module/encoder_tf.py b/rllib/core/rl_module/encoder_tf.py deleted file mode 100644 index 5c517f6c745d..000000000000 --- a/rllib/core/rl_module/encoder_tf.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass, field -from typing import List - -from ray.rllib.core.rl_module.encoder import EncoderConfig -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.models.tf.primitives import FCNet, IdentityNetwork - -tf1, tf, tfv = try_import_tf() - - -@dataclass -class FCTfConfig(EncoderConfig): - """Configuration for a fully connected network. - input_dim: The input dimension of the network. It cannot be None. - hidden_layers: The sizes of the hidden layers. - activation: The activation function to use after each layer (except for the - output). - output_activation: The activation function to use for the output layer. - """ - - input_dim: int = None - output_dim: int = None - hidden_layers: List[int] = field(default_factory=lambda: [256, 256]) - activation: str = "ReLU" - - def build(self): - return FCNet( - self.input_dim, self.hidden_layers, self.output_dim, self.activation - ) - - -@dataclass -class IdentityTfConfig(EncoderConfig): - """A network that returns the input as the output.""" - - def build(self): - return IdentityNetwork() diff --git a/rllib/models/configs/encoder.py b/rllib/models/configs/encoder.py deleted file mode 100644 index 38a7f305123a..000000000000 --- a/rllib/models/configs/encoder.py +++ /dev/null @@ -1,83 +0,0 @@ -import abc -from dataclasses import dataclass -from typing import TYPE_CHECKING, Tuple - -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.torch.encoders.vector import TorchVectorEncoder - -if TYPE_CHECKING: - from ray.rllib.models.torch.encoders.vector import Encoder - - -@dataclass -class EncoderConfig: - """The base config for encoder models. - - Each config should define a `build` method that builds a model from the config. - - All user-configurable parameters known before runtime - (e.g. framework, activation, num layers, etc.) should be defined as attributes. - - Parameters unknown before runtime (e.g. the output size of the module providing - input for this module) should be passed as arguments to `build`. This should be - as few params as possible. - - `build` should return an instance of the encoder associated with the config. - - Attributes: - framework_str: The tensor framework to construct a model for. - This can be 'torch', 'tf2', or 'jax'. - """ - - framework_str: str = "torch" - - @abc.abstractmethod - def build(self, input_spec: SpecDict, **kwargs) -> "Encoder": - """Builds the EncoderConfig into an Encoder instance""" - - -@dataclass -class VectorEncoderConfig(EncoderConfig): - """An MLP encoder mappings tensors with shape [..., feature] to [..., output]. - - Attributes: - activation: The type of activation function to use between hidden layers. - Options are 'relu', 'swish', 'tanh', or 'linear' - final_activation: The activation function to use after the final linear layer. - Options are the same as for activation. - hidden_layer_sizes: A list, where each element represents the number of neurons - in that layer. For example, [128, 64] would produce a two-layer MLP with - 128 hidden neurons and 64 hidden neurons. - output_key: Write the output of the encoder to this key in the NestedDict. - """ - - activation: str = "relu" - final_activation: str = "linear" - hidden_layer_sizes: Tuple[int, ...] = (128, 128) - output_key: str = "encoding" - - def build(self, input_spec: SpecDict) -> TorchVectorEncoder: - """Build the config into a VectorEncoder model instance. - - Args: - input_spec: The output spec of the previous module(s) that will feed - inputs to this encoder. - - Returns: - A VectorEncoder of the specified framework. - """ - assert ( - len(self.hidden_layer_sizes) > 1 - ), "Must have at least a single hidden layer" - for k in input_spec.shallow_keys(): - assert isinstance( - input_spec[k].shape[-1], int - ), "Input spec {k} does not define the size of the feature (last) dimension" - - if self.framework_str == "torch": - return TorchVectorEncoder(input_spec, self) - else: - raise NotImplementedError( - "{self.__class__.__name__} not implemented" - " for framework {self.framework}" - ) diff --git a/rllib/models/experimental/README.rst b/rllib/models/experimental/README.rst new file mode 100644 index 000000000000..2ef2007403e2 --- /dev/null +++ b/rllib/models/experimental/README.rst @@ -0,0 +1,2 @@ +This folder holds models that are under development and to be used with RLModules in upcoming versions of RLlib. +They are not yet ready for use in the current version of RLlib. \ No newline at end of file diff --git a/rllib/models/experimental/__init__.py b/rllib/models/experimental/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/models/experimental/base.py b/rllib/models/experimental/base.py new file mode 100644 index 000000000000..bcf822770e97 --- /dev/null +++ b/rllib/models/experimental/base.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +import abc + +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.utils.annotations import ExperimentalAPI + +ForwardOutputType = TensorDict + + +@ExperimentalAPI +@dataclass +class ModelConfig(abc.ABC): + """Base class for model configurations. + + Attributes: + output_dim: The output dimension of the network. + """ + + output_dim: int = None + + @abc.abstractmethod + def build(self, framework: str = "torch"): + """Builds the model. + + Args: + framework: The framework to use for building the model. + """ + raise NotImplementedError + + +class Model: + """Framework-agnostic base class for RLlib models. + + Models are low-level neural network components that offer input- and + output-specification, a forward method, and a get_initial_state method. Models + are composed in RLModules. + """ + + def __init__(self, config: ModelConfig): + self.config = config + + @abc.abstractmethod + def get_initial_state(self): + """Returns the initial state of the model.""" + return {} + + @property + @abc.abstractmethod + def output_spec(self) -> SpecDict: + """Returns the outputs spec of this model. + + This can include the state specs as well. + + Examples: + >>> ... + """ + # If no checking is needed, we can simply return an empty spec. + return SpecDict() + + @property + @abc.abstractmethod + def input_spec(self) -> SpecDict: + """Returns the input spec of this model. + + This can include the state specs as well. + + Examples: + >>> ... + """ + # If no checking is needed, we can simply return an empty spec. + return SpecDict() diff --git a/rllib/models/experimental/configs.py b/rllib/models/experimental/configs.py new file mode 100644 index 000000000000..6a6c84772063 --- /dev/null +++ b/rllib/models/experimental/configs.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass, field +from typing import List, Callable +import functools + +from ray.rllib.models.experimental.base import ModelConfig, Model +from ray.rllib.models.experimental.encoder import Encoder +from ray.rllib.utils.annotations import DeveloperAPI + + +@DeveloperAPI +def _framework_implemented(torch: bool = True, tf2: bool = True): + """Decorator to check if a model was implemented in a framework. + + Args: + torch: Whether we can build this model with torch. + tf2: Whether we can build this model with tf2. + + Returns: + The decorated function. + + Raises: + ValueError: If the framework is not available to build. + """ + accepted = [] + if torch: + accepted.append("torch") + if tf2: + accepted.append("tf") + accepted.append("tf2") + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + def checked_build(self, framework, **kwargs): + if framework not in accepted: + raise ValueError(f"Framework {framework} not supported.") + return fn(self, framework, **kwargs) + + return checked_build + + return decorator + + +@dataclass +class MLPConfig(ModelConfig): + """Configuration for a fully connected network. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layer_dims: The sizes of the hidden layers. + hidden_layer_activation: The activation function to use after each layer ( + except for the output). + output_activation: The activation function to use for the output layer. + """ + + input_dim: int = None + hidden_layer_dims: List[int] = field(default_factory=lambda: [256, 256]) + hidden_layer_activation: str = "ReLU" + output_activation: str = "linear" + + @_framework_implemented() + def build(self, framework: str = "torch") -> Model: + if framework == "torch": + from ray.rllib.models.experimental.torch.mlp import TorchMLPModel + + return TorchMLPModel(self) + else: + from ray.rllib.models.experimental.tf.mlp import TfMLPModel + + return TfMLPModel(self) + + +@dataclass +class MLPEncoderConfig(MLPConfig): + @_framework_implemented() + def build(self, framework: str = "torch") -> Encoder: + if framework == "torch": + from ray.rllib.models.experimental.torch.encoder import TorchMLPEncoder + + return TorchMLPEncoder(self) + else: + from ray.rllib.models.experimental.tf.encoder import TfMLPEncoder + + return TfMLPEncoder(self) + + +@dataclass +class LSTMEncoderConfig(ModelConfig): + input_dim: int = None + hidden_dim: int = None + num_layers: int = None + batch_first: bool = True + output_activation: str = "linear" + + @_framework_implemented(tf2=False) + def build(self, framework: str = "torch") -> Encoder: + if framework == "torch": + from ray.rllib.models.experimental.torch.encoder import TorchLSTMEncoder + + return TorchLSTMEncoder(self) + + +@dataclass +class IdentityConfig(ModelConfig): + """Configuration for an identity encoder.""" + + @_framework_implemented() + def build(self, framework: str = "torch") -> Model: + if framework == "torch": + from ray.rllib.models.experimental.torch.encoder import TorchIdentityEncoder + + return TorchIdentityEncoder(self) + else: + from ray.rllib.models.experimental.tf.encoder import TfIdentityEncoder + + return TfIdentityEncoder(self) diff --git a/rllib/models/experimental/encoder.py b/rllib/models/experimental/encoder.py new file mode 100644 index 000000000000..bf1e85d4ff8e --- /dev/null +++ b/rllib/models/experimental/encoder.py @@ -0,0 +1,57 @@ +import abc + +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.utils.typing import TensorType +from ray.rllib.models.experimental.base import Model, ForwardOutputType + +STATE_IN: str = "state_in" +STATE_OUT: str = "state_out" + + +class Encoder(Model): + """The framework-agnostic base class for all encoders RLlib produces. + + Encoders are used to encode observations into a latent space in RLModules. + Therefore, their input_spec contains the observation space dimensions. + Similarly, their output_spec usually the latent space dimensions. + Encoders can be recurrent, in which case they should also have state_specs. + + Encoders encode observations into a latent space that serve as input to heads. + Outputs of encoders are generally of shape (B, latent_dim) or (B, T, latent_dim). + That is, for time-series data, we encode into the latent space for each time step. + This should be reflected in the output_spec. + """ + + def get_initial_state(self) -> TensorType: + """Returns the initial state of the encoder. + + It can be left empty if this encoder is not stateful. + + Examples: + >>> encoder = Encoder(...) + >>> state = encoder.get_initial_state() + >>> out = encoder.forward({"obs": ..., STATE_IN: state}) + """ + return {} + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) + @abc.abstractmethod + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + """Computes the output of this module for each timestep. + + Outputs and inputs are subjected to spec checking. + + Args: + inputs: A TensorDict containing model inputs + kwargs: For forwards compatibility + + Returns: + outputs: A TensorDict containing model outputs + + Examples: + # This is abstract, see the framework implementations + >>> out = encoder.forward({"obs": np.arange(10)})) + """ + raise NotImplementedError diff --git a/rllib/models/experimental/tf/__init__.py b/rllib/models/experimental/tf/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/models/experimental/tf/encoder.py b/rllib/models/experimental/tf/encoder.py new file mode 100644 index 000000000000..4ee2ea4433ea --- /dev/null +++ b/rllib/models/experimental/tf/encoder.py @@ -0,0 +1,159 @@ +import torch +import torch.nn as nn +import tree + +from ray.rllib.models.experimental.base import ( + ForwardOutputType, + ModelConfig, +) +from ray.rllib.models.experimental.encoder import ( + Encoder, + STATE_IN, + STATE_OUT, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.experimental.tf.primitives import TfMLP +from ray.rllib.policy.rnn_sequencing import add_time_dimension +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_tf import TFTensorSpecs +from ray.rllib.models.experimental.torch.encoder import ENCODER_OUT +from ray.rllib.models.experimental.tf.primitives import TfModel + + +class TfMLPEncoder(Encoder, TfModel): + """A fully connected encoder.""" + + def __init__(self, config: ModelConfig) -> None: + Encoder.__init__(self, config) + TfModel.__init__(self, config) + + self.net = TfMLP( + input_dim=config.input_dim, + hidden_layer_dims=config.hidden_layer_dims, + output_dim=config.output_dim, + hidden_layer_activation=config.hidden_layer_activation, + ) + + @property + def input_spec(self): + return SpecDict( + {SampleBatch.OBS: TFTensorSpecs("b, h", h=self.config.input_dim)} + ) + + @property + def output_spec(self): + return SpecDict({ENCODER_OUT: TFTensorSpecs("b, h", h=self.config.output_dim)}) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} + + +class LSTMEncoder(Encoder, TfModel): + """An encoder that uses an LSTM cell and a linear layer.""" + + def __init__(self, config: ModelConfig) -> None: + Encoder.__init__(self, config) + TfModel.__init__(self, config) + + self.lstm = nn.LSTM( + config.input_dim, + config.hidden_dim, + config.num_layers, + batch_first=config.batch_first, + ) + self.linear = nn.Linear(config.hidden_dim, config.output_dim) + + def get_initial_state(self): + config = self.config + return { + "h": torch.zeros(config.num_layers, config.hidden_dim), + "c": torch.zeros(config.num_layers, config.hidden_dim), + } + + @property + def input_spec(self): + config = self.config + return SpecDict( + { + # bxt is just a name for better readability to indicated padded batch + SampleBatch.OBS: TFTensorSpecs("bxt, h", h=config.input_dim), + STATE_IN: { + "h": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + SampleBatch.SEQ_LENS: None, + } + ) + + @property + def output_spec(self): + config = self.config + return SpecDict( + { + ENCODER_OUT: TFTensorSpecs("bxt, h", h=config.output_dim), + STATE_OUT: { + "h": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TFTensorSpecs( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + } + ) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + x = inputs[SampleBatch.OBS] + states = inputs[STATE_IN] + # states are batch-first when coming in + states = tree.map_structure(lambda x: x.transpose(0, 1), states) + + x = add_time_dimension( + x, + seq_lens=inputs[SampleBatch.SEQ_LENS], + framework="torch", + time_major=not self.config.batch_first, + ) + states_o = {} + x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) + + x = self.linear(x) + x = x.view(-1, x.shape[-1]) + + return { + ENCODER_OUT: x, + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + } + + +class TfIdentityEncoder(TfModel): + """An encoder that does nothing but passing on inputs. + + We use this so that we avoid having many if/else statements in the RLModule. + """ + + @property + def input_spec(self): + return SpecDict( + # Use the output dim as input dim because identity. + {SampleBatch.OBS: TFTensorSpecs("b, h", h=self.config.output_dim)} + ) + + @property + def output_spec(self): + return SpecDict({ENCODER_OUT: TFTensorSpecs("b, h", h=self.config.output_dim)}) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: inputs[SampleBatch.OBS]} diff --git a/rllib/models/experimental/tf/mlp.py b/rllib/models/experimental/tf/mlp.py new file mode 100644 index 000000000000..beebabd02c68 --- /dev/null +++ b/rllib/models/experimental/tf/mlp.py @@ -0,0 +1,34 @@ +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_tf import TFTensorSpecs +from ray.rllib.utils import try_import_tf +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.experimental.tf.primitives import TfMLP, TfModel +from ray.rllib.models.experimental.base import ModelConfig, ForwardOutputType + +tf1, tf, tfv = try_import_tf() + + +class TfMLPModel(TfModel): + def __init__(self, config: ModelConfig) -> None: + TfModel.__init__(self, config) + + self.net = TfMLP( + input_dim=config.input_dim, + hidden_layer_dims=config.hidden_layer_dims, + output_dim=config.output_dim, + hidden_layer_activation=config.hidden_layer_activation, + output_activation=config.output_activation, + ) + + @property + def input_spec(self): + return TFTensorSpecs("b, h", h=self.config.input_dim) + + @property + def output_spec(self): + return TFTensorSpecs("b, h", h=self.config.output_dim) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def __call__(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return self.net(inputs) diff --git a/rllib/models/experimental/tf/primitives.py b/rllib/models/experimental/tf/primitives.py new file mode 100644 index 000000000000..a7c45c11afe0 --- /dev/null +++ b/rllib/models/experimental/tf/primitives.py @@ -0,0 +1,107 @@ +from typing import List +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.models.specs.checker import ( + is_input_decorated, + is_output_decorated, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.experimental.base import Model +from ray.rllib.utils.typing import TensorType +from ray.rllib.models.utils import get_activation_fn +from typing import Tuple +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, +) + +_, tf, _ = try_import_tf() + + +def _call_not_decorated(input_or_output): + return ( + f"forward not decorated with {input_or_output} specification. Decorate " + f"with @check_{input_or_output}_specs() to define a specification. See " + f"BaseModel for examples." + ) + + +class TfModel(Model, tf.Module): + """Base class for RLlib models. + + This class is used to define the general interface for RLlib models and checks + whether inputs and outputs are checked with `check_input_specs()` and + `check_output_specs()` respectively. + """ + + def __init__(self, config): + super().__init__(config) + # automatically apply spec checking + if not is_input_decorated(self.__call__): + self.__call__ = check_input_specs("input_spec", cache=True)(self.__call__) + if not is_output_decorated(self.__call__): + self.__call__ = check_output_specs("output_spec", cache=True)(self.__call__) + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) + def __call__(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: + """Returns the output of this model for the given input. + + Args: + input_dict: The input tensors. + + Returns: + Tuple[TensorDict, List[TensorType]]: The output tensors. + """ + raise NotImplementedError + + +class TfMLP(tf.Module): + """A multi-layer perceptron. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layer_dims: The sizes of the hidden layers. + output_dim: The output dimension of the network. + hidden_layer_activation: The activation function to use after each layer. + Currently "Linear" (no activation) and "ReLU" are supported. + output_activation: The activation function to use for the output layer. + """ + + def __init__( + self, + input_dim: int, + hidden_layer_dims: List[int], + output_dim: int, + hidden_layer_activation: str = "linear", + output_activation: str = "linear", + ): + super().__init__() + + assert hidden_layer_activation in ("linear", "ReLU", "Tanh"), ( + "Activation function not " "supported" + ) + assert input_dim is not None, "Input dimension must not be None" + assert output_dim is not None, "Output dimension must not be None" + layers = [] + hidden_layer_activation = hidden_layer_activation.lower() + # input = tf.keras.layers.Dense(input_dim, activation=activation) + layers.append(tf.keras.Input(shape=(input_dim,))) + for i in range(len(hidden_layer_dims)): + layers.append( + tf.keras.layers.Dense( + hidden_layer_dims[i], activation=hidden_layer_activation + ) + ) + if output_activation != "linear": + output_activation = get_activation_fn(output_activation, framework="torch") + final_layer = tf.keras.layers.Dense( + output_dim, activation=output_activation + ) + else: + final_layer = tf.keras.layers.Dense(output_dim) + + layers.append(final_layer) + self.network = tf.keras.Sequential(layers) + + def __call__(self, inputs): + return self.network(inputs) diff --git a/rllib/models/experimental/torch/__init__.py b/rllib/models/experimental/torch/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rllib/models/experimental/torch/encoder.py b/rllib/models/experimental/torch/encoder.py new file mode 100644 index 000000000000..cb87f9b8dc78 --- /dev/null +++ b/rllib/models/experimental/torch/encoder.py @@ -0,0 +1,167 @@ +import torch +import torch.nn as nn +import tree + +from ray.rllib.models.experimental.base import ( + ForwardOutputType, + ModelConfig, +) +from ray.rllib.models.experimental.encoder import ( + Encoder, + STATE_IN, + STATE_OUT, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.policy.rnn_sequencing import add_time_dimension +from ray.rllib.models.specs.specs_dict import SpecDict +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.experimental.torch.primitives import TorchMLP, TorchModel + +ENCODER_OUT: str = "encoder_out" + + +class TorchMLPEncoder(TorchModel, Encoder): + """A fully connected encoder.""" + + def __init__(self, config: ModelConfig) -> None: + TorchModel.__init__(self, config) + Encoder.__init__(self, config) + + self.net = TorchMLP( + input_dim=config.input_dim, + hidden_layer_dims=config.hidden_layer_dims, + output_dim=config.output_dim, + hidden_layer_activation=config.hidden_layer_activation, + ) + + @property + @override(TorchModel) + def input_spec(self) -> SpecDict: + return SpecDict( + {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} + ) + + @property + @override(TorchModel) + def output_spec(self) -> SpecDict: + return SpecDict( + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: self.net(inputs[SampleBatch.OBS])} + + +class TorchLSTMEncoder(TorchModel, Encoder): + """An encoder that uses an LSTM cell and a linear layer.""" + + def __init__(self, config: ModelConfig) -> None: + TorchModel.__init__(self, config) + + self.lstm = nn.LSTM( + config.input_dim, + config.hidden_dim, + config.num_layers, + batch_first=config.batch_first, + ) + self.linear = nn.Linear(config.hidden_dim, config.output_dim) + + def get_initial_state(self): + config = self.config + return { + "h": torch.zeros(config.num_layers, config.hidden_dim), + "c": torch.zeros(config.num_layers, config.hidden_dim), + } + + @property + @override(TorchModel) + def input_spec(self) -> SpecDict: + config = self.config + return SpecDict( + { + # bxt is just a name for better readability to indicated padded batch + SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), + STATE_IN: { + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + SampleBatch.SEQ_LENS: None, + } + ) + + @property + @override(TorchModel) + def output_spec(self) -> SpecDict: + config = self.config + return SpecDict( + { + ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), + STATE_OUT: { + "h": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + "c": TorchTensorSpec( + "b, l, h", h=config.hidden_dim, l=config.num_layers + ), + }, + } + ) + + @check_input_specs("input_spec", filter=True, cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + x = inputs[SampleBatch.OBS] + states = inputs[STATE_IN] + # states are batch-first when coming in + states = tree.map_structure(lambda x: x.transpose(0, 1), states) + + x = add_time_dimension( + x, + seq_lens=inputs[SampleBatch.SEQ_LENS], + framework="torch", + time_major=not self.config.batch_first, + ) + states_o = {} + x, (states_o["h"], states_o["c"]) = self.lstm(x, (states["h"], states["c"])) + + x = self.linear(x) + x = x.view(-1, x.shape[-1]) + + return { + ENCODER_OUT: x, + STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + } + + +class TorchIdentityEncoder(TorchModel): + """An encoder that does nothing but passing on inputs. + + We use this so that we avoid having many if/else statements in the RLModule. + """ + + @property + def input_spec(self) -> SpecDict: + return SpecDict( + # Use the output dim as input dim because identity. + {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + @property + def output_spec(self) -> SpecDict: + return SpecDict( + {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + ) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return {ENCODER_OUT: inputs[SampleBatch.OBS]} diff --git a/rllib/models/experimental/torch/mlp.py b/rllib/models/experimental/torch/mlp.py new file mode 100644 index 000000000000..4d5afa11d502 --- /dev/null +++ b/rllib/models/experimental/torch/mlp.py @@ -0,0 +1,39 @@ +import torch.nn as nn + +from ray.rllib.models.experimental.base import ForwardOutputType, Model, ModelConfig +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs +from ray.rllib.models.specs.specs_torch import TorchTensorSpec +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.models.experimental.torch.primitives import TorchMLP +from ray.rllib.models.experimental.torch.primitives import TorchModel +from ray.rllib.utils.annotations import override + + +class TorchMLPModel(TorchModel, nn.Module): + def __init__(self, config: ModelConfig) -> None: + nn.Module.__init__(self) + TorchModel.__init__(self, config) + + self.net = TorchMLP( + input_dim=config.input_dim, + hidden_layer_dims=config.hidden_layer_dims, + output_dim=config.output_dim, + hidden_layer_activation=config.hidden_layer_activation, + output_activation=config.output_activation, + ) + + @property + @override(Model) + def input_spec(self) -> TorchTensorSpec: + return TorchTensorSpec("b, h", h=self.config.input_dim) + + @property + @override(Model) + def output_spec(self) -> TorchTensorSpec: + return TorchTensorSpec("b, h", h=self.config.output_dim) + + @check_input_specs("input_spec", cache=False) + @check_output_specs("output_spec", cache=False) + @override(TorchModel) + def forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: + return self.net(inputs) diff --git a/rllib/models/experimental/torch/primitives.py b/rllib/models/experimental/torch/primitives.py new file mode 100644 index 000000000000..11156380133d --- /dev/null +++ b/rllib/models/experimental/torch/primitives.py @@ -0,0 +1,99 @@ +from typing import List, Optional +from typing import Tuple + +from ray.rllib.models.experimental.base import Model +from ray.rllib.models.specs.checker import ( + is_input_decorated, + is_output_decorated, +) +from ray.rllib.models.temp_spec_classes import TensorDict +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType +from ray.rllib.models.experimental.base import ModelConfig +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.models.specs.checker import ( + check_input_specs, + check_output_specs, +) + +torch, nn = try_import_torch() + + +class TorchModel(nn.Module, Model): + """Base class for torch models. + + This class is used to define the general interface for torch models and checks + whether inputs and outputs are checked with `check_input_specs()` and + `check_output_specs()` respectively. + """ + + def __init__(self, config: ModelConfig): + nn.Module.__init__(self) + Model.__init__(self, config) + # automatically apply spec checking + if not is_input_decorated(self.forward): + self.forward = check_input_specs("input_spec", cache=True)(self.forward) + if not is_output_decorated(self.forward): + self.forward = check_output_specs("output_spec", cache=True)(self.forward) + + @check_input_specs("input_spec", cache=True) + @check_output_specs("output_spec", cache=True) + def forward(self, input_dict: TensorDict) -> Tuple[TensorDict, List[TensorType]]: + """Returns the output of this model for the given input. + + Args: + input_dict: The input tensors. + + Returns: + Tuple[TensorDict, List[TensorType]]: The output tensors. + """ + raise NotImplementedError + + +class TorchMLP(nn.Module): + """A multi-layer perceptron. + + Attributes: + input_dim: The input dimension of the network. It cannot be None. + hidden_layer_dims: The sizes of the hidden layers. + output_dim: The output dimension of the network. if None, the last layer would + be the last hidden layer. + hidden_layer_activation: The activation function to use after each layer. + output_activation: The activation function to use for the output layer. + """ + + def __init__( + self, + input_dim: int, + hidden_layer_dims: List[int], + output_dim: Optional[int] = None, + hidden_layer_activation: str = "linear", + output_activation: str = "linear", + ): + super().__init__() + self.input_dim = input_dim + hidden_layer_dims = hidden_layer_dims + + activation_class = getattr(nn, hidden_layer_activation, lambda: None)() + layers = [] + layers.append(nn.Linear(input_dim, hidden_layer_dims[0])) + for i in range(len(hidden_layer_dims) - 1): + if hidden_layer_activation != "linear": + layers.append(activation_class) + layers.append(nn.Linear(hidden_layer_dims[i], hidden_layer_dims[i + 1])) + + if output_dim is not None: + if hidden_layer_activation != "linear": + layers.append(activation_class) + layers.append(nn.Linear(hidden_layer_dims[-1], output_dim)) + self.output_dim = output_dim + else: + self.output_dim = hidden_layer_dims[-1] + + if output_activation != "linear": + layers.append(get_activation_fn(output_activation, framework="torch")) + + self.mlp = nn.Sequential(*layers) + + def forward(self, x): + return self.mlp(x) diff --git a/rllib/models/specs/checker.py b/rllib/models/specs/checker.py index b7ca04c74325..1662c58aa07d 100644 --- a/rllib/models/specs/checker.py +++ b/rllib/models/specs/checker.py @@ -336,3 +336,15 @@ def wrapper(self, input_data, **kwargs): return wrapper return decorator + + +@DeveloperAPI +def is_input_decorated(obj: object) -> bool: + """Returns True if the object is decorated with `check_input_specs`.""" + return hasattr(obj, "__checked_input_specs__") + + +@DeveloperAPI +def is_output_decorated(obj: object) -> bool: + """Returns True if the object is decorated with `check_output_specs`.""" + return hasattr(obj, "__checked_output_specs__") diff --git a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py b/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py deleted file mode 100644 index 0f3743be2f43..000000000000 --- a/rllib/models/torch/encoders/tests/test_torch_vector_encoder.py +++ /dev/null @@ -1,70 +0,0 @@ -import unittest - -import torch - -from ray.rllib.models.configs.encoder import VectorEncoderConfig -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.specs_torch import TorchTensorSpec -from ray.rllib.utils.nested_dict import NestedDict - - -class TestConfig(unittest.TestCase): - def test_error_no_feature_dim(self): - """Ensure we error out if we don't know the input dim""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c")}) - c = VectorEncoderConfig() - with self.assertRaises(AssertionError): - c.build(input_spec) - - def test_default_build(self): - """Test building with the default config""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - c.build(input_spec) - - def test_nonlinear_final_build(self): - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig(final_activation="relu") - c.build(input_spec) - - def test_default_forward(self): - """Test the default config/model _forward implementation""" - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - m = c.build(input_spec) - inputs = NestedDict({"bork": torch.rand((2, 4, 3))}) - outputs, _ = m.unroll(inputs, NestedDict()) - self.assertEqual(outputs[c.output_key].shape[-1], c.hidden_layer_sizes[-1]) - self.assertEqual(outputs[c.output_key].shape[:-1], (2, 4)) - - def test_two_inputs_forward(self): - """Test the default model when we have two items in the input_spec. - These two items will be concatenated and fed thru the mlp.""" - """Test the default config/model _forward implementation""" - input_spec = SpecDict( - { - "bork": TorchTensorSpec("a, b, c", c=3), - "dork": TorchTensorSpec("x, y, z", z=5), - } - ) - c = VectorEncoderConfig() - m = c.build(input_spec) - self.assertEqual(m.net[0].in_features, 8) - inputs = NestedDict( - {"bork": torch.rand((2, 4, 3)), "dork": torch.rand((2, 4, 5))} - ) - outputs, _ = m.unroll(inputs, NestedDict()) - self.assertEqual(outputs[c.output_key].shape[-1], c.hidden_layer_sizes[-1]) - self.assertEqual(outputs[c.output_key].shape[:-1], (2, 4)) - - def test_deep_build(self): - input_spec = SpecDict({"bork": TorchTensorSpec("a, b, c", c=3)}) - c = VectorEncoderConfig() - c.build(input_spec) - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/models/torch/encoders/vector.py b/rllib/models/torch/encoders/vector.py deleted file mode 100644 index 91ef65d71f44..000000000000 --- a/rllib/models/torch/encoders/vector.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import TYPE_CHECKING -from ray.rllib.models.specs.specs_torch import TorchTensorSpec - -import torch -from torch import nn - -from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.torch.model import TorchModel -from ray.rllib.models.utils import get_activation_fn -from ray.rllib.utils.nested_dict import NestedDict - -from ray.rllib.models.utils import input_to_output_spec - -if TYPE_CHECKING: - from ray.rllib.models.configs.encoder import VectorEncoderConfig - - -class TorchVectorEncoder(TorchModel): - """A torch implementation of an MLP encoder. - - This encoder concatenates inputs along the last dimension, - then pushes them through a series of linear layers and nonlinear activations. - """ - - @property - def input_spec(self) -> SpecDict: - return self._input_spec - - @property - def output_spec(self) -> SpecDict: - return self._output_spec - - def __init__( - self, - input_spec: SpecDict, - config: "VectorEncoderConfig", - ): - super().__init__(config=config) - # Setup input and output specs - self._input_spec = input_spec - self._output_spec = input_to_output_spec( - input_spec=input_spec, - num_input_feature_dims=1, - output_key=config.output_key, - output_feature_spec=TorchTensorSpec("f", f=config.hidden_layer_sizes[-1]), - ) - # Returns the size of the feature dimension for the input tensors - prev_size = sum(v.shape[-1] for v in input_spec.values()) - - # Construct layers - layers = [] - activation = ( - None - if config.activation == "linear" - else get_activation_fn(config.activation, framework=config.framework_str)() - ) - for size in config.hidden_layer_sizes[:-1]: - layers += [nn.Linear(prev_size, size)] - layers += [activation] if activation is not None else [] - prev_size = size - - # Final layer - layers += [ - nn.Linear(config.hidden_layer_sizes[-2], config.hidden_layer_sizes[-1]) - ] - if config.final_activation != "linear": - layers += [ - get_activation_fn( - config.final_activation, framework=config.framework_str - )() - ] - - self.net = nn.Sequential(*layers) - - def _forward(self, inputs: NestedDict) -> NestedDict: - """Runs the forward pass of the MLP. Call this via unroll(). - - Args: - inputs: The nested dictionary of inputs - - Returns: - The nested dictionary of outputs - """ - # Ensure all inputs have matching dims before concat - # so we can emit an informative error message - first_key, first_tensor = list(inputs.items())[0] - for k, tensor in inputs.items(): - assert tensor.shape[:-1] == first_tensor.shape[:-1], ( - "Inputs have mismatching dimensions, all dims but the last should " - f"be equal: {first_key}: {first_tensor.shape} != {k}: {tensor.shape}" - ) - - # Concatenate all input along the feature dim - x = torch.cat(list(inputs.values()), dim=-1) - [out_key] = self.output_spec.keys() - inputs[out_key] = self.net(x) - return inputs diff --git a/rllib/models/torch/primitives.py b/rllib/models/torch/primitives.py index 191a0ff35e5a..eaa43a6db3d4 100644 --- a/rllib/models/torch/primitives.py +++ b/rllib/models/torch/primitives.py @@ -11,8 +11,8 @@ class FCNet(nn.Module): Attributes: input_dim: The input dimension of the network. It cannot be None. - output_dim: The output dimension of the network. if None, the last layer would - be the last hidden layer. + output_dim: The output dimension of the network. If None, the output_dim will + be the number of nodes in the last hidden layer. hidden_layers: The sizes of the hidden layers. activation: The activation function to use after each layer. """