Skip to content

Commit

Permalink
Quick try with CrossQ
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 4, 2024
1 parent 9589326 commit 6e533e0
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 9 deletions.
5 changes: 3 additions & 2 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union

import flax.linen as nn
import jax
Expand Down Expand Up @@ -212,11 +212,12 @@ class SimbaResidualBlock(nn.Module):
hidden_dim: int
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4
norm_layer: Type[nn.Module] = nn.LayerNorm

@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
residual = x
x = nn.LayerNorm()(x)
x = self.norm_layer()(x)
x = nn.Dense(self.hidden_dim * self.scale_factor, kernel_init=nn.initializers.he_normal())(x)
x = self.activation_fn(x)
x = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal())(x)
Expand Down
3 changes: 2 additions & 1 deletion sbx/crossq/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from sbx.common.off_policy_algorithm import OffPolicyAlgorithmJax
from sbx.common.type_aliases import BatchNormTrainState, ReplayBufferSamplesNp
from sbx.crossq.policies import CrossQPolicy
from sbx.crossq.policies import CrossQPolicy, SimbaCrossQPolicy


class EntropyCoef(nn.Module):
Expand All @@ -42,6 +42,7 @@ def __call__(self) -> float:
class CrossQ(OffPolicyAlgorithmJax):
policy_aliases: ClassVar[Dict[str, Type[CrossQPolicy]]] = { # type: ignore[assignment]
"MlpPolicy": CrossQPolicy,
"SimbaPolicy": SimbaCrossQPolicy,
# Minimal dict support using flatten()
"MultiInputPolicy": CrossQPolicy,
}
Expand Down
151 changes: 146 additions & 5 deletions sbx/crossq/policies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union

import flax.linen as nn
import jax
Expand All @@ -10,7 +11,7 @@
from stable_baselines3.common.type_aliases import Schedule

from sbx.common.distributions import TanhTransformedDistribution
from sbx.common.jax_layers import BatchRenorm
from sbx.common.jax_layers import BatchRenorm, SimbaResidualBlock
from sbx.common.policies import BaseJaxPolicy, Flatten
from sbx.common.type_aliases import BatchNormTrainState

Expand Down Expand Up @@ -48,8 +49,52 @@ def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) ->
x = nn.LayerNorm()(x)
x = self.activation_fn(x)
if self.use_batch_norm:
x = BatchRenorm(use_running_average=not train, momentum=self.batch_norm_momentum)(x)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

x = nn.Dense(1)(x)
return x


class SimbaCritic(nn.Module):
net_arch: Sequence[int]
dropout_rate: Optional[float] = None
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, x: jnp.ndarray, action: jnp.ndarray, train: bool = False) -> jnp.ndarray:
x = Flatten()(x)
x = jnp.concatenate([x, action], -1)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)

norm_layer = partial(BatchRenorm, use_running_average=not train, momentum=self.batch_norm_momentum)
x = nn.Dense(self.net_arch[0])(x)

for n_units in self.net_arch:
x = SimbaResidualBlock(
n_units,
self.activation_fn,
self.scale_factor,
norm_layer, # type: ignore[arg-type]
)(x)
# TODO: double check where to put the dropout
if self.dropout_rate is not None and self.dropout_rate > 0:
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=False)
x = BatchRenorm(
use_running_average=not train,
momentum=self.batch_norm_momentum,
warmup_steps=self.renorm_warmup_steps,
)(x)
x = nn.Dense(1)(x)
return x

Expand Down Expand Up @@ -88,6 +133,42 @@ def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
return q_values


class SimbaVectorCritic(nn.Module):
net_arch: Sequence[int]
use_layer_norm: bool = False # ignored
use_batch_norm: bool = True
batch_norm_momentum: float = 0.99
renorm_warmup_steps: int = 100_000
dropout_rate: Optional[float] = None
n_critics: int = 2
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu
scale_factor: int = 4

@nn.compact
def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, train: bool = False):
# Idea taken from https://github.com/perrin-isir/xpag
# Similar to https://github.com/tinkoff-ai/CORL for PyTorch
vmap_critic = nn.vmap(
SimbaCritic,
variable_axes={"params": 0, "batch_stats": 0}, # parameters not shared between the critics
split_rngs={"params": True, "dropout": True, "batch_stats": True}, # different initializations
in_axes=None,
out_axes=0,
axis_size=self.n_critics,
)
q_values = vmap_critic(
# use_layer_norm=self.use_layer_norm,
# use_batch_norm=self.use_batch_norm,
batch_norm_momentum=self.batch_norm_momentum,
renorm_warmup_steps=self.renorm_warmup_steps,
dropout_rate=self.dropout_rate,
net_arch=self.net_arch,
activation_fn=self.activation_fn,
scale_factor=self.scale_factor,
)(obs, action, train)
return q_values


class Actor(nn.Module):
net_arch: Sequence[int]
action_dim: int
Expand Down Expand Up @@ -159,6 +240,8 @@ def __init__(
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = Actor,
vector_critic_class: Type[nn.Module] = VectorCritic,
):
if optimizer_kwargs is None:
# Note: the default value for b1 is 0.9 in Adam.
Expand All @@ -183,6 +266,8 @@ def __init__(
self.batch_norm_momentum = batch_norm_momentum
self.batch_norm_actor = batch_norm_actor
self.renorm_warmup_steps = renorm_warmup_steps
self.actor_class = actor_class
self.vector_critic_class = vector_critic_class

if net_arch is not None:
if isinstance(net_arch, list):
Expand Down Expand Up @@ -216,7 +301,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
obs = jnp.array([self.observation_space.sample()])
action = jnp.array([self.action_space.sample()])

self.actor = Actor(
self.actor = self.actor_class(
action_dim=int(np.prod(self.action_space.shape)),
net_arch=self.net_arch_pi,
use_batch_norm=self.batch_norm_actor,
Expand Down Expand Up @@ -244,7 +329,7 @@ def build(self, key: jax.Array, lr_schedule: Schedule, qf_learning_rate: float)
),
)

self.qf = VectorCritic(
self.qf = self.vector_critic_class(
dropout_rate=self.dropout_rate,
use_layer_norm=self.layer_norm,
use_batch_norm=self.batch_norm,
Expand Down Expand Up @@ -319,3 +404,59 @@ def _predict(self, observation: np.ndarray, deterministic: bool = False) -> np.n
if not self.use_sde:
self.reset_noise()
return self.sample_action(self.actor_state, observation, self.noise_key)


class SimbaCrossQPolicy(CrossQPolicy):
def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Box,
lr_schedule: Schedule,
net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None,
dropout_rate: float = 0,
layer_norm: bool = False,
batch_norm: bool = True,
batch_norm_actor: bool = True,
batch_norm_momentum: float = 0.99,
renorm_warmup_steps: int = 100000,
use_sde: bool = False,
activation_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2,
features_extractor_class=None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = Actor, # TODO: replace with Simba actor
vector_critic_class: Type[nn.Module] = SimbaVectorCritic,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
dropout_rate,
layer_norm,
batch_norm,
batch_norm_actor,
batch_norm_momentum,
renorm_warmup_steps,
use_sde,
activation_fn,
log_std_init,
use_expln,
clip_mean,
features_extractor_class,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
n_critics,
share_features_extractor,
actor_class,
vector_critic_class,
)
2 changes: 1 addition & 1 deletion sbx/sac/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
# AdamW for simba
optimizer_class: Callable[..., optax.GradientTransformation] = optax.adamw,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
n_critics: int = 2,
n_critics: int = 1,
share_features_extractor: bool = False,
actor_class: Type[nn.Module] = SimbaActor,
vector_critic_class: Type[nn.Module] = SimbaVectorCritic,
Expand Down

0 comments on commit 6e533e0

Please sign in to comment.