forked from flairNLP/flair
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added AdamW and SGDW optimizer and their LR schedular
for issue flairNLP#220
- Loading branch information
Kashif Rasul
committed
Nov 23, 2018
1 parent
eeff246
commit 741743d
Showing
1 changed file
with
318 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
import math | ||
from functools import partial | ||
|
||
import torch | ||
from torch.optim import Optimizer | ||
from torch.optim.optimizer import required | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
|
||
class SGDW(Optimizer): | ||
r"""Implements stochastic gradient descent (optionally with momentum) with | ||
weight decay from the paper `Fixing Weight Decay Regularization in Adam`_. | ||
Nesterov momentum is based on the formula from | ||
`On the importance of initialization and momentum in deep learning`__. | ||
Args: | ||
params (iterable): iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr (float): learning rate | ||
momentum (float, optional): momentum factor (default: 0) | ||
weight_decay (float, optional): weight decay factor (default: 0) | ||
dampening (float, optional): dampening for momentum (default: 0) | ||
nesterov (bool, optional): enables Nesterov momentum (default: False) | ||
.. _Fixing Weight Decay Regularization in Adam: | ||
https://arxiv.org/abs/1711.05101 | ||
Example: | ||
>>> optimizer = torch.optim.SGDW(model.parameters(), lr=0.1, momentum=0.9, | ||
weight_decay=1e-5) | ||
>>> optimizer.zero_grad() | ||
>>> loss_fn(model(input), target).backward() | ||
>>> optimizer.step() | ||
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf | ||
.. note:: | ||
The implementation of SGD with Momentum/Nesterov subtly differs from | ||
Sutskever et. al. and implementations in some other frameworks. | ||
Considering the specific case of Momentum, the update can be written as | ||
.. math:: | ||
v = \rho * v + g \\ | ||
p = p - lr * v | ||
where p, g, v and :math:`\rho` denote the parameters, gradient, | ||
velocity, and momentum respectively. | ||
This is in contrast to Sutskever et. al. and | ||
other frameworks which employ an update of the form | ||
.. math:: | ||
v = \rho * v + lr * g \\ | ||
p = p - v | ||
The Nesterov version is analogously modified. | ||
""" | ||
|
||
def __init__(self, params, lr=required, momentum=0, dampening=0, | ||
weight_decay=0, nesterov=False): | ||
if lr is not required and lr < 0.0: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if momentum < 0.0: | ||
raise ValueError("Invalid momentum value: {}".format(momentum)) | ||
if weight_decay < 0.0: | ||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) | ||
|
||
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, | ||
weight_decay=weight_decay, nesterov=nesterov) | ||
if nesterov and (momentum <= 0 or dampening != 0): | ||
raise ValueError("Nesterov momentum requires a momentum and zero dampening") | ||
super(SGDW, self).__init__(params, defaults) | ||
|
||
def __setstate__(self, state): | ||
super(SGDW, self).__setstate__(state) | ||
for group in self.param_groups: | ||
group.setdefault('nesterov', False) | ||
|
||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
weight_decay = group['weight_decay'] | ||
momentum = group['momentum'] | ||
dampening = group['dampening'] | ||
nesterov = group['nesterov'] | ||
|
||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
d_p = p.grad.data | ||
|
||
if momentum != 0: | ||
param_state = self.state[p] | ||
if 'momentum_buffer' not in param_state: | ||
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) | ||
buf.mul_(momentum).add_(d_p) | ||
else: | ||
buf = param_state['momentum_buffer'] | ||
buf.mul_(momentum).add_(1 - dampening, d_p) | ||
if nesterov: | ||
d_p = d_p.add(momentum, buf) | ||
else: | ||
d_p = buf | ||
|
||
if weight_decay != 0: | ||
p.data.add_(-weight_decay, p.data) | ||
|
||
p.data.add_(-group['lr'], d_p) | ||
|
||
return loss | ||
|
||
|
||
class AdamW(Optimizer): | ||
r"""Implements AdamW optimizer. | ||
Adam has been proposed in `Adam\: A Method for Stochastic Optimization`_. | ||
AdamW uses the weight decay method from the paper | ||
`Fixing Weight Decay Regularization in Adam`_. | ||
Arguments: | ||
params (iterable): iterable of parameters to optimize or dicts defining | ||
parameter groups | ||
lr (float, optional): learning rate (default: 1e-3) | ||
betas (Tuple[float, float], optional): coefficients used for computing | ||
running averages of gradient and its square (default: (0.9, 0.999)) | ||
eps (float, optional): term added to the denominator to improve | ||
numerical stability (default: 1e-8) | ||
weight_decay (float, optional): weight decay factor (default: 0) | ||
amsgrad (boolean, optional): whether to use the AMSGrad variant of this | ||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | ||
(default: False) | ||
.. _Adam\: A Method for Stochastic Optimization: | ||
https://arxiv.org/abs/1412.6980 | ||
.. _Fixing Weight Decay Regularization in Adam: | ||
https://arxiv.org/abs/1711.05101 | ||
.. _On the Convergence of Adam and Beyond: | ||
https://openreview.net/forum?id=ryQu7f-RZ | ||
""" | ||
|
||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | ||
weight_decay=0, amsgrad=False): | ||
if not 0.0 <= lr: | ||
raise ValueError("Invalid learning rate: {}".format(lr)) | ||
if not 0.0 <= eps: | ||
raise ValueError("Invalid epsilon value: {}".format(eps)) | ||
if not 0.0 <= betas[0] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) | ||
if not 0.0 <= betas[1] < 1.0: | ||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) | ||
defaults = dict(lr=lr, betas=betas, eps=eps, | ||
weight_decay=weight_decay, amsgrad=amsgrad) | ||
super(AdamW, self).__init__(params, defaults) | ||
|
||
def __setstate__(self, state): | ||
super(AdamW, self).__setstate__(state) | ||
for group in self.param_groups: | ||
group.setdefault('amsgrad', False) | ||
|
||
def step(self, closure=None): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable, optional): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
loss = None | ||
if closure is not None: | ||
loss = closure() | ||
|
||
for group in self.param_groups: | ||
for p in group['params']: | ||
if p.grad is None: | ||
continue | ||
grad = p.grad.data | ||
if grad.is_sparse: | ||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') | ||
amsgrad = group['amsgrad'] | ||
|
||
state = self.state[p] | ||
|
||
# State initialization | ||
if len(state) == 0: | ||
state['step'] = 0 | ||
# Exponential moving average of gradient values | ||
state['exp_avg'] = torch.zeros_like(p.data) | ||
# Exponential moving average of squared gradient values | ||
state['exp_avg_sq'] = torch.zeros_like(p.data) | ||
if amsgrad: | ||
# Maintains max of all exp. moving avg. of sq. grad. values | ||
state['max_exp_avg_sq'] = torch.zeros_like(p.data) | ||
|
||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | ||
if amsgrad: | ||
max_exp_avg_sq = state['max_exp_avg_sq'] | ||
beta1, beta2 = group['betas'] | ||
|
||
state['step'] += 1 | ||
|
||
# Decay the first and second moment running average coefficient | ||
exp_avg.mul_(beta1).add_(1 - beta1, grad) | ||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) | ||
if amsgrad: | ||
# Maintains the maximum of all 2nd moment running avg. till now | ||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) | ||
# Use the max. for normalizing running avg. of gradient | ||
denom = max_exp_avg_sq.sqrt().add_(group['eps']) | ||
else: | ||
denom = exp_avg_sq.sqrt().add_(group['eps']) | ||
|
||
bias_correction1 = 1 - beta1 ** state['step'] | ||
bias_correction2 = 1 - beta2 ** state['step'] | ||
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 | ||
|
||
if group['weight_decay'] != 0: | ||
p.data.add_(-group['weight_decay'], p.data) | ||
|
||
p.data.addcdiv_(-step_size, exp_avg, denom) | ||
|
||
return loss | ||
|
||
|
||
class ReduceLRWDOnPlateau(ReduceLROnPlateau): | ||
"""Reduce learning rate and weight decay when a metric has stopped | ||
improving. Models often benefit from reducing the learning rate by | ||
a factor of 2-10 once learning stagnates. This scheduler reads a metric | ||
quantity and if no improvement is seen for a 'patience' number | ||
of epochs, the learning rate and weight decay factor is reduced for | ||
optimizers that implement the the weight decay method from the paper | ||
`Fixing Weight Decay Regularization in Adam`_. | ||
.. _Fixing Weight Decay Regularization in Adam: | ||
https://arxiv.org/abs/1711.05101 | ||
Args: | ||
optimizer (Optimizer): Wrapped optimizer. | ||
mode (str): One of `min`, `max`. In `min` mode, lr will | ||
be reduced when the quantity monitored has stopped | ||
decreasing; in `max` mode it will be reduced when the | ||
quantity monitored has stopped increasing. Default: 'min'. | ||
factor (float): Factor by which the learning rate will be | ||
reduced. new_lr = lr * factor. Default: 0.1. | ||
patience (int): Number of epochs with no improvement after | ||
which learning rate will be reduced. For example, if | ||
`patience = 2`, then we will ignore the first 2 epochs | ||
with no improvement, and will only decrease the LR after the | ||
3rd epoch if the loss still hasn't improved then. | ||
Default: 10. | ||
verbose (bool): If ``True``, prints a message to stdout for | ||
each update. Default: ``False``. | ||
threshold (float): Threshold for measuring the new optimum, | ||
to only focus on significant changes. Default: 1e-4. | ||
threshold_mode (str): One of `rel`, `abs`. In `rel` mode, | ||
dynamic_threshold = best * ( 1 + threshold ) in 'max' | ||
mode or best * ( 1 - threshold ) in `min` mode. | ||
In `abs` mode, dynamic_threshold = best + threshold in | ||
`max` mode or best - threshold in `min` mode. Default: 'rel'. | ||
cooldown (int): Number of epochs to wait before resuming | ||
normal operation after lr has been reduced. Default: 0. | ||
min_lr (float or list): A scalar or a list of scalars. A | ||
lower bound on the learning rate of all param groups | ||
or each group respectively. Default: 0. | ||
eps (float): Minimal decay applied to lr. If the difference | ||
between new and old lr is smaller than eps, the update is | ||
ignored. Default: 1e-8. | ||
Example: | ||
>>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) | ||
>>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') | ||
>>> for epoch in range(10): | ||
>>> train(...) | ||
>>> val_loss = validate(...) | ||
>>> # Note that step should be called after validate() | ||
>>> scheduler.step(val_loss) | ||
""" | ||
def step(self, metrics, epoch=None): | ||
current = metrics | ||
if epoch is None: | ||
epoch = self.last_epoch = self.last_epoch + 1 | ||
self.last_epoch = epoch | ||
|
||
if self.is_better(current, self.best): | ||
self.best = current | ||
self.num_bad_epochs = 0 | ||
else: | ||
self.num_bad_epochs += 1 | ||
|
||
if self.in_cooldown: | ||
self.cooldown_counter -= 1 | ||
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown | ||
|
||
if self.num_bad_epochs > self.patience: | ||
self._reduce_lr(epoch) | ||
self._reduce_weight_decay(epoch) | ||
self.cooldown_counter = self.cooldown | ||
self.num_bad_epochs = 0 | ||
|
||
def _reduce_weight_decay(self, epoch): | ||
for i, param_group in enumerate(self.optimizer.param_groups): | ||
if param_group['weight_decay'] != 0: | ||
old_weight_decay = float(param_group['weight_decay']) | ||
new_weight_decay = max(old_weight_decay * self.factor, self.min_lrs[i]) | ||
if old_weight_decay - new_weight_decay > self.eps: | ||
param_group['weight_decay'] = new_weight_decay | ||
if self.verbose: | ||
print('Epoch {:5d}: reducing weight decay factor' | ||
' of group {} to {:.4e}.'.format(epoch, i, new_weight_decay)) | ||
|