diff --git a/docs/pypots.optim.rst b/docs/pypots.optim.rst index 8badeb1c..2bcc93f2 100644 --- a/docs/pypots.optim.rst +++ b/docs/pypots.optim.rst @@ -54,3 +54,12 @@ pypots.optim.base module :undoc-members: :show-inheritance: :inherited-members: + +pypots.optim.lr_scheduler module +------------------------------ + +.. automodule:: pypots.optim.lr_scheduler + :members: + :undoc-members: + :show-inheritance: + :inherited-members: diff --git a/pypots/optim/adadelta.py b/pypots/optim/adadelta.py index ac4726d3..59e98f2a 100644 --- a/pypots/optim/adadelta.py +++ b/pypots/optim/adadelta.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adadelta as torch_Adadelta from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adadelta(Optimizer): @@ -39,8 +40,9 @@ def __init__( rho: float = 0.9, eps: float = 1e-08, weight_decay: float = 0.01, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.rho = rho self.eps = eps self.weight_decay = weight_decay @@ -61,3 +63,6 @@ def init_optimizer(self, params: Iterable) -> None: eps=self.eps, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adagrad.py b/pypots/optim/adagrad.py index e4374244..8a10f06c 100644 --- a/pypots/optim/adagrad.py +++ b/pypots/optim/adagrad.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import Adagrad as torch_Adagrad from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adagrad(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0.01, initial_accumulator_value: float = 0.01, # it is set as 0 in the torch implementation, but delta shouldn't be 0 eps: float = 1e-08, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.lr_decay = lr_decay self.weight_decay = weight_decay self.initial_accumulator_value = initial_accumulator_value @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: initial_accumulator_value=self.initial_accumulator_value, eps=self.eps, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adam.py b/pypots/optim/adam.py index d308b27e..c5e0e1af 100644 --- a/pypots/optim/adam.py +++ b/pypots/optim/adam.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import Adam as torch_Adam from .base import Optimizer +from .lr_scheduler.base import LRScheduler class Adam(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/adamw.py b/pypots/optim/adamw.py index b93d8a74..6a5191e4 100644 --- a/pypots/optim/adamw.py +++ b/pypots/optim/adamw.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional from torch.optim import AdamW as torch_AdamW from .base import Optimizer +from .lr_scheduler.base import LRScheduler class AdamW(Optimizer): @@ -42,8 +43,9 @@ def __init__( eps: float = 1e-08, weight_decay: float = 0.01, amsgrad: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.betas = betas self.eps = eps self.weight_decay = weight_decay @@ -66,3 +68,6 @@ def init_optimizer(self, params: Iterable) -> None: weight_decay=self.weight_decay, amsgrad=self.amsgrad, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/base.py b/pypots/optim/base.py index f1bb9637..db09fb3a 100644 --- a/pypots/optim/base.py +++ b/pypots/optim/base.py @@ -19,6 +19,8 @@ from abc import ABC, abstractmethod from typing import Callable, Iterable, Optional +from .lr_scheduler.base import LRScheduler + class Optimizer(ABC): """The base wrapper for PyTorch optimizers, also is the base class for all optimizers in pypots.optim. @@ -35,9 +37,10 @@ class Optimizer(ABC): """ - def __init__(self, lr): + def __init__(self, lr, lr_scheduler: Optional[LRScheduler] = None): self.lr = lr self.torch_optimizer = None + self.lr_scheduler = lr_scheduler @abstractmethod def init_optimizer(self, params: Iterable) -> None: @@ -97,6 +100,9 @@ def step(self, closure: Optional[Callable] = None) -> None: """ self.torch_optimizer.step(closure) + if self.lr_scheduler is not None: + self.lr_scheduler.step() + def zero_grad(self, set_to_none: bool = True) -> None: """Sets the gradients of all optimized ``torch.Tensor`` to zero. diff --git a/pypots/optim/lr_scheduler/__init__.py b/pypots/optim/lr_scheduler/__init__.py new file mode 100644 index 00000000..ddb14350 --- /dev/null +++ b/pypots/optim/lr_scheduler/__init__.py @@ -0,0 +1,29 @@ +""" +Learning rate schedulers available in PyPOTS. Their functionalities are the same with those in PyTorch, +the only difference that is also why we implement them is that you don't have to pass according optimizers +into them immediately while initializing them. Instead, you can pass them into pypots.optim.Optimizer +after initialization and call their `init_scheduler()` method in pypots.optim.Optimizer.init_optimizer() to initialize +schedulers together with optimizers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .lambda_lrs import LambdaLR +from .multiplicative_lrs import MultiplicativeLR +from .step_lrs import StepLR +from .multistep_lrs import MultiStepLR +from .constant_lrs import ConstantLR +from .exponential_lrs import ExponentialLR +from .linear_lrs import LinearLR + + +__all__ = [ + "LambdaLR", + "MultiplicativeLR", + "StepLR", + "MultiStepLR", + "ConstantLR", + "ExponentialLR", + "LinearLR", +] diff --git a/pypots/optim/lr_scheduler/base.py b/pypots/optim/lr_scheduler/base.py new file mode 100644 index 00000000..0aeffd8b --- /dev/null +++ b/pypots/optim/lr_scheduler/base.py @@ -0,0 +1,162 @@ +""" +The base class for learning rate schedulers. This class is adapted from PyTorch, +please refer to torch.optim.lr_scheduler for more details. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import weakref +from abc import ABC, abstractmethod +from functools import wraps + +from torch.optim import Optimizer + +from ...utils.logging import logger + + +class LRScheduler(ABC): + """Base class for PyPOTS learning rate schedulers. + + Parameters + ---------- + last_epoch: int + The index of last epoch. Default: -1. + + verbose: If ``True``, prints a message to stdout for + each update. Default: ``False``. + + """ + + def __init__(self, last_epoch=-1, verbose=False): + self.last_epoch = last_epoch + self.verbose = verbose + self.optimizer = None + self.base_lrs = None + self._last_lr = None + self._step_count = 0 + + def init_scheduler(self, optimizer): + """Initialize the scheduler. This method should be called in pypots.optim.Optimizer.init_optimizer() + to initialize the scheduler together with the optimizer. + + Parameters + ---------- + optimizer: torch.optim.Optimizer, + The optimizer to be scheduled. + + """ + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if self.last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault("initial_lr", group["lr"]) + else: + for i, group in enumerate(optimizer.param_groups): + if "initial_lr" not in group: + raise KeyError( + "param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i) + ) + self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, "_with_counter", False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.optimizer._step_count = 0 + + @abstractmethod + def get_lr(self): + """Compute learning rate.""" + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def get_last_lr(self): + """Return last computed learning rate by current scheduler.""" + return self._last_lr + + @staticmethod + def print_lr(is_verbose, group, lr): + """Display the current learning rate.""" + if is_verbose: + logger.info(f"Adjusting learning rate of group {group} to {lr:.4e}.") + + def step(self): + """Step could be called after every batch update. This should be called in ``pypots.optim.Optimizer.step()`` + after ``pypots.optim.Optimizer.torch_optimizer.step()``. + """ + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + logger.warning( + "Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + logger.warning.warn( + "Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", + ) + self._step_count += 1 + + class _enable_get_lr_call: + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + with _enable_get_lr_call(self): + self.last_epoch += 1 + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group["lr"] = lr + self.print_lr(self.verbose, i, lr) + + self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/constant_lrs.py b/pypots/optim/lr_scheduler/constant_lrs.py new file mode 100644 index 00000000..3f6ae1a3 --- /dev/null +++ b/pypots/optim/lr_scheduler/constant_lrs.py @@ -0,0 +1,84 @@ +""" +Constant learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ConstantLR(LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the number of epoch reaches + a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously with other changes + to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + factor: float, default=1./3. + The number we multiply learning rate until the milestone. + + total_iters: int, default=5, + The number of steps that the scheduler decays the learning rate. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ConstantLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = ConstantLR(factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + if factor > 1.0 or factor < 0: + raise ValueError( + "Constant multiplicative factor expected to be between 0 and 1." + ) + + self.factor = factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] * self.factor for group in self.optimizer.param_groups] + + if self.last_epoch > self.total_iters or (self.last_epoch != self.total_iters): + return [group["lr"] for group in self.optimizer.param_groups] + + if self.last_epoch == self.total_iters: + return [ + group["lr"] * (1.0 / self.factor) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/exponential_lrs.py b/pypots/optim/lr_scheduler/exponential_lrs.py new file mode 100644 index 00000000..ed7e960f --- /dev/null +++ b/pypots/optim/lr_scheduler/exponential_lrs.py @@ -0,0 +1,55 @@ +""" +Exponential learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + gamma: float, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.ExponentialLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> scheduler = ExponentialLR(gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, gamma, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs] diff --git a/pypots/optim/lr_scheduler/lambda_lrs.py b/pypots/optim/lr_scheduler/lambda_lrs.py new file mode 100644 index 00000000..5471cee6 --- /dev/null +++ b/pypots/optim/lr_scheduler/lambda_lrs.py @@ -0,0 +1,79 @@ +""" +Lambda learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from typing import Callable, Union + +from .base import LRScheduler, logger + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr times a given function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LambdaLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lambda1 = lambda epoch: epoch // 30 + >>> scheduler = LambdaLR(lr_lambda=lambda1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + lr_lambda: Union[Callable, list], + last_epoch: int = -1, + verbose: bool = False, + ): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`." + ) + + return [ + base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs) + ] diff --git a/pypots/optim/lr_scheduler/linear_lrs.py b/pypots/optim/lr_scheduler/linear_lrs.py new file mode 100644 index 00000000..a1e8e1e6 --- /dev/null +++ b/pypots/optim/lr_scheduler/linear_lrs.py @@ -0,0 +1,115 @@ +""" +Linear learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small multiplicative factor until + the number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can happen simultaneously + with other changes to the learning rate from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + start_factor: float, default=1.0 / 3, + The number we multiply learning rate in the first epoch. The multiplication factor changes towards + end_factor in the following epochs. + + end_factor: float, default=1.0, + The number we multiply learning rate at the end of linear changing process. + + total_iters: int, default=5, + The number of iterations that multiplicative factor reaches to 1. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.LinearLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> # xdoctest: +SKIP + >>> scheduler = LinearLR(start_factor=0.5, total_iters=4) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__( + self, + start_factor=1.0 / 3, + end_factor=1.0, + total_iters=5, + last_epoch=-1, + verbose=False, + ): + super().__init__(last_epoch, verbose) + if start_factor > 1.0 or start_factor < 0: + raise ValueError( + "Starting multiplicative factor expected to be between 0 and 1." + ) + + if end_factor > 1.0 or end_factor < 0: + raise ValueError( + "Ending multiplicative factor expected to be between 0 and 1." + ) + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch == 0: + return [ + group["lr"] * self.start_factor for group in self.optimizer.param_groups + ] + + if self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + return [ + group["lr"] + * ( + 1.0 + + (self.end_factor - self.start_factor) + / ( + self.total_iters * self.start_factor + + (self.last_epoch - 1) * (self.end_factor - self.start_factor) + ) + ) + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + return [ + base_lr + * ( + self.start_factor + + (self.end_factor - self.start_factor) + * min(self.total_iters, self.last_epoch) + / self.total_iters + ) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/multiplicative_lrs.py b/pypots/optim/lr_scheduler/multiplicative_lrs.py new file mode 100644 index 00000000..5dbc18ea --- /dev/null +++ b/pypots/optim/lr_scheduler/multiplicative_lrs.py @@ -0,0 +1,77 @@ +""" +Multiplicative learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + + +from .base import LRScheduler, logger + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given in the specified function. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + lr_lambda: Callable or list, + A function which computes a multiplicative factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + + last_epoch: int, + The index of last epoch. Default: -1. + + verbose: bool, + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiplicativeLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> lmbda = lambda epoch: 0.95 + >>> # xdoctest: +SKIP + >>> scheduler = MultiplicativeLR(lr_lambda=lmbda) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, lr_lambda, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.lr_lambda = lr_lambda + self.lr_lambdas = None + + def init_scheduler(self, optimizer): + if not isinstance(self.lr_lambda, list) and not isinstance( + self.lr_lambda, tuple + ): + self.lr_lambdas = [self.lr_lambda] * len(optimizer.param_groups) + else: + if len(self.lr_lambda) != len(optimizer.param_groups): + raise ValueError( + "Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(self.lr_lambda) + ) + ) + self.lr_lambdas = list(self.lr_lambda) + + super().init_scheduler(optimizer) + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch > 0: + return [ + group["lr"] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups) + ] + else: + return [group["lr"] for group in self.optimizer.param_groups] diff --git a/pypots/optim/lr_scheduler/multistep_lrs.py b/pypots/optim/lr_scheduler/multistep_lrs.py new file mode 100644 index 00000000..567570e9 --- /dev/null +++ b/pypots/optim/lr_scheduler/multistep_lrs.py @@ -0,0 +1,75 @@ +""" +Multistep learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from bisect import bisect_right +from collections import Counter + +from .base import LRScheduler, logger + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. + Notice that such decay can happen simultaneously with other changes to the learning rate from outside this + scheduler. When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + milestones: list, + List of epoch indices. Must be increasing. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.MultiStepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> # xdoctest: +SKIP + >>> scheduler = MultiStepLR(milestones=[30,80], gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, milestones, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + self.milestones = Counter(milestones) + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if self.last_epoch not in self.milestones: + return [group["lr"] for group in self.optimizer.param_groups] + return [ + group["lr"] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + milestones = list(sorted(self.milestones.elements())) + return [ + base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/lr_scheduler/step_lrs.py b/pypots/optim/lr_scheduler/step_lrs.py new file mode 100644 index 00000000..29f72bb8 --- /dev/null +++ b/pypots/optim/lr_scheduler/step_lrs.py @@ -0,0 +1,70 @@ +""" +Step learning rate scheduler. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +from .base import LRScheduler, logger + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Parameters + ---------- + step_size: int, + Period of learning rate decay. + + gamma: float, default=0.1, + Multiplicative factor of learning rate decay. + + last_epoch: int + The index of last epoch. Default: -1. + + verbose: bool + If ``True``, prints a message to stdout for each update. Default: ``False``. + + Notes + ----- + This class works the same with ``torch.optim.lr_scheduler.StepLR``. + The only difference that is also why we implement them is that you don't have to pass according optimizers + into them immediately while initializing them. + + Example + ------- + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> # xdoctest: +SKIP + >>> scheduler = StepLR(step_size=30, gamma=0.1) + >>> adam = pypots.optim.Adam(lr=1e-3, lr_scheduler=scheduler) + + """ + + def __init__(self, step_size, gamma=0.1, last_epoch=-1, verbose=False): + super().__init__(last_epoch, verbose) + + self.step_size = step_size + self.gamma = gamma + + def get_lr(self): + if not self._get_lr_called_within_step: + logger.warning( + "To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", + ) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group["lr"] for group in self.optimizer.param_groups] + return [group["lr"] * self.gamma for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs + ] diff --git a/pypots/optim/rmsprop.py b/pypots/optim/rmsprop.py index 65a817ca..f00da68d 100644 --- a/pypots/optim/rmsprop.py +++ b/pypots/optim/rmsprop.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import RMSprop as torch_RMSprop from .base import Optimizer +from .lr_scheduler.base import LRScheduler class RMSprop(Optimizer): @@ -47,8 +48,9 @@ def __init__( eps: float = 1e-08, centered: bool = False, weight_decay: float = 0, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.alpha = alpha self.eps = eps @@ -73,3 +75,6 @@ def init_optimizer(self, params: Iterable) -> None: centered=self.centered, weight_decay=self.weight_decay, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/pypots/optim/sgd.py b/pypots/optim/sgd.py index 4696db91..34cd07f0 100644 --- a/pypots/optim/sgd.py +++ b/pypots/optim/sgd.py @@ -6,11 +6,12 @@ # Created by Wenjie Du # License: GLP-v3 -from typing import Iterable +from typing import Iterable, Optional from torch.optim import SGD as torch_SGD from .base import Optimizer +from .lr_scheduler.base import LRScheduler class SGD(Optimizer): @@ -43,8 +44,9 @@ def __init__( weight_decay: float = 0, dampening: float = 0, nesterov: bool = False, + lr_scheduler: Optional[LRScheduler] = None, ): - super().__init__(lr) + super().__init__(lr, lr_scheduler) self.momentum = momentum self.weight_decay = weight_decay self.dampening = dampening @@ -67,3 +69,6 @@ def init_optimizer(self, params: Iterable) -> None: dampening=self.dampening, nesterov=self.nesterov, ) + + if self.lr_scheduler is not None: + self.lr_scheduler.init_scheduler(self.torch_optimizer) diff --git a/tests/optim/lr_schedulers.py b/tests/optim/lr_schedulers.py new file mode 100644 index 00000000..e7748f91 --- /dev/null +++ b/tests/optim/lr_schedulers.py @@ -0,0 +1,249 @@ +""" +Test cases for the learning rate schedulers. +""" + +# Created by Wenjie Du +# License: GLP-v3 + +import unittest + +import numpy as np +import pytest + +from pypots.imputation import SAITS +from pypots.optim import Adam, AdamW, Adadelta, Adagrad, RMSprop, SGD +from pypots.optim.lr_scheduler import ( + LambdaLR, + ConstantLR, + ExponentialLR, + LinearLR, + StepLR, + MultiStepLR, + MultiplicativeLR, +) +from pypots.utils.logging import logger +from pypots.utils.metrics import cal_mae +from tests.global_test_config import DATA +from tests.optim.config import EPOCHS, TEST_SET, TRAIN_SET, VAL_SET + + +class TestLRSchedulers(unittest.TestCase): + logger.info("Running tests for learning rate schedulers...") + + # init lambda_lrs + lambda_lrs = LambdaLR(lr_lambda=lambda epoch: epoch // 30, verbose=True) + + # init multiplicative_lrs + multiplicative_lrs = MultiplicativeLR(lr_lambda=lambda epoch: 0.95, verbose=True) + + # init step_lrs + step_lrs = StepLR(step_size=30, gamma=0.1, verbose=True) + + # init multistep_lrs + multistep_lrs = MultiStepLR(milestones=[30, 80], gamma=0.1, verbose=True) + + # init constant_lrs + constant_lrs = ConstantLR(factor=0.5, total_iters=4, verbose=True) + + # init linear_lrs + linear_lrs = LinearLR(start_factor=0.5, total_iters=4, verbose=True) + + # init exponential_lrs + exponential_lrs = ExponentialLR(gamma=0.9, verbose=True) + + @pytest.mark.xdist_group(name="lrs-lambda") + def test_0_lambda_lrs(self): + logger.info("Running tests for Adam + LambdaLRS...") + + adam = Adam(lr=0.001, weight_decay=1e-5, lr_scheduler=self.lambda_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adam, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multiplicative") + def test_1_multiplicative_lrs(self): + logger.info("Running tests for Adamw + MultiplicativeLRS...") + + adamw = AdamW(lr=0.001, weight_decay=1e-5, lr_scheduler=self.multiplicative_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-step") + def test_2_step_lrs(self): + logger.info("Running tests for Adadelta + StepLRS...") + + adamw = Adadelta(lr=0.001, lr_scheduler=self.step_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adamw, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-multistep") + def test_3_multistep_lrs(self): + logger.info("Running tests for Adadelta + MultiStepLRS...") + + adagrad = Adagrad(lr=0.001, lr_scheduler=self.multistep_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=adagrad, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-constant") + def test_4_constant_lrs(self): + logger.info("Running tests for RMSprop + ConstantLRS...") + + # initialize a SAITS model for testing DatasetForMIT and BaseDataset + rmsprop = RMSprop(lr=0.001, lr_scheduler=self.constant_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=rmsprop, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-linear") + def test_5_linear_lrs(self): + logger.info("Running tests for SGD + MultiStepLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.linear_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}") + + @pytest.mark.xdist_group(name="lrs-exponential") + def test_6_exponential_lrs(self): + logger.info("Running tests for SGD + ExponentialLRS...") + + sgd = SGD(lr=0.001, lr_scheduler=self.exponential_lrs) + saits = SAITS( + DATA["n_steps"], + DATA["n_features"], + n_layers=1, + d_model=128, + d_inner=64, + n_heads=2, + d_k=64, + d_v=64, + dropout=0.1, + optimizer=sgd, + epochs=EPOCHS, + ) + saits.fit(TRAIN_SET, VAL_SET) + imputed_X = saits.impute(TEST_SET) + assert not np.isnan( + imputed_X + ).any(), "Output still has missing values after running impute()." + test_MAE = cal_mae( + imputed_X, DATA["test_X_intact"], DATA["test_X_indicating_mask"] + ) + logger.info(f"SAITS test_MAE: {test_MAE}")