diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 878d5faa8c8a..cf6cb5f7e041 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -2,6 +2,12 @@ from typing import Optional, Type, Union from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy +from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy +from ray.rllib.algorithms.sac.sac import ( + SAC, + SACConfig, +) from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( AddObservationsFromEpisodesToBatch, ) @@ -9,12 +15,7 @@ AddNextObservationsFromEpisodesToTrainBatch, ) from ray.rllib.core.learner.learner import Learner -from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy -from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy -from ray.rllib.algorithms.sac.sac import ( - SAC, - SACConfig, -) +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) @@ -48,7 +49,7 @@ SAMPLE_TIMER, TIMERS, ) -from ray.rllib.utils.typing import ResultDict +from ray.rllib.utils.typing import ResultDict, RLModuleSpecType tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() @@ -83,7 +84,14 @@ def __init__(self, algo_class=None): self.lagrangian = False self.lagrangian_thresh = 5.0 self.min_q_weight = 5.0 + self.deterministic_backup = True self.lr = 3e-4 + # Note, the new stack defines learning rates for each component. + # The base learning rate `lr` has to be set to `None`, if using + # the new stack. + self.actor_lr = 1e-4, + self.critic_lr = 1e-3 + self.alpha_lr = 1e-3 # Changes to Algorithm's/SACConfig's default: @@ -105,6 +113,7 @@ def training( lagrangian: Optional[bool] = NotProvided, lagrangian_thresh: Optional[float] = NotProvided, min_q_weight: Optional[float] = NotProvided, + deterministic_backup: Optional[bool] = NotProvided, **kwargs, ) -> "CQLConfig": """Sets the training-related configuration. @@ -116,6 +125,8 @@ def training( lagrangian: Whether to use the Lagrangian for Alpha Prime (in CQL loss). lagrangian_thresh: Lagrangian threshold. min_q_weight: in Q weight multiplier. + deterministic_backup: If the target in the Bellman update should have an + entropy backup. Defaults to `True`. Returns: This updated AlgorithmConfig object. @@ -135,6 +146,8 @@ def training( self.lagrangian_thresh = lagrangian_thresh if min_q_weight is not NotProvided: self.min_q_weight = min_q_weight + if deterministic_backup is not NotProvided: + self.deterministic_backup = deterministic_backup return self @@ -234,6 +247,27 @@ def validate(self) -> None: "Set this hyperparameter in the `AlgorithmConfig.offline_data`." ) + @override(SACConfig) + def get_default_rl_module_spec(self) -> RLModuleSpecType: + from ray.rllib.algorithms.sac.sac_catalog import SACCatalog + + if self.framework_str == "torch": + from ray.rllib.algorithms.cql.torch.cql_torch_rl_module import ( + CQLTorchRLModule, + ) + + return RLModuleSpec(module_class=CQLTorchRLModule, catalog_class=SACCatalog) + else: + raise ValueError( + f"The framework {self.framework_str} is not supported. " "Use `torch`." + ) + + @property + def _model_config_auto_includes(self): + return super()._model_config_auto_includes | { + "num_actions": self.num_actions, + } + class CQL(SAC): """CQL (derived from SAC).""" diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index a28547393c2a..4d74e2f22c73 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -1,4 +1,3 @@ -import tree from typing import Dict from ray.air.constants import TRAINING_ITERATION @@ -52,17 +51,6 @@ def compute_loss_for_module( fwd_out[Columns.ACTION_DIST_INPUTS] ) - # Sample actions for the current state. Note that we need to apply the - # reparameterization trick here to avoid the expectation over actions. - actions_curr = ( - action_dist_curr.rsample() - if not config._deterministic_loss - # If deterministic, we use the mean.s - else action_dist_curr.to_deterministic().sample() - ) - # Compute the log probabilities for the current state (for the alpha loss) - logps_curr = action_dist_curr.logp(actions_curr) - # Optimize also the hyperparameter `alpha` by using the current policy # evaluated at the current state (from offline data). Note, in contrast # to the original SAC loss, here the `alpha` and actor losses are @@ -71,13 +59,9 @@ def compute_loss_for_module( # to optimize and monotonic function. Original equation uses alpha. alpha_loss = -torch.mean( self.curr_log_alpha[module_id] - * (logps_curr.detach() + self.target_entropy[module_id]) + * (fwd_out["logp_resampled"].detach() + self.target_entropy[module_id]) ) - # Get the current batch size. Note, this size might vary in case the - # last batch contains less than `train_batch_size_per_learner` examples. - batch_size = batch[Columns.OBS].shape[0] - # Get the current alpha. alpha = torch.exp(self.curr_log_alpha[module_id]) # Start training with behavior cloning and turn to the classic Soft-Actor Critic @@ -86,36 +70,20 @@ def compute_loss_for_module( self.metrics.peek((ALL_MODULES, TRAINING_ITERATION), default=0) >= config.bc_iters ): - # Calculate current Q-values. - batch_curr = { - Columns.OBS: batch[Columns.OBS], - # Use the actions sampled from the current policy. - Columns.ACTIONS: actions_curr, - } - # Note, if `twin_q` is `True`, `compute_q_values` computes the minimum - # of the `qf` and `qf_twin` and returns this minimum. - q_curr = self.module[module_id].compute_q_values(batch_curr) - actor_loss = torch.mean(alpha.detach() * logps_curr - q_curr) + actor_loss = torch.mean( + alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"] + ) else: # Use log-probabilities of the current action distribution to clone # the behavior policy (selected actions in data) in the first `bc_iters` # training iterations. bc_logps_curr = action_dist_curr.logp(batch[Columns.ACTIONS]) - actor_loss = torch.mean(alpha.detach() * logps_curr - bc_logps_curr) + actor_loss = torch.mean( + alpha.detach() * fwd_out["logp_resampled"] - bc_logps_curr + ) # The critic loss is composed of the standard SAC Critic L2 loss and the # CQL entropy loss. - action_dist_next = action_dist_class.from_logits( - fwd_out["action_dist_inputs_next"] - ) - # Sample the actions for the next state. - actions_next = ( - # Note, we do not need to backpropagate through the - # next actions. - action_dist_next.sample() - if not config._deterministic_loss - else action_dist_next.to_deterministic().sample() - ) # Get the Q-values for the actually selected actions in the offline data. # In the critic loss we use these as predictions. @@ -123,23 +91,21 @@ def compute_loss_for_module( if config.twin_q: q_twin_selected = fwd_out[QF_TWIN_PREDS] - # Compute Q-values from the target Q network for the next state with the - # sampled actions for the next state. - q_batch_next = { - Columns.OBS: batch[Columns.NEXT_OBS], - Columns.ACTIONS: actions_next, - } - # Note, if `twin_q` is `True`, `SACTorchRLModule.forward_target` calculates - # the Q-values for both, `qf_target` and `qf_twin_target` and - # returns the minimum. - q_target_next = self.module[module_id].forward_target(q_batch_next) + if not config.deterministic_backup: + q_next = ( + fwd_out["q_target_next"] + - alpha.detach() * fwd_out["logp_next_resampled"] + ) + else: + q_next = fwd_out["q_target_next"] # Now mask all Q-values with terminating next states in the targets. - q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_target_next + q_next_masked = (1.0 - batch[Columns.TERMINATEDS].float()) * q_next # Compute the right hand side of the Bellman equation. Detach this node # from the computation graph as we do not want to backpropagate through - # the target netowrk when optimizing the Q loss. + # the target network when optimizing the Q loss. + # TODO (simon, sven): Kumar et al. (2020) use here also a reward scaler. q_selected_target = ( # TODO (simon): Add an `n_step` option to the `AddNextObsToBatch` connector. batch[Columns.REWARDS] @@ -171,132 +137,30 @@ def compute_loss_for_module( # Now calculate the CQL loss (we use the entropy version of the CQL algorithm). # Note, the entropy version performs best in shown experiments. - # Generate random actions (from the mu distribution as named in Kumar et - # al. (2020)) - low = torch.tensor( - self.module[module_id].config.action_space.low, - device=fwd_out[QF_PREDS].device, - ) - high = torch.tensor( - self.module[module_id].config.action_space.high, - device=fwd_out[QF_PREDS].device, - ) - num_samples = batch[Columns.ACTIONS].shape[0] * config.num_actions - actions_rand_repeat = low + (high - low) * torch.rand( - (num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device - ) - - # Sample current and next actions (from the pi distribution as named in Kumar - # et al. (2020)) using repeated observations. - actions_curr_repeat, logps_curr_repeat, obs_curr_repeat = self._repeat_actions( - action_dist_class, batch[Columns.OBS], config.num_actions, module_id - ) - actions_next_repeat, logps_next_repeat, obs_next_repeat = self._repeat_actions( - action_dist_class, batch[Columns.NEXT_OBS], config.num_actions, module_id - ) - # Calculate the Q-values for all actions. - batch_rand_repeat = { - Columns.OBS: obs_curr_repeat, - Columns.ACTIONS: actions_rand_repeat, - } - # Note, we need here the Q-values from the base Q-value function - # and not the minimum with an eventual Q-value twin. - q_rand_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_rand_repeat, - self.module[module_id].qf_encoder, - self.module[module_id].qf, - ) - .view(batch_size, config.num_actions, 1) - ) - # Calculate twin Q-values for the random actions, if needed. - if config.twin_q: - q_twin_rand_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_rand_repeat, - self.module[module_id].qf_twin_encoder, - self.module[module_id].qf_twin, - ) - .view(batch_size, config.num_actions, 1) - ) - del batch_rand_repeat - batch_curr_repeat = { - Columns.OBS: obs_curr_repeat, - Columns.ACTIONS: actions_curr_repeat, - } - q_curr_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_curr_repeat, - self.module[module_id].qf_encoder, - self.module[module_id].qf, - ) - .view(batch_size, config.num_actions, 1) - ) - # Calculate twin Q-values for the repeated actions from the current policy, - # if needed. - if config.twin_q: - q_twin_curr_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_curr_repeat, - self.module[module_id].qf_twin_encoder, - self.module[module_id].qf_twin, - ) - .view(batch_size, config.num_actions, 1) - ) - del batch_curr_repeat - batch_next_repeat = { - # Note, we use here the current observations b/c we want to keep the - # state fix while sampling the actions. - Columns.OBS: obs_curr_repeat, - Columns.ACTIONS: actions_next_repeat, - } - q_next_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_next_repeat, - self.module[module_id].qf_encoder, - self.module[module_id].qf, - ) - .view(batch_size, config.num_actions, 1) - ) - # Calculate also the twin Q-values for the current policy and next actions, - # if needed. - if config.twin_q: - q_twin_next_repeat = ( - self.module[module_id] - ._qf_forward_train_helper( - batch_next_repeat, - self.module[module_id].qf_twin_encoder, - self.module[module_id].qf_twin, - ) - .view(batch_size, config.num_actions, 1) - ) - del batch_next_repeat - - # Compute the log-probabilities for the random actions. + # Compute the log-probabilities for the random actions (note, we generate random + # actions (from the mu distribution as named in Kumar et al. (2020))). + # Note, all actions, action log-probabilities and Q-values are already computed + # by the module's `_forward_train` method. # TODO (simon): This is the density for a discrete uniform, however, actions # come from a continuous one. So actually this density should use (1/(high-low)) # instead of (1/2). random_density = torch.log( torch.pow( + 0.5, torch.tensor( - actions_curr_repeat.shape[-1], device=actions_curr_repeat.device + fwd_out["actions_curr_repeat"].shape[-1], + device=fwd_out["actions_curr_repeat"].device, ), - 0.5, ) ) # Merge all Q-values and subtract the log-probabilities (note, we use the # entropy version of CQL). q_repeat = torch.cat( [ - q_rand_repeat - random_density, - q_next_repeat - logps_next_repeat.detach(), - q_curr_repeat - logps_curr_repeat.detach(), + fwd_out["q_rand_repeat"] - random_density, + fwd_out["q_next_repeat"] - fwd_out["logps_next_repeat"].detach(), + fwd_out["q_curr_repeat"] - fwd_out["logps_curr_repeat"].detach(), ], dim=1, ) @@ -313,9 +177,11 @@ def compute_loss_for_module( if config.twin_q: q_twin_repeat = torch.cat( [ - q_twin_rand_repeat - random_density, - q_twin_next_repeat - logps_next_repeat.detach(), - q_twin_curr_repeat - logps_curr_repeat.detach(), + fwd_out["q_twin_rand_repeat"] - random_density, + fwd_out["q_twin_next_repeat"] + - fwd_out["logps_next_repeat"].detach(), + fwd_out["q_twin_curr_repeat"] + - fwd_out["logps_curr_repeat"].detach(), ], dim=1, ) @@ -350,11 +216,12 @@ def compute_loss_for_module( "alpha_value": alpha, "log_alpha_value": torch.log(alpha), "target_entropy": self.target_entropy[module_id], - "actions_curr_policy": torch.mean(actions_curr), - LOGPS_KEY: torch.mean(logps_curr), - QF_MEAN_KEY: torch.mean(q_curr_repeat), - QF_MAX_KEY: torch.max(q_curr_repeat), - QF_MIN_KEY: torch.min(q_curr_repeat), + LOGPS_KEY: torch.mean( + fwd_out["logp_resampled"] + ), # torch.mean(logps_curr), + QF_MEAN_KEY: torch.mean(fwd_out["q_curr_repeat"]), + QF_MAX_KEY: torch.max(fwd_out["q_curr_repeat"]), + QF_MIN_KEY: torch.min(fwd_out["q_curr_repeat"]), TD_ERROR_MEAN_KEY: torch.mean(td_error), }, key=module_id, @@ -389,7 +256,7 @@ def compute_gradients( # Compute the gradients for the component and module. self.metrics.peek((module_id, optim_name + "_loss")).backward( - retain_graph=True + retain_graph=False if optim_name in ["policy", "alpha"] else True ) # Store the gradients for the component and module. # TODO (simon): Check another time the graph for overlapping @@ -406,65 +273,3 @@ def compute_gradients( ) return grads - - def _repeat_tensor(self, tensor, repeat): - """Generates a repeated version of a tensor. - - The repetition is done similar `np.repeat` and repeats each value - instead of the complete vector. - - Args: - tensor: The tensor to be repeated. - repeat: How often each value in the tensor should be repeated. - - Returns: - A tensor holding `repeat` repeated values of the input `tensor` - """ - # Insert the new dimension at axis 1 into the tensor. - t_repeat = tensor.unsqueeze(1) - # Repeat the tensor along the new dimension. - t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1) - # Stack the repeated values into the batch dimension. - t_repeat = t_repeat.view(-1, *tensor.shape[1:]) - # Return the repeated tensor. - return t_repeat - - def _repeat_actions(self, action_dist_class, obs, num_actions, module_id): - """Generated actions for repeated observations. - - The `num_actions` define a multiplier used for generating `num_actions` - as many actions as the batch size. Observations are repeated and then a - model forward pass is made. - - Args: - action_dist_class: The action distribution class to be sued for sampling - actions. - obs: A batched observation tensor. - num_actions: The multiplier for actions, i.e. how much more actions - than the batch size should be generated. - module_id: The module ID to be used when calling the forward pass. - - Returns: - A tuple containing the sampled actions, their log-probabilities and the - repeated observations. - """ - # Receive the batch size. - batch_size = obs.shape[0] - # Repeat the observations `num_actions` times. - obs_repeat = tree.map_structure( - lambda t: self._repeat_tensor(t, num_actions), obs - ) - # Generate a batch for the forward pass. - temp_batch = {Columns.OBS: obs_repeat} - # Run the forward pass in inference mode. - fwd_out = self.module[module_id].forward_inference(temp_batch) - # Generate the squashed Gaussian from the model's logits. - action_dist = action_dist_class.from_logits(fwd_out[Columns.ACTION_DIST_INPUTS]) - # Sample the actions. Note, we want to make a backward pass through - # these actions. - actions = action_dist.rsample() - # Compute the action log-probabilities. - action_logps = action_dist.logp(actions).view(batch_size, num_actions, 1) - - # Return - return actions, action_logps, obs_repeat diff --git a/rllib/algorithms/cql/torch/cql_torch_rl_module.py b/rllib/algorithms/cql/torch/cql_torch_rl_module.py new file mode 100644 index 000000000000..8edb5fcf5c32 --- /dev/null +++ b/rllib/algorithms/cql/torch/cql_torch_rl_module.py @@ -0,0 +1,201 @@ +import tree +from typing import Any, Dict, Optional + +from ray.rllib.algorithms.sac.sac_learner import ( + QF_PREDS, + QF_TWIN_PREDS, +) +from ray.rllib.algorithms.sac.torch.sac_torch_rl_module import SACTorchRLModule + +from ray.rllib.core.columns import Columns +from ray.rllib.core.models.base import ENCODER_OUT +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() + + +class CQLTorchRLModule(SACTorchRLModule): + @override(SACTorchRLModule) + def _forward_train(self, batch: Dict) -> Dict[str, Any]: + # Call the super method. + fwd_out = super()._forward_train(batch) + + # Make sure we perform a "straight-through gradient" pass here, + # ignoring the gradients of the q-net, however, still recording + # the gradients of the policy net (which was used to rsample the actions used + # here). This is different from doing `.detach()` or `with torch.no_grads()`, + # as these two methds would fully block all gradient recordings, including + # the needed policy ones. + all_params = list(self.pi_encoder.parameters()) + list(self.pi.parameters()) + # if self.twin_q: + # all_params += list(self.qf_twin.parameters()) + list( + # self.qf_twin_encoder.parameters() + # ) + + for param in all_params: + param.requires_grad = False + + # Compute the repeated actions, action log-probabilites and Q-values for all + # observations. + # First for the random actions (from the mu-distribution as named by Kumar et + # al. (2020)). + low = torch.tensor( + self.config.action_space.low, + device=fwd_out[QF_PREDS].device, + ) + high = torch.tensor( + self.config.action_space.high, + device=fwd_out[QF_PREDS].device, + ) + num_samples = ( + batch[Columns.ACTIONS].shape[0] + * self.config.model_config_dict["num_actions"] + ) + actions_rand_repeat = low + (high - low) * torch.rand( + (num_samples, low.shape[0]), device=fwd_out[QF_PREDS].device + ) + + # First for the random actions (from the mu-distribution as named in Kumar + # et al. (2020)) using repeated observations. + rand_repeat_out = self._repeat_actions(batch[Columns.OBS], actions_rand_repeat) + (fwd_out["actions_rand_repeat"], fwd_out["q_rand_repeat"]) = ( + rand_repeat_out[Columns.ACTIONS], + rand_repeat_out[QF_PREDS], + ) + # Sample current and next actions (from the pi distribution as named in Kumar + # et al. (2020)) using repeated observations + # Second for the current observations and the current action distribution. + curr_repeat_out = self._repeat_actions(batch[Columns.OBS]) + ( + fwd_out["actions_curr_repeat"], + fwd_out["logps_curr_repeat"], + fwd_out["q_curr_repeat"], + ) = ( + curr_repeat_out[Columns.ACTIONS], + curr_repeat_out[Columns.ACTION_LOGP], + curr_repeat_out[QF_PREDS], + ) + # Then, for the next observations and the current action distribution. + next_repeat_out = self._repeat_actions(batch[Columns.NEXT_OBS]) + ( + fwd_out["actions_next_repeat"], + fwd_out["logps_next_repeat"], + fwd_out["q_next_repeat"], + ) = ( + next_repeat_out[Columns.ACTIONS], + next_repeat_out[Columns.ACTION_LOGP], + next_repeat_out[QF_PREDS], + ) + if self.twin_q: + # First for the random actions from the mu-distribution. + fwd_out["q_twin_rand_repeat"] = rand_repeat_out[QF_TWIN_PREDS] + # Second for the current observations and the current action distribution. + fwd_out["q_twin_curr_repeat"] = curr_repeat_out[QF_TWIN_PREDS] + # Then, for the next observations and the current action distribution. + fwd_out["q_twin_next_repeat"] = next_repeat_out[QF_TWIN_PREDS] + # Reset the gradient requirements for all Q-function parameters. + for param in all_params: + param.requires_grad = True + + return fwd_out + + def _repeat_tensor(self, tensor: TensorType, repeat: int) -> TensorType: + """Generates a repeated version of a tensor. + + The repetition is done similar `np.repeat` and repeats each value + instead of the complete vector. + + Args: + tensor: The tensor to be repeated. + repeat: How often each value in the tensor should be repeated. + + Returns: + A tensor holding `repeat` repeated values of the input `tensor` + """ + # Insert the new dimension at axis 1 into the tensor. + t_repeat = tensor.unsqueeze(1) + # Repeat the tensor along the new dimension. + t_repeat = torch.repeat_interleave(t_repeat, repeat, dim=1) + # Stack the repeated values into the batch dimension. + t_repeat = t_repeat.view(-1, *tensor.shape[1:]) + # Return the repeated tensor. + return t_repeat + + def _repeat_actions( + self, obs: TensorType, actions: Optional[TensorType] = None + ) -> Dict[str, TensorType]: + """Generated actions and Q-values for repeated observations. + + The `self.config.model_condfig_dict["num_actions"]` define a multiplier + used for generating `num_actions` as many actions as the batch size. + Observations are repeated and then a model forward pass is made. + + Args: + obs: A batched observation tensor. + actions: An optional batched actions tensor. + + Returns: + A dictionary holding the (sampled or passed-in actions), the log + probabilities (of sampled actions), the Q-values and if available + the twin-Q values. + """ + output = {} + # Receive the batch size. + batch_size = obs.shape[0] + # Receive the number of action to sample. + num_actions = self.config.model_config_dict["num_actions"] + # Repeat the observations `num_actions` times. + obs_repeat = tree.map_structure( + lambda t: self._repeat_tensor(t, num_actions), obs + ) + # Generate a batch for the forward pass. + temp_batch = {Columns.OBS: obs_repeat} + if actions is None: + # TODO (simon): Run the forward pass in inference mode. + # Compute the action logits. + pi_encoder_outs = self.pi_encoder(temp_batch) + action_logits = self.pi(pi_encoder_outs[ENCODER_OUT]) + # Generate the squashed Gaussian from the model's logits. + action_dist = self.get_train_action_dist_cls().from_logits(action_logits) + # Sample the actions. Note, we want to make a backward pass through + # these actions. + output[Columns.ACTIONS] = action_dist.rsample() + # Compute the action log-probabilities. + output[Columns.ACTION_LOGP] = action_dist.logp( + output[Columns.ACTIONS] + ).view(batch_size, num_actions, 1) + else: + output[Columns.ACTIONS] = actions + + # Compute all Q-values. + temp_batch.update( + { + Columns.ACTIONS: output[Columns.ACTIONS], + } + ) + output.update( + { + QF_PREDS: self._qf_forward_train_helper( + temp_batch, + self.qf_encoder, + self.qf, + ).view(batch_size, num_actions, 1) + } + ) + # If we have a twin-Q network, compute its Q-values, too. + if self.twin_q: + output.update( + { + QF_TWIN_PREDS: self._qf_forward_train_helper( + temp_batch, + self.qf_twin_encoder, + self.qf_twin, + ).view(batch_size, num_actions, 1) + } + ) + del temp_batch + + # Return + return output diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index da9675d1f473..4cffb877bdc7 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -169,6 +169,7 @@ def compute_loss_for_module( # Hence, we can't do `fwd_out[q_curr].detach()`! # Note further, we minimize here, while the original equation in Haarnoja et # al. (2018) considers maximization. + # TODO (simon): Rename to `resampled` to `current`. actor_loss = torch.mean( alpha.detach() * fwd_out["logp_resampled"] - fwd_out["q_curr"] ) diff --git a/rllib/algorithms/sac/torch/sac_torch_rl_module.py b/rllib/algorithms/sac/torch/sac_torch_rl_module.py index 957e6a9ebf32..878d55532bdc 100644 --- a/rllib/algorithms/sac/torch/sac_torch_rl_module.py +++ b/rllib/algorithms/sac/torch/sac_torch_rl_module.py @@ -141,12 +141,12 @@ def _forward_train(self, batch: Dict) -> Dict[str, Any]: # here). This is different from doing `.detach()` or `with torch.no_grads()`, # as these two methds would fully block all gradient recordings, including # the needed policy ones. - all_params = ( - list(self.qf.parameters()) - + list(self.qf_encoder.parameters()) - + list(self.qf_twin.parameters()) - + list(self.qf_twin_encoder.parameters()) - ) + all_params = list(self.qf.parameters()) + list(self.qf_encoder.parameters()) + if self.twin_q: + all_params += list(self.qf_twin.parameters()) + list( + self.qf_twin_encoder.parameters() + ) + for param in all_params: param.requires_grad = False output["q_curr"] = self.compute_q_values(q_batch_curr) diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 018181ccc178..b52c6603f629 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -1106,6 +1106,9 @@ def _finalize_fn(batch: Dict[str, numpy.ndarray]) -> Dict[str, Any]: ) # Call `after_gradient_based_update` to allow for non-gradient based # cleanups-, logging-, and update logic to happen. + # TODO (simon): Check, if this should stay here, when running multiple + # gradient steps inside the iterator loop above (could be a complete epoch) + # the target networks might need to be updated earlier. self.after_gradient_based_update(timesteps=timesteps or {}) # Reduce results across all minibatch update steps. diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index ea7813f50d04..db844938e339 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -69,7 +69,7 @@ class OfflinePreLearner: the `__call__` method and `_map_to_episodes` can be overridden to induce custom logic for the complete transformation pipeline (`__call__`) or for converting to episodes only ('_map_to_episodes`). For an example - how this class can be sued to also compute values and advantages see + how this class can be used to also compute values and advantages see `rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`. Custom `OfflinePreLearner` classes can be passed into diff --git a/rllib/tuned_examples/cql/pendulum_cql.py b/rllib/tuned_examples/cql/pendulum_cql.py index 8d5f47be2780..f821cdb8859d 100644 --- a/rllib/tuned_examples/cql/pendulum_cql.py +++ b/rllib/tuned_examples/cql/pendulum_cql.py @@ -38,9 +38,15 @@ dataset_num_iters_per_learner=1 if args.num_gpus == 0 else None, ) .training( - bc_iters=100, - train_batch_size_per_learner=2000, - twin_q=False, + bc_iters=200, + tau=9.5e-3, + min_q_weight=5.0, + train_batch_size_per_learner=2048, + twin_q=True, + actor_lr=1.7e-3 * (args.num_gpus or 1) ** 0.5, + critic_lr=2.5e-3 * (args.num_gpus or 1) ** 0.5, + alpha_lr=1e-3 * (args.num_gpus or 1) ** 0.5, + lr=None, ) .reporting( min_time_s_per_iteration=10, @@ -48,7 +54,7 @@ ) .evaluation( evaluation_interval=1, - evaluation_num_env_runners=2, + evaluation_num_env_runners=0, evaluation_duration=10, evaluation_config={ "explore": False, @@ -56,7 +62,6 @@ ) ) - stop = { f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -700.0, NUM_ENV_STEPS_SAMPLED_LIFETIME: 800000,