Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update doc: SB3-Contrib #267

Merged
merged 8 commits into from
Dec 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitlab-ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: stablebaselines/stable-baselines3-cpu:0.9.0a2
image: stablebaselines/stable-baselines3-cpu:0.11.0a4

type-check:
script:
Expand Down
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Stable Baselines3

Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

You can read a detailed presentation of Stable Baselines in the [Medium article](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-reinforcement-learning-made-easy-df87c4b2fc82).

Expand Down Expand Up @@ -50,7 +50,6 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st

Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/)


## RL Baselines3 Zoo: A Collection of Trained RL Agents

[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.
Expand All @@ -68,6 +67,15 @@ Github repo: https://github.com/DLR-RM/rl-baselines3-zoo

Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html

## SB3-Contrib: Experimental RL Features

We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib)

This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN).

Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/)


## Installation

**Note:** Stable-Baselines3 supports PyTorch 1.4+.
Expand Down
4 changes: 4 additions & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ Actions ``gym.spaces``:
- ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination.


.. note::

More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo <sb3_contrib>`.

.. note::

Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper
Expand Down
7 changes: 3 additions & 4 deletions docs/guide/rl_tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,6 @@ Looking at the training curve (episode reward function of the timesteps) is a go





We suggest you reading `Deep Reinforcement Learning that Matters <https://arxiv.org/abs/1709.06560>`_ for a good discussion about RL evaluation.

You can also take a look at this `blog post <https://openlab-flowers.inria.fr/t/how-many-random-seeds-should-i-use-statistical-power-analysis-in-deep-reinforcement-learning-experiments/457>`_
Expand Down Expand Up @@ -122,6 +120,7 @@ Discrete Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms.
We notably provide QR-DQN in our :ref:`contrib repo <sb3_contrib>`.
DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer).

Discrete Actions - Multiprocessed
Expand All @@ -136,7 +135,7 @@ Continuous Actions
Continuous Actions - Single Process
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Current State Of The Art (SOTA) algorithms are ``SAC`` and ``TD3``.
Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo <sb3_contrib>`).
Please use the hyperparameters in the `RL zoo <https://github.com/DLR-RM/rl-baselines3-zoo>`_ for best results.


Expand All @@ -156,7 +155,7 @@ Goal Environment
-----------------

If your environment follows the ``GoalEnv`` interface (cf :ref:`HER <her>`), then you should use
HER + (SAC/TD3/DDPG/DQN) depending on the action space.
HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space.


.. note::
Expand Down
97 changes: 97 additions & 0 deletions docs/guide/sb3_contrib.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
.. _sb3_contrib:

==================
SB3 Contrib
==================

We implement experimental features in a separate contrib repository:
`SB3-Contrib`_

This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still
providing the latest features, like Truncated Quantile Critics (TQC) or
Quantile Regression DQN (QR-DQN).

Why create this repository?
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Over the span of stable-baselines and stable-baselines3, the community
has been eager to contribute in form of better logging utilities,
environment wrappers, extended support (e.g. different action spaces)
and learning algorithms.

However sometimes these utilities were too niche to be considered for
stable-baselines or proved to be too difficult to integrate well into
the existing code without creating a mess. sb3-contrib aims to fix this by not
requiring the neatest code integration with existing code and not
setting limits on what is too niche: almost everything remotely useful
goes!
We hope this allows us to provide reliable implementations
following stable-baselines usual standards (consistent style, documentation, etc)
beyond the relatively small scope of utilities in the main repository.

Features
--------

See documentation for the full list of included features.

**RL Algorithms**:

- `Truncated Quantile Critics (TQC)`_
- `Quantile Regression DQN (QR-DQN)`_

**Gym Wrappers**:

- `Time Feature Wrapper`_

Documentation
-------------

Documentation is available online: https://sb3-contrib.readthedocs.io/

Installation
------------

To install Stable-Baselines3 contrib with pip, execute:

::

pip install sb3-contrib

We recommend to use the ``master`` version of Stable Baselines3 and SB3-Contrib.

To install Stable Baselines3 ``master`` version:

::

pip install git+https://github.com/DLR-RM/stable-baselines3

To install Stable Baselines3 contrib ``master`` version:

::

pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib


Example
-------

SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3,
using SB3-Contrib should be easy too.

Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.

.. code-block:: python

from sb3_contrib import QRDQN

policy_kwargs = dict(n_quantiles=50)
model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")



.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
.. _Truncated Quantile Critics (TQC): https://arxiv.org/abs/2005.04269
.. _Quantile Regression DQN (QR-DQN): https://arxiv.org/abs/1710.10044
.. _Time Feature Wrapper: https://arxiv.org/abs/1712.00378
9 changes: 6 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.

Welcome to Stable Baselines3 docs! - RL Baselines Made Easy
===========================================================
Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations
========================================================================

`Stable Baselines3 <https://github.com/DLR-RM/stable-baselines3>`_ is a set of improved implementations of reinforcement learning algorithms in PyTorch.
`Stable Baselines3 (SB3) <https://github.com/DLR-RM/stable-baselines3>`_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch.
It is the next major version of `Stable Baselines <https://github.com/hill-a/stable-baselines>`_.


Expand All @@ -16,6 +16,8 @@ RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/

RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning.

SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib


Main Features
--------------
Expand Down Expand Up @@ -45,6 +47,7 @@ Main Features
guide/callbacks
guide/tensorboard
guide/rl_zoo
guide/sb3_contrib
guide/imitation
guide/migration
guide/checking_nan
Expand Down
6 changes: 5 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Pre-Release 0.11.0a2 (WIP)
Pre-Release 0.11.0a4 (WIP)
-------------------------------

Breaking Changes:
Expand Down Expand Up @@ -44,6 +44,9 @@ Others:
- Add signatures to callable type annotations (@ernestum)
- Improve error message in ``NatureCNN``
- Added checks for supported action spaces to improve clarity of error messages for the user
- Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib.
- Updated docker base image to Ubuntu 18.04
- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch)

Documentation:
^^^^^^^^^^^^^^
Expand All @@ -55,6 +58,7 @@ Documentation:
- Added example of learning rate schedule
- Added SUMO-RL as example project (@LucasAlegre)
- Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre)
- Added SB3-Contrib page

Pre-Release 0.10.0 (2020-10-28)
-------------------------------
Expand Down
4 changes: 2 additions & 2 deletions scripts/build_docker.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

CPU_PARENT=ubuntu:16.04
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04
CPU_PARENT=ubuntu:18.04
GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04

TAG=stablebaselines/stable-baselines3
VERSION=$(cat ./stable_baselines3/version.txt)
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

# Stable Baselines3

Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).
Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines).

These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details.

Expand All @@ -29,6 +29,9 @@
RL Baselines3 Zoo:
https://github.com/DLR-RM/rl-baselines3-zoo

SB3 Contrib:
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

## Quick example

Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym.
Expand Down Expand Up @@ -112,7 +115,7 @@
"atari_py~=0.2.0",
"pillow",
# Tensorboard support
"tensorboard",
"tensorboard>=2.2.0",
# Checking memory taken by replay buffer
"psutil",
],
Expand Down
18 changes: 9 additions & 9 deletions stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,23 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

with th.no_grad():
# Compute the target Q values
target_q = self.q_net_target(replay_data.next_observations)
# Compute the next Q-values using the target network
next_q_values = self.q_net_target(replay_data.next_observations)
# Follow greedy policy: use the one with the highest value
target_q, _ = target_q.max(dim=1)
next_q_values, _ = next_q_values.max(dim=1)
# Avoid potential broadcast issue
target_q = target_q.reshape(-1, 1)
next_q_values = next_q_values.reshape(-1, 1)
# 1-step TD target
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

# Get current Q estimates
current_q = self.q_net(replay_data.observations)
# Get current Q-values estimates
current_q_values = self.q_net(replay_data.observations)

# Retrieve the q-values for the actions from the replay buffer
current_q = th.gather(current_q, dim=1, index=replay_data.actions.long())
current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long())

# Compute Huber loss (less sensitive to outliers)
loss = F.smooth_l1_loss(current_q, target_q)
loss = F.smooth_l1_loss(current_q_values, target_q_values)
losses.append(loss.item())

# Optimize the policy
Expand Down
1 change: 0 additions & 1 deletion stable_baselines3/dqn/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def _get_data(self) -> Dict[str, Any]:
features_dim=self.features_dim,
activation_fn=self.activation_fn,
features_extractor=self.features_extractor,
epsilon=self.epsilon,
)
)
return data
Expand Down
16 changes: 8 additions & 8 deletions stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,20 +223,20 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
with th.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute the target Q value: min over all critics targets
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
target_q, _ = th.min(targets, dim=1, keepdim=True)
# Compute the next Q values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
# add entropy term
target_q = target_q - ent_coef * next_log_prob.reshape(-1, 1)
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
# td error + entropy term
q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

# Get current Q estimates for each critic network
# Get current Q-values estimates for each critic network
# using action from the replay buffer
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
current_q_values = self.critic(replay_data.observations, replay_data.actions)

# Compute critic loss
critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates])
critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_losses.append(critic_loss.item())

# Optimize the critic
Expand Down
14 changes: 7 additions & 7 deletions stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip)
next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1)

# Compute the target Q value: min over all critics targets
targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
target_q, _ = th.min(targets, dim=1, keepdim=True)
target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q
# Compute the next Q-values: min over all critics targets
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

# Get current Q estimates for each critic network
current_q_estimates = self.critic(replay_data.observations, replay_data.actions)
# Get current Q-values estimates for each critic network
current_q_values = self.critic(replay_data.observations, replay_data.actions)

# Compute critic loss
critic_loss = sum([F.mse_loss(current_q, target_q) for current_q in current_q_estimates])
critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values])
critic_losses.append(critic_loss.item())

# Optimize the critics
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.11.0a2
0.11.0a4
Loading