Skip to content

Commit

Permalink
Remove trainer.num_optimizer and support multiple optimizers in defau…
Browse files Browse the repository at this point in the history
…lt trainer
  • Loading branch information
kamo-naoyuki committed Feb 25, 2021
1 parent 75eb357 commit 4561969
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 33 deletions.
5 changes: 0 additions & 5 deletions espnet2/tasks/abs_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,11 +974,6 @@ def print_config(cls, file=sys.stdout) -> None:

@classmethod
def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None):
if cls.num_optimizers != cls.trainer.num_optimizers:
raise RuntimeError(
f"Task.num_optimizers != Task.trainer.num_optimizers: "
f"{cls.num_optimizers} != {cls.trainer.num_optimizers}"
)
assert check_argument_types()
print(get_commandline_args(), file=sys.stderr)
if args is None:
Expand Down
77 changes: 49 additions & 28 deletions espnet2/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ class Trainer:
and override the methods if necessary - at least "train_one_epoch()"
>>> class TwoOptimizerTrainer(Trainer):
... num_optimizers: int = 1
...
... @classmethod
... def add_arguments(cls, parser):
... ...
Expand All @@ -99,9 +97,6 @@ class Trainer:
"""

# If you need more than one optimizers, change this value in inheritance
num_optimizers: int = 1

def __init__(self):
raise RuntimeError("This class can't be instantiated.")

Expand Down Expand Up @@ -143,6 +138,7 @@ def run(
assert check_argument_types()
# NOTE(kamo): Don't check the type more strictly as far trainer_options
assert is_dataclass(trainer_options), type(trainer_options)
assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))

start_epoch = reporter.get_epoch() + 1
if start_epoch == max_epoch + 1:
Expand Down Expand Up @@ -345,12 +341,6 @@ def train_one_epoch(
) -> bool:
assert check_argument_types()

# Note(kamo): assumes one optimizer
assert cls.num_optimizers == 1, cls.num_optimizers
assert len(optimizers) == 1, len(optimizers)
optimizer = optimizers[0]
scheduler = schedulers[0]

grad_noise = options.grad_noise
accum_grad = options.accum_grad
grad_clip = options.grad_clip
Expand Down Expand Up @@ -391,7 +381,26 @@ def train_one_epoch(

with autocast(scaler is not None):
with reporter.measure_time("forward_time"):
loss, stats, weight = model(**batch)
retval = model(**batch)

# Note(kamo):
# Supporting two patterns for the returned value from the model
# a. dict type
if isinstance(retval, dict):
loss = retval["loss"]
stats = retval["stats"]
weight = retval["weight"]

if "optim_idx" in retval:
optim_idx = retval["optim_idx"]
else:
optim_idx = None

# b. tuple or list type
else:
loss, stats, weight = retval
optim_idx = None

stats = {k: v for k, v in stats.items() if v is not None}
if ngpu > 1 or distributed:
# Apply weighted averaging for loss and stats
Expand Down Expand Up @@ -425,7 +434,10 @@ def train_one_epoch(
if iiter % accum_grad == 0:
if scaler is not None:
# Unscales the gradients of optimizer's assigned params in-place
scaler.unscale_(optimizer)
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
scaler.unscale_(optimizer)

# gradient noise injection
if grad_noise:
Expand Down Expand Up @@ -459,31 +471,40 @@ def train_one_epoch(
# Note that if the gradient has inf/nan values,
# scaler.step skips optimizer.step().
if scaler is not None:
scaler.step(optimizer)
scaler.update()
for iopt, optimizer in enumerate(optimizers):
if optim_idx is not None and iopt != optim_idx:
continue
scaler.step(optimizer)
scaler.update()

else:
all_steps_are_invalid = False
with reporter.measure_time("optim_step_time"):
if scaler is not None:
# scaler.step() first unscales the gradients of
# the optimizer's assigned params.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
else:
optimizer.step()
if isinstance(scheduler, AbsBatchStepScheduler):
scheduler.step()
optimizer.zero_grad()
for iopt, optimizer, scheduler in enumerate(
zip(optimizers, schedulers)
):
if optim_idx is not None and iopt != optim_idx:
continue
if scaler is not None:
# scaler.step() first unscales the gradients of
# the optimizer's assigned params.
scaler.step(optimizer)
# Updates the scale for next iteration.
scaler.update()
else:
optimizer.step()
if isinstance(scheduler, AbsBatchStepScheduler):
scheduler.step()
optimizer.zero_grad()

# Register lr and train/load time[sec/step],
# where step refers to accum_grad * mini-batch
reporter.register(
dict(
{
f"lr_{i}": pg["lr"]
for i, pg in enumerate(optimizer.param_groups)
f"optim{i}_lr{j}": pg["lr"]
for i, optimizer in enumerate(optimizers)
for j, pg in enumerate(optimizer.param_groups)
if "lr" in pg
},
train_time=time.perf_counter() - start_time,
Expand Down

0 comments on commit 4561969

Please sign in to comment.