Skip to content

Commit

Permalink
Merge pull request #195 from WenjieDu/lr_scheduler
Browse files Browse the repository at this point in the history
Add learning-rate schedulers
  • Loading branch information
WenjieDu authored Sep 28, 2023
2 parents 28b9fdc + 1798ecf commit fc79142
Show file tree
Hide file tree
Showing 18 changed files with 1,053 additions and 13 deletions.
9 changes: 9 additions & 0 deletions docs/pypots.optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ pypots.optim.base module
:undoc-members:
:show-inheritance:
:inherited-members:

pypots.optim.lr_scheduler module
------------------------------

.. automodule:: pypots.optim.lr_scheduler
:members:
:undoc-members:
:show-inheritance:
:inherited-members:
9 changes: 7 additions & 2 deletions pypots/optim/adadelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

from typing import Iterable
from typing import Iterable, Optional

from torch.optim import Adadelta as torch_Adadelta

from .base import Optimizer
from .lr_scheduler.base import LRScheduler


class Adadelta(Optimizer):
Expand Down Expand Up @@ -39,8 +40,9 @@ def __init__(
rho: float = 0.9,
eps: float = 1e-08,
weight_decay: float = 0.01,
lr_scheduler: Optional[LRScheduler] = None,
):
super().__init__(lr)
super().__init__(lr, lr_scheduler)
self.rho = rho
self.eps = eps
self.weight_decay = weight_decay
Expand All @@ -61,3 +63,6 @@ def init_optimizer(self, params: Iterable) -> None:
eps=self.eps,
weight_decay=self.weight_decay,
)

if self.lr_scheduler is not None:
self.lr_scheduler.init_scheduler(self.torch_optimizer)
9 changes: 7 additions & 2 deletions pypots/optim/adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

from typing import Iterable
from typing import Iterable, Optional

from torch.optim import Adagrad as torch_Adagrad

from .base import Optimizer
from .lr_scheduler.base import LRScheduler


class Adagrad(Optimizer):
Expand Down Expand Up @@ -43,8 +44,9 @@ def __init__(
weight_decay: float = 0.01,
initial_accumulator_value: float = 0.01, # it is set as 0 in the torch implementation, but delta shouldn't be 0
eps: float = 1e-08,
lr_scheduler: Optional[LRScheduler] = None,
):
super().__init__(lr)
super().__init__(lr, lr_scheduler)
self.lr_decay = lr_decay
self.weight_decay = weight_decay
self.initial_accumulator_value = initial_accumulator_value
Expand All @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None:
initial_accumulator_value=self.initial_accumulator_value,
eps=self.eps,
)

if self.lr_scheduler is not None:
self.lr_scheduler.init_scheduler(self.torch_optimizer)
9 changes: 7 additions & 2 deletions pypots/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

from typing import Iterable, Tuple
from typing import Iterable, Tuple, Optional

from torch.optim import Adam as torch_Adam

from .base import Optimizer
from .lr_scheduler.base import LRScheduler


class Adam(Optimizer):
Expand Down Expand Up @@ -42,8 +43,9 @@ def __init__(
eps: float = 1e-08,
weight_decay: float = 0,
amsgrad: bool = False,
lr_scheduler: Optional[LRScheduler] = None,
):
super().__init__(lr)
super().__init__(lr, lr_scheduler)
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
Expand All @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None:
weight_decay=self.weight_decay,
amsgrad=self.amsgrad,
)

if self.lr_scheduler is not None:
self.lr_scheduler.init_scheduler(self.torch_optimizer)
9 changes: 7 additions & 2 deletions pypots/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

from typing import Iterable, Tuple
from typing import Iterable, Tuple, Optional

from torch.optim import AdamW as torch_AdamW

from .base import Optimizer
from .lr_scheduler.base import LRScheduler


class AdamW(Optimizer):
Expand Down Expand Up @@ -42,8 +43,9 @@ def __init__(
eps: float = 1e-08,
weight_decay: float = 0.01,
amsgrad: bool = False,
lr_scheduler: Optional[LRScheduler] = None,
):
super().__init__(lr)
super().__init__(lr, lr_scheduler)
self.betas = betas
self.eps = eps
self.weight_decay = weight_decay
Expand All @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None:
weight_decay=self.weight_decay,
amsgrad=self.amsgrad,
)

if self.lr_scheduler is not None:
self.lr_scheduler.init_scheduler(self.torch_optimizer)
8 changes: 7 additions & 1 deletion pypots/optim/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterable, Optional

from .lr_scheduler.base import LRScheduler


class Optimizer(ABC):
"""The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim.
Expand All @@ -35,9 +37,10 @@ class Optimizer(ABC):
"""

def __init__(self, lr):
def __init__(self, lr, lr_scheduler: Optional[LRScheduler] = None):
self.lr = lr
self.torch_optimizer = None
self.lr_scheduler = lr_scheduler

@abstractmethod
def init_optimizer(self, params: Iterable) -> None:
Expand Down Expand Up @@ -97,6 +100,9 @@ def step(self, closure: Optional[Callable] = None) -> None:
"""
self.torch_optimizer.step(closure)

if self.lr_scheduler is not None:
self.lr_scheduler.step()

def zero_grad(self, set_to_none: bool = True) -> None:
"""Sets the gradients of all optimized ``torch.Tensor`` to zero.
Expand Down
29 changes: 29 additions & 0 deletions pypots/optim/lr_scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch,
the only difference that is also why we implement them is that you don't have to pass according optimizers
into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer
after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize
schedulers together with optimizers.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

from .lambda_lrs import LambdaLR
from .multiplicative_lrs import MultiplicativeLR
from .step_lrs import StepLR
from .multistep_lrs import MultiStepLR
from .constant_lrs import ConstantLR
from .exponential_lrs import ExponentialLR
from .linear_lrs import LinearLR


__all__ = [
"LambdaLR",
"MultiplicativeLR",
"StepLR",
"MultiStepLR",
"ConstantLR",
"ExponentialLR",
"LinearLR",
]
162 changes: 162 additions & 0 deletions pypots/optim/lr_scheduler/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
The base class for learning rate schedulers. This class is adapted from PyTorch,
please refer to torch.optim.lr_scheduler for more details.
"""

# Created by Wenjie Du <wenjay.du@gmail.com>
# License: GLP-v3

import weakref
from abc import ABC, abstractmethod
from functools import wraps

from torch.optim import Optimizer

from ...utils.logging import logger


class LRScheduler(ABC):
"""Base class for PyPOTS learning rate schedulers.
Parameters
----------
last_epoch: int
The index of last epoch. Default: -1.
verbose: If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""

def __init__(self, last_epoch=-1, verbose=False):
self.last_epoch = last_epoch
self.verbose = verbose
self.optimizer = None
self.base_lrs = None
self._last_lr = None
self._step_count = 0

def init_scheduler(self, optimizer):
"""Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer()
to initialize the scheduler together with the optimizer.
Parameters
----------
optimizer: torch.optim.Optimizer,
The optimizer to be scheduled.
"""

# Attach optimizer
if not isinstance(optimizer, Optimizer):
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
self.optimizer = optimizer

# Initialize epoch and base learning rates
if self.last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
else:
for i, group in enumerate(optimizer.param_groups):
if "initial_lr" not in group:
raise KeyError(
"param 'initial_lr' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i)
)
self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups]

# Following https://github.com/pytorch/pytorch/issues/20124
# We would like to ensure that `lr_scheduler.step()` is called after
# `optimizer.step()`
def with_counter(method):
if getattr(method, "_with_counter", False):
# `optimizer.step()` has already been replaced, return.
return method

# Keep a weak reference to the optimizer instance to prevent
# cyclic references.
instance_ref = weakref.ref(method.__self__)
# Get the unbound method for the same purpose.
func = method.__func__
cls = instance_ref().__class__
del method

@wraps(func)
def wrapper(*args, **kwargs):
instance = instance_ref()
instance._step_count += 1
wrapped = func.__get__(instance, cls)
return wrapped(*args, **kwargs)

# Note that the returned function here is no longer a bound method,
# so attributes like `__func__` and `__self__` no longer exist.
wrapper._with_counter = True
return wrapper

self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0

@abstractmethod
def get_lr(self):
"""Compute learning rate."""
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError

def get_last_lr(self):
"""Return last computed learning rate by current scheduler."""
return self._last_lr

@staticmethod
def print_lr(is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.")

def step(self):
"""Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()``
after ``pypots.optim.Optimizer.torch_optimizer.step()``.
"""
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
if self._step_count == 1:
if not hasattr(self.optimizer.step, "_with_counter"):
logger.warning(
"Seems like `optimizer.step()` has been overridden after learning rate scheduler "
"initialization. Please, make sure to call `optimizer.step()` before "
"`lr_scheduler.step()`. See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
)

# Just check if there were two first lr_scheduler.step() calls before optimizer.step()
elif self.optimizer._step_count < 1:
logger.warning.warn(
"Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
"In PyTorch 1.1.0 and later, you should call them in the opposite order: "
"`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
"will result in PyTorch skipping the first value of the learning rate schedule. "
"See more details at "
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
)
self._step_count += 1

class _enable_get_lr_call:
def __init__(self, o):
self.o = o

def __enter__(self):
self.o._get_lr_called_within_step = True
return self

def __exit__(self, type, value, traceback):
self.o._get_lr_called_within_step = False

with _enable_get_lr_call(self):
self.last_epoch += 1
values = self.get_lr()

for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group["lr"] = lr
self.print_lr(self.verbose, i, lr)

self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
Loading

0 comments on commit fc79142

Please sign in to comment.