Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add learning rate scheduler to BasePolicy #598

Merged
merged 6 commits into from
Apr 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
SHELL=/bin/bash
PROJECT_NAME=tianshou
PROJECT_PATH=${PROJECT_NAME}/
LINT_PATHS=${PROJECT_PATH} test/ docs/conf.py examples/ setup.py
PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} test docs/conf.py examples -type f -name "*.py")

check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade
check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade
Expand All @@ -19,18 +19,18 @@ mypy:
lint:
$(call check_install, flake8)
$(call check_install_extra, bugbear, flake8_bugbear)
flake8 ${LINT_PATHS} --count --show-source --statistics
flake8 ${PYTHON_FILES} --count --show-source --statistics

format:
$(call check_install, isort)
isort ${LINT_PATHS}
isort ${PYTHON_FILES}
$(call check_install, yapf)
yapf -ir ${LINT_PATHS}
yapf -ir ${PYTHON_FILES}

check-codestyle:
$(call check_install, isort)
$(call check_install, yapf)
isort --check ${LINT_PATHS} && yapf -r -d ${LINT_PATHS}
isort --check ${PYTHON_FILES} && yapf -r -d ${PYTHON_FILES}

check-docstyle:
$(call check_install, pydocstyle)
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ numpy
ndarray
stackoverflow
tensorboard
state_dict
len
tac
fqf
Expand Down
2 changes: 1 addition & 1 deletion examples/atari/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pip install envpool

After that, `atari_wrapper` will automatically switch to envpool's Atari env. EnvPool's implementation is much faster (about 2\~3x faster for pure execution speed, 1.5x for overall RL training pipeline) than python vectorized env implementation, and it's behavior is consistent to that approach (OpenAI wrapper), which will describe below.

For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://ppo-details.cleanrl.dev/2021/11/05/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).
For more information, please refer to EnvPool's [GitHub](https://github.com/sail-sg/envpool/), [Docs](https://envpool.readthedocs.io/en/latest/api/atari.html), and [3rd-party report](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/#solving-pong-in-5-minutes-with-ppo--envpool).

## ALE-py

Expand Down
42 changes: 41 additions & 1 deletion test/base/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch

from tianshou.exploration import GaussianNoise, OUNoise
from tianshou.utils import MovAvg, RunningMeanStd
from tianshou.utils import MovAvg, MultipleLRSchedulers, RunningMeanStd
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import RecurrentActorProb, RecurrentCritic

Expand Down Expand Up @@ -99,8 +99,48 @@ def test_net():
assert list(net(data, act).shape) == [bsz, 1]


def test_lr_schedulers():
initial_lr_1 = 10.0
step_size_1 = 1
gamma_1 = 0.5
net_1 = torch.nn.Linear(2, 3)
optim_1 = torch.optim.Adam(net_1.parameters(), lr=initial_lr_1)
sched_1 = torch.optim.lr_scheduler.StepLR(
optim_1, step_size=step_size_1, gamma=gamma_1
)

initial_lr_2 = 5.0
step_size_2 = 2
gamma_2 = 0.3
net_2 = torch.nn.Linear(3, 2)
optim_2 = torch.optim.Adam(net_2.parameters(), lr=initial_lr_2)
sched_2 = torch.optim.lr_scheduler.StepLR(
optim_2, step_size=step_size_2, gamma=gamma_2
)
schedulers = MultipleLRSchedulers(sched_1, sched_2)
for _ in range(10):
loss_1 = (torch.ones((1, 3)) - net_1(torch.ones((1, 2)))).sum()
optim_1.zero_grad()
loss_1.backward()
optim_1.step()
loss_2 = (torch.ones((1, 2)) - net_2(torch.ones((1, 3)))).sum()
optim_2.zero_grad()
loss_2.backward()
optim_2.step()
schedulers.step()
assert (
optim_1.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_1 * gamma_1**(10 // step_size_1))
)
assert (
optim_2.state_dict()["param_groups"][0]["lr"] ==
(initial_lr_2 * gamma_2**(10 // step_size_2))
)


if __name__ == '__main__':
test_noise()
test_moving_average()
test_rms()
test_net()
test_lr_schedulers()
6 changes: 6 additions & 0 deletions tianshou/policy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn

from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.utils import MultipleLRSchedulers


class BasePolicy(ABC, nn.Module):
Expand Down Expand Up @@ -64,6 +65,8 @@ def __init__(
action_space: Optional[gym.Space] = None,
action_scaling: bool = False,
action_bound_method: str = "",
lr_scheduler: Optional[Union[torch.optim.lr_scheduler.LambdaLR,
MultipleLRSchedulers]] = None,
) -> None:
super().__init__()
self.observation_space = observation_space
Expand All @@ -79,6 +82,7 @@ def __init__(
# can be one of ("clip", "tanh", ""), empty string means no bounding
assert action_bound_method in ("", "clip", "tanh")
self.action_bound_method = action_bound_method
self.lr_scheduler = lr_scheduler
self._compile()

def set_agent_id(self, agent_id: int) -> None:
Expand Down Expand Up @@ -272,6 +276,8 @@ def update(self, sample_size: int, buffer: Optional[ReplayBuffer],
batch = self.process_fn(batch, buffer, indices)
result = self.learn(batch, **kwargs)
self.post_process_fn(batch, buffer, indices)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
self.updating = False
return result

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class ImitationPolicy(BasePolicy):
:class:`~tianshou.policy.BasePolicy`. (s -> a)
:param torch.optim.Optimizer optim: for optimizing the model.
:param gym.Space action_space: env's action space.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class BCQPolicy(BasePolicy):
:param int num_sampled_action: the number of sampled actions in calculating
target Q. The algorithm samples several actions using VAE, and perturbs
each action to get the target Q. Default to 10.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class CQLPolicy(SACPolicy):
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
:param Union[str, torch.device] device: which device to create this model on.
Default to "cpu".
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/discrete_bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class DiscreteBCQPolicy(DQNPolicy):
logits. Default to 1e-2.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/discrete_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class DiscreteCQLPolicy(QRDQNPolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param float min_q_weight: the weight for the cql loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::
Please refer to :class:`~tianshou.policy.QRDQNPolicy` for more detailed
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/discrete_crr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class DiscreteCRRPolicy(PGPolicy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::
Please refer to :class:`~tianshou.policy.PGPolicy` for more detailed
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/imitation/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class GAILPolicy(PPOPolicy):
optimizer in each policy.update(). Default to None (no lr_scheduler).
:param bool deterministic_eval: whether to use deterministic action instead of
stochastic action sampled by the policy. Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelbased/icm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class ICMPolicy(BasePolicy):
:param torch.optim.Optimizer optim: a torch.optim for optimizing the model.
:param float lr_scale: the scaling factor for ICM learning.
:param float forward_loss_weight: the weight for forward model loss.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelbased/psrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class PSRLModel(object):
of rewards, with shape (n_state, n_action).
:param float discount_factor: in [0, 1].
:param float epsilon: for precision control in value iteration.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).
"""

def __init__(
Expand Down
3 changes: 0 additions & 3 deletions tianshou/policy/modelfree/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ def learn( # type: ignore
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss": losses,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class C51Policy(DQNPolicy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class DDPGPolicy(BasePolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/discrete_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class DiscreteSACPolicy(SACPolicy):
alpha is automatically tuned.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class DQNPolicy(BasePolicy):
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param bool is_double: use double dqn. Default to True.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class FQFPolicy(QRDQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class IQNPolicy(QRDQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
4 changes: 0 additions & 4 deletions tianshou/policy/modelfree/npg.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ def learn( # type: ignore
vf_losses.append(vf_loss.item())
kls.append(kl.item())

# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,
Expand Down
5 changes: 0 additions & 5 deletions tianshou/policy/modelfree/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
reward_normalization: bool = False,
action_scaling: bool = True,
action_bound_method: str = "clip",
lr_scheduler: Optional[torch.optim.lr_scheduler.LambdaLR] = None,
deterministic_eval: bool = False,
**kwargs: Any,
) -> None:
Expand All @@ -55,7 +54,6 @@ def __init__(
)
self.actor = model
self.optim = optim
self.lr_scheduler = lr_scheduler
self.dist_fn = dist_fn
assert 0.0 <= discount_factor <= 1.0, "discount factor should be in [0, 1]"
self._gamma = discount_factor
Expand Down Expand Up @@ -137,8 +135,5 @@ def learn( # type: ignore
loss.backward()
self.optim.step()
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {"loss": losses}
3 changes: 0 additions & 3 deletions tianshou/policy/modelfree/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ def learn( # type: ignore
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss": losses,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class QRDQNPolicy(DQNPolicy):
you do not use the target network).
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/rainbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class RainbowPolicy(C51Policy):
you do not use the target network). Default to 0.
:param bool reward_normalization: normalize the reward to Normal(0, 1).
Default to False.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class SACPolicy(DDPGPolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
2 changes: 2 additions & 0 deletions tianshou/policy/modelfree/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class TD3Policy(DDPGPolicy):
Default to "clip".
:param Optional[gym.Space] action_space: env's action space, mandatory if you want
to use option "action_scaling" or "action_bound_method". Default to None.
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
optimizer in each policy.update(). Default to None (no lr_scheduler).

.. seealso::

Expand Down
4 changes: 0 additions & 4 deletions tianshou/policy/modelfree/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,6 @@ def learn( # type: ignore
step_sizes.append(step_size.item())
kls.append(kl.item())

# update learning rate if lr_scheduler is given
if self.lr_scheduler is not None:
self.lr_scheduler.step()

return {
"loss/actor": actor_losses,
"loss/vf": vf_losses,
Expand Down
2 changes: 2 additions & 0 deletions tianshou/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tianshou.utils.logger.base import BaseLogger, LazyLogger
from tianshou.utils.logger.tensorboard import BasicLogger, TensorboardLogger
from tianshou.utils.logger.wandb import WandbLogger
from tianshou.utils.lr_scheduler import MultipleLRSchedulers
from tianshou.utils.statistics import MovAvg, RunningMeanStd
from tianshou.utils.warning import deprecation

Expand All @@ -17,4 +18,5 @@
"LazyLogger",
"WandbLogger",
"deprecation",
"MultipleLRSchedulers",
]
Loading