From a45e1cfb795f5b465994cad2dc7bfab61d8c1008 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 11:30:40 +0100 Subject: [PATCH 1/8] Fix big when saving/loading q-net alone --- docs/misc/changelog.rst | 2 +- stable_baselines3/dqn/policies.py | 1 - stable_baselines3/version.txt | 2 +- tests/test_save_load.py | 76 +++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 10ff71dc9..e2cfb474d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Pre-Release 0.11.0a2 (WIP) +Pre-Release 0.11.0a4 (WIP) ------------------------------- Breaking Changes: diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index f72424ec1..cd0d17e41 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -74,7 +74,6 @@ def _get_data(self) -> Dict[str, Any]: features_dim=self.features_dim, activation_fn=self.activation_fn, features_extractor=self.features_extractor, - epsilon=self.epsilon, ) ) return data diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index a09c7eb7a..1b742ef03 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.11.0a2 +0.11.0a4 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index f7c5521b0..b5b733f47 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -408,6 +408,82 @@ def test_save_load_policy(tmp_path, model_class, policy_str): os.remove(tmp_path / "actor.pkl") +@pytest.mark.parametrize("model_class", [DQN]) +@pytest.mark.parametrize("policy_str", ["MlpPolicy", "CnnPolicy"]) +def test_save_load_q_net(tmp_path, model_class, policy_str): + """ + Test saving and loading q-network/quantile net only. + + :param model_class: (BaseAlgorithm) A RL model + :param policy_str: (str) Name of the policy. + """ + kwargs = dict(policy_kwargs=dict(net_arch=[16])) + if policy_str == "MlpPolicy": + env = select_env(model_class) + else: + if model_class in [DQN]: + # Avoid memory error when using replay buffer + # Reduce the size of the features + kwargs = dict( + buffer_size=250, + learning_starts=100, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) + env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == DQN) + + env = DummyVecEnv([lambda: env]) + + # create model + model = model_class(policy_str, env, verbose=1, **kwargs) + model.learn(total_timesteps=300) + + env.reset() + observations = np.concatenate([env.step([env.action_space.sample()])[0] for _ in range(10)], axis=0) + + q_net = model.q_net + q_net_class = q_net.__class__ + + # Get dictionary of current parameters + params = deepcopy(q_net.state_dict()) + + # Modify all parameters to be random values + random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items()) + + # Update model parameters with the new random values + q_net.load_state_dict(random_params) + + new_params = q_net.state_dict() + # Check that all params are different now + for k in params: + assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected." + + params = new_params + + # get selected actions + selected_actions, _ = q_net.predict(observations, deterministic=True) + + # Save and load q_net + q_net.save(tmp_path / "q_net.pkl") + + del q_net + + q_net = q_net_class.load(tmp_path / "q_net.pkl") + + # check if params are still the same after load + new_params = q_net.state_dict() + + # Check that all params are the same as before save load procedure now + for key in params: + assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load." + + # check if model still selects the same actions + new_selected_actions, _ = q_net.predict(observations, deterministic=True) + assert np.allclose(selected_actions, new_selected_actions, 1e-4) + + # clear file from os + os.remove(tmp_path / "q_net.pkl") + + @pytest.mark.parametrize("pathtype", [str, pathlib.Path]) def test_open_file_str_pathlib(tmp_path, pathtype): # check that suffix isn't added because we used open_path first From 43d829e1121589b7cf0c3c7fa714b0ddb1bad1b7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 11:44:55 +0100 Subject: [PATCH 2/8] Rename variables to match SB3-contrib --- docs/misc/changelog.rst | 1 + stable_baselines3/dqn/dqn.py | 18 +++++++++--------- stable_baselines3/sac/sac.py | 16 ++++++++-------- stable_baselines3/td3/td3.py | 14 +++++++------- 4 files changed, 25 insertions(+), 24 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e2cfb474d..e30493d2c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -44,6 +44,7 @@ Others: - Add signatures to callable type annotations (@ernestum) - Improve error message in ``NatureCNN`` - Added checks for supported action spaces to improve clarity of error messages for the user +- Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib. Documentation: ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 045c377d3..2818e53dc 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -155,23 +155,23 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) with th.no_grad(): - # Compute the target Q values - target_q = self.q_net_target(replay_data.next_observations) + # Compute the next Q-values using the target network + next_q_values = self.q_net_target(replay_data.next_observations) # Follow greedy policy: use the one with the highest value - target_q, _ = target_q.max(dim=1) + next_q_values, _ = next_q_values.max(dim=1) # Avoid potential broadcast issue - target_q = target_q.reshape(-1, 1) + next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates - current_q = self.q_net(replay_data.observations) + # Get current Q-values estimates + current_q_values = self.q_net(replay_data.observations) # Retrieve the q-values for the actions from the replay buffer - current_q = th.gather(current_q, dim=1, index=replay_data.actions.long()) + current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q, target_q) + loss = F.smooth_l1_loss(current_q_values, target_q) losses.append(loss.item()) # Optimize the policy diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index a0c299a8f..cd7a41356 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -223,20 +223,20 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: with th.no_grad(): # Select action according to policy next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) - # Compute the target Q value: min over all critics targets - targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) - target_q, _ = th.min(targets, dim=1, keepdim=True) + # Compute the next Q values: min over all critics targets + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) # add entropy term - target_q = target_q - ent_coef * next_log_prob.reshape(-1, 1) + next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1) # td error + entropy term - q_backup = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates for each critic network + # Get current Q-values estimates for each critic network # using action from the replay buffer - current_q_estimates = self.critic(replay_data.observations, replay_data.actions) + current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates]) + critic_loss = 0.5 * sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) critic_losses.append(critic_loss.item()) # Optimize the critic diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index ed74830d3..1a3d0597c 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -146,16 +146,16 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: noise = noise.clamp(-self.target_noise_clip, self.target_noise_clip) next_actions = (self.actor_target(replay_data.next_observations) + noise).clamp(-1, 1) - # Compute the target Q value: min over all critics targets - targets = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) - target_q, _ = th.min(targets, dim=1, keepdim=True) - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * target_q + # Compute the next Q-values: min over all critics targets + next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1) + next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True) + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values - # Get current Q estimates for each critic network - current_q_estimates = self.critic(replay_data.observations, replay_data.actions) + # Get current Q-values estimates for each critic network + current_q_values = self.critic(replay_data.observations, replay_data.actions) # Compute critic loss - critic_loss = sum([F.mse_loss(current_q, target_q) for current_q in current_q_estimates]) + critic_loss = sum([F.mse_loss(current_q, target_q_values) for current_q in current_q_values]) critic_losses.append(critic_loss.item()) # Optimize the critics From ff3756d1a1c5162aab80e85108f834612d1363f7 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 11:54:23 +0100 Subject: [PATCH 3/8] Update docker image --- .gitlab-ci.yml | 2 +- docs/misc/changelog.rst | 1 + scripts/build_docker.sh | 4 ++-- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4813df4ca..6a31f4f19 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:0.9.0a2 +image: stablebaselines/stable-baselines3-cpu:0.11.0a4 type-check: script: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e30493d2c..42018e5d2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -45,6 +45,7 @@ Others: - Improve error message in ``NatureCNN`` - Added checks for supported action spaces to improve clarity of error messages for the user - Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib. +- Updated docker base image to Ubuntu 18.04 Documentation: ^^^^^^^^^^^^^^ diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 0c599a6dd..13ac86b17 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,7 +1,7 @@ #!/bin/bash -CPU_PARENT=ubuntu:16.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu16.04 +CPU_PARENT=ubuntu:18.04 +GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) From 7e2e6fae52479e75328c8e05f781613200d7beb6 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 11:54:39 +0100 Subject: [PATCH 4/8] Set min version for tensorboard --- docs/misc/changelog.rst | 1 + setup.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 42018e5d2..e41f38ba1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -46,6 +46,7 @@ Others: - Added checks for supported action spaces to improve clarity of error messages for the user - Renamed variables in the ``train()`` method of ``SAC``, ``TD3`` and ``DQN`` to match SB3-Contrib. - Updated docker base image to Ubuntu 18.04 +- Set tensorboard min version to 2.2.0 (earlier version are apparently not working with PyTorch) Documentation: ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 72146adba..9b2e649a7 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,7 @@ "atari_py~=0.2.0", "pillow", # Tensorboard support - "tensorboard", + "tensorboard>=2.2.0", # Checking memory taken by replay buffer "psutil", ], From 14c1a11afbe8df4047abacf70377c7a0a1a0c13a Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 12:26:09 +0100 Subject: [PATCH 5/8] Add SB3-Contrib to doc --- README.md | 12 ++++- docs/guide/algos.rst | 4 ++ docs/guide/rl_tips.rst | 7 ++- docs/guide/sb3_contrib.rst | 96 ++++++++++++++++++++++++++++++++++++++ docs/index.rst | 9 ++-- docs/misc/changelog.rst | 1 + setup.py | 5 +- 7 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 docs/guide/sb3_contrib.rst diff --git a/README.md b/README.md index f3baf41ed..600c5da31 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ # Stable Baselines3 -Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). +Stable Baselines3 (SB3) is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). You can read a detailed presentation of Stable Baselines in the [Medium article](https://medium.com/@araffin/stable-baselines-a-fork-of-openai-baselines-reinforcement-learning-made-easy-df87c4b2fc82). @@ -50,7 +50,6 @@ A migration guide from SB2 to SB3 can be found in the [documentation](https://st Documentation is available online: [https://stable-baselines3.readthedocs.io/](https://stable-baselines3.readthedocs.io/) - ## RL Baselines3 Zoo: A Collection of Trained RL Agents [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo). is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3. @@ -68,6 +67,15 @@ Github repo: https://github.com/DLR-RM/rl-baselines3-zoo Documentation: https://stable-baselines3.readthedocs.io/en/master/guide/rl_zoo.html +## SB3-Contrib: Experimental RL Features + +We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) + +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Truncated Quantile Critics (TQC) or Quantile Regression DQN (QR-DQN). + +Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) + + ## Installation **Note:** Stable-Baselines3 supports PyTorch 1.4+. diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 2ca362d98..887bfb9b3 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -31,6 +31,10 @@ Actions ``gym.spaces``: - ``MultiBinary``: A list of possible actions, where each timestep any of the actions can be used in any combination. +.. note:: + + More algorithms (like QR-DQN or TQC) are implemented in our :ref:`contrib repo `. + .. note:: Some logging values (like ``ep_rew_mean``, ``ep_len_mean``) are only available when using a ``Monitor`` wrapper diff --git a/docs/guide/rl_tips.rst b/docs/guide/rl_tips.rst index 29199a1c8..c207c2148 100644 --- a/docs/guide/rl_tips.rst +++ b/docs/guide/rl_tips.rst @@ -87,8 +87,6 @@ Looking at the training curve (episode reward function of the timesteps) is a go - - We suggest you reading `Deep Reinforcement Learning that Matters `_ for a good discussion about RL evaluation. You can also take a look at this `blog post `_ @@ -122,6 +120,7 @@ Discrete Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ DQN with extensions (double DQN, prioritized replay, ...) are the recommended algorithms. +We notably provide QR-DQN in our :ref:`contrib repo `. DQN is usually slower to train (regarding wall clock time) but is the most sample efficient (because of its replay buffer). Discrete Actions - Multiprocessed @@ -136,7 +135,7 @@ Continuous Actions Continuous Actions - Single Process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Current State Of The Art (SOTA) algorithms are ``SAC`` and ``TD3``. +Current State Of The Art (SOTA) algorithms are ``SAC``, ``TD3`` and ``TQC`` (available in our :ref:`contrib repo `). Please use the hyperparameters in the `RL zoo `_ for best results. @@ -156,7 +155,7 @@ Goal Environment ----------------- If your environment follows the ``GoalEnv`` interface (cf :ref:`HER `), then you should use -HER + (SAC/TD3/DDPG/DQN) depending on the action space. +HER + (SAC/TD3/DDPG/DQN/TQC) depending on the action space. .. note:: diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst new file mode 100644 index 000000000..2bfbe260f --- /dev/null +++ b/docs/guide/sb3_contrib.rst @@ -0,0 +1,96 @@ +.. _sb3_contrib: + +================== +SB3 Contrib +================== + +We implement experimental features in a separate contrib repository: +`SB3-Contrib`_ + +This allows Stable-Baselines3 (SB3) to maintain a stable and compact core, while still +providing the latest features, like Truncated Quantile Critics (TQC) or +Quantile Regression DQN (QR-DQN). + +Why create this repository? +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Over the span of stable-baselines and stable-baselines3, the community +has been eager to contribute in form of better logging utilities, +environment wrappers, extended support (e.g. different action spaces) +and learning algorithms. + +However sometimes these utilities were too niche to be considered for +stable-baselines or proved to be too difficult to integrate well into +existing code without a mess. sb3-contrib aims to fix this by not +requiring the neatest code integration with existing code and not +setting limits on what is too niche: almost everything remotely useful +goes! We hope this allows to extend the known quality of +stable-baselines style and documentation beyond the relatively small +scope of utilities of the main repository. + +Features +-------- + +See documentation for the full list of included features. + +**RL Algorithms**: + +- `Truncated Quantile Critics (TQC)`_ +- `Quantile Regression DQN (QR-DQN)`_ + +**Gym Wrappers**: + +- `Time Feature Wrapper`_ + +Documentation +------------- + +Documentation is available online: https://sb3-contrib.readthedocs.io/ + +Installation +------------ + +To install Stable-Baselines3 contrib with pip, execute: + +:: + + pip install sb3-contrib + +We recommend to use the ``master`` version of Stable Baselines3 and SB3-Contrib. + +To install Stable Baselines3 ``master`` version: + +:: + + pip install git+https://github.com/DLR-RM/stable-baselines3 + +To install Stable Baselines3 ``master`` version: + +:: + + pip install git+https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + + +Example +------- + +SB3-Contrib follows SB3 interface and folder structure. So, if you are familiar with SB3, +using SB3-Contrib should be easy too. + +Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. + +.. code-block:: python + + from sb3_contrib import QRDQN + + policy_kwargs = dict(n_quantiles=50) + model = QRDQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("qrdqn_cartpole") + + + +.. _SB3-Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib +.. _Truncated Quantile Critics (TQC): https://arxiv.org/abs/2005.04269 +.. _Quantile Regression DQN (QR-DQN): https://arxiv.org/abs/1710.10044 +.. _Time Feature Wrapper: https://arxiv.org/abs/1712.00378 diff --git a/docs/index.rst b/docs/index.rst index c60e5f397..61ac1d59b 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,10 +3,10 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to Stable Baselines3 docs! - RL Baselines Made Easy -=========================================================== +Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations +======================================================================== -`Stable Baselines3 `_ is a set of improved implementations of reinforcement learning algorithms in PyTorch. +`Stable Baselines3 (SB3) `_ is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of `Stable Baselines `_. @@ -16,6 +16,8 @@ RL Baselines3 Zoo (collection of pre-trained agents): https://github.com/DLR-RM/ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and do hyperparameter tuning. +SB3 Contrib (experimental RL code, latest algorithms): https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + Main Features -------------- @@ -45,6 +47,7 @@ Main Features guide/callbacks guide/tensorboard guide/rl_zoo + guide/sb3_contrib guide/imitation guide/migration guide/checking_nan diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index e41f38ba1..d3b7b3fd6 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -58,6 +58,7 @@ Documentation: - Added example of learning rate schedule - Added SUMO-RL as example project (@LucasAlegre) - Fix docstring of classes in atari_wrappers.py which were inside the constructor (@LucasAlegre) +- Added SB3-Contrib page Pre-Release 0.10.0 (2020-10-28) ------------------------------- diff --git a/setup.py b/setup.py index 9b2e649a7..0ef4e9ba1 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ # Stable Baselines3 -Stable Baselines3 is a set of improved implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). +Stable Baselines3 is a set of reliable implementations of reinforcement learning algorithms in PyTorch. It is the next major version of [Stable Baselines](https://github.com/hill-a/stable-baselines). These algorithms will make it easier for the research community and industry to replicate, refine, and identify new ideas, and will create good baselines to build projects on top of. We expect these tools will be used as a base around which new ideas can be added, and as a tool for comparing a new approach against existing ones. We also hope that the simplicity of these tools will allow beginners to experiment with a more advanced toolset, without being buried in implementation details. @@ -29,6 +29,9 @@ RL Baselines3 Zoo: https://github.com/DLR-RM/rl-baselines3-zoo +SB3 Contrib: +https://github.com/Stable-Baselines-Team/stable-baselines3-contrib + ## Quick example Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms using Gym. From cc4dac42d1b55683e4d1e92944d4925e4ed79346 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 12:52:32 +0100 Subject: [PATCH 6/8] Update DQN --- stable_baselines3/dqn/dqn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 2818e53dc..83772020b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -162,7 +162,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: # Avoid potential broadcast issue next_q_values = next_q_values.reshape(-1, 1) # 1-step TD target - target_q = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values + target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values # Get current Q-values estimates current_q_values = self.q_net(replay_data.observations) @@ -171,7 +171,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) # Compute Huber loss (less sensitive to outliers) - loss = F.smooth_l1_loss(current_q_values, target_q) + loss = F.smooth_l1_loss(current_q_values, target_q_values) losses.append(loss.item()) # Optimize the policy From 9cf72be5c9368e047b6086b9a30b61390fffa6bb Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 15:22:51 +0100 Subject: [PATCH 7/8] Apply suggestions from code review Co-authored-by: Adam Gleave --- docs/guide/sb3_contrib.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 2bfbe260f..8ae24ac8d 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -21,7 +21,7 @@ and learning algorithms. However sometimes these utilities were too niche to be considered for stable-baselines or proved to be too difficult to integrate well into -existing code without a mess. sb3-contrib aims to fix this by not +the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful goes! We hope this allows to extend the known quality of @@ -64,7 +64,7 @@ To install Stable Baselines3 ``master`` version: pip install git+https://github.com/DLR-RM/stable-baselines3 -To install Stable Baselines3 ``master`` version: +To install Stable Baselines3 contrib ``master`` version: :: @@ -74,7 +74,7 @@ To install Stable Baselines3 ``master`` version: Example ------- -SB3-Contrib follows SB3 interface and folder structure. So, if you are familiar with SB3, +SB3-Contrib follows the SB3 API and folder structure. So, if you are familiar with SB3, using SB3-Contrib should be easy too. Here is an example of training a Quantile Regression DQN (QR-DQN) agent on the CartPole environment. From 56b8f65723ac753eb33165e1da0a94d2c48e3dcd Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Mon, 21 Dec 2020 15:24:12 +0100 Subject: [PATCH 8/8] Update wording --- docs/guide/sb3_contrib.rst | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 8ae24ac8d..3d2d15e6f 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -24,9 +24,10 @@ stable-baselines or proved to be too difficult to integrate well into the existing code without creating a mess. sb3-contrib aims to fix this by not requiring the neatest code integration with existing code and not setting limits on what is too niche: almost everything remotely useful -goes! We hope this allows to extend the known quality of -stable-baselines style and documentation beyond the relatively small -scope of utilities of the main repository. +goes! +We hope this allows us to provide reliable implementations +following stable-baselines usual standards (consistent style, documentation, etc) +beyond the relatively small scope of utilities in the main repository. Features --------