Skip to content

Commit

Permalink
added AdamW and SGDW optimizer and their LR schedular
Browse files Browse the repository at this point in the history
for issue flairNLP#220
  • Loading branch information
Kashif Rasul committed Nov 23, 2018
1 parent eeff246 commit 741743d
Showing 1 changed file with 318 additions and 0 deletions.
318 changes: 318 additions & 0 deletions flair/optim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
import math
from functools import partial

import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.optim.lr_scheduler import ReduceLROnPlateau

class SGDW(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum) with
weight decay from the paper `Fixing Weight Decay Regularization in Adam`_.
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay factor (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
.. _Fixing Weight Decay Regularization in Adam:
https://arxiv.org/abs/1711.05101
Example:
>>> optimizer = torch.optim.SGDW(model.parameters(), lr=0.1, momentum=0.9,
weight_decay=1e-5)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""

def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(SGDW, self).__init__(params, defaults)

def __setstate__(self, state):
super(SGDW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']

for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data

if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
buf.mul_(momentum).add_(d_p)
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf

if weight_decay != 0:
p.data.add_(-weight_decay, p.data)

p.data.add_(-group['lr'], d_p)

return loss


class AdamW(Optimizer):
r"""Implements AdamW optimizer.
Adam has been proposed in `Adam\: A Method for Stochastic Optimization`_.
AdamW uses the weight decay method from the paper
`Fixing Weight Decay Regularization in Adam`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay factor (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Fixing Weight Decay Regularization in Adam:
https://arxiv.org/abs/1711.05101
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""

def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)

def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']

state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']

state['step'] += 1

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])

bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'], p.data)

p.data.addcdiv_(-step_size, exp_avg, denom)

return loss


class ReduceLRWDOnPlateau(ReduceLROnPlateau):
"""Reduce learning rate and weight decay when a metric has stopped
improving. Models often benefit from reducing the learning rate by
a factor of 2-10 once learning stagnates. This scheduler reads a metric
quantity and if no improvement is seen for a 'patience' number
of epochs, the learning rate and weight decay factor is reduced for
optimizers that implement the the weight decay method from the paper
`Fixing Weight Decay Regularization in Adam`_.
.. _Fixing Weight Decay Regularization in Adam:
https://arxiv.org/abs/1711.05101
Args:
optimizer (Optimizer): Wrapped optimizer.
mode (str): One of `min`, `max`. In `min` mode, lr will
be reduced when the quantity monitored has stopped
decreasing; in `max` mode it will be reduced when the
quantity monitored has stopped increasing. Default: 'min'.
factor (float): Factor by which the learning rate will be
reduced. new_lr = lr * factor. Default: 0.1.
patience (int): Number of epochs with no improvement after
which learning rate will be reduced. For example, if
`patience = 2`, then we will ignore the first 2 epochs
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
dynamic_threshold = best * ( 1 + threshold ) in 'max'
mode or best * ( 1 - threshold ) in `min` mode.
In `abs` mode, dynamic_threshold = best + threshold in
`max` mode or best - threshold in `min` mode. Default: 'rel'.
cooldown (int): Number of epochs to wait before resuming
normal operation after lr has been reduced. Default: 0.
min_lr (float or list): A scalar or a list of scalars. A
lower bound on the learning rate of all param groups
or each group respectively. Default: 0.
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
Example:
>>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3)
>>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min')
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
>>> # Note that step should be called after validate()
>>> scheduler.step(val_loss)
"""
def step(self, metrics, epoch=None):
current = metrics
if epoch is None:
epoch = self.last_epoch = self.last_epoch + 1
self.last_epoch = epoch

if self.is_better(current, self.best):
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1

if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown

if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self._reduce_weight_decay(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0

def _reduce_weight_decay(self, epoch):
for i, param_group in enumerate(self.optimizer.param_groups):
if param_group['weight_decay'] != 0:
old_weight_decay = float(param_group['weight_decay'])
new_weight_decay = max(old_weight_decay * self.factor, self.min_lrs[i])
if old_weight_decay - new_weight_decay > self.eps:
param_group['weight_decay'] = new_weight_decay
if self.verbose:
print('Epoch {:5d}: reducing weight decay factor'
' of group {} to {:.4e}.'.format(epoch, i, new_weight_decay))

0 comments on commit 741743d

Please sign in to comment.