From 8712997b1b2e972c14c81484c5729f441c9212b9 Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Wed, 4 Oct 2023 23:05:00 +0200 Subject: [PATCH 1/4] Fix log_interval handling in OffPolicyAlgorithm (#1708) --- stable_baselines3/common/base_class.py | 2 +- stable_baselines3/common/off_policy_algorithm.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 5e8759990..f5b972bfe 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -523,7 +523,7 @@ def learn( :param total_timesteps: The total number of samples (env steps) to train on :param callback: callback(s) called at every step with state of the algorithm. - :param log_interval: The number of episodes before logging. + :param log_interval: The number of rounds (environment interactions + agent updates) between logging. :param tb_log_name: the name of the run for TensorBoard logging :param reset_num_timesteps: whether or not to reset the current timestep number (used in logging) :param progress_bar: Display a progress bar using tqdm and rich. diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 2caaf8e97..76e277ecc 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -309,6 +309,8 @@ def learn( reset_num_timesteps: bool = True, progress_bar: bool = False, ) -> SelfOffPolicyAlgorithm: + iteration = 0 + total_timesteps, callback = self._setup_learn( total_timesteps, callback, @@ -327,12 +329,16 @@ def learn( callback=callback, learning_starts=self.learning_starts, replay_buffer=self.replay_buffer, - log_interval=log_interval, ) if rollout.continue_training is False: break + iteration += 1 + + if log_interval is not None and iteration % log_interval == 0: + self._dump_logs() + if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout @@ -502,7 +508,6 @@ def collect_rollouts( replay_buffer: ReplayBuffer, action_noise: Optional[ActionNoise] = None, learning_starts: int = 0, - log_interval: Optional[int] = None, ) -> RolloutReturn: """ Collect experiences and store them into a ``ReplayBuffer``. @@ -520,7 +525,6 @@ def collect_rollouts( in addition to the stochastic policy for SAC. :param learning_starts: Number of steps before learning for the warm-up phase. :param replay_buffer: - :param log_interval: Log data every ``log_interval`` episodes :return: """ # Switch to eval mode (this affects batch norm / dropout) @@ -583,9 +587,6 @@ def collect_rollouts( kwargs = dict(indices=[idx]) if env.num_envs > 1 else {} action_noise.reset(**kwargs) - # Log training infos - if log_interval is not None and self._episode_num % log_interval == 0: - self._dump_logs() callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) From 574307a5661efe419ac961ee7156abc53c29ac1d Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Wed, 4 Oct 2023 23:05:18 +0200 Subject: [PATCH 2/4] Update changelog (#1708) --- docs/misc/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3fade42aa..e653fc37f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -14,6 +14,7 @@ New Features: ^^^^^^^^^^^^^ - Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined) - Improved error message when mixing Gym API with VecEnv API (see GH#1694) +- Added capability to log on a step-based interval in OffPolicyAlgorithm (@tobiabir) Bug Fixes: ^^^^^^^^^^ @@ -1466,3 +1467,4 @@ And all the contributors: @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger +@tobiabir From 43dc8f9d27b9f61839094fc50a004668a98848cd Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Wed, 4 Oct 2023 23:07:08 +0200 Subject: [PATCH 3/4] Fix order of logging and training (#1708) --- stable_baselines3/common/off_policy_algorithm.py | 10 +++++----- stable_baselines3/common/on_policy_algorithm.py | 7 ++++--- tests/test_run.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 76e277ecc..04325e6e0 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -334,11 +334,6 @@ def learn( if rollout.continue_training is False: break - iteration += 1 - - if log_interval is not None and iteration % log_interval == 0: - self._dump_logs() - if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout @@ -347,6 +342,11 @@ def learn( if gradient_steps > 0: self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + iteration += 1 + + if log_interval is not None and iteration % log_interval == 0: + self._dump_logs() + callback.on_training_end() return self diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1e0f9e6c9..2641b915f 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -268,9 +268,12 @@ def learn( if continue_training is False: break - iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + self.train() + + iteration += 1 + # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None @@ -285,8 +288,6 @@ def learn( self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) - self.train() - callback.on_training_end() return self diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..8bc43dce1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -229,7 +229,7 @@ def test_ppo_warnings(): # in that case with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"): model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1) - model.learn(64) + model.learn(64, log_interval=2) loss = model.logger.name_to_value["train/loss"] assert loss > 0 From d0d19123409209672ed49f5e585fe6c0488750ee Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Fri, 6 Oct 2023 13:21:16 +0200 Subject: [PATCH 4/4] Update log_interval default values (#1708) --- stable_baselines3/ddpg/ddpg.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/td3.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index c311b2357..7c666e876 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -115,7 +115,7 @@ def learn( self: SelfDDPG, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "DDPG", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 42e3d0df0..d4ed512b1 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -259,7 +259,7 @@ def learn( self: SelfDQN, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "DQN", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index bf0fa5028..ec3d63181 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -299,7 +299,7 @@ def learn( self: SelfSAC, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "SAC", reset_num_timesteps: bool = True, progress_bar: bool = False, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index a06ce67e0..e15f18e60 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -214,7 +214,7 @@ def learn( self: SelfTD3, total_timesteps: int, callback: MaybeCallback = None, - log_interval: int = 4, + log_interval: int = 1000, tb_log_name: str = "TD3", reset_num_timesteps: bool = True, progress_bar: bool = False,