Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hwangjeff committed Apr 12, 2022
1 parent c747eca commit 30459cc
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 16 deletions.
10 changes: 5 additions & 5 deletions examples/asr/librispeech_conformer_rnnt/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,28 @@ def run_eval(args):
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint_path",
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--global_stats_path",
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path",
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path",
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--use_cuda",
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
Expand Down
52 changes: 46 additions & 6 deletions examples/asr/librispeech_conformer_rnnt/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,37 @@ def forward(self, input):


class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_updates, force_anneal_epoch, anneal_factor, last_epoch=-1, verbose=False):
self.warmup_updates = warmup_updates
self.force_anneal_epoch = force_anneal_epoch
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""

def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

def get_lr(self):
if self._step_count < self.force_anneal_epoch:
return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_epoch)
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]


Expand Down Expand Up @@ -272,6 +292,26 @@ def forward(self, batch: Batch):
return post_process_hypos(hypotheses, self.sp_model)[0][0]

def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data commonly yields.
"""

opt = self.optimizers()
opt.zero_grad()
loss = self._step(batch, batch_idx, "train")
Expand Down
10 changes: 5 additions & 5 deletions examples/asr/librispeech_conformer_rnnt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,29 @@ def run_train(args):
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--exp_dir",
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global_stats_path",
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech_path",
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
)
parser.add_argument(
"--sp_model_path",
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
)
parser.add_argument(
"--num_nodes",
"--num-nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
Expand Down

0 comments on commit 30459cc

Please sign in to comment.