Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert PPO tanh-normal #310

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 27 additions & 105 deletions sheeprl/algos/ppo/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import copy
import math
from math import prod
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand All @@ -14,7 +13,6 @@
from torch.distributions import Distribution, Independent, Normal, OneHotCategorical

from sheeprl.models.models import MLP, MultiEncoder, NatureCNN
from sheeprl.utils.distribution import safeatanh, safetanh
from sheeprl.utils.fabric import get_single_device_fabric


Expand Down Expand Up @@ -52,37 +50,26 @@ def __init__(
self.keys = keys
self.input_dim = input_dim
self.output_dim = features_dim if features_dim else dense_units
if mlp_layers == 0:
self.model = nn.Identity()
self.output_dim = input_dim
else:
self.model = MLP(
input_dim,
features_dim,
[dense_units] * mlp_layers,
activation=dense_act,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)
self.model = MLP(
input_dim,
features_dim,
[dense_units] * mlp_layers,
activation=dense_act,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)

def forward(self, obs: Dict[str, Tensor]) -> Tensor:
x = torch.cat([obs[k] for k in self.keys], dim=-1)
return self.model(x)


class PPOActor(nn.Module):
def __init__(
self,
actor_backbone: torch.nn.Module,
actor_heads: torch.nn.ModuleList,
is_continuous: bool,
distribution: str = "auto",
) -> None:
def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None:
super().__init__()
self.actor_backbone = actor_backbone
self.actor_heads = actor_heads
self.is_continuous = is_continuous
self.distribution = distribution

def forward(self, x: Tensor) -> List[Tensor]:
x = self.actor_backbone(x)
Expand All @@ -106,21 +93,6 @@ def __init__(
super().__init__()
self.is_continuous = is_continuous
self.distribution_cfg = distribution_cfg
self.distribution = distribution_cfg.get("type", "auto").lower()
if self.distribution not in ("auto", "normal", "tanh_normal", "discrete"):
raise ValueError(
"The distribution must be on of: `auto`, `discrete`, `normal`, `tanh_normal` and `trunc_normal`. "
f"Found: {self.distribution}"
)
if self.distribution == "discrete" and is_continuous:
raise ValueError("You have choose a discrete distribution but `is_continuous` is true")
elif self.distribution != "discrete" and not is_continuous:
raise ValueError("You have choose a continuous distribution but `is_continuous` is false")
if self.distribution == "auto":
if is_continuous:
self.distribution = "normal"
else:
self.distribution = "discrete"
self.actions_dim = actions_dim
in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys])
mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys])
Expand Down Expand Up @@ -177,40 +149,7 @@ def __init__(
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
else:
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim])
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous, self.distribution)

def _normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
actions = actions[0]
log_prob = normal.log_prob(actions)
return actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1)

def _tanh_normal(self, actor_out: Tensor, actions: Optional[List[Tensor]] = None) -> Tuple[Tensor, Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.sample().float()
tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution)
else:
# always composed by a tuple of one element containing all the
# continuous actions
tanh_actions = actions[0].float()
actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution)
log_prob = normal.log_prob(actions)
log_prob -= 2.0 * (
torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device))
- tanh_actions
- torch.nn.functional.softplus(-2.0 * tanh_actions)
).sum(-1, keepdim=False)
return tanh_actions, log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1)
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous)

def forward(
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
Expand All @@ -219,11 +158,17 @@ def forward(
actor_out: List[Tensor] = self.actor(feat)
values = self.critic(feat)
if self.is_continuous:
if self.distribution == "normal":
actions, log_prob, entropy = self._normal(actor_out[0], actions)
elif self.distribution == "tanh_normal":
actions, log_prob, entropy = self._tanh_normal(actor_out[0], actions)
return tuple([actions]), log_prob, entropy, values
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
actions = actions[0]
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values
else:
should_append = False
actions_logprobs: List[Tensor] = []
Expand Down Expand Up @@ -253,38 +198,17 @@ def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.
self.critic = critic
self.actor = actor

def _normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
log_prob = normal.log_prob(actions)
return actions, log_prob.unsqueeze(dim=-1)

def _tanh_normal(self, actor_out: Tensor) -> Tuple[Tensor, Tensor]:
mean, log_std = torch.chunk(actor_out, chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample().float()
tanh_actions = safetanh(actions, eps=torch.finfo(actions.dtype).resolution)
log_prob = normal.log_prob(actions)
log_prob -= 2.0 * (
torch.log(torch.tensor([2.0], dtype=actions.dtype, device=actions.device))
- tanh_actions
- torch.nn.functional.softplus(-2.0 * tanh_actions)
).sum(-1, keepdim=False)
return tanh_actions, log_prob.unsqueeze(dim=-1)

def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
feat = self.feature_extractor(obs)
values = self.critic(feat)
actor_out: List[Tensor] = self.actor(feat)
if self.actor.is_continuous:
if self.actor.distribution == "normal":
actions, log_prob = self._normal(actor_out[0])
elif self.actor.distribution == "tanh_normal":
actions, log_prob = self._tanh_normal(actor_out[0])
return tuple([actions]), log_prob, values
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), values
else:
actions_dist: List[Distribution] = []
actions_logprobs: List[Tensor] = []
Expand Down Expand Up @@ -314,8 +238,6 @@ def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
if self.actor.distribution == "tanh_normal":
actions = safeatanh(actions, eps=torch.finfo(actions.dtype).resolution)
return tuple([actions])
else:
actions: List[Tensor] = []
Expand Down
11 changes: 4 additions & 7 deletions sheeprl/algos/ppo/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,11 @@ def value_loss(
) -> Tensor:
if not clip_vloss:
values_pred = new_values
return F.mse_loss(values_pred, returns, reduction=reduction)
# return F.mse_loss(values_pred, returns, reduction=reduction)
else:
v_loss_unclipped = (new_values - returns) ** 2
v_clipped = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
v_loss_clipped = (v_clipped - returns) ** 2
v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
v_loss = 0.5 * v_loss_max.mean()
return v_loss
values_pred = old_values + torch.clamp(new_values - old_values, -clip_coef, clip_coef)
# return torch.max((new_values - returns) ** 2, (values_pred - returns) ** 2).mean()
return F.mse_loss(values_pred, returns, reduction=reduction)


def entropy_loss(entropy: Tensor, reduction: str = "mean") -> Tensor:
Expand Down
5 changes: 2 additions & 3 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
)
clip_rewards_fn = lambda r: np.tanh(r) if cfg.env.clip_rewards else r
# Create the actor and critic models
agent, player = build_agent(
fabric,
Expand Down Expand Up @@ -305,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
vals = player.get_values(real_next_obs).cpu().numpy()
rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape)
dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
rewards = clip_rewards_fn(rewards).reshape(cfg.env.num_envs, -1).astype(np.float32)
rewards = rewards.reshape(cfg.env.num_envs, -1)

# Update the step data
step_data["dones"] = dones[np.newaxis]
Expand Down Expand Up @@ -348,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"],
local_data["rewards"].to(torch.float64),
local_data["values"],
local_data["dones"],
next_values,
Expand Down
4 changes: 0 additions & 4 deletions sheeprl/configs/exp/ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ algo:
mlp_keys:
encoder: [state]

# Distribution
distribution:
type: "auto"

# Buffer
buffer:
share_data: False
Expand Down
18 changes: 3 additions & 15 deletions sheeprl/utils/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,12 @@

import math
from numbers import Number
from typing import Callable, Union
from typing import Callable

import torch
import torch.nn.functional as F
from packaging import version
from torch import Tensor, autograd
from torch import distributions as d
from torch.distributions import (
Bernoulli,
Categorical,
Distribution,
Independent,
Transform,
TransformedDistribution,
constraints,
)
from torch import Tensor
from torch.distributions import Bernoulli, Categorical, Distribution, constraints
from torch.distributions.kl import _kl_categorical_categorical, register_kl
from torch.distributions.utils import broadcast_all

Expand Down Expand Up @@ -317,7 +307,6 @@ class OneHotCategoricalValidateArgs(Distribution):
probs (Tensor): event probabilities
logits (Tensor): event log probabilities (unnormalized)
"""

arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True
Expand Down Expand Up @@ -402,7 +391,6 @@ class OneHotCategoricalStraightThroughValidateArgs(OneHotCategoricalValidateArgs
[1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
(Bengio et al, 2013)
"""

has_rsample = True

def rsample(self, sample_shape=torch.Size()):
Expand Down
13 changes: 0 additions & 13 deletions sheeprl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,3 @@ def load_state_dict(self, state_dict: Mapping[str, Any]):
self._prev = state_dict["_prev"]
self._pretrain_steps = state_dict["_pretrain_steps"]
return self


# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L156
def safetanh(x, eps):
lim = 1.0 - eps
y = x.tanh()
return y.clamp(-lim, lim)


# https://github.com/pytorch/rl/blob/824f6d192e88c115790cf046e4df416ce2d7aaf6/torchrl/modules/distributions/utils.py#L161
def safeatanh(y, eps):
lim = 1.0 - eps
return y.clamp(-lim, lim).atanh()
Loading