Skip to content

Commit

Permalink
Incorporate some latest changes to optim.py (#1359)
Browse files Browse the repository at this point in the history
* init commit

* black formatted

* isort formatted
  • Loading branch information
JinZr authored Nov 2, 2023
1 parent 23913f6 commit 9e5a5d7
Showing 1 changed file with 120 additions and 49 deletions.
169 changes: 120 additions & 49 deletions egs/librispeech/ASR/zipformer/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import torch
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch import Tensor, nn
from torch.optim import Optimizer


Expand Down Expand Up @@ -116,7 +116,7 @@ def batched_params(self, param_group, group_params_names):

yield tuples # <-- calling code will do the actual optimization here!

for (stacked_params, _state, _names), batch in zip(tuples, batches):
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])

Expand Down Expand Up @@ -181,6 +181,7 @@ def __init__(
size_update_period=4,
clipping_update_period=100,
):

defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
Expand Down Expand Up @@ -326,7 +327,9 @@ def step(self, closure=None):
batch = True

for group, group_params_names in zip(self.param_groups, self.parameters_names):

with self.batched_params(group["params"], group_params_names) as batches:

# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
Expand Down Expand Up @@ -423,16 +426,19 @@ def _get_clipping_scale(
# parameters' state won't have been initialized yet.
return 1.0
clipping_update_period = group["clipping_update_period"]
scalar_lr_scale = group["scalar_lr_scale"]

tot_sumsq = torch.tensor(0.0, device=first_p.device)
for p, state, param_names in tuples:
for (p, state, param_names) in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
tot_sumsq += (grad**2).sum() * (
scalar_lr_scale**2
) # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()

Expand All @@ -443,64 +449,72 @@ def _get_clipping_scale(
)
first_state["model_norms"][step % clipping_update_period] = tot_norm

if step % clipping_update_period == 0:
irregular_estimate_steps = [
i for i in [10, 20, 40] if i < clipping_update_period
]
if step % clipping_update_period == 0 or step in irregular_estimate_steps:
# Print some stats.
# We don't reach here if step == 0 because we would have returned
# above.
sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
if step in irregular_estimate_steps:
sorted_norms = sorted_norms[-step:]
num_norms = sorted_norms.numel()
quartiles = []
for n in range(0, 5):
index = min(
clipping_update_period - 1, (clipping_update_period // 4) * n
)
index = min(num_norms - 1, (num_norms // 4) * n)
quartiles.append(sorted_norms[index].item())

median = quartiles[2]
threshold = clipping_scale * median
if step in irregular_estimate_steps:
# use larger thresholds on first few steps of estimating threshold,
# as norm may be changing rapidly.
threshold = threshold * 2.0
first_state["model_norm_threshold"] = threshold
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period
first_state["num_clipped"] * 100.0 / num_norms
if "num_clipped" in first_state
else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
logging.warn(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)

if step < clipping_update_period:
return 1.0 # We have not yet estimated a norm to clip to.
else:
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
try:
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
return 1.0 # threshold has not yet been set.

ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans != ans: # e.g. ans is nan
ans = 0.0
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(
tuples, tot_sumsq, group["scalar_lr_scale"]
)
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
if ans != ans: # e.g. ans is nan
ans = 0.0
if ans == 0.0:
for p, state, param_names in tuples:
p.grad.zero_() # get rid of infinity()

return ans
if ans == 0.0:
for (p, state, param_names) in tuples:
p.grad.zero_() # get rid of infinity()

return ans

def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor
self,
tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor,
scalar_lr_scale: float,
):
"""
Show information of parameter which dominates tot_sumsq.
Expand All @@ -516,29 +530,30 @@ def _show_gradient_dominating_parameter(
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for p, state, batch_param_names in tuples:
for (p, state, batch_param_names) in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
batch_sumsq_orig = batch_grad**2
# Dummy values used by following `zip` statement.
batch_rms_orig = torch.ones(p.shape[0])
batch_rms_orig = torch.full(
p.shape, scalar_lr_scale, device=batch_grad.device
)
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(
batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
if batch_grad.ndim > 1:
# need to guard it with if-statement because sum() sums over
# all dims if dim == ().
batch_sumsq_orig = batch_sumsq_orig.sum(
dim=list(range(1, batch_grad.ndim))
)

for name, sumsq_orig, rms, grad in zip(
batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
):

proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)

assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
Expand All @@ -552,7 +567,7 @@ def _show_gradient_dominating_parameter(
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
logging.warn(
f"Parameter dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
Expand Down Expand Up @@ -826,7 +841,7 @@ def _set_lrs(self):
def print_lr(self, is_verbose, group, lr):
"""Display the current learning rate."""
if is_verbose:
logging.info(
logging.warn(
f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
f" of group {group} to {lr:.4e}."
)
Expand All @@ -841,8 +856,14 @@ class Eden(LRScheduler):
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
and then stays constant at 1.
If you don't have the concept of epochs, or one epoch takes a very long time,
you can replace the notion of 'epoch' with some measure of the amount of data
processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
some measure representing "quite a lot of data": say, one fifth or one third
of an entire training run, but it doesn't matter much. You could also use
Eden2 which has only the notion of batches.
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
Args:
optimizer: the optimizer to change the learning rates on
Expand Down Expand Up @@ -888,6 +909,56 @@ def get_lr(self):
return [x * factor * warmup_factor for x in self.base_lrs]


class Eden2(LRScheduler):
"""
Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
only batches.
The basic formula (before warmup) is:
lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
and then stays constant at 1.
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
Args:
optimizer: the optimizer to change the learning rates on
lr_batches: the number of batches after which we start significantly
decreasing the learning rate, suggest 5000.
"""

def __init__(
self,
optimizer: Optimizer,
lr_batches: Union[int, float],
warmup_batches: Union[int, float] = 500.0,
warmup_start: float = 0.5,
verbose: bool = False,
):
super().__init__(optimizer, verbose)
self.lr_batches = lr_batches
self.warmup_batches = warmup_batches

assert 0.0 <= warmup_start <= 1.0, warmup_start
self.warmup_start = warmup_start

def get_lr(self):
factor = (
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2
) ** -0.5
warmup_factor = (
1.0
if self.batch >= self.warmup_batches
else self.warmup_start
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
)

return [x * factor * warmup_factor for x in self.base_lrs]


def _test_eden():
m = torch.nn.Linear(100, 100)
optim = ScaledAdam(m.parameters(), lr=0.03)
Expand Down

0 comments on commit 9e5a5d7

Please sign in to comment.