diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index f852340aff1eb..34940edd43524 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -3362,9 +3362,9 @@ def rl_module( """Sets the config's RLModule settings. Args: - model_config_dict: The default model config dictionary for `RLModule`s. This - is used for any `RLModule` if not otherwise specified in the - `rl_module_spec`. + model_config: The DefaultModelConfig object (or a config dictionary) passed + as `model_config` arg into each RLModule's constructor. This is used + for all RLModules, if not otherwise specified through `rl_module_spec`. rl_module_spec: The RLModule spec to use for this config. It can be either a RLModuleSpec or a MultiRLModuleSpec. If the observation_space, action_space, catalog_class, or the model config is @@ -4186,7 +4186,7 @@ def model_config(self): This method combines the auto configuration `self _model_config_auto_includes` defined by an algorithm with the user-defined configuration in - `self._model_config_dict`.This configuration dictionary is used to + `self._model_config`.This configuration dictionary is used to configure the `RLModule` in the new stack and the `ModelV2` in the old stack. diff --git a/rllib/algorithms/bc/torch/bc_torch_rl_module.py b/rllib/algorithms/bc/torch/bc_torch_rl_module.py index 6d328bb568195..a547047d7f417 100644 --- a/rllib/algorithms/bc/torch/bc_torch_rl_module.py +++ b/rllib/algorithms/bc/torch/bc_torch_rl_module.py @@ -20,13 +20,12 @@ def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]: """Generic BC forward pass (for all phases of training/evaluation).""" output = {} - # State encodings. + # Encoder forward pass. encoder_outs = self.encoder(batch) if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] # Actions. - action_logits = self.pi(encoder_outs[ENCODER_OUT]) - output[Columns.ACTION_DIST_INPUTS] = action_logits + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT]) return output diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index 93f242d3af833..5c2584d2ed821 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -167,14 +167,13 @@ def possibly_masked_mean(data_): # Calculate our own expected values (to then compare against the # agent's loss output). - fwd_out = ( - algo.learner_group._learner.module[DEFAULT_MODULE_ID] - .unwrapped() - .forward_train({k: v for k, v in batch[DEFAULT_MODULE_ID].items()}) + module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped() + fwd_out = module.forward_train( + {k: v for k, v in batch[DEFAULT_MODULE_ID].items()} ) advantages = ( batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy() - - fwd_out["vf_preds"].detach().cpu().numpy() + - module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy() ) advantages_squared = possibly_masked_mean(np.square(advantages)) c_2 = 100.0 + 1e-8 * (advantages_squared - 100.0) diff --git a/rllib/algorithms/ppo/ppo_rl_module.py b/rllib/algorithms/ppo/ppo_rl_module.py index 00027ed3e7e70..30ca5d843df1d 100644 --- a/rllib/algorithms/ppo/ppo_rl_module.py +++ b/rllib/algorithms/ppo/ppo_rl_module.py @@ -32,7 +32,7 @@ def setup(self): self.inference_only = False # If this is an `inference_only` Module, we'll have to pass this information # to the encoder config as well. - if self.config.inference_only and self.framework == "torch": + if self.inference_only and self.framework == "torch": self.catalog.actor_critic_encoder_config.inference_only = True # Build models from catalog. @@ -54,6 +54,6 @@ def get_non_inference_attributes(self) -> List[str]: """Return attributes, which are NOT inference-only (only used for training).""" return ["vf"] + ( [] - if self.config.model_config_dict.get("vf_share_layers") + if self.model_config.get("vf_share_layers") else ["encoder.critic_encoder"] ) diff --git a/rllib/algorithms/ppo/tests/test_ppo.py b/rllib/algorithms/ppo/tests/test_ppo.py index 02bb6d144ad32..ae51de75389dc 100644 --- a/rllib/algorithms/ppo/tests/test_ppo.py +++ b/rllib/algorithms/ppo/tests/test_ppo.py @@ -8,12 +8,12 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.learner.learner import DEFAULT_OPTIMIZER, LR_KEY - +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.metrics import LEARNER_RESULTS from ray.rllib.utils.test_utils import check, check_train_results_new_api_stack -def get_model_config(framework, lstm=False): +def get_model_config(lstm=False): return ( dict( use_lstm=True, @@ -102,9 +102,7 @@ def test_ppo_compilation_and_schedule_mixins(self): print("Env={}".format(env)) for lstm in [False]: print("LSTM={}".format(lstm)) - config.rl_module( - model_config_dict=get_model_config("torch", lstm=lstm) - ).framework(eager_tracing=False) + config.rl_module(model_config=get_model_config(lstm=lstm)) algo = config.build(env=env) # TODO: Maybe add an API to get the Learner(s) instances within @@ -143,12 +141,12 @@ def test_ppo_free_log_std(self): num_env_runners=1, ) .rl_module( - model_config_dict={ - "fcnet_hiddens": [10], - "fcnet_activation": "linear", - "free_log_std": True, - "vf_share_layers": True, - } + model_config=DefaultModelConfig( + fcnet_hiddens=[10], + fcnet_activation="linear", + free_log_std=True, + vf_share_layers=True, + ), ) .training( gamma=0.99, diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 94caeba7d1abd..de3d3f42f424b 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -21,17 +21,6 @@ torch, nn = try_import_torch() -def get_expected_module_config(env, model_config_dict, observation_space): - config = RLModuleConfig( - observation_space=observation_space, - action_space=env.action_space, - model_config_dict=model_config_dict, - catalog_class=PPOCatalog, - ) - - return config - - def dummy_torch_ppo_loss(module, batch, fwd_out): adv = batch[Columns.REWARDS] - module.compute_values(batch) action_dist_class = module.get_train_action_dist_cls() @@ -46,12 +35,12 @@ def dummy_torch_ppo_loss(module, batch, fwd_out): def _get_ppo_module(env, lstm, observation_space): - model_config_dict = {"use_lstm": lstm} - config = get_expected_module_config( - env, model_config_dict=model_config_dict, observation_space=observation_space + return PPOTorchRLModule( + observation_space=observation_space, + action_space=env.action_space, + model_config=DefaultModelConfig(use_lstm=lstm), + catalog_class=PPOCatalog, ) - module = PPOTorchRLModule(config) - return module def _get_input_batch_from_obs(obs, lstm): diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 86df016326a10..ea0b32e96381f 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -3,7 +3,7 @@ from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import ACTOR, CRITIC, ENCODER_OUT -from ray.rllib.core.rl_module.apis import ValueFunctionAPI +from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch import TorchRLModule from ray.rllib.utils.annotations import override @@ -14,8 +14,6 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): - framework: str = "torch" - @override(RLModule) def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: """Default forward pass (used for inference and exploration).""" @@ -31,7 +29,7 @@ def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: @override(RLModule) def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Train forward pass (keep features for possible shared value func. call).""" + """Train forward pass (keep embeddings for possible shared value func. call).""" output = {} encoder_outs = self.encoder(batch) output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] diff --git a/rllib/algorithms/sac/sac_rl_module.py b/rllib/algorithms/sac/sac_rl_module.py index cff6e97bbdf9d..832df79e9eded 100644 --- a/rllib/algorithms/sac/sac_rl_module.py +++ b/rllib/algorithms/sac/sac_rl_module.py @@ -75,7 +75,7 @@ def setup(self): # Build heads. self.pi = self.catalog.build_pi_head(framework=self.framework) - if not self.config.inference_only or self.framework != "torch": + if not self.inference_only or self.framework != "torch": self.qf = self.catalog.build_qf_head(framework=self.framework) # If necessary build also a twin Q heads. if self.twin_q: diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index 641dfc640230a..1d7a32e87a2ac 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -404,30 +404,6 @@ def get_default_rl_module_spec(self): spec, expected = self._get_expected_marl_spec(config, CustomRLModule1) self._assertEqualMARLSpecs(spec, expected) - # expected module should become the passed module if we pass it in. - spec, expected = self._get_expected_marl_spec( - config, CustomRLModule2, passed_module_class=CustomRLModule2 - ) - self._assertEqualMARLSpecs(spec, expected) - ######################################## - # This is an alternative way to ask the algorithm to assign a specific type of - # RLModule class to ALL module_ids. - config = ( - SingleAgentAlgoConfig() - .api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - .rl_module( - rl_module_spec=MultiRLModuleSpec( - module_specs=RLModuleSpec(module_class=CustomRLModule1) - ), - ) - ) - - spec, expected = self._get_expected_marl_spec(config, CustomRLModule1) - self._assertEqualMARLSpecs(spec, expected) - # expected module should become the passed module if we pass it in. spec, expected = self._get_expected_marl_spec( config, CustomRLModule2, passed_module_class=CustomRLModule2 @@ -490,30 +466,6 @@ def get_default_rl_module_spec(self): lambda: config.rl_module_spec, ) - ######################################## - # This is the case where we ask the algorithm to use its default - # MultiRLModuleSpec, and the MultiRLModuleSpec has defined its - # RLModuleSpecs. - config = MultiAgentAlgoConfig().api_stack( - enable_rl_module_and_learner=True, - enable_env_runner_and_connector_v2=True, - ) - - spec, expected = self._get_expected_marl_spec( - config, - DiscreteBCTorchModule, - expected_multi_rl_module_class=CustomMultiRLModule1, - ) - self._assertEqualMARLSpecs(spec, expected) - - spec, expected = self._get_expected_marl_spec( - config, - CustomRLModule1, - passed_module_class=CustomRLModule1, - expected_multi_rl_module_class=CustomMultiRLModule1, - ) - self._assertEqualMARLSpecs(spec, expected) - if __name__ == "__main__": import pytest diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index d73a73878f25b..b9979da368d36 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -143,7 +143,7 @@ def test_e2e_load_complex_multi_rl_module(self): module_class=PPOTorchRLModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [64]}, + model_config=DefaultModelConfig(fcnet_hiddens=[64]), catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) @@ -283,7 +283,7 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): module_class=PPOTorchRLModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [64]}, + model_config=DefaultModelConfig(fcnet_hiddens=[64]), catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index 31f52b2f29b6b..136dd713e01af 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -102,11 +102,21 @@ def __init__( if view_requirements != DEPRECATED_VALUE: deprecation_warning(old="Catalog(view_requirements=..)", error=True) + # TODO (sven): The following logic won't be needed anymore, once we get rid of + # Catalogs entirely. We will assert directly inside the algo's DefaultRLModule + # class that the `model_config` is a DefaultModelConfig. Thus users won't be + # able to pass in partial config dicts into a default model (alternatively, we + # could automatically augment the user provided dict by the default config + # dataclass object only(!) for default modules). + if dataclasses.is_dataclass(model_config_dict): + model_config_dict = dataclasses.asdict(model_config_dict) + default_config = dataclasses.asdict(DefaultModelConfig()) + # end: TODO + self.observation_space = observation_space self.action_space = action_space - # TODO (Artur): Make model defaults a dataclass - self._model_config_dict = {**MODEL_DEFAULTS, **model_config_dict} + self._model_config_dict = default_config | model_config_dict self._latent_dims = None self._determine_components_hook() diff --git a/rllib/core/rl_module/apis/inference_only_api.py b/rllib/core/rl_module/apis/inference_only_api.py index 7a6c689354f67..2b9db54212a4e 100644 --- a/rllib/core/rl_module/apis/inference_only_api.py +++ b/rllib/core/rl_module/apis/inference_only_api.py @@ -7,9 +7,9 @@ class InferenceOnlyAPI(abc.ABC): Only the `get_non_inference_attributes` method needs to get implemented for an RLModule to have the following functionality: - - On EnvRunners (or when self.config.inference_only=True), RLlib will remove + - On EnvRunners (or when self.inference_only=True), RLlib will remove those parts of the model not required for action computation. - - An RLModule on a Learner (where `self.config.inference_only=False`) will + - An RLModule on a Learner (where `self.inference_only=False`) will return only those weights from `get_state()` that are part of its inference-only version, thus possibly saving network traffic/time. """ diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index f755c5efb1b49..a447d084533b9 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -1,5 +1,5 @@ import copy -import dataclasses +from dataclasses import dataclass, field import logging import pprint from typing import ( @@ -18,6 +18,8 @@ ValuesView, ) +import gymnasium as gym + from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec @@ -137,7 +139,6 @@ def __init__( def setup(self): """Sets up the underlying, individual RLModules.""" self._rl_modules = {} - self._check_module_configs(self.config.modules) # Make sure all individual RLModules have the same framework OR framework=None. framework = None for module_id, rl_module_spec in self.rl_module_specs.items(): @@ -287,19 +288,6 @@ def add_module( # has `inference_only=False`. if not module.inference_only: self.inference_only = False - - # Check framework of incoming RLModule against `self.framework`. - if module.framework is not None: - if self.framework is None: - self.framework = module.framework - elif module.framework != self.framework: - raise ValueError( - f"Framework ({module.framework}) of incoming RLModule does NOT " - f"match framework ({self.framework}) of MultiRLModule! If the " - f"added module should not be trained, try setting its framework " - f"to None." - ) - self._rl_modules[module_id] = module # Update our RLModuleSpecs dict, such that - if written to disk - # it'll allow for proper restoring this instance through `.from_checkpoint()`. @@ -460,12 +448,12 @@ def set_state(self, state: StateDict) -> None: ) # Go through all of our current modules and check, whether they are listed # in the given MultiRLModuleSpec. If not, erase them from `self`. - for module_id, module in self._rl_modules.items(): - if module_id not in multi_rl_module_spec.module_specs: + for module_id, module in self._rl_modules.copy().items(): + if module_id not in multi_rl_module_spec.rl_module_specs: self.remove_module(module_id, raise_err_if_not_found=True) # Go through all the modules in the given MultiRLModuleSpec and if # they are not present in `self`, add them. - for module_id, module_spec in multi_rl_module_spec.module_specs.items(): + for module_id, module_spec in multi_rl_module_spec.rl_module_specs.items(): if module_id not in self: self.add_module(module_id, module_spec.build(), override=False) @@ -542,6 +530,20 @@ def _check_module_configs(cls, module_configs: Dict[ModuleID, Any]): if not isinstance(module_spec, RLModuleSpec): raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") + @classmethod + def _check_module_specs(cls, rl_module_specs: Dict[ModuleID, RLModuleSpec]): + """Checks the individual RLModuleSpecs for validity. + + Args: + rl_module_specs: Dict mapping ModuleIDs to the respective RLModuleSpec. + + Raises: + ValueError: If any RLModuleSpec is invalid. + """ + for module_id, rl_module_spec in rl_module_specs.items(): + if not isinstance(rl_module_spec, RLModuleSpec): + raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") + def _check_module_exists(self, module_id: ModuleID) -> None: if module_id not in self._rl_modules: raise KeyError( @@ -664,11 +666,7 @@ def build(self, module_id: Optional[ModuleID] = None) -> RLModule: observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=( - dataclasses.asdict(self.model_config) - if dataclasses.is_dataclass(self.model_config) - else self.model_config - ), + model_config=self.model_config, rl_module_specs=self.rl_module_specs, ) # Older custom model might still require the old `MultiRLModuleConfig` under @@ -861,7 +859,7 @@ def get_rl_module_config(self): "module2: [RLModuleSpec], ..}, inference_only=..)", error=False, ) -@dataclasses.dataclass +@dataclass class MultiRLModuleConfig: inference_only: bool = False modules: Dict[ModuleID, RLModuleSpec] = dataclasses.field(default_factory=dict) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 7724f83af373d..f1fb5b337cc54 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -8,14 +8,18 @@ from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.models.distributions import Distribution from ray.rllib.utils.annotations import ( - ExperimentalAPI, override, OverrideToImplementCustomLogic, ) from ray.rllib.utils.checkpoints import Checkpointable -from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.deprecation import ( + Deprecated, + DEPRECATED_VALUE, + deprecation_warning, +) from ray.rllib.utils.serialization import ( gym_space_from_dict, gym_space_to_dict, @@ -94,7 +98,7 @@ def build(self) -> "RLModule": observation_space=self.observation_space, action_space=self.action_space, inference_only=self.inference_only, - model_config=self._get_model_config(), + model_config=self.model_config, catalog_class=self.catalog_class, ) # Older custom model might still require the old `RLModuleConfig` under @@ -232,82 +236,18 @@ def _get_model_config(self): else (self.model_config or {}) ) - -@ExperimentalAPI -@dataclass -class RLModuleConfig: - """A utility config class to make it constructing RLModules easier. - - Args: - observation_space: The observation space of the RLModule. This may differ - from the observation space of the environment. For example, a discrete - observation space of an environment, would usually correspond to a - one-hot encoded observation space of the RLModule because of preprocessing. - action_space: The action space of the RLModule. - inference_only: Whether the RLModule should be configured in its inference-only - state, in which those components not needed for action computing (for - example a value function or a target network) might be missing. - Note that `inference_only=True` AND `learner_only=True` is not allowed. - learner_only: Whether this RLModule should only be built on Learner workers, but - NOT on EnvRunners. Useful for RLModules inside a MultiRLModule that are only - used for training, for example a shared value function in a multi-agent - setup or a world model in a curiosity-learning setup. - Note that `inference_only=True` AND `learner_only=True` is not allowed. - model_config_dict: The model config dict to use. - catalog_class: The Catalog class to use. - """ - - observation_space: gym.Space = None - action_space: gym.Space = None - inference_only: bool = False - learner_only: bool = False - model_config_dict: Dict[str, Any] = field(default_factory=dict) - catalog_class: Type["Catalog"] = None - - def get_catalog(self) -> Optional["Catalog"]: - """Returns the catalog for this config, if a class is provided.""" - if self.catalog_class is not None: - return self.catalog_class( - observation_space=self.observation_space, - action_space=self.action_space, - model_config_dict=self.model_config_dict, - ) - return None - - def to_dict(self): - """Returns a serialized representation of the config. - - NOTE: This should be JSON-able. Users can test this by calling - json.dumps(config.to_dict()). - - """ - catalog_class_path = ( - serialize_type(self.catalog_class) if self.catalog_class else "" - ) - return { - "observation_space": gym_space_to_dict(self.observation_space), - "action_space": gym_space_to_dict(self.action_space), - "inference_only": self.inference_only, - "learner_only": self.learner_only, - "model_config_dict": self.model_config_dict, - "catalog_class_path": catalog_class_path, - } - - @classmethod - def from_dict(cls, d: Dict[str, Any]): - """Creates a config from a serialized representation.""" - catalog_class = ( - None - if d["catalog_class_path"] == "" - else deserialize_type(d["catalog_class_path"]) - ) - return cls( - observation_space=gym_space_from_dict(d["observation_space"]), - action_space=gym_space_from_dict(d["action_space"]), - inference_only=d["inference_only"], - learner_only=d["learner_only"], - model_config_dict=d["model_config_dict"], - catalog_class=catalog_class, + @Deprecated( + new="RLModule(*, observation_space=.., action_space=.., ....)", + error=False, + ) + def get_rl_module_config(self): + return RLModuleConfig( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + learner_only=self.learner_only, + model_config_dict=self._get_model_config(), + catalog_class=self.catalog_class, ) @@ -456,13 +396,50 @@ def __init__( # primitive components based on obs- and action spaces. self.catalog = None - # TODO (sven): Deprecate Catalog and replace with utility functions to create - # primitive components based on obs- and action spaces. - self.catalog = None - try: - self.catalog = self.config.get_catalog() - except Exception: - pass + # Deprecated + self.config = config + if self.config != DEPRECATED_VALUE: + deprecation_warning( + old="RLModule(config=[RLModuleConfig])", + new="RLModule(observation_space=.., action_space=.., inference_only=..," + " learner_only=.., model_config=..)", + error=False, + ) + self.observation_space = self.config.observation_space + self.action_space = self.config.action_space + self.inference_only = self.config.inference_only + self.learner_only = self.config.learner_only + self.model_config = self.config.model_config_dict + try: + self.catalog = self.config.get_catalog() + except Exception: + pass + else: + self.observation_space = observation_space + self.action_space = action_space + self.inference_only = inference_only + self.learner_only = learner_only + self.model_config = model_config + try: + self.catalog = catalog_class( + observation_space=self.observation_space, + action_space=self.action_space, + model_config_dict=self.model_config, + ) + except Exception: + pass + + # TODO (sven): Deprecate this. We keep it here for now in case users + # still have custom models (or subclasses of RLlib default models) + # into which they pass in a `config` argument. + self.config = RLModuleConfig( + observation_space=self.observation_space, + action_space=self.action_space, + inference_only=self.inference_only, + learner_only=self.learner_only, + model_config_dict=self.model_config, + catalog_class=catalog_class, + ) self.action_dist_cls = None if self.catalog is not None: @@ -790,3 +767,57 @@ def input_specs_train(self) -> SpecType: def _default_input_specs(self) -> SpecType: """Returns the default input specs.""" return [Columns.OBS] + + +@Deprecated( + old="RLModule(config=[RLModuleConfig object])", + new="RLModule(observation_space=.., action_space=.., inference_only=.., " + "model_config=.., catalog_class=..)", + error=False, +) +@dataclass +class RLModuleConfig: + observation_space: gym.Space = None + action_space: gym.Space = None + inference_only: bool = False + learner_only: bool = False + model_config_dict: Dict[str, Any] = field(default_factory=dict) + catalog_class: Type["Catalog"] = None + + def get_catalog(self) -> Optional["Catalog"]: + if self.catalog_class is not None: + return self.catalog_class( + observation_space=self.observation_space, + action_space=self.action_space, + model_config_dict=self.model_config_dict, + ) + return None + + def to_dict(self): + catalog_class_path = ( + serialize_type(self.catalog_class) if self.catalog_class else "" + ) + return { + "observation_space": gym_space_to_dict(self.observation_space), + "action_space": gym_space_to_dict(self.action_space), + "inference_only": self.inference_only, + "learner_only": self.learner_only, + "model_config_dict": self.model_config_dict, + "catalog_class_path": catalog_class_path, + } + + @classmethod + def from_dict(cls, d: Dict[str, Any]): + catalog_class = ( + None + if d["catalog_class_path"] == "" + else deserialize_type(d["catalog_class_path"]) + ) + return cls( + observation_space=gym_space_from_dict(d["observation_space"]), + action_space=gym_space_from_dict(d["action_space"]), + inference_only=d["inference_only"], + learner_only=d["learner_only"], + model_config_dict=d["model_config_dict"], + catalog_class=catalog_class, + ) diff --git a/rllib/core/rl_module/tests/test_multi_rl_module.py b/rllib/core/rl_module/tests/test_multi_rl_module.py index 02aeab9d8901e..800d36061f287 100644 --- a/rllib/core/rl_module/tests/test_multi_rl_module.py +++ b/rllib/core/rl_module/tests/test_multi_rl_module.py @@ -2,9 +2,9 @@ import unittest from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC, DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModuleConfig -from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig -from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule +from ray.rllib.examples.rl_modules.classes.vpg_rlm import VPGTorchRLModule from ray.rllib.env.multi_agent_env import make_multi_agent from ray.rllib.utils.test_utils import check @@ -15,17 +15,17 @@ def test_from_config(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) module1 = RLModuleSpec( - module_class=DiscreteBCTorchModule, + module_class=VPGTorchRLModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, + model_config={"hidden_dim": 32}, ) module2 = RLModuleSpec( - module_class=DiscreteBCTorchModule, + module_class=VPGTorchRLModule, observation_space=env.get_observation_space(0), action_space=env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, + model_config={"hidden_dim": 32}, ) multi_rl_module = MultiRLModule( @@ -41,12 +41,10 @@ def test_as_multi_rl_module(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - multi_rl_module = DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + multi_rl_module = VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ).as_multi_rl_module() self.assertNotIsInstance(multi_rl_module, VPGTorchRLModule) @@ -62,12 +60,10 @@ def test_get_state_and_set_state(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + module = VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ).as_multi_rl_module() state = module.get_state() @@ -81,12 +77,10 @@ def test_get_state_and_set_state(self): set(module[DEFAULT_MODULE_ID].get_state().keys()), ) - module2 = DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + module2 = VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ).as_multi_rl_module() state2 = module2.get_state() check(state[DEFAULT_MODULE_ID], state2[DEFAULT_MODULE_ID], false=True) @@ -101,22 +95,18 @@ def test_add_remove_modules(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + module = VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ).as_multi_rl_module() module.add_module( "test", - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ), ) self.assertEqual(set(module.keys()), {DEFAULT_MODULE_ID, "test"}) @@ -128,24 +118,20 @@ def test_add_remove_modules(self): ValueError, lambda: module.add_module( DEFAULT_MODULE_ID, - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ), ), ) module.add_module( DEFAULT_MODULE_ID, - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ), override=True, ) @@ -154,32 +140,26 @@ def test_save_to_path_and_from_checkpoint(self): """Test saving and loading from checkpoint after adding / removing modules.""" env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module = DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + module = VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ).as_multi_rl_module() module.add_module( "test", - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [32]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 32}, ), ) module.add_module( "test2", - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [128]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 128}, ), ) @@ -205,12 +185,10 @@ def test_save_to_path_and_from_checkpoint(self): # Check that - after adding a new module - the checkpoint is correct. module.add_module( "test3", - DiscreteBCTorchModule( - config=RLModuleConfig( - env.get_observation_space(0), - env.get_action_space(0), - model_config_dict={"fcnet_hiddens": [120]}, - ) + VPGTorchRLModule( + observation_space=env.get_observation_space(0), + action_space=env.get_action_space(0), + model_config={"hidden_dim": 120}, ), ) # Check that - after adding a module - the checkpoint is correct. diff --git a/rllib/core/rl_module/tf/tf_rl_module.py b/rllib/core/rl_module/tf/tf_rl_module.py index 5b6ea84dbb3a4..144ba00953e62 100644 --- a/rllib/core/rl_module/tf/tf_rl_module.py +++ b/rllib/core/rl_module/tf/tf_rl_module.py @@ -64,14 +64,14 @@ def set_state(self, state: StateDict) -> None: def get_inference_action_dist_cls(self) -> Type[TfDistribution]: if self.action_dist_cls is not None: return self.action_dist_cls - elif isinstance(self.config.action_space, gym.spaces.Discrete): + elif isinstance(self.action_space, gym.spaces.Discrete): return TfCategorical - elif isinstance(self.config.action_space, gym.spaces.Box): + elif isinstance(self.action_space, gym.spaces.Box): return TfDiagGaussian else: raise ValueError( f"Default action distribution for action space " - f"{self.config.action_space} not supported! Either set the " + f"{self.action_space} not supported! Either set the " f"`self.action_dist_cls` property in your RLModule's `setup()` method " f"to a subclass of `ray.rllib.models.tf.tf_distributions." f"TfDistribution` or - if you need different distributions for " diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index d9dfc9fdc6ed2..84f0f7c85275c 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -49,10 +49,10 @@ def __init__(self, *args, **kwargs) -> None: nn.Module.__init__(self) RLModule.__init__(self, *args, **kwargs) - # If an inference-only class AND self.config.inference_only is True, + # If an inference-only class AND self.inference_only is True, # remove all attributes that are returned by # `self.get_non_inference_attributes()`. - if self.config.inference_only and isinstance(self, InferenceOnlyAPI): + if self.inference_only and isinstance(self, InferenceOnlyAPI): for attr in self.get_non_inference_attributes(): parts = attr.split(".") if not hasattr(self, parts[0]): @@ -109,7 +109,7 @@ def get_state( # InferenceOnlyAPI). if ( inference_only - and not self.config.inference_only + and not self.inference_only and isinstance(self, InferenceOnlyAPI) ): attr = self.get_non_inference_attributes() @@ -135,14 +135,14 @@ def set_state(self, state: StateDict) -> None: def get_inference_action_dist_cls(self) -> Type[TorchDistribution]: if self.action_dist_cls is not None: return self.action_dist_cls - elif isinstance(self.config.action_space, gym.spaces.Discrete): + elif isinstance(self.action_space, gym.spaces.Discrete): return TorchCategorical - elif isinstance(self.config.action_space, gym.spaces.Box): + elif isinstance(self.action_space, gym.spaces.Box): return TorchDiagGaussian else: raise ValueError( f"Default action distribution for action space " - f"{self.config.action_space} not supported! Either set the " + f"{self.action_space} not supported! Either set the " f"`self.action_dist_cls` property in your RLModule's `setup()` method " f"to a subclass of `ray.rllib.models.torch.torch_distributions." f"TorchDistribution` or - if you need different distributions for " diff --git a/rllib/examples/connectors/flatten_observations_dict_space.py b/rllib/examples/connectors/flatten_observations_dict_space.py index b9f1f48977a1a..564df75c6b9d7 100644 --- a/rllib/examples/connectors/flatten_observations_dict_space.py +++ b/rllib/examples/connectors/flatten_observations_dict_space.py @@ -121,11 +121,11 @@ def _env_to_module_pipeline(env): lr=0.0003, ) .rl_module( - model_config_dict={ - "fcnet_hiddens": [32], - "fcnet_activation": "linear", - "vf_share_layers": True, - }, + model_config=DefaultModelConfig( + fcnet_hiddens=[32], + fcnet_activation="linear", + vf_share_layers=True, + ), ) ) @@ -149,11 +149,6 @@ def _env_to_module_pipeline(env): vf_loss_coeff=0.05, entropy_coeff=0.0, ) - base_config.rl_module( - model_config_dict={ - "vf_share_layers": True, - } - ) # Run everything as configured. run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/connectors/frame_stacking.py b/rllib/examples/connectors/frame_stacking.py index dbbc04866e27e..554bd1c8f20d3 100644 --- a/rllib/examples/connectors/frame_stacking.py +++ b/rllib/examples/connectors/frame_stacking.py @@ -197,14 +197,12 @@ def _env_creator(cfg): grad_clip_by="global_norm", ) .rl_module( - model_config_dict=dict( - { - "vf_share_layers": True, - "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], - "conv_activation": "relu", - "post_fcnet_hiddens": [256], - }, - ) + model_config=DefaultModelConfig( + vf_share_layers=True, + conv_filters=[(16, 4, 2), (32, 4, 2), (64, 4, 2), (128, 4, 2)], + conv_activation="relu", + head_fcnet_hiddens=[256], + ), ) ) diff --git a/rllib/examples/connectors/mean_std_filtering.py b/rllib/examples/connectors/mean_std_filtering.py index 6867f3b715725..e4511bdb888e5 100644 --- a/rllib/examples/connectors/mean_std_filtering.py +++ b/rllib/examples/connectors/mean_std_filtering.py @@ -151,12 +151,12 @@ def observation(self, observation): vf_loss_coeff=0.01, ) .rl_module( - model_config_dict={ - "fcnet_activation": "relu", - "fcnet_weights_initializer": torch.nn.init.xavier_uniform_, - "fcnet_bias_initializer": torch.nn.init.constant_, - "fcnet_bias_initializer_config": {"val": 0.0}, - } + model_config=DefaultModelConfig( + fcnet_activation="relu", + fcnet_kernel_initializer=torch.nn.init.xavier_uniform_, + fcnet_bias_initializer=torch.nn.init.constant_, + fcnet_bias_initializer_kwargs={"val": 0.0}, + ), ) # In case you would like to run with a evaluation EnvRunners, make sure your # `evaluation_config` key contains the `use_worker_filter_stats=False` setting diff --git a/rllib/examples/connectors/prev_actions_prev_rewards.py b/rllib/examples/connectors/prev_actions_prev_rewards.py index ee12bad84a7d5..1fa1e6681b90d 100644 --- a/rllib/examples/connectors/prev_actions_prev_rewards.py +++ b/rllib/examples/connectors/prev_actions_prev_rewards.py @@ -141,15 +141,16 @@ def _env_to_module(env): vf_loss_coeff=0.01, ) .rl_module( - model_config_dict={ - "use_lstm": True, - "max_seq_len": 20, - "fcnet_hiddens": [32], - "fcnet_activation": "linear", - "vf_share_layers": True, - "fcnet_weights_initializer": nn.init.xavier_uniform_, - "fcnet_bias_initializer": functools.partial(nn.init.constant_, 0.0), - } + model_config=DefaultModelConfig( + use_lstm=True, + max_seq_len=20, + fcnet_hiddens=[32], + fcnet_activation="linear", + fcnet_kernel_initializer=nn.init.xavier_uniform_, + fcnet_bias_initializer=nn.init.constant_, + fcnet_bias_initializer_kwargs={"val": 0.0}, + vf_share_layers=True, + ), ) ) diff --git a/rllib/examples/envs/env_rendering_and_recording.py b/rllib/examples/envs/env_rendering_and_recording.py index 4e8c5254cc378..bf5b2d69b844f 100644 --- a/rllib/examples/envs/env_rendering_and_recording.py +++ b/rllib/examples/envs/env_rendering_and_recording.py @@ -275,12 +275,12 @@ def _env_creator(cfg): if base_config.is_atari: base_config.rl_module( - model_config_dict={ - "vf_share_layers": True, - "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], - "conv_activation": "relu", - "post_fcnet_hiddens": [256], - }, + model_config=DefaultModelConfig( + conv_filters=[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + conv_activation="relu", + head_fcnet_hiddens=[256], + vf_share_layers=True, + ), ) run_rllib_example_script_experiment(base_config, args) diff --git a/rllib/examples/learners/custom_loss_fn_simple.py b/rllib/examples/learners/custom_loss_fn_simple.py index 1ecdcbea3ae2b..9877fa10cddf0 100644 --- a/rllib/examples/learners/custom_loss_fn_simple.py +++ b/rllib/examples/learners/custom_loss_fn_simple.py @@ -135,9 +135,7 @@ class for details on how to override the main (PPO) loss function. lr=args.lr, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - }, + model_config=DefaultModelConfig(vf_share_layers=True), ) ) diff --git a/rllib/examples/learners/separate_vf_lr_and_optimizer.py b/rllib/examples/learners/separate_vf_lr_and_optimizer.py index 2c9eb96fc9521..1e5359f1162b9 100644 --- a/rllib/examples/learners/separate_vf_lr_and_optimizer.py +++ b/rllib/examples/learners/separate_vf_lr_and_optimizer.py @@ -126,11 +126,9 @@ class for details on how to override the main (torch) `configure_optimizers_for_ lr=args.lr_policy, ) .rl_module( - model_config_dict={ - # Another very important setting is this here. Make sure you use - # completely separate NNs for policy and value-functions. - "vf_share_layers": False, - }, + # Another very important setting is this here. Make sure you use + # completely separate NNs for policy and value-functions. + model_config=DefaultModelConfig(vf_share_layers=False), ) ) diff --git a/rllib/examples/multi_agent/self_play_with_open_spiel.py b/rllib/examples/multi_agent/self_play_with_open_spiel.py index 815120870d68e..8f0b63dbf017d 100644 --- a/rllib/examples/multi_agent/self_play_with_open_spiel.py +++ b/rllib/examples/multi_agent/self_play_with_open_spiel.py @@ -161,9 +161,7 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): policies_to_train=["main"], ) .rl_module( - model_config_dict={ - "fcnet_hiddens": [512, 512], - }, + model_config=DefaultModelConfig(fcnet_hiddens=[512, 512]), rl_module_spec=MultiRLModuleSpec( rl_module_specs={ "main": RLModuleSpec(), diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py index 4cb6b9effe349..38d2004d6cbf6 100644 --- a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -74,9 +74,8 @@ def setup(self): # Assume a simple Box(1D) tensor as input shape. in_size = self.observation_space.shape[0] - # Get the LSTM cell size from our RLModuleConfig's (self.config) - # `model_config_dict` property: - self._lstm_cell_size = self.config.model_config_dict.get("lstm_cell_size", 256) + # Get the LSTM cell size from the `model_config` attribute: + self._lstm_cell_size = self.model_config.get("lstm_cell_size", 256) self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=True) in_size = self._lstm_cell_size @@ -94,7 +93,7 @@ def setup(self): self._fc_net = nn.Sequential(*layers) # Logits layer (no bias, no activation). - self._pi_head = nn.Linear(in_size, self.config.action_space.n) + self._pi_head = nn.Linear(in_size, self.action_space.n) # Single-node value layer. self._values = nn.Linear(in_size, 1) diff --git a/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py b/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py index c84d2277afff6..c1f5359a812eb 100644 --- a/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py +++ b/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py @@ -29,23 +29,10 @@ def setup(self): ) @override(RLModule) - def _forward_inference(self, batch): - with torch.no_grad(): - return self._common_forward(batch) - - @override(RLModule) - def _forward_exploration(self, batch): - with torch.no_grad(): - return self._common_forward(batch) - - @override(RLModule) - def _forward_train(self, batch): - return self._common_forward(batch) - - def _common_forward(self, batch): + def _forward(self, batch, **kwargs): # Features can be found in the batch under the "encoder_features" key. - features = batch["encoder_features"] - logits = self._pi_head(features) + embeddings = batch["encoder_embeddings"] + logits = self._pi_head(embeddings) return {Columns.ACTION_DIST_INPUTS: logits} @@ -69,19 +56,19 @@ class VPGTorchMultiRLModuleWithSharedEncoder(MultiRLModule): # Central/shared encoder net. SHARED_ENCODER_ID: RLModuleSpec( module_class=SharedTorchEncoder, - model_config_dict={"embedding_dim": EMBEDDING_DIM}, + model_config={"embedding_dim": EMBEDDING_DIM}, ), # Arbitrary number of policy nets (w/o encoder sub-net). "p0": RLModuleSpec( module_class=VPGTorchRLModuleUsingSharedEncoder, - model_config_dict={ + model_config={ "embedding_dim": EMBEDDING_DIM, "hidden_dim": HIDDEN_DIM, }, ), "p1": RLModuleSpec( module_class=VPGTorchRLModuleUsingSharedEncoder, - model_config_dict={ + model_config={ "embedding_dim": EMBEDDING_DIM, "hidden_dim": HIDDEN_DIM, }, @@ -109,7 +96,7 @@ def setup(self): ) @override(MultiRLModule) - def _run_forward_pass(self, forward_fn_name, batch, **kwargs): + def _forward(self, forward_fn_name, batch, **kwargs): outputs = {} encoder_forward_fn = getattr( self._rl_modules[SHARED_ENCODER_ID], forward_fn_name @@ -122,9 +109,9 @@ def _run_forward_pass(self, forward_fn_name, batch, **kwargs): # Pass policy's observations through shared encoder to get the features for # this policy. - features = encoder_forward_fn(batch[policy_id]) + embeddings = encoder_forward_fn(batch[policy_id]) # Pass the policy's features through the policy net. - batch[policy_id]["encoder_features"] = features + batch[policy_id]["encoder_embeddings"] = embeddings outputs[policy_id] = forward_fn(batch[policy_id], **kwargs) return outputs @@ -144,21 +131,7 @@ def setup(self): nn.Linear(input_dim, embedding_dim), ) - @override(RLModule) - def _forward_inference(self, batch): - with torch.no_grad(): - return self._common_forward(batch) - - @override(RLModule) - def _forward_exploration(self, batch): - with torch.no_grad(): - return self._common_forward(batch) - - @override(RLModule) - def _forward_train(self, batch): - return self._common_forward(batch) - - def _common_forward(self, batch): + def _forward(self, batch, **kwargs): # Pass observations through the encoder and return outputs. - features = self._encoder(batch[Columns.OBS]) - return {"encoder_features": features} + embeddings = self._encoder(batch[Columns.OBS]) + return {"encoder_embeddings": embeddings} diff --git a/rllib/models/catalog.py b/rllib/models/catalog.py index 0a227dfc27da0..bff7ac243e621 100644 --- a/rllib/models/catalog.py +++ b/rllib/models/catalog.py @@ -73,15 +73,7 @@ "post_fcnet_bias_initializer": None, "post_fcnet_bias_initializer_config": None, "free_log_std": False, - # Whether to clip the log standard deviation when using a Gaussian (or any - # other continuous control distribution). This can stabilize training and avoid - # very small or large log standard deviations leading to numerical instabilities - # which can turn network outputs to `nan`. The default is to clamp the log std - # in between -20 and 20. "log_std_clip_param": 20.0, - # Whether to skip the final linear layer used to resize the hidden layer - # outputs to size `num_outputs`. If True, then the last hidden layer - # should already match num_outputs. "no_final_linear": False, "vf_share_layers": True, "use_lstm": False, diff --git a/rllib/tuned_examples/appo/cartpole_appo.py b/rllib/tuned_examples/appo/cartpole_appo.py index e1256524bb5b6..06ffd7dc77f1e 100644 --- a/rllib/tuned_examples/appo/cartpole_appo.py +++ b/rllib/tuned_examples/appo/cartpole_appo.py @@ -1,4 +1,5 @@ from ray.rllib.algorithms.appo import APPOConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.test_utils import add_rllib_example_script_args parser = add_rllib_example_script_args( @@ -24,9 +25,7 @@ entropy_coeff=0.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - }, + model_config=DefaultModelConfig(vf_share_layers=True), ) ) diff --git a/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py b/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py index bceed3d2068fa..cb6f9be28bb21 100644 --- a/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py +++ b/rllib/tuned_examples/appo/multi_agent_cartpole_appo.py @@ -34,9 +34,7 @@ entropy_coeff=0.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - }, + model_config=DefaultModelConfig(vf_share_layers=True), ) .multi_agent( policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), diff --git a/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py b/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py index 7f88d6164c5c5..4437d05730521 100644 --- a/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py +++ b/rllib/tuned_examples/appo/multi_agent_stateless_cartpole_appo.py @@ -42,10 +42,10 @@ grad_clip=20.0, ) .rl_module( - model_config_dict={ - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + use_lstm=True, + max_seq_len=20, + ), ) .multi_agent( policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), diff --git a/rllib/tuned_examples/appo/stateless_cartpole_appo.py b/rllib/tuned_examples/appo/stateless_cartpole_appo.py index 6f2b0265b8123..43df2f3ff302c 100644 --- a/rllib/tuned_examples/appo/stateless_cartpole_appo.py +++ b/rllib/tuned_examples/appo/stateless_cartpole_appo.py @@ -35,11 +35,11 @@ grad_clip=20.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + use_lstm=True, + max_seq_len=20, + ), ) ) diff --git a/rllib/tuned_examples/bc/benchmark_atari_pong_bc.py b/rllib/tuned_examples/bc/benchmark_atari_pong_bc.py index 26174a074ceb8..f5d7727bb68a5 100644 --- a/rllib/tuned_examples/bc/benchmark_atari_pong_bc.py +++ b/rllib/tuned_examples/bc/benchmark_atari_pong_bc.py @@ -272,12 +272,12 @@ def _env_creator(cfg): learner_connector=_make_learner_connector, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], - "conv_activation": "relu", - "post_fcnet_hiddens": [256], - } + model_config=DefaultModelConfig( + vf_share_layers=True, + conv_filters=[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + conv_activation="relu", + post_fcnet_hiddens=[256], + ), ) ) diff --git a/rllib/tuned_examples/bc/cartpole_bc.py b/rllib/tuned_examples/bc/cartpole_bc.py index 8df1efe701613..bae72495fcbe5 100644 --- a/rllib/tuned_examples/bc/cartpole_bc.py +++ b/rllib/tuned_examples/bc/cartpole_bc.py @@ -66,7 +66,7 @@ # The number of iterations to be run per learner when in multi-learner # mode in a single RLlib training iteration. Leave this to `None` to # run an entire epoch on the dataset during a single RLlib training - # iteration. For single-learner mode 1 is the only option. + # iteration. For single-learner mode, 1 is the only option. dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None, ) .training( @@ -74,7 +74,11 @@ # To increase learning speed with multiple learners, # increase the learning rate correspondingly. lr=0.0008 * max(1, args.num_gpus**0.5), - train_batch_size_per_learner=1024, + ) + .rl_module( + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + ), ) ) diff --git a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py index 1cbe44a92f091..eabd261ceac4b 100644 --- a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py +++ b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py @@ -31,7 +31,7 @@ ) .environment(env="multi_agent_cartpole", env_config={"num_agents": args.num_agents}) .training( - lr=0.0005 * (args.num_gpus or 1) ** 0.5, + lr=0.00065 * (args.num_gpus or 1) ** 0.5, train_batch_size_per_learner=48, replay_buffer_config={ "type": "MultiAgentPrioritizedEpisodeReplayBuffer", @@ -47,14 +47,13 @@ epsilon=[(0, 1.0), (20000, 0.02)], ) .rl_module( - model_config_dict={ - "fcnet_hiddens": [256, 256], - "fcnet_activation": "tanh", - "epsilon": [(0, 1.0), (20000, 0.02)], - "fcnet_bias_initializer": "zeros_", - "post_fcnet_bias_initializer": "zeros_", - "post_fcnet_hiddens": [256], - }, + model_config=DefaultModelConfig( + fcnet_hiddens=[256, 256], + fcnet_activation="tanh", + fcnet_bias_initializer="zeros_", + head_fcnet_bias_initializer="zeros_", + head_fcnet_hiddens=[256], + ), ) ) diff --git a/rllib/tuned_examples/impala/cartpole_impala.py b/rllib/tuned_examples/impala/cartpole_impala.py index 4e78a0f55fa17..00373e986ad09 100644 --- a/rllib/tuned_examples/impala/cartpole_impala.py +++ b/rllib/tuned_examples/impala/cartpole_impala.py @@ -31,9 +31,9 @@ entropy_coeff=0.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + ), ) ) diff --git a/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py b/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py index 6a00e2f299e80..932cc0ca16080 100644 --- a/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py +++ b/rllib/tuned_examples/impala/multi_agent_cartpole_impala.py @@ -38,9 +38,9 @@ entropy_coeff=0.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + ), ) .multi_agent( policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), diff --git a/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py b/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py index 6a3a793edefce..63f26bf8a9203 100644 --- a/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py +++ b/rllib/tuned_examples/impala/multi_agent_stateless_cartpole_impala.py @@ -44,10 +44,10 @@ grad_clip=20.0, ) .rl_module( - model_config_dict={ - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + use_lstm=True, + max_seq_len=20, + ), ) .multi_agent( policy_mapping_fn=(lambda agent_id, episode, **kwargs: f"p{agent_id}"), diff --git a/rllib/tuned_examples/impala/pendulum_impala.py b/rllib/tuned_examples/impala/pendulum_impala.py index 9373908214cd7..3f9ecad3cf0c4 100644 --- a/rllib/tuned_examples/impala/pendulum_impala.py +++ b/rllib/tuned_examples/impala/pendulum_impala.py @@ -31,10 +31,10 @@ entropy_coeff=[[0, 0.1], [2000000, 0.0]], ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "fcnet_hiddens": [512, 512], - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + fcnet_hiddens=[512, 512], + ), ) ) diff --git a/rllib/tuned_examples/impala/stateless_cartpole_impala.py b/rllib/tuned_examples/impala/stateless_cartpole_impala.py index 2e4838c15168d..1c0376de55c5f 100644 --- a/rllib/tuned_examples/impala/stateless_cartpole_impala.py +++ b/rllib/tuned_examples/impala/stateless_cartpole_impala.py @@ -35,11 +35,11 @@ entropy_coeff=0.0, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + use_lstm=True, + max_seq_len=20, + ), ) ) diff --git a/rllib/tuned_examples/ppo/atari_ppo.py b/rllib/tuned_examples/ppo/atari_ppo.py index 7cc28d441d2f5..7abcfdff245ef 100644 --- a/rllib/tuned_examples/ppo/atari_ppo.py +++ b/rllib/tuned_examples/ppo/atari_ppo.py @@ -72,12 +72,12 @@ def _env_creator(cfg): grad_clip_by="global_norm", ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "conv_filters": [[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], - "conv_activation": "relu", - "post_fcnet_hiddens": [256], - } + model_config=DefaultModelConfig( + conv_filters=[[16, 4, 2], [32, 4, 2], [64, 4, 2], [128, 4, 2]], + conv_activation="relu", + head_fcnet_hiddens=[256], + vf_share_layers=True, + ), ) ) diff --git a/rllib/tuned_examples/ppo/cartpole_ppo.py b/rllib/tuned_examples/ppo/cartpole_ppo.py index 18f11b9d8ffa4..de33650280b0e 100644 --- a/rllib/tuned_examples/ppo/cartpole_ppo.py +++ b/rllib/tuned_examples/ppo/cartpole_ppo.py @@ -1,4 +1,5 @@ from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.test_utils import add_rllib_example_script_args parser = add_rllib_example_script_args(default_reward=450.0, default_timesteps=300000) @@ -16,11 +17,11 @@ vf_loss_coeff=0.01, ) .rl_module( - model_config_dict={ - "fcnet_hiddens": [32], - "fcnet_activation": "linear", - "vf_share_layers": True, - } + model_config=DefaultModelConfig( + fcnet_hiddens=[32], + fcnet_activation="linear", + vf_share_layers=True, + ), ) ) diff --git a/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py b/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py index 3b67ca6153e0d..9ad40c4c2b479 100644 --- a/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py +++ b/rllib/tuned_examples/ppo/multi_agent_pendulum_ppo.py @@ -1,5 +1,6 @@ from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -38,9 +39,7 @@ lambda_=0.5, ) .rl_module( - model_config_dict={ - "fcnet_activation": "relu", - }, + model_config=DefaultModelConfig(fcnet_activation="relu"), ) .multi_agent( policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}", diff --git a/rllib/tuned_examples/ppo/multi_agent_stateless_cartpole_ppo.py b/rllib/tuned_examples/ppo/multi_agent_stateless_cartpole_ppo.py index 00e449e4a63e0..d700cb7ab0c8d 100644 --- a/rllib/tuned_examples/ppo/multi_agent_stateless_cartpole_ppo.py +++ b/rllib/tuned_examples/ppo/multi_agent_stateless_cartpole_ppo.py @@ -42,10 +42,10 @@ vf_loss_coeff=0.05, ) .rl_module( - model_config_dict={ - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + use_lstm=True, + max_seq_len=20, + ), ) .multi_agent( policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}", diff --git a/rllib/tuned_examples/ppo/pendulum_ppo.py b/rllib/tuned_examples/ppo/pendulum_ppo.py index ef0e3ce32acdb..d381b529f0fce 100644 --- a/rllib/tuned_examples/ppo/pendulum_ppo.py +++ b/rllib/tuned_examples/ppo/pendulum_ppo.py @@ -1,5 +1,6 @@ from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.connectors.env_to_module import MeanStdFilter +from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.utils.test_utils import add_rllib_example_script_args parser = add_rllib_example_script_args(default_timesteps=400000, default_reward=-300) @@ -25,9 +26,7 @@ # num_epochs=8, ) .rl_module( - model_config_dict={ - "fcnet_activation": "relu", - }, + model_config=DefaultModelConfig(fcnet_activation="relu"), ) ) diff --git a/rllib/tuned_examples/ppo/stateless_cartpole_ppo.py b/rllib/tuned_examples/ppo/stateless_cartpole_ppo.py index ef49776f31dc0..36880cc9b8610 100644 --- a/rllib/tuned_examples/ppo/stateless_cartpole_ppo.py +++ b/rllib/tuned_examples/ppo/stateless_cartpole_ppo.py @@ -29,11 +29,11 @@ vf_loss_coeff=0.05, ) .rl_module( - model_config_dict={ - "vf_share_layers": True, - "use_lstm": True, - "max_seq_len": 20, - }, + model_config=DefaultModelConfig( + vf_share_layers=True, + use_lstm=True, + max_seq_len=20, + ), ) )