diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index bce58cd55c3e..d53815989e09 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -14,7 +14,7 @@ ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY -from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI +from ray.rllib.core.rl_module.apis import TargetNetworkAPI, ValueFunctionAPI from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import convert_to_numpy @@ -35,6 +35,10 @@ def compute_loss_for_module( batch: Dict, fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + assert isinstance(module, TargetNetworkAPI) + assert isinstance(module, ValueFunctionAPI) + # TODO (sven): Now that we do the +1ts trick to be less vulnerable about # bootstrap values at the end of rollouts in the new stack, we might make # this a more flexible, configurable parameter for users, e.g. @@ -51,10 +55,9 @@ def compute_loss_for_module( ) size_loss_mask = torch.sum(loss_mask) - module = self.module[module_id].unwrapped() - assert isinstance(module, TargetNetworkAPI) - - values = fwd_out[Columns.VF_PREDS] + values = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) action_dist_cls_train = module.get_train_action_dist_cls() target_policy_dist = action_dist_cls_train.from_logits( diff --git a/rllib/algorithms/bc/bc_catalog.py b/rllib/algorithms/bc/bc_catalog.py index 6f4a1f8468fa..a1b5e8970e57 100644 --- a/rllib/algorithms/bc/bc_catalog.py +++ b/rllib/algorithms/bc/bc_catalog.py @@ -64,7 +64,7 @@ def build_pi_head(self, framework: str) -> Model: The default behavior is to build the head from the pi_head_config. This can be overridden to build a custom policy head as a means of configuring - the behavior of a BCRLModule implementation. + the behavior of a BC specific RLModule implementation. Args: framework: The framework to use. Either "torch" or "tf2". diff --git a/rllib/algorithms/bc/bc_rl_module.py b/rllib/algorithms/bc/bc_rl_module.py deleted file mode 100644 index 6099ffb9ea81..000000000000 --- a/rllib/algorithms/bc/bc_rl_module.py +++ /dev/null @@ -1,82 +0,0 @@ -import abc -from typing import Any, Dict, List, Union - -from ray.rllib.core.columns import Columns -from ray.rllib.core.models.base import ENCODER_OUT -from ray.rllib.core.models.specs.typing import SpecType -from ray.rllib.core.rl_module.rl_module import RLModule -from ray.rllib.utils.annotations import override -from ray.rllib.utils.typing import TensorType -from ray.util.annotations import DeveloperAPI - - -@DeveloperAPI(stability="alpha") -class BCRLModule(RLModule, abc.ABC): - @override(RLModule) - def setup(self): - # __sphinx_doc_begin__ - # Build models from catalog - self.encoder = self.catalog.build_encoder(framework=self.framework) - self.pi = self.catalog.build_pi_head(framework=self.framework) - - @override(RLModule) - def get_initial_state(self) -> Union[dict, List[TensorType]]: - if hasattr(self.encoder, "get_initial_state"): - return self.encoder.get_initial_state() - else: - return {} - - @override(RLModule) - def output_specs_inference(self) -> SpecType: - return self.output_specs_exploration() - - @override(RLModule) - def output_specs_exploration(self) -> SpecType: - return [Columns.ACTION_DIST_INPUTS] - - @override(RLModule) - def output_specs_train(self) -> SpecType: - return self.output_specs_exploration() - - @override(RLModule) - def _forward_inference(self, batch: Dict, **kwargs) -> Dict[str, Any]: - """BC forward pass during inference. - - See the `BCTorchRLModule._forward_exploration` method for - implementation details. - """ - return self._forward_exploration(batch) - - @override(RLModule) - def _forward_exploration(self, batch: Dict, **kwargs) -> Dict[str, Any]: - """BC forward pass during exploration. - - Besides the action distribution this method also returns a possible - state in case a stateful encoder is used. - - Note that for BC `_forward_train`, `_forward_exploration`, and - `_forward_inference` return the same items and therefore only - `_forward_exploration` is implemented and is used by the two other - forward methods. - """ - output = {} - - # State encodings. - encoder_outs = self.encoder(batch) - if Columns.STATE_OUT in encoder_outs: - output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] - - # Actions. - action_logits = self.pi(encoder_outs[ENCODER_OUT]) - output[Columns.ACTION_DIST_INPUTS] = action_logits - - return output - - @override(RLModule) - def _forward_train(self, batch: Dict, **kwargs) -> Dict[str, Any]: - """BC forward pass during training. - - See the `BCTorchRLModule._forward_exploration` method for - implementation details. - """ - return self._forward_exploration(batch) diff --git a/rllib/algorithms/bc/torch/bc_torch_rl_module.py b/rllib/algorithms/bc/torch/bc_torch_rl_module.py index 76f9d26020a6..6d328bb56819 100644 --- a/rllib/algorithms/bc/torch/bc_torch_rl_module.py +++ b/rllib/algorithms/bc/torch/bc_torch_rl_module.py @@ -1,6 +1,32 @@ -from ray.rllib.algorithms.bc.bc_rl_module import BCRLModule +from typing import Any, Dict + +from ray.rllib.core import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule +from ray.rllib.utils.annotations import override + + +class BCTorchRLModule(TorchRLModule): + @override(RLModule) + def setup(self): + # __sphinx_doc_begin__ + # Build models from catalog + self.encoder = self.catalog.build_encoder(framework=self.framework) + self.pi = self.catalog.build_pi_head(framework=self.framework) + + @override(RLModule) + def _forward(self, batch: Dict, **kwargs) -> Dict[str, Any]: + """Generic BC forward pass (for all phases of training/evaluation).""" + output = {} + + # State encodings. + encoder_outs = self.encoder(batch) + if Columns.STATE_OUT in encoder_outs: + output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] + # Actions. + action_logits = self.pi(encoder_outs[ENCODER_OUT]) + output[Columns.ACTION_DIST_INPUTS] = action_logits -class BCTorchRLModule(TorchRLModule, BCRLModule): - pass + return output diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index c98c0947be05..256e3b48fb79 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -28,6 +28,8 @@ def compute_loss_for_module( batch: Dict, fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + # TODO (sven): Now that we do the +1ts trick to be less vulnerable about # bootstrap values at the end of rollouts in the new stack, we might make # this a more flexible, configurable parameter for users, e.g. @@ -46,17 +48,17 @@ def compute_loss_for_module( # Behavior actions logp and target actions logp. behaviour_actions_logp = batch[Columns.ACTION_LOGP] - target_policy_dist = ( - self.module[module_id] - .unwrapped() - .get_train_action_dist_cls() - .from_logits(fwd_out[Columns.ACTION_DIST_INPUTS]) + target_policy_dist = module.get_train_action_dist_cls().from_logits( + fwd_out[Columns.ACTION_DIST_INPUTS] ) target_actions_logp = target_policy_dist.logp(batch[Columns.ACTIONS]) # Values and bootstrap values. + values = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) values_time_major = make_time_major( - fwd_out[Columns.VF_PREDS], + values, trajectory_len=rollout_frag_or_episode_len, recurrent_seq_len=recurrent_seq_len, ) diff --git a/rllib/algorithms/marwil/marwil_rl_module.py b/rllib/algorithms/marwil/marwil_rl_module.py index a2a1b13dc4f4..a0e5a40db4f9 100644 --- a/rllib/algorithms/marwil/marwil_rl_module.py +++ b/rllib/algorithms/marwil/marwil_rl_module.py @@ -10,6 +10,7 @@ @DeveloperAPI(stability="alpha") class MARWILRLModule(RLModule, ValueFunctionAPI, abc.ABC): + @override(RLModule) def setup(self): # Build models from catalog self.encoder = self.catalog.build_actor_critic_encoder(framework=self.framework) @@ -47,7 +48,4 @@ def input_specs_train(self) -> SpecDict: @override(RLModule) def output_specs_train(self) -> SpecDict: - return [ - Columns.VF_PREDS, - Columns.ACTION_DIST_INPUTS, - ] + return [Columns.ACTION_DIST_INPUTS] diff --git a/rllib/algorithms/marwil/torch/marwil_torch_learner.py b/rllib/algorithms/marwil/torch/marwil_torch_learner.py index afd9bc748270..58905920655d 100644 --- a/rllib/algorithms/marwil/torch/marwil_torch_learner.py +++ b/rllib/algorithms/marwil/torch/marwil_torch_learner.py @@ -31,6 +31,7 @@ def compute_loss_for_module( batch: Dict[str, Any], fwd_out: Dict[str, TensorType] ) -> TensorType: + module = self.module[module_id].unwrapped() # Possibly apply masking to some sub loss terms and to the total loss term # at the end. Masking could be used for RNN-based model (zero padded `batch`) @@ -46,9 +47,7 @@ def possibly_masked_mean(data_): else: possibly_masked_mean = torch.mean - action_dist_class_train = ( - self.module[module_id].unwrapped().get_train_action_dist_cls() - ) + action_dist_class_train = module.get_train_action_dist_cls() curr_action_dist = action_dist_class_train.from_logits( fwd_out[Columns.ACTION_DIST_INPUTS] ) @@ -64,7 +63,9 @@ def possibly_masked_mean(data_): # Otherwise, compute advantages. else: # cumulative_rewards = batch[Columns.ADVANTAGES] - value_fn_out = fwd_out[Columns.VF_PREDS] + value_fn_out = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) advantages = batch[Columns.VALUE_TARGETS] - value_fn_out advantages_squared_mean = possibly_masked_mean(torch.pow(advantages, 2.0)) diff --git a/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py b/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py index fe774e9041f4..9f5098504046 100644 --- a/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py +++ b/rllib/algorithms/marwil/torch/marwil_torch_rl_module.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.algorithms.marwil.marwil_rl_module import MARWILRLModule from ray.rllib.core.columns import Columns @@ -42,7 +42,6 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: "flag `inference_only=False` when building the module." ) output = {} - # Shared encoder. encoder_outs = self.encoder(batch) if Columns.STATE_OUT in encoder_outs: @@ -63,18 +62,23 @@ def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: # (similar to IMPALA's v-trace architecture). This would also get rid of the # second Connector pass currently necessary. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - # Separate vf-encoder. - if hasattr(self.encoder, "critic_encoder"): - if self.is_stateful(): - # The recurrent encoders expect a `(state_in, h)` key in the - # input dict while the key returned is `(state_in, critic, h)`. - batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] - encoder_outs = self.encoder.critic_encoder(batch)[ENCODER_OUT] - # Shared encoder. - else: - encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + # Separate vf-encoder. + if hasattr(self.encoder, "critic_encoder"): + if self.is_stateful(): + # The recurrent encoders expect a `(state_in, h)` key in the + # input dict while the key returned is `(state_in, critic, h)`. + batch[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] + embeddings = self.encoder.critic_encoder(batch)[ENCODER_OUT] + # Shared encoder. + else: + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] # Value head. - vf_out = self.vf(encoder_outs) + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/algorithms/ppo/ppo_rl_module.py b/rllib/algorithms/ppo/ppo_rl_module.py index 85ce1ea6b90a..5c48ab7af7b6 100644 --- a/rllib/algorithms/ppo/ppo_rl_module.py +++ b/rllib/algorithms/ppo/ppo_rl_module.py @@ -72,10 +72,7 @@ def input_specs_train(self) -> SpecDict: @override(RLModule) def output_specs_train(self) -> SpecDict: - return [ - Columns.VF_PREDS, - Columns.ACTION_DIST_INPUTS, - ] + return [Columns.ACTION_DIST_INPUTS] @OverrideToImplementCustomLogic_CallToSuperRecommended @override(InferenceOnlyAPI) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index 2473fbafa32e..a1ff80592eb9 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -26,21 +26,7 @@ torch, nn = try_import_torch() -def get_expected_module_config( - env: gym.Env, - model_config_dict: dict, - observation_space: gym.spaces.Space, -) -> RLModuleConfig: - """Get a PPOModuleConfig that we would expect from the catalog otherwise. - - Args: - env: Environment for which we build the model later - model_config_dict: Model config to use for the catalog - observation_space: Observation space to use for the catalog. - - Returns: - A PPOModuleConfig containing the relevant configs to build PPORLModule - """ +def get_expected_module_config(env, model_config_dict, observation_space): config = RLModuleConfig( observation_space=observation_space, action_space=env.action_space, @@ -52,22 +38,7 @@ def get_expected_module_config( def dummy_torch_ppo_loss(module, batch, fwd_out): - """Dummy PPO loss function for testing purposes. - - Will eventually use the actual PPO loss function implemented in PPO. - - Args: - batch: Batch used for training. - fwd_out: Forward output of the model. - - Returns: - Loss tensor - """ - # TODO: we should replace these components later with real ppo components when - # RLOptimizer and RLModule are integrated together. - # this is not exactly a ppo loss, just something to show that the - # forward train works - adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS] + adv = batch[Columns.REWARDS] - module.compute_values(batch) action_dist_class = module.get_train_action_dist_cls() action_probs = action_dist_class.from_logits( fwd_out[Columns.ACTION_DIST_INPUTS] @@ -80,19 +51,7 @@ def dummy_torch_ppo_loss(module, batch, fwd_out): def dummy_tf_ppo_loss(module, batch, fwd_out): - """Dummy PPO loss function for testing purposes. - - Will eventually use the actual PPO loss function implemented in PPO. - - Args: - module: PPOTfRLModule - batch: Batch used for training. - fwd_out: Forward output of the model. - - Returns: - Loss tensor - """ - adv = batch[Columns.REWARDS] - fwd_out[Columns.VF_PREDS] + adv = batch[Columns.REWARDS] - module.compute_values(batch) action_dist_class = module.get_train_action_dist_cls() action_probs = action_dist_class.from_logits( fwd_out[Columns.ACTION_DIST_INPUTS] @@ -180,7 +139,7 @@ def test_rollouts(self): def test_forward_train(self): # TODO: Add FrozenLake-v1 to cover LSTM case. - frameworks = ["tf2", "torch"] + frameworks = ["torch", "tf2"] env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"] lstm = [False, True] config_combinations = [frameworks, env_names, lstm] diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index e242e2f89229..021b68e505de 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -58,7 +58,7 @@ def _forward_train(self, batch: Dict): return output @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: + def compute_values(self, batch: Dict[str, Any], embeddings=None) -> TensorType: infos = batch.pop(Columns.INFOS, None) batch = tree.map_structure(lambda s: tf.convert_to_tensor(s), batch) if infos is not None: diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index f866165e2243..8cff87e4fc2a 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -40,10 +40,12 @@ def compute_loss_for_module( batch: Dict[str, Any], fwd_out: Dict[str, TensorType], ) -> TensorType: + module = self.module[module_id].unwrapped() + # Possibly apply masking to some sub loss terms and to the total loss term # at the end. Masking could be used for RNN-based model (zero padded `batch`) # and for PPO's batched value function (and bootstrap value) computations, - # for which we add an additional (artificial) timestep to each episode to + # for which we add an (artificial) timestep to each episode to # simplify the actual computation. if Columns.LOSS_MASK in batch: mask = batch[Columns.LOSS_MASK] @@ -55,12 +57,8 @@ def possibly_masked_mean(data_): else: possibly_masked_mean = torch.mean - action_dist_class_train = ( - self.module[module_id].unwrapped().get_train_action_dist_cls() - ) - action_dist_class_exploration = ( - self.module[module_id].unwrapped().get_exploration_action_dist_cls() - ) + action_dist_class_train = module.get_train_action_dist_cls() + action_dist_class_exploration = module.get_exploration_action_dist_cls() curr_action_dist = action_dist_class_train.from_logits( fwd_out[Columns.ACTION_DIST_INPUTS] @@ -91,12 +89,14 @@ def possibly_masked_mean(data_): # Compute a value function loss. if config.use_critic: - value_fn_out = fwd_out[Columns.VF_PREDS] + value_fn_out = module.compute_values( + batch, embeddings=fwd_out.get(Columns.EMBEDDINGS) + ) vf_loss = torch.pow(value_fn_out - batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss_clipped = torch.clamp(vf_loss, 0, config.vf_clip_param) mean_vf_loss = possibly_masked_mean(vf_loss_clipped) mean_vf_unclipped_loss = possibly_masked_mean(vf_loss) - # Ignore the value function. + # Ignore the value function -> Set all to 0.0. else: z = torch.tensor(0.0, device=surrogate_loss.device) value_fn_out = mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = z diff --git a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py index 2f3283cf1ca0..86df016326a1 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.algorithms.ppo.ppo_rl_module import PPORLModule from ray.rllib.core.columns import Columns @@ -17,63 +17,50 @@ class PPOTorchRLModule(TorchRLModule, PPORLModule): framework: str = "torch" @override(RLModule) - def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]: + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Default forward pass (used for inference and exploration).""" output = {} - # Encoder forward pass. encoder_outs = self.encoder(batch) + # Stateful encoder? if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] - # Pi head. output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - return output @override(RLModule) - def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - return self._forward_inference(batch) - - @override(RLModule) - def _forward_train(self, batch: Dict[str, Any]) -> Dict[str, Any]: - if self.config.inference_only: - raise RuntimeError( - "Trying to train a module that is not a learner module. Set the " - "flag `inference_only=False` when building the module." - ) + def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Train forward pass (keep features for possible shared value func. call).""" output = {} - - # Shared encoder. encoder_outs = self.encoder(batch) + output[Columns.EMBEDDINGS] = encoder_outs[ENCODER_OUT][CRITIC] if Columns.STATE_OUT in encoder_outs: output[Columns.STATE_OUT] = encoder_outs[Columns.STATE_OUT] - - # Value head. - vf_out = self.vf(encoder_outs[ENCODER_OUT][CRITIC]) - # Squeeze out last dim (value function node). - output[Columns.VF_PREDS] = vf_out.squeeze(-1) - - # Policy head. - action_logits = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) - output[Columns.ACTION_DIST_INPUTS] = action_logits - + output[Columns.ACTION_DIST_INPUTS] = self.pi(encoder_outs[ENCODER_OUT][ACTOR]) return output @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - # Separate vf-encoder. - if hasattr(self.encoder, "critic_encoder"): - batch_ = batch - if self.is_stateful(): - # The recurrent encoders expect a `(state_in, h)` key in the - # input dict while the key returned is `(state_in, critic, h)`. - batch_ = batch.copy() - batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] - encoder_outs = self.encoder.critic_encoder(batch_)[ENCODER_OUT] - # Shared encoder. - else: - encoder_outs = self.encoder(batch)[ENCODER_OUT][CRITIC] + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + if embeddings is None: + # Separate vf-encoder. + if hasattr(self.encoder, "critic_encoder"): + batch_ = batch + if self.is_stateful(): + # The recurrent encoders expect a `(state_in, h)` key in the + # input dict while the key returned is `(state_in, critic, h)`. + batch_ = batch.copy() + batch_[Columns.STATE_IN] = batch[Columns.STATE_IN][CRITIC] + embeddings = self.encoder.critic_encoder(batch_)[ENCODER_OUT] + # Shared encoder. + else: + embeddings = self.encoder(batch)[ENCODER_OUT][CRITIC] + # Value head. - vf_out = self.vf(encoder_outs) + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/core/columns.py b/rllib/core/columns.py index 0944d521e2c1..2fc722f5c724 100644 --- a/rllib/core/columns.py +++ b/rllib/core/columns.py @@ -44,6 +44,7 @@ class Columns: # Common extra RLModule output keys. STATE_IN = "state_in" STATE_OUT = "state_out" + EMBEDDINGS = "embeddings" ACTION_DIST_INPUTS = "action_dist_inputs" ACTION_PROB = "action_prob" ACTION_LOGP = "action_logp" diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index dcb088ac7450..f0db18a79295 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -691,7 +691,7 @@ def filter_param_dict_for_optimizer( def get_param_ref(self, param: Param) -> Hashable: """Returns a hashable reference to a trainable parameter. - This should be overriden in framework specific specialization. For example in + This should be overridden in framework specific specialization. For example in torch it will return the parameter itself, while in tf it returns the .ref() of the variable. The purpose is to retrieve a unique reference to the parameters. @@ -706,7 +706,7 @@ def get_param_ref(self, param: Param) -> Hashable: def get_parameters(self, module: RLModule) -> Sequence[Param]: """Returns the list of parameters of a module. - This should be overriden in framework specific learner. For example in torch it + This should be overridden in framework specific learner. For example in torch it will return .parameters(), while in tf it returns .trainable_variables. Args: diff --git a/rllib/core/rl_module/apis/value_function_api.py b/rllib/core/rl_module/apis/value_function_api.py index 06f0afccc19e..43280228badb 100644 --- a/rllib/core/rl_module/apis/value_function_api.py +++ b/rllib/core/rl_module/apis/value_function_api.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.utils.typing import TensorType @@ -7,14 +7,24 @@ class ValueFunctionAPI(abc.ABC): """An API to be implemented by RLModules for handling value function-based learning. - RLModules implementing this API must override the `compute_values` method.""" + RLModules implementing this API must override the `compute_values` method. + """ @abc.abstractmethod - def compute_values(self, batch: Dict[str, Any]) -> TensorType: + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: """Computes the value estimates given `batch`. Args: batch: The batch to compute value function estimates for. + embeddings: Optional embeddings already computed from the `batch` (by + another forward pass through the model's encoder (or other subcomponent + that computes an embedding). For example, the caller of thie method + should provide `embeddings` - if available - to avoid duplicate passes + through a shared encoder. Returns: A tensor of shape (B,) or (B, T) (in case the input `batch` has a diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py index a4b0deedce1e..a9f8a7a606f8 100644 --- a/rllib/core/rl_module/multi_rl_module.py +++ b/rllib/core/rl_module/multi_rl_module.py @@ -20,7 +20,6 @@ from ray.rllib.core import COMPONENT_MULTI_RL_MODULE_SPEC from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import force_list from ray.rllib.utils.annotations import ( ExperimentalAPI, @@ -74,7 +73,7 @@ def __init__(self, config: Optional["MultiRLModuleConfig"] = None) -> None: def setup(self): """Sets up the underlying RLModules.""" self._rl_modules = {} - self.__check_module_configs(self.config.modules) + self._check_module_configs(self.config.modules) # Make sure all individual RLModules have the same framework OR framework=None. framework = None for module_id, module_spec in self.config.modules.items(): @@ -85,6 +84,91 @@ def setup(self): assert self._rl_modules[module_id].framework in [None, framework] self.framework = framework + @override(RLModule) + def _forward( + self, + batch: Dict[ModuleID, Any], + **kwargs, + ) -> Dict[ModuleID, Dict[str, Any]]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods instead: + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. + + Args: + batch: The input batch, a dict mapping from ModuleID to individual modules' + batches. + **kwargs: Additional keyword arguments. + + Returns: + The output of the forward pass. + """ + return { + mid: self._rl_modules[mid]._forward(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_inference( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_inference(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_exploration( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_exploration(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + + @override(RLModule) + def _forward_train( + self, batch: Dict[str, Any], **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return { + mid: self._rl_modules[mid]._forward_train(batch[mid], **kwargs) + for mid in batch.keys() + if mid in self + } + @OverrideToImplementCustomLogic @override(RLModule) def get_initial_state(self) -> Any: @@ -105,50 +189,6 @@ def is_stateful(self) -> bool: ) return bool(any(sa_init_state for sa_init_state in initial_state.values())) - @classmethod - def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]): - """Checks the module configs for validity. - - The module_configs be a mapping from module_ids to RLModuleSpec - objects. - - Args: - module_configs: The module configs to check. - - Raises: - ValueError: If the module configs are invalid. - """ - for module_id, module_spec in module_configs.items(): - if not isinstance(module_spec, RLModuleSpec): - raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") - - def items(self) -> ItemsView[ModuleID, RLModule]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.items() - - def keys(self) -> KeysView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.keys() - - def values(self) -> ValuesView[ModuleID]: - """Returns a keys view over the module IDs in this MultiRLModule.""" - return self._rl_modules.values() - - def __len__(self) -> int: - """Returns the number of RLModules within this MultiRLModule.""" - return len(self._rl_modules) - - @override(RLModule) - def as_multi_rl_module(self) -> "MultiRLModule": - """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. - - This method is overridden to avoid double wrapping. - - Returns: - The instance itself. - """ - return self - def add_module( self, module_id: ModuleID, @@ -271,70 +311,24 @@ def get( return default return self._rl_modules[module_id] - @override(RLModule) - def output_specs_train(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_inference(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_exploration(self) -> SpecType: - return [] - - @override(RLModule) - def _default_input_specs(self) -> SpecType: - """MultiRLModule should not check the input specs. - - The underlying single-agent RLModules will check the input specs. - """ - return [] - - @override(RLModule) - def _forward_train( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_train pass. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). - - Returns: - The output of the forward_train pass the specified modules. - """ - return self._run_forward_pass("forward_train", batch, **kwargs) - - @override(RLModule) - def _forward_inference( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_inference pass. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). + def items(self) -> ItemsView[ModuleID, RLModule]: + """Returns an ItemsView over the module IDs in this MultiRLModule.""" + return self._rl_modules.items() - Returns: - The output of the forward_inference pass the specified modules. - """ - return self._run_forward_pass("forward_inference", batch, **kwargs) + def keys(self) -> KeysView[ModuleID]: + """Returns a KeysView over the module IDs in this MultiRLModule.""" + return self._rl_modules.keys() - @override(RLModule) - def _forward_exploration( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_exploration pass. + def values(self) -> ValuesView[ModuleID]: + """Returns a ValuesView over the module IDs in this MultiRLModule.""" + return self._rl_modules.values() - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - individual modules' batches). + def __len__(self) -> int: + """Returns the number of RLModules within this MultiRLModule.""" + return len(self._rl_modules) - Returns: - The output of the forward_exploration pass the specified modules. - """ - return self._run_forward_pass("forward_exploration", batch, **kwargs) + def __repr__(self) -> str: + return f"MARL({pprint.pformat(self._rl_modules)})" @override(RLModule) def get_state( @@ -409,39 +403,53 @@ def set_state(self, state: StateDict) -> None: def get_checkpointable_components(self) -> List[Tuple[str, Checkpointable]]: return list(self._rl_modules.items()) - def __repr__(self) -> str: - return f"MARL({pprint.pformat(self._rl_modules)})" + @override(RLModule) + def output_specs_train(self) -> SpecType: + return [] - def _run_forward_pass( - self, - forward_fn_name: str, - batch: Dict[ModuleID, Any], - **kwargs, - ) -> Dict[ModuleID, Dict[ModuleID, Any]]: - """This is a helper method that runs the forward pass for the given module. + @override(RLModule) + def output_specs_inference(self) -> SpecType: + return [] + + @override(RLModule) + def output_specs_exploration(self) -> SpecType: + return [] - It uses forward_fn_name to get the forward pass method from the RLModule - (e.g. forward_train vs. forward_exploration) and runs it on the given batch. + @override(RLModule) + def _default_input_specs(self) -> SpecType: + """MultiRLModule should not check the input specs. - Args: - forward_fn_name: The name of the forward pass method to run. - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). - **kwargs: Additional keyword arguments to pass to the forward function. + The underlying single-agent RLModules will check the input specs. + """ + return [] + + @override(RLModule) + def as_multi_rl_module(self) -> "MultiRLModule": + """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. + + This method is overridden to avoid double wrapping. Returns: - The output of the forward pass the specified modules. The output is a - mapping from module ID to the output of the forward pass. + The instance itself. """ + return self - outputs = {} - for module_id in batch.keys(): - self._check_module_exists(module_id) - rl_module = self._rl_modules[module_id] - forward_fn = getattr(rl_module, forward_fn_name) - outputs[module_id] = forward_fn(batch[module_id], **kwargs) + @classmethod + def _check_module_configs(cls, module_configs: Dict[ModuleID, Any]): + """Checks the module configs for validity. + + The module_configs be a mapping from module_ids to RLModuleSpec + objects. + + Args: + module_configs: The module configs to check. - return outputs + Raises: + ValueError: If the module configs are invalid. + """ + for module_id, module_spec in module_configs.items(): + if not isinstance(module_spec, RLModuleSpec): + raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") def _check_module_exists(self, module_id: ModuleID) -> None: if module_id not in self._rl_modules: @@ -457,7 +465,7 @@ class MultiRLModuleSpec: """A utility spec class to make it constructing MultiRLModules easier. Users can extend this class to modify the behavior of base class. For example to - share neural networks across the modules, the build method can be overriden to + share neural networks across the modules, the build method can be overridden to create the shared module first and then pass it to custom module classes that would then use it as a shared module. diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index 88debf2204c7..1ddd33471a29 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -4,13 +4,6 @@ import gymnasium as gym -if TYPE_CHECKING: - from ray.rllib.core.rl_module.multi_rl_module import ( - MultiRLModule, - MultiRLModuleSpec, - ) - from ray.rllib.core.models.catalog import Catalog - from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.models.specs.typing import SpecType @@ -34,9 +27,16 @@ serialize_type, deserialize_type, ) -from ray.rllib.utils.typing import SampleBatchType, StateDict +from ray.rllib.utils.typing import StateDict from ray.util.annotations import PublicAPI +if TYPE_CHECKING: + from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, + ) + from ray.rllib.core.models.catalog import Catalog + @PublicAPI(stability="alpha") @dataclass @@ -476,21 +476,6 @@ def setup(self): """ return None - @OverrideToImplementCustomLogic - def get_train_action_dist_cls(self) -> Type[Distribution]: - """Returns the action distribution class for this RLModule used for training. - - This class is used to get the correct action distribution class to be used by - the training components. In case that no action distribution class is needed, - this method can return None. - - Note that RLlib's distribution classes all implement the `Distribution` - interface. This requires two special methods: `Distribution.from_logits()` and - `Distribution.to_deterministic()`. See the documentation of the - :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. - """ - raise NotImplementedError - @OverrideToImplementCustomLogic def get_exploration_action_dist_cls(self) -> Type[Distribution]: """Returns the action distribution class for this RLModule used for exploration. @@ -522,77 +507,49 @@ def get_inference_action_dist_cls(self) -> Type[Distribution]: raise NotImplementedError @OverrideToImplementCustomLogic - def get_initial_state(self) -> Any: - """Returns the initial state of the RLModule. - - This can be used for recurrent models. - """ - return {} + def get_train_action_dist_cls(self) -> Type[Distribution]: + """Returns the action distribution class for this RLModule used for training. - @OverrideToImplementCustomLogic - def is_stateful(self) -> bool: - """Returns False if the initial state is an empty dict (or None). + This class is used to get the correct action distribution class to be used by + the training components. In case that no action distribution class is needed, + this method can return None. - By default, RLlib assumes that the module is non-recurrent if the initial - state is an empty dict and recurrent otherwise. - This behavior can be overridden by implementing this method. + Note that RLlib's distribution classes all implement the `Distribution` + interface. This requires two special methods: `Distribution.from_logits()` and + `Distribution.to_deterministic()`. See the documentation of the + :py:class:`~ray.rllib.models.distributions.Distribution` class for more details. """ - initial_state = self.get_initial_state() - assert isinstance(initial_state, dict), ( - "The initial state of an RLModule must be a dict, but is " - f"{type(initial_state)} instead." - ) - return bool(initial_state) - - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_inference(self) -> SpecType: - """Returns the output specs of the `forward_inference()` method. + raise NotImplementedError - Override this method to customize the output specs of the inference call. - The default implementation requires the `forward_inference()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. - """ - return [Columns.ACTION_DIST_INPUTS] + @OverrideToImplementCustomLogic + def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """Generic forward pass method, used in all phases of training and evaluation. + + If you need a more nuanced distinction between forward passes in the different + phases of training and evaluation, override the following methods instead: + For distinct action computation logic w/o exploration, override the + `self._forward_inference()` method. + For distinct action computation logic with exploration, override the + `self._forward_exploration()` method. + For distinct forward pass logic before loss computation, override the + `self._forward_train()` method. - @OverrideToImplementCustomLogic_CallToSuperRecommended - def output_specs_exploration(self) -> SpecType: - """Returns the output specs of the `forward_exploration()` method. + Args: + batch: The input batch. + **kwargs: Additional keyword arguments. - Override this method to customize the output specs of the exploration call. - The default implementation requires the `forward_exploration()` method to return - a dict that has `action_dist` key and its value is an instance of - `Distribution`. + Returns: + The output of the forward pass. """ - return [Columns.ACTION_DIST_INPUTS] - - def output_specs_train(self) -> SpecType: - """Returns the output specs of the forward_train method.""" return {} - def input_specs_inference(self) -> SpecType: - """Returns the input specs of the forward_inference method.""" - return self._default_input_specs() - - def input_specs_exploration(self) -> SpecType: - """Returns the input specs of the forward_exploration method.""" - return self._default_input_specs() - - def input_specs_train(self) -> SpecType: - """Returns the input specs of the forward_train method.""" - return self._default_input_specs() - - def _default_input_specs(self) -> SpecType: - """Returns the default input specs.""" - return [Columns.OBS] - @check_input_specs("_input_specs_inference") @check_output_specs("_output_specs_inference") - def forward_inference(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during evaluation, called from the sampler. + def forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during evaluation, called from the sampler. - This method should not be overriden to implement a custom forward inference - method. Instead, override the _forward_inference method. + This method should not be overridden. Override the `self._forward_inference()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -605,17 +562,25 @@ def forward_inference(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: """ return self._forward_inference(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during evaluation. See forward_inference for details.""" + """Forward-pass used for action computation without exploration behavior. + + Override this method only, if you need specific behavior for non-exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) @check_input_specs("_input_specs_exploration") @check_output_specs("_output_specs_exploration") - def forward_exploration(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during exploration, called from the sampler. + def forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during exploration, called from the sampler. - This method should not be overriden to implement a custom forward exploration - method. Instead, override the _forward_exploration method. + This method should not be overridden. Override the `self._forward_exploration()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -628,15 +593,25 @@ def forward_exploration(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any """ return self._forward_exploration(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during exploration. See forward_exploration for details.""" + """Forward-pass used for action computation with exploration behavior. + + Override this method only, if you need specific behavior for exploratory + action computation behavior. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) @check_input_specs("_input_specs_train") @check_output_specs("_output_specs_train") - def forward_train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: - """Forward-pass during training called from the learner. This method should - not be overriden. Instead, override the _forward_train method. + def forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! Forward-pass during training called from the learner. + + This method should not be overridden. Override the `self._forward_train()` + method instead. Args: batch: The input batch. This input batch should comply with @@ -655,9 +630,42 @@ def forward_train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: ) return self._forward_train(batch, **kwargs) - @abc.abstractmethod + @OverrideToImplementCustomLogic def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """Forward-pass during training. See forward_train for details.""" + """Forward-pass used before the loss computation (training). + + Override this method only, if you need specific behavior and outputs for your + loss computations. If you have only one generic behavior for all + phases of training and evaluation, override `self._forward()` instead. + + By default, this calls the generic `self._forward()` method. + """ + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def get_initial_state(self) -> Any: + """Returns the initial state of the RLModule, in case this is a stateful module. + + Returns: + A tensor or any nested struct of tensors, representing an initial state for + this (stateful) RLModule. + """ + return {} + + @OverrideToImplementCustomLogic + def is_stateful(self) -> bool: + """By default, returns False if the initial state is an empty dict (or None). + + By default, RLlib assumes that the module is non-recurrent, if the initial + state is an empty dict and recurrent otherwise. + This behavior can be customized by overriding this method. + """ + initial_state = self.get_initial_state() + assert isinstance(initial_state, dict), ( + "The initial state of an RLModule must be a dict, but is " + f"{type(initial_state)} instead." + ) + return bool(initial_state) @OverrideToImplementCustomLogic @override(Checkpointable) @@ -701,6 +709,48 @@ def get_ctor_args_and_kwargs(self): {}, # **kwargs ) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def output_specs_inference(self) -> SpecType: + """Returns the output specs of the `forward_inference()` method. + + Override this method to customize the output specs of the inference call. + The default implementation requires the `forward_inference()` method to return + a dict that has `action_dist` key and its value is an instance of + `Distribution`. + """ + return [Columns.ACTION_DIST_INPUTS] + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def output_specs_exploration(self) -> SpecType: + """Returns the output specs of the `forward_exploration()` method. + + Override this method to customize the output specs of the exploration call. + The default implementation requires the `forward_exploration()` method to return + a dict that has `action_dist` key and its value is an instance of + `Distribution`. + """ + return [Columns.ACTION_DIST_INPUTS] + + def output_specs_train(self) -> SpecType: + """Returns the output specs of the forward_train method.""" + return {} + + def input_specs_inference(self) -> SpecType: + """Returns the input specs of the forward_inference method.""" + return self._default_input_specs() + + def input_specs_exploration(self) -> SpecType: + """Returns the input specs of the forward_exploration method.""" + return self._default_input_specs() + + def input_specs_train(self) -> SpecType: + """Returns the input specs of the forward_train method.""" + return self._default_input_specs() + + def _default_input_specs(self) -> SpecType: + """Returns the default input specs.""" + return [Columns.OBS] + def as_multi_rl_module(self) -> "MultiRLModule": """Returns a multi-agent wrapper around this module.""" from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule diff --git a/rllib/core/rl_module/torch/torch_rl_module.py b/rllib/core/rl_module/torch/torch_rl_module.py index db33cf9f9e0e..536631db96a8 100644 --- a/rllib/core/rl_module/torch/torch_rl_module.py +++ b/rllib/core/rl_module/torch/torch_rl_module.py @@ -68,15 +68,6 @@ def __init__(self, *args, **kwargs) -> None: if target is not None: del target - @override(nn.Module) - def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: - """forward pass of the module. - - This is aliased to forward_train because Torch DDP requires a forward method to - be implemented for backpropagation to work. - """ - return self.forward_train(batch, **kwargs) - def compile(self, compile_config: TorchCompileConfig): """Compile the forward methods of this module. @@ -88,6 +79,20 @@ def compile(self, compile_config: TorchCompileConfig): """ return compile_wrapper(self, compile_config) + @OverrideToImplementCustomLogic + def _forward_inference(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + + @OverrideToImplementCustomLogic + def _forward_exploration(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + # By default, calls the generic `_forward()` method, but with a no-grad context + # for performance reasons. + with torch.no_grad(): + return self._forward(batch, **kwargs) + @OverrideToImplementCustomLogic @override(RLModule) def get_state( @@ -156,6 +161,24 @@ def get_exploration_action_dist_cls(self) -> Type[TorchDistribution]: def get_train_action_dist_cls(self) -> Type[TorchDistribution]: return self.get_inference_action_dist_cls() + @override(nn.Module) + def forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]: + """DO NOT OVERRIDE! + + This is aliased to `self.forward_train` because Torch DDP requires a forward + method to be implemented for backpropagation to work. + + Instead, override: + `_forward()` to define a generic forward pass for all phases (exploration, + inference, training) + `_forward_inference()` to define the forward pass for action inference in + deployment/production (no exploration). + `_forward_exploration()` to define the forward pass for action inference during + training sample collection (w/ exploration behavior). + `_forward_train()` to define the forward pass prior to loss computation. + """ + return self.forward_train(batch, **kwargs) + class TorchDDPRLModule(RLModule, nn.parallel.DistributedDataParallel): def __init__(self, *args, **kwargs) -> None: @@ -165,8 +188,8 @@ def __init__(self, *args, **kwargs) -> None: self.config = self.unwrapped().config @override(RLModule) - def get_train_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: - return self.unwrapped().get_train_action_dist_cls(*args, **kwargs) + def get_inference_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_inference_action_dist_cls(*args, **kwargs) @override(RLModule) def get_exploration_action_dist_cls( @@ -175,8 +198,8 @@ def get_exploration_action_dist_cls( return self.unwrapped().get_exploration_action_dist_cls(*args, **kwargs) @override(RLModule) - def get_inference_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: - return self.unwrapped().get_inference_action_dist_cls(*args, **kwargs) + def get_train_action_dist_cls(self, *args, **kwargs) -> Type[TorchDistribution]: + return self.unwrapped().get_train_action_dist_cls(*args, **kwargs) @override(RLModule) def get_initial_state(self) -> Any: @@ -187,8 +210,8 @@ def is_stateful(self) -> bool: return self.unwrapped().is_stateful() @override(RLModule) - def _forward_train(self, *args, **kwargs): - return self(*args, **kwargs) + def _forward(self, *args, **kwargs): + return self.unwrapped()._forward(*args, **kwargs) @override(RLModule) def _forward_inference(self, *args, **kwargs) -> Dict[str, Any]: @@ -198,6 +221,10 @@ def _forward_inference(self, *args, **kwargs) -> Dict[str, Any]: def _forward_exploration(self, *args, **kwargs) -> Dict[str, Any]: return self.unwrapped()._forward_exploration(*args, **kwargs) + @override(RLModule) + def _forward_train(self, *args, **kwargs): + return self(*args, **kwargs) + @override(RLModule) def get_state(self, *args, **kwargs): return self.unwrapped().get_state(*args, **kwargs) diff --git a/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py b/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py index e18512dbf47e..348dfb2af142 100644 --- a/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py +++ b/rllib/examples/offline_rl/train_w_bc_finetune_w_ppo.py @@ -176,11 +176,12 @@ def _forward_train(self, batch, **kwargs): } @override(ValueFunctionAPI) - def compute_values(self, batch): - # Compute features ... - features = self._encoder(batch)[ENCODER_OUT] + def compute_values(self, batch, embeddings=None): + # Compute embeddings ... + if embeddings is None: + embeddings = self._encoder(batch)[ENCODER_OUT] # then values using our value head. - return self._vf(features).squeeze(-1) + return self._vf(embeddings).squeeze(-1) if __name__ == "__main__": diff --git a/rllib/examples/rl_modules/classes/action_masking_rlm.py b/rllib/examples/rl_modules/classes/action_masking_rlm.py index e948b8c1a1ef..2a71b66fa109 100644 --- a/rllib/examples/rl_modules/classes/action_masking_rlm.py +++ b/rllib/examples/rl_modules/classes/action_masking_rlm.py @@ -99,14 +99,14 @@ def _forward_train( return self._mask_action_logits(outs, batch["action_mask"]) @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, TensorType]): + def compute_values(self, batch: Dict[str, TensorType], embeddings=None): # Preprocess the batch to extract the `observations` to `Columns.OBS`. action_mask, batch = self._preprocess_batch(batch) # NOTE: Because we manipulate the batch we need to add the `action_mask` # to the batch to access them in `_forward_train`. batch["action_mask"] = action_mask # Call the super's method to compute values for GAE. - return super().compute_values(batch) + return super().compute_values(batch, embeddings) def _preprocess_batch( self, batch: Dict[str, TensorType], **kwargs diff --git a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py index d0ff7650a166..bbfcb6982151 100644 --- a/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/classes/autoregressive_actions_rlm.py @@ -212,10 +212,8 @@ def pi( @override(TorchRLModule) def _forward_inference(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. return self.pi(encoder_out[ENCODER_OUT], inference=True) @@ -225,21 +223,16 @@ def _forward_exploration( ) -> Dict[str, TensorType]: # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. return self.pi(encoder_out[ENCODER_OUT], inference=False) @override(TorchRLModule) def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: - outs = {} - # Encoder forward pass. encoder_out = self.encoder(batch) - # Policy head forward pass. outs.update(self.pi(encoder_out[ENCODER_OUT])) - # Value function head forward pass. vf_out = self.vf(encoder_out[ENCODER_OUT]) outs[Columns.VF_PREDS] = vf_out.squeeze(-1) @@ -247,13 +240,11 @@ def _forward_train(self, batch: Dict[str, TensorType]) -> Dict[str, TensorType]: return outs @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, TensorType]): - - # Encoder forward pass. - encoder_outs = self.encoder(batch)[ENCODER_OUT] - + def compute_values(self, batch: Dict[str, TensorType], embeddings=None): + # Encoder forward pass to get `embeddings`, if necessary. + if embeddings is None: + embeddings = self.encoder(batch)[ENCODER_OUT] # Value head forward pass. - vf_out = self.vf(encoder_outs) - + vf_out = self.vf(embeddings) # Squeeze out last dimension (single node value head). return vf_out.squeeze(-1) diff --git a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py index c03f61d82028..ed1efbc1fc17 100644 --- a/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py +++ b/rllib/examples/rl_modules/classes/intrinsic_curiosity_model_rlm.py @@ -238,12 +238,8 @@ def compute_self_supervised_loss( # Inference and exploration not supported (this is a world-model that should only # be used for training). @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): + def _forward(self, batch, **kwargs): raise NotImplementedError( "`IntrinsicCuriosityModel` should only be used for training! " - "Use `forward_train()` instead." + "Only calls to `forward_train()` supported." ) - - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - return self._forward_inference(batch) diff --git a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py index 87363c267a7a..5e8fae5f2e5e 100644 --- a/rllib/examples/rl_modules/classes/lstm_containing_rlm.py +++ b/rllib/examples/rl_modules/classes/lstm_containing_rlm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional import numpy as np @@ -23,12 +23,12 @@ class LSTMContainingRLModule(TorchRLModule, ValueFunctionAPI): B = 10 # batch size T = 5 # seq len - f = 25 # feature dim + e = 25 # embedding dim CELL = 32 # LSTM cell size # Construct the RLModule. rl_module_config = RLModuleConfig( - observation_space=gym.spaces.Box(-1.0, 1.0, (f,), np.float32), + observation_space=gym.spaces.Box(-1.0, 1.0, (e,), np.float32), action_space=gym.spaces.Discrete(4), model_config_dict={"lstm_cell_size": CELL} ) @@ -36,7 +36,7 @@ class LSTMContainingRLModule(TorchRLModule, ValueFunctionAPI): # Create some dummy input. obs = torch.from_numpy( - np.random.random_sample(size=(B, T, f) + np.random.random_sample(size=(B, T, e) ).astype(np.float32)) state_in = my_net.get_initial_state() # Repeat state_in across batch. @@ -77,7 +77,7 @@ def setup(self): # Get the LSTM cell size from our RLModuleConfig's (self.config) # `model_config_dict` property: self._lstm_cell_size = self.config.model_config_dict.get("lstm_cell_size", 256) - self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=False) + self._lstm = nn.LSTM(in_size, self._lstm_cell_size, batch_first=True) in_size = self._lstm_cell_size # Build a sequential stack. @@ -94,7 +94,7 @@ def setup(self): self._fc_net = nn.Sequential(*layers) # Logits layer (no bias, no activation). - self._logits = nn.Linear(in_size, self.config.action_space.n) + self._pi_head = nn.Linear(in_size, self.config.action_space.n) # Single-node value layer. self._values = nn.Linear(in_size, 1) @@ -106,70 +106,50 @@ def get_initial_state(self) -> Any: } @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): - # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - _, state_out, logits = self._compute_features_state_out_and_logits(batch) + def _forward(self, batch, **kwargs): + # Compute the basic 1D embedding tensor (inputs to policy- and value-heads). + embeddings, state_outs = self._compute_embeddings_and_state_outs(batch) + logits = self._pi_head(embeddings) # Return logits as ACTION_DIST_INPUTS (categorical distribution). # Note that the default `GetActions` connector piece (in the EnvRunner) will # take care of argmax-"sampling" from the logits to yield the inference (greedy) # action. return { - Columns.STATE_OUT: state_out, Columns.ACTION_DIST_INPUTS: logits, + Columns.STATE_OUT: state_outs, } - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - # Exact same as `_forward_inference`. - # Note that the default `GetActions` connector piece (in the EnvRunner) will - # take care of stochastic sampling from the Categorical defined by the logits - # to yield the exploration action. - return self._forward_inference(batch, **kwargs) - @override(TorchRLModule) def _forward_train(self, batch, **kwargs): - # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, state_out, logits = self._compute_features_state_out_and_logits(batch) - # Besides the action logits, we also have to return value predictions here - # (to be used inside the loss function). - values = self._values(features).squeeze(-1) + # Same logic as _forward, but also return embeddings to be used by value + # function branch during training. + embeddings, state_outs = self._compute_embeddings_and_state_outs(batch) + logits = self._pi_head(embeddings) return { - Columns.STATE_OUT: state_out, Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: values, + Columns.STATE_OUT: state_outs, + Columns.EMBEDDINGS: embeddings, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - obs = batch[Columns.OBS] - state_in = batch[Columns.STATE_IN] - h, c = state_in["h"], state_in["c"] - # Unsqueeze the layer dim (we only have 1 LSTM layer. - features, _ = self._lstm( - obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major - (h.unsqueeze(0), c.unsqueeze(0)), - ) - # Make batch-major again. - features = features.permute(1, 0, 2) - # Push through our FC net. - features = self._fc_net(features) - return self._values(features).squeeze(-1) - - def _compute_features_state_out_and_logits(self, batch): + def compute_values( + self, batch: Dict[str, Any], embeddings: Optional[Any] = None + ) -> TensorType: + if embeddings is None: + embeddings, _ = self._compute_embeddings_and_state_outs(batch) + values = self._values(embeddings).squeeze(-1) + return values + + def _compute_embeddings_and_state_outs(self, batch): obs = batch[Columns.OBS] state_in = batch[Columns.STATE_IN] h, c = state_in["h"], state_in["c"] - # Unsqueeze the layer dim (we only have 1 LSTM layer. - features, (h, c) = self._lstm( - obs.permute(1, 0, 2), # we have to permute, b/c our LSTM is time-major - (h.unsqueeze(0), c.unsqueeze(0)), - ) - # Make batch-major again. - features = features.permute(1, 0, 2) + # Unsqueeze the layer dim (we only have 1 LSTM layer). + embeddings, (h, c) = self._lstm(obs, (h.unsqueeze(0), c.unsqueeze(0))) # Push through our FC net. - features = self._fc_net(features) - logits = self._logits(features) - return features, {"h": h.squeeze(0), "c": c.squeeze(0)}, logits + embeddings = self._fc_net(embeddings) + # Squeeze the layer dim (we only have 1 LSTM layer). + return embeddings, {"h": h.squeeze(0), "c": c.squeeze(0)} diff --git a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py index bf8e4731ceef..0fa166a610a7 100644 --- a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py +++ b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Dict +from typing import Any, Dict, Optional import tree from ray.rllib.core import Columns, DEFAULT_POLICY_ID @@ -181,7 +181,7 @@ def _forward_pass(self, batch, inference=True): return output @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]): + def compute_values(self, batch: Dict[str, Any], embeddings: Optional[Any] = None): self._forward_pass(batch, inference=False) v_preds = self._model_v2.value_function() if Columns.STATE_IN in batch and Columns.SEQ_LENS in batch: diff --git a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py index 22acf3939e8f..317b3e3c8c09 100644 --- a/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py +++ b/rllib/examples/rl_modules/classes/tiny_atari_cnn_rlm.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI @@ -49,7 +49,6 @@ class TinyAtariCNN(TorchRLModule, ValueFunctionAPI): num_all_params = sum(int(np.prod(p.size())) for p in my_net.parameters()) print(f"num params = {num_all_params}") - """ @override(TorchRLModule) @@ -122,42 +121,44 @@ def setup(self): normc_initializer(0.01)(self._values.weight) @override(TorchRLModule) - def _forward_inference(self, batch, **kwargs): + def _forward(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - _, logits = self._compute_features_and_logits(batch) - # Return logits as ACTION_DIST_INPUTS (categorical distribution). - return {Columns.ACTION_DIST_INPUTS: logits} - - @override(TorchRLModule) - def _forward_exploration(self, batch, **kwargs): - return self._forward_inference(batch, **kwargs) + _, logits = self._compute_embeddings_and_logits(batch) + # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). + return { + Columns.ACTION_DIST_INPUTS: logits, + } @override(TorchRLModule) def _forward_train(self, batch, **kwargs): # Compute the basic 1D feature tensor (inputs to policy- and value-heads). - features, logits = self._compute_features_and_logits(batch) - # Besides the action logits, we also have to return value predictions here - # (to be used inside the loss function). - values = self._values(features).squeeze(-1) + embeddings, logits = self._compute_embeddings_and_logits(batch) + # Return features and logits as ACTION_DIST_INPUTS (categorical distribution). return { Columns.ACTION_DIST_INPUTS: logits, - Columns.VF_PREDS: values, + Columns.EMBEDDINGS: embeddings, } # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used # by value-based methods like PPO or IMPALA. @override(ValueFunctionAPI) - def compute_values(self, batch: Dict[str, Any]) -> TensorType: - obs = batch[Columns.OBS] - features = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) - features = torch.squeeze(features, dim=[-1, -2]) - return self._values(features).squeeze(-1) - - def _compute_features_and_logits(self, batch): + def compute_values( + self, + batch: Dict[str, Any], + embeddings: Optional[Any] = None, + ) -> TensorType: + # Features not provided -> We need to compute them first. + if embeddings is None: + obs = batch[Columns.OBS] + embeddings = self._base_cnn_stack(obs.permute(0, 3, 1, 2)) + embeddings = torch.squeeze(embeddings, dim=[-1, -2]) + return self._values(embeddings).squeeze(-1) + + def _compute_embeddings_and_logits(self, batch): obs = batch[Columns.OBS].permute(0, 3, 1, 2) - features = self._base_cnn_stack(obs) - logits = self._logits(features) + embeddings = self._base_cnn_stack(obs) + logits = self._logits(embeddings) return ( - torch.squeeze(features, dim=[-1, -2]), + torch.squeeze(embeddings, dim=[-1, -2]), torch.squeeze(logits, dim=[-1, -2]), ) diff --git a/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py b/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py index a76490257e4d..c84d2277afff 100644 --- a/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py +++ b/rllib/examples/rl_modules/classes/vpg_using_shared_encoder_rlm.py @@ -19,11 +19,11 @@ def setup(self): super().setup() # Incoming feature dim from the shared encoder. - feature_dim = self.model_config["feature_dim"] + embedding_dim = self.model_config["embedding_dim"] hidden_dim = self.model_config["hidden_dim"] self._pi_head = nn.Sequential( - nn.Linear(feature_dim, hidden_dim), + nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, self.action_space.n), ) @@ -60,7 +60,7 @@ class VPGTorchMultiRLModuleWithSharedEncoder(MultiRLModule): from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec - FEATURE_DIM = 64 # encoder output (feature) dim + EMBEDDING_DIM = 64 # encoder output (feature) dim HIDDEN_DIM = 64 # hidden dim for the policy nets config.rl_module( @@ -69,20 +69,20 @@ class VPGTorchMultiRLModuleWithSharedEncoder(MultiRLModule): # Central/shared encoder net. SHARED_ENCODER_ID: RLModuleSpec( module_class=SharedTorchEncoder, - model_config_dict={"feature_dim": FEATURE_DIM}, + model_config_dict={"embedding_dim": EMBEDDING_DIM}, ), # Arbitrary number of policy nets (w/o encoder sub-net). "p0": RLModuleSpec( module_class=VPGTorchRLModuleUsingSharedEncoder, model_config_dict={ - "feature_dim": FEATURE_DIM, + "embedding_dim": EMBEDDING_DIM, "hidden_dim": HIDDEN_DIM, }, ), "p1": RLModuleSpec( module_class=VPGTorchRLModuleUsingSharedEncoder, model_config_dict={ - "feature_dim": FEATURE_DIM, + "embedding_dim": EMBEDDING_DIM, "hidden_dim": HIDDEN_DIM, }, ), @@ -138,10 +138,10 @@ def setup(self): super().setup() input_dim = self.observation_space.shape[0] - feature_dim = self.model_config["feature_dim"] + embedding_dim = self.model_config["embedding_dim"] self._encoder = nn.Sequential( - nn.Linear(input_dim, feature_dim), + nn.Linear(input_dim, embedding_dim), ) @override(RLModule) diff --git a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py index e7aa087413e5..f716758bae21 100644 --- a/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py +++ b/rllib/tuned_examples/dqn/multi_agent_cartpole_dqn.py @@ -31,7 +31,7 @@ .environment(env="multi_agent_cartpole", env_config={"num_agents": args.num_agents}) .training( lr=0.0005 * (args.num_gpus or 1) ** 0.5, - train_batch_size_per_learner=32, + train_batch_size_per_learner=48, replay_buffer_config={ "type": "MultiAgentPrioritizedEpisodeReplayBuffer", "capacity": 50000, @@ -46,9 +46,9 @@ ) .rl_module( model_config_dict={ - "fcnet_hiddens": [256], + "fcnet_hiddens": [256, 256], "fcnet_activation": "tanh", - "epsilon": [(0, 1.0), (10000, 0.02)], + "epsilon": [(0, 1.0), (20000, 0.02)], "fcnet_bias_initializer": "zeros_", "post_fcnet_bias_initializer": "zeros_", "post_fcnet_hiddens": [256], @@ -65,7 +65,7 @@ stop = { NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, # `episode_return_mean` is the sum of all agents/policies' returns. - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 250.0 * args.num_agents, + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": 200.0 * args.num_agents, } if __name__ == "__main__": diff --git a/rllib/utils/annotations.py b/rllib/utils/annotations.py index d06a45dcb49b..6824412b354f 100644 --- a/rllib/utils/annotations.py +++ b/rllib/utils/annotations.py @@ -171,7 +171,7 @@ def loss(self, ...): ... """ - obj.__is_overriden__ = False + obj.__is_overridden__ = False return obj @@ -196,7 +196,7 @@ def setup(self, config): super().setup(config) # ... or here (after having called super()'s setup method. """ - obj.__is_overriden__ = False + obj.__is_overridden__ = False return obj @@ -206,7 +206,7 @@ def is_overridden(obj): Note, this only works for API calls decorated with OverrideToImplementCustomLogic or OverrideToImplementCustomLogic_CallToSuperRecommended. """ - return getattr(obj, "__is_overriden__", True) + return getattr(obj, "__is_overridden__", True) # Backward compatibility.