Skip to content

Commit

Permalink
Verbose param for schedulers that don't have it #38726 (#41580)
Browse files Browse the repository at this point in the history
Summary:
Verbose param for schedulers that don't have it pytorch/pytorch#38726

Pull Request resolved: pytorch/pytorch#41580

Reviewed By: izdeby

Differential Revision: D22671163

Pulled By: vincentqb

fbshipit-source-id: 53a6c9e929141d411b6846bc25f3fe7f46fdf3be
  • Loading branch information
guol-fnst authored and facebook-github-bot committed Jul 23, 2020
1 parent 37e7f0c commit 17f76f9
Showing 1 changed file with 62 additions and 25 deletions.
87 changes: 62 additions & 25 deletions torch/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class _LRScheduler(object):

def __init__(self, optimizer, last_epoch=-1):
def __init__(self, optimizer, last_epoch=-1, verbose=False):

# Attach optimizer
if not isinstance(optimizer, Optimizer):
Expand Down Expand Up @@ -74,6 +74,7 @@ def wrapper(*args, **kwargs):
self.optimizer.step = with_counter(self.optimizer.step)
self.optimizer._step_count = 0
self._step_count = 0
self.verbose = verbose

self.step()

Expand Down Expand Up @@ -103,6 +104,18 @@ def get_lr(self):
# Compute learning rate using chainable form of the scheduler
raise NotImplementedError

def print_lr(self, is_verbose, group, lr, epoch=None):
"""Display the current learning rate.
"""
if is_verbose:
if epoch is None:
print('Adjusting learning rate'
' of group {} to {:.4e}.'.format(group, lr))
else:
print('Epoch {:5d}: adjusting learning rate'
' of group {} to {:.4e}.'.format(epoch, group, lr))


def step(self, epoch=None):
# Raise a warning if old pattern is detected
# https://github.com/pytorch/pytorch/issues/20124
Expand Down Expand Up @@ -147,8 +160,10 @@ def __exit__(self, type, value, traceback):
else:
values = self.get_lr()

for param_group, lr in zip(self.optimizer.param_groups, values):
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)

self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

Expand All @@ -163,6 +178,8 @@ class LambdaLR(_LRScheduler):
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming optimizer has two groups.
Expand All @@ -175,7 +192,7 @@ class LambdaLR(_LRScheduler):
>>> scheduler.step()
"""

def __init__(self, optimizer, lr_lambda, last_epoch=-1):
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
self.optimizer = optimizer

if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
Expand All @@ -186,7 +203,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
len(optimizer.param_groups), len(lr_lambda)))
self.lr_lambdas = list(lr_lambda)
self.last_epoch = last_epoch
super(LambdaLR, self).__init__(optimizer, last_epoch)
super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
Expand Down Expand Up @@ -245,6 +262,8 @@ class MultiplicativeLR(_LRScheduler):
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> lmbda = lambda epoch: 0.95
Expand All @@ -255,7 +274,7 @@ class MultiplicativeLR(_LRScheduler):
>>> scheduler.step()
"""

def __init__(self, optimizer, lr_lambda, last_epoch=-1):
def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
self.optimizer = optimizer

if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
Expand All @@ -266,7 +285,7 @@ def __init__(self, optimizer, lr_lambda, last_epoch=-1):
len(optimizer.param_groups), len(lr_lambda)))
self.lr_lambdas = list(lr_lambda)
self.last_epoch = last_epoch
super(MultiplicativeLR, self).__init__(optimizer, last_epoch)
super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)

def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
Expand Down Expand Up @@ -326,6 +345,8 @@ class StepLR(_LRScheduler):
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
Expand All @@ -340,10 +361,10 @@ class StepLR(_LRScheduler):
>>> scheduler.step()
"""

def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):
def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
self.step_size = step_size
self.gamma = gamma
super(StepLR, self).__init__(optimizer, last_epoch)
super(StepLR, self).__init__(optimizer, last_epoch, verbose)

def get_lr(self):
if not self._get_lr_called_within_step:
Expand Down Expand Up @@ -372,6 +393,8 @@ class MultiStepLR(_LRScheduler):
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> # Assuming optimizer uses lr = 0.05 for all groups
Expand All @@ -385,10 +408,10 @@ class MultiStepLR(_LRScheduler):
>>> scheduler.step()
"""

def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
self.milestones = Counter(milestones)
self.gamma = gamma
super(MultiStepLR, self).__init__(optimizer, last_epoch)
super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)

def get_lr(self):
if not self._get_lr_called_within_step:
Expand All @@ -414,11 +437,13 @@ class ExponentialLR(_LRScheduler):
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""

def __init__(self, optimizer, gamma, last_epoch=-1):
def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
self.gamma = gamma
super(ExponentialLR, self).__init__(optimizer, last_epoch)
super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)

def get_lr(self):
if not self._get_lr_called_within_step:
Expand Down Expand Up @@ -468,15 +493,17 @@ class CosineAnnealingLR(_LRScheduler):
T_max (int): Maximum number of iterations.
eta_min (float): Minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""

def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):
def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
self.T_max = T_max
self.eta_min = eta_min
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch)
super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)

def get_lr(self):
if not self._get_lr_called_within_step:
Expand Down Expand Up @@ -522,8 +549,6 @@ class ReduceLROnPlateau(object):
with no improvement, and will only decrease the LR after the
3rd epoch if the loss still hasn't improved then.
Default: 10.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
threshold (float): Threshold for measuring the new optimum,
to only focus on significant changes. Default: 1e-4.
threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
Expand All @@ -539,6 +564,8 @@ class ReduceLROnPlateau(object):
eps (float): Minimal decay applied to lr. If the difference
between new and old lr is smaller than eps, the update is
ignored. Default: 1e-8.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
Expand All @@ -551,8 +578,8 @@ class ReduceLROnPlateau(object):
"""

def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=1e-4, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-8):
threshold=1e-4, threshold_mode='rel', cooldown=0,
min_lr=0, eps=1e-8, verbose=False):

if factor >= 1.0:
raise ValueError('Factor should be < 1.0.')
Expand Down Expand Up @@ -749,6 +776,8 @@ class CyclicLR(_LRScheduler):
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
Expand Down Expand Up @@ -777,7 +806,8 @@ def __init__(self,
cycle_momentum=True,
base_momentum=0.8,
max_momentum=0.9,
last_epoch=-1):
last_epoch=-1,
verbose=False):

# Attach optimizer
if not isinstance(optimizer, Optimizer):
Expand Down Expand Up @@ -830,7 +860,7 @@ def __init__(self,
self.base_momentums = list(map(lambda group: group['momentum'], optimizer.param_groups))
self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)

super(CyclicLR, self).__init__(optimizer, last_epoch)
super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
self.base_lrs = base_lrs

def _format_param(self, name, optimizer, param):
Expand Down Expand Up @@ -917,12 +947,14 @@ class CosineAnnealingWarmRestarts(_LRScheduler):
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
eta_min (float, optional): Minimum learning rate. Default: 0.
last_epoch (int, optional): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""

def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
if T_0 <= 0 or not isinstance(T_0, int):
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
if T_mult < 1 or not isinstance(T_mult, int):
Expand All @@ -932,7 +964,7 @@ def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
self.T_mult = T_mult
self.eta_min = eta_min

super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)

self.T_cur = self.last_epoch

Expand Down Expand Up @@ -1008,8 +1040,10 @@ def __exit__(self, type, value, traceback):
return self

with _enable_get_lr_call(self):
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
param_group, lr = data
param_group['lr'] = lr
self.print_lr(self.verbose, i, lr, epoch)

self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

Expand Down Expand Up @@ -1090,6 +1124,8 @@ class OneCycleLR(_LRScheduler):
number of *batches* computed, not the total number of epochs computed.
When last_epoch=-1, the schedule is started from the beginning.
Default: -1
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
Example:
>>> data_loader = torch.utils.data.DataLoader(...)
Expand Down Expand Up @@ -1117,7 +1153,8 @@ def __init__(self,
max_momentum=0.95,
div_factor=25.,
final_div_factor=1e4,
last_epoch=-1):
last_epoch=-1,
verbose=False):

# Validate optimizer
if not isinstance(optimizer, Optimizer):
Expand Down Expand Up @@ -1179,7 +1216,7 @@ def __init__(self,
group['max_momentum'] = m_momentum
group['base_momentum'] = b_momentum

super(OneCycleLR, self).__init__(optimizer, last_epoch)
super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)

def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
Expand Down

0 comments on commit 17f76f9

Please sign in to comment.