diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 30ab1a6e134a..a2f22929893e 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -15,8 +15,6 @@ FCConfig, IdentityConfig, LSTMConfig, - STATE_IN, - STATE_OUT, ) from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -76,8 +74,18 @@ def get_expected_model_config( activation="ReLU", ) - pi_config = FCConfig() - vf_config = FCConfig() + pi_config = FCConfig( + input_dim=pi_encoder_config.output_dim, + hidden_layers=[16], + ) + if isinstance(env.action_space, gym.spaces.Discrete): + pi_config.output_dim = env.action_space.n + else: + pi_config.output_dim = env.action_space.shape[0] * 2 + + vf_config = FCConfig( + input_dim=vf_encoder_config.output_dim, hidden_layers=[16], output_dim=1 + ) if isinstance(env.action_space, gym.spaces.Discrete): pi_config.output_dim = env.action_space.n @@ -110,14 +118,14 @@ def test_rollouts(self): for env_name in ["CartPole-v1", "Pendulum-v1"]: for fwd_fn in ["forward_exploration", "forward_inference"]: for shared_encoder in [False, True]: - for lstm in [True, False]: + for lstm in [False, True]: if lstm and shared_encoder: # Not yet implemented # TODO (Artur): Implement continue print( f"[ENV={env_name}] | [SHARED={shared_encoder}] | LSTM" - f"={lstm}" + f"={lstm} | [FWD={fwd_fn}" ) env = gym.make(env_name) @@ -135,7 +143,7 @@ def test_rollouts(self): state_in = tree.map_structure( lambda x: x[None], convert_to_torch_tensor(state_in) ) - batch[STATE_IN] = state_in + batch["state_in_0"] = state_in batch[SampleBatch.SEQ_LENS] = torch.Tensor([1]) if fwd_fn == "forward_exploration": @@ -146,8 +154,8 @@ def test_rollouts(self): def test_forward_train(self): # TODO: Add BreakoutNoFrameskip-v4 to cover a 3D obs space for env_name in ["CartPole-v1", "Pendulum-v1"]: - for shared_encoder in [False, True]: - for lstm in [True, False]: + for shared_encoder in [False]: + for lstm in [True]: if lstm and shared_encoder: # Not yet implemented # TODO (Artur): Implement @@ -175,7 +183,7 @@ def test_forward_train(self): if lstm: input_batch = { SampleBatch.OBS: convert_to_torch_tensor(obs)[None], - STATE_IN: state_in, + "state_in_0": state_in, SampleBatch.SEQ_LENS: np.array([1]), } else: @@ -196,8 +204,7 @@ def test_forward_train(self): SampleBatch.TRUNCATEDS: np.array(truncated), } if lstm: - assert STATE_OUT in fwd_out - state_in = fwd_out[STATE_OUT] + state_in = fwd_out["state_out_0"] batches.append(output_batch) obs = new_obs tstep += 1 @@ -210,7 +217,7 @@ def test_forward_train(self): for k, v in batch.items() } if lstm: - fwd_in[STATE_IN] = initial_state + fwd_in["state_in_0"] = initial_state fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10]) # forward train diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 40c0d1fd302c..a3d173d13334 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -16,14 +16,13 @@ TorchDiagGaussian, ) from ray.rllib.core.rl_module.encoder import ( - FCNet, FCConfig, LSTMConfig, IdentityConfig, LSTMEncoder, - ENCODER_OUT, ) from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space +from ray.rllib.models.base_model import BaseModelIOKeys torch, nn = try_import_torch() @@ -79,20 +78,8 @@ def setup(self) -> None: self.shared_encoder = self.config.shared_encoder_config.build() self.pi_encoder = self.config.pi_encoder_config.build() self.vf_encoder = self.config.vf_encoder_config.build() - - self.pi = FCNet( - input_dim=self.config.pi_encoder_config.output_dim, - output_dim=self.config.pi_config.output_dim, - hidden_layers=self.config.pi_config.hidden_layers, - activation=self.config.pi_config.activation, - ) - - self.vf = FCNet( - input_dim=self.config.vf_encoder_config.output_dim, - output_dim=1, - hidden_layers=self.config.vf_config.hidden_layers, - activation=self.config.vf_config.activation, - ) + self.pi = self.config.pi_config.build() + self.vf = self.config.vf_config.build() self._is_discrete = isinstance( convert_old_gym_space_to_gymnasium_space(self.config.action_space), @@ -212,21 +199,36 @@ def get_initial_state(self) -> NestedDict: else: return NestedDict({}) - @override(RLModule) - def input_specs_inference(self) -> SpecDict: - return self.input_specs_exploration() - @override(RLModule) def output_specs_inference(self) -> SpecDict: return SpecDict({SampleBatch.ACTION_DIST: TorchDeterministic}) @override(RLModule) def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: - shared_enc_out = self.shared_encoder(batch) - pi_enc_out = self.pi_encoder(shared_enc_out) - - action_logits = self.pi(pi_enc_out[ENCODER_OUT]) + x = self.shared_encoder( + batch, + input_mapping={ + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.pi_encoder( + x, + input_mapping={ + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.pi( + x, + input_mapping={ + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], + }, + ) + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] if self._is_discrete: action = torch.argmax(action_logits, dim=-1) else: @@ -234,13 +236,9 @@ def _forward_inference(self, batch: NestedDict) -> Mapping[str, Any]: action_dist = TorchDeterministic(action) output = {SampleBatch.ACTION_DIST: action_dist} - output["state_out"] = pi_enc_out.get("state_out", {}) + output["state_out_0"] = x.get("state_out", {}) return output - @override(RLModule) - def input_specs_exploration(self): - return self.shared_encoder.input_spec() - @override(RLModule) def output_specs_exploration(self) -> SpecDict: specs = {SampleBatch.ACTION_DIST: self.__get_action_dist_type()} @@ -264,12 +262,40 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: policy distribution to be used for computing KL divergence between the old policy and the new policy during training. """ - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) + x = self.shared_encoder( + batch, + input_mapping={ + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.pi_encoder( + x, + input_mapping={ + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.vf_encoder( + x, + input_mapping={ + self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + }, + ) + + x = self.pi( + x, + input_mapping={ + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], + }, + ) output = {} + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) output[SampleBatch.ACTION_DIST_INPUTS] = {"logits": action_logits} @@ -281,24 +307,23 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: output[SampleBatch.ACTION_DIST] = action_dist # compute the value function - output[SampleBatch.VF_PREDS] = self.vf(encoder_out_vf[ENCODER_OUT]).squeeze(-1) - output["state_out"] = encoder_out_pi.get("state_out", {}) - return output + vf_out = self.vf( + x, + input_mapping={ + self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT], + }, + ) + output[SampleBatch.VF_PREDS] = vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze( + -1 + ) - @override(RLModule) - def input_specs_train(self) -> SpecDict: - if self._is_discrete: - action_spec = TorchTensorSpec("b") - else: - action_dim = self.config.action_space.shape[0] - action_spec = TorchTensorSpec("b, h", h=action_dim) - - spec_dict = self.shared_encoder.input_spec() - spec_dict.update({SampleBatch.ACTIONS: action_spec}) - if SampleBatch.OBS in spec_dict: - spec_dict[SampleBatch.NEXT_OBS] = spec_dict[SampleBatch.OBS] - spec = SpecDict(spec_dict) - return spec + shared_encoder_state = x.get(self.shared_encoder.io[BaseModelIOKeys.STATE_OUT]) + pi_encoder_state = x.get(self.pi_encoder.io[BaseModelIOKeys.STATE_OUT]) + + state_out = shared_encoder_state or pi_encoder_state + if state_out: + output["state_out_0"] = state_out + return output @override(RLModule) def output_specs_train(self) -> SpecDict: @@ -314,12 +339,46 @@ def output_specs_train(self) -> SpecDict: @override(RLModule) def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: - encoder_out = self.shared_encoder(batch) - encoder_out_pi = self.pi_encoder(encoder_out) - encoder_out_vf = self.vf_encoder(encoder_out) + x = self.shared_encoder( + batch, + input_mapping={ + self.shared_encoder.io[BaseModelIOKeys.IN]: SampleBatch.OBS, + self.shared_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.pi_encoder( + x, + input_mapping={ + self.pi_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + self.pi_encoder.io[BaseModelIOKeys.STATE_IN]: "state_in_0", + }, + ) + x = self.vf_encoder( + x, + input_mapping={ + self.vf_encoder.io[BaseModelIOKeys.IN]: self.shared_encoder.io[ + BaseModelIOKeys.OUT + ], + }, + ) - action_logits = self.pi(encoder_out_pi[ENCODER_OUT]) - vf = self.vf(encoder_out_vf[ENCODER_OUT]) + x = self.pi( + x, + input_mapping={ + self.pi.io[BaseModelIOKeys.IN]: self.pi_encoder.io[BaseModelIOKeys.OUT], + }, + ) + + action_logits = x[self.pi.io[BaseModelIOKeys.OUT]] + + vf_out = self.vf( + x, + input_mapping={ + self.vf.io[BaseModelIOKeys.IN]: self.vf_encoder.io[BaseModelIOKeys.OUT], + }, + ) if self._is_discrete: action_dist = TorchCategorical(logits=action_logits) @@ -333,11 +392,11 @@ def _forward_train(self, batch: NestedDict) -> Mapping[str, Any]: output = { SampleBatch.ACTION_DIST: action_dist, SampleBatch.ACTION_LOGP: logp, - SampleBatch.VF_PREDS: vf.squeeze(-1), + SampleBatch.VF_PREDS: vf_out[self.vf.io[BaseModelIOKeys.OUT]].squeeze(-1), "entropy": entropy, } - output["state_out"] = encoder_out_pi.get("state_out", {}) + output["state_out_0"] = x.get("state_out", {}) return output def __get_action_dist_type(self): diff --git a/rllib/core/rl_module/encoder.py b/rllib/core/rl_module/encoder.py index f3bb22b46900..46656200d014 100644 --- a/rllib/core/rl_module/encoder.py +++ b/rllib/core/rl_module/encoder.py @@ -3,20 +3,16 @@ import tree from typing import List +from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from dataclasses import dataclass, field +from ray.rllib.utils.annotations import override from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.rnn_sequencing import add_time_dimension from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.checker import check_input_specs, check_output_specs from ray.rllib.models.specs.specs_torch import TorchTensorSpec from ray.rllib.models.torch.primitives import FCNet - -# TODO (Kourosh): Find a better / more straight fwd approach for sub-components - -ENCODER_OUT = "encoder_out" -STATE_IN = "state_in" -STATE_OUT = "state_out" +from ray.rllib.models.base_model import Model, BaseModelIOKeys @dataclass @@ -68,30 +64,23 @@ def build(self): return LSTMEncoder(self) -class Encoder(nn.Module): +class Encoder(Model, nn.Module): def __init__(self, config: EncoderConfig) -> None: - super().__init__() + nn.Module.__init__(self) + Model.__init__(self) self.config = config - self._input_spec = self.input_spec() - self._output_spec = self.output_spec() def get_initial_state(self): return [] + @property def input_spec(self): return SpecDict() + @property def output_spec(self): return SpecDict() - @check_input_specs("_input_spec") - @check_output_specs("_output_spec") - def forward(self, input_dict): - return self._forward(input_dict) - - def _forward(self, input_dict): - raise NotImplementedError - class FullyConnectedEncoder(Encoder): def __init__(self, config: FCConfig) -> None: @@ -104,21 +93,41 @@ def __init__(self, config: FCConfig) -> None: activation=config.activation, ) + @property + @override(Model) def input_spec(self): return SpecDict( - {SampleBatch.OBS: TorchTensorSpec("b, h", h=self.config.input_dim)} + { + self.io[BaseModelIOKeys.IN]: TorchTensorSpec( + "b, h", h=self.config.input_dim + ) + } ) + @property + @override(Model) def output_spec(self): return SpecDict( - {ENCODER_OUT: TorchTensorSpec("b, h", h=self.config.output_dim)} + { + self.io[BaseModelIOKeys.OUT]: TorchTensorSpec( + "b, h", h=self.config.output_dim + ) + } ) - def _forward(self, input_dict): - return {ENCODER_OUT: self.net(input_dict[SampleBatch.OBS])} + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict, **kwargs): + inputs = input_dict[self.io[BaseModelIOKeys.IN]] + return {self.io[BaseModelIOKeys.OUT]: self.net(inputs)} -class LSTMEncoder(Encoder): +class RecurrentEncoder(Encoder): + def __init__(self, config: EncoderConfig): + super().__init__(config=config) + + +class LSTMEncoder(RecurrentEncoder): def __init__(self, config: LSTMConfig) -> None: super().__init__(config) @@ -137,13 +146,17 @@ def get_initial_state(self): "c": torch.zeros(config.num_layers, config.hidden_dim), } + @property + @override(Model) def input_spec(self): config = self.config return SpecDict( { # bxt is just a name for better readability to indicated padded batch - SampleBatch.OBS: TorchTensorSpec("bxt, h", h=config.input_dim), - STATE_IN: { + self.io[BaseModelIOKeys.IN]: TorchTensorSpec( + "bxt, h", h=config.input_dim + ), + self.io[BaseModelIOKeys.STATE_IN]: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -154,12 +167,16 @@ def input_spec(self): } ) + @property + @override(Model) def output_spec(self): config = self.config return SpecDict( { - ENCODER_OUT: TorchTensorSpec("bxt, h", h=config.output_dim), - STATE_OUT: { + self.io[BaseModelIOKeys.OUT]: TorchTensorSpec( + "bxt, h", h=config.output_dim + ), + self.io[BaseModelIOKeys.STATE_OUT]: { "h": TorchTensorSpec( "b, l, h", h=config.hidden_dim, l=config.num_layers ), @@ -170,9 +187,11 @@ def output_spec(self): } ) - def _forward(self, input_dict: SampleBatch): - x = input_dict[SampleBatch.OBS] - states = input_dict[STATE_IN] + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict: SampleBatch, **kwargs): + x = input_dict[self.io[BaseModelIOKeys.IN]] + states = input_dict[self.io[BaseModelIOKeys.STATE_IN]] # states are batch-first when coming in states = tree.map_structure(lambda x: x.transpose(0, 1), states) @@ -189,14 +208,18 @@ def _forward(self, input_dict: SampleBatch): x = x.view(-1, x.shape[-1]) return { - ENCODER_OUT: x, - STATE_OUT: tree.map_structure(lambda x: x.transpose(0, 1), states_o), + self.io[BaseModelIOKeys.OUT]: x, + self.io[BaseModelIOKeys.STATE_OUT]: tree.map_structure( + lambda x: x.transpose(0, 1), states_o + ), } class IdentityEncoder(Encoder): def __init__(self, config: EncoderConfig) -> None: - super().__init__(config) + super().__init__(config=config) - def _forward(self, input_dict): - return input_dict + @check_input_specs("input_spec") + @check_output_specs("output_spec") + def _forward(self, input_dict, **kwargs): + return {self.io[BaseModelIOKeys.OUT]: input_dict[self.io[BaseModelIOKeys.IN]]} diff --git a/rllib/models/base_model.py b/rllib/models/base_model.py index c006af27f6f1..516bfa3b387f 100644 --- a/rllib/models/base_model.py +++ b/rllib/models/base_model.py @@ -1,22 +1,11 @@ -# Copyright 2021 DeepMind Technologies Limited. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - import abc +from enum import Enum from typing import Optional, Tuple +from collections import defaultdict +from typing import Mapping +from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.temp_spec_classes import TensorDict, SpecDict, ModelConfig +from ray.rllib.models.temp_spec_classes import TensorDict, ModelConfig from ray.rllib.utils.annotations import ( DeveloperAPI, OverrideToImplementCustomLogic, @@ -30,6 +19,81 @@ UnrollOutputType = Tuple[TensorDict, TensorDict] +@ExperimentalAPI +class BaseModelIOKeys(Enum): + IN: str = "in" + OUT: str = "out" + STATE_IN: str = "state_in" + STATE_OUT: str = "state_out" + + +class ModelIOMapping(Mapping): + """A mapping from general ModelIOKeys to their instance-based counterparts. + + In order to distinguish keys in input- and outputs-specs of multiple instances of + a give model, each instance is supposed to have distinct keys. + This Mapping provides a way to generate distinct keys per instance of + ModelIOMapping. + + """ + + __init_counters__ = defaultdict(lambda: 0) + + def __init__(self, model_name: str): + self._name: str = model_name + self._init_idx: str = str(self.__init_counters__[model_name]) + self.__init_counters__[model_name] += 1 + self._valid_keys = set() + + def __getitem__(self, item): + if item in self._valid_keys: + return self._name + "_" + str(item) + "_" + self._init_idx + else: + raise KeyError( + "`{}` is not a key of ModelIOMapping for model_name `{}` " + "and index `{}`. Valid keys are `{}`".format( + item, self._name, self._init_idx, self._valid_keys + ) + ) + + def add(self, key): + self._valid_keys.add(key) + + def __repr__(self): + return "ModelIOMapping for model {} with index {} and valid keys {" "}".format( + self._name, self._init_idx, self._valid_keys + ) + + def __iter__(self): + return self._valid_keys.__iter__() + + def __len__(self): + return self._valid_keys.__len__() + + def __contains__(self): + return self._valid_keys.__contains__() + + def keys(self): + return iter(self._valid_keys) + + def items(self): + return iter([(k, self[k]) for k in self._valid_keys]) + + def values(self): + return iter([self[k] for k in self._valid_keys]) + + def get(self, name): + raise NotImplementedError + + def __eq__(self, other: "ModelIOMapping") -> bool: + assert isinstance(other, ModelIOMapping) + return self._valid_keys.__eq__(other._valid_keys) + + def __ne__(self, other: "ModelIOMapping") -> bool: + assert isinstance(other, ModelIOMapping) + return self._valid_keys.__ne__(other._valid_keys) + + @ExperimentalAPI class RecurrentModel(abc.ABC): """The base model all other models are based on. @@ -57,6 +121,11 @@ class RecurrentModel(abc.ABC): def __init__(self, name: Optional[str] = None): self._name = name or self.__class__.__name__ + self.io = ModelIOMapping(self._name) + self.io.add(BaseModelIOKeys.IN) + self.io.add(BaseModelIOKeys.OUT) + self.io.add(BaseModelIOKeys.STATE_IN) + self.io.add(BaseModelIOKeys.STATE_OUT) @property def name(self) -> str: @@ -272,6 +341,16 @@ def _unroll( outputs = self._forward(inputs, **kwargs) return outputs, TensorDict() + def forward( + self, input_dict, input_mapping: Mapping = None, **kwargs + ) -> ForwardOutputType: + if input_mapping: + for forward_key, input_dict_key in input_mapping.items(): + if input_dict_key in input_dict: + input_dict[forward_key] = input_dict[input_dict_key] + input_dict.update(self._forward(input_dict, **kwargs)) + return input_dict + @abc.abstractmethod def _forward(self, inputs: TensorDict, **kwargs) -> ForwardOutputType: """Computes the output of this module for each timestep. diff --git a/rllib/models/torch/torch_modelv2.py b/rllib/models/torch/torch_modelv2.py index b56bf425fb6f..728d6b2fa431 100644 --- a/rllib/models/torch/torch_modelv2.py +++ b/rllib/models/torch/torch_modelv2.py @@ -67,7 +67,7 @@ def variables( ) -> Union[List[TensorType], Dict[str, TensorType]]: p = list(self.parameters()) if as_dict: - return {k: p[i] for i, k in enumerate(self.state_dict().keys())} + return {k: p[i] for i, k in enumerate(self.state_dict().io_map())} return p @override(ModelV2) diff --git a/rllib/utils/nested_dict.py b/rllib/utils/nested_dict.py index 1e4d308d1ef5..05108a0acc07 100644 --- a/rllib/utils/nested_dict.py +++ b/rllib/utils/nested_dict.py @@ -166,10 +166,7 @@ def get( k = _flatten_index(k) if k not in self: - if default is not None: - return default - else: - raise KeyError(k) + return default data_ptr = self._data for key in k: @@ -180,6 +177,8 @@ def get( return data_ptr def __getitem__(self, k: SeqStrType) -> T: + if k not in self: + raise KeyError(k) output = self.get(k) return output