From 9c6353b836236303af0bf4056a9f3969a5dc1ffb Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Sat, 16 Apr 2022 18:44:29 -0700 Subject: [PATCH 1/6] Add LR scheduler to BasePolicy --- test/base/test_utils.py | 42 ++++++++++++++++++++++- tianshou/policy/base.py | 4 +++ tianshou/policy/imitation/base.py | 2 ++ tianshou/policy/imitation/bcq.py | 2 ++ tianshou/policy/imitation/cql.py | 2 ++ tianshou/policy/imitation/discrete_bcq.py | 2 ++ tianshou/policy/imitation/discrete_cql.py | 2 ++ tianshou/policy/imitation/discrete_crr.py | 2 ++ tianshou/policy/imitation/gail.py | 2 ++ tianshou/policy/modelbased/icm.py | 2 ++ tianshou/policy/modelbased/psrl.py | 2 ++ tianshou/policy/modelfree/a2c.py | 3 -- tianshou/policy/modelfree/c51.py | 2 ++ tianshou/policy/modelfree/ddpg.py | 2 ++ tianshou/policy/modelfree/discrete_sac.py | 2 ++ tianshou/policy/modelfree/dqn.py | 2 ++ tianshou/policy/modelfree/fqf.py | 2 ++ tianshou/policy/modelfree/iqn.py | 2 ++ tianshou/policy/modelfree/npg.py | 4 --- tianshou/policy/modelfree/pg.py | 5 --- tianshou/policy/modelfree/ppo.py | 3 -- tianshou/policy/modelfree/qrdqn.py | 2 ++ tianshou/policy/modelfree/rainbow.py | 2 ++ tianshou/policy/modelfree/sac.py | 2 ++ tianshou/policy/modelfree/td3.py | 2 ++ tianshou/policy/modelfree/trpo.py | 4 --- tianshou/utils/__init__.py | 2 ++ tianshou/utils/lr_scheduler.py | 20 +++++++++++ 28 files changed, 105 insertions(+), 20 deletions(-) create mode 100644 tianshou/utils/lr_scheduler.py diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 38bf5d40e..269223130 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,7 +2,7 @@ import torch from tianshou.exploration import GaussianNoise, OUNoise -from tianshou.utils import MovAvg, RunningMeanStd +from tianshou.utils import MovAvg, RunningMeanStd, MultipleLRSchedulers from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic @@ -99,8 +99,48 @@ def test_net(): assert list(net(data, act).shape) == [bsz, 1] +def test_lr_schedulers(): + initial_lr_1 = 10.0 + step_size_1 = 1 + gamma_1 = 0.5 + net_1 = torch.nn.Linear(2, 3) + optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1) + sched_1 = torch.optim.lr_scheduler.StepLR( + optim_1, step_size=step_size_1, gamma=gamma_1 + ) + + initial_lr_2 = 5.0 + step_size_2 = 2 + gamma_2 = 0.3 + net_2 = torch.nn.Linear(3, 2) + optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2) + sched_2 = torch.optim.lr_scheduler.StepLR( + optim_2, step_size=step_size_2, gamma=gamma_2 + ) + schedulers = MultipleLRSchedulers(sched_1, sched_2) + for _ in range(10): + loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum() + optim_1.zero_grad() + loss_1.backward() + optim_1.step() + loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum() + optim_2.zero_grad() + loss_2.backward() + optim_2.step() + schedulers.step() + assert ( + optim_1.state_dict()["param_groups"][0]["lr"] == + (initial_lr_1 * gamma_1**(10 // step_size_1)) + ) + assert ( + optim_2.state_dict()["param_groups"][0]["lr"] == + (initial_lr_2 * gamma_2**(10 // step_size_2)) + ) + + if __name__ == '__main__': test_noise() test_moving_average() test_rms() test_net() + test_lr_schedulers() diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index fd055abb3..0e7fa4a49 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -64,6 +64,7 @@ def __init__( action_space: Optional[gym.Space] = None, action_scaling: bool = False, action_bound_method: str = "", + lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, ) -> None: super().__init__() self.observation_space = observation_space @@ -79,6 +80,7 @@ def __init__( # can be one of ("clip", "tanh", ""), empty string means no bounding assert action_bound_method in ("", "clip", "tanh") self.action_bound_method = action_bound_method + self.lr_scheduler = lr_scheduler self._compile() def set_agent_id(self, agent_id: int) -> None: @@ -272,6 +274,8 @@ def update(self, sample_size: int, buffer: Optional[ReplayBuffer], batch = self.process_fn(batch, buffer, indices) result = self.learn(batch, **kwargs) self.post_process_fn(batch, buffer, indices) + if self.lr_scheduler is not None: + self.lr_scheduler.step() self.updating = False return result diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 405b8c762..20c6a6212 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy): :class:`~tianshou.policy.BasePolicy`. (s -> a) :param torch.optim.Optimizer optim: for optimizing the model. :param gym.Space action_space: env's action space. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/imitation/bcq.py b/tianshou/policy/imitation/bcq.py index afd9be90d..89facd7d5 100644 --- a/tianshou/policy/imitation/bcq.py +++ b/tianshou/policy/imitation/bcq.py @@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy): :param int num_sampled_action: the number of sampled actions in calculating target Q. The algorithm samples several actions using VAE, and perturbs each action to get the target Q. Default to 10. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/imitation/cql.py b/tianshou/policy/imitation/cql.py index 102ede058..dd890abb4 100644 --- a/tianshou/policy/imitation/cql.py +++ b/tianshou/policy/imitation/cql.py @@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy): :param float clip_grad: clip_grad for updating critic network. Default to 1.0. :param Union[str, torch.device] device: which device to create this model on. Default to "cpu". + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index bca9b09af..95b5dea5f 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy): logits. Default to 1e-2. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index 1adbb26f7..217e96001 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param float min_q_weight: the weight for the cql loss. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index b182ead13..edbd25d08 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy): you do not use the target network). Default to 0. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index d4779326b..434e12e60 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy): optimizer in each policy.update(). Default to None (no lr_scheduler). :param bool deterministic_eval: whether to use deterministic action instead of stochastic action sampled by the policy. Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 5a723c10b..66ab1bbc4 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy): :param torch.optim.Optimizer optim: a torch.optim for optimizing the model. :param float lr_scale: the scaling factor for ICM learning. :param float forward_loss_weight: the weight for forward model loss. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelbased/psrl.py b/tianshou/policy/modelbased/psrl.py index 3caab0d63..8e7473adc 100644 --- a/tianshou/policy/modelbased/psrl.py +++ b/tianshou/policy/modelbased/psrl.py @@ -18,6 +18,8 @@ class PSRLModel(object): of rewards, with shape (n_state, n_action). :param float discount_factor: in [0, 1]. :param float epsilon: for precision control in value iteration. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). """ def __init__( diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2ad5bb3b2..b218cbf81 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -146,9 +146,6 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) - # update learning rate if lr_scheduler is given - if self.lr_scheduler is not None: - self.lr_scheduler.step() return { "loss": losses, diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index e49b2d697..3ebc1cfa8 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -25,6 +25,8 @@ class C51Policy(DQNPolicy): you do not use the target network). Default to 0. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 0a779c661..0637d05ab 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy): Default to "clip". :param Optional[gym.Space] action_space: env's action space, mandatory if you want to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 33e06da30..2a626a722 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy): alpha is automatically tuned. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index d03c3e3cb..593de15fd 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy): :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. :param bool is_double: use double dqn. Default to True. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 054781a7a..9eee122d3 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy): you do not use the target network). :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index 502dd693d..74b8d78d9 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy): you do not use the target network). :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index ce91fdb6e..e3ab3087c 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -127,10 +127,6 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) kls.append(kl.item()) - # update learning rate if lr_scheduler is given - if self.lr_scheduler is not None: - self.lr_scheduler.step() - return { "loss/actor": actor_losses, "loss/vf": vf_losses, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 9149a383b..1557c1ad6 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -44,7 +44,6 @@ def __init__( reward_normalization: bool = False, action_scaling: bool = True, action_bound_method: str = "clip", - lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, deterministic_eval: bool = False, **kwargs: Any, ) -> None: @@ -55,7 +54,6 @@ def __init__( ) self.actor = model self.optim = optim - self.lr_scheduler = lr_scheduler self.dist_fn = dist_fn assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]" self._gamma = discount_factor @@ -137,8 +135,5 @@ def learn( # type: ignore loss.backward() self.optim.step() losses.append(loss.item()) - # update learning rate if lr_scheduler is given - if self.lr_scheduler is not None: - self.lr_scheduler.step() return {"loss": losses} diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index fe5aa2fa6..3c19daf1e 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -152,9 +152,6 @@ def learn( # type: ignore vf_losses.append(vf_loss.item()) ent_losses.append(ent_loss.item()) losses.append(loss.item()) - # update learning rate if lr_scheduler is given - if self.lr_scheduler is not None: - self.lr_scheduler.step() return { "loss": losses, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index ea4913f62..39dde3d4a 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy): you do not use the target network). :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/rainbow.py b/tianshou/policy/modelfree/rainbow.py index 9028258d7..773abddc4 100644 --- a/tianshou/policy/modelfree/rainbow.py +++ b/tianshou/policy/modelfree/rainbow.py @@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy): you do not use the target network). Default to 0. :param bool reward_normalization: normalize the reward to Normal(0, 1). Default to False. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index abe707d01..5f17427a7 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy): Default to "clip". :param Optional[gym.Space] action_space: env's action space, mandatory if you want to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 8ad31db06..003e09968 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy): Default to "clip". :param Optional[gym.Space] action_space: env's action space, mandatory if you want to use option "action_scaling" or "action_bound_method". Default to None. + :param lr_scheduler: a learning rate scheduler that adjusts the learning rate in + optimizer in each policy.update(). Default to None (no lr_scheduler). .. seealso:: diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index 2803a2df1..00688af69 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -146,10 +146,6 @@ def learn( # type: ignore step_sizes.append(step_size.item()) kls.append(kl.item()) - # update learning rate if lr_scheduler is given - if self.lr_scheduler is not None: - self.lr_scheduler.step() - return { "loss/actor": actor_losses, "loss/vf": vf_losses, diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index 2d7ea4906..e630ccf8d 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -6,6 +6,7 @@ from tianshou.utils.logger.wandb import WandbLogger from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation +from tianshou.utils.lr_scheduler import MultipleLRSchedulers __all__ = [ "MovAvg", @@ -17,4 +18,5 @@ "LazyLogger", "WandbLogger", "deprecation", + "MultipleLRSchedulers", ] diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py new file mode 100644 index 000000000..93e05f1c9 --- /dev/null +++ b/tianshou/utils/lr_scheduler.py @@ -0,0 +1,20 @@ +from typing import List, Dict + +import torch + + +class MultipleLRSchedulers: + + def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): + self.schedulers = args + + def step(self) -> None: + for scheduler in self.schedulers: + scheduler.step() + + def state_dict(self) -> List[Dict]: + return [s.state_dict() for s in self.schedulers] + + def load_state_dict(self, state_dict: List[Dict]) -> None: + for (s, sd) in zip(self.schedulers, state_dict): + s.__dict__.update(sd) From 188b55f42da64bbbb6b8ab77b6264a3825c3bdc4 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Sat, 16 Apr 2022 18:56:16 -0700 Subject: [PATCH 2/6] fix type annotation --- tianshou/policy/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 0e7fa4a49..325a85b62 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -9,6 +9,7 @@ from torch import nn from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as +from tianshou.utils import MultipleLRSchedulers class BasePolicy(ABC, nn.Module): @@ -64,7 +65,8 @@ def __init__( action_space: Optional[gym.Space] = None, action_scaling: bool = False, action_bound_method: str = "", - lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None, + lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR, + MultipleLRSchedulers]] = None, ) -> None: super().__init__() self.observation_space = observation_space From d86775bd692d79b26a26735056638a28508f466a Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Sat, 16 Apr 2022 19:35:45 -0700 Subject: [PATCH 3/6] fix isort --- test/base/test_utils.py | 2 +- tianshou/utils/__init__.py | 2 +- tianshou/utils/lr_scheduler.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 269223130..b99b0742b 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -2,7 +2,7 @@ import torch from tianshou.exploration import GaussianNoise, OUNoise -from tianshou.utils import MovAvg, RunningMeanStd, MultipleLRSchedulers +from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd from tianshou.utils.net.common import MLP, Net from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic diff --git a/tianshou/utils/__init__.py b/tianshou/utils/__init__.py index e630ccf8d..8acd06c4a 100644 --- a/tianshou/utils/__init__.py +++ b/tianshou/utils/__init__.py @@ -4,9 +4,9 @@ from tianshou.utils.logger.base import BaseLogger, LazyLogger from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger from tianshou.utils.logger.wandb import WandbLogger +from tianshou.utils.lr_scheduler import MultipleLRSchedulers from tianshou.utils.statistics import MovAvg, RunningMeanStd from tianshou.utils.warning import deprecation -from tianshou.utils.lr_scheduler import MultipleLRSchedulers __all__ = [ "MovAvg", diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index 93e05f1c9..6aea16998 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import Dict, List import torch From 10bc5ca146d074a1f261b09719ba479219968cdf Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Sat, 16 Apr 2022 20:00:06 -0700 Subject: [PATCH 4/6] fix docstrings --- tianshou/utils/lr_scheduler.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index 6aea16998..75f85e1ef 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -4,17 +4,32 @@ class MultipleLRSchedulers: + """A wrapper for multiple learning rate schedulers. + + Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step()` is called, \ + it calls the step() method of each of the schedulers that it contains. + """ def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): self.schedulers = args def step(self) -> None: + """Take a step in each of the learning rate schedulers.""" for scheduler in self.schedulers: scheduler.step() def state_dict(self) -> List[Dict]: + """Get state dictionaries for each of the learning rate schedulers. + + :return: A list of state_dicts of learning rate schedulers + """ return [s.state_dict() for s in self.schedulers] def load_state_dict(self, state_dict: List[Dict]) -> None: + """Load states from dictionaries. + + :param List[Dict] state_dict: A list of learning rate scheduler + state_dicts, in the same order as the schedulers. + """ for (s, sd) in zip(self.schedulers, state_dict): s.__dict__.update(sd) From 1e78b4d951f303b8909c0fc47f432b7c151ec550 Mon Sep 17 00:00:00 2001 From: Alex Nikulkov Date: Sat, 16 Apr 2022 20:55:00 -0700 Subject: [PATCH 5/6] fix spelling error --- tianshou/utils/lr_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index 75f85e1ef..beb66ea7e 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -21,7 +21,7 @@ def step(self) -> None: def state_dict(self) -> List[Dict]: """Get state dictionaries for each of the learning rate schedulers. - :return: A list of state_dicts of learning rate schedulers + :return: A list of state dictionaries of learning rate schedulers """ return [s.state_dict() for s in self.schedulers] @@ -29,7 +29,7 @@ def load_state_dict(self, state_dict: List[Dict]) -> None: """Load states from dictionaries. :param List[Dict] state_dict: A list of learning rate scheduler - state_dicts, in the same order as the schedulers. + state dictionaries, in the same order as the schedulers. """ for (s, sd) in zip(self.schedulers, state_dict): s.__dict__.update(sd) From 86b181dbdf4f7bca43bc66f262e95be3c70a2f11 Mon Sep 17 00:00:00 2001 From: Zixu Chen Date: Sun, 17 Apr 2022 10:06:22 -0400 Subject: [PATCH 6/6] polish --- Makefile | 10 +++++----- docs/spelling_wordlist.txt | 1 + examples/atari/README.md | 2 +- tianshou/utils/lr_scheduler.py | 17 ++++++++++++----- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/Makefile b/Makefile index b4c4a623e..b9967f886 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ SHELL=/bin/bash PROJECT_NAME=tianshou PROJECT_PATH=${PROJECT_NAME}/ -LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py +PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py") check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade @@ -19,18 +19,18 @@ mypy: lint: $(call check_install, flake8) $(call check_install_extra, bugbear, flake8_bugbear) - flake8 ${LINT_PATHS} --count --show-source --statistics + flake8 ${PYTHON_FILES} --count --show-source --statistics format: $(call check_install, isort) - isort ${LINT_PATHS} + isort ${PYTHON_FILES} $(call check_install, yapf) - yapf -ir ${LINT_PATHS} + yapf -ir ${PYTHON_FILES} check-codestyle: $(call check_install, isort) $(call check_install, yapf) - isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS} + isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES} check-docstyle: $(call check_install, pydocstyle) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 820fb8867..8a8f9755a 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -18,6 +18,7 @@ numpy ndarray stackoverflow tensorboard +state_dict len tac fqf diff --git a/examples/atari/README.md b/examples/atari/README.md index 561255b20..313a6fa28 100644 --- a/examples/atari/README.md +++ b/examples/atari/README.md @@ -10,7 +10,7 @@ pip install envpool After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below. -For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). +For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool). ## ALE-py diff --git a/tianshou/utils/lr_scheduler.py b/tianshou/utils/lr_scheduler.py index beb66ea7e..36a08b207 100644 --- a/tianshou/utils/lr_scheduler.py +++ b/tianshou/utils/lr_scheduler.py @@ -6,8 +6,15 @@ class MultipleLRSchedulers: """A wrapper for multiple learning rate schedulers. - Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step()` is called, \ + Every time :meth:`~tianshou.utils.MultipleLRSchedulers.step` is called, it calls the step() method of each of the schedulers that it contains. + Example usage: + :: + + scheduler1 = ConstantLR(opt1, factor=0.1, total_iters=2) + scheduler2 = ExponentialLR(opt2, gamma=0.9) + scheduler = MultipleLRSchedulers(scheduler1, scheduler2) + policy = PPOPolicy(..., lr_scheduler=scheduler) """ def __init__(self, *args: torch.optim.lr_scheduler.LambdaLR): @@ -19,17 +26,17 @@ def step(self) -> None: scheduler.step() def state_dict(self) -> List[Dict]: - """Get state dictionaries for each of the learning rate schedulers. + """Get state_dict for each of the learning rate schedulers. - :return: A list of state dictionaries of learning rate schedulers + :return: A list of state_dict of learning rate schedulers. """ return [s.state_dict() for s in self.schedulers] def load_state_dict(self, state_dict: List[Dict]) -> None: - """Load states from dictionaries. + """Load states from state_dict. :param List[Dict] state_dict: A list of learning rate scheduler - state dictionaries, in the same order as the schedulers. + state_dict, in the same order as the schedulers. """ for (s, sd) in zip(self.schedulers, state_dict): s.__dict__.update(sd)