Skip to content

Commit

Permalink
Add CrossQ (#28)
Browse files Browse the repository at this point in the history
* Added support for large values for gradient_steps to SAC, TD3, and TQC by replacing the unrolled loop with jax.lax.fori_loop

* Add comments

* Hotfix for train signature

* Fixed start index for dynamic_slice_in_dim

* Rename policy delay

* Fix type annotation

* Add CrossQ POC

* Remove old annotations

* Add actor BN

* Concatenate obs/next obs, first working example

* Deactivate batchnorm for actor

* Fix off-by-one and improve type annotation

* Fix typo

* Update type annotation

* Update off-by one

* Implemented CrossQ

* Added CrossQ to README

* clean up and comments

* refactored and added comments

* Update doc

* Cleanup CrossQ and BatchRenorm

* Update tests

* Fix for new tfp version

* Clean-up: Removed unused variables and fixed typo

* Cleaner variable names for BatchReNorm

Co-authored-by: Jan Schneider <33448112+jan1854@users.noreply.github.com>

* Allow to change the number of warmup steps

* Update SB3 dependency

* Deprecate DroQ class

* [ci skip] Update comments

---------

Co-authored-by: Jan Schneider <Jan.Schneider1997@gmail.com>
Co-authored-by: Daniel Palenicek <daniel.palenicek@tu-darmstadt.de>
Co-authored-by: Jan Schneider <Jan.Schneider@tuebingen.mpg.de>
Co-authored-by: Jan Schneider <33448112+jan1854@users.noreply.github.com>
  • Loading branch information
5 people authored Apr 3, 2024
1 parent 655f4a3 commit c8db73f
Show file tree
Hide file tree
Showing 10 changed files with 1,103 additions and 42 deletions.
20 changes: 15 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Implemented algorithms:
- [Deep Q Network (DQN)](https://arxiv.org/abs/1312.5602)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/abs/1802.09477)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/abs/1509.02971)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)


### Install using pip
Expand All @@ -36,7 +37,7 @@ pip install sbx-rl
```python
import gymnasium as gym

from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

env = gym.make("Pendulum-v1", render_mode="human")

Expand All @@ -61,15 +62,17 @@ Since SBX shares the SB3 API, it is compatible with the [RL Zoo](https://github.
import rl_zoo3
import rl_zoo3.train
from rl_zoo3.train import train
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.train.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand All @@ -89,15 +92,17 @@ The same goes for the enjoy script:
import rl_zoo3
import rl_zoo3.enjoy
from rl_zoo3.enjoy import enjoy
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, DroQ
from sbx import DDPG, DQN, PPO, SAC, TD3, TQC, CrossQ

rl_zoo3.ALGOS["ddpg"] = DDPG
rl_zoo3.ALGOS["dqn"] = DQN
rl_zoo3.ALGOS["droq"] = DroQ
# See note below to use DroQ configuration
# rl_zoo3.ALGOS["droq"] = DroQ
rl_zoo3.ALGOS["sac"] = SAC
rl_zoo3.ALGOS["ppo"] = PPO
rl_zoo3.ALGOS["td3"] = TD3
rl_zoo3.ALGOS["tqc"] = TQC
rl_zoo3.ALGOS["crossq"] = CrossQ
rl_zoo3.enjoy.ALGOS = rl_zoo3.ALGOS
rl_zoo3.exp_manager.ALGOS = rl_zoo3.ALGOS

Expand Down Expand Up @@ -125,6 +130,11 @@ and then using the RL Zoo script defined above: `python train.py --algo sac --en
We recommend playing with the `policy_delay` and `gradient_steps` parameters for better speed/efficiency.
Having a higher learning rate for the q-value function is also helpful: `qf_learning_rate: !!float 1e-3`.

Note: when using the DroQ configuration with CrossQ, you should set `layer_norm=False` as there is already batch normalization.

## Benchmark

A partial benchmark can be found on [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sbx) where you can also find several [reports](https://wandb.ai/openrlbenchmark/sbx/reportlist).


## Citing the Project
Expand Down
2 changes: 2 additions & 0 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from sbx.crossq import CrossQ
from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.droq import DroQ
Expand All @@ -14,6 +15,7 @@
__version__ = file_handler.read().strip()

__all__ = [
"CrossQ",
"DDPG",
"DQN",
"DroQ",
Expand Down
206 changes: 206 additions & 0 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from flax.linen.module import Module, compact, merge_param
from flax.linen.normalization import _canonicalize_axes, _compute_stats, _normalize
from jax.nn import initializers

PRNGKey = Any
Array = Any
Shape = Tuple[int, ...]
Dtype = Any # this could be a real type?
Axes = Union[int, Sequence[int]]


class BatchRenorm(Module):
"""BatchRenorm Module (https://arxiv.org/abs/1702.03275).
Adapted from flax.linen.normalization.BatchNorm
BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm,
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase.
This makes it less prone to suffer from "outlier" batches that can happen
during very long training runs and, therefore, is more robust during long training runs.
During the warmup phase, it behaves exactly like a BatchNorm layer.
Usage Note:
If we define a model with BatchRenorm, for example::
BRN = BatchRenorm(use_running_average=False, momentum=0.99, epsilon=0.001, dtype=jnp.float32)
The initialized variables dict will contain in addition to a 'params'
collection a separate 'batch_stats' collection that will contain all the
running statistics for all the BatchRenorm layers in a model::
vars_initialized = BRN.init(key, x) # {'params': ..., 'batch_stats': ...}
We then update the batch_stats during training by specifying that the
`batch_stats` collection is mutable in the `apply` method for our module.::
vars_in = {'params': params, 'batch_stats': old_batch_stats}
y, mutated_vars = BRN.apply(vars_in, x, mutable=['batch_stats'])
new_batch_stats = mutated_vars['batch_stats']
During eval we would define BRN with `use_running_average=True` and use the
batch_stats collection from training to set the statistics. In this case
we are not mutating the batch statistics collection, and needn't mark it
mutable::
vars_in = {'params': params, 'batch_stats': training_batch_stats}
y = BRN.apply(vars_in, x)
Attributes:
use_running_average: if True, the statistics stored in batch_stats will be
used. Else the running statistics will be first updated and then used to normalize.
axis: the feature or non-batch axis of the input.
momentum: decay rate for the exponential moving average of the batch
statistics.
epsilon: a small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: if True, bias (beta) is added.
use_scale: if True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: initializer for bias, by default, zero.
scale_init: initializer for scale, by default, one.
axis_name: the axis name used to combine batch statistics from multiple
devices. See `jax.pmap` for a description of axis names (default: None).
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, `[[0, 1], [2, 3]]` would independently batch-normalize over the
examples on the first two and last two devices. See `jax.lax.psum` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
"""

use_running_average: Optional[bool] = None
axis: int = -1
momentum: float = 0.99
epsilon: float = 0.001
warmup_steps: int = 100_000
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
use_scale: bool = True
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
axis_name: Optional[str] = None
axis_index_groups: Any = None
# This parameter was added in flax.linen 0.7.2 (08/2023)
# commented out to be compatible with a wider range of jax versions
# TODO: re-activate in some months (04/2024)
# use_fast_variance: bool = True

@compact
def __call__(self, x, use_running_average: Optional[bool] = None):
"""Normalizes the input using batch statistics.
NOTE:
During initialization (when `self.is_initializing()` is `True`) the running
average of the batch statistics will not be updated. Therefore, the inputs
fed during initialization don't need to match that of the actual input
distribution and the reduction axis (set with `axis_name`) does not have
to exist.
Args:
x: the input to be normalized.
use_running_average: if true, the statistics stored in batch_stats will be
used instead of computing the batch statistics on the input.
Returns:
Normalized inputs (the same shape as inputs).
"""

use_running_average = merge_param("use_running_average", self.use_running_average, use_running_average)
feature_axes = _canonicalize_axes(x.ndim, self.axis)
reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)
feature_shape = [x.shape[ax] for ax in feature_axes]

ra_mean = self.variable(
"batch_stats",
"mean",
lambda s: jnp.zeros(s, jnp.float32),
feature_shape,
)
ra_var = self.variable("batch_stats", "var", lambda s: jnp.ones(s, jnp.float32), feature_shape)

r_max = self.variable(
"batch_stats",
"r_max",
lambda s: s,
3,
)
d_max = self.variable(
"batch_stats",
"d_max",
lambda s: s,
5,
)
steps = self.variable(
"batch_stats",
"steps",
lambda s: s,
0,
)

if use_running_average:
custom_mean = ra_mean.value
custom_var = ra_var.value
else:
batch_mean, batch_var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
# use_fast_variance=self.use_fast_variance,
)
if self.is_initializing():
custom_mean = batch_mean
custom_var = batch_var
else:
std = jnp.sqrt(batch_var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
# scale
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
# bias
d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)

# BatchNorm normalization, using minibatch stats and running average stats
# Because we use _normalize, this is equivalent to
# ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
# where sigma = sqrt(var)
affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r
affine_var = batch_var / (r**2)

# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warmup_steps).astype(jnp.float32)
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean
custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var

ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean
ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var
steps.value += 1

return _normalize(
self,
x,
custom_mean,
custom_var,
reduction_axes,
feature_axes,
self.dtype,
self.param_dtype,
self.epsilon,
self.use_bias,
self.use_scale,
self.bias_init,
self.scale_init,
)
4 changes: 4 additions & 0 deletions sbx/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ class RLTrainState(TrainState): # type: ignore[misc]
target_params: flax.core.FrozenDict # type: ignore[misc]


class BatchNormTrainState(TrainState): # type: ignore[misc]
batch_stats: flax.core.FrozenDict # type: ignore[misc]


class ReplayBufferSamplesNp(NamedTuple):
observations: np.ndarray
actions: np.ndarray
Expand Down
3 changes: 3 additions & 0 deletions sbx/crossq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbx.crossq.crossq import CrossQ

__all__ = ["CrossQ"]
Loading

0 comments on commit c8db73f

Please sign in to comment.