diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 63887488..98c7e807 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -1,7 +1,6 @@ from __future__ import annotations import copy -import math from math import prod from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -14,7 +13,6 @@ from torch.distributions import Distribution, Independent, Normal, OneHotCategorical from sheeprl.models.models import MLP, MultiEncoder, NatureCNN -from sheeprl.utils.distribution import safeatanh, safetanh from sheeprl.utils.fabric import get_single_device_fabric @@ -52,18 +50,14 @@ def __init__( self.keys = keys self.input_dim = input_dim self.output_dim = features_dim if features_dim else dense_units - if mlp_layers == 0: - self.model = nn.Identity() - self.output_dim = input_dim - else: - self.model = MLP( - input_dim, - features_dim, - [dense_units] * mlp_layers, - activation=dense_act, - norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, - norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None, - ) + self.model = MLP( + input_dim, + features_dim, + [dense_units] * mlp_layers, + activation=dense_act, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None, + ) def forward(self, obs: Dict[str, Tensor]) -> Tensor: x = torch.cat([obs[k] for k in self.keys], dim=-1) @@ -71,18 +65,11 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class PPOActor(nn.Module): - def __init__( - self, - actor_backbone: torch.nn.Module, - actor_heads: torch.nn.ModuleList, - is_continuous: bool, - distribution: str = "auto", - ) -> None: + def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None: super().__init__() self.actor_backbone = actor_backbone self.actor_heads = actor_heads self.is_continuous = is_continuous - self.distribution = distribution def forward(self, x: Tensor) -> List[Tensor]: x = self.actor_backbone(x) @@ -106,21 +93,6 @@ def __init__( super().__init__() self.is_continuous = is_continuous self.distribution_cfg = distribution_cfg - self.distribution = distribution_cfg.get("type", "auto").lower() - if self.distribution not in ("auto", "normal", "tanh_normal", "discrete"): - raise ValueError( - "The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. " - f"Found: {self.distribution}" - ) - if self.distribution == "discrete" and is_continuous: - raise ValueError("You have choose a discrete distribution but `is_continuous` is true") - elif self.distribution != "discrete" and not is_continuous: - raise ValueError("You have choose a continuous distribution but `is_continuous` is false") - if self.distribution == "auto": - if is_continuous: - self.distribution = "normal" - else: - self.distribution = "discrete" self.actions_dim = actions_dim in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys]) mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys]) @@ -177,40 +149,7 @@ def __init__( actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) else: actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) - self.actor = PPOActor(actor_backbone, actor_heads, is_continuous, self.distribution) - - def _normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]: - mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - if actions is None: - actions = normal.sample() - else: - # always composed by a tuple of one element containing all the - # continuous actions - actions = actions[0] - log_prob = normal.log_prob(actions) - return actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1) - - def _tanh_normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]: - mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - if actions is None: - actions = normal.sample().float() - tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution) - else: - # always composed by a tuple of one element containing all the - # continuous actions - tanh_actions = actions[0].float() - actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution) - log_prob = normal.log_prob(actions) - log_prob -= 2.0 * ( - torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device)) - - tanh_actions - - torch.nn.functional.softplus(-2.0 * tanh_actions) - ).sum(-1, keepdim=False) - return tanh_actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) def forward( self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None @@ -219,11 +158,17 @@ def forward( actor_out: List[Tensor] = self.actor(feat) values = self.critic(feat) if self.is_continuous: - if self.distribution == "normal": - actions, log_prob, entropy = self._normal(actor_out[0], actions) - elif self.distribution == "tanh_normal": - actions, log_prob, entropy = self._tanh_normal(actor_out[0], actions) - return tuple([actions]), log_prob, entropy, values + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + if actions is None: + actions = normal.sample() + else: + # always composed by a tuple of one element containing all the + # continuous actions + actions = actions[0] + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values else: should_append = False actions_logprobs: List[Tensor] = [] @@ -253,38 +198,17 @@ def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn. self.critic = critic self.actor = actor - def _normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]: - mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - actions = normal.sample() - log_prob = normal.log_prob(actions) - return actions, log_prob.unsqueeze(dim=-1) - - def _tanh_normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]: - mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1) - std = log_std.exp() - normal = Independent(Normal(mean, std), 1) - actions = normal.sample().float() - tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution) - log_prob = normal.log_prob(actions) - log_prob -= 2.0 * ( - torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device)) - - tanh_actions - - torch.nn.functional.softplus(-2.0 * tanh_actions) - ).sum(-1, keepdim=False) - return tanh_actions, log_prob.unsqueeze(dim=-1) - def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]: feat = self.feature_extractor(obs) values = self.critic(feat) actor_out: List[Tensor] = self.actor(feat) if self.actor.is_continuous: - if self.actor.distribution == "normal": - actions, log_prob = self._normal(actor_out[0]) - elif self.actor.distribution == "tanh_normal": - actions, log_prob = self._tanh_normal(actor_out[0]) - return tuple([actions]), log_prob, values + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), values else: actions_dist: List[Distribution] = [] actions_logprobs: List[Tensor] = [] @@ -314,8 +238,6 @@ def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[ std = log_std.exp() normal = Independent(Normal(mean, std), 1) actions = normal.sample() - if self.actor.distribution == "tanh_normal": - actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution) return tuple([actions]) else: actions: List[Tensor] = [] diff --git a/sheeprl/algos/ppo/loss.py b/sheeprl/algos/ppo/loss.py index 15209a47..5422da54 100644 --- a/sheeprl/algos/ppo/loss.py +++ b/sheeprl/algos/ppo/loss.py @@ -52,14 +52,11 @@ def value_loss( ) -> Tensor: if not clip_vloss: values_pred = new_values - return F.mse_loss(values_pred, returns, reduction=reduction) + # return F.mse_loss(values_pred, returns, reduction=reduction) else: - v_loss_unclipped = (new_values - returns) ** 2 - v_clipped = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef) - v_loss_clipped = (v_clipped - returns) ** 2 - v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = 0.5 * v_loss_max.mean() - return v_loss + values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef) + # return torch.max((new_values - returns) ** 2, (values_pred - returns) ** 2).mean() + return F.mse_loss(values_pred, returns, reduction=reduction) def entropy_loss(entropy: Tensor, reduction: str = "mean") -> Tensor: diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 205489d9..95057f2d 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -169,7 +169,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) - clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r # Create the actor and critic models agent, player = build_agent( fabric, @@ -305,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): vals = player.get_values(real_next_obs).cpu().numpy() rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) - rewards = clip_rewards_fn(rewards).reshape(cfg.env.num_envs, -1).astype(np.float32) + rewards = rewards.reshape(cfg.env.num_envs, -1) # Update the step data step_data["dones"] = dones[np.newaxis] @@ -348,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs) next_values = player.get_values(torch_obs) returns, advantages = gae( - local_data["rewards"], + local_data["rewards"].to(torch.float64), local_data["values"], local_data["dones"], next_values, diff --git a/sheeprl/configs/exp/ppo.yaml b/sheeprl/configs/exp/ppo.yaml index f149611d..c5c05719 100644 --- a/sheeprl/configs/exp/ppo.yaml +++ b/sheeprl/configs/exp/ppo.yaml @@ -13,10 +13,6 @@ algo: mlp_keys: encoder: [state] -# Distribution -distribution: - type: "auto" - # Buffer buffer: share_data: False diff --git a/sheeprl/utils/distribution.py b/sheeprl/utils/distribution.py index 810e4252..31765bb6 100644 --- a/sheeprl/utils/distribution.py +++ b/sheeprl/utils/distribution.py @@ -2,22 +2,12 @@ import math from numbers import Number -from typing import Callable, Union +from typing import Callable import torch import torch.nn.functional as F -from packaging import version -from torch import Tensor, autograd -from torch import distributions as d -from torch.distributions import ( - Bernoulli, - Categorical, - Distribution, - Independent, - Transform, - TransformedDistribution, - constraints, -) +from torch import Tensor +from torch.distributions import Bernoulli, Categorical, Distribution, constraints from torch.distributions.kl import _kl_categorical_categorical, register_kl from torch.distributions.utils import broadcast_all @@ -317,7 +307,6 @@ class OneHotCategoricalValidateArgs(Distribution): probs (Tensor): event probabilities logits (Tensor): event log probabilities (unnormalized) """ - arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} support = constraints.one_hot has_enumerate_support = True @@ -402,7 +391,6 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation (Bengio et al, 2013) """ - has_rsample = True def rsample(self, sample_shape=torch.Size()): diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index 971c4192..74bf8a35 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -298,16 +298,3 @@ def load_state_dict(self, state_dict: Mapping[str, Any]): self._prev = state_dict["_prev"] self._pretrain_steps = state_dict["_pretrain_steps"] return self - - -# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L156 -def safetanh(x, eps): - lim = 1.0 - eps - y = x.tanh() - return y.clamp(-lim, lim) - - -# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L161 -def safeatanh(y, eps): - lim = 1.0 - eps - return y.clamp(-lim, lim).atanh()